llama_cpp 0.7.1 → 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6688a7296f7a7e7ba4aa593b2d9b792beb1d569f7f2e0e872e1dbda64a336b57
4
- data.tar.gz: 3f683714c3b11b8f247d9ef40774b90e297c25f3bf2ab478e763bda9c983d73a
3
+ metadata.gz: 8045208b5f7801979212a4f6ed395217e78f06bcfbc2d0362aaaa04c529745cd
4
+ data.tar.gz: 4011dfe279d8d4041c6c79dc5a6bad199777f83b5f0559f11ccd2f68c957e462
5
5
  SHA512:
6
- metadata.gz: d7dc061516e688624f4090b956fd40999c9e2e5d2ae41fe8a1baac3caaf61ed9aef3ef31e8ca971e0a210a592cb3618f67533483e5808e2e9205e2ba9a7dfcf8
7
- data.tar.gz: aae1a4952d19aa186aa2ea97ce59af1dac7295f5430108aaf6545949218851b31c266472cf6111a62f7a5784c5f23fd3e3697f1181d5e659c217975890eed299
6
+ metadata.gz: d15e74da491773961006eca8ca6c6d80b30ffc995c56a9140961be0002eb09134f1a029c4e8ee192497fb7256fe36cf1c3ed928967ce57ece4c7a0904392c8fe
7
+ data.tar.gz: a863596304ddb9ac5e4be2b2b65bebc7d3913705b8a0f516debfee0ca213f9dca69707edda8d70cfafb15500fcb6e70cffb6d5d1119302d24e05059c50f0da77
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
2
+
3
+ **Breaking Changes**
4
+ - Bump bundled llama.cpp from b1380 to b1405
5
+ - Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
6
+ - Add `special` keyword argument to `tokenize` method in Model.
7
+ - Add `n_seq_max` keyword argument to `initialize` method in Batch.
8
+
1
9
  ## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
2
10
 
3
11
  - Bump bundled llama.cpp from b1334 to b1380.
@@ -63,8 +63,8 @@ public:
63
63
  rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
64
  rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
65
  rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
- rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
- rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
68
68
  rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
69
  rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
70
  }
@@ -74,10 +74,10 @@ private:
74
74
 
75
75
  static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
76
  VALUE kw_args = Qnil;
77
- ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
- VALUE kw_values[2] = { Qundef, Qundef };
77
+ ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
78
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
79
79
  rb_scan_args(argc, argv, ":", &kw_args);
80
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
80
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
81
81
 
82
82
  if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
83
  rb_raise(rb_eArgError, "n_tokens must be an integer");
@@ -87,12 +87,17 @@ private:
87
87
  rb_raise(rb_eArgError, "embd must be an integer");
88
88
  return Qnil;
89
89
  }
90
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
91
+ rb_raise(rb_eArgError, "n_seq_max must be an integer");
92
+ return Qnil;
93
+ }
90
94
 
91
95
  const int32_t n_tokens = NUM2INT(kw_values[0]);
92
96
  const int32_t embd = NUM2INT(kw_values[1]);
97
+ const int32_t n_seq_max = NUM2INT(kw_values[2]);
93
98
 
94
99
  LLaMABatchWrapper* ptr = get_llama_batch(self);
95
- ptr->batch = llama_batch_init(n_tokens, embd);
100
+ ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
96
101
 
97
102
  return Qnil;
98
103
  }
@@ -190,25 +195,35 @@ private:
190
195
  }
191
196
 
192
197
  // seq_id
193
- static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
198
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
194
199
  LLaMABatchWrapper* ptr = get_llama_batch(self);
195
- const int32_t id = NUM2INT(idx);
196
- if (id < 0 || id >= ptr->batch.n_tokens) {
197
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
200
+ const int32_t i = NUM2INT(i_);
201
+ if (i < 0 || i >= ptr->batch.n_tokens) {
202
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
203
+ return Qnil;
204
+ }
205
+ const int32_t j = NUM2INT(j_);
206
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
207
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
198
208
  return Qnil;
199
209
  }
200
- ptr->batch.seq_id[id] = NUM2INT(value);
201
- return INT2NUM(ptr->batch.seq_id[id]);
210
+ ptr->batch.seq_id[i][j] = NUM2INT(value);
211
+ return INT2NUM(ptr->batch.seq_id[i][j]);
202
212
  }
203
213
 
204
- static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
214
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
205
215
  LLaMABatchWrapper* ptr = get_llama_batch(self);
206
- const int32_t id = NUM2INT(idx);
207
- if (id < 0 || id >= ptr->batch.n_tokens) {
208
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
216
+ const int32_t i = NUM2INT(i_);
217
+ if (i < 0 || i >= ptr->batch.n_tokens) {
218
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
219
+ return Qnil;
220
+ }
221
+ const int32_t j = NUM2INT(j_);
222
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
223
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
209
224
  return Qnil;
210
225
  }
211
- return INT2NUM(ptr->batch.seq_id[id]);
226
+ return INT2NUM(ptr->batch.seq_id[i][j]);
212
227
  }
213
228
 
214
229
  // logits
@@ -1319,10 +1334,10 @@ private:
1319
1334
 
1320
1335
  static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1321
1336
  VALUE kw_args = Qnil;
1322
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1323
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1337
+ ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
1338
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1324
1339
  rb_scan_args(argc, argv, ":", &kw_args);
1325
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1340
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
1326
1341
 
1327
1342
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1328
1343
  rb_raise(rb_eArgError, "text must be a String");
@@ -1336,15 +1351,20 @@ private:
1336
1351
  rb_raise(rb_eArgError, "add_bos must be a boolean");
1337
1352
  return Qnil;
1338
1353
  }
1354
+ if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
1355
+ rb_raise(rb_eArgError, "special must be a boolean");
1356
+ return Qnil;
1357
+ }
1339
1358
 
1340
1359
  VALUE text_ = kw_values[0];
1341
1360
  std::string text = StringValueCStr(text_);
1342
1361
  const bool add_bos = kw_values[2] == Qtrue ? true : false;
1362
+ const bool special = kw_values[3] == Qtrue ? true : false;
1343
1363
  const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1344
1364
 
1345
1365
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1346
1366
  LLaMAModelWrapper* ptr = get_llama_model(self);
1347
- const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1367
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
1348
1368
 
1349
1369
  if (n_tokens < 0) {
1350
1370
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -73,6 +73,8 @@ struct ggml_metal_context {
73
73
  GGML_METAL_DECL_KERNEL(get_rows_f16);
74
74
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
75
75
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
76
+ GGML_METAL_DECL_KERNEL(get_rows_q5_0);
77
+ GGML_METAL_DECL_KERNEL(get_rows_q5_1);
76
78
  GGML_METAL_DECL_KERNEL(get_rows_q8_0);
77
79
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
78
80
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -87,6 +89,8 @@ struct ggml_metal_context {
87
89
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
88
90
  GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
89
91
  GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
90
94
  GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
91
95
  GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
92
96
  GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
@@ -97,6 +101,8 @@ struct ggml_metal_context {
97
101
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98
102
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
99
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
104
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
105
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
100
106
  GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
101
107
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
102
108
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
254
260
  GGML_METAL_ADD_KERNEL(get_rows_f16);
255
261
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
256
262
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
263
+ GGML_METAL_ADD_KERNEL(get_rows_q5_0);
264
+ GGML_METAL_ADD_KERNEL(get_rows_q5_1);
257
265
  GGML_METAL_ADD_KERNEL(get_rows_q8_0);
258
266
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
259
267
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
268
276
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269
277
  GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270
278
  GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
279
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
280
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
271
281
  GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272
282
  GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273
283
  GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
@@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
278
288
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
279
289
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
280
290
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
281
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
282
291
  GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
293
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
294
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
283
295
  GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
284
296
  GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
285
297
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
@@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
346
358
  GGML_METAL_DEL_KERNEL(get_rows_f16);
347
359
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
348
360
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
361
+ GGML_METAL_DEL_KERNEL(get_rows_q5_0);
362
+ GGML_METAL_DEL_KERNEL(get_rows_q5_1);
349
363
  GGML_METAL_DEL_KERNEL(get_rows_q8_0);
350
364
  GGML_METAL_DEL_KERNEL(get_rows_q2_K);
351
365
  GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
360
374
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
361
375
  GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
362
376
  GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
377
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
378
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
363
379
  GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
364
380
  GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
365
381
  GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
@@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
370
386
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
371
387
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
372
388
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
373
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
374
389
  GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
390
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
391
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
392
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
375
393
  GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
376
394
  GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
377
395
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
@@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
1052
1070
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1053
1071
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1054
1072
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1073
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1074
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1055
1075
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1056
1076
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1057
1077
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
1121
1141
  nth1 = 8;
1122
1142
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1123
1143
  } break;
1144
+ case GGML_TYPE_Q5_0:
1145
+ {
1146
+ GGML_ASSERT(ne02 == 1);
1147
+ GGML_ASSERT(ne12 == 1);
1148
+
1149
+ nth0 = 8;
1150
+ nth1 = 8;
1151
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1152
+ } break;
1153
+ case GGML_TYPE_Q5_1:
1154
+ {
1155
+ GGML_ASSERT(ne02 == 1);
1156
+ GGML_ASSERT(ne12 == 1);
1157
+
1158
+ nth0 = 8;
1159
+ nth1 = 8;
1160
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1161
+ } break;
1124
1162
  case GGML_TYPE_Q8_0:
1125
1163
  {
1126
1164
  GGML_ASSERT(ne02 == 1);
@@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
1201
1239
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1202
1240
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1203
1241
 
1204
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1242
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1243
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1205
1244
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1206
1245
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1207
1246
  }
@@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
1233
1272
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1234
1273
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1235
1274
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1275
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1276
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1236
1277
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1237
1278
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1238
1279
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -18,6 +18,21 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
  #define QK8_0 32
22
37
  typedef struct {
23
38
  half d; // delta
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
399
414
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
400
415
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
401
416
  float d = qb_curr->d;
417
+
402
418
  float2 acc = 0.f;
419
+
403
420
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
421
+
404
422
  for (int i = 0; i < 8; i+=2) {
405
423
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
406
424
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
417
435
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
418
436
  float d = qb_curr->d;
419
437
  float m = qb_curr->m;
420
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
438
+
421
439
  float2 acc = 0.f;
440
+
441
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
442
+
422
443
  for (int i = 0; i < 8; i+=2) {
423
444
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
424
445
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
428
449
  return d * (acc[0] + acc[1]) + sumy * m;
429
450
  }
430
451
 
452
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
454
+ // we assume that the yl's have been multiplied with the appropriate scale factor
455
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457
+ float d = qb_curr->d;
458
+
459
+ float2 acc = 0.f;
460
+
461
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
462
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
463
+
464
+ for (int i = 0; i < 8; i+=2) {
465
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
466
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
467
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
468
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
469
+ }
470
+ return d * (sumy * -16.f + acc[0] + acc[1]);
471
+ }
472
+
473
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
475
+ // we assume that the yl's have been multiplied with the appropriate scale factor
476
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478
+ float d = qb_curr->d;
479
+ float m = qb_curr->m;
480
+
481
+ float2 acc = 0.f;
482
+
483
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
484
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
485
+
486
+ for (int i = 0; i < 8; i+=2) {
487
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
488
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
489
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
490
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
491
+ }
492
+ return d * (acc[0] + acc[1]) + sumy * m;
493
+ }
494
+
431
495
  // putting them in the kernel cause a significant performance penalty
432
496
  #define N_DST 4 // each SIMD group works on 4 rows
433
497
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
525
589
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
526
590
  }
527
591
 
592
+ kernel void kernel_mul_mv_q5_0_f32(
593
+ device const void * src0,
594
+ device const float * src1,
595
+ device float * dst,
596
+ constant int64_t & ne00,
597
+ constant int64_t & ne01[[buffer(4)]],
598
+ constant int64_t & ne02[[buffer(5)]],
599
+ constant int64_t & ne10[[buffer(9)]],
600
+ constant int64_t & ne12[[buffer(11)]],
601
+ constant int64_t & ne0[[buffer(15)]],
602
+ constant int64_t & ne1[[buffer(16)]],
603
+ constant uint & gqa[[buffer(17)]],
604
+ uint3 tgpig[[threadgroup_position_in_grid]],
605
+ uint tiisg[[thread_index_in_simdgroup]],
606
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
607
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608
+ }
609
+
610
+ kernel void kernel_mul_mv_q5_1_f32(
611
+ device const void * src0,
612
+ device const float * src1,
613
+ device float * dst,
614
+ constant int64_t & ne00,
615
+ constant int64_t & ne01[[buffer(4)]],
616
+ constant int64_t & ne02[[buffer(5)]],
617
+ constant int64_t & ne10[[buffer(9)]],
618
+ constant int64_t & ne12[[buffer(11)]],
619
+ constant int64_t & ne0[[buffer(15)]],
620
+ constant int64_t & ne1[[buffer(16)]],
621
+ constant uint & gqa[[buffer(17)]],
622
+ uint3 tgpig[[threadgroup_position_in_grid]],
623
+ uint tiisg[[thread_index_in_simdgroup]],
624
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
625
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626
+ }
627
+
628
+
528
629
  #define NB_Q8_0 8
529
630
 
530
631
  kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
2149
2250
  }
2150
2251
  }
2151
2252
 
2253
+ template <typename type4x4>
2254
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2256
+ const float d = xb->d;
2257
+ const float md = -16.h * xb->d;
2258
+ const ushort mask = il ? 0x00F0 : 0x000F;
2259
+
2260
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2261
+
2262
+ const int x_mv = il ? 4 : 0;
2263
+
2264
+ const int gh_mv = il ? 12 : 0;
2265
+ const int gh_bk = il ? 0 : 4;
2266
+
2267
+ for (int i = 0; i < 8; i++) {
2268
+ // extract the 5-th bits for x0 and x1
2269
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2270
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2271
+
2272
+ // combine the 4-bits from qs with the 5th bit
2273
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2275
+
2276
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2277
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2278
+ }
2279
+ }
2280
+
2281
+ template <typename type4x4>
2282
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2284
+ const float d = xb->d;
2285
+ const float m = xb->m;
2286
+ const ushort mask = il ? 0x00F0 : 0x000F;
2287
+
2288
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2289
+
2290
+ const int x_mv = il ? 4 : 0;
2291
+
2292
+ const int gh_mv = il ? 12 : 0;
2293
+ const int gh_bk = il ? 0 : 4;
2294
+
2295
+ for (int i = 0; i < 8; i++) {
2296
+ // extract the 5-th bits for x0 and x1
2297
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2298
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2299
+
2300
+ // combine the 4-bits from qs with the 5th bit
2301
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2303
+
2304
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2305
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2306
+ }
2307
+ }
2308
+
2152
2309
  template <typename type4x4>
2153
2310
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2154
2311
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2490
2647
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2491
2648
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2492
2649
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2650
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2651
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2493
2652
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2494
2653
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2495
2654
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2518
2677
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2519
2678
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2520
2679
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2680
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2681
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2521
2682
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2522
2683
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2523
2684
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
1395
1395
  const int64_t ne01 = src0->ne[1];
1396
1396
  const int64_t ne02 = src0->ne[2];
1397
1397
  const int64_t ne03 = src0->ne[3];
1398
- const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
1399
1398
  const int64_t ne10 = src1->ne[0];
1400
1399
  const int64_t ne11 = src1->ne[1];
1401
1400
  const int64_t ne12 = src1->ne[2];
1402
1401
  const int64_t ne13 = src1->ne[3];
1403
- const int64_t nb10 = src1->nb[0];
1404
1402
  const int nb2 = dst->nb[2];
1405
1403
  const int nb3 = dst->nb[3];
1406
1404
  size_t x_size;
1407
1405
  size_t d_size;
1408
1406
 
1409
- cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
1407
+ cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
1410
1408
  cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
1411
- cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
1409
+ cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
1412
1410
 
1413
1411
 
1414
1412
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1415
1413
  for (int64_t i02 = 0; i02 < ne02; i02++) {
1416
- const int i0 = i03*ne02 + i02;
1417
-
1418
1414
  cl_event ev;
1419
1415
 
1420
1416
  // copy src0 to device
1421
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
1422
-
1423
- if (nb10 == sizeof(float)) {
1424
- // Contiguous, avoid overhead from queueing many kernel runs
1425
- const int64_t i13 = i03%ne13;
1426
- const int64_t i12 = i02%ne12;
1427
- const int i1 = i13*ne12*ne11 + i12*ne11;
1428
-
1429
- cl_int x_offset = 0;
1430
- cl_int y_offset = i1*ne10;
1431
- cl_int d_offset = 0;
1432
-
1433
- size_t global = ne00 * ne01;
1434
- cl_int ky = ne10;
1435
- CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1436
- CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1437
- CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1438
- CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1439
- CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1440
- CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1441
- CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1442
- CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1443
- } else {
1444
- for (int64_t i01 = 0; i01 < ne01; i01++) {
1445
- const int64_t i13 = i03%ne13;
1446
- const int64_t i12 = i02%ne12;
1447
- const int64_t i11 = i01%ne11;
1448
- const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
1449
-
1450
- cl_int x_offset = i01*ne00;
1451
- cl_int y_offset = i1*ne10;
1452
- cl_int d_offset = i01*ne00;
1453
-
1454
- // compute
1455
- size_t global = ne00;
1456
- cl_int ky = ne10;
1457
- CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1458
- CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1459
- CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1460
- CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1461
- CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1462
- CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1463
- CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1464
- CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1465
- }
1466
- }
1417
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
1418
+
1419
+ const int64_t i13 = i03%ne13;
1420
+ const int64_t i12 = i02%ne12;
1421
+ const int i1 = i13*ne12*ne11 + i12*ne11;
1422
+
1423
+ cl_int x_offset = 0;
1424
+ cl_int y_offset = i1*ne10;
1425
+ cl_int d_offset = 0;
1426
+
1427
+ size_t global = ne00 * ne01;
1428
+ cl_int ky = ne10 * ne11;
1429
+
1430
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1431
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1432
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1433
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1434
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1435
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1436
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1437
+ CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1467
1438
 
1468
1439
  CL_CHECK(clReleaseEvent(ev));
1469
1440
  CL_CHECK(clFinish(queue));
@@ -1568,7 +1539,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
1568
1539
  ggml_cl_pool_free(d_D, d_size);
1569
1540
  }
1570
1541
 
1571
- static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
1542
+ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
1572
1543
  GGML_ASSERT(fp16_support);
1573
1544
 
1574
1545
  const int64_t ne00 = src0->ne[0];
@@ -1598,6 +1569,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1598
1569
  const int y_ne = ne11 * ne10;
1599
1570
  const int d_ne = ne11 * ne01;
1600
1571
 
1572
+ GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
1573
+ GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
1574
+ ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
1575
+
1601
1576
  size_t x_size;
1602
1577
  size_t y_size;
1603
1578
  size_t d_size;
@@ -1634,7 +1609,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1634
1609
 
1635
1610
  // convert src1 to fp16
1636
1611
  // TODO: use multiple threads
1637
- ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
1638
1612
  char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
1639
1613
  if (src1_cont_rows) {
1640
1614
  if (src1_cont_cols) {
@@ -1897,8 +1871,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
1897
1871
  }
1898
1872
 
1899
1873
  size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1900
- if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
1901
- return ggml_nelements(src1) * sizeof(ggml_fp16_t);
1874
+ if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
1875
+ return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
1902
1876
  }
1903
1877
  return 0;
1904
1878
  }