llama_cpp 0.12.7 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/ext/llama_cpp/llama_cpp.cpp +131 -288
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +29 -29
- data/vendor/tmp/llama.cpp/Makefile +10 -6
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +6 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +32 -23
- data/vendor/tmp/llama.cpp/ggml-backend.h +17 -16
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +949 -168
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +9 -3
- data/vendor/tmp/llama.cpp/ggml-metal.m +159 -22
- data/vendor/tmp/llama.cpp/ggml-metal.metal +1195 -139
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +27 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +1971 -271
- data/vendor/tmp/llama.cpp/ggml-quants.h +52 -0
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +3586 -1201
- data/vendor/tmp/llama.cpp/ggml-sycl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +39336 -43461
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +1391 -825
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +1 -0
- data/vendor/tmp/llama.cpp/ggml.c +545 -210
- data/vendor/tmp/llama.cpp/ggml.h +65 -23
- data/vendor/tmp/llama.cpp/llama.cpp +1458 -763
- data/vendor/tmp/llama.cpp/llama.h +81 -75
- data/vendor/tmp/llama.cpp/unicode.h +310 -1
- metadata +2 -2
@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
|
|
1959
1959
|
}
|
1960
1960
|
}
|
1961
1961
|
|
1962
|
+
kernel void kernel_arange_f32(
|
1963
|
+
device char * dst,
|
1964
|
+
constant int64_t & ne0,
|
1965
|
+
constant float & start,
|
1966
|
+
constant float & step,
|
1967
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1968
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1969
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1970
|
+
|
1971
|
+
device float * dst_ptr = (device float *) dst;
|
1972
|
+
|
1973
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
1974
|
+
dst_ptr[i0] = start + step * i0;
|
1975
|
+
}
|
1976
|
+
}
|
1977
|
+
|
1978
|
+
kernel void kernel_timestep_embedding_f32(
|
1979
|
+
device const char * src0,
|
1980
|
+
device char * dst,
|
1981
|
+
constant uint64_t & nb1,
|
1982
|
+
constant int & dim,
|
1983
|
+
constant int & max_period,
|
1984
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
1985
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
1986
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
1987
|
+
|
1988
|
+
int i = tgpig.x;
|
1989
|
+
device float * embed_data = (device float *)(dst + i*nb1);
|
1990
|
+
|
1991
|
+
int half_ = dim / 2;
|
1992
|
+
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
1993
|
+
float timestep = ((device float *)src0)[i];
|
1994
|
+
float freq = (float)exp(-log((float)max_period) * j / half_);
|
1995
|
+
float arg = timestep * freq;
|
1996
|
+
embed_data[j ] = cos(arg);
|
1997
|
+
embed_data[j + half_] = sin(arg);
|
1998
|
+
}
|
1999
|
+
|
2000
|
+
if (dim % 2 != 0 && tpitg.x == 0) {
|
2001
|
+
embed_data[dim] = 0.f;
|
2002
|
+
}
|
2003
|
+
}
|
2004
|
+
|
1962
2005
|
// bitonic sort implementation following the CUDA kernels as reference
|
1963
2006
|
typedef void (argsort_t)(
|
1964
2007
|
device const float * x,
|
@@ -2519,12 +2562,34 @@ typedef struct {
|
|
2519
2562
|
} block_iq2_xs;
|
2520
2563
|
// 74 bytes / block for QK_K = 256, so 2.3125 bpw
|
2521
2564
|
|
2565
|
+
// 2.5625 bpw quants
|
2566
|
+
typedef struct {
|
2567
|
+
half d;
|
2568
|
+
uint8_t qs[QK_K/4];
|
2569
|
+
uint8_t qh[QK_K/32];
|
2570
|
+
uint8_t scales[QK_K/32];
|
2571
|
+
} block_iq2_s;
|
2572
|
+
|
2522
2573
|
typedef struct {
|
2523
2574
|
half d;
|
2524
2575
|
uint8_t qs[3*QK_K/8];
|
2525
2576
|
} block_iq3_xxs;
|
2526
2577
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
2527
2578
|
|
2579
|
+
// 3.4375 bpw
|
2580
|
+
#if QK_K == 64
|
2581
|
+
#define IQ3S_N_SCALE 2
|
2582
|
+
#else
|
2583
|
+
#define IQ3S_N_SCALE QK_K/64
|
2584
|
+
#endif
|
2585
|
+
typedef struct {
|
2586
|
+
half d;
|
2587
|
+
uint8_t qs[QK_K/4];
|
2588
|
+
uint8_t qh[QK_K/32];
|
2589
|
+
uint8_t signs[QK_K/8];
|
2590
|
+
uint8_t scales[IQ3S_N_SCALE];
|
2591
|
+
} block_iq3_s;
|
2592
|
+
|
2528
2593
|
typedef struct {
|
2529
2594
|
half d;
|
2530
2595
|
uint8_t qs[QK_K/8];
|
@@ -2538,6 +2603,17 @@ typedef struct {
|
|
2538
2603
|
uint8_t qs[QK4_NL/2];
|
2539
2604
|
} block_iq4_nl;
|
2540
2605
|
|
2606
|
+
#if QK_K == 64
|
2607
|
+
#define block_iq4_xs block_iq4_nl
|
2608
|
+
#else
|
2609
|
+
typedef struct {
|
2610
|
+
half d;
|
2611
|
+
uint16_t scales_h;
|
2612
|
+
uint8_t scales_l[QK_K/64];
|
2613
|
+
uint8_t qs[QK_K/2];
|
2614
|
+
} block_iq4_xs;
|
2615
|
+
#endif
|
2616
|
+
|
2541
2617
|
//====================================== dot products =========================
|
2542
2618
|
|
2543
2619
|
void kernel_mul_mv_q2_K_f32_impl(
|
@@ -3760,6 +3836,265 @@ constexpr constant static uint64_t iq2xs_grid[512] = {
|
|
3760
3836
|
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
|
3761
3837
|
};
|
3762
3838
|
|
3839
|
+
constexpr constant static uint64_t iq2s_grid[1024] = {
|
3840
|
+
0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
|
3841
|
+
0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
|
3842
|
+
0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
|
3843
|
+
0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
|
3844
|
+
0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
|
3845
|
+
0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x08080808192b192b,
|
3846
|
+
0x08080808192b2b19, 0x080808082b080808, 0x080808082b08082b, 0x080808082b081919,
|
3847
|
+
0x080808082b082b08, 0x080808082b190819, 0x080808082b191908, 0x080808082b2b0808,
|
3848
|
+
0x080808082b2b1919, 0x080808082b2b2b2b, 0x0808081908080819, 0x0808081908081908,
|
3849
|
+
0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, 0x080808190819082b,
|
3850
|
+
0x0808081908191919, 0x0808081908192b08, 0x08080819082b0819, 0x08080819082b1908,
|
3851
|
+
0x0808081919080808, 0x080808191908082b, 0x0808081919081919, 0x0808081919082b08,
|
3852
|
+
0x0808081919190819, 0x0808081919191908, 0x080808191919192b, 0x0808081919192b19,
|
3853
|
+
0x08080819192b0808, 0x08080819192b1919, 0x08080819192b2b08, 0x080808192b080819,
|
3854
|
+
0x080808192b081908, 0x080808192b190808, 0x080808192b19082b, 0x080808192b191919,
|
3855
|
+
0x080808192b2b0819, 0x080808192b2b1908, 0x0808082b08080808, 0x0808082b0808082b,
|
3856
|
+
0x0808082b08081919, 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908,
|
3857
|
+
0x0808082b082b0808, 0x0808082b082b2b2b, 0x0808082b19080819, 0x0808082b19081908,
|
3858
|
+
0x0808082b1908192b, 0x0808082b19082b19, 0x0808082b19190808, 0x0808082b19191919,
|
3859
|
+
0x0808082b2b080808, 0x0808082b2b081919, 0x0808082b2b082b2b, 0x0808082b2b191908,
|
3860
|
+
0x0808082b2b2b082b, 0x0808190808080819, 0x0808190808081908, 0x080819080808192b,
|
3861
|
+
0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, 0x0808190808191919,
|
3862
|
+
0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908, 0x08081908082b192b,
|
3863
|
+
0x08081908082b2b19, 0x0808190819080808, 0x080819081908082b, 0x0808190819081919,
|
3864
|
+
0x0808190819082b08, 0x0808190819082b2b, 0x0808190819190819, 0x0808190819191908,
|
3865
|
+
0x080819081919192b, 0x0808190819192b19, 0x08081908192b0808, 0x08081908192b082b,
|
3866
|
+
0x08081908192b1919, 0x080819082b080819, 0x080819082b081908, 0x080819082b08192b,
|
3867
|
+
0x080819082b082b19, 0x080819082b190808, 0x080819082b191919, 0x080819082b192b08,
|
3868
|
+
0x080819082b2b0819, 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b,
|
3869
|
+
0x0808191908081919, 0x0808191908082b08, 0x0808191908082b2b, 0x0808191908190819,
|
3870
|
+
0x0808191908191908, 0x080819190819192b, 0x0808191908192b19, 0x08081919082b0808,
|
3871
|
+
0x08081919082b1919, 0x08081919082b2b08, 0x0808191919080819, 0x0808191919081908,
|
3872
|
+
0x080819191908192b, 0x0808191919082b19, 0x0808191919190808, 0x080819191919082b,
|
3873
|
+
0x0808191919191919, 0x0808191919192b08, 0x08081919192b0819, 0x08081919192b1908,
|
3874
|
+
0x080819192b080808, 0x080819192b08082b, 0x080819192b081919, 0x080819192b082b08,
|
3875
|
+
0x080819192b190819, 0x080819192b191908, 0x080819192b2b0808, 0x0808192b08080819,
|
3876
|
+
0x0808192b08081908, 0x0808192b0808192b, 0x0808192b08082b19, 0x0808192b08190808,
|
3877
|
+
0x0808192b08191919, 0x0808192b19080808, 0x0808192b19081919, 0x0808192b19082b08,
|
3878
|
+
0x0808192b19190819, 0x0808192b19191908, 0x0808192b192b0808, 0x0808192b2b080819,
|
3879
|
+
0x0808192b2b081908, 0x0808192b2b190808, 0x08082b0808080808, 0x08082b080808082b,
|
3880
|
+
0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808190819, 0x08082b0808191908,
|
3881
|
+
0x08082b080819192b, 0x08082b0808192b19, 0x08082b08082b0808, 0x08082b08082b1919,
|
3882
|
+
0x08082b08082b2b2b, 0x08082b0819080819, 0x08082b0819081908, 0x08082b081908192b,
|
3883
|
+
0x08082b0819082b19, 0x08082b0819190808, 0x08082b081919082b, 0x08082b0819191919,
|
3884
|
+
0x08082b0819192b08, 0x08082b08192b0819, 0x08082b08192b1908, 0x08082b082b080808,
|
3885
|
+
0x08082b082b081919, 0x08082b082b191908, 0x08082b082b2b2b2b, 0x08082b1908080819,
|
3886
|
+
0x08082b1908081908, 0x08082b1908190808, 0x08082b190819082b, 0x08082b1908191919,
|
3887
|
+
0x08082b1908192b08, 0x08082b19082b0819, 0x08082b1919080808, 0x08082b1919081919,
|
3888
|
+
0x08082b1919082b08, 0x08082b1919190819, 0x08082b1919191908, 0x08082b19192b0808,
|
3889
|
+
0x08082b192b080819, 0x08082b192b190808, 0x08082b2b08080808, 0x08082b2b08190819,
|
3890
|
+
0x08082b2b08191908, 0x08082b2b082b082b, 0x08082b2b082b2b08, 0x08082b2b082b2b2b,
|
3891
|
+
0x08082b2b19190808, 0x08082b2b2b192b19, 0x0819080808080819, 0x0819080808081908,
|
3892
|
+
0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b,
|
3893
|
+
0x0819080808191919, 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908,
|
3894
|
+
0x08190808082b192b, 0x0819080819080808, 0x081908081908082b, 0x0819080819081919,
|
3895
|
+
0x0819080819082b08, 0x0819080819190819, 0x0819080819191908, 0x081908081919192b,
|
3896
|
+
0x0819080819192b19, 0x08190808192b0808, 0x08190808192b082b, 0x08190808192b1919,
|
3897
|
+
0x08190808192b2b08, 0x081908082b080819, 0x081908082b081908, 0x081908082b08192b,
|
3898
|
+
0x081908082b190808, 0x081908082b191919, 0x081908082b192b08, 0x081908082b2b0819,
|
3899
|
+
0x081908082b2b1908, 0x0819081908080808, 0x081908190808082b, 0x0819081908081919,
|
3900
|
+
0x0819081908082b08, 0x0819081908082b2b, 0x0819081908190819, 0x0819081908191908,
|
3901
|
+
0x081908190819192b, 0x0819081908192b19, 0x08190819082b0808, 0x08190819082b082b,
|
3902
|
+
0x08190819082b1919, 0x08190819082b2b08, 0x0819081919080819, 0x0819081919081908,
|
3903
|
+
0x081908191908192b, 0x0819081919082b19, 0x0819081919190808, 0x081908191919082b,
|
3904
|
+
0x0819081919191919, 0x0819081919192b08, 0x08190819192b0819, 0x08190819192b1908,
|
3905
|
+
0x081908192b080808, 0x081908192b08082b, 0x081908192b081919, 0x081908192b082b08,
|
3906
|
+
0x081908192b190819, 0x081908192b191908, 0x0819082b08080819, 0x0819082b08081908,
|
3907
|
+
0x0819082b08082b19, 0x0819082b08190808, 0x0819082b08191919, 0x0819082b082b0819,
|
3908
|
+
0x0819082b082b1908, 0x0819082b19080808, 0x0819082b19081919, 0x0819082b19190819,
|
3909
|
+
0x0819082b19191908, 0x0819082b2b080819, 0x0819082b2b081908, 0x0819082b2b190808,
|
3910
|
+
0x0819190808080808, 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08,
|
3911
|
+
0x0819190808190819, 0x0819190808191908, 0x081919080819192b, 0x0819190808192b19,
|
3912
|
+
0x08191908082b0808, 0x08191908082b1919, 0x08191908082b2b08, 0x0819190819080819,
|
3913
|
+
0x0819190819081908, 0x081919081908192b, 0x0819190819082b19, 0x0819190819190808,
|
3914
|
+
0x081919081919082b, 0x0819190819191919, 0x0819190819192b08, 0x08191908192b0819,
|
3915
|
+
0x08191908192b1908, 0x081919082b080808, 0x081919082b08082b, 0x081919082b081919,
|
3916
|
+
0x081919082b082b08, 0x081919082b190819, 0x081919082b191908, 0x081919082b2b0808,
|
3917
|
+
0x0819191908080819, 0x0819191908081908, 0x081919190808192b, 0x0819191908082b19,
|
3918
|
+
0x0819191908190808, 0x081919190819082b, 0x0819191908191919, 0x0819191908192b08,
|
3919
|
+
0x08191919082b0819, 0x08191919082b1908, 0x0819191919080808, 0x081919191908082b,
|
3920
|
+
0x0819191919081919, 0x0819191919082b08, 0x0819191919190819, 0x0819191919191908,
|
3921
|
+
0x08191919192b0808, 0x081919192b080819, 0x081919192b081908, 0x081919192b190808,
|
3922
|
+
0x0819192b08080808, 0x0819192b08081919, 0x0819192b08082b08, 0x0819192b08190819,
|
3923
|
+
0x0819192b08191908, 0x0819192b082b0808, 0x0819192b19080819, 0x0819192b19081908,
|
3924
|
+
0x0819192b19190808, 0x0819192b2b080808, 0x0819192b2b2b2b2b, 0x08192b0808080819,
|
3925
|
+
0x08192b0808081908, 0x08192b080808192b, 0x08192b0808082b19, 0x08192b0808190808,
|
3926
|
+
0x08192b0808191919, 0x08192b0808192b08, 0x08192b08082b0819, 0x08192b0819080808,
|
3927
|
+
0x08192b081908082b, 0x08192b0819081919, 0x08192b0819082b08, 0x08192b0819190819,
|
3928
|
+
0x08192b0819191908, 0x08192b08192b0808, 0x08192b082b080819, 0x08192b082b081908,
|
3929
|
+
0x08192b1908080808, 0x08192b190808082b, 0x08192b1908081919, 0x08192b1908082b08,
|
3930
|
+
0x08192b1908190819, 0x08192b1908191908, 0x08192b19082b0808, 0x08192b1919080819,
|
3931
|
+
0x08192b1919081908, 0x08192b1919190808, 0x08192b19192b2b19, 0x08192b192b2b082b,
|
3932
|
+
0x08192b2b08081908, 0x08192b2b08190808, 0x08192b2b19080808, 0x08192b2b1919192b,
|
3933
|
+
0x082b080808080808, 0x082b08080808082b, 0x082b080808081919, 0x082b080808082b08,
|
3934
|
+
0x082b080808190819, 0x082b080808191908, 0x082b08080819192b, 0x082b080808192b19,
|
3935
|
+
0x082b0808082b0808, 0x082b0808082b1919, 0x082b0808082b2b2b, 0x082b080819080819,
|
3936
|
+
0x082b080819081908, 0x082b080819190808, 0x082b08081919082b, 0x082b080819191919,
|
3937
|
+
0x082b0808192b1908, 0x082b08082b080808, 0x082b08082b082b2b, 0x082b08082b191908,
|
3938
|
+
0x082b08082b2b2b2b, 0x082b081908080819, 0x082b081908081908, 0x082b081908190808,
|
3939
|
+
0x082b08190819082b, 0x082b081908191919, 0x082b0819082b0819, 0x082b081919080808,
|
3940
|
+
0x082b08191908082b, 0x082b081919081919, 0x082b081919190819, 0x082b081919191908,
|
3941
|
+
0x082b0819192b0808, 0x082b08192b080819, 0x082b08192b081908, 0x082b08192b190808,
|
3942
|
+
0x082b082b08080808, 0x082b082b08082b2b, 0x082b082b082b082b, 0x082b082b082b2b08,
|
3943
|
+
0x082b082b082b2b2b, 0x082b082b19081908, 0x082b082b19190808, 0x082b082b2b082b08,
|
3944
|
+
0x082b082b2b082b2b, 0x082b082b2b2b2b08, 0x082b190808080819, 0x082b190808081908,
|
3945
|
+
0x082b19080808192b, 0x082b190808082b19, 0x082b190808190808, 0x082b190808191919,
|
3946
|
+
0x082b190808192b08, 0x082b1908082b0819, 0x082b1908082b1908, 0x082b190819080808,
|
3947
|
+
0x082b19081908082b, 0x082b190819081919, 0x082b190819082b08, 0x082b190819190819,
|
3948
|
+
0x082b190819191908, 0x082b1908192b0808, 0x082b19082b080819, 0x082b19082b081908,
|
3949
|
+
0x082b19082b190808, 0x082b191908080808, 0x082b191908081919, 0x082b191908082b08,
|
3950
|
+
0x082b191908190819, 0x082b191908191908, 0x082b1919082b0808, 0x082b191919080819,
|
3951
|
+
0x082b191919081908, 0x082b191919190808, 0x082b1919192b192b, 0x082b19192b080808,
|
3952
|
+
0x082b192b08080819, 0x082b192b08081908, 0x082b192b08190808, 0x082b192b19080808,
|
3953
|
+
0x082b192b19192b19, 0x082b2b0808080808, 0x082b2b0808081919, 0x082b2b0808190819,
|
3954
|
+
0x082b2b0808191908, 0x082b2b0819080819, 0x082b2b0819081908, 0x082b2b0819190808,
|
3955
|
+
0x082b2b082b082b2b, 0x082b2b082b2b2b2b, 0x082b2b1908080819, 0x082b2b1908081908,
|
3956
|
+
0x082b2b1908190808, 0x082b2b192b191919, 0x082b2b2b08082b2b, 0x082b2b2b082b082b,
|
3957
|
+
0x082b2b2b192b1908, 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819,
|
3958
|
+
0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808,
|
3959
|
+
0x190808080819082b, 0x1908080808191919, 0x1908080808192b08, 0x1908080808192b2b,
|
3960
|
+
0x19080808082b0819, 0x19080808082b1908, 0x19080808082b192b, 0x1908080819080808,
|
3961
|
+
0x190808081908082b, 0x1908080819081919, 0x1908080819082b08, 0x1908080819082b2b,
|
3962
|
+
0x1908080819190819, 0x1908080819191908, 0x190808081919192b, 0x1908080819192b19,
|
3963
|
+
0x19080808192b0808, 0x19080808192b082b, 0x19080808192b1919, 0x190808082b080819,
|
3964
|
+
0x190808082b081908, 0x190808082b190808, 0x190808082b191919, 0x190808082b192b08,
|
3965
|
+
0x190808082b2b0819, 0x190808082b2b1908, 0x1908081908080808, 0x190808190808082b,
|
3966
|
+
0x1908081908081919, 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908,
|
3967
|
+
0x190808190819192b, 0x1908081908192b19, 0x19080819082b0808, 0x19080819082b082b,
|
3968
|
+
0x19080819082b1919, 0x1908081919080819, 0x1908081919081908, 0x190808191908192b,
|
3969
|
+
0x1908081919082b19, 0x1908081919190808, 0x190808191919082b, 0x1908081919191919,
|
3970
|
+
0x1908081919192b08, 0x19080819192b0819, 0x19080819192b1908, 0x190808192b080808,
|
3971
|
+
0x190808192b08082b, 0x190808192b081919, 0x190808192b082b08, 0x190808192b190819,
|
3972
|
+
0x190808192b191908, 0x190808192b2b0808, 0x1908082b08080819, 0x1908082b08081908,
|
3973
|
+
0x1908082b08190808, 0x1908082b0819082b, 0x1908082b08191919, 0x1908082b08192b08,
|
3974
|
+
0x1908082b082b1908, 0x1908082b19080808, 0x1908082b19081919, 0x1908082b19082b08,
|
3975
|
+
0x1908082b19190819, 0x1908082b19191908, 0x1908082b192b0808, 0x1908082b2b080819,
|
3976
|
+
0x1908082b2b081908, 0x1908190808080808, 0x190819080808082b, 0x1908190808081919,
|
3977
|
+
0x1908190808082b08, 0x1908190808082b2b, 0x1908190808190819, 0x1908190808191908,
|
3978
|
+
0x190819080819192b, 0x1908190808192b19, 0x19081908082b0808, 0x19081908082b082b,
|
3979
|
+
0x19081908082b1919, 0x19081908082b2b08, 0x1908190819080819, 0x1908190819081908,
|
3980
|
+
0x190819081908192b, 0x1908190819082b19, 0x1908190819190808, 0x190819081919082b,
|
3981
|
+
0x1908190819191919, 0x1908190819192b08, 0x19081908192b0819, 0x19081908192b1908,
|
3982
|
+
0x190819082b080808, 0x190819082b08082b, 0x190819082b081919, 0x190819082b082b08,
|
3983
|
+
0x190819082b190819, 0x190819082b191908, 0x190819082b2b0808, 0x1908191908080819,
|
3984
|
+
0x1908191908081908, 0x190819190808192b, 0x1908191908082b19, 0x1908191908190808,
|
3985
|
+
0x190819190819082b, 0x1908191908191919, 0x1908191908192b08, 0x19081919082b0819,
|
3986
|
+
0x19081919082b1908, 0x1908191919080808, 0x190819191908082b, 0x1908191919081919,
|
3987
|
+
0x1908191919082b08, 0x1908191919190819, 0x1908191919191908, 0x19081919192b0808,
|
3988
|
+
0x19081919192b2b2b, 0x190819192b080819, 0x190819192b081908, 0x190819192b190808,
|
3989
|
+
0x1908192b08080808, 0x1908192b0808082b, 0x1908192b08081919, 0x1908192b08082b08,
|
3990
|
+
0x1908192b08190819, 0x1908192b08191908, 0x1908192b082b0808, 0x1908192b19080819,
|
3991
|
+
0x1908192b19081908, 0x1908192b19190808, 0x1908192b2b080808, 0x1908192b2b2b1919,
|
3992
|
+
0x19082b0808080819, 0x19082b0808081908, 0x19082b0808082b19, 0x19082b0808190808,
|
3993
|
+
0x19082b080819082b, 0x19082b0808191919, 0x19082b0808192b08, 0x19082b08082b0819,
|
3994
|
+
0x19082b08082b1908, 0x19082b0819080808, 0x19082b081908082b, 0x19082b0819081919,
|
3995
|
+
0x19082b0819082b08, 0x19082b0819190819, 0x19082b0819191908, 0x19082b08192b0808,
|
3996
|
+
0x19082b082b081908, 0x19082b082b190808, 0x19082b1908080808, 0x19082b190808082b,
|
3997
|
+
0x19082b1908081919, 0x19082b1908082b08, 0x19082b1908190819, 0x19082b1908191908,
|
3998
|
+
0x19082b19082b0808, 0x19082b1919080819, 0x19082b1919081908, 0x19082b1919190808,
|
3999
|
+
0x19082b192b080808, 0x19082b192b19192b, 0x19082b2b08080819, 0x19082b2b08081908,
|
4000
|
+
0x19082b2b08190808, 0x19082b2b19080808, 0x1919080808080808, 0x191908080808082b,
|
4001
|
+
0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, 0x1919080808191908,
|
4002
|
+
0x191908080819192b, 0x1919080808192b19, 0x19190808082b0808, 0x19190808082b082b,
|
4003
|
+
0x19190808082b1919, 0x19190808082b2b08, 0x1919080819080819, 0x1919080819081908,
|
4004
|
+
0x191908081908192b, 0x1919080819082b19, 0x1919080819190808, 0x191908081919082b,
|
4005
|
+
0x1919080819191919, 0x1919080819192b08, 0x19190808192b0819, 0x19190808192b1908,
|
4006
|
+
0x191908082b080808, 0x191908082b08082b, 0x191908082b081919, 0x191908082b082b08,
|
4007
|
+
0x191908082b190819, 0x191908082b191908, 0x1919081908080819, 0x1919081908081908,
|
4008
|
+
0x191908190808192b, 0x1919081908082b19, 0x1919081908190808, 0x191908190819082b,
|
4009
|
+
0x1919081908191919, 0x1919081908192b08, 0x19190819082b0819, 0x19190819082b1908,
|
4010
|
+
0x1919081919080808, 0x191908191908082b, 0x1919081919081919, 0x1919081919082b08,
|
4011
|
+
0x1919081919190819, 0x1919081919191908, 0x19190819192b0808, 0x191908192b080819,
|
4012
|
+
0x191908192b081908, 0x191908192b190808, 0x1919082b08080808, 0x1919082b08081919,
|
4013
|
+
0x1919082b08082b08, 0x1919082b08190819, 0x1919082b08191908, 0x1919082b082b0808,
|
4014
|
+
0x1919082b19080819, 0x1919082b19081908, 0x1919082b19190808, 0x1919082b192b2b19,
|
4015
|
+
0x1919082b2b080808, 0x1919190808080819, 0x1919190808081908, 0x191919080808192b,
|
4016
|
+
0x1919190808082b19, 0x1919190808190808, 0x191919080819082b, 0x1919190808191919,
|
4017
|
+
0x1919190808192b08, 0x19191908082b0819, 0x19191908082b1908, 0x1919190819080808,
|
4018
|
+
0x191919081908082b, 0x1919190819081919, 0x1919190819082b08, 0x1919190819190819,
|
4019
|
+
0x1919190819191908, 0x19191908192b0808, 0x191919082b080819, 0x191919082b081908,
|
4020
|
+
0x191919082b190808, 0x1919191908080808, 0x191919190808082b, 0x1919191908081919,
|
4021
|
+
0x1919191908082b08, 0x1919191908190819, 0x1919191908191908, 0x19191919082b0808,
|
4022
|
+
0x1919191919080819, 0x1919191919081908, 0x1919191919190808, 0x191919192b080808,
|
4023
|
+
0x1919192b08080819, 0x1919192b08081908, 0x1919192b08190808, 0x1919192b082b192b,
|
4024
|
+
0x1919192b19080808, 0x19192b0808080808, 0x19192b080808082b, 0x19192b0808081919,
|
4025
|
+
0x19192b0808082b08, 0x19192b0808190819, 0x19192b0808191908, 0x19192b08082b0808,
|
4026
|
+
0x19192b0819080819, 0x19192b0819081908, 0x19192b0819190808, 0x19192b0819192b2b,
|
4027
|
+
0x19192b082b080808, 0x19192b1908080819, 0x19192b1908081908, 0x19192b1908190808,
|
4028
|
+
0x19192b1919080808, 0x19192b2b08080808, 0x19192b2b08192b19, 0x19192b2b2b081919,
|
4029
|
+
0x19192b2b2b2b2b08, 0x192b080808080819, 0x192b080808081908, 0x192b08080808192b,
|
4030
|
+
0x192b080808190808, 0x192b08080819082b, 0x192b080808191919, 0x192b080808192b08,
|
4031
|
+
0x192b0808082b0819, 0x192b0808082b1908, 0x192b080819080808, 0x192b080819081919,
|
4032
|
+
0x192b080819082b08, 0x192b080819190819, 0x192b080819191908, 0x192b0808192b0808,
|
4033
|
+
0x192b08082b081908, 0x192b08082b190808, 0x192b081908080808, 0x192b08190808082b,
|
4034
|
+
0x192b081908081919, 0x192b081908082b08, 0x192b081908190819, 0x192b081908191908,
|
4035
|
+
0x192b0819082b0808, 0x192b081919080819, 0x192b081919081908, 0x192b081919190808,
|
4036
|
+
0x192b08192b080808, 0x192b08192b192b19, 0x192b082b08081908, 0x192b082b08190808,
|
4037
|
+
0x192b082b19080808, 0x192b082b1919192b, 0x192b082b2b2b0819, 0x192b190808080808,
|
4038
|
+
0x192b190808081919, 0x192b190808082b08, 0x192b190808190819, 0x192b190808191908,
|
4039
|
+
0x192b1908082b0808, 0x192b190819080819, 0x192b190819081908, 0x192b190819190808,
|
4040
|
+
0x192b19082b080808, 0x192b191908080819, 0x192b191908081908, 0x192b191908190808,
|
4041
|
+
0x192b191919080808, 0x192b191919082b2b, 0x192b1919192b2b08, 0x192b19192b19082b,
|
4042
|
+
0x192b192b08080808, 0x192b192b2b191908, 0x192b2b0808080819, 0x192b2b0808081908,
|
4043
|
+
0x192b2b0808190808, 0x192b2b08192b1919, 0x192b2b082b192b08, 0x192b2b1908080808,
|
4044
|
+
0x192b2b19082b2b2b, 0x192b2b2b1908082b, 0x192b2b2b2b2b0819, 0x2b08080808080808,
|
4045
|
+
0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819,
|
4046
|
+
0x2b08080808191908, 0x2b08080808192b19, 0x2b080808082b0808, 0x2b080808082b1919,
|
4047
|
+
0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808081919082b,
|
4048
|
+
0x2b08080819191919, 0x2b08080819192b08, 0x2b080808192b0819, 0x2b0808082b080808,
|
4049
|
+
0x2b0808082b081919, 0x2b0808082b190819, 0x2b0808082b191908, 0x2b08081908080819,
|
4050
|
+
0x2b08081908081908, 0x2b08081908082b19, 0x2b08081908190808, 0x2b0808190819082b,
|
4051
|
+
0x2b08081908191919, 0x2b08081908192b08, 0x2b080819082b0819, 0x2b080819082b1908,
|
4052
|
+
0x2b08081919080808, 0x2b0808191908082b, 0x2b08081919081919, 0x2b08081919082b08,
|
4053
|
+
0x2b08081919190819, 0x2b08081919191908, 0x2b0808192b080819, 0x2b0808192b081908,
|
4054
|
+
0x2b0808192b190808, 0x2b0808192b2b2b19, 0x2b08082b08080808, 0x2b08082b08081919,
|
4055
|
+
0x2b08082b08082b2b, 0x2b08082b08190819, 0x2b08082b08191908, 0x2b08082b19080819,
|
4056
|
+
0x2b08082b19081908, 0x2b08082b19190808, 0x2b08190808080819, 0x2b08190808081908,
|
4057
|
+
0x2b0819080808192b, 0x2b08190808082b19, 0x2b08190808190808, 0x2b0819080819082b,
|
4058
|
+
0x2b08190808191919, 0x2b08190808192b08, 0x2b081908082b0819, 0x2b08190819080808,
|
4059
|
+
0x2b0819081908082b, 0x2b08190819081919, 0x2b08190819082b08, 0x2b08190819190819,
|
4060
|
+
0x2b08190819191908, 0x2b081908192b0808, 0x2b0819082b080819, 0x2b0819082b081908,
|
4061
|
+
0x2b0819082b190808, 0x2b08191908080808, 0x2b0819190808082b, 0x2b08191908081919,
|
4062
|
+
0x2b08191908082b08, 0x2b08191908190819, 0x2b08191908191908, 0x2b081919082b0808,
|
4063
|
+
0x2b08191919080819, 0x2b08191919081908, 0x2b08191919190808, 0x2b0819192b080808,
|
4064
|
+
0x2b0819192b082b2b, 0x2b08192b08080819, 0x2b08192b08081908, 0x2b08192b08190808,
|
4065
|
+
0x2b08192b082b2b19, 0x2b08192b19080808, 0x2b082b0808080808, 0x2b082b0808081919,
|
4066
|
+
0x2b082b0808190819, 0x2b082b0808191908, 0x2b082b0819080819, 0x2b082b0819081908,
|
4067
|
+
0x2b082b0819190808, 0x2b082b082b2b082b, 0x2b082b1908080819, 0x2b082b1908081908,
|
4068
|
+
0x2b082b1919080808, 0x2b082b19192b1919, 0x2b082b2b082b082b, 0x2b082b2b19192b08,
|
4069
|
+
0x2b082b2b19192b2b, 0x2b082b2b2b08082b, 0x2b082b2b2b2b082b, 0x2b19080808080819,
|
4070
|
+
0x2b19080808081908, 0x2b19080808082b19, 0x2b19080808190808, 0x2b1908080819082b,
|
4071
|
+
0x2b19080808191919, 0x2b19080808192b08, 0x2b190808082b1908, 0x2b19080819080808,
|
4072
|
+
0x2b1908081908082b, 0x2b19080819081919, 0x2b19080819082b08, 0x2b19080819190819,
|
4073
|
+
0x2b19080819191908, 0x2b190808192b0808, 0x2b1908082b080819, 0x2b1908082b081908,
|
4074
|
+
0x2b1908082b190808, 0x2b19081908080808, 0x2b19081908081919, 0x2b19081908190819,
|
4075
|
+
0x2b19081908191908, 0x2b19081919080819, 0x2b19081919081908, 0x2b19081919190808,
|
4076
|
+
0x2b19081919192b2b, 0x2b19082b08080819, 0x2b19082b08081908, 0x2b19082b08190808,
|
4077
|
+
0x2b19082b19080808, 0x2b19082b2b2b192b, 0x2b19190808080808, 0x2b1919080808082b,
|
4078
|
+
0x2b19190808081919, 0x2b19190808082b08, 0x2b19190808190819, 0x2b19190808191908,
|
4079
|
+
0x2b191908082b0808, 0x2b19190819080819, 0x2b19190819081908, 0x2b19190819190808,
|
4080
|
+
0x2b1919082b080808, 0x2b1919082b19192b, 0x2b19191908080819, 0x2b19191908081908,
|
4081
|
+
0x2b19191908190808, 0x2b19191919080808, 0x2b1919192b192b08, 0x2b1919192b2b0819,
|
4082
|
+
0x2b19192b08080808, 0x2b19192b1908192b, 0x2b19192b192b1908, 0x2b192b0808080819,
|
4083
|
+
0x2b192b0808081908, 0x2b192b0808190808, 0x2b192b08082b192b, 0x2b192b0819080808,
|
4084
|
+
0x2b192b082b2b2b19, 0x2b192b1908080808, 0x2b192b1919082b19, 0x2b192b191919082b,
|
4085
|
+
0x2b192b2b2b190808, 0x2b2b080808080808, 0x2b2b080808081919, 0x2b2b080808082b2b,
|
4086
|
+
0x2b2b080808191908, 0x2b2b0808082b082b, 0x2b2b0808082b2b2b, 0x2b2b080819080819,
|
4087
|
+
0x2b2b080819081908, 0x2b2b080819190808, 0x2b2b08082b2b082b, 0x2b2b08082b2b2b2b,
|
4088
|
+
0x2b2b081919080808, 0x2b2b0819192b1919, 0x2b2b082b0808082b, 0x2b2b082b08082b2b,
|
4089
|
+
0x2b2b082b082b082b, 0x2b2b082b082b2b08, 0x2b2b082b082b2b2b, 0x2b2b082b2b08082b,
|
4090
|
+
0x2b2b082b2b082b08, 0x2b2b082b2b082b2b, 0x2b2b082b2b2b2b08, 0x2b2b190808080819,
|
4091
|
+
0x2b2b190808081908, 0x2b2b190808190808, 0x2b2b190819080808, 0x2b2b19082b082b19,
|
4092
|
+
0x2b2b19082b2b1908, 0x2b2b191908080808, 0x2b2b191908192b19, 0x2b2b192b19190819,
|
4093
|
+
0x2b2b2b0808082b2b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b082b, 0x2b2b2b1919191908,
|
4094
|
+
0x2b2b2b192b08192b, 0x2b2b2b2b08082b08, 0x2b2b2b2b08082b2b, 0x2b2b2b2b082b0808,
|
4095
|
+
0x2b2b2b2b082b082b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b2b082b08, 0x2b2b2b2b2b2b2b2b,
|
4096
|
+
};
|
4097
|
+
|
3763
4098
|
constexpr constant static uint32_t iq3xxs_grid[256] = {
|
3764
4099
|
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
3765
4100
|
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
@@ -3795,6 +4130,73 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
3795
4130
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
3796
4131
|
};
|
3797
4132
|
|
4133
|
+
constexpr constant static uint32_t iq3s_grid[512] = {
|
4134
|
+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
4135
|
+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
4136
|
+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
4137
|
+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
4138
|
+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
4139
|
+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
4140
|
+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
4141
|
+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
4142
|
+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
4143
|
+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
4144
|
+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
4145
|
+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
4146
|
+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
4147
|
+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
4148
|
+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
4149
|
+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
4150
|
+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
4151
|
+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
4152
|
+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
4153
|
+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
4154
|
+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
4155
|
+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
4156
|
+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
4157
|
+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
4158
|
+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
4159
|
+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
4160
|
+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
4161
|
+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
4162
|
+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
4163
|
+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
4164
|
+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
4165
|
+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
4166
|
+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
4167
|
+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
4168
|
+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
4169
|
+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
4170
|
+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
4171
|
+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
4172
|
+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
4173
|
+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
4174
|
+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
4175
|
+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
4176
|
+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
4177
|
+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
4178
|
+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
4179
|
+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
4180
|
+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
4181
|
+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
4182
|
+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
4183
|
+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
4184
|
+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
4185
|
+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
4186
|
+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
4187
|
+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
4188
|
+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
4189
|
+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
4190
|
+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
4191
|
+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
4192
|
+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
4193
|
+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
4194
|
+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
4195
|
+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
4196
|
+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
4197
|
+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
4198
|
+
};
|
4199
|
+
|
3798
4200
|
#define NGRID_IQ1S 512
|
3799
4201
|
constexpr constant static uint64_t iq1s_grid[NGRID_IQ1S] = {
|
3800
4202
|
0xffffffffffff0101, 0xffffffffff01ff00, 0xffffffffff010100, 0xffffffff00000000,
|
@@ -3991,7 +4393,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
3991
4393
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
3992
4394
|
}
|
3993
4395
|
|
3994
|
-
#if QK_K == 256
|
3995
4396
|
const int ix = tiisg;
|
3996
4397
|
|
3997
4398
|
device const float * y4 = y + 32 * ix;
|
@@ -4032,12 +4433,6 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
4032
4433
|
|
4033
4434
|
y4 += 32 * 32;
|
4034
4435
|
}
|
4035
|
-
#else
|
4036
|
-
(void) x;
|
4037
|
-
(void) y;
|
4038
|
-
(void) yl;
|
4039
|
-
(void) nb32;
|
4040
|
-
#endif
|
4041
4436
|
|
4042
4437
|
for (int row = 0; row < N_DST; ++row) {
|
4043
4438
|
all_sum = simd_sum(sumf[row]);
|
@@ -4127,7 +4522,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
4127
4522
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4128
4523
|
}
|
4129
4524
|
|
4130
|
-
#if QK_K == 256
|
4131
4525
|
const int ix = tiisg;
|
4132
4526
|
|
4133
4527
|
device const float * y4 = y + 32 * ix;
|
@@ -4178,12 +4572,6 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
4178
4572
|
|
4179
4573
|
y4 += 32 * 32;
|
4180
4574
|
}
|
4181
|
-
#else
|
4182
|
-
(void) x;
|
4183
|
-
(void) y;
|
4184
|
-
(void) yl;
|
4185
|
-
(void) nb32;
|
4186
|
-
#endif
|
4187
4575
|
|
4188
4576
|
for (int row = 0; row < N_DST; ++row) {
|
4189
4577
|
all_sum = simd_sum(sumf[row]);
|
@@ -4273,7 +4661,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
4273
4661
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4274
4662
|
}
|
4275
4663
|
|
4276
|
-
#if QK_K == 256
|
4277
4664
|
const int ix = tiisg;
|
4278
4665
|
|
4279
4666
|
device const float * y4 = y + 32 * ix;
|
@@ -4317,12 +4704,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
4317
4704
|
|
4318
4705
|
y4 += 32 * 32;
|
4319
4706
|
}
|
4320
|
-
#else
|
4321
|
-
(void) x;
|
4322
|
-
(void) y;
|
4323
|
-
(void) yl;
|
4324
|
-
(void) nb32;
|
4325
|
-
#endif
|
4326
4707
|
|
4327
4708
|
for (int row = 0; row < N_DST; ++row) {
|
4328
4709
|
all_sum = simd_sum(sumf[row]);
|
@@ -4361,7 +4742,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
4361
4742
|
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4362
4743
|
}
|
4363
4744
|
|
4364
|
-
void
|
4745
|
+
void kernel_mul_mv_iq3_s_f32_impl(
|
4365
4746
|
device const void * src0,
|
4366
4747
|
device const float * src1,
|
4367
4748
|
device float * dst,
|
@@ -4374,6 +4755,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4374
4755
|
constant int64_t & ne1,
|
4375
4756
|
constant uint & r2,
|
4376
4757
|
constant uint & r3,
|
4758
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4377
4759
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4378
4760
|
uint tiisg[[thread_index_in_simdgroup]],
|
4379
4761
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
@@ -4390,59 +4772,70 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4390
4772
|
const uint i13 = im/ne12;
|
4391
4773
|
|
4392
4774
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4393
|
-
|
4775
|
+
|
4776
|
+
device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
|
4394
4777
|
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4395
4778
|
|
4396
|
-
float yl[
|
4779
|
+
float yl[32];
|
4397
4780
|
float sumf[N_DST]={0.f}, all_sum;
|
4398
4781
|
|
4399
4782
|
const int nb32 = nb * (QK_K / 32);
|
4400
4783
|
|
4401
|
-
|
4402
|
-
|
4403
|
-
|
4784
|
+
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
4785
|
+
{
|
4786
|
+
int nval = 8;
|
4787
|
+
int pos = (32*sgitg + tiisg)*nval;
|
4788
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
|
4789
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
4790
|
+
}
|
4404
4791
|
|
4405
|
-
|
4792
|
+
const int ix = tiisg;
|
4406
4793
|
|
4407
|
-
|
4794
|
+
device const float * y4 = y + 32 * ix;
|
4408
4795
|
|
4409
|
-
|
4796
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4797
|
+
|
4798
|
+
for (int i = 0; i < 32; ++i) {
|
4410
4799
|
yl[i] = y4[i];
|
4411
4800
|
}
|
4412
4801
|
|
4413
4802
|
const int ibl = ib32 / (QK_K / 32);
|
4414
4803
|
const int ib = ib32 % (QK_K / 32);
|
4415
4804
|
|
4416
|
-
device const
|
4417
|
-
device const uint8_t * qs = xr->qs +
|
4418
|
-
device const uint8_t *
|
4419
|
-
device const
|
4805
|
+
device const block_iq3_s * xr = x + ibl;
|
4806
|
+
device const uint8_t * qs = xr->qs + 8 * ib;
|
4807
|
+
device const uint8_t * qh = xr->qh + ib;
|
4808
|
+
device const uint8_t * sc = xr->scales + (ib/2);
|
4809
|
+
device const uint8_t * signs = xr->signs + 4 * ib;
|
4810
|
+
device const half * dh = &xr->d;
|
4420
4811
|
|
4421
4812
|
for (int row = 0; row < N_DST; row++) {
|
4422
4813
|
|
4423
|
-
|
4424
|
-
|
4814
|
+
const float db = dh[0];
|
4815
|
+
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
4425
4816
|
|
4426
4817
|
float2 sum = {0};
|
4427
|
-
for (int
|
4428
|
-
|
4429
|
-
|
4818
|
+
for (int l = 0; l < 4; ++l) {
|
4819
|
+
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
|
4820
|
+
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
|
4821
|
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
4822
|
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
4823
|
+
for (int j = 0; j < 4; ++j) {
|
4824
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
4825
|
+
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
4826
|
+
}
|
4430
4827
|
}
|
4431
|
-
sumf[row] +=
|
4828
|
+
sumf[row] += d * (sum[0] + sum[1]);
|
4432
4829
|
|
4433
|
-
dh
|
4434
|
-
qs
|
4435
|
-
|
4830
|
+
dh += nb*sizeof(block_iq3_s)/2;
|
4831
|
+
qs += nb*sizeof(block_iq3_s);
|
4832
|
+
qh += nb*sizeof(block_iq3_s);
|
4833
|
+
sc += nb*sizeof(block_iq3_s);
|
4834
|
+
signs += nb*sizeof(block_iq3_s);
|
4436
4835
|
}
|
4437
4836
|
|
4438
|
-
y4 +=
|
4837
|
+
y4 += 32 * 32;
|
4439
4838
|
}
|
4440
|
-
#else
|
4441
|
-
(void) x;
|
4442
|
-
(void) y;
|
4443
|
-
(void) yl;
|
4444
|
-
(void) nb32;
|
4445
|
-
#endif
|
4446
4839
|
|
4447
4840
|
for (int row = 0; row < N_DST; ++row) {
|
4448
4841
|
all_sum = simd_sum(sumf[row]);
|
@@ -4452,11 +4845,36 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
4452
4845
|
}
|
4453
4846
|
}
|
4454
4847
|
|
4455
|
-
|
4456
|
-
|
4457
|
-
|
4848
|
+
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
4849
|
+
kernel void kernel_mul_mv_iq3_s_f32(
|
4850
|
+
device const void * src0,
|
4851
|
+
device const float * src1,
|
4852
|
+
device float * dst,
|
4853
|
+
constant int64_t & ne00,
|
4854
|
+
constant int64_t & ne01,
|
4855
|
+
constant int64_t & ne02,
|
4856
|
+
constant uint64_t & nb00,
|
4857
|
+
constant uint64_t & nb01,
|
4858
|
+
constant uint64_t & nb02,
|
4859
|
+
constant int64_t & ne10,
|
4860
|
+
constant int64_t & ne11,
|
4861
|
+
constant int64_t & ne12,
|
4862
|
+
constant uint64_t & nb10,
|
4863
|
+
constant uint64_t & nb11,
|
4864
|
+
constant uint64_t & nb12,
|
4865
|
+
constant int64_t & ne0,
|
4866
|
+
constant int64_t & ne1,
|
4867
|
+
constant uint & r2,
|
4868
|
+
constant uint & r3,
|
4869
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4870
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
4871
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
4872
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4458
4873
|
|
4459
|
-
|
4874
|
+
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4875
|
+
}
|
4876
|
+
|
4877
|
+
void kernel_mul_mv_iq2_s_f32_impl(
|
4460
4878
|
device const void * src0,
|
4461
4879
|
device const float * src1,
|
4462
4880
|
device float * dst,
|
@@ -4469,88 +4887,99 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
4469
4887
|
constant int64_t & ne1,
|
4470
4888
|
constant uint & r2,
|
4471
4889
|
constant uint & r3,
|
4472
|
-
threadgroup
|
4890
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4473
4891
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4474
4892
|
uint tiisg[[thread_index_in_simdgroup]],
|
4475
4893
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4476
4894
|
|
4477
|
-
const int nb = ne00/
|
4895
|
+
const int nb = ne00/QK_K;
|
4478
4896
|
const int r0 = tgpig.x;
|
4479
4897
|
const int r1 = tgpig.y;
|
4480
4898
|
const int im = tgpig.z;
|
4481
|
-
|
4899
|
+
|
4900
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
4482
4901
|
const int ib_row = first_row * nb;
|
4483
4902
|
|
4484
4903
|
const uint i12 = im%ne12;
|
4485
4904
|
const uint i13 = im/ne12;
|
4486
4905
|
|
4487
4906
|
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
4488
|
-
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
4489
|
-
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4490
|
-
|
4491
|
-
const int ix = tiisg/2; // 0...15
|
4492
|
-
const int it = tiisg%2; // 0 or 1
|
4493
4907
|
|
4494
|
-
|
4495
|
-
|
4496
|
-
|
4497
|
-
float4 yl[4];
|
4498
|
-
float sumf[2]={0.f}, all_sum;
|
4908
|
+
device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
|
4909
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4499
4910
|
|
4500
|
-
|
4911
|
+
float yl[32];
|
4912
|
+
float sumf[N_DST]={0.f}, all_sum;
|
4501
4913
|
|
4502
|
-
|
4503
|
-
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
4914
|
+
const int nb32 = nb * (QK_K / 32);
|
4504
4915
|
|
4505
|
-
|
4916
|
+
//threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
4917
|
+
//{
|
4918
|
+
// int nval = 32;
|
4919
|
+
// int pos = (32*sgitg + tiisg)*nval;
|
4920
|
+
// for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
|
4921
|
+
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
4922
|
+
//}
|
4506
4923
|
|
4507
|
-
|
4924
|
+
const int ix = tiisg;
|
4508
4925
|
|
4509
|
-
|
4510
|
-
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
4926
|
+
device const float * y4 = y + 32 * ix;
|
4511
4927
|
|
4512
|
-
|
4928
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
4513
4929
|
|
4514
|
-
|
4515
|
-
|
4930
|
+
for (int i = 0; i < 32; ++i) {
|
4931
|
+
yl[i] = y4[i];
|
4932
|
+
}
|
4516
4933
|
|
4517
|
-
|
4934
|
+
const int ibl = ib32 / (QK_K / 32);
|
4935
|
+
const int ib = ib32 % (QK_K / 32);
|
4518
4936
|
|
4519
|
-
|
4520
|
-
|
4521
|
-
|
4522
|
-
|
4523
|
-
|
4524
|
-
|
4525
|
-
acc2 += yl[1] * qf2;
|
4937
|
+
device const block_iq2_s * xr = x + ibl;
|
4938
|
+
device const uint8_t * qs = xr->qs + 4 * ib;
|
4939
|
+
device const uint8_t * qh = xr->qh + ib;
|
4940
|
+
device const uint8_t * sc = xr->scales + ib;
|
4941
|
+
device const uint8_t * signs = qs + QK_K/8;
|
4942
|
+
device const half * dh = &xr->d;
|
4526
4943
|
|
4527
|
-
|
4528
|
-
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
4529
|
-
aux32[0] &= 0x0f0f0f0f;
|
4530
|
-
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
4531
|
-
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
4532
|
-
acc1 += yl[2] * qf1;
|
4533
|
-
acc2 += yl[3] * qf2;
|
4944
|
+
for (int row = 0; row < N_DST; row++) {
|
4534
4945
|
|
4535
|
-
|
4946
|
+
const float db = dh[0];
|
4947
|
+
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
4948
|
+
const float d2 = db * (0.5f + (sc[0] >> 4));
|
4536
4949
|
|
4537
|
-
|
4950
|
+
float2 sum = {0};
|
4951
|
+
for (int l = 0; l < 2; ++l) {
|
4952
|
+
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
4953
|
+
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
4954
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
4955
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
4956
|
+
for (int j = 0; j < 8; ++j) {
|
4957
|
+
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
4958
|
+
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
4959
|
+
}
|
4960
|
+
}
|
4961
|
+
sumf[row] += d1 * sum[0] + d2 * sum[1];
|
4538
4962
|
|
4963
|
+
dh += nb*sizeof(block_iq2_s)/2;
|
4964
|
+
qs += nb*sizeof(block_iq2_s);
|
4965
|
+
qh += nb*sizeof(block_iq2_s);
|
4966
|
+
sc += nb*sizeof(block_iq2_s);
|
4967
|
+
signs += nb*sizeof(block_iq2_s);
|
4539
4968
|
}
|
4540
4969
|
|
4541
|
-
|
4970
|
+
y4 += 32 * 32;
|
4542
4971
|
}
|
4543
4972
|
|
4544
|
-
for (int row = 0; row <
|
4973
|
+
for (int row = 0; row < N_DST; ++row) {
|
4545
4974
|
all_sum = simd_sum(sumf[row]);
|
4546
4975
|
if (tiisg == 0) {
|
4547
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
4976
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
|
4548
4977
|
}
|
4549
4978
|
}
|
4550
4979
|
}
|
4551
4980
|
|
4552
|
-
[[host_name("
|
4553
|
-
kernel void
|
4981
|
+
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
4982
|
+
kernel void kernel_mul_mv_iq2_s_f32(
|
4554
4983
|
device const void * src0,
|
4555
4984
|
device const float * src1,
|
4556
4985
|
device float * dst,
|
@@ -4570,67 +4999,406 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
4570
4999
|
constant int64_t & ne1,
|
4571
5000
|
constant uint & r2,
|
4572
5001
|
constant uint & r3,
|
5002
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
4573
5003
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4574
5004
|
uint tiisg[[thread_index_in_simdgroup]],
|
4575
5005
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4576
5006
|
|
4577
|
-
|
5007
|
+
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
4578
5008
|
}
|
4579
5009
|
|
4580
|
-
|
4581
|
-
kernel void kernel_mul_mv_iq4_nl_f32(
|
5010
|
+
void kernel_mul_mv_iq1_s_f32_impl(
|
4582
5011
|
device const void * src0,
|
4583
5012
|
device const float * src1,
|
4584
5013
|
device float * dst,
|
4585
5014
|
constant int64_t & ne00,
|
4586
5015
|
constant int64_t & ne01,
|
4587
5016
|
constant int64_t & ne02,
|
4588
|
-
constant uint64_t & nb00,
|
4589
|
-
constant uint64_t & nb01,
|
4590
|
-
constant uint64_t & nb02,
|
4591
5017
|
constant int64_t & ne10,
|
4592
|
-
constant int64_t & ne11,
|
4593
5018
|
constant int64_t & ne12,
|
4594
|
-
constant uint64_t & nb10,
|
4595
|
-
constant uint64_t & nb11,
|
4596
|
-
constant uint64_t & nb12,
|
4597
5019
|
constant int64_t & ne0,
|
4598
5020
|
constant int64_t & ne1,
|
4599
5021
|
constant uint & r2,
|
4600
5022
|
constant uint & r3,
|
4601
|
-
threadgroup float * shared_values [[threadgroup(0)]],
|
4602
5023
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
4603
|
-
uint
|
4604
|
-
uint
|
5024
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5025
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
4605
5026
|
|
4606
|
-
|
4607
|
-
|
5027
|
+
const int nb = ne00/QK_K;
|
5028
|
+
const int r0 = tgpig.x;
|
5029
|
+
const int r1 = tgpig.y;
|
5030
|
+
const int im = tgpig.z;
|
4608
5031
|
|
4609
|
-
|
5032
|
+
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
5033
|
+
const int ib_row = first_row * nb;
|
4610
5034
|
|
4611
|
-
|
4612
|
-
|
4613
|
-
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
4614
|
-
float4x4 temp = *(((device float4x4 *)src));
|
4615
|
-
for (int i = 0; i < 16; i++){
|
4616
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
4617
|
-
}
|
4618
|
-
}
|
5035
|
+
const uint i12 = im%ne12;
|
5036
|
+
const uint i13 = im/ne12;
|
4619
5037
|
|
4620
|
-
|
4621
|
-
|
4622
|
-
|
4623
|
-
for (int i = 0; i < 16; i++){
|
4624
|
-
reg[i/4][i%4] = temp[i/4][i%4];
|
4625
|
-
}
|
4626
|
-
}
|
5038
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
5039
|
+
device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
|
5040
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
4627
5041
|
|
4628
|
-
|
4629
|
-
|
4630
|
-
|
4631
|
-
const
|
4632
|
-
|
4633
|
-
const
|
5042
|
+
float yl[16];
|
5043
|
+
float sumf[N_DST]={0.f}, all_sum;
|
5044
|
+
|
5045
|
+
const int nb32 = nb * (QK_K / 32);
|
5046
|
+
|
5047
|
+
const int ix = tiisg/2;
|
5048
|
+
const int il = tiisg%2;
|
5049
|
+
|
5050
|
+
device const float * y4 = y + 32 * ix + 16 * il;
|
5051
|
+
|
5052
|
+
for (int ib32 = ix; ib32 < nb32; ib32 += 16) {
|
5053
|
+
|
5054
|
+
for (int i = 0; i < 16; ++i) {
|
5055
|
+
yl[i] = y4[i];
|
5056
|
+
}
|
5057
|
+
|
5058
|
+
const int ibl = ib32 / (QK_K / 32);
|
5059
|
+
const int ib = ib32 % (QK_K / 32);
|
5060
|
+
|
5061
|
+
device const block_iq1_s * xr = x + ibl;
|
5062
|
+
device const uint8_t * qs = xr->qs + 4 * ib + 2 * il;
|
5063
|
+
device const uint8_t * sc = xr->scales + 2 * ib + il;
|
5064
|
+
device const half * dh = &xr->d;
|
5065
|
+
|
5066
|
+
for (int row = 0; row < N_DST; row++) {
|
5067
|
+
|
5068
|
+
constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
|
5069
|
+
constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
|
5070
|
+
|
5071
|
+
float2 sum = {0};
|
5072
|
+
for (int j = 0; j < 8; ++j) {
|
5073
|
+
sum[0] += yl[j+ 0] * grid1[j];
|
5074
|
+
sum[1] += yl[j+ 8] * grid2[j];
|
5075
|
+
}
|
5076
|
+
sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1));
|
5077
|
+
|
5078
|
+
dh += nb*sizeof(block_iq1_s)/2;
|
5079
|
+
qs += nb*sizeof(block_iq1_s);
|
5080
|
+
sc += nb*sizeof(block_iq1_s);
|
5081
|
+
}
|
5082
|
+
|
5083
|
+
y4 += 16 * 32;
|
5084
|
+
}
|
5085
|
+
|
5086
|
+
for (int row = 0; row < N_DST; ++row) {
|
5087
|
+
all_sum = simd_sum(sumf[row]);
|
5088
|
+
if (tiisg == 0) {
|
5089
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
5090
|
+
}
|
5091
|
+
}
|
5092
|
+
}
|
5093
|
+
|
5094
|
+
constexpr constant static float kvalues_iq4nl_f[16] = {
|
5095
|
+
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
|
5096
|
+
};
|
5097
|
+
|
5098
|
+
void kernel_mul_mv_iq4_nl_f32_impl(
|
5099
|
+
device const void * src0,
|
5100
|
+
device const float * src1,
|
5101
|
+
device float * dst,
|
5102
|
+
constant int64_t & ne00,
|
5103
|
+
constant int64_t & ne01,
|
5104
|
+
constant int64_t & ne02,
|
5105
|
+
constant int64_t & ne10,
|
5106
|
+
constant int64_t & ne12,
|
5107
|
+
constant int64_t & ne0,
|
5108
|
+
constant int64_t & ne1,
|
5109
|
+
constant uint & r2,
|
5110
|
+
constant uint & r3,
|
5111
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5112
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5113
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5114
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5115
|
+
|
5116
|
+
const int nb = ne00/QK4_NL;
|
5117
|
+
const int r0 = tgpig.x;
|
5118
|
+
const int r1 = tgpig.y;
|
5119
|
+
const int im = tgpig.z;
|
5120
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
5121
|
+
const int ib_row = first_row * nb;
|
5122
|
+
|
5123
|
+
const uint i12 = im%ne12;
|
5124
|
+
const uint i13 = im/ne12;
|
5125
|
+
|
5126
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
5127
|
+
device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
|
5128
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
5129
|
+
|
5130
|
+
const int ix = tiisg/2; // 0...15
|
5131
|
+
const int it = tiisg%2; // 0 or 1
|
5132
|
+
|
5133
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
5134
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5135
|
+
|
5136
|
+
float4 yl[4];
|
5137
|
+
float sumf[2]={0.f}, all_sum;
|
5138
|
+
|
5139
|
+
device const float * yb = y + ix * QK4_NL + it * 8;
|
5140
|
+
|
5141
|
+
uint32_t aux32[2];
|
5142
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
5143
|
+
|
5144
|
+
float4 qf1, qf2;
|
5145
|
+
|
5146
|
+
for (int ib = ix; ib < nb; ib += 16) {
|
5147
|
+
|
5148
|
+
device const float4 * y4 = (device const float4 *)yb;
|
5149
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
5150
|
+
|
5151
|
+
for (int row = 0; row < 2; ++row) {
|
5152
|
+
|
5153
|
+
device const block_iq4_nl & xb = x[row*nb + ib];
|
5154
|
+
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
5155
|
+
|
5156
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
5157
|
+
|
5158
|
+
aux32[0] = q4[0] | (q4[1] << 16);
|
5159
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
5160
|
+
aux32[0] &= 0x0f0f0f0f;
|
5161
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5162
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5163
|
+
acc1 += yl[0] * qf1;
|
5164
|
+
acc2 += yl[1] * qf2;
|
5165
|
+
|
5166
|
+
aux32[0] = q4[2] | (q4[3] << 16);
|
5167
|
+
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
|
5168
|
+
aux32[0] &= 0x0f0f0f0f;
|
5169
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5170
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5171
|
+
acc1 += yl[2] * qf1;
|
5172
|
+
acc2 += yl[3] * qf2;
|
5173
|
+
|
5174
|
+
acc1 += acc2;
|
5175
|
+
|
5176
|
+
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
5177
|
+
|
5178
|
+
}
|
5179
|
+
|
5180
|
+
yb += 16 * QK4_NL;
|
5181
|
+
}
|
5182
|
+
|
5183
|
+
for (int row = 0; row < 2; ++row) {
|
5184
|
+
all_sum = simd_sum(sumf[row]);
|
5185
|
+
if (tiisg == 0) {
|
5186
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
5187
|
+
}
|
5188
|
+
}
|
5189
|
+
}
|
5190
|
+
|
5191
|
+
#if QK_K != 64
|
5192
|
+
void kernel_mul_mv_iq4_xs_f32_impl(
|
5193
|
+
device const void * src0,
|
5194
|
+
device const float * src1,
|
5195
|
+
device float * dst,
|
5196
|
+
constant int64_t & ne00,
|
5197
|
+
constant int64_t & ne01,
|
5198
|
+
constant int64_t & ne02,
|
5199
|
+
constant int64_t & ne10,
|
5200
|
+
constant int64_t & ne12,
|
5201
|
+
constant int64_t & ne0,
|
5202
|
+
constant int64_t & ne1,
|
5203
|
+
constant uint & r2,
|
5204
|
+
constant uint & r3,
|
5205
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5206
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5207
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5208
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5209
|
+
|
5210
|
+
const int nb = ne00/QK_K;
|
5211
|
+
const int r0 = tgpig.x;
|
5212
|
+
const int r1 = tgpig.y;
|
5213
|
+
const int im = tgpig.z;
|
5214
|
+
const int first_row = (r0 * 2 + sgitg) * 2;
|
5215
|
+
const int ib_row = first_row * nb;
|
5216
|
+
|
5217
|
+
const uint i12 = im%ne12;
|
5218
|
+
const uint i13 = im/ne12;
|
5219
|
+
|
5220
|
+
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
5221
|
+
device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
|
5222
|
+
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
5223
|
+
|
5224
|
+
const int ix = tiisg/16; // 0 or 1
|
5225
|
+
const int it = tiisg%16; // 0...15
|
5226
|
+
const int ib = it/2;
|
5227
|
+
const int il = it%2;
|
5228
|
+
|
5229
|
+
shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
5230
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
5231
|
+
|
5232
|
+
float4 yl[4];
|
5233
|
+
float sumf[2]={0.f}, all_sum;
|
5234
|
+
|
5235
|
+
device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
|
5236
|
+
|
5237
|
+
uint32_t aux32[2];
|
5238
|
+
thread const uint8_t * q8 = (thread const uint8_t *)aux32;
|
5239
|
+
|
5240
|
+
float4 qf1, qf2;
|
5241
|
+
|
5242
|
+
for (int ibl = ix; ibl < nb; ibl += 2) {
|
5243
|
+
|
5244
|
+
device const float4 * y4 = (device const float4 *)yb;
|
5245
|
+
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
5246
|
+
|
5247
|
+
for (int row = 0; row < 2; ++row) {
|
5248
|
+
|
5249
|
+
device const block_iq4_xs & xb = x[row*nb + ibl];
|
5250
|
+
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
5251
|
+
|
5252
|
+
float4 acc1 = {0.f}, acc2 = {0.f};
|
5253
|
+
|
5254
|
+
aux32[0] = q4[0] & 0x0f0f0f0f;
|
5255
|
+
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
|
5256
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5257
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5258
|
+
acc1 += yl[0] * qf1;
|
5259
|
+
acc2 += yl[1] * qf2;
|
5260
|
+
|
5261
|
+
aux32[0] = q4[1] & 0x0f0f0f0f;
|
5262
|
+
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
|
5263
|
+
qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
|
5264
|
+
qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
|
5265
|
+
acc1 += yl[2] * qf1;
|
5266
|
+
acc2 += yl[3] * qf2;
|
5267
|
+
|
5268
|
+
acc1 += acc2;
|
5269
|
+
|
5270
|
+
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
5271
|
+
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
5272
|
+
|
5273
|
+
}
|
5274
|
+
|
5275
|
+
yb += 2 * QK_K;
|
5276
|
+
}
|
5277
|
+
|
5278
|
+
for (int row = 0; row < 2; ++row) {
|
5279
|
+
all_sum = simd_sum(sumf[row]);
|
5280
|
+
if (tiisg == 0) {
|
5281
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
5282
|
+
}
|
5283
|
+
}
|
5284
|
+
}
|
5285
|
+
#endif
|
5286
|
+
|
5287
|
+
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
5288
|
+
kernel void kernel_mul_mv_iq1_s_f32(
|
5289
|
+
device const void * src0,
|
5290
|
+
device const float * src1,
|
5291
|
+
device float * dst,
|
5292
|
+
constant int64_t & ne00,
|
5293
|
+
constant int64_t & ne01,
|
5294
|
+
constant int64_t & ne02,
|
5295
|
+
constant uint64_t & nb00,
|
5296
|
+
constant uint64_t & nb01,
|
5297
|
+
constant uint64_t & nb02,
|
5298
|
+
constant int64_t & ne10,
|
5299
|
+
constant int64_t & ne11,
|
5300
|
+
constant int64_t & ne12,
|
5301
|
+
constant uint64_t & nb10,
|
5302
|
+
constant uint64_t & nb11,
|
5303
|
+
constant uint64_t & nb12,
|
5304
|
+
constant int64_t & ne0,
|
5305
|
+
constant int64_t & ne1,
|
5306
|
+
constant uint & r2,
|
5307
|
+
constant uint & r3,
|
5308
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5309
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5310
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5311
|
+
|
5312
|
+
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
5313
|
+
}
|
5314
|
+
|
5315
|
+
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
5316
|
+
kernel void kernel_mul_mv_iq4_nl_f32(
|
5317
|
+
device const void * src0,
|
5318
|
+
device const float * src1,
|
5319
|
+
device float * dst,
|
5320
|
+
constant int64_t & ne00,
|
5321
|
+
constant int64_t & ne01,
|
5322
|
+
constant int64_t & ne02,
|
5323
|
+
constant uint64_t & nb00,
|
5324
|
+
constant uint64_t & nb01,
|
5325
|
+
constant uint64_t & nb02,
|
5326
|
+
constant int64_t & ne10,
|
5327
|
+
constant int64_t & ne11,
|
5328
|
+
constant int64_t & ne12,
|
5329
|
+
constant uint64_t & nb10,
|
5330
|
+
constant uint64_t & nb11,
|
5331
|
+
constant uint64_t & nb12,
|
5332
|
+
constant int64_t & ne0,
|
5333
|
+
constant int64_t & ne1,
|
5334
|
+
constant uint & r2,
|
5335
|
+
constant uint & r3,
|
5336
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5337
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5338
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5339
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5340
|
+
|
5341
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5342
|
+
}
|
5343
|
+
|
5344
|
+
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
5345
|
+
kernel void kernel_mul_mv_iq4_xs_f32(
|
5346
|
+
device const void * src0,
|
5347
|
+
device const float * src1,
|
5348
|
+
device float * dst,
|
5349
|
+
constant int64_t & ne00,
|
5350
|
+
constant int64_t & ne01,
|
5351
|
+
constant int64_t & ne02,
|
5352
|
+
constant uint64_t & nb00,
|
5353
|
+
constant uint64_t & nb01,
|
5354
|
+
constant uint64_t & nb02,
|
5355
|
+
constant int64_t & ne10,
|
5356
|
+
constant int64_t & ne11,
|
5357
|
+
constant int64_t & ne12,
|
5358
|
+
constant uint64_t & nb10,
|
5359
|
+
constant uint64_t & nb11,
|
5360
|
+
constant uint64_t & nb12,
|
5361
|
+
constant int64_t & ne0,
|
5362
|
+
constant int64_t & ne1,
|
5363
|
+
constant uint & r2,
|
5364
|
+
constant uint & r3,
|
5365
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
5366
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
5367
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
5368
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
5369
|
+
|
5370
|
+
#if QK_K == 64
|
5371
|
+
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5372
|
+
#else
|
5373
|
+
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
5374
|
+
#endif
|
5375
|
+
}
|
5376
|
+
|
5377
|
+
//============================= templates and their specializations =============================
|
5378
|
+
|
5379
|
+
// NOTE: this is not dequantizing - we are simply fitting the template
|
5380
|
+
template <typename type4x4>
|
5381
|
+
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
5382
|
+
float4x4 temp = *(((device float4x4 *)src));
|
5383
|
+
for (int i = 0; i < 16; i++){
|
5384
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
5385
|
+
}
|
5386
|
+
}
|
5387
|
+
|
5388
|
+
template <typename type4x4>
|
5389
|
+
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
5390
|
+
half4x4 temp = *(((device half4x4 *)src));
|
5391
|
+
for (int i = 0; i < 16; i++){
|
5392
|
+
reg[i/4][i%4] = temp[i/4][i%4];
|
5393
|
+
}
|
5394
|
+
}
|
5395
|
+
|
5396
|
+
template <typename type4x4>
|
5397
|
+
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
5398
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
5399
|
+
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
5400
|
+
const float d2 = d1 / 256.f;
|
5401
|
+
const float md = -8.h * xb->d;
|
4634
5402
|
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
4635
5403
|
const ushort mask1 = mask0 << 8;
|
4636
5404
|
|
@@ -4952,6 +5720,50 @@ void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x
|
|
4952
5720
|
}
|
4953
5721
|
}
|
4954
5722
|
|
5723
|
+
template <typename type4x4>
|
5724
|
+
void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
|
5725
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5726
|
+
const float d = xb->d;
|
5727
|
+
const int ib32 = il/2;
|
5728
|
+
il = il%2;
|
5729
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
5730
|
+
device const uint8_t * qs = xb->qs + 8*ib32;
|
5731
|
+
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
5732
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
5733
|
+
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
5734
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
5735
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
5736
|
+
for (int i = 0; i < 4; ++i) {
|
5737
|
+
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
5738
|
+
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
5739
|
+
}
|
5740
|
+
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
5741
|
+
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
5742
|
+
for (int i = 0; i < 4; ++i) {
|
5743
|
+
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
5744
|
+
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
5745
|
+
}
|
5746
|
+
}
|
5747
|
+
|
5748
|
+
template <typename type4x4>
|
5749
|
+
void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
|
5750
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5751
|
+
const float d = xb->d;
|
5752
|
+
const int ib32 = il/2;
|
5753
|
+
il = il%2;
|
5754
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
5755
|
+
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
5756
|
+
device const uint8_t * signs = qs + QK_K/8;
|
5757
|
+
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
5758
|
+
const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
|
5759
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
|
5760
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
|
5761
|
+
for (int i = 0; i < 8; ++i) {
|
5762
|
+
reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
|
5763
|
+
reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
|
5764
|
+
}
|
5765
|
+
}
|
5766
|
+
|
4955
5767
|
template <typename type4x4>
|
4956
5768
|
void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
|
4957
5769
|
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
@@ -4983,6 +5795,30 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
|
|
4983
5795
|
}
|
4984
5796
|
}
|
4985
5797
|
|
5798
|
+
template <typename type4x4>
|
5799
|
+
void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
|
5800
|
+
#if QK_K == 64
|
5801
|
+
dequantize_iq4_nl(xb, il, reg);
|
5802
|
+
#else
|
5803
|
+
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
5804
|
+
const int ib32 = il/2;
|
5805
|
+
il = il%2;
|
5806
|
+
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
5807
|
+
device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
|
5808
|
+
const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
|
5809
|
+
const float d = (float)xb->d * (ls - 32);
|
5810
|
+
uint32_t aux32;
|
5811
|
+
thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
|
5812
|
+
for (int i = 0; i < 4; ++i) {
|
5813
|
+
aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
|
5814
|
+
reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
|
5815
|
+
reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
|
5816
|
+
reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
|
5817
|
+
reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
|
5818
|
+
}
|
5819
|
+
#endif
|
5820
|
+
}
|
5821
|
+
|
4986
5822
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
4987
5823
|
kernel void kernel_get_rows(
|
4988
5824
|
device const void * src0,
|
@@ -5525,8 +6361,15 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
|
|
5525
6361
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5526
6362
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5527
6363
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6364
|
+
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6365
|
+
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5528
6366
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5529
|
-
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2,
|
6367
|
+
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6368
|
+
#if QK_K == 64
|
6369
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6370
|
+
#else
|
6371
|
+
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6372
|
+
#endif
|
5530
6373
|
|
5531
6374
|
//
|
5532
6375
|
// matrix-matrix multiplication
|
@@ -5566,8 +6409,15 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|
5566
6409
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5567
6410
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5568
6411
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6412
|
+
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6413
|
+
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5569
6414
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5570
|
-
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2,
|
6415
|
+
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6416
|
+
#if QK_K == 64
|
6417
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
6418
|
+
#else
|
6419
|
+
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6420
|
+
#endif
|
5571
6421
|
|
5572
6422
|
//
|
5573
6423
|
// indirect matrix-matrix multiplication
|
@@ -5619,8 +6469,15 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
|
|
5619
6469
|
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
5620
6470
|
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
5621
6471
|
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
6472
|
+
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
6473
|
+
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
5622
6474
|
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
5623
|
-
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2,
|
6475
|
+
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
6476
|
+
#if QK_K == 64
|
6477
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
6478
|
+
#else
|
6479
|
+
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
6480
|
+
#endif
|
5624
6481
|
|
5625
6482
|
//
|
5626
6483
|
// matrix-vector multiplication
|
@@ -6589,6 +7446,136 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|
6589
7446
|
sgitg);
|
6590
7447
|
}
|
6591
7448
|
|
7449
|
+
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
7450
|
+
kernel void kernel_mul_mv_id_iq3_s_f32(
|
7451
|
+
device const char * ids,
|
7452
|
+
device const char * src1,
|
7453
|
+
device float * dst,
|
7454
|
+
constant uint64_t & nbi1,
|
7455
|
+
constant int64_t & ne00,
|
7456
|
+
constant int64_t & ne01,
|
7457
|
+
constant int64_t & ne02,
|
7458
|
+
constant uint64_t & nb00,
|
7459
|
+
constant uint64_t & nb01,
|
7460
|
+
constant uint64_t & nb02,
|
7461
|
+
constant int64_t & ne10,
|
7462
|
+
constant int64_t & ne11,
|
7463
|
+
constant int64_t & ne12,
|
7464
|
+
constant int64_t & ne13,
|
7465
|
+
constant uint64_t & nb10,
|
7466
|
+
constant uint64_t & nb11,
|
7467
|
+
constant uint64_t & nb12,
|
7468
|
+
constant int64_t & ne0,
|
7469
|
+
constant int64_t & ne1,
|
7470
|
+
constant uint64_t & nb1,
|
7471
|
+
constant uint & r2,
|
7472
|
+
constant uint & r3,
|
7473
|
+
constant int & idx,
|
7474
|
+
device const char * src00,
|
7475
|
+
device const char * src01,
|
7476
|
+
device const char * src02,
|
7477
|
+
device const char * src03,
|
7478
|
+
device const char * src04,
|
7479
|
+
device const char * src05,
|
7480
|
+
device const char * src06,
|
7481
|
+
device const char * src07,
|
7482
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
7483
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7484
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7485
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7486
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7487
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7488
|
+
|
7489
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7490
|
+
|
7491
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7492
|
+
|
7493
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7494
|
+
|
7495
|
+
kernel_mul_mv_iq3_s_f32_impl(
|
7496
|
+
src0[id],
|
7497
|
+
(device const float *) (src1 + bid*nb11),
|
7498
|
+
dst + bid*ne0,
|
7499
|
+
ne00,
|
7500
|
+
ne01,
|
7501
|
+
ne02,
|
7502
|
+
ne10,
|
7503
|
+
ne12,
|
7504
|
+
ne0,
|
7505
|
+
ne1,
|
7506
|
+
r2,
|
7507
|
+
r3,
|
7508
|
+
shared_values,
|
7509
|
+
tgpig,
|
7510
|
+
tiisg,
|
7511
|
+
sgitg);
|
7512
|
+
}
|
7513
|
+
|
7514
|
+
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
7515
|
+
kernel void kernel_mul_mv_id_iq2_s_f32(
|
7516
|
+
device const char * ids,
|
7517
|
+
device const char * src1,
|
7518
|
+
device float * dst,
|
7519
|
+
constant uint64_t & nbi1,
|
7520
|
+
constant int64_t & ne00,
|
7521
|
+
constant int64_t & ne01,
|
7522
|
+
constant int64_t & ne02,
|
7523
|
+
constant uint64_t & nb00,
|
7524
|
+
constant uint64_t & nb01,
|
7525
|
+
constant uint64_t & nb02,
|
7526
|
+
constant int64_t & ne10,
|
7527
|
+
constant int64_t & ne11,
|
7528
|
+
constant int64_t & ne12,
|
7529
|
+
constant int64_t & ne13,
|
7530
|
+
constant uint64_t & nb10,
|
7531
|
+
constant uint64_t & nb11,
|
7532
|
+
constant uint64_t & nb12,
|
7533
|
+
constant int64_t & ne0,
|
7534
|
+
constant int64_t & ne1,
|
7535
|
+
constant uint64_t & nb1,
|
7536
|
+
constant uint & r2,
|
7537
|
+
constant uint & r3,
|
7538
|
+
constant int & idx,
|
7539
|
+
device const char * src00,
|
7540
|
+
device const char * src01,
|
7541
|
+
device const char * src02,
|
7542
|
+
device const char * src03,
|
7543
|
+
device const char * src04,
|
7544
|
+
device const char * src05,
|
7545
|
+
device const char * src06,
|
7546
|
+
device const char * src07,
|
7547
|
+
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
7548
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7549
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7550
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7551
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7552
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7553
|
+
|
7554
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7555
|
+
|
7556
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7557
|
+
|
7558
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7559
|
+
|
7560
|
+
kernel_mul_mv_iq2_s_f32_impl(
|
7561
|
+
src0[id],
|
7562
|
+
(device const float *) (src1 + bid*nb11),
|
7563
|
+
dst + bid*ne0,
|
7564
|
+
ne00,
|
7565
|
+
ne01,
|
7566
|
+
ne02,
|
7567
|
+
ne10,
|
7568
|
+
ne12,
|
7569
|
+
ne0,
|
7570
|
+
ne1,
|
7571
|
+
r2,
|
7572
|
+
r3,
|
7573
|
+
shared_values,
|
7574
|
+
tgpig,
|
7575
|
+
tiisg,
|
7576
|
+
sgitg);
|
7577
|
+
}
|
7578
|
+
|
6592
7579
|
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
6593
7580
|
kernel void kernel_mul_mv_id_iq1_s_f32(
|
6594
7581
|
device const char * ids,
|
@@ -6716,3 +7703,72 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|
6716
7703
|
tiisg,
|
6717
7704
|
sgitg);
|
6718
7705
|
}
|
7706
|
+
|
7707
|
+
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
7708
|
+
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
7709
|
+
device const char * ids,
|
7710
|
+
device const char * src1,
|
7711
|
+
device float * dst,
|
7712
|
+
constant uint64_t & nbi1,
|
7713
|
+
constant int64_t & ne00,
|
7714
|
+
constant int64_t & ne01,
|
7715
|
+
constant int64_t & ne02,
|
7716
|
+
constant uint64_t & nb00,
|
7717
|
+
constant uint64_t & nb01,
|
7718
|
+
constant uint64_t & nb02,
|
7719
|
+
constant int64_t & ne10,
|
7720
|
+
constant int64_t & ne11,
|
7721
|
+
constant int64_t & ne12,
|
7722
|
+
constant int64_t & ne13,
|
7723
|
+
constant uint64_t & nb10,
|
7724
|
+
constant uint64_t & nb11,
|
7725
|
+
constant uint64_t & nb12,
|
7726
|
+
constant int64_t & ne0,
|
7727
|
+
constant int64_t & ne1,
|
7728
|
+
constant uint64_t & nb1,
|
7729
|
+
constant uint & r2,
|
7730
|
+
constant uint & r3,
|
7731
|
+
constant int & idx,
|
7732
|
+
device const char * src00,
|
7733
|
+
device const char * src01,
|
7734
|
+
device const char * src02,
|
7735
|
+
device const char * src03,
|
7736
|
+
device const char * src04,
|
7737
|
+
device const char * src05,
|
7738
|
+
device const char * src06,
|
7739
|
+
device const char * src07,
|
7740
|
+
threadgroup float * shared_values [[threadgroup(0)]],
|
7741
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
7742
|
+
uint tiitg[[thread_index_in_threadgroup]],
|
7743
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
7744
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
7745
|
+
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
7746
|
+
|
7747
|
+
const int64_t bid = tgpig.z/(ne12*ne13);
|
7748
|
+
|
7749
|
+
tgpig.z = tgpig.z%(ne12*ne13);
|
7750
|
+
|
7751
|
+
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
7752
|
+
|
7753
|
+
#if QK_K == 64
|
7754
|
+
kernel_mul_mv_iq4_nl_f32_impl(
|
7755
|
+
#else
|
7756
|
+
kernel_mul_mv_iq4_xs_f32_impl(
|
7757
|
+
#endif
|
7758
|
+
src0[id],
|
7759
|
+
(device const float *) (src1 + bid*nb11),
|
7760
|
+
dst + bid*ne0,
|
7761
|
+
ne00,
|
7762
|
+
ne01,
|
7763
|
+
ne02,
|
7764
|
+
ne10,
|
7765
|
+
ne12,
|
7766
|
+
ne0,
|
7767
|
+
ne1,
|
7768
|
+
r2,
|
7769
|
+
r3,
|
7770
|
+
shared_values,
|
7771
|
+
tgpig,
|
7772
|
+
tiisg,
|
7773
|
+
sgitg);
|
7774
|
+
}
|