llama_cpp 0.7.1 → 0.8.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +41 -21
- data/ext/llama_cpp/src/ggml-metal.m +44 -3
- data/ext/llama_cpp/src/ggml-metal.metal +162 -1
- data/ext/llama_cpp/src/ggml-opencl.cpp +30 -56
- data/ext/llama_cpp/src/ggml.c +13 -9
- data/ext/llama_cpp/src/ggml.h +3 -2
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/llama.cpp +359 -58
- data/ext/llama_cpp/src/llama.h +18 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -4
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8045208b5f7801979212a4f6ed395217e78f06bcfbc2d0362aaaa04c529745cd
|
4
|
+
data.tar.gz: 4011dfe279d8d4041c6c79dc5a6bad199777f83b5f0559f11ccd2f68c957e462
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d15e74da491773961006eca8ca6c6d80b30ffc995c56a9140961be0002eb09134f1a029c4e8ee192497fb7256fe36cf1c3ed928967ce57ece4c7a0904392c8fe
|
7
|
+
data.tar.gz: a863596304ddb9ac5e4be2b2b65bebc7d3913705b8a0f516debfee0ca213f9dca69707edda8d70cfafb15500fcb6e70cffb6d5d1119302d24e05059c50f0da77
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
|
2
|
+
|
3
|
+
**Breaking Changes**
|
4
|
+
- Bump bundled llama.cpp from b1380 to b1405
|
5
|
+
- Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
|
6
|
+
- Add `special` keyword argument to `tokenize` method in Model.
|
7
|
+
- Add `n_seq_max` keyword argument to `initialize` method in Batch.
|
8
|
+
|
1
9
|
## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
|
2
10
|
|
3
11
|
- Bump bundled llama.cpp from b1334 to b1380.
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -63,8 +63,8 @@ public:
|
|
63
63
|
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
64
|
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
65
|
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
-
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id),
|
67
|
-
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id),
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
|
68
68
|
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
69
|
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
70
|
}
|
@@ -74,10 +74,10 @@ private:
|
|
74
74
|
|
75
75
|
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
76
|
VALUE kw_args = Qnil;
|
77
|
-
ID kw_table[
|
78
|
-
VALUE kw_values[
|
77
|
+
ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
|
78
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
79
79
|
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
-
rb_get_kwargs(kw_args, kw_table,
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
81
81
|
|
82
82
|
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
83
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
@@ -87,12 +87,17 @@ private:
|
|
87
87
|
rb_raise(rb_eArgError, "embd must be an integer");
|
88
88
|
return Qnil;
|
89
89
|
}
|
90
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
91
|
+
rb_raise(rb_eArgError, "n_seq_max must be an integer");
|
92
|
+
return Qnil;
|
93
|
+
}
|
90
94
|
|
91
95
|
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
96
|
const int32_t embd = NUM2INT(kw_values[1]);
|
97
|
+
const int32_t n_seq_max = NUM2INT(kw_values[2]);
|
93
98
|
|
94
99
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
-
ptr->batch = llama_batch_init(n_tokens, embd);
|
100
|
+
ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
|
96
101
|
|
97
102
|
return Qnil;
|
98
103
|
}
|
@@ -190,25 +195,35 @@ private:
|
|
190
195
|
}
|
191
196
|
|
192
197
|
// seq_id
|
193
|
-
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE
|
198
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
|
194
199
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
-
const int32_t
|
196
|
-
if (
|
197
|
-
rb_raise(rb_eArgError, "
|
200
|
+
const int32_t i = NUM2INT(i_);
|
201
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
202
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
203
|
+
return Qnil;
|
204
|
+
}
|
205
|
+
const int32_t j = NUM2INT(j_);
|
206
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
207
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
198
208
|
return Qnil;
|
199
209
|
}
|
200
|
-
ptr->batch.seq_id[
|
201
|
-
return INT2NUM(ptr->batch.seq_id[
|
210
|
+
ptr->batch.seq_id[i][j] = NUM2INT(value);
|
211
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
202
212
|
}
|
203
213
|
|
204
|
-
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE
|
214
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
|
205
215
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
-
const int32_t
|
207
|
-
if (
|
208
|
-
rb_raise(rb_eArgError, "
|
216
|
+
const int32_t i = NUM2INT(i_);
|
217
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
218
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
219
|
+
return Qnil;
|
220
|
+
}
|
221
|
+
const int32_t j = NUM2INT(j_);
|
222
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
223
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
209
224
|
return Qnil;
|
210
225
|
}
|
211
|
-
return INT2NUM(ptr->batch.seq_id[
|
226
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
212
227
|
}
|
213
228
|
|
214
229
|
// logits
|
@@ -1319,10 +1334,10 @@ private:
|
|
1319
1334
|
|
1320
1335
|
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1321
1336
|
VALUE kw_args = Qnil;
|
1322
|
-
ID kw_table[
|
1323
|
-
VALUE kw_values[
|
1337
|
+
ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
|
1338
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
1324
1339
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1325
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1340
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
1326
1341
|
|
1327
1342
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1328
1343
|
rb_raise(rb_eArgError, "text must be a String");
|
@@ -1336,15 +1351,20 @@ private:
|
|
1336
1351
|
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1337
1352
|
return Qnil;
|
1338
1353
|
}
|
1354
|
+
if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
|
1355
|
+
rb_raise(rb_eArgError, "special must be a boolean");
|
1356
|
+
return Qnil;
|
1357
|
+
}
|
1339
1358
|
|
1340
1359
|
VALUE text_ = kw_values[0];
|
1341
1360
|
std::string text = StringValueCStr(text_);
|
1342
1361
|
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1362
|
+
const bool special = kw_values[3] == Qtrue ? true : false;
|
1343
1363
|
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1344
1364
|
|
1345
1365
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1346
1366
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1347
|
-
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1367
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
|
1348
1368
|
|
1349
1369
|
if (n_tokens < 0) {
|
1350
1370
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -73,6 +73,8 @@ struct ggml_metal_context {
|
|
73
73
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
74
74
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
75
75
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
76
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
77
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
76
78
|
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
77
79
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
78
80
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
@@ -87,6 +89,8 @@ struct ggml_metal_context {
|
|
87
89
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
88
90
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
89
91
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
92
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
93
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
90
94
|
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
91
95
|
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
92
96
|
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
@@ -97,6 +101,8 @@ struct ggml_metal_context {
|
|
97
101
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
98
102
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
99
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
104
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
105
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
100
106
|
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
101
107
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
102
108
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
@@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
254
260
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
255
261
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
256
262
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
263
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
264
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
257
265
|
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
258
266
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
259
267
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
@@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
268
276
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
269
277
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
270
278
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
279
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
280
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
271
281
|
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
272
282
|
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
273
283
|
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
@@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
278
288
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
279
289
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
280
290
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
281
|
-
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
282
291
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
292
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
293
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
294
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
283
295
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
284
296
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
285
297
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
@@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
346
358
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
347
359
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
348
360
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
361
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
362
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
349
363
|
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
350
364
|
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
351
365
|
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
@@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
360
374
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
361
375
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
362
376
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
377
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
378
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
363
379
|
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
364
380
|
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
365
381
|
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
@@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
370
386
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
371
387
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
372
388
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
373
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
374
389
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
390
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
391
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
392
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
375
393
|
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
376
394
|
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
377
395
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
@@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
|
|
1052
1070
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
1053
1071
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
1054
1072
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
1073
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
1074
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
1055
1075
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
1056
1076
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
1057
1077
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
@@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
|
|
1121
1141
|
nth1 = 8;
|
1122
1142
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1123
1143
|
} break;
|
1144
|
+
case GGML_TYPE_Q5_0:
|
1145
|
+
{
|
1146
|
+
GGML_ASSERT(ne02 == 1);
|
1147
|
+
GGML_ASSERT(ne12 == 1);
|
1148
|
+
|
1149
|
+
nth0 = 8;
|
1150
|
+
nth1 = 8;
|
1151
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1152
|
+
} break;
|
1153
|
+
case GGML_TYPE_Q5_1:
|
1154
|
+
{
|
1155
|
+
GGML_ASSERT(ne02 == 1);
|
1156
|
+
GGML_ASSERT(ne12 == 1);
|
1157
|
+
|
1158
|
+
nth0 = 8;
|
1159
|
+
nth1 = 8;
|
1160
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1161
|
+
} break;
|
1124
1162
|
case GGML_TYPE_Q8_0:
|
1125
1163
|
{
|
1126
1164
|
GGML_ASSERT(ne02 == 1);
|
@@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
|
|
1201
1239
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1202
1240
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
1203
1241
|
|
1204
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1242
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1243
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1205
1244
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1206
1245
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1207
1246
|
}
|
@@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
|
|
1233
1272
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1234
1273
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1235
1274
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
1275
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
1276
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
1236
1277
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
1237
1278
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
1238
1279
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
@@ -18,6 +18,21 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK5_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
uint8_t qh[4]; // 5-th bit of quants
|
25
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
26
|
+
} block_q5_0;
|
27
|
+
|
28
|
+
#define QK5_1 32
|
29
|
+
typedef struct {
|
30
|
+
half d; // delta
|
31
|
+
half m; // min
|
32
|
+
uint8_t qh[4]; // 5-th bit of quants
|
33
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
34
|
+
} block_q5_1;
|
35
|
+
|
21
36
|
#define QK8_0 32
|
22
37
|
typedef struct {
|
23
38
|
half d; // delta
|
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
|
|
399
414
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
400
415
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
401
416
|
float d = qb_curr->d;
|
417
|
+
|
402
418
|
float2 acc = 0.f;
|
419
|
+
|
403
420
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
421
|
+
|
404
422
|
for (int i = 0; i < 8; i+=2) {
|
405
423
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
406
424
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
417
435
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
418
436
|
float d = qb_curr->d;
|
419
437
|
float m = qb_curr->m;
|
420
|
-
|
438
|
+
|
421
439
|
float2 acc = 0.f;
|
440
|
+
|
441
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
442
|
+
|
422
443
|
for (int i = 0; i < 8; i+=2) {
|
423
444
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
424
445
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
428
449
|
return d * (acc[0] + acc[1]) + sumy * m;
|
429
450
|
}
|
430
451
|
|
452
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
453
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
454
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
455
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
456
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
457
|
+
float d = qb_curr->d;
|
458
|
+
|
459
|
+
float2 acc = 0.f;
|
460
|
+
|
461
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
462
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
463
|
+
|
464
|
+
for (int i = 0; i < 8; i+=2) {
|
465
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
466
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
467
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
468
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
469
|
+
}
|
470
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
471
|
+
}
|
472
|
+
|
473
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
474
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
475
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
476
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
477
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
478
|
+
float d = qb_curr->d;
|
479
|
+
float m = qb_curr->m;
|
480
|
+
|
481
|
+
float2 acc = 0.f;
|
482
|
+
|
483
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
484
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
485
|
+
|
486
|
+
for (int i = 0; i < 8; i+=2) {
|
487
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
488
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
489
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
490
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
491
|
+
}
|
492
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
493
|
+
}
|
494
|
+
|
431
495
|
// putting them in the kernel cause a significant performance penalty
|
432
496
|
#define N_DST 4 // each SIMD group works on 4 rows
|
433
497
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
525
589
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
526
590
|
}
|
527
591
|
|
592
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
593
|
+
device const void * src0,
|
594
|
+
device const float * src1,
|
595
|
+
device float * dst,
|
596
|
+
constant int64_t & ne00,
|
597
|
+
constant int64_t & ne01[[buffer(4)]],
|
598
|
+
constant int64_t & ne02[[buffer(5)]],
|
599
|
+
constant int64_t & ne10[[buffer(9)]],
|
600
|
+
constant int64_t & ne12[[buffer(11)]],
|
601
|
+
constant int64_t & ne0[[buffer(15)]],
|
602
|
+
constant int64_t & ne1[[buffer(16)]],
|
603
|
+
constant uint & gqa[[buffer(17)]],
|
604
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
605
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
606
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
607
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
608
|
+
}
|
609
|
+
|
610
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
611
|
+
device const void * src0,
|
612
|
+
device const float * src1,
|
613
|
+
device float * dst,
|
614
|
+
constant int64_t & ne00,
|
615
|
+
constant int64_t & ne01[[buffer(4)]],
|
616
|
+
constant int64_t & ne02[[buffer(5)]],
|
617
|
+
constant int64_t & ne10[[buffer(9)]],
|
618
|
+
constant int64_t & ne12[[buffer(11)]],
|
619
|
+
constant int64_t & ne0[[buffer(15)]],
|
620
|
+
constant int64_t & ne1[[buffer(16)]],
|
621
|
+
constant uint & gqa[[buffer(17)]],
|
622
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
623
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
624
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
625
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
626
|
+
}
|
627
|
+
|
628
|
+
|
528
629
|
#define NB_Q8_0 8
|
529
630
|
|
530
631
|
kernel void kernel_mul_mv_q8_0_f32(
|
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
2149
2250
|
}
|
2150
2251
|
}
|
2151
2252
|
|
2253
|
+
template <typename type4x4>
|
2254
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
2255
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
2256
|
+
const float d = xb->d;
|
2257
|
+
const float md = -16.h * xb->d;
|
2258
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2259
|
+
|
2260
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2261
|
+
|
2262
|
+
const int x_mv = il ? 4 : 0;
|
2263
|
+
|
2264
|
+
const int gh_mv = il ? 12 : 0;
|
2265
|
+
const int gh_bk = il ? 0 : 4;
|
2266
|
+
|
2267
|
+
for (int i = 0; i < 8; i++) {
|
2268
|
+
// extract the 5-th bits for x0 and x1
|
2269
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2270
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2271
|
+
|
2272
|
+
// combine the 4-bits from qs with the 5th bit
|
2273
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2274
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2275
|
+
|
2276
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
2277
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
2278
|
+
}
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
template <typename type4x4>
|
2282
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
2283
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
2284
|
+
const float d = xb->d;
|
2285
|
+
const float m = xb->m;
|
2286
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2287
|
+
|
2288
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2289
|
+
|
2290
|
+
const int x_mv = il ? 4 : 0;
|
2291
|
+
|
2292
|
+
const int gh_mv = il ? 12 : 0;
|
2293
|
+
const int gh_bk = il ? 0 : 4;
|
2294
|
+
|
2295
|
+
for (int i = 0; i < 8; i++) {
|
2296
|
+
// extract the 5-th bits for x0 and x1
|
2297
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2298
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2299
|
+
|
2300
|
+
// combine the 4-bits from qs with the 5th bit
|
2301
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2302
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2303
|
+
|
2304
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
2305
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
2306
|
+
}
|
2307
|
+
}
|
2308
|
+
|
2152
2309
|
template <typename type4x4>
|
2153
2310
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
2154
2311
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|
2490
2647
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2491
2648
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2492
2649
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2650
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
2651
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
2493
2652
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
2494
2653
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
2495
2654
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|
2518
2677
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2519
2678
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2520
2679
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2680
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
2681
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
2521
2682
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2522
2683
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2523
2684
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
@@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
|
|
1395
1395
|
const int64_t ne01 = src0->ne[1];
|
1396
1396
|
const int64_t ne02 = src0->ne[2];
|
1397
1397
|
const int64_t ne03 = src0->ne[3];
|
1398
|
-
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
1399
1398
|
const int64_t ne10 = src1->ne[0];
|
1400
1399
|
const int64_t ne11 = src1->ne[1];
|
1401
1400
|
const int64_t ne12 = src1->ne[2];
|
1402
1401
|
const int64_t ne13 = src1->ne[3];
|
1403
|
-
const int64_t nb10 = src1->nb[0];
|
1404
1402
|
const int nb2 = dst->nb[2];
|
1405
1403
|
const int nb3 = dst->nb[3];
|
1406
1404
|
size_t x_size;
|
1407
1405
|
size_t d_size;
|
1408
1406
|
|
1409
|
-
cl_mem d_X = ggml_cl_pool_malloc(
|
1407
|
+
cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
|
1410
1408
|
cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
|
1411
|
-
cl_mem d_D = ggml_cl_pool_malloc(
|
1409
|
+
cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
|
1412
1410
|
|
1413
1411
|
|
1414
1412
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
1415
1413
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
1416
|
-
const int i0 = i03*ne02 + i02;
|
1417
|
-
|
1418
1414
|
cl_event ev;
|
1419
1415
|
|
1420
1416
|
// copy src0 to device
|
1421
|
-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X,
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
1443
|
-
} else {
|
1444
|
-
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
1445
|
-
const int64_t i13 = i03%ne13;
|
1446
|
-
const int64_t i12 = i02%ne12;
|
1447
|
-
const int64_t i11 = i01%ne11;
|
1448
|
-
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
|
1449
|
-
|
1450
|
-
cl_int x_offset = i01*ne00;
|
1451
|
-
cl_int y_offset = i1*ne10;
|
1452
|
-
cl_int d_offset = i01*ne00;
|
1453
|
-
|
1454
|
-
// compute
|
1455
|
-
size_t global = ne00;
|
1456
|
-
cl_int ky = ne10;
|
1457
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
1458
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
1459
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
1460
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
1461
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
1462
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
1463
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
1464
|
-
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
1465
|
-
}
|
1466
|
-
}
|
1417
|
+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
|
1418
|
+
|
1419
|
+
const int64_t i13 = i03%ne13;
|
1420
|
+
const int64_t i12 = i02%ne12;
|
1421
|
+
const int i1 = i13*ne12*ne11 + i12*ne11;
|
1422
|
+
|
1423
|
+
cl_int x_offset = 0;
|
1424
|
+
cl_int y_offset = i1*ne10;
|
1425
|
+
cl_int d_offset = 0;
|
1426
|
+
|
1427
|
+
size_t global = ne00 * ne01;
|
1428
|
+
cl_int ky = ne10 * ne11;
|
1429
|
+
|
1430
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
1431
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
1432
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
1433
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
1434
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
1435
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
1436
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
1437
|
+
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
1467
1438
|
|
1468
1439
|
CL_CHECK(clReleaseEvent(ev));
|
1469
1440
|
CL_CHECK(clFinish(queue));
|
@@ -1568,7 +1539,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|
1568
1539
|
ggml_cl_pool_free(d_D, d_size);
|
1569
1540
|
}
|
1570
1541
|
|
1571
|
-
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t
|
1542
|
+
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
|
1572
1543
|
GGML_ASSERT(fp16_support);
|
1573
1544
|
|
1574
1545
|
const int64_t ne00 = src0->ne[0];
|
@@ -1598,6 +1569,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|
1598
1569
|
const int y_ne = ne11 * ne10;
|
1599
1570
|
const int d_ne = ne11 * ne01;
|
1600
1571
|
|
1572
|
+
GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
|
1573
|
+
GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
|
1574
|
+
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
|
1575
|
+
|
1601
1576
|
size_t x_size;
|
1602
1577
|
size_t y_size;
|
1603
1578
|
size_t d_size;
|
@@ -1634,7 +1609,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|
1634
1609
|
|
1635
1610
|
// convert src1 to fp16
|
1636
1611
|
// TODO: use multiple threads
|
1637
|
-
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
|
1638
1612
|
char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
|
1639
1613
|
if (src1_cont_rows) {
|
1640
1614
|
if (src1_cont_cols) {
|
@@ -1897,8 +1871,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
|
|
1897
1871
|
}
|
1898
1872
|
|
1899
1873
|
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
1900
|
-
if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
|
1901
|
-
return
|
1874
|
+
if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
|
1875
|
+
return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
|
1902
1876
|
}
|
1903
1877
|
return 0;
|
1904
1878
|
}
|