llama_cpp 0.7.1 → 0.8.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/ext/llama_cpp/llama_cpp.cpp +41 -21
- data/ext/llama_cpp/src/ggml-metal.m +44 -3
- data/ext/llama_cpp/src/ggml-metal.metal +162 -1
- data/ext/llama_cpp/src/ggml-opencl.cpp +30 -56
- data/ext/llama_cpp/src/ggml.c +13 -9
- data/ext/llama_cpp/src/ggml.h +3 -2
- data/ext/llama_cpp/src/k_quants.c +12 -20
- data/ext/llama_cpp/src/llama.cpp +359 -58
- data/ext/llama_cpp/src/llama.h +18 -12
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +4 -4
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8045208b5f7801979212a4f6ed395217e78f06bcfbc2d0362aaaa04c529745cd
|
4
|
+
data.tar.gz: 4011dfe279d8d4041c6c79dc5a6bad199777f83b5f0559f11ccd2f68c957e462
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: d15e74da491773961006eca8ca6c6d80b30ffc995c56a9140961be0002eb09134f1a029c4e8ee192497fb7256fe36cf1c3ed928967ce57ece4c7a0904392c8fe
|
7
|
+
data.tar.gz: a863596304ddb9ac5e4be2b2b65bebc7d3913705b8a0f516debfee0ca213f9dca69707edda8d70cfafb15500fcb6e70cffb6d5d1119302d24e05059c50f0da77
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
|
2
|
+
|
3
|
+
**Breaking Changes**
|
4
|
+
- Bump bundled llama.cpp from b1380 to b1405
|
5
|
+
- Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
|
6
|
+
- Add `special` keyword argument to `tokenize` method in Model.
|
7
|
+
- Add `n_seq_max` keyword argument to `initialize` method in Batch.
|
8
|
+
|
1
9
|
## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
|
2
10
|
|
3
11
|
- Bump bundled llama.cpp from b1334 to b1380.
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -63,8 +63,8 @@ public:
|
|
63
63
|
rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
|
64
64
|
rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
|
65
65
|
rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
|
66
|
-
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id),
|
67
|
-
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id),
|
66
|
+
rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
|
67
|
+
rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
|
68
68
|
rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
|
69
69
|
rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
|
70
70
|
}
|
@@ -74,10 +74,10 @@ private:
|
|
74
74
|
|
75
75
|
static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
|
76
76
|
VALUE kw_args = Qnil;
|
77
|
-
ID kw_table[
|
78
|
-
VALUE kw_values[
|
77
|
+
ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
|
78
|
+
VALUE kw_values[3] = { Qundef, Qundef, Qundef };
|
79
79
|
rb_scan_args(argc, argv, ":", &kw_args);
|
80
|
-
rb_get_kwargs(kw_args, kw_table,
|
80
|
+
rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
|
81
81
|
|
82
82
|
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
83
83
|
rb_raise(rb_eArgError, "n_tokens must be an integer");
|
@@ -87,12 +87,17 @@ private:
|
|
87
87
|
rb_raise(rb_eArgError, "embd must be an integer");
|
88
88
|
return Qnil;
|
89
89
|
}
|
90
|
+
if (!RB_INTEGER_TYPE_P(kw_values[2])) {
|
91
|
+
rb_raise(rb_eArgError, "n_seq_max must be an integer");
|
92
|
+
return Qnil;
|
93
|
+
}
|
90
94
|
|
91
95
|
const int32_t n_tokens = NUM2INT(kw_values[0]);
|
92
96
|
const int32_t embd = NUM2INT(kw_values[1]);
|
97
|
+
const int32_t n_seq_max = NUM2INT(kw_values[2]);
|
93
98
|
|
94
99
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
95
|
-
ptr->batch = llama_batch_init(n_tokens, embd);
|
100
|
+
ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
|
96
101
|
|
97
102
|
return Qnil;
|
98
103
|
}
|
@@ -190,25 +195,35 @@ private:
|
|
190
195
|
}
|
191
196
|
|
192
197
|
// seq_id
|
193
|
-
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE
|
198
|
+
static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
|
194
199
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
195
|
-
const int32_t
|
196
|
-
if (
|
197
|
-
rb_raise(rb_eArgError, "
|
200
|
+
const int32_t i = NUM2INT(i_);
|
201
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
202
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
203
|
+
return Qnil;
|
204
|
+
}
|
205
|
+
const int32_t j = NUM2INT(j_);
|
206
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
207
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
198
208
|
return Qnil;
|
199
209
|
}
|
200
|
-
ptr->batch.seq_id[
|
201
|
-
return INT2NUM(ptr->batch.seq_id[
|
210
|
+
ptr->batch.seq_id[i][j] = NUM2INT(value);
|
211
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
202
212
|
}
|
203
213
|
|
204
|
-
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE
|
214
|
+
static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
|
205
215
|
LLaMABatchWrapper* ptr = get_llama_batch(self);
|
206
|
-
const int32_t
|
207
|
-
if (
|
208
|
-
rb_raise(rb_eArgError, "
|
216
|
+
const int32_t i = NUM2INT(i_);
|
217
|
+
if (i < 0 || i >= ptr->batch.n_tokens) {
|
218
|
+
rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
|
219
|
+
return Qnil;
|
220
|
+
}
|
221
|
+
const int32_t j = NUM2INT(j_);
|
222
|
+
if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
|
223
|
+
rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
|
209
224
|
return Qnil;
|
210
225
|
}
|
211
|
-
return INT2NUM(ptr->batch.seq_id[
|
226
|
+
return INT2NUM(ptr->batch.seq_id[i][j]);
|
212
227
|
}
|
213
228
|
|
214
229
|
// logits
|
@@ -1319,10 +1334,10 @@ private:
|
|
1319
1334
|
|
1320
1335
|
static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
|
1321
1336
|
VALUE kw_args = Qnil;
|
1322
|
-
ID kw_table[
|
1323
|
-
VALUE kw_values[
|
1337
|
+
ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
|
1338
|
+
VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
|
1324
1339
|
rb_scan_args(argc, argv, ":", &kw_args);
|
1325
|
-
rb_get_kwargs(kw_args, kw_table, 1,
|
1340
|
+
rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
|
1326
1341
|
|
1327
1342
|
if (!RB_TYPE_P(kw_values[0], T_STRING)) {
|
1328
1343
|
rb_raise(rb_eArgError, "text must be a String");
|
@@ -1336,15 +1351,20 @@ private:
|
|
1336
1351
|
rb_raise(rb_eArgError, "add_bos must be a boolean");
|
1337
1352
|
return Qnil;
|
1338
1353
|
}
|
1354
|
+
if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
|
1355
|
+
rb_raise(rb_eArgError, "special must be a boolean");
|
1356
|
+
return Qnil;
|
1357
|
+
}
|
1339
1358
|
|
1340
1359
|
VALUE text_ = kw_values[0];
|
1341
1360
|
std::string text = StringValueCStr(text_);
|
1342
1361
|
const bool add_bos = kw_values[2] == Qtrue ? true : false;
|
1362
|
+
const bool special = kw_values[3] == Qtrue ? true : false;
|
1343
1363
|
const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
|
1344
1364
|
|
1345
1365
|
llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
|
1346
1366
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1347
|
-
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
|
1367
|
+
const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
|
1348
1368
|
|
1349
1369
|
if (n_tokens < 0) {
|
1350
1370
|
rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
|
@@ -73,6 +73,8 @@ struct ggml_metal_context {
|
|
73
73
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
74
74
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
75
75
|
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
76
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_0);
|
77
|
+
GGML_METAL_DECL_KERNEL(get_rows_q5_1);
|
76
78
|
GGML_METAL_DECL_KERNEL(get_rows_q8_0);
|
77
79
|
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
78
80
|
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
@@ -87,6 +89,8 @@ struct ggml_metal_context {
|
|
87
89
|
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
|
88
90
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
|
89
91
|
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
|
92
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
|
93
|
+
GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
|
90
94
|
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
|
91
95
|
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
|
92
96
|
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
|
@@ -97,6 +101,8 @@ struct ggml_metal_context {
|
|
97
101
|
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
|
98
102
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
|
99
103
|
GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
|
104
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
|
105
|
+
GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
|
100
106
|
GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
|
101
107
|
GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
|
102
108
|
GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
|
@@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
254
260
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
255
261
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
256
262
|
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
263
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_0);
|
264
|
+
GGML_METAL_ADD_KERNEL(get_rows_q5_1);
|
257
265
|
GGML_METAL_ADD_KERNEL(get_rows_q8_0);
|
258
266
|
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
259
267
|
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
@@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
268
276
|
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
|
269
277
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
|
270
278
|
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
|
279
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
|
280
|
+
GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
|
271
281
|
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
|
272
282
|
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
|
273
283
|
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
|
@@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
278
288
|
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
|
279
289
|
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
|
280
290
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
|
281
|
-
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
282
291
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
|
292
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
|
293
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
|
294
|
+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
|
283
295
|
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
|
284
296
|
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
|
285
297
|
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
|
@@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
346
358
|
GGML_METAL_DEL_KERNEL(get_rows_f16);
|
347
359
|
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
|
348
360
|
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
|
361
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_0);
|
362
|
+
GGML_METAL_DEL_KERNEL(get_rows_q5_1);
|
349
363
|
GGML_METAL_DEL_KERNEL(get_rows_q8_0);
|
350
364
|
GGML_METAL_DEL_KERNEL(get_rows_q2_K);
|
351
365
|
GGML_METAL_DEL_KERNEL(get_rows_q3_K);
|
@@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
360
374
|
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
|
361
375
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
|
362
376
|
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
|
377
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
|
378
|
+
GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
|
363
379
|
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
|
364
380
|
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
|
365
381
|
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
|
@@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
|
|
370
386
|
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
|
371
387
|
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
|
372
388
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
|
373
|
-
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
374
389
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
|
390
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
|
391
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
|
392
|
+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
|
375
393
|
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
|
376
394
|
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
|
377
395
|
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
|
@@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
|
|
1052
1070
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
|
1053
1071
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
|
1054
1072
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
|
1073
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
|
1074
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
|
1055
1075
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
|
1056
1076
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
|
1057
1077
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
|
@@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
|
|
1121
1141
|
nth1 = 8;
|
1122
1142
|
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
|
1123
1143
|
} break;
|
1144
|
+
case GGML_TYPE_Q5_0:
|
1145
|
+
{
|
1146
|
+
GGML_ASSERT(ne02 == 1);
|
1147
|
+
GGML_ASSERT(ne12 == 1);
|
1148
|
+
|
1149
|
+
nth0 = 8;
|
1150
|
+
nth1 = 8;
|
1151
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
|
1152
|
+
} break;
|
1153
|
+
case GGML_TYPE_Q5_1:
|
1154
|
+
{
|
1155
|
+
GGML_ASSERT(ne02 == 1);
|
1156
|
+
GGML_ASSERT(ne12 == 1);
|
1157
|
+
|
1158
|
+
nth0 = 8;
|
1159
|
+
nth1 = 8;
|
1160
|
+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
|
1161
|
+
} break;
|
1124
1162
|
case GGML_TYPE_Q8_0:
|
1125
1163
|
{
|
1126
1164
|
GGML_ASSERT(ne02 == 1);
|
@@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
|
|
1201
1239
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
|
1202
1240
|
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
|
1203
1241
|
|
1204
|
-
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1242
|
+
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
1243
|
+
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
1205
1244
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
1206
1245
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
1207
1246
|
}
|
@@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
|
|
1233
1272
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
1234
1273
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
1235
1274
|
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
1275
|
+
case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
|
1276
|
+
case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
|
1236
1277
|
case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
|
1237
1278
|
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
1238
1279
|
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
@@ -18,6 +18,21 @@ typedef struct {
|
|
18
18
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
19
19
|
} block_q4_1;
|
20
20
|
|
21
|
+
#define QK5_0 32
|
22
|
+
typedef struct {
|
23
|
+
half d; // delta
|
24
|
+
uint8_t qh[4]; // 5-th bit of quants
|
25
|
+
uint8_t qs[QK5_0 / 2]; // nibbles / quants
|
26
|
+
} block_q5_0;
|
27
|
+
|
28
|
+
#define QK5_1 32
|
29
|
+
typedef struct {
|
30
|
+
half d; // delta
|
31
|
+
half m; // min
|
32
|
+
uint8_t qh[4]; // 5-th bit of quants
|
33
|
+
uint8_t qs[QK5_1 / 2]; // nibbles / quants
|
34
|
+
} block_q5_1;
|
35
|
+
|
21
36
|
#define QK8_0 32
|
22
37
|
typedef struct {
|
23
38
|
half d; // delta
|
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
|
|
399
414
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
400
415
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
401
416
|
float d = qb_curr->d;
|
417
|
+
|
402
418
|
float2 acc = 0.f;
|
419
|
+
|
403
420
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
421
|
+
|
404
422
|
for (int i = 0; i < 8; i+=2) {
|
405
423
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
406
424
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
|
417
435
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
418
436
|
float d = qb_curr->d;
|
419
437
|
float m = qb_curr->m;
|
420
|
-
|
438
|
+
|
421
439
|
float2 acc = 0.f;
|
440
|
+
|
441
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
442
|
+
|
422
443
|
for (int i = 0; i < 8; i+=2) {
|
423
444
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
424
445
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|
428
449
|
return d * (acc[0] + acc[1]) + sumy * m;
|
429
450
|
}
|
430
451
|
|
452
|
+
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
453
|
+
// il indicates where the q5 quants begin (0 or QK5_0/4)
|
454
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
455
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
456
|
+
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
457
|
+
float d = qb_curr->d;
|
458
|
+
|
459
|
+
float2 acc = 0.f;
|
460
|
+
|
461
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
462
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
463
|
+
|
464
|
+
for (int i = 0; i < 8; i+=2) {
|
465
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
466
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
467
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
468
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
469
|
+
}
|
470
|
+
return d * (sumy * -16.f + acc[0] + acc[1]);
|
471
|
+
}
|
472
|
+
|
473
|
+
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
474
|
+
// il indicates where the q5 quants begin (0 or QK5_1/4)
|
475
|
+
// we assume that the yl's have been multiplied with the appropriate scale factor
|
476
|
+
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
477
|
+
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
478
|
+
float d = qb_curr->d;
|
479
|
+
float m = qb_curr->m;
|
480
|
+
|
481
|
+
float2 acc = 0.f;
|
482
|
+
|
483
|
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
484
|
+
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
485
|
+
|
486
|
+
for (int i = 0; i < 8; i+=2) {
|
487
|
+
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
488
|
+
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
489
|
+
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
490
|
+
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
491
|
+
}
|
492
|
+
return d * (acc[0] + acc[1]) + sumy * m;
|
493
|
+
}
|
494
|
+
|
431
495
|
// putting them in the kernel cause a significant performance penalty
|
432
496
|
#define N_DST 4 // each SIMD group works on 4 rows
|
433
497
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
525
589
|
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
526
590
|
}
|
527
591
|
|
592
|
+
kernel void kernel_mul_mv_q5_0_f32(
|
593
|
+
device const void * src0,
|
594
|
+
device const float * src1,
|
595
|
+
device float * dst,
|
596
|
+
constant int64_t & ne00,
|
597
|
+
constant int64_t & ne01[[buffer(4)]],
|
598
|
+
constant int64_t & ne02[[buffer(5)]],
|
599
|
+
constant int64_t & ne10[[buffer(9)]],
|
600
|
+
constant int64_t & ne12[[buffer(11)]],
|
601
|
+
constant int64_t & ne0[[buffer(15)]],
|
602
|
+
constant int64_t & ne1[[buffer(16)]],
|
603
|
+
constant uint & gqa[[buffer(17)]],
|
604
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
605
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
606
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
607
|
+
mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
608
|
+
}
|
609
|
+
|
610
|
+
kernel void kernel_mul_mv_q5_1_f32(
|
611
|
+
device const void * src0,
|
612
|
+
device const float * src1,
|
613
|
+
device float * dst,
|
614
|
+
constant int64_t & ne00,
|
615
|
+
constant int64_t & ne01[[buffer(4)]],
|
616
|
+
constant int64_t & ne02[[buffer(5)]],
|
617
|
+
constant int64_t & ne10[[buffer(9)]],
|
618
|
+
constant int64_t & ne12[[buffer(11)]],
|
619
|
+
constant int64_t & ne0[[buffer(15)]],
|
620
|
+
constant int64_t & ne1[[buffer(16)]],
|
621
|
+
constant uint & gqa[[buffer(17)]],
|
622
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
623
|
+
uint tiisg[[thread_index_in_simdgroup]],
|
624
|
+
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
625
|
+
mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
|
626
|
+
}
|
627
|
+
|
628
|
+
|
528
629
|
#define NB_Q8_0 8
|
529
630
|
|
530
631
|
kernel void kernel_mul_mv_q8_0_f32(
|
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
2149
2250
|
}
|
2150
2251
|
}
|
2151
2252
|
|
2253
|
+
template <typename type4x4>
|
2254
|
+
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
2255
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
2256
|
+
const float d = xb->d;
|
2257
|
+
const float md = -16.h * xb->d;
|
2258
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2259
|
+
|
2260
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2261
|
+
|
2262
|
+
const int x_mv = il ? 4 : 0;
|
2263
|
+
|
2264
|
+
const int gh_mv = il ? 12 : 0;
|
2265
|
+
const int gh_bk = il ? 0 : 4;
|
2266
|
+
|
2267
|
+
for (int i = 0; i < 8; i++) {
|
2268
|
+
// extract the 5-th bits for x0 and x1
|
2269
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2270
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2271
|
+
|
2272
|
+
// combine the 4-bits from qs with the 5th bit
|
2273
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2274
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2275
|
+
|
2276
|
+
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
2277
|
+
reg[i/2][2*(i%2)+1] = d * x1 + md;
|
2278
|
+
}
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
template <typename type4x4>
|
2282
|
+
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
2283
|
+
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
2284
|
+
const float d = xb->d;
|
2285
|
+
const float m = xb->m;
|
2286
|
+
const ushort mask = il ? 0x00F0 : 0x000F;
|
2287
|
+
|
2288
|
+
const uint32_t qh = *((device const uint32_t *)xb->qh);
|
2289
|
+
|
2290
|
+
const int x_mv = il ? 4 : 0;
|
2291
|
+
|
2292
|
+
const int gh_mv = il ? 12 : 0;
|
2293
|
+
const int gh_bk = il ? 0 : 4;
|
2294
|
+
|
2295
|
+
for (int i = 0; i < 8; i++) {
|
2296
|
+
// extract the 5-th bits for x0 and x1
|
2297
|
+
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
2298
|
+
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
2299
|
+
|
2300
|
+
// combine the 4-bits from qs with the 5th bit
|
2301
|
+
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
2302
|
+
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
2303
|
+
|
2304
|
+
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
2305
|
+
reg[i/2][2*(i%2)+1] = d * x1 + m;
|
2306
|
+
}
|
2307
|
+
}
|
2308
|
+
|
2152
2309
|
template <typename type4x4>
|
2153
2310
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
2154
2311
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
|
|
2490
2647
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
2491
2648
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
2492
2649
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
2650
|
+
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
2651
|
+
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
2493
2652
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
2494
2653
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
2495
2654
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
|
|
2518
2677
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
2519
2678
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
2520
2679
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
2680
|
+
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
2681
|
+
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
2521
2682
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
2522
2683
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
2523
2684
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
@@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
|
|
1395
1395
|
const int64_t ne01 = src0->ne[1];
|
1396
1396
|
const int64_t ne02 = src0->ne[2];
|
1397
1397
|
const int64_t ne03 = src0->ne[3];
|
1398
|
-
const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
|
1399
1398
|
const int64_t ne10 = src1->ne[0];
|
1400
1399
|
const int64_t ne11 = src1->ne[1];
|
1401
1400
|
const int64_t ne12 = src1->ne[2];
|
1402
1401
|
const int64_t ne13 = src1->ne[3];
|
1403
|
-
const int64_t nb10 = src1->nb[0];
|
1404
1402
|
const int nb2 = dst->nb[2];
|
1405
1403
|
const int nb3 = dst->nb[3];
|
1406
1404
|
size_t x_size;
|
1407
1405
|
size_t d_size;
|
1408
1406
|
|
1409
|
-
cl_mem d_X = ggml_cl_pool_malloc(
|
1407
|
+
cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
|
1410
1408
|
cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
|
1411
|
-
cl_mem d_D = ggml_cl_pool_malloc(
|
1409
|
+
cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
|
1412
1410
|
|
1413
1411
|
|
1414
1412
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
1415
1413
|
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
1416
|
-
const int i0 = i03*ne02 + i02;
|
1417
|
-
|
1418
1414
|
cl_event ev;
|
1419
1415
|
|
1420
1416
|
// copy src0 to device
|
1421
|
-
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X,
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
|
1430
|
-
|
1431
|
-
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
1443
|
-
} else {
|
1444
|
-
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
1445
|
-
const int64_t i13 = i03%ne13;
|
1446
|
-
const int64_t i12 = i02%ne12;
|
1447
|
-
const int64_t i11 = i01%ne11;
|
1448
|
-
const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
|
1449
|
-
|
1450
|
-
cl_int x_offset = i01*ne00;
|
1451
|
-
cl_int y_offset = i1*ne10;
|
1452
|
-
cl_int d_offset = i01*ne00;
|
1453
|
-
|
1454
|
-
// compute
|
1455
|
-
size_t global = ne00;
|
1456
|
-
cl_int ky = ne10;
|
1457
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
1458
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
1459
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
1460
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
1461
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
1462
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
1463
|
-
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
1464
|
-
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
1465
|
-
}
|
1466
|
-
}
|
1417
|
+
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
|
1418
|
+
|
1419
|
+
const int64_t i13 = i03%ne13;
|
1420
|
+
const int64_t i12 = i02%ne12;
|
1421
|
+
const int i1 = i13*ne12*ne11 + i12*ne11;
|
1422
|
+
|
1423
|
+
cl_int x_offset = 0;
|
1424
|
+
cl_int y_offset = i1*ne10;
|
1425
|
+
cl_int d_offset = 0;
|
1426
|
+
|
1427
|
+
size_t global = ne00 * ne01;
|
1428
|
+
cl_int ky = ne10 * ne11;
|
1429
|
+
|
1430
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
|
1431
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
|
1432
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
|
1433
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
|
1434
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
|
1435
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
|
1436
|
+
CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
|
1437
|
+
CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
|
1467
1438
|
|
1468
1439
|
CL_CHECK(clReleaseEvent(ev));
|
1469
1440
|
CL_CHECK(clFinish(queue));
|
@@ -1568,7 +1539,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|
1568
1539
|
ggml_cl_pool_free(d_D, d_size);
|
1569
1540
|
}
|
1570
1541
|
|
1571
|
-
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t
|
1542
|
+
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
|
1572
1543
|
GGML_ASSERT(fp16_support);
|
1573
1544
|
|
1574
1545
|
const int64_t ne00 = src0->ne[0];
|
@@ -1598,6 +1569,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|
1598
1569
|
const int y_ne = ne11 * ne10;
|
1599
1570
|
const int d_ne = ne11 * ne01;
|
1600
1571
|
|
1572
|
+
GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
|
1573
|
+
GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
|
1574
|
+
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
|
1575
|
+
|
1601
1576
|
size_t x_size;
|
1602
1577
|
size_t y_size;
|
1603
1578
|
size_t d_size;
|
@@ -1634,7 +1609,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|
1634
1609
|
|
1635
1610
|
// convert src1 to fp16
|
1636
1611
|
// TODO: use multiple threads
|
1637
|
-
ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
|
1638
1612
|
char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
|
1639
1613
|
if (src1_cont_rows) {
|
1640
1614
|
if (src1_cont_cols) {
|
@@ -1897,8 +1871,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
|
|
1897
1871
|
}
|
1898
1872
|
|
1899
1873
|
size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
|
1900
|
-
if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
|
1901
|
-
return
|
1874
|
+
if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
|
1875
|
+
return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
|
1902
1876
|
}
|
1903
1877
|
return 0;
|
1904
1878
|
}
|