llama_cpp 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/ext/llama_cpp/llama_cpp.cpp +93 -15
- data/ext/llama_cpp/src/ggml-cuda.h +2 -0
- data/ext/llama_cpp/src/ggml-opencl.c +85 -122
- data/ext/llama_cpp/src/ggml.c +6268 -4208
- data/ext/llama_cpp/src/ggml.h +205 -12
- data/ext/llama_cpp/src/llama.cpp +159 -79
- data/ext/llama_cpp/src/llama.h +10 -10
- data/lib/llama_cpp/client.rb +1 -3
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -4
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 33b146badd1bebdf9588e48c0adac1f9924a0653aa5ec806fdf5dd288ef665d8
|
4
|
+
data.tar.gz: 134606db2b9fb10b51fc82f410d6653a6481b828d9fd05390b1570d6e198526a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 462d9e00121408c7af3934b0a663b29f99d5ad28f60a3471155509463bf26a14792c484d1fdc6054460941ae011d39b510774e225ad4ec03d60ce20a1dfef667
|
7
|
+
data.tar.gz: 4bf447ac55bba2b62d204dc975528de6664fe53af89df8ba4aa4172d4dbff709ac5b14a944326be5c71d64baa2cde00b60f7ba5e916e1fb68123c595f74ce24f
|
data/CHANGELOG.md
CHANGED
@@ -1,5 +1,15 @@
|
|
1
1
|
## [Unreleased]
|
2
2
|
|
3
|
+
## [[0.1.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.1.0...v0.1.1)] - 2023-05-21
|
4
|
+
|
5
|
+
- Add load_session_file method to Context
|
6
|
+
- Add save_session_file method to Context
|
7
|
+
|
8
|
+
**Breaking Changes**
|
9
|
+
|
10
|
+
- Bump bundled llama.cpp from master-173d0e6 to master-6986c78
|
11
|
+
- bump LLAMA_FILE_VERSION to 2
|
12
|
+
|
3
13
|
## [[0.1.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.0.7...v0.1.0)] - 2023-05-20
|
4
14
|
|
5
15
|
**Breaking Changes**
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -292,8 +292,6 @@ public:
|
|
292
292
|
// rb_define_method(rb_cLLaMAContextParams, "initialize", RUBY_METHOD_FUNC(_llama_context_params_init), 0);
|
293
293
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx=", RUBY_METHOD_FUNC(_llama_context_params_set_n_ctx), 1);
|
294
294
|
rb_define_method(rb_cLLaMAContextParams, "n_ctx", RUBY_METHOD_FUNC(_llama_context_params_get_n_ctx), 0);
|
295
|
-
rb_define_method(rb_cLLaMAContextParams, "n_parts=", RUBY_METHOD_FUNC(_llama_context_params_set_n_parts), 1);
|
296
|
-
rb_define_method(rb_cLLaMAContextParams, "n_parts", RUBY_METHOD_FUNC(_llama_context_params_get_n_parts), 0);
|
297
295
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
298
296
|
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
299
297
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
@@ -331,18 +329,6 @@ private:
|
|
331
329
|
return INT2NUM(ptr->params.n_ctx);
|
332
330
|
};
|
333
331
|
|
334
|
-
// n_parts
|
335
|
-
static VALUE _llama_context_params_set_n_parts(VALUE self, VALUE n_parts) {
|
336
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
337
|
-
ptr->params.n_parts = NUM2INT(n_parts);
|
338
|
-
return INT2NUM(ptr->params.n_parts);
|
339
|
-
};
|
340
|
-
|
341
|
-
static VALUE _llama_context_params_get_n_parts(VALUE self) {
|
342
|
-
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
343
|
-
return INT2NUM(ptr->params.n_parts);
|
344
|
-
};
|
345
|
-
|
346
332
|
// seed
|
347
333
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
348
334
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -494,6 +480,8 @@ public:
|
|
494
480
|
rb_define_method(rb_cLLaMAContext, "apply_lora_from_file", RUBY_METHOD_FUNC(_llama_context_apply_lora_from_file), -1);
|
495
481
|
rb_define_method(rb_cLLaMAContext, "kv_cache_token_count", RUBY_METHOD_FUNC(_llama_context_kv_cache_token_count), 0);
|
496
482
|
rb_define_method(rb_cLLaMAContext, "set_rng_seed", RUBY_METHOD_FUNC(_llama_context_set_rng_seed), 1);
|
483
|
+
rb_define_method(rb_cLLaMAContext, "load_session_file", RUBY_METHOD_FUNC(_llama_context_load_session_file), -1);
|
484
|
+
rb_define_method(rb_cLLaMAContext, "save_session_file", RUBY_METHOD_FUNC(_llama_context_save_session_file), -1);
|
497
485
|
rb_define_method(rb_cLLaMAContext, "sample_repetition_penalty", RUBY_METHOD_FUNC(_llama_context_sample_repetition_penalty), -1);
|
498
486
|
rb_define_method(rb_cLLaMAContext, "sample_frequency_and_presence_penalties", RUBY_METHOD_FUNC(_llama_context_sample_frequency_and_presence_penalties), -1);
|
499
487
|
rb_define_method(rb_cLLaMAContext, "sample_softmax", RUBY_METHOD_FUNC(_llama_context_sample_softmax), 1);
|
@@ -870,6 +858,97 @@ private:
|
|
870
858
|
return Qnil;
|
871
859
|
};
|
872
860
|
|
861
|
+
static VALUE _llama_context_load_session_file(int argc, VALUE* argv, VALUE self) {
|
862
|
+
VALUE kw_args = Qnil;
|
863
|
+
ID kw_table[1] = { rb_intern("session_path") };
|
864
|
+
VALUE kw_values[1] = { Qundef };
|
865
|
+
VALUE candidates = Qnil;
|
866
|
+
VALUE last_n_tokens = Qnil;
|
867
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
868
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
869
|
+
|
870
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
871
|
+
rb_raise(rb_eArgError, "session_path must be a String");
|
872
|
+
return Qnil;
|
873
|
+
}
|
874
|
+
|
875
|
+
VALUE filename = kw_values[0];
|
876
|
+
|
877
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
878
|
+
if (ctx_ptr->ctx == NULL) {
|
879
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
880
|
+
return Qnil;
|
881
|
+
}
|
882
|
+
|
883
|
+
LLaMAContextParamsWrapper* prms_ptr = RbLLaMAContextParams::get_llama_context_params(rb_iv_get(self, "@params"));
|
884
|
+
const int n_ctx = prms_ptr->params.n_ctx;
|
885
|
+
|
886
|
+
std::vector<llama_token> session_tokens(n_ctx);
|
887
|
+
size_t n_token_count_out = 0;
|
888
|
+
|
889
|
+
try {
|
890
|
+
bool res = llama_load_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), session_tokens.capacity(), &n_token_count_out);
|
891
|
+
if (!res) {
|
892
|
+
rb_raise(rb_eRuntimeError, "Failed to load session file");
|
893
|
+
return Qnil;
|
894
|
+
}
|
895
|
+
session_tokens.resize(n_token_count_out);
|
896
|
+
} catch (const std::runtime_error& e) {
|
897
|
+
rb_raise(rb_eRuntimeError, "%s", e.what());
|
898
|
+
return Qnil;
|
899
|
+
}
|
900
|
+
|
901
|
+
VALUE ary_session_tokens = rb_ary_new2(n_token_count_out);
|
902
|
+
for (size_t i = 0; i < n_token_count_out; i++) {
|
903
|
+
rb_ary_store(ary_session_tokens, i, INT2NUM(session_tokens[i]));
|
904
|
+
}
|
905
|
+
|
906
|
+
RB_GC_GUARD(filename);
|
907
|
+
return ary_session_tokens;
|
908
|
+
}
|
909
|
+
|
910
|
+
static VALUE _llama_context_save_session_file(int argc, VALUE* argv, VALUE self) {
|
911
|
+
VALUE kw_args = Qnil;
|
912
|
+
ID kw_table[2] = { rb_intern("session_path"), rb_intern("session_tokens") };
|
913
|
+
VALUE kw_values[2] = { Qundef, Qundef };
|
914
|
+
VALUE candidates = Qnil;
|
915
|
+
VALUE last_n_tokens = Qnil;
|
916
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
917
|
+
rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
|
918
|
+
|
919
|
+
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
920
|
+
rb_raise(rb_eArgError, "session_path must be a String");
|
921
|
+
return Qnil;
|
922
|
+
}
|
923
|
+
if (!RB_TYPE_P(kw_values[1], T_ARRAY)) {
|
924
|
+
rb_raise(rb_eArgError, "session_tokens must be an Array");
|
925
|
+
return Qnil;
|
926
|
+
}
|
927
|
+
|
928
|
+
VALUE filename = kw_values[0];
|
929
|
+
const size_t sz_session_tokens = RARRAY_LEN(kw_values[1]);
|
930
|
+
std::vector<llama_token> session_tokens(sz_session_tokens);
|
931
|
+
for (size_t i = 0; i < sz_session_tokens; i++) {
|
932
|
+
session_tokens[i] = NUM2INT(rb_ary_entry(kw_values[1], i));
|
933
|
+
}
|
934
|
+
|
935
|
+
LLaMAContextWrapper* ctx_ptr = get_llama_context(self);
|
936
|
+
if (ctx_ptr->ctx == NULL) {
|
937
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
938
|
+
return Qnil;
|
939
|
+
}
|
940
|
+
|
941
|
+
bool res = llama_save_session_file(ctx_ptr->ctx, StringValueCStr(filename), session_tokens.data(), sz_session_tokens);
|
942
|
+
|
943
|
+
if (!res) {
|
944
|
+
rb_raise(rb_eRuntimeError, "Failed to save session file");
|
945
|
+
return Qnil;
|
946
|
+
}
|
947
|
+
|
948
|
+
RB_GC_GUARD(filename);
|
949
|
+
return Qnil;
|
950
|
+
}
|
951
|
+
|
873
952
|
static VALUE _llama_context_sample_repetition_penalty(int argc, VALUE* argv, VALUE self) {
|
874
953
|
VALUE kw_args = Qnil;
|
875
954
|
ID kw_table[1] = { rb_intern("penalty") };
|
@@ -1411,7 +1490,6 @@ extern "C" void Init_llama_cpp(void) {
|
|
1411
1490
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_0));
|
1412
1491
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1));
|
1413
1492
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16));
|
1414
|
-
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q4_2", INT2NUM(LLAMA_FTYPE_MOSTLY_Q4_2));
|
1415
1493
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q8_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q8_0));
|
1416
1494
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_0", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_0));
|
1417
1495
|
rb_define_const(rb_mLLaMACpp, "LLAMA_FTYPE_MOSTLY_Q5_1", INT2NUM(LLAMA_FTYPE_MOSTLY_Q5_1));
|
@@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|
14
14
|
void * ggml_cuda_host_malloc(size_t size);
|
15
15
|
void ggml_cuda_host_free(void * ptr);
|
16
16
|
|
17
|
+
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
18
|
+
|
17
19
|
#ifdef __cplusplus
|
18
20
|
}
|
19
21
|
#endif
|
@@ -12,129 +12,129 @@
|
|
12
12
|
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
13
13
|
const char * clblast_dequant = MULTILINE_QUOTE(
|
14
14
|
|
15
|
+
typedef uchar uint8_t;
|
16
|
+
typedef int int32_t;
|
17
|
+
typedef uint uint32_t;
|
18
|
+
|
19
|
+
constant uint QK4_0 = 32;
|
15
20
|
struct block_q4_0
|
16
21
|
{
|
17
22
|
float d;
|
18
|
-
|
23
|
+
uint8_t qs[QK4_0 / 2];
|
19
24
|
};
|
20
25
|
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
+
constant uint QK4_1 = 32;
|
27
|
+
struct block_q4_1
|
28
|
+
{
|
29
|
+
float d;
|
30
|
+
float m;
|
31
|
+
uint8_t qs[QK4_1 / 2];
|
32
|
+
};
|
26
33
|
|
27
|
-
|
34
|
+
constant uint QK5_0 = 32;
|
35
|
+
struct __attribute__ ((packed)) block_q5_0
|
36
|
+
{
|
37
|
+
half d;
|
38
|
+
uint32_t qh;
|
39
|
+
uint8_t qs[QK5_0 / 2];
|
40
|
+
};
|
28
41
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
42
|
+
constant uint QK5_1 = 32;
|
43
|
+
struct block_q5_1
|
44
|
+
{
|
45
|
+
half d;
|
46
|
+
half m;
|
47
|
+
uint32_t qh;
|
48
|
+
uint8_t qs[QK5_1 / 2];
|
49
|
+
};
|
33
50
|
|
34
|
-
|
51
|
+
constant uint QK8_0 = 32;
|
52
|
+
struct block_q8_0
|
35
53
|
{
|
36
54
|
float d;
|
37
|
-
|
38
|
-
uchar qs[16];
|
55
|
+
uint8_t qs[QK8_0];
|
39
56
|
};
|
40
57
|
|
41
|
-
__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
|
42
|
-
const uint i = get_global_id(0) / 32;
|
43
|
-
const uint l = get_local_id(0);
|
44
58
|
|
45
|
-
|
46
|
-
|
59
|
+
__kernel void dequantize_row_q4_0(__global struct block_q4_0* x, __global float* y) {
|
60
|
+
constant uint qk = QK4_0;
|
61
|
+
|
62
|
+
const uint i = get_global_id(0) / qk;
|
63
|
+
const uint j = get_local_id(0);
|
64
|
+
|
65
|
+
const float d = x[i].d;
|
47
66
|
|
48
|
-
const
|
67
|
+
const int x0 = (x[i].qs[j] & 0xf) - 8;
|
68
|
+
const int x1 = (x[i].qs[j] >> 4) - 8;
|
49
69
|
|
50
|
-
|
51
|
-
|
52
|
-
result[index + 1] = (vi >> 4) * d + m;
|
70
|
+
y[i*qk + j + 0 ] = x0*d;
|
71
|
+
y[i*qk + j + qk/2] = x1*d;
|
53
72
|
}
|
54
73
|
|
55
|
-
struct
|
56
|
-
|
57
|
-
ushort d;
|
58
|
-
uchar qs[8];
|
59
|
-
};
|
74
|
+
__kernel void dequantize_row_q4_1(__global struct block_q4_1* x, __global float* y) {
|
75
|
+
constant uint qk = QK4_1;
|
60
76
|
|
61
|
-
|
62
|
-
const uint
|
63
|
-
const uint l = get_local_id(0);
|
77
|
+
const uint i = get_global_id(0) / qk;
|
78
|
+
const uint j = get_local_id(0);
|
64
79
|
|
65
|
-
const float d =
|
80
|
+
const float d = x[i].d;
|
81
|
+
const float m = x[i].m;
|
66
82
|
|
67
|
-
const
|
83
|
+
const int x0 = (x[i].qs[j] & 0xf);
|
84
|
+
const int x1 = (x[i].qs[j] >> 4);
|
68
85
|
|
69
|
-
|
70
|
-
|
71
|
-
result[index + 1] = ((vi >> 4) - 8)*d;
|
86
|
+
y[i*qk + j + 0 ] = x0*d + m;
|
87
|
+
y[i*qk + j + qk/2] = x1*d + m;
|
72
88
|
}
|
73
89
|
|
90
|
+
__kernel void dequantize_row_q5_0(__global struct block_q5_0* x, __global float* y) {
|
91
|
+
constant uint qk = QK5_0;
|
74
92
|
|
75
|
-
|
76
|
-
|
77
|
-
float d;
|
78
|
-
uint qh;
|
79
|
-
uchar qs[16];
|
80
|
-
};
|
93
|
+
const uint i = get_global_id(0) / qk;
|
94
|
+
const uint j = get_local_id(0);
|
81
95
|
|
82
|
-
|
83
|
-
const uint i = get_global_id(0) / 32;
|
84
|
-
const uint l = get_local_id(0);
|
96
|
+
const float d = vload_half(0, (__global half*) &x[i].d);
|
85
97
|
|
86
|
-
|
98
|
+
uint32_t qh = x[i].qh;
|
87
99
|
|
88
|
-
const
|
100
|
+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
101
|
+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
89
102
|
|
90
|
-
const
|
103
|
+
const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
|
104
|
+
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
|
91
105
|
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
const uint index = i*32 + l2;
|
96
|
-
result[index + 0] = (((vi & 0xf) | vh0) - 16)*d;
|
97
|
-
result[index + 1] = (((vi >> 4) | vh1) - 16)*d;
|
106
|
+
y[i*qk + j + 0 ] = x0*d;
|
107
|
+
y[i*qk + j + qk/2] = x1*d;
|
98
108
|
}
|
99
109
|
|
100
|
-
struct block_q5_1
|
101
|
-
|
102
|
-
ushort d;
|
103
|
-
ushort m;
|
104
|
-
uint qh;
|
105
|
-
uchar qs[16];
|
106
|
-
};
|
110
|
+
__kernel void dequantize_row_q5_1(__global struct block_q5_1* x, __global float* y) {
|
111
|
+
constant uint qk = QK5_1;
|
107
112
|
|
108
|
-
|
109
|
-
const uint
|
110
|
-
const uint l = get_local_id(0);
|
113
|
+
const uint i = get_global_id(0) / qk;
|
114
|
+
const uint j = get_local_id(0);
|
111
115
|
|
112
|
-
const float d = vload_half(0, (__global half*) &
|
113
|
-
const float m = vload_half(0, (__global half*) &
|
116
|
+
const float d = vload_half(0, (__global half*) &x[i].d);
|
117
|
+
const float m = vload_half(0, (__global half*) &x[i].m);
|
114
118
|
|
115
|
-
|
119
|
+
uint32_t qh = x[i].qh;
|
116
120
|
|
117
|
-
const
|
121
|
+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
|
122
|
+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
|
118
123
|
|
119
|
-
const
|
120
|
-
const
|
124
|
+
const int x0 = (x[i].qs[j] & 0xf) | xh_0;
|
125
|
+
const int x1 = (x[i].qs[j] >> 4) | xh_1;
|
121
126
|
|
122
|
-
|
123
|
-
|
124
|
-
result[index + 1] = ((vi >> 4) | vh1)*d + m;
|
127
|
+
y[i*qk + j + 0 ] = x0*d + m;
|
128
|
+
y[i*qk + j + qk/2] = x1*d + m;
|
125
129
|
}
|
126
130
|
|
127
|
-
struct block_q8_0
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
};
|
132
|
-
|
133
|
-
__kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global float* result) {
|
134
|
-
const uint i = get_global_id(0) / 32;
|
135
|
-
const uint l = get_local_id(0);
|
131
|
+
__kernel void dequantize_row_q8_0(__global struct block_q8_0* x, __global float* y) {
|
132
|
+
constant uint qk = QK8_0;
|
133
|
+
const uint i = get_global_id(0) / qk;
|
134
|
+
const uint j = get_local_id(0);
|
136
135
|
|
137
|
-
|
136
|
+
const float d = x[i].d;
|
137
|
+
y[i*qk + j] = x[i].qs[j]*d;
|
138
138
|
}
|
139
139
|
|
140
140
|
);
|
@@ -148,26 +148,12 @@ __kernel void dequantize_row_q8_0(__global struct block_q8_0* blocks, __global f
|
|
148
148
|
} \
|
149
149
|
} while (0)
|
150
150
|
|
151
|
-
#define QK5_0 32
|
152
|
-
typedef struct {
|
153
|
-
ggml_fp16_t d; // delta
|
154
|
-
uint8_t qh[4]; // 5-th bit of quants
|
155
|
-
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
156
|
-
} block_q5_0;
|
157
|
-
|
158
|
-
|
159
|
-
typedef struct {
|
160
|
-
float d; // delta
|
161
|
-
uint32_t qh; // 5-th bit of quants
|
162
|
-
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
163
|
-
} cl_block_q5_0;
|
164
|
-
|
165
151
|
static cl_platform_id platform;
|
166
152
|
static cl_device_id device;
|
167
153
|
static cl_context context;
|
168
154
|
static cl_command_queue queue;
|
169
155
|
static cl_program program;
|
170
|
-
static cl_kernel kernel_q4_0, kernel_q4_1,
|
156
|
+
static cl_kernel kernel_q4_0, kernel_q4_1, kernel_q5_0, kernel_q5_1, kernel_q8_0;
|
171
157
|
static cl_mem cl_buffer_a, cl_buffer_qb, cl_buffer_b, cl_buffer_c;
|
172
158
|
static size_t cl_size_a = 0, cl_size_qb = 0, cl_size_b = 0, cl_size_c = 0;
|
173
159
|
|
@@ -238,8 +224,6 @@ void ggml_cl_init(void) {
|
|
238
224
|
CL_CHECK(err, "clCreateKernel");
|
239
225
|
kernel_q4_1 = clCreateKernel(program, "dequantize_row_q4_1", &err);
|
240
226
|
CL_CHECK(err, "clCreateKernel");
|
241
|
-
kernel_q4_2 = clCreateKernel(program, "dequantize_row_q4_2", &err);
|
242
|
-
CL_CHECK(err, "clCreateKernel");
|
243
227
|
kernel_q5_0 = clCreateKernel(program, "dequantize_row_q5_0", &err);
|
244
228
|
CL_CHECK(err, "clCreateKernel");
|
245
229
|
kernel_q5_1 = clCreateKernel(program, "dequantize_row_q5_1", &err);
|
@@ -274,7 +258,6 @@ void ggml_cl_sgemm_wrapper(
|
|
274
258
|
cl_kernel kernel;
|
275
259
|
size_t global = n * k, local, size_qb;
|
276
260
|
bool dequant;
|
277
|
-
cl_block_q5_0* cl_host_b;
|
278
261
|
|
279
262
|
switch (btype) {
|
280
263
|
case GGML_TYPE_F32:
|
@@ -292,28 +275,11 @@ void ggml_cl_sgemm_wrapper(
|
|
292
275
|
local = 16;
|
293
276
|
size_qb = global * (sizeof(float) * 2 + local) / 32;
|
294
277
|
break;
|
295
|
-
case GGML_TYPE_Q4_2:
|
296
|
-
dequant = true;
|
297
|
-
kernel = kernel_q4_2;
|
298
|
-
local = 8;
|
299
|
-
size_qb = global * (sizeof(ggml_fp16_t) + local) / 16;
|
300
|
-
break;
|
301
278
|
case GGML_TYPE_Q5_0:
|
302
279
|
dequant = true;
|
303
280
|
kernel = kernel_q5_0;
|
304
281
|
local = 16;
|
305
|
-
|
306
|
-
// 20 and 24 bytes are fine. Workaround to do the fp16 to fp32 step on CPU...
|
307
|
-
// TODO Find the reason, fix and remove workaround.
|
308
|
-
const block_q5_0* b = (const block_q5_0*) host_b;
|
309
|
-
cl_host_b = (cl_block_q5_0*) malloc(sizeof(cl_block_q5_0) * global / 32);
|
310
|
-
for (size_t i = 0; i < global / 32; i++) {
|
311
|
-
cl_host_b[i].d = ggml_fp16_to_fp32(b[i].d);
|
312
|
-
memcpy(&cl_host_b[i].qh, b[i].qh, sizeof(uint32_t));
|
313
|
-
memcpy(&cl_host_b[i].qs, b[i].qs, QK5_0 / 2);
|
314
|
-
}
|
315
|
-
host_b = (const float*) cl_host_b;
|
316
|
-
size_qb = global * (sizeof(float) + sizeof(uint32_t) + local) / 32;
|
282
|
+
size_qb = global * (sizeof(ggml_fp16_t) + sizeof(uint32_t) + local) / 32;
|
317
283
|
break;
|
318
284
|
case GGML_TYPE_Q5_1:
|
319
285
|
dequant = true;
|
@@ -392,7 +358,4 @@ void ggml_cl_sgemm_wrapper(
|
|
392
358
|
clWaitForEvents(1, &ev_c);
|
393
359
|
clReleaseEvent(ev_sgemm);
|
394
360
|
clReleaseEvent(ev_c);
|
395
|
-
if (btype == GGML_TYPE_Q5_0) {
|
396
|
-
free((void*) cl_host_b);
|
397
|
-
}
|
398
361
|
}
|