llama_cpp 0.12.7 → 0.14.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +131 -288
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +29 -29
- data/vendor/tmp/llama.cpp/Makefile +10 -6
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +32 -23
- data/vendor/tmp/llama.cpp/ggml-backend.h +17 -16
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +949 -168
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +159 -22
- data/vendor/tmp/llama.cpp/ggml-metal.metal +1195 -139
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +27 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +1971 -271
- data/vendor/tmp/llama.cpp/ggml-quants.h +52 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +3586 -1201
- data/vendor/tmp/llama.cpp/ggml-sycl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +39336 -43461
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +1391 -825
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +1 -0
- data/vendor/tmp/llama.cpp/ggml.c +545 -210
- data/vendor/tmp/llama.cpp/ggml.h +65 -23
- data/vendor/tmp/llama.cpp/llama.cpp +1458 -763
- data/vendor/tmp/llama.cpp/llama.h +81 -75
- data/vendor/tmp/llama.cpp/unicode.h +310 -1
- metadata +2 -2
@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
|
|
1959
1959
|
}
|
1960
1960
|
}
|
1961
1961
|
|
1962
|
+
kernel void kernel_arange_f32(
|
1963
|
+
device char * dst,
|
1964
|
+
constant int64_t & ne0,
|
1965
|
+
constant float & start,
|
1966
|
+
constant float & step,
|
1967
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1968
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1969
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1970
|
+
|
1971
|
+
device float * dst_ptr = (device float *) dst;
|
1972
|
+
|
1973
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1974
|
+
dst_ptr[i0] = start + step * i0;
|
1975
|
+
}
|
1976
|
+
}
|
1977
|
+
|
1978
|
+
kernel void kernel_timestep_embedding_f32(
|
1979
|
+
device const char * src0,
|
1980
|
+
device char * dst,
|
1981
|
+
constant uint64_t & nb1,
|
1982
|
+
constant int & dim,
|
1983
|
+
constant int & max_period,
|
1984
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1985
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1986
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1987
|
+
|
1988
|
+
int i = tgpig.x;
|
1989
|
+
device float * embed_data = (device float *)(dst + i*nb1);
|
1990
|
+
|
1991
|
+
int half_ = dim / 2;
|
1992
|
+
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
1993
|
+
float timestep = ((device float *)src0)[i];
|
1994
|
+
float freq = (float)exp(-log((float)max_period) * j / half_);
|
1995
|
+
float arg = timestep * freq;
|
1996
|
+
embed_data[j ] = cos(arg);
|
1997
|
+
embed_data[j + half_] = sin(arg);
|
1998
|
+
}
|
1999
|
+
|
2000
|
+
if (dim % 2 != 0 && tpitg.x == 0) {
|
2001
|
+
embed_data[dim] = 0.f;
|
2002
|
+
}
|
2003
|
+
}
|
2004
|
+
|
1962
2005
|
// bitonic sort implementation following the CUDA kernels as reference
|
1963
2006
|
typedef void (argsort_t)(
|
1964
2007
|
device const float * x,
|
@@ -2519,12 +2562,34 @@ typedef struct {
|
|
2519
2562
|
} block_iq2_xs;
|
2520
2563
|
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
2521
2564
|
|
2565
|
+
// 2.5625 bpw quants
|
2566
|
+
typedef struct {
|
2567
|
+
half d;
|
2568
|
+
uint8_t qs[QK_K/4];
|
2569
|
+
uint8_t qh[QK_K/32];
|
2570
|
+
uint8_t scales[QK_K/32];
|
2571
|
+
} block_iq2_s;
|
2572
|
+
|
2522
2573
|
typedef struct {
|
2523
2574
|
half d;
|
2524
2575
|
uint8_t qs[3*QK_K/8];
|
2525
2576
|
} block_iq3_xxs;
|
2526
2577
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
2527
2578
|
|
2579
|
+
// 3.4375 bpw
|
2580
|
+
#if QK_K == 64
|
2581
|
+
#define IQ3S_N_SCALE 2
|
2582
|
+
#else
|
2583
|
+
#define IQ3S_N_SCALE QK_K/64
|
2584
|
+
#endif
|
2585
|
+
typedef struct {
|
2586
|
+
half d;
|
2587
|
+
uint8_t qs[QK_K/4];
|
2588
|
+
uint8_t qh[QK_K/32];
|
2589
|
+
uint8_t signs[QK_K/8];
|
2590
|
+
uint8_t scales[IQ3S_N_SCALE];
|
2591
|
+
} block_iq3_s;
|
2592
|
+
|
2528
2593
|
typedef struct {
|
2529
2594
|
half d;
|
2530
2595
|
uint8_t qs[QK_K/8];
|
@@ -2538,6 +2603,17 @@ typedef struct {
|
|
2538
2603
|
uint8_t qs[QK4_NL/2];
|
2539
2604
|
} block_iq4_nl;
|
2540
2605
|
|
2606
|
+
#if QK_K == 64
|
2607
|
+
#define block_iq4_xs block_iq4_nl
|
2608
|
+
#else
|
2609
|
+
typedef struct {
|
2610
|
+
half d;
|
2611
|
+
uint16_t scales_h;
|
2612
|
+
uint8_t scales_l[QK_K/64];
|
2613
|
+
uint8_t qs[QK_K/2];
|
2614
|
+
} block_iq4_xs;
|
2615
|
+
#endif
|
2616
|
+
|
2541
2617
|
//====================================== dot products =========================
|
2542
2618
|
|
2543
2619
|
void kernel_mul_mv_q2_K_f32_impl(
|
@@ -3760,6 +3836,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
3760
3836
|
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
3761
3837
|
};
|
3762
3838
|
|
3839
|
+
constexpr constant static uint64_t iq2s_grid[1024] = {
|
3840
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
3841
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
3842
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
3843
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
3844
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
3845
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
|
3846
|
+
0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
|
3847
|
+
0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
|
3848
|
+
0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
|
3849
|
+
0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
|
3850
|
+
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
|
3851
|
+
0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
|
3852
|
+
0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
|
3853
|
+
0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
|
3854
|
+
0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
|
3855
|
+
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
|
3856
|
+
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
|
3857
|
+
0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
|
3858
|
+
0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
|
3859
|
+
0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
|
3860
|
+
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
|
3861
|
+
0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
|
3862
|
+
0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
|
3863
|
+
0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
|
3864
|
+
0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
|
3865
|
+
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
|
3866
|
+
0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
|
3867
|
+
0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
|
3868
|
+
0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
|
3869
|
+
0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
|
3870
|
+
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
|
3871
|
+
0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
|
3872
|
+
0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
|
3873
|
+
0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
|
3874
|
+
0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
|
3875
|
+
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
|
3876
|
+
0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
|
3877
|
+
0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
|
3878
|
+
0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
|
3879
|
+
0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
|
3880
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
|
3881
|
+
0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
|
3882
|
+
0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
|
3883
|
+
0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
|
3884
|
+
0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
|
3885
|
+
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
|
3886
|
+
0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
|
3887
|
+
0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
|
3888
|
+
0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
|
3889
|
+
0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
|
3890
|
+
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
|
3891
|
+
0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
|
3892
|
+
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
|
3893
|
+
0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
|
3894
|
+
0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
|
3895
|
+
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
|
3896
|
+
0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
|
3897
|
+
0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
|
3898
|
+
0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
|
3899
|
+
0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
|
3900
|
+
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
|
3901
|
+
0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
|
3902
|
+
0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
|
3903
|
+
0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
|
3904
|
+
0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
|
3905
|
+
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
|
3906
|
+
0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
|
3907
|
+
0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
|
3908
|
+
0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
|
3909
|
+
0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
|
3910
|
+
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
|
3911
|
+
0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
|
3912
|
+
0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
|
3913
|
+
0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
|
3914
|
+
0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
|
3915
|
+
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
|
3916
|
+
0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
|
3917
|
+
0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
|
3918
|
+
0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
|
3919
|
+
0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
|
3920
|
+
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
|
3921
|
+
0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
|
3922
|
+
0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
|
3923
|
+
0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
|
3924
|
+
0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
|
3925
|
+
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
|
3926
|
+
0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
|
3927
|
+
0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
|
3928
|
+
0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
|
3929
|
+
0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
|
3930
|
+
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
|
3931
|
+
0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
|
3932
|
+
0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
|
3933
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
|
3934
|
+
0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
|
3935
|
+
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
|
3936
|
+
0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
|
3937
|
+
0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
|
3938
|
+
0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
|
3939
|
+
0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
|
3940
|
+
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
|
3941
|
+
0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
|
3942
|
+
0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
|
3943
|
+
0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
|
3944
|
+
0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
|
3945
|
+
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
|
3946
|
+
0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
|
3947
|
+
0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
|
3948
|
+
0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
|
3949
|
+
0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
|
3950
|
+
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
|
3951
|
+
0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
|
3952
|
+
0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
|
3953
|
+
0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
|
3954
|
+
0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
|
3955
|
+
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
|
3956
|
+
0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
|
3957
|
+
0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
|
3958
|
+
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
|
3959
|
+
0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
|
3960
|
+
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
|
3961
|
+
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
|
3962
|
+
0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
|
3963
|
+
0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
|
3964
|
+
0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
|
3965
|
+
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
|
3966
|
+
0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
|
3967
|
+
0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
|
3968
|
+
0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
|
3969
|
+
0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
|
3970
|
+
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
|
3971
|
+
0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
|
3972
|
+
0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
|
3973
|
+
0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
|
3974
|
+
0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
|
3975
|
+
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
|
3976
|
+
0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
|
3977
|
+
0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
|
3978
|
+
0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
|
3979
|
+
0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
|
3980
|
+
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
|
3981
|
+
0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
|
3982
|
+
0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
|
3983
|
+
0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
|
3984
|
+
0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
|
3985
|
+
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
|
3986
|
+
0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
|
3987
|
+
0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
|
3988
|
+
0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
|
3989
|
+
0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
|
3990
|
+
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
|
3991
|
+
0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
|
3992
|
+
0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
|
3993
|
+
0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
|
3994
|
+
0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
|
3995
|
+
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
|
3996
|
+
0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
|
3997
|
+
0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
|
3998
|
+
0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
|
3999
|
+
0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
|
4000
|
+
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
|
4001
|
+
0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
|
4002
|
+
0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
|
4003
|
+
0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
|
4004
|
+
0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
|
4005
|
+
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
|
4006
|
+
0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
|
4007
|
+
0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
|
4008
|
+
0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
|
4009
|
+
0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
|
4010
|
+
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
|
4011
|
+
0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
|
4012
|
+
0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
|
4013
|
+
0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
|
4014
|
+
0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
|
4015
|
+
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
|
4016
|
+
0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
|
4017
|
+
0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
|
4018
|
+
0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
|
4019
|
+
0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
|
4020
|
+
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
|
4021
|
+
0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
|
4022
|
+
0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
|
4023
|
+
0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
|
4024
|
+
0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
|
4025
|
+
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
|
4026
|
+
0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
|
4027
|
+
0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
|
4028
|
+
0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
|
4029
|
+
0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
|
4030
|
+
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
|
4031
|
+
0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
|
4032
|
+
0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
|
4033
|
+
0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
|
4034
|
+
0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
|
4035
|
+
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
|
4036
|
+
0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
|
4037
|
+
0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
|
4038
|
+
0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
|
4039
|
+
0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
|
4040
|
+
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
|
4041
|
+
0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
|
4042
|
+
0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
|
4043
|
+
0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
|
4044
|
+
0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
|
4045
|
+
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
|
4046
|
+
0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
|
4047
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
|
4048
|
+
0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
|
4049
|
+
0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
|
4050
|
+
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
|
4051
|
+
0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
|
4052
|
+
0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
|
4053
|
+
0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
|
4054
|
+
0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
|
4055
|
+
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
|
4056
|
+
0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
|
4057
|
+
0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
|
4058
|
+
0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
|
4059
|
+
0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
|
4060
|
+
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
|
4061
|
+
0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
|
4062
|
+
0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
|
4063
|
+
0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
|
4064
|
+
0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
|
4065
|
+
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
|
4066
|
+
0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
|
4067
|
+
0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
|
4068
|
+
0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
|
4069
|
+
0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
|
4070
|
+
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
|
4071
|
+
0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
|
4072
|
+
0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
|
4073
|
+
0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
|
4074
|
+
0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
|
4075
|
+
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
|
4076
|
+
0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
|
4077
|
+
0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
|
4078
|
+
0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
|
4079
|
+
0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
|
4080
|
+
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
|
4081
|
+
0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
|
4082
|
+
0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
|
4083
|
+
0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
|
4084
|
+
0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
|
4085
|
+
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
|
4086
|
+
0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
|
4087
|
+
0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
|
4088
|
+
0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
|
4089
|
+
0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
|
4090
|
+
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
|
4091
|
+
0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
|
4092
|
+
0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
|
4093
|
+
0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
|
4094
|
+
0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
|
4095
|
+
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
|
4096
|
+
};
|
4097
|
+
|
3763
4098
|
constexpr constant static uint32_t iq3xxs_grid[256] = {
|
3764
4099
|
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
3765
4100
|
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
@@ -3795,6 +4130,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
3795
4130
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
3796
4131
|
};
|
3797
4132
|
|
4133
|
+
constexpr constant static uint32_t iq3s_grid[512] = {
|
4134
|
+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
4135
|
+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
4136
|
+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
4137
|
+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
4138
|
+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
4139
|
+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
4140
|
+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
4141
|
+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
4142
|
+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
4143
|
+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
4144
|
+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
4145
|
+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
4146
|
+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
4147
|
+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
4148
|
+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
4149
|
+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
4150
|
+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
4151
|
+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
4152
|
+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
4153
|
+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
4154
|
+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
4155
|
+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
4156
|
+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
4157
|
+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
4158
|
+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
4159
|
+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
4160
|
+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
4161
|
+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
4162
|
+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
4163
|
+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
4164
|
+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
4165
|
+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
4166
|
+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
4167
|
+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
4168
|
+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
4169
|
+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
4170
|
+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
4171
|
+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
4172
|
+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
4173
|
+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
4174
|
+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
4175
|
+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
4176
|
+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
4177
|
+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
4178
|
+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
4179
|
+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
4180
|
+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
4181
|
+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
4182
|
+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
4183
|
+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
4184
|
+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
4185
|
+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
4186
|
+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
4187
|
+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
4188
|
+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
4189
|
+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
4190
|
+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
4191
|
+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
4192
|
+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
4193
|
+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
4194
|
+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
4195
|
+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
4196
|
+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
4197
|
+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
4198
|
+
};
|
4199
|
+
|
3798
4200
|
#define NGRID_IQ1S 512
|
3799
4201
|
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
3800
4202
|
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
@@ -3991,7 +4393,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3991
4393
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3992
4394
|
}
|
3993
4395
|
|
3994
|
-
#if QK_K == 256
|
3995
4396
|
const int ix = tiisg;
|
3996
4397
|
|
3997
4398
|
device const float * y4 = y + 32 * ix;
|
@@ -4032,12 +4433,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
4032
4433
|
|
4033
4434
|
y4 += 32 * 32;
|
4034
4435
|
}
|
4035
|
-
#else
|
4036
|
-
(void) x;
|
4037
|
-
(void) y;
|
4038
|
-
(void) yl;
|
4039
|
-
(void) nb32;
|
4040
|
-
#endif
|
4041
4436
|
|
4042
4437
|
for (int row = 0; row < N_DST; ++row) {
|
4043
4438
|
all_sum = simd_sum(sumf[row]);
|
@@ -4127,7 +4522,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
4127
4522
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4128
4523
|
}
|
4129
4524
|
|
4130
|
-
#if QK_K == 256
|
4131
4525
|
const int ix = tiisg;
|
4132
4526
|
|
4133
4527
|
device const float * y4 = y + 32 * ix;
|
@@ -4178,12 +4572,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
4178
4572
|
|
4179
4573
|
y4 += 32 * 32;
|
4180
4574
|
}
|
4181
|
-
#else
|
4182
|
-
(void) x;
|
4183
|
-
(void) y;
|
4184
|
-
(void) yl;
|
4185
|
-
(void) nb32;
|
4186
|
-
#endif
|
4187
4575
|
|
4188
4576
|
for (int row = 0; row < N_DST; ++row) {
|
4189
4577
|
all_sum = simd_sum(sumf[row]);
|
@@ -4273,7 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
4273
4661
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4274
4662
|
}
|
4275
4663
|
|
4276
|
-
#if QK_K == 256
|
4277
4664
|
const int ix = tiisg;
|
4278
4665
|
|
4279
4666
|
device const float * y4 = y + 32 * ix;
|
@@ -4317,12 +4704,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
4317
4704
|
|
4318
4705
|
y4 += 32 * 32;
|
4319
4706
|
}
|
4320
|
-
#else
|
4321
|
-
(void) x;
|
4322
|
-
(void) y;
|
4323
|
-
(void) yl;
|
4324
|
-
(void) nb32;
|
4325
|
-
#endif
|
4326
4707
|
|
4327
4708
|
for (int row = 0; row < N_DST; ++row) {
|
4328
4709
|
all_sum = simd_sum(sumf[row]);
|
@@ -4361,7 +4742,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
4361
4742
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4362
4743
|
}
|
4363
4744
|
|
4364
|
-
void
|
4745
|
+
void kernel_mul_mv_iq3_s_f32_impl(
|
4365
4746
|
device const void * src0,
|
4366
4747
|
device const float * src1,
|
4367
4748
|
device float * dst,
|
@@ -4374,6 +4755,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4374
4755
|
constant int64_t & ne1,
|
4375
4756
|
constant uint & r2,
|
4376
4757
|
constant uint & r3,
|
4758
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4377
4759
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4378
4760
|
uint tiisg[[thread_index_in_simdgroup]],
|
4379
4761
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -4390,59 +4772,70 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4390
4772
|
const uint i13 = im/ne12;
|
4391
4773
|
|
4392
4774
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4393
|
-
|
4775
|
+
|
4776
|
+
device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
|
4394
4777
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4395
4778
|
|
4396
|
-
float yl[
|
4779
|
+
float yl[32];
|
4397
4780
|
float sumf[N_DST]={0.f}, all_sum;
|
4398
4781
|
|
4399
4782
|
const int nb32 = nb * (QK_K / 32);
|
4400
4783
|
|
4401
|
-
|
4402
|
-
|
4403
|
-
|
4784
|
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
4785
|
+
{
|
4786
|
+
int nval = 8;
|
4787
|
+
int pos = (32*sgitg + tiisg)*nval;
|
4788
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
|
4789
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4790
|
+
}
|
4404
4791
|
|
4405
|
-
|
4792
|
+
const int ix = tiisg;
|
4406
4793
|
|
4407
|
-
|
4794
|
+
device const float * y4 = y + 32 * ix;
|
4408
4795
|
|
4409
|
-
|
4796
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4797
|
+
|
4798
|
+
for (int i = 0; i < 32; ++i) {
|
4410
4799
|
yl[i] = y4[i];
|
4411
4800
|
}
|
4412
4801
|
|
4413
4802
|
const int ibl = ib32 / (QK_K / 32);
|
4414
4803
|
const int ib = ib32 % (QK_K / 32);
|
4415
4804
|
|
4416
|
-
device const
|
4417
|
-
device const uint8_t * qs = xr->qs +
|
4418
|
-
device const uint8_t *
|
4419
|
-
device const
|
4805
|
+
device const block_iq3_s * xr = x + ibl;
|
4806
|
+
device const uint8_t * qs = xr->qs + 8 * ib;
|
4807
|
+
device const uint8_t * qh = xr->qh + ib;
|
4808
|
+
device const uint8_t * sc = xr->scales + (ib/2);
|
4809
|
+
device const uint8_t * signs = xr->signs + 4 * ib;
|
4810
|
+
device const half * dh = &xr->d;
|
4420
4811
|
|
4421
4812
|
for (int row = 0; row < N_DST; row++) {
|
4422
4813
|
|
4423
|
-
|
4424
|
-
|
4814
|
+
const float db = dh[0];
|
4815
|
+
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
4425
4816
|
|
4426
4817
|
float2 sum = {0};
|
4427
|
-
for (int
|
4428
|
-
|
4429
|
-
|
4818
|
+
for (int l = 0; l < 4; ++l) {
|
4819
|
+
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
|
4820
|
+
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
|
4821
|
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
4822
|
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
4823
|
+
for (int j = 0; j < 4; ++j) {
|
4824
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
4825
|
+
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
4826
|
+
}
|
4430
4827
|
}
|
4431
|
-
sumf[row] +=
|
4828
|
+
sumf[row] += d * (sum[0] + sum[1]);
|
4432
4829
|
|
4433
|
-
dh
|
4434
|
-
qs
|
4435
|
-
|
4830
|
+
dh += nb*sizeof(block_iq3_s)/2;
|
4831
|
+
qs += nb*sizeof(block_iq3_s);
|
4832
|
+
qh += nb*sizeof(block_iq3_s);
|
4833
|
+
sc += nb*sizeof(block_iq3_s);
|
4834
|
+
signs += nb*sizeof(block_iq3_s);
|
4436
4835
|
}
|
4437
4836
|
|
4438
|
-
y4 +=
|
4837
|
+
y4 += 32 * 32;
|
4439
4838
|
}
|
4440
|
-
#else
|
4441
|
-
(void) x;
|
4442
|
-
(void) y;
|
4443
|
-
(void) yl;
|
4444
|
-
(void) nb32;
|
4445
|
-
#endif
|
4446
4839
|
|
4447
4840
|
for (int row = 0; row < N_DST; ++row) {
|
4448
4841
|
all_sum = simd_sum(sumf[row]);
|
@@ -4452,11 +4845,36 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4452
4845
|
}
|
4453
4846
|
}
|
4454
4847
|
|
4455
|
-
|
4456
|
-
|
4457
|
-
|
4848
|
+
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
4849
|
+
kernel void kernel_mul_mv_iq3_s_f32(
|
4850
|
+
device const void * src0,
|
4851
|
+
device const float * src1,
|
4852
|
+
device float * dst,
|
4853
|
+
constant int64_t & ne00,
|
4854
|
+
constant int64_t & ne01,
|
4855
|
+
constant int64_t & ne02,
|
4856
|
+
constant uint64_t & nb00,
|
4857
|
+
constant uint64_t & nb01,
|
4858
|
+
constant uint64_t & nb02,
|
4859
|
+
constant int64_t & ne10,
|
4860
|
+
constant int64_t & ne11,
|
4861
|
+
constant int64_t & ne12,
|
4862
|
+
constant uint64_t & nb10,
|
4863
|
+
constant uint64_t & nb11,
|
4864
|
+
constant uint64_t & nb12,
|
4865
|
+
constant int64_t & ne0,
|
4866
|
+
constant int64_t & ne1,
|
4867
|
+
constant uint & r2,
|
4868
|
+
constant uint & r3,
|
4869
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4871
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4872
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4458
4873
|
|
4459
|
-
|
4874
|
+
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4875
|
+
}
|
4876
|
+
|
4877
|
+
void kernel_mul_mv_iq2_s_f32_impl(
|
4460
4878
|
device const void * src0,
|
4461
4879
|
device const float * src1,
|
4462
4880
|
device float * dst,
|
@@ -4469,88 +4887,99 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
4469
4887
|
constant int64_t & ne1,
|
4470
4888
|
constant uint & r2,
|
4471
4889
|
constant uint & r3,
|
4472
|
-
threadgroup
|
4890
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4473
4891
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4474
4892
|
uint tiisg[[thread_index_in_simdgroup]],
|
4475
4893
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4476
4894
|
|
4477
|
-
const int nb = ne00/
|
4895
|
+
const int nb = ne00/QK_K;
|
4478
4896
|
const int r0 = tgpig.x;
|
4479
4897
|
const int r1 = tgpig.y;
|
4480
4898
|
const int im = tgpig.z;
|
4481
|
-
|
4899
|
+
|
4900
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4482
4901
|
const int ib_row = first_row * nb;
|
4483
4902
|
|
4484
4903
|
const uint i12 = im%ne12;
|
4485
4904
|
const uint i13 = im/ne12;
|
4486
4905
|
|
4487
4906
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4488
|
-
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
4489
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4490
|
-
|
4491
|
-
const int ix = tiisg/2; // 0...15
|
4492
|
-
const int it = tiisg%2; // 0 or 1
|
4493
4907
|
|
4494
|
-
|
4495
|
-
|
4496
|
-
|
4497
|
-
float4 yl[4];
|
4498
|
-
float sumf[2]={0.f}, all_sum;
|
4908
|
+
device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
|
4909
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4499
4910
|
|
4500
|
-
|
4911
|
+
float yl[32];
|
4912
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4501
4913
|
|
4502
|
-
|
4503
|
-
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
4914
|
+
const int nb32 = nb * (QK_K / 32);
|
4504
4915
|
|
4505
|
-
|
4916
|
+
//threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
4917
|
+
//{
|
4918
|
+
// int nval = 32;
|
4919
|
+
// int pos = (32*sgitg + tiisg)*nval;
|
4920
|
+
// for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
|
4921
|
+
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
4922
|
+
//}
|
4506
4923
|
|
4507
|
-
|
4924
|
+
const int ix = tiisg;
|
4508
4925
|
|
4509
|
-
|
4510
|
-
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
4926
|
+
device const float * y4 = y + 32 * ix;
|
4511
4927
|
|
4512
|
-
|
4928
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4513
4929
|
|
4514
|
-
|
4515
|
-
|
4930
|
+
for (int i = 0; i < 32; ++i) {
|
4931
|
+
yl[i] = y4[i];
|
4932
|
+
}
|
4516
4933
|
|
4517
|
-
|
4934
|
+
const int ibl = ib32 / (QK_K / 32);
|
4935
|
+
const int ib = ib32 % (QK_K / 32);
|
4518
4936
|
|
4519
|
-
|
4520
|
-
|
4521
|
-
|
4522
|
-
|
4523
|
-
|
4524
|
-
|
4525
|
-
acc2 += yl[1] * qf2;
|
4937
|
+
device const block_iq2_s * xr = x + ibl;
|
4938
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
4939
|
+
device const uint8_t * qh = xr->qh + ib;
|
4940
|
+
device const uint8_t * sc = xr->scales + ib;
|
4941
|
+
device const uint8_t * signs = qs + QK_K/8;
|
4942
|
+
device const half * dh = &xr->d;
|
4526
4943
|
|
4527
|
-
|
4528
|
-
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
4529
|
-
aux32[0] &= 0x0f0f0f0f;
|
4530
|
-
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
4531
|
-
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
4532
|
-
acc1 += yl[2] * qf1;
|
4533
|
-
acc2 += yl[3] * qf2;
|
4944
|
+
for (int row = 0; row < N_DST; row++) {
|
4534
4945
|
|
4535
|
-
|
4946
|
+
const float db = dh[0];
|
4947
|
+
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
4948
|
+
const float d2 = db * (0.5f + (sc[0] >> 4));
|
4536
4949
|
|
4537
|
-
|
4950
|
+
float2 sum = {0};
|
4951
|
+
for (int l = 0; l < 2; ++l) {
|
4952
|
+
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
4953
|
+
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
4954
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
4955
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
4956
|
+
for (int j = 0; j < 8; ++j) {
|
4957
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
4958
|
+
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
4959
|
+
}
|
4960
|
+
}
|
4961
|
+
sumf[row] += d1 * sum[0] + d2 * sum[1];
|
4538
4962
|
|
4963
|
+
dh += nb*sizeof(block_iq2_s)/2;
|
4964
|
+
qs += nb*sizeof(block_iq2_s);
|
4965
|
+
qh += nb*sizeof(block_iq2_s);
|
4966
|
+
sc += nb*sizeof(block_iq2_s);
|
4967
|
+
signs += nb*sizeof(block_iq2_s);
|
4539
4968
|
}
|
4540
4969
|
|
4541
|
-
|
4970
|
+
y4 += 32 * 32;
|
4542
4971
|
}
|
4543
4972
|
|
4544
|
-
for (int row = 0; row <
|
4973
|
+
for (int row = 0; row < N_DST; ++row) {
|
4545
4974
|
all_sum = simd_sum(sumf[row]);
|
4546
4975
|
if (tiisg == 0) {
|
4547
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4976
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
4548
4977
|
}
|
4549
4978
|
}
|
4550
4979
|
}
|
4551
4980
|
|
4552
|
-
[[host_name("
|
4553
|
-
kernel void
|
4981
|
+
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
4982
|
+
kernel void kernel_mul_mv_iq2_s_f32(
|
4554
4983
|
device const void * src0,
|
4555
4984
|
device const float * src1,
|
4556
4985
|
device float * dst,
|
@@ -4570,67 +4999,406 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4570
4999
|
constant int64_t & ne1,
|
4571
5000
|
constant uint & r2,
|
4572
5001
|
constant uint & r3,
|
5002
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4573
5003
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4574
5004
|
uint tiisg[[thread_index_in_simdgroup]],
|
4575
5005
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4576
5006
|
|
4577
|
-
|
5007
|
+
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4578
5008
|
}
|
4579
5009
|
|
4580
|
-
|
4581
|
-
kernel void kernel_mul_mv_iq4_nl_f32(
|
5010
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
4582
5011
|
device const void * src0,
|
4583
5012
|
device const float * src1,
|
4584
5013
|
device float * dst,
|
4585
5014
|
constant int64_t & ne00,
|
4586
5015
|
constant int64_t & ne01,
|
4587
5016
|
constant int64_t & ne02,
|
4588
|
-
constant uint64_t & nb00,
|
4589
|
-
constant uint64_t & nb01,
|
4590
|
-
constant uint64_t & nb02,
|
4591
5017
|
constant int64_t & ne10,
|
4592
|
-
constant int64_t & ne11,
|
4593
5018
|
constant int64_t & ne12,
|
4594
|
-
constant uint64_t & nb10,
|
4595
|
-
constant uint64_t & nb11,
|
4596
|
-
constant uint64_t & nb12,
|
4597
5019
|
constant int64_t & ne0,
|
4598
5020
|
constant int64_t & ne1,
|
4599
5021
|
constant uint & r2,
|
4600
5022
|
constant uint & r3,
|
4601
|
-
threadgroup float * shared_values [[threadgroup(0)]],
|
4602
5023
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4603
|
-
uint
|
4604
|
-
uint
|
5024
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5025
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4605
5026
|
|
4606
|
-
|
4607
|
-
|
5027
|
+
const int nb = ne00/QK_K;
|
5028
|
+
const int r0 = tgpig.x;
|
5029
|
+
const int r1 = tgpig.y;
|
5030
|
+
const int im = tgpig.z;
|
4608
5031
|
|
4609
|
-
|
5032
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
5033
|
+
const int ib_row = first_row * nb;
|
4610
5034
|
|
4611
|
-
|
4612
|
-
|
4613
|
-
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
4614
|
-
float4x4 temp = *(((device float4x4 *)src));
|
4615
|
-
for (int i = 0; i < 16; i++){
|
4616
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
4617
|
-
}
|
4618
|
-
}
|
5035
|
+
const uint i12 = im%ne12;
|
5036
|
+
const uint i13 = im/ne12;
|
4619
5037
|
|
4620
|
-
|
4621
|
-
|
4622
|
-
|
4623
|
-
for (int i = 0; i < 16; i++){
|
4624
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
4625
|
-
}
|
4626
|
-
}
|
5038
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
5039
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
5040
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4627
5041
|
|
4628
|
-
|
4629
|
-
|
4630
|
-
|
4631
|
-
const
|
4632
|
-
|
4633
|
-
const
|
5042
|
+
float yl[16];
|
5043
|
+
float sumf[N_DST]={0.f}, all_sum;
|
5044
|
+
|
5045
|
+
const int nb32 = nb * (QK_K / 32);
|
5046
|
+
|
5047
|
+
const int ix = tiisg/2;
|
5048
|
+
const int il = tiisg%2;
|
5049
|
+
|
5050
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
5051
|
+
|
5052
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
5053
|
+
|
5054
|
+
for (int i = 0; i < 16; ++i) {
|
5055
|
+
yl[i] = y4[i];
|
5056
|
+
}
|
5057
|
+
|
5058
|
+
const int ibl = ib32 / (QK_K / 32);
|
5059
|
+
const int ib = ib32 % (QK_K / 32);
|
5060
|
+
|
5061
|
+
device const block_iq1_s * xr = x + ibl;
|
5062
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
5063
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
5064
|
+
device const half * dh = &xr->d;
|
5065
|
+
|
5066
|
+
for (int row = 0; row < N_DST; row++) {
|
5067
|
+
|
5068
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
5069
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
5070
|
+
|
5071
|
+
float2 sum = {0};
|
5072
|
+
for (int j = 0; j < 8; ++j) {
|
5073
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
5074
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
5075
|
+
}
|
5076
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
5077
|
+
|
5078
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
5079
|
+
qs += nb*sizeof(block_iq1_s);
|
5080
|
+
sc += nb*sizeof(block_iq1_s);
|
5081
|
+
}
|
5082
|
+
|
5083
|
+
y4 += 16 * 32;
|
5084
|
+
}
|
5085
|
+
|
5086
|
+
for (int row = 0; row < N_DST; ++row) {
|
5087
|
+
all_sum = simd_sum(sumf[row]);
|
5088
|
+
if (tiisg == 0) {
|
5089
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
5090
|
+
}
|
5091
|
+
}
|
5092
|
+
}
|
5093
|
+
|
5094
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
5095
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
5096
|
+
};
|
5097
|
+
|
5098
|
+
void kernel_mul_mv_iq4_nl_f32_impl(
|
5099
|
+
device const void * src0,
|
5100
|
+
device const float * src1,
|
5101
|
+
device float * dst,
|
5102
|
+
constant int64_t & ne00,
|
5103
|
+
constant int64_t & ne01,
|
5104
|
+
constant int64_t & ne02,
|
5105
|
+
constant int64_t & ne10,
|
5106
|
+
constant int64_t & ne12,
|
5107
|
+
constant int64_t & ne0,
|
5108
|
+
constant int64_t & ne1,
|
5109
|
+
constant uint & r2,
|
5110
|
+
constant uint & r3,
|
5111
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5112
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5113
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5114
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5115
|
+
|
5116
|
+
const int nb = ne00/QK4_NL;
|
5117
|
+
const int r0 = tgpig.x;
|
5118
|
+
const int r1 = tgpig.y;
|
5119
|
+
const int im = tgpig.z;
|
5120
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
5121
|
+
const int ib_row = first_row * nb;
|
5122
|
+
|
5123
|
+
const uint i12 = im%ne12;
|
5124
|
+
const uint i13 = im/ne12;
|
5125
|
+
|
5126
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
5127
|
+
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
5128
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
5129
|
+
|
5130
|
+
const int ix = tiisg/2; // 0...15
|
5131
|
+
const int it = tiisg%2; // 0 or 1
|
5132
|
+
|
5133
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
5134
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5135
|
+
|
5136
|
+
float4 yl[4];
|
5137
|
+
float sumf[2]={0.f}, all_sum;
|
5138
|
+
|
5139
|
+
device const float * yb = y + ix * QK4_NL + it * 8;
|
5140
|
+
|
5141
|
+
uint32_t aux32[2];
|
5142
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
5143
|
+
|
5144
|
+
float4 qf1, qf2;
|
5145
|
+
|
5146
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
5147
|
+
|
5148
|
+
device const float4 * y4 = (device const float4 *)yb;
|
5149
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
5150
|
+
|
5151
|
+
for (int row = 0; row < 2; ++row) {
|
5152
|
+
|
5153
|
+
device const block_iq4_nl & xb = x[row*nb + ib];
|
5154
|
+
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
5155
|
+
|
5156
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
5157
|
+
|
5158
|
+
aux32[0] = q4[0] | (q4[1] << 16);
|
5159
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
5160
|
+
aux32[0] &= 0x0f0f0f0f;
|
5161
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5162
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5163
|
+
acc1 += yl[0] * qf1;
|
5164
|
+
acc2 += yl[1] * qf2;
|
5165
|
+
|
5166
|
+
aux32[0] = q4[2] | (q4[3] << 16);
|
5167
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
5168
|
+
aux32[0] &= 0x0f0f0f0f;
|
5169
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5170
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5171
|
+
acc1 += yl[2] * qf1;
|
5172
|
+
acc2 += yl[3] * qf2;
|
5173
|
+
|
5174
|
+
acc1 += acc2;
|
5175
|
+
|
5176
|
+
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
5177
|
+
|
5178
|
+
}
|
5179
|
+
|
5180
|
+
yb += 16 * QK4_NL;
|
5181
|
+
}
|
5182
|
+
|
5183
|
+
for (int row = 0; row < 2; ++row) {
|
5184
|
+
all_sum = simd_sum(sumf[row]);
|
5185
|
+
if (tiisg == 0) {
|
5186
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
5187
|
+
}
|
5188
|
+
}
|
5189
|
+
}
|
5190
|
+
|
5191
|
+
#if QK_K != 64
|
5192
|
+
void kernel_mul_mv_iq4_xs_f32_impl(
|
5193
|
+
device const void * src0,
|
5194
|
+
device const float * src1,
|
5195
|
+
device float * dst,
|
5196
|
+
constant int64_t & ne00,
|
5197
|
+
constant int64_t & ne01,
|
5198
|
+
constant int64_t & ne02,
|
5199
|
+
constant int64_t & ne10,
|
5200
|
+
constant int64_t & ne12,
|
5201
|
+
constant int64_t & ne0,
|
5202
|
+
constant int64_t & ne1,
|
5203
|
+
constant uint & r2,
|
5204
|
+
constant uint & r3,
|
5205
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5206
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5207
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5208
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5209
|
+
|
5210
|
+
const int nb = ne00/QK_K;
|
5211
|
+
const int r0 = tgpig.x;
|
5212
|
+
const int r1 = tgpig.y;
|
5213
|
+
const int im = tgpig.z;
|
5214
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
5215
|
+
const int ib_row = first_row * nb;
|
5216
|
+
|
5217
|
+
const uint i12 = im%ne12;
|
5218
|
+
const uint i13 = im/ne12;
|
5219
|
+
|
5220
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
5221
|
+
device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
|
5222
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
5223
|
+
|
5224
|
+
const int ix = tiisg/16; // 0 or 1
|
5225
|
+
const int it = tiisg%16; // 0...15
|
5226
|
+
const int ib = it/2;
|
5227
|
+
const int il = it%2;
|
5228
|
+
|
5229
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
5230
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5231
|
+
|
5232
|
+
float4 yl[4];
|
5233
|
+
float sumf[2]={0.f}, all_sum;
|
5234
|
+
|
5235
|
+
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
5236
|
+
|
5237
|
+
uint32_t aux32[2];
|
5238
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
5239
|
+
|
5240
|
+
float4 qf1, qf2;
|
5241
|
+
|
5242
|
+
for (int ibl = ix; ibl < nb; ibl += 2) {
|
5243
|
+
|
5244
|
+
device const float4 * y4 = (device const float4 *)yb;
|
5245
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
5246
|
+
|
5247
|
+
for (int row = 0; row < 2; ++row) {
|
5248
|
+
|
5249
|
+
device const block_iq4_xs & xb = x[row*nb + ibl];
|
5250
|
+
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
5251
|
+
|
5252
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
5253
|
+
|
5254
|
+
aux32[0] = q4[0] & 0x0f0f0f0f;
|
5255
|
+
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
5256
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5257
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5258
|
+
acc1 += yl[0] * qf1;
|
5259
|
+
acc2 += yl[1] * qf2;
|
5260
|
+
|
5261
|
+
aux32[0] = q4[1] & 0x0f0f0f0f;
|
5262
|
+
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
5263
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5264
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5265
|
+
acc1 += yl[2] * qf1;
|
5266
|
+
acc2 += yl[3] * qf2;
|
5267
|
+
|
5268
|
+
acc1 += acc2;
|
5269
|
+
|
5270
|
+
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
5271
|
+
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
5272
|
+
|
5273
|
+
}
|
5274
|
+
|
5275
|
+
yb += 2 * QK_K;
|
5276
|
+
}
|
5277
|
+
|
5278
|
+
for (int row = 0; row < 2; ++row) {
|
5279
|
+
all_sum = simd_sum(sumf[row]);
|
5280
|
+
if (tiisg == 0) {
|
5281
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
5282
|
+
}
|
5283
|
+
}
|
5284
|
+
}
|
5285
|
+
#endif
|
5286
|
+
|
5287
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5288
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
5289
|
+
device const void * src0,
|
5290
|
+
device const float * src1,
|
5291
|
+
device float * dst,
|
5292
|
+
constant int64_t & ne00,
|
5293
|
+
constant int64_t & ne01,
|
5294
|
+
constant int64_t & ne02,
|
5295
|
+
constant uint64_t & nb00,
|
5296
|
+
constant uint64_t & nb01,
|
5297
|
+
constant uint64_t & nb02,
|
5298
|
+
constant int64_t & ne10,
|
5299
|
+
constant int64_t & ne11,
|
5300
|
+
constant int64_t & ne12,
|
5301
|
+
constant uint64_t & nb10,
|
5302
|
+
constant uint64_t & nb11,
|
5303
|
+
constant uint64_t & nb12,
|
5304
|
+
constant int64_t & ne0,
|
5305
|
+
constant int64_t & ne1,
|
5306
|
+
constant uint & r2,
|
5307
|
+
constant uint & r3,
|
5308
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5309
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5310
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5311
|
+
|
5312
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
5313
|
+
}
|
5314
|
+
|
5315
|
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
5316
|
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
5317
|
+
device const void * src0,
|
5318
|
+
device const float * src1,
|
5319
|
+
device float * dst,
|
5320
|
+
constant int64_t & ne00,
|
5321
|
+
constant int64_t & ne01,
|
5322
|
+
constant int64_t & ne02,
|
5323
|
+
constant uint64_t & nb00,
|
5324
|
+
constant uint64_t & nb01,
|
5325
|
+
constant uint64_t & nb02,
|
5326
|
+
constant int64_t & ne10,
|
5327
|
+
constant int64_t & ne11,
|
5328
|
+
constant int64_t & ne12,
|
5329
|
+
constant uint64_t & nb10,
|
5330
|
+
constant uint64_t & nb11,
|
5331
|
+
constant uint64_t & nb12,
|
5332
|
+
constant int64_t & ne0,
|
5333
|
+
constant int64_t & ne1,
|
5334
|
+
constant uint & r2,
|
5335
|
+
constant uint & r3,
|
5336
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5337
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5338
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5339
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5340
|
+
|
5341
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5342
|
+
}
|
5343
|
+
|
5344
|
+
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
5345
|
+
kernel void kernel_mul_mv_iq4_xs_f32(
|
5346
|
+
device const void * src0,
|
5347
|
+
device const float * src1,
|
5348
|
+
device float * dst,
|
5349
|
+
constant int64_t & ne00,
|
5350
|
+
constant int64_t & ne01,
|
5351
|
+
constant int64_t & ne02,
|
5352
|
+
constant uint64_t & nb00,
|
5353
|
+
constant uint64_t & nb01,
|
5354
|
+
constant uint64_t & nb02,
|
5355
|
+
constant int64_t & ne10,
|
5356
|
+
constant int64_t & ne11,
|
5357
|
+
constant int64_t & ne12,
|
5358
|
+
constant uint64_t & nb10,
|
5359
|
+
constant uint64_t & nb11,
|
5360
|
+
constant uint64_t & nb12,
|
5361
|
+
constant int64_t & ne0,
|
5362
|
+
constant int64_t & ne1,
|
5363
|
+
constant uint & r2,
|
5364
|
+
constant uint & r3,
|
5365
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5366
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5367
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5368
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5369
|
+
|
5370
|
+
#if QK_K == 64
|
5371
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5372
|
+
#else
|
5373
|
+
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5374
|
+
#endif
|
5375
|
+
}
|
5376
|
+
|
5377
|
+
//============================= templates and their specializations =============================
|
5378
|
+
|
5379
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
5380
|
+
template <typename type4x4>
|
5381
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
5382
|
+
float4x4 temp = *(((device float4x4 *)src));
|
5383
|
+
for (int i = 0; i < 16; i++){
|
5384
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
5385
|
+
}
|
5386
|
+
}
|
5387
|
+
|
5388
|
+
template <typename type4x4>
|
5389
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
5390
|
+
half4x4 temp = *(((device half4x4 *)src));
|
5391
|
+
for (int i = 0; i < 16; i++){
|
5392
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
5393
|
+
}
|
5394
|
+
}
|
5395
|
+
|
5396
|
+
template <typename type4x4>
|
5397
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
5398
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
5399
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
5400
|
+
const float d2 = d1 / 256.f;
|
5401
|
+
const float md = -8.h * xb->d;
|
4634
5402
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
4635
5403
|
const ushort mask1 = mask0 << 8;
|
4636
5404
|
|
@@ -4952,6 +5720,50 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
4952
5720
|
}
|
4953
5721
|
}
|
4954
5722
|
|
5723
|
+
template <typename type4x4>
|
5724
|
+
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
5725
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5726
|
+
const float d = xb->d;
|
5727
|
+
const int ib32 = il/2;
|
5728
|
+
il = il%2;
|
5729
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
5730
|
+
device const uint8_t * qs = xb->qs + 8*ib32;
|
5731
|
+
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
5732
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
5733
|
+
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
5734
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
5735
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
5736
|
+
for (int i = 0; i < 4; ++i) {
|
5737
|
+
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
5738
|
+
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
5739
|
+
}
|
5740
|
+
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
5741
|
+
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
5742
|
+
for (int i = 0; i < 4; ++i) {
|
5743
|
+
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
5744
|
+
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
5745
|
+
}
|
5746
|
+
}
|
5747
|
+
|
5748
|
+
template <typename type4x4>
|
5749
|
+
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
|
5750
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5751
|
+
const float d = xb->d;
|
5752
|
+
const int ib32 = il/2;
|
5753
|
+
il = il%2;
|
5754
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
5755
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5756
|
+
device const uint8_t * signs = qs + QK_K/8;
|
5757
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
5758
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
5759
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
|
5760
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
|
5761
|
+
for (int i = 0; i < 8; ++i) {
|
5762
|
+
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
|
5763
|
+
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
|
5764
|
+
}
|
5765
|
+
}
|
5766
|
+
|
4955
5767
|
template <typename type4x4>
|
4956
5768
|
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
4957
5769
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
@@ -4983,6 +5795,30 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
4983
5795
|
}
|
4984
5796
|
}
|
4985
5797
|
|
5798
|
+
template <typename type4x4>
|
5799
|
+
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
5800
|
+
#if QK_K == 64
|
5801
|
+
dequantize_iq4_nl(xb, il, reg);
|
5802
|
+
#else
|
5803
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5804
|
+
const int ib32 = il/2;
|
5805
|
+
il = il%2;
|
5806
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
5807
|
+
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
5808
|
+
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
|
5809
|
+
const float d = (float)xb->d * (ls - 32);
|
5810
|
+
uint32_t aux32;
|
5811
|
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
5812
|
+
for (int i = 0; i < 4; ++i) {
|
5813
|
+
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
|
5814
|
+
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
5815
|
+
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
5816
|
+
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
5817
|
+
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
5818
|
+
}
|
5819
|
+
#endif
|
5820
|
+
}
|
5821
|
+
|
4986
5822
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
4987
5823
|
kernel void kernel_get_rows(
|
4988
5824
|
device const void * src0,
|
@@ -5525,8 +6361,15 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
5525
6361
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5526
6362
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5527
6363
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6364
|
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6365
|
+
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5528
6366
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5529
|
-
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2,
|
6367
|
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6368
|
+
#if QK_K == 64
|
6369
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6370
|
+
#else
|
6371
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6372
|
+
#endif
|
5530
6373
|
|
5531
6374
|
//
|
5532
6375
|
// matrix-matrix multiplication
|
@@ -5566,8 +6409,15 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
5566
6409
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5567
6410
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5568
6411
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6412
|
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6413
|
+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5569
6414
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5570
|
-
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2,
|
6415
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6416
|
+
#if QK_K == 64
|
6417
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6418
|
+
#else
|
6419
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6420
|
+
#endif
|
5571
6421
|
|
5572
6422
|
//
|
5573
6423
|
// indirect matrix-matrix multiplication
|
@@ -5619,8 +6469,15 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
5619
6469
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5620
6470
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5621
6471
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6472
|
+
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6473
|
+
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5622
6474
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5623
|
-
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2,
|
6475
|
+
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6476
|
+
#if QK_K == 64
|
6477
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6478
|
+
#else
|
6479
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6480
|
+
#endif
|
5624
6481
|
|
5625
6482
|
//
|
5626
6483
|
// matrix-vector multiplication
|
@@ -6589,6 +7446,136 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6589
7446
|
sgitg);
|
6590
7447
|
}
|
6591
7448
|
|
7449
|
+
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
7450
|
+
kernel void kernel_mul_mv_id_iq3_s_f32(
|
7451
|
+
device const char * ids,
|
7452
|
+
device const char * src1,
|
7453
|
+
device float * dst,
|
7454
|
+
constant uint64_t & nbi1,
|
7455
|
+
constant int64_t & ne00,
|
7456
|
+
constant int64_t & ne01,
|
7457
|
+
constant int64_t & ne02,
|
7458
|
+
constant uint64_t & nb00,
|
7459
|
+
constant uint64_t & nb01,
|
7460
|
+
constant uint64_t & nb02,
|
7461
|
+
constant int64_t & ne10,
|
7462
|
+
constant int64_t & ne11,
|
7463
|
+
constant int64_t & ne12,
|
7464
|
+
constant int64_t & ne13,
|
7465
|
+
constant uint64_t & nb10,
|
7466
|
+
constant uint64_t & nb11,
|
7467
|
+
constant uint64_t & nb12,
|
7468
|
+
constant int64_t & ne0,
|
7469
|
+
constant int64_t & ne1,
|
7470
|
+
constant uint64_t & nb1,
|
7471
|
+
constant uint & r2,
|
7472
|
+
constant uint & r3,
|
7473
|
+
constant int & idx,
|
7474
|
+
device const char * src00,
|
7475
|
+
device const char * src01,
|
7476
|
+
device const char * src02,
|
7477
|
+
device const char * src03,
|
7478
|
+
device const char * src04,
|
7479
|
+
device const char * src05,
|
7480
|
+
device const char * src06,
|
7481
|
+
device const char * src07,
|
7482
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
7483
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7484
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7485
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7486
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7487
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7488
|
+
|
7489
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7490
|
+
|
7491
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7492
|
+
|
7493
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7494
|
+
|
7495
|
+
kernel_mul_mv_iq3_s_f32_impl(
|
7496
|
+
src0[id],
|
7497
|
+
(device const float *) (src1 + bid*nb11),
|
7498
|
+
dst + bid*ne0,
|
7499
|
+
ne00,
|
7500
|
+
ne01,
|
7501
|
+
ne02,
|
7502
|
+
ne10,
|
7503
|
+
ne12,
|
7504
|
+
ne0,
|
7505
|
+
ne1,
|
7506
|
+
r2,
|
7507
|
+
r3,
|
7508
|
+
shared_values,
|
7509
|
+
tgpig,
|
7510
|
+
tiisg,
|
7511
|
+
sgitg);
|
7512
|
+
}
|
7513
|
+
|
7514
|
+
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
7515
|
+
kernel void kernel_mul_mv_id_iq2_s_f32(
|
7516
|
+
device const char * ids,
|
7517
|
+
device const char * src1,
|
7518
|
+
device float * dst,
|
7519
|
+
constant uint64_t & nbi1,
|
7520
|
+
constant int64_t & ne00,
|
7521
|
+
constant int64_t & ne01,
|
7522
|
+
constant int64_t & ne02,
|
7523
|
+
constant uint64_t & nb00,
|
7524
|
+
constant uint64_t & nb01,
|
7525
|
+
constant uint64_t & nb02,
|
7526
|
+
constant int64_t & ne10,
|
7527
|
+
constant int64_t & ne11,
|
7528
|
+
constant int64_t & ne12,
|
7529
|
+
constant int64_t & ne13,
|
7530
|
+
constant uint64_t & nb10,
|
7531
|
+
constant uint64_t & nb11,
|
7532
|
+
constant uint64_t & nb12,
|
7533
|
+
constant int64_t & ne0,
|
7534
|
+
constant int64_t & ne1,
|
7535
|
+
constant uint64_t & nb1,
|
7536
|
+
constant uint & r2,
|
7537
|
+
constant uint & r3,
|
7538
|
+
constant int & idx,
|
7539
|
+
device const char * src00,
|
7540
|
+
device const char * src01,
|
7541
|
+
device const char * src02,
|
7542
|
+
device const char * src03,
|
7543
|
+
device const char * src04,
|
7544
|
+
device const char * src05,
|
7545
|
+
device const char * src06,
|
7546
|
+
device const char * src07,
|
7547
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
7548
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7549
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7550
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7551
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7552
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7553
|
+
|
7554
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7555
|
+
|
7556
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7557
|
+
|
7558
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7559
|
+
|
7560
|
+
kernel_mul_mv_iq2_s_f32_impl(
|
7561
|
+
src0[id],
|
7562
|
+
(device const float *) (src1 + bid*nb11),
|
7563
|
+
dst + bid*ne0,
|
7564
|
+
ne00,
|
7565
|
+
ne01,
|
7566
|
+
ne02,
|
7567
|
+
ne10,
|
7568
|
+
ne12,
|
7569
|
+
ne0,
|
7570
|
+
ne1,
|
7571
|
+
r2,
|
7572
|
+
r3,
|
7573
|
+
shared_values,
|
7574
|
+
tgpig,
|
7575
|
+
tiisg,
|
7576
|
+
sgitg);
|
7577
|
+
}
|
7578
|
+
|
6592
7579
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6593
7580
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6594
7581
|
device const char * ids,
|
@@ -6716,3 +7703,72 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6716
7703
|
tiisg,
|
6717
7704
|
sgitg);
|
6718
7705
|
}
|
7706
|
+
|
7707
|
+
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
7708
|
+
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
7709
|
+
device const char * ids,
|
7710
|
+
device const char * src1,
|
7711
|
+
device float * dst,
|
7712
|
+
constant uint64_t & nbi1,
|
7713
|
+
constant int64_t & ne00,
|
7714
|
+
constant int64_t & ne01,
|
7715
|
+
constant int64_t & ne02,
|
7716
|
+
constant uint64_t & nb00,
|
7717
|
+
constant uint64_t & nb01,
|
7718
|
+
constant uint64_t & nb02,
|
7719
|
+
constant int64_t & ne10,
|
7720
|
+
constant int64_t & ne11,
|
7721
|
+
constant int64_t & ne12,
|
7722
|
+
constant int64_t & ne13,
|
7723
|
+
constant uint64_t & nb10,
|
7724
|
+
constant uint64_t & nb11,
|
7725
|
+
constant uint64_t & nb12,
|
7726
|
+
constant int64_t & ne0,
|
7727
|
+
constant int64_t & ne1,
|
7728
|
+
constant uint64_t & nb1,
|
7729
|
+
constant uint & r2,
|
7730
|
+
constant uint & r3,
|
7731
|
+
constant int & idx,
|
7732
|
+
device const char * src00,
|
7733
|
+
device const char * src01,
|
7734
|
+
device const char * src02,
|
7735
|
+
device const char * src03,
|
7736
|
+
device const char * src04,
|
7737
|
+
device const char * src05,
|
7738
|
+
device const char * src06,
|
7739
|
+
device const char * src07,
|
7740
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
7741
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7742
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7743
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7744
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7745
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7746
|
+
|
7747
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7748
|
+
|
7749
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7750
|
+
|
7751
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7752
|
+
|
7753
|
+
#if QK_K == 64
|
7754
|
+
kernel_mul_mv_iq4_nl_f32_impl(
|
7755
|
+
#else
|
7756
|
+
kernel_mul_mv_iq4_xs_f32_impl(
|
7757
|
+
#endif
|
7758
|
+
src0[id],
|
7759
|
+
(device const float *) (src1 + bid*nb11),
|
7760
|
+
dst + bid*ne0,
|
7761
|
+
ne00,
|
7762
|
+
ne01,
|
7763
|
+
ne02,
|
7764
|
+
ne10,
|
7765
|
+
ne12,
|
7766
|
+
ne0,
|
7767
|
+
ne1,
|
7768
|
+
r2,
|
7769
|
+
r3,
|
7770
|
+
shared_values,
|
7771
|
+
tgpig,
|
7772
|
+
tiisg,
|
7773
|
+
sgitg);
|
7774
|
+
}
|