cui-llama.rn 1.1.2 → 1.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/android/src/main/CMakeLists.txt +1 -2
- package/android/src/main/jni.cpp +26 -21
- package/cpp/common.cpp +2028 -1520
- package/cpp/common.h +134 -18
- package/cpp/ggml-aarch64.c +612 -0
- package/cpp/ggml-alloc.h +2 -2
- package/cpp/ggml-backend.c +33 -6
- package/cpp/ggml-backend.h +2 -0
- package/cpp/ggml-common.h +20 -0
- package/cpp/ggml-impl.h +4 -7
- package/cpp/ggml-metal.m +63 -2
- package/cpp/ggml-quants.c +690 -2
- package/cpp/ggml-quants.h +15 -0
- package/cpp/ggml.c +1650 -317
- package/cpp/ggml.h +155 -48
- package/cpp/llama-grammar.cpp +721 -122
- package/cpp/llama-grammar.h +120 -15
- package/cpp/llama-impl.h +132 -1
- package/cpp/llama-sampling.cpp +1361 -356
- package/cpp/llama-sampling.h +20 -48
- package/cpp/llama-vocab.cpp +140 -7
- package/cpp/llama-vocab.h +3 -2
- package/cpp/llama.cpp +810 -307
- package/cpp/llama.h +213 -259
- package/cpp/rn-llama.hpp +17 -14
- package/cpp/sampling.cpp +347 -355
- package/cpp/sampling.h +106 -135
- package/cpp/sgemm.cpp +153 -0
- package/package.json +1 -1
- package/cpp/grammar-parser.cpp +0 -539
- package/cpp/grammar-parser.h +0 -29
package/cpp/sampling.h
CHANGED
@@ -2,163 +2,134 @@
|
|
2
2
|
|
3
3
|
#include "llama.h"
|
4
4
|
|
5
|
-
#include "grammar-parser.h"
|
6
|
-
|
7
|
-
#include <random>
|
8
5
|
#include <string>
|
9
|
-
#include <unordered_map>
|
10
6
|
#include <vector>
|
11
7
|
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
8
|
+
enum gpt_sampler_type {
|
9
|
+
GPT_SAMPLER_TYPE_NONE = 0,
|
10
|
+
GPT_SAMPLER_TYPE_TOP_K = 1,
|
11
|
+
GPT_SAMPLER_TYPE_TOP_P = 2,
|
12
|
+
GPT_SAMPLER_TYPE_MIN_P = 3,
|
13
|
+
GPT_SAMPLER_TYPE_TFS_Z = 4,
|
14
|
+
GPT_SAMPLER_TYPE_TYPICAL_P = 5,
|
15
|
+
GPT_SAMPLER_TYPE_TEMPERATURE = 6,
|
16
|
+
GPT_SAMPLER_TYPE_XTC = 7,
|
21
17
|
};
|
22
18
|
|
23
19
|
// sampling parameters
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
int32_t
|
28
|
-
int32_t
|
29
|
-
|
30
|
-
|
31
|
-
float
|
32
|
-
float
|
33
|
-
float
|
34
|
-
float
|
35
|
-
float
|
36
|
-
float
|
37
|
-
float
|
38
|
-
|
39
|
-
float
|
40
|
-
|
41
|
-
float
|
42
|
-
|
43
|
-
float
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
20
|
+
struct gpt_sampler_params {
|
21
|
+
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
22
|
+
|
23
|
+
int32_t n_prev = 64; // number of previous tokens to remember
|
24
|
+
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
25
|
+
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
|
26
|
+
int32_t top_k = 40; // <= 0 to use vocab size
|
27
|
+
float top_p = 0.95f; // 1.0 = disabled
|
28
|
+
float min_p = 0.05f; // 0.0 = disabled
|
29
|
+
float tfs_z = 1.00f; // 1.0 = disabled
|
30
|
+
float xtc_t = 0.0f; // 0.0 = disabled
|
31
|
+
float xtc_p = 0.0f;
|
32
|
+
float typ_p = 1.00f; // typical_p, 1.0 = disabled
|
33
|
+
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
|
34
|
+
float dynatemp_range = 0.00f; // 0.0 = disabled
|
35
|
+
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
|
36
|
+
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
37
|
+
float penalty_repeat = 1.00f; // 1.0 = disabled
|
38
|
+
float penalty_freq = 0.00f; // 0.0 = disabled
|
39
|
+
float penalty_present = 0.00f; // 0.0 = disabled
|
40
|
+
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
41
|
+
float mirostat_tau = 5.00f; // target entropy
|
42
|
+
float mirostat_eta = 0.10f; // learning rate
|
43
|
+
bool penalize_nl = false; // consider newlines as a repeatable token
|
44
|
+
bool ignore_eos = false;
|
45
|
+
|
46
|
+
std::vector<enum gpt_sampler_type> samplers = {
|
47
|
+
GPT_SAMPLER_TYPE_TOP_K,
|
48
|
+
GPT_SAMPLER_TYPE_TFS_Z,
|
49
|
+
GPT_SAMPLER_TYPE_TYPICAL_P,
|
50
|
+
GPT_SAMPLER_TYPE_TOP_P,
|
51
|
+
GPT_SAMPLER_TYPE_MIN_P,
|
52
|
+
GPT_SAMPLER_TYPE_XTC,
|
53
|
+
GPT_SAMPLER_TYPE_TEMPERATURE
|
56
54
|
};
|
57
55
|
|
58
|
-
std::string grammar;
|
59
|
-
|
60
|
-
// Classifier-Free Guidance
|
61
|
-
// https://arxiv.org/abs/2306.17806
|
62
|
-
std::string cfg_negative_prompt; // string to help guidance
|
63
|
-
float cfg_scale = 1.f; // how strong is guidance
|
64
|
-
|
65
|
-
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
66
|
-
|
67
|
-
std::vector<llama_token> penalty_prompt_tokens;
|
68
|
-
bool use_penalty_prompt_tokens = false;
|
69
|
-
} llama_sampling_params;
|
70
|
-
|
71
|
-
// general sampler context
|
72
|
-
// TODO: move to llama.h
|
73
|
-
struct llama_sampling_context {
|
74
|
-
// parameters that will be used for sampling
|
75
|
-
llama_sampling_params params;
|
76
|
-
|
77
|
-
// mirostat sampler state
|
78
|
-
float mirostat_mu;
|
56
|
+
std::string grammar; // optional BNF-like grammar to constrain sampling
|
79
57
|
|
80
|
-
|
58
|
+
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
81
59
|
|
82
|
-
//
|
83
|
-
|
60
|
+
// print the parameters into a string
|
61
|
+
std::string print() const;
|
62
|
+
};
|
84
63
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
64
|
+
// gpt_sampler extends llama_sampler with additional functionality:
|
65
|
+
//
|
66
|
+
// - grammar support
|
67
|
+
// - custom sampler logic based on the parameters
|
68
|
+
// - history of the last accepted tokens
|
69
|
+
// - performance metrics
|
70
|
+
//
|
71
|
+
// This goal is to have a common implementation of the sampling logic shared across the examples.
|
72
|
+
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
|
73
|
+
// complex (top-k, top-p, etc).
|
74
|
+
//
|
75
|
+
// Another example is related to the grammar. In general, the grammar constraints applied on the full
|
76
|
+
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
|
77
|
+
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
|
78
|
+
// grammar constraints are applied to the full vocabulary and the token is resampled.
|
79
|
+
//
|
80
|
+
// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
|
81
|
+
// be moved into the core llama library.
|
82
|
+
//
|
83
|
+
// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
|
84
|
+
// This can be used to access the probabilities of the rest of the non-sampled tokens.
|
85
|
+
//
|
86
|
+
// TODO: measure grammar performance
|
87
|
+
//
|
89
88
|
|
90
|
-
|
91
|
-
};
|
89
|
+
struct gpt_sampler;
|
92
90
|
|
93
|
-
|
91
|
+
// llama_sampler API overloads
|
94
92
|
|
95
|
-
|
96
|
-
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params);
|
93
|
+
struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
|
97
94
|
|
98
|
-
void
|
95
|
+
void gpt_sampler_free(struct gpt_sampler * gsmpl);
|
99
96
|
|
100
|
-
//
|
101
|
-
|
102
|
-
|
103
|
-
|
97
|
+
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
98
|
+
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
|
99
|
+
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
|
100
|
+
struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
|
104
101
|
|
105
|
-
//
|
106
|
-
void
|
102
|
+
// arguments can be nullptr to skip printing
|
103
|
+
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
|
107
104
|
|
108
|
-
//
|
109
|
-
|
105
|
+
// extended sampling implementation:
|
106
|
+
//
|
107
|
+
// - set logits
|
108
|
+
// - apply the configured sampler chain
|
109
|
+
// - check if the token fits the grammar (if any)
|
110
|
+
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
111
|
+
//
|
112
|
+
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
113
|
+
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
114
|
+
//
|
115
|
+
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
110
116
|
|
111
|
-
//
|
112
|
-
llama_token llama_sampling_last(llama_sampling_context * ctx);
|
117
|
+
// helpers
|
113
118
|
|
114
|
-
//
|
115
|
-
|
119
|
+
// access the internal list of current candidate tokens
|
120
|
+
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
|
116
121
|
|
117
|
-
//
|
118
|
-
|
122
|
+
// get the last accepted token
|
123
|
+
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
|
119
124
|
|
120
|
-
//
|
121
|
-
std::string
|
125
|
+
// print the sampler chain into a string
|
126
|
+
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
|
122
127
|
|
123
|
-
|
128
|
+
// get a string representation of the last accepted tokens
|
129
|
+
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
|
124
130
|
|
125
|
-
|
126
|
-
std::
|
131
|
+
char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
|
132
|
+
std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
|
127
133
|
|
128
|
-
|
129
|
-
|
130
|
-
// Note: When using multiple sequences, it is the caller's responsibility to call
|
131
|
-
// llama_sampling_reset when a sequence ends
|
132
|
-
//
|
133
|
-
// required:
|
134
|
-
// - ctx_main: context to use for sampling
|
135
|
-
// - ctx_sampling: sampling-specific context
|
136
|
-
//
|
137
|
-
// optional:
|
138
|
-
// - ctx_cfg: context to use for classifier-free guidance
|
139
|
-
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
140
|
-
//
|
141
|
-
// returns:
|
142
|
-
// - token: sampled token
|
143
|
-
// - candidates: vector of candidate tokens
|
144
|
-
//
|
145
|
-
llama_token llama_sampling_sample(
|
146
|
-
struct llama_sampling_context * ctx_sampling,
|
147
|
-
struct llama_context * ctx_main,
|
148
|
-
struct llama_context * ctx_cfg,
|
149
|
-
int idx = -1);
|
150
|
-
|
151
|
-
// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
|
152
|
-
llama_token_data_array llama_sampling_prepare(
|
153
|
-
struct llama_sampling_context * ctx_sampling,
|
154
|
-
struct llama_context * ctx_main,
|
155
|
-
struct llama_context * ctx_cfg,
|
156
|
-
int idx = 0,
|
157
|
-
bool apply_grammar = true,
|
158
|
-
std::vector<float> * original_logits = nullptr);
|
159
|
-
|
160
|
-
void llama_sampling_accept(
|
161
|
-
struct llama_sampling_context * ctx_sampling,
|
162
|
-
struct llama_context * ctx_main,
|
163
|
-
llama_token id,
|
164
|
-
bool apply_grammar);
|
134
|
+
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
|
135
|
+
std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);
|
package/cpp/sgemm.cpp
CHANGED
@@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
|
|
606
606
|
case 0x44:
|
607
607
|
mc = 4;
|
608
608
|
nc = 4;
|
609
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
610
|
+
gemm4xN<4>(m0, m, n0, n);
|
611
|
+
#else
|
609
612
|
gemm<4, 4>(m0, m, n0, n);
|
613
|
+
#endif
|
610
614
|
break;
|
611
615
|
case 0x43:
|
612
616
|
mc = 4;
|
613
617
|
nc = 3;
|
618
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
619
|
+
gemm4xN<3>(m0, m, n0, n);
|
620
|
+
#else
|
614
621
|
gemm<4, 3>(m0, m, n0, n);
|
622
|
+
#endif
|
615
623
|
break;
|
616
624
|
case 0x34:
|
617
625
|
mc = 3;
|
618
626
|
nc = 4;
|
627
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
628
|
+
gemmMx4<3>(m0, m, n0, n);
|
629
|
+
#else
|
619
630
|
gemm<3, 4>(m0, m, n0, n);
|
631
|
+
#endif
|
620
632
|
break;
|
621
633
|
case 0x33:
|
622
634
|
mc = 3;
|
@@ -626,12 +638,20 @@ class tinyBLAS_Q0_AVX {
|
|
626
638
|
case 0x42:
|
627
639
|
mc = 4;
|
628
640
|
nc = 2;
|
641
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
642
|
+
gemm4xN<2>(m0, m, n0, n);
|
643
|
+
#else
|
629
644
|
gemm<4, 2>(m0, m, n0, n);
|
645
|
+
#endif
|
630
646
|
break;
|
631
647
|
case 0x24:
|
632
648
|
mc = 2;
|
633
649
|
nc = 4;
|
650
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
651
|
+
gemmMx4<2>(m0, m, n0, n);
|
652
|
+
#else
|
634
653
|
gemm<2, 4>(m0, m, n0, n);
|
654
|
+
#endif
|
635
655
|
break;
|
636
656
|
#else
|
637
657
|
case 0x44:
|
@@ -639,13 +659,21 @@ class tinyBLAS_Q0_AVX {
|
|
639
659
|
case 0x42:
|
640
660
|
mc = 4;
|
641
661
|
nc = 2;
|
662
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
663
|
+
gemm4xN<2>(m0, m, n0, n);
|
664
|
+
#else
|
642
665
|
gemm<4, 2>(m0, m, n0, n);
|
666
|
+
#endif
|
643
667
|
break;
|
644
668
|
case 0x34:
|
645
669
|
case 0x24:
|
646
670
|
mc = 2;
|
647
671
|
nc = 4;
|
672
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
673
|
+
gemmMx4<2>(m0, m, n0, n);
|
674
|
+
#else
|
648
675
|
gemm<2, 4>(m0, m, n0, n);
|
676
|
+
#endif
|
649
677
|
break;
|
650
678
|
case 0x33:
|
651
679
|
#endif
|
@@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
|
|
662
690
|
case 0x41:
|
663
691
|
mc = 4;
|
664
692
|
nc = 1;
|
693
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
694
|
+
gemm4xN<1>(m0, m, n0, n);
|
695
|
+
#else
|
665
696
|
gemm<4, 1>(m0, m, n0, n);
|
697
|
+
#endif
|
666
698
|
break;
|
667
699
|
case 0x22:
|
668
700
|
mc = 2;
|
@@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
|
|
672
704
|
case 0x14:
|
673
705
|
mc = 1;
|
674
706
|
nc = 4;
|
707
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
708
|
+
gemmMx4<1>(m0, m, n0, n);
|
709
|
+
#else
|
675
710
|
gemm<1, 4>(m0, m, n0, n);
|
711
|
+
#endif
|
676
712
|
break;
|
677
713
|
case 0x31:
|
678
714
|
mc = 3;
|
@@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
|
|
708
744
|
mnpack(m0, m, np, n);
|
709
745
|
}
|
710
746
|
|
747
|
+
#if defined(__AVX2__) && defined(__F16C__)
|
748
|
+
// Templated functions for gemm of dimensions 4xN
|
749
|
+
template <int RN>
|
750
|
+
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
751
|
+
int64_t ytiles = (m - m0) / 4;
|
752
|
+
int64_t xtiles = (n - n0) / RN;
|
753
|
+
int64_t tiles = xtiles * ytiles;
|
754
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
755
|
+
int64_t start = duty * ith;
|
756
|
+
int64_t end = start + duty;
|
757
|
+
if (end > tiles)
|
758
|
+
end = tiles;
|
759
|
+
for (int64_t job = start; job < end; ++job) {
|
760
|
+
int64_t ii = m0 + job / xtiles * 4;
|
761
|
+
int64_t jj = n0 + job % xtiles * RN;
|
762
|
+
__m256 Cv[RN][4] = {};
|
763
|
+
for (int64_t l = 0; l < k; ++l) {
|
764
|
+
uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
|
765
|
+
// Convert delta values for four blocks to float values
|
766
|
+
__m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
|
767
|
+
__m256i avec0 = load(A + lda * (ii + 0) + l);
|
768
|
+
__m256i avec1 = load(A + lda * (ii + 1) + l);
|
769
|
+
__m256i avec2 = load(A + lda * (ii + 2) + l);
|
770
|
+
__m256i avec3 = load(A + lda * (ii + 3) + l);
|
771
|
+
for (int64_t j = 0; j < RN; ++j) {
|
772
|
+
__m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
|
773
|
+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
774
|
+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
775
|
+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
776
|
+
// Computation of dot product and multiplication with appropriate delta value products
|
777
|
+
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
778
|
+
updot(_mm256_sign_epi8(avec0, avec0),
|
779
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
|
780
|
+
Cv[j][0]);
|
781
|
+
Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
782
|
+
updot(_mm256_sign_epi8(avec1, avec1),
|
783
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
|
784
|
+
Cv[j][1]);
|
785
|
+
Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
786
|
+
updot(_mm256_sign_epi8(avec2, avec2),
|
787
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
|
788
|
+
Cv[j][2]);
|
789
|
+
Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
790
|
+
updot(_mm256_sign_epi8(avec3, avec3),
|
791
|
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
|
792
|
+
Cv[j][3]);
|
793
|
+
}
|
794
|
+
}
|
795
|
+
|
796
|
+
for (int64_t j = 0; j < RN; ++j)
|
797
|
+
for (int64_t i = 0; i < 4; ++i)
|
798
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
799
|
+
}
|
800
|
+
}
|
801
|
+
|
802
|
+
// Templated functions for gemm of dimensions Mx4
|
803
|
+
template <int RM>
|
804
|
+
NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
805
|
+
int64_t ytiles = (m - m0) / RM;
|
806
|
+
int64_t xtiles = (n - n0) / 4;
|
807
|
+
int64_t tiles = xtiles * ytiles;
|
808
|
+
int64_t duty = (tiles + nth - 1) / nth;
|
809
|
+
int64_t start = duty * ith;
|
810
|
+
int64_t end = start + duty;
|
811
|
+
if (end > tiles)
|
812
|
+
end = tiles;
|
813
|
+
for (int64_t job = start; job < end; ++job) {
|
814
|
+
int64_t ii = m0 + job / xtiles * RM;
|
815
|
+
int64_t jj = n0 + job % xtiles * 4;
|
816
|
+
__m256 Cv[4][RM] = {};
|
817
|
+
for (int64_t l = 0; l < k; ++l) {
|
818
|
+
uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
|
819
|
+
// Convert delta values for four blocks to float values
|
820
|
+
__m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
|
821
|
+
__m256i bvec0 = load(B + ldb * (jj + 0) + l);
|
822
|
+
__m256i bvec1 = load(B + ldb * (jj + 1) + l);
|
823
|
+
__m256i bvec2 = load(B + ldb * (jj + 2) + l);
|
824
|
+
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
|
825
|
+
for (int64_t i = 0; i < RM; ++i) {
|
826
|
+
__m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
|
827
|
+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
828
|
+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
829
|
+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
830
|
+
// Computation of dot product and multiplication with appropriate delta value products
|
831
|
+
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
832
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
833
|
+
load(A + lda * (ii + i) + l)),
|
834
|
+
_mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
|
835
|
+
Cv[0][i]);
|
836
|
+
Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
837
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
838
|
+
load(A + lda * (ii + i) + l)),
|
839
|
+
_mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
|
840
|
+
Cv[1][i]);
|
841
|
+
Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
842
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
843
|
+
load(A + lda * (ii + i) + l)),
|
844
|
+
_mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
|
845
|
+
Cv[2][i]);
|
846
|
+
Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
847
|
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
848
|
+
load(A + lda * (ii + i) + l)),
|
849
|
+
_mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
|
850
|
+
Cv[3][i]);
|
851
|
+
}
|
852
|
+
}
|
853
|
+
for (int64_t j = 0; j < 4; ++j)
|
854
|
+
for (int64_t i = 0; i < RM; ++i)
|
855
|
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
856
|
+
}
|
857
|
+
}
|
858
|
+
#endif
|
859
|
+
|
711
860
|
template <int RM, int RN>
|
712
861
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
713
862
|
int64_t ytiles = (m - m0) / RM;
|
@@ -857,6 +1006,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|
857
1006
|
assert(nth > 0);
|
858
1007
|
assert(ith < nth);
|
859
1008
|
|
1009
|
+
// only enable sgemm for prompt processing
|
1010
|
+
if (n < 2)
|
1011
|
+
return false;
|
1012
|
+
|
860
1013
|
if (Ctype != LM_GGML_TYPE_F32)
|
861
1014
|
return false;
|
862
1015
|
|