llama_cpp 0.7.1 → 0.8.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 6688a7296f7a7e7ba4aa593b2d9b792beb1d569f7f2e0e872e1dbda64a336b57
4
- data.tar.gz: 3f683714c3b11b8f247d9ef40774b90e297c25f3bf2ab478e763bda9c983d73a
3
+ metadata.gz: 8045208b5f7801979212a4f6ed395217e78f06bcfbc2d0362aaaa04c529745cd
4
+ data.tar.gz: 4011dfe279d8d4041c6c79dc5a6bad199777f83b5f0559f11ccd2f68c957e462
5
5
  SHA512:
6
- metadata.gz: d7dc061516e688624f4090b956fd40999c9e2e5d2ae41fe8a1baac3caaf61ed9aef3ef31e8ca971e0a210a592cb3618f67533483e5808e2e9205e2ba9a7dfcf8
7
- data.tar.gz: aae1a4952d19aa186aa2ea97ce59af1dac7295f5430108aaf6545949218851b31c266472cf6111a62f7a5784c5f23fd3e3697f1181d5e659c217975890eed299
6
+ metadata.gz: d15e74da491773961006eca8ca6c6d80b30ffc995c56a9140961be0002eb09134f1a029c4e8ee192497fb7256fe36cf1c3ed928967ce57ece4c7a0904392c8fe
7
+ data.tar.gz: a863596304ddb9ac5e4be2b2b65bebc7d3913705b8a0f516debfee0ca213f9dca69707edda8d70cfafb15500fcb6e70cffb6d5d1119302d24e05059c50f0da77
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## [[0.8.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.1...v0.8.0)] - 2023-10-21
2
+
3
+ **Breaking Changes**
4
+ - Bump bundled llama.cpp from b1380 to b1405
5
+ - Add column index argument to `set_seq_id` and `get_seq_id` methods in Batch.
6
+ - Add `special` keyword argument to `tokenize` method in Model.
7
+ - Add `n_seq_max` keyword argument to `initialize` method in Batch.
8
+
1
9
  ## [[0.7.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.7.0...v0.7.1)] - 2023-10-14
2
10
 
3
11
  - Bump bundled llama.cpp from b1334 to b1380.
@@ -63,8 +63,8 @@ public:
63
63
  rb_define_method(rb_cLLaMABatch, "get_token", RUBY_METHOD_FUNC(_llama_batch_get_token), 1);
64
64
  rb_define_method(rb_cLLaMABatch, "set_pos", RUBY_METHOD_FUNC(_llama_batch_set_pos), 2);
65
65
  rb_define_method(rb_cLLaMABatch, "get_pos", RUBY_METHOD_FUNC(_llama_batch_get_pos), 1);
66
- rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 2);
67
- rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 1);
66
+ rb_define_method(rb_cLLaMABatch, "set_seq_id", RUBY_METHOD_FUNC(_llama_batch_set_seq_id), 3);
67
+ rb_define_method(rb_cLLaMABatch, "get_seq_id", RUBY_METHOD_FUNC(_llama_batch_get_seq_id), 2);
68
68
  rb_define_method(rb_cLLaMABatch, "set_logits", RUBY_METHOD_FUNC(_llama_batch_set_logits), 2);
69
69
  rb_define_method(rb_cLLaMABatch, "get_logits", RUBY_METHOD_FUNC(_llama_batch_get_logits), 1);
70
70
  }
@@ -74,10 +74,10 @@ private:
74
74
 
75
75
  static VALUE _llama_batch_initialize(int argc, VALUE* argv, VALUE self) {
76
76
  VALUE kw_args = Qnil;
77
- ID kw_table[2] = { rb_intern("n_tokens"), rb_intern("embd") };
78
- VALUE kw_values[2] = { Qundef, Qundef };
77
+ ID kw_table[3] = { rb_intern("n_tokens"), rb_intern("embd"), rb_intern("n_seq_max") };
78
+ VALUE kw_values[3] = { Qundef, Qundef, Qundef };
79
79
  rb_scan_args(argc, argv, ":", &kw_args);
80
- rb_get_kwargs(kw_args, kw_table, 2, 0, kw_values);
80
+ rb_get_kwargs(kw_args, kw_table, 3, 0, kw_values);
81
81
 
82
82
  if (!RB_INTEGER_TYPE_P(kw_values[0])) {
83
83
  rb_raise(rb_eArgError, "n_tokens must be an integer");
@@ -87,12 +87,17 @@ private:
87
87
  rb_raise(rb_eArgError, "embd must be an integer");
88
88
  return Qnil;
89
89
  }
90
+ if (!RB_INTEGER_TYPE_P(kw_values[2])) {
91
+ rb_raise(rb_eArgError, "n_seq_max must be an integer");
92
+ return Qnil;
93
+ }
90
94
 
91
95
  const int32_t n_tokens = NUM2INT(kw_values[0]);
92
96
  const int32_t embd = NUM2INT(kw_values[1]);
97
+ const int32_t n_seq_max = NUM2INT(kw_values[2]);
93
98
 
94
99
  LLaMABatchWrapper* ptr = get_llama_batch(self);
95
- ptr->batch = llama_batch_init(n_tokens, embd);
100
+ ptr->batch = llama_batch_init(n_tokens, embd, n_seq_max);
96
101
 
97
102
  return Qnil;
98
103
  }
@@ -190,25 +195,35 @@ private:
190
195
  }
191
196
 
192
197
  // seq_id
193
- static VALUE _llama_batch_set_seq_id(VALUE self, VALUE idx, VALUE value) {
198
+ static VALUE _llama_batch_set_seq_id(VALUE self, VALUE i_, VALUE j_, VALUE value) {
194
199
  LLaMABatchWrapper* ptr = get_llama_batch(self);
195
- const int32_t id = NUM2INT(idx);
196
- if (id < 0 || id >= ptr->batch.n_tokens) {
197
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
200
+ const int32_t i = NUM2INT(i_);
201
+ if (i < 0 || i >= ptr->batch.n_tokens) {
202
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
203
+ return Qnil;
204
+ }
205
+ const int32_t j = NUM2INT(j_);
206
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
207
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
198
208
  return Qnil;
199
209
  }
200
- ptr->batch.seq_id[id] = NUM2INT(value);
201
- return INT2NUM(ptr->batch.seq_id[id]);
210
+ ptr->batch.seq_id[i][j] = NUM2INT(value);
211
+ return INT2NUM(ptr->batch.seq_id[i][j]);
202
212
  }
203
213
 
204
- static VALUE _llama_batch_get_seq_id(VALUE self, VALUE idx) {
214
+ static VALUE _llama_batch_get_seq_id(VALUE self, VALUE i_, VALUE j_) {
205
215
  LLaMABatchWrapper* ptr = get_llama_batch(self);
206
- const int32_t id = NUM2INT(idx);
207
- if (id < 0 || id >= ptr->batch.n_tokens) {
208
- rb_raise(rb_eArgError, "id must be in [0, n_tokens)");
216
+ const int32_t i = NUM2INT(i_);
217
+ if (i < 0 || i >= ptr->batch.n_tokens) {
218
+ rb_raise(rb_eArgError, "i must be in [0, n_tokens)");
219
+ return Qnil;
220
+ }
221
+ const int32_t j = NUM2INT(j_);
222
+ if (j < 0 || j >= ptr->batch.n_seq_id[i]) {
223
+ rb_raise(rb_eArgError, "j must be in [0, n_seq_id[i])");
209
224
  return Qnil;
210
225
  }
211
- return INT2NUM(ptr->batch.seq_id[id]);
226
+ return INT2NUM(ptr->batch.seq_id[i][j]);
212
227
  }
213
228
 
214
229
  // logits
@@ -1319,10 +1334,10 @@ private:
1319
1334
 
1320
1335
  static VALUE _llama_model_tokenize(int argc, VALUE* argv, VALUE self) {
1321
1336
  VALUE kw_args = Qnil;
1322
- ID kw_table[3] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos") };
1323
- VALUE kw_values[3] = { Qundef, Qundef, Qundef };
1337
+ ID kw_table[4] = { rb_intern("text"), rb_intern("n_max_tokens"), rb_intern("add_bos"), rb_intern("special") };
1338
+ VALUE kw_values[4] = { Qundef, Qundef, Qundef, Qundef };
1324
1339
  rb_scan_args(argc, argv, ":", &kw_args);
1325
- rb_get_kwargs(kw_args, kw_table, 1, 2, kw_values);
1340
+ rb_get_kwargs(kw_args, kw_table, 1, 3, kw_values);
1326
1341
 
1327
1342
  if (!RB_TYPE_P(kw_values[0], T_STRING)) {
1328
1343
  rb_raise(rb_eArgError, "text must be a String");
@@ -1336,15 +1351,20 @@ private:
1336
1351
  rb_raise(rb_eArgError, "add_bos must be a boolean");
1337
1352
  return Qnil;
1338
1353
  }
1354
+ if (kw_values[3] != Qundef && (kw_values[3] != Qtrue && kw_values[3] != Qfalse)) {
1355
+ rb_raise(rb_eArgError, "special must be a boolean");
1356
+ return Qnil;
1357
+ }
1339
1358
 
1340
1359
  VALUE text_ = kw_values[0];
1341
1360
  std::string text = StringValueCStr(text_);
1342
1361
  const bool add_bos = kw_values[2] == Qtrue ? true : false;
1362
+ const bool special = kw_values[3] == Qtrue ? true : false;
1343
1363
  const int n_max_tokens = kw_values[1] != Qundef ? NUM2INT(kw_values[1]) : text.size() + (add_bos ? 1 : 0);
1344
1364
 
1345
1365
  llama_token* tokens = ALLOCA_N(llama_token, n_max_tokens);
1346
1366
  LLaMAModelWrapper* ptr = get_llama_model(self);
1347
- const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos);
1367
+ const int n_tokens = llama_tokenize(ptr->model, text.c_str(), text.size(), tokens, n_max_tokens, add_bos, special);
1348
1368
 
1349
1369
  if (n_tokens < 0) {
1350
1370
  rb_raise(rb_eRuntimeError, "failed to tokenize. The numebr of tokens (%d) is greater than n_max_tokens.", -n_tokens);
@@ -73,6 +73,8 @@ struct ggml_metal_context {
73
73
  GGML_METAL_DECL_KERNEL(get_rows_f16);
74
74
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
75
75
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
76
+ GGML_METAL_DECL_KERNEL(get_rows_q5_0);
77
+ GGML_METAL_DECL_KERNEL(get_rows_q5_1);
76
78
  GGML_METAL_DECL_KERNEL(get_rows_q8_0);
77
79
  GGML_METAL_DECL_KERNEL(get_rows_q2_K);
78
80
  GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -87,6 +89,8 @@ struct ggml_metal_context {
87
89
  GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
88
90
  GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
89
91
  GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
92
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
93
+ GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
90
94
  GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
91
95
  GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
92
96
  GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
@@ -97,6 +101,8 @@ struct ggml_metal_context {
97
101
  GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
98
102
  GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
99
103
  GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
104
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
105
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
100
106
  GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
101
107
  GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
102
108
  GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -254,6 +260,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
254
260
  GGML_METAL_ADD_KERNEL(get_rows_f16);
255
261
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
256
262
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
263
+ GGML_METAL_ADD_KERNEL(get_rows_q5_0);
264
+ GGML_METAL_ADD_KERNEL(get_rows_q5_1);
257
265
  GGML_METAL_ADD_KERNEL(get_rows_q8_0);
258
266
  GGML_METAL_ADD_KERNEL(get_rows_q2_K);
259
267
  GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -268,6 +276,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
268
276
  GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269
277
  GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270
278
  GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
279
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
280
+ GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
271
281
  GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272
282
  GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273
283
  GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
@@ -278,8 +288,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
278
288
  GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
279
289
  GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
280
290
  GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
281
- GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
282
291
  GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
292
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
293
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
294
+ GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
283
295
  GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
284
296
  GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
285
297
  GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
@@ -346,6 +358,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
346
358
  GGML_METAL_DEL_KERNEL(get_rows_f16);
347
359
  GGML_METAL_DEL_KERNEL(get_rows_q4_0);
348
360
  GGML_METAL_DEL_KERNEL(get_rows_q4_1);
361
+ GGML_METAL_DEL_KERNEL(get_rows_q5_0);
362
+ GGML_METAL_DEL_KERNEL(get_rows_q5_1);
349
363
  GGML_METAL_DEL_KERNEL(get_rows_q8_0);
350
364
  GGML_METAL_DEL_KERNEL(get_rows_q2_K);
351
365
  GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -360,6 +374,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
360
374
  GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
361
375
  GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
362
376
  GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
377
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
378
+ GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
363
379
  GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
364
380
  GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
365
381
  GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
@@ -370,8 +386,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
370
386
  GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
371
387
  GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
372
388
  GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
373
- GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
374
389
  GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
390
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
391
+ GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
392
+ GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
375
393
  GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
376
394
  GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
377
395
  GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
@@ -1052,6 +1070,8 @@ void ggml_metal_graph_compute(
1052
1070
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
1053
1071
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
1054
1072
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
1073
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
1074
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
1055
1075
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
1056
1076
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
1057
1077
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -1121,6 +1141,24 @@ void ggml_metal_graph_compute(
1121
1141
  nth1 = 8;
1122
1142
  [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
1123
1143
  } break;
1144
+ case GGML_TYPE_Q5_0:
1145
+ {
1146
+ GGML_ASSERT(ne02 == 1);
1147
+ GGML_ASSERT(ne12 == 1);
1148
+
1149
+ nth0 = 8;
1150
+ nth1 = 8;
1151
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
1152
+ } break;
1153
+ case GGML_TYPE_Q5_1:
1154
+ {
1155
+ GGML_ASSERT(ne02 == 1);
1156
+ GGML_ASSERT(ne12 == 1);
1157
+
1158
+ nth0 = 8;
1159
+ nth1 = 8;
1160
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
1161
+ } break;
1124
1162
  case GGML_TYPE_Q8_0:
1125
1163
  {
1126
1164
  GGML_ASSERT(ne02 == 1);
@@ -1201,7 +1239,8 @@ void ggml_metal_graph_compute(
1201
1239
  [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
1202
1240
  [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
1203
1241
 
1204
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1242
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1243
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1205
1244
  src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1206
1245
  [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
1207
1246
  }
@@ -1233,6 +1272,8 @@ void ggml_metal_graph_compute(
1233
1272
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
1234
1273
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
1235
1274
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
1275
+ case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
1276
+ case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
1236
1277
  case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
1237
1278
  case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
1238
1279
  case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -18,6 +18,21 @@ typedef struct {
18
18
  uint8_t qs[QK4_1 / 2]; // nibbles / quants
19
19
  } block_q4_1;
20
20
 
21
+ #define QK5_0 32
22
+ typedef struct {
23
+ half d; // delta
24
+ uint8_t qh[4]; // 5-th bit of quants
25
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
26
+ } block_q5_0;
27
+
28
+ #define QK5_1 32
29
+ typedef struct {
30
+ half d; // delta
31
+ half m; // min
32
+ uint8_t qh[4]; // 5-th bit of quants
33
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
34
+ } block_q5_1;
35
+
21
36
  #define QK8_0 32
22
37
  typedef struct {
23
38
  half d; // delta
@@ -399,8 +414,11 @@ kernel void kernel_rms_norm(
399
414
  // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
400
415
  inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
401
416
  float d = qb_curr->d;
417
+
402
418
  float2 acc = 0.f;
419
+
403
420
  device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
421
+
404
422
  for (int i = 0; i < 8; i+=2) {
405
423
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
406
424
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
417
435
  inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
418
436
  float d = qb_curr->d;
419
437
  float m = qb_curr->m;
420
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
438
+
421
439
  float2 acc = 0.f;
440
+
441
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
442
+
422
443
  for (int i = 0; i < 8; i+=2) {
423
444
  acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
424
445
  + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +449,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
428
449
  return d * (acc[0] + acc[1]) + sumy * m;
429
450
  }
430
451
 
452
+ // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
453
+ // il indicates where the q5 quants begin (0 or QK5_0/4)
454
+ // we assume that the yl's have been multiplied with the appropriate scale factor
455
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
456
+ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
457
+ float d = qb_curr->d;
458
+
459
+ float2 acc = 0.f;
460
+
461
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
462
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
463
+
464
+ for (int i = 0; i < 8; i+=2) {
465
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
466
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
467
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
468
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
469
+ }
470
+ return d * (sumy * -16.f + acc[0] + acc[1]);
471
+ }
472
+
473
+ // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
474
+ // il indicates where the q5 quants begin (0 or QK5_1/4)
475
+ // we assume that the yl's have been multiplied with the appropriate scale factor
476
+ // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
477
+ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
478
+ float d = qb_curr->d;
479
+ float m = qb_curr->m;
480
+
481
+ float2 acc = 0.f;
482
+
483
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
484
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
485
+
486
+ for (int i = 0; i < 8; i+=2) {
487
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
488
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
489
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
490
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
491
+ }
492
+ return d * (acc[0] + acc[1]) + sumy * m;
493
+ }
494
+
431
495
  // putting them in the kernel cause a significant performance penalty
432
496
  #define N_DST 4 // each SIMD group works on 4 rows
433
497
  #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +589,43 @@ kernel void kernel_mul_mv_q4_1_f32(
525
589
  mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
526
590
  }
527
591
 
592
+ kernel void kernel_mul_mv_q5_0_f32(
593
+ device const void * src0,
594
+ device const float * src1,
595
+ device float * dst,
596
+ constant int64_t & ne00,
597
+ constant int64_t & ne01[[buffer(4)]],
598
+ constant int64_t & ne02[[buffer(5)]],
599
+ constant int64_t & ne10[[buffer(9)]],
600
+ constant int64_t & ne12[[buffer(11)]],
601
+ constant int64_t & ne0[[buffer(15)]],
602
+ constant int64_t & ne1[[buffer(16)]],
603
+ constant uint & gqa[[buffer(17)]],
604
+ uint3 tgpig[[threadgroup_position_in_grid]],
605
+ uint tiisg[[thread_index_in_simdgroup]],
606
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
607
+ mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
608
+ }
609
+
610
+ kernel void kernel_mul_mv_q5_1_f32(
611
+ device const void * src0,
612
+ device const float * src1,
613
+ device float * dst,
614
+ constant int64_t & ne00,
615
+ constant int64_t & ne01[[buffer(4)]],
616
+ constant int64_t & ne02[[buffer(5)]],
617
+ constant int64_t & ne10[[buffer(9)]],
618
+ constant int64_t & ne12[[buffer(11)]],
619
+ constant int64_t & ne0[[buffer(15)]],
620
+ constant int64_t & ne1[[buffer(16)]],
621
+ constant uint & gqa[[buffer(17)]],
622
+ uint3 tgpig[[threadgroup_position_in_grid]],
623
+ uint tiisg[[thread_index_in_simdgroup]],
624
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
625
+ mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
626
+ }
627
+
628
+
528
629
  #define NB_Q8_0 8
529
630
 
530
631
  kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2250,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
2149
2250
  }
2150
2251
  }
2151
2252
 
2253
+ template <typename type4x4>
2254
+ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
2255
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
2256
+ const float d = xb->d;
2257
+ const float md = -16.h * xb->d;
2258
+ const ushort mask = il ? 0x00F0 : 0x000F;
2259
+
2260
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2261
+
2262
+ const int x_mv = il ? 4 : 0;
2263
+
2264
+ const int gh_mv = il ? 12 : 0;
2265
+ const int gh_bk = il ? 0 : 4;
2266
+
2267
+ for (int i = 0; i < 8; i++) {
2268
+ // extract the 5-th bits for x0 and x1
2269
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2270
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2271
+
2272
+ // combine the 4-bits from qs with the 5th bit
2273
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2274
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2275
+
2276
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
2277
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
2278
+ }
2279
+ }
2280
+
2281
+ template <typename type4x4>
2282
+ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
2283
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
2284
+ const float d = xb->d;
2285
+ const float m = xb->m;
2286
+ const ushort mask = il ? 0x00F0 : 0x000F;
2287
+
2288
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
2289
+
2290
+ const int x_mv = il ? 4 : 0;
2291
+
2292
+ const int gh_mv = il ? 12 : 0;
2293
+ const int gh_bk = il ? 0 : 4;
2294
+
2295
+ for (int i = 0; i < 8; i++) {
2296
+ // extract the 5-th bits for x0 and x1
2297
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
2298
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
2299
+
2300
+ // combine the 4-bits from qs with the 5th bit
2301
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
2302
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
2303
+
2304
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
2305
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
2306
+ }
2307
+ }
2308
+
2152
2309
  template <typename type4x4>
2153
2310
  void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
2154
2311
  device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2647,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
2490
2647
  template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
2491
2648
  template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
2492
2649
  template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
2650
+ template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
2651
+ template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
2493
2652
  template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
2494
2653
  template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
2495
2654
  template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2518,6 +2677,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<f
2518
2677
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
2519
2678
  template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
2520
2679
  template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
2680
+ template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
2681
+ template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
2521
2682
  template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
2522
2683
  template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
2523
2684
  template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -1395,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
1395
1395
  const int64_t ne01 = src0->ne[1];
1396
1396
  const int64_t ne02 = src0->ne[2];
1397
1397
  const int64_t ne03 = src0->ne[3];
1398
- const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
1399
1398
  const int64_t ne10 = src1->ne[0];
1400
1399
  const int64_t ne11 = src1->ne[1];
1401
1400
  const int64_t ne12 = src1->ne[2];
1402
1401
  const int64_t ne13 = src1->ne[3];
1403
- const int64_t nb10 = src1->nb[0];
1404
1402
  const int nb2 = dst->nb[2];
1405
1403
  const int nb3 = dst->nb[3];
1406
1404
  size_t x_size;
1407
1405
  size_t d_size;
1408
1406
 
1409
- cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
1407
+ cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
1410
1408
  cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
1411
- cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
1409
+ cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
1412
1410
 
1413
1411
 
1414
1412
  for (int64_t i03 = 0; i03 < ne03; i03++) {
1415
1413
  for (int64_t i02 = 0; i02 < ne02; i02++) {
1416
- const int i0 = i03*ne02 + i02;
1417
-
1418
1414
  cl_event ev;
1419
1415
 
1420
1416
  // copy src0 to device
1421
- CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
1422
-
1423
- if (nb10 == sizeof(float)) {
1424
- // Contiguous, avoid overhead from queueing many kernel runs
1425
- const int64_t i13 = i03%ne13;
1426
- const int64_t i12 = i02%ne12;
1427
- const int i1 = i13*ne12*ne11 + i12*ne11;
1428
-
1429
- cl_int x_offset = 0;
1430
- cl_int y_offset = i1*ne10;
1431
- cl_int d_offset = 0;
1432
-
1433
- size_t global = ne00 * ne01;
1434
- cl_int ky = ne10;
1435
- CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1436
- CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1437
- CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1438
- CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1439
- CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1440
- CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1441
- CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1442
- CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1443
- } else {
1444
- for (int64_t i01 = 0; i01 < ne01; i01++) {
1445
- const int64_t i13 = i03%ne13;
1446
- const int64_t i12 = i02%ne12;
1447
- const int64_t i11 = i01%ne11;
1448
- const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
1449
-
1450
- cl_int x_offset = i01*ne00;
1451
- cl_int y_offset = i1*ne10;
1452
- cl_int d_offset = i01*ne00;
1453
-
1454
- // compute
1455
- size_t global = ne00;
1456
- cl_int ky = ne10;
1457
- CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1458
- CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1459
- CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1460
- CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1461
- CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1462
- CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1463
- CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1464
- CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1465
- }
1466
- }
1417
+ CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
1418
+
1419
+ const int64_t i13 = i03%ne13;
1420
+ const int64_t i12 = i02%ne12;
1421
+ const int i1 = i13*ne12*ne11 + i12*ne11;
1422
+
1423
+ cl_int x_offset = 0;
1424
+ cl_int y_offset = i1*ne10;
1425
+ cl_int d_offset = 0;
1426
+
1427
+ size_t global = ne00 * ne01;
1428
+ cl_int ky = ne10 * ne11;
1429
+
1430
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
1431
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
1432
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
1433
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
1434
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
1435
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
1436
+ CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
1437
+ CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
1467
1438
 
1468
1439
  CL_CHECK(clReleaseEvent(ev));
1469
1440
  CL_CHECK(clFinish(queue));
@@ -1568,7 +1539,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
1568
1539
  ggml_cl_pool_free(d_D, d_size);
1569
1540
  }
1570
1541
 
1571
- static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
1542
+ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
1572
1543
  GGML_ASSERT(fp16_support);
1573
1544
 
1574
1545
  const int64_t ne00 = src0->ne[0];
@@ -1598,6 +1569,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1598
1569
  const int y_ne = ne11 * ne10;
1599
1570
  const int d_ne = ne11 * ne01;
1600
1571
 
1572
+ GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
1573
+ GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
1574
+ ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
1575
+
1601
1576
  size_t x_size;
1602
1577
  size_t y_size;
1603
1578
  size_t d_size;
@@ -1634,7 +1609,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
1634
1609
 
1635
1610
  // convert src1 to fp16
1636
1611
  // TODO: use multiple threads
1637
- ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
1638
1612
  char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
1639
1613
  if (src1_cont_rows) {
1640
1614
  if (src1_cont_cols) {
@@ -1897,8 +1871,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
1897
1871
  }
1898
1872
 
1899
1873
  size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
1900
- if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
1901
- return ggml_nelements(src1) * sizeof(ggml_fp16_t);
1874
+ if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
1875
+ return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
1902
1876
  }
1903
1877
  return 0;
1904
1878
  }