llama_cpp 0.15.0 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -17,6 +17,83 @@
17
17
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
18
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
19
 
20
+ /**
21
+ * Converts brain16 to float32.
22
+ *
23
+ * The bfloat16 floating point format has the following structure:
24
+ *
25
+ * ┌sign
26
+ * │
27
+ * │ ┌exponent
28
+ * │ │
29
+ * │ │ ┌mantissa
30
+ * │ │ │
31
+ * │┌──┴───┐┌─┴───┐
32
+ * 0b0000000000000000 brain16
33
+ *
34
+ * Since bf16 has the same number of exponent bits as a 32bit float,
35
+ * encoding and decoding numbers becomes relatively straightforward.
36
+ *
37
+ * ┌sign
38
+ * │
39
+ * │ ┌exponent
40
+ * │ │
41
+ * │ │ ┌mantissa
42
+ * │ │ │
43
+ * │┌──┴───┐┌─┴───────────────────┐
44
+ * 0b00000000000000000000000000000000 IEEE binary32
45
+ *
46
+ * For comparison, the standard fp16 format has fewer exponent bits.
47
+ *
48
+ * ┌sign
49
+ * │
50
+ * │ ┌exponent
51
+ * │ │
52
+ * │ │ ┌mantissa
53
+ * │ │ │
54
+ * │┌─┴─┐┌─┴──────┐
55
+ * 0b0000000000000000 IEEE binary16
56
+ *
57
+ * @see IEEE 754-2008
58
+ */
59
+ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
60
+ union {
61
+ float f;
62
+ uint32_t i;
63
+ } u;
64
+ u.i = (uint32_t)h.bits << 16;
65
+ return u.f;
66
+ }
67
+
68
+ /**
69
+ * Converts float32 to brain16.
70
+ *
71
+ * This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
72
+ * Subnormals shall be flushed to zero, and NANs will be quiet.
73
+ * This code should vectorize nicely if using modern compilers.
74
+ */
75
+ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
76
+ ggml_bf16_t h;
77
+ union {
78
+ float f;
79
+ uint32_t i;
80
+ } u;
81
+ u.f = s;
82
+ if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
83
+ h.bits = (u.i >> 16) | 64; /* force to quiet */
84
+ return h;
85
+ }
86
+ if (!(u.i & 0x7f800000)) { /* subnormal */
87
+ h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
88
+ return h;
89
+ }
90
+ h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
91
+ return h;
92
+ }
93
+
94
+ #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
95
+ #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
96
+
20
97
  #ifdef __cplusplus
21
98
  extern "C" {
22
99
  #endif
@@ -43,9 +120,16 @@ extern "C" {
43
120
  #ifndef __F16C__
44
121
  #define __F16C__
45
122
  #endif
123
+ #endif
124
+
125
+ // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
126
+ #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
46
127
  #ifndef __SSE3__
47
128
  #define __SSE3__
48
129
  #endif
130
+ #ifndef __SSSE3__
131
+ #define __SSSE3__
132
+ #endif
49
133
  #endif
50
134
 
51
135
  // 16-bit float
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
1559
  case GGML_OP_SOFT_MAX:
1560
1560
  {
1561
1561
  float scale;
1562
- memcpy(&scale, dst->op_params, sizeof(float));
1562
+ float max_bias;
1563
1563
 
1564
- #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1564
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1565
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1566
+
1567
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1565
1568
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
1569
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
- GGML_ASSERT(src2 == nullptr);
1570
+
1571
+ #pragma message("TODO: add ALiBi support")
1572
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1573
+ GGML_ASSERT(max_bias == 0.0f);
1568
1574
 
1569
1575
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1570
1576
  } break;
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
40
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
41
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
44
  GGML_METAL_KERNEL_TYPE_GELU,
44
45
  GGML_METAL_KERNEL_TYPE_GELU_4,
45
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
169
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
171
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
172
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -265,11 +265,20 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
265
265
 
266
266
  static void * ggml_metal_host_malloc(size_t n) {
267
267
  void * data = NULL;
268
+
269
+ #if TARGET_OS_OSX
270
+ kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
271
+ if (err != KERN_SUCCESS) {
272
+ GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
273
+ return NULL;
274
+ }
275
+ #else
268
276
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
269
277
  if (result != 0) {
270
278
  GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
271
279
  return NULL;
272
280
  }
281
+ #endif
273
282
 
274
283
  return data;
275
284
  }
@@ -485,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
486
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
487
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
488
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
489
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
490
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -614,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
614
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
615
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
616
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
617
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
618
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
619
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
620
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -624,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
624
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
625
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
626
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
627
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
631
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
632
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
633
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
634
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
635
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
636
645
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
637
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -723,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
723
732
  switch (ggml_get_unary_op(op)) {
724
733
  case GGML_UNARY_OP_TANH:
725
734
  case GGML_UNARY_OP_RELU:
735
+ case GGML_UNARY_OP_SIGMOID:
726
736
  case GGML_UNARY_OP_GELU:
727
737
  case GGML_UNARY_OP_GELU_QUICK:
728
738
  case GGML_UNARY_OP_SILU:
@@ -750,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
750
760
  case GGML_OP_GROUP_NORM:
751
761
  return ctx->support_simdgroup_reduction;
752
762
  case GGML_OP_NORM:
753
- case GGML_OP_ALIBI:
754
763
  case GGML_OP_ROPE:
755
764
  case GGML_OP_IM2COL:
756
765
  return true;
@@ -763,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
763
772
  case GGML_OP_TIMESTEP_EMBEDDING:
764
773
  case GGML_OP_ARGSORT:
765
774
  case GGML_OP_LEAKY_RELU:
766
- case GGML_OP_FLASH_ATTN_EXT:
767
775
  return true;
776
+ case GGML_OP_FLASH_ATTN_EXT:
777
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
768
778
  case GGML_OP_MUL_MAT:
769
779
  case GGML_OP_MUL_MAT_ID:
770
780
  return ctx->support_simdgroup_reduction &&
@@ -803,7 +813,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
803
813
  case GGML_OP_DIAG_MASK_INF:
804
814
  case GGML_OP_GET_ROWS:
805
815
  {
806
- return op->ne[3] == 1;
816
+ return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
807
817
  }
808
818
  default:
809
819
  return false;
@@ -1185,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
1185
1195
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1186
1196
  } break;
1187
1197
  case GGML_OP_CLAMP:
1188
- {
1189
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1198
+ {
1199
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1190
1200
 
1191
- float min;
1192
- float max;
1193
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1194
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1201
+ float min;
1202
+ float max;
1203
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1204
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1195
1205
 
1196
- [encoder setComputePipelineState:pipeline];
1197
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1198
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1199
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1200
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1206
+ [encoder setComputePipelineState:pipeline];
1207
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1208
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1209
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1210
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1201
1211
 
1202
- const int64_t n = ggml_nelements(dst);
1212
+ const int64_t n = ggml_nelements(dst);
1203
1213
 
1204
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1205
- } break;
1214
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1215
+ } break;
1206
1216
  case GGML_OP_UNARY:
1207
1217
  switch (ggml_get_unary_op(gf->nodes[i])) {
1208
1218
  // we are not taking into account the strides, so for now require contiguous tensors
@@ -1230,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
1230
1240
 
1231
1241
  const int64_t n = ggml_nelements(dst);
1232
1242
 
1243
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1244
+ } break;
1245
+ case GGML_UNARY_OP_SIGMOID:
1246
+ {
1247
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1248
+
1249
+ [encoder setComputePipelineState:pipeline];
1250
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1251
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1252
+
1253
+ const int64_t n = ggml_nelements(dst);
1254
+
1233
1255
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1234
1256
  } break;
1235
1257
  case GGML_UNARY_OP_GELU:
@@ -1348,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
1348
1370
  case GGML_OP_SOFT_MAX:
1349
1371
  {
1350
1372
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1351
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1352
1373
 
1353
1374
  int nth = 32; // SIMD width
1354
1375
 
1355
1376
  id<MTLComputePipelineState> pipeline = nil;
1356
1377
 
1357
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1378
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1358
1379
 
1359
1380
  if (ne00%4 == 0) {
1360
- while (nth < ne00/4 && nth < 256) {
1381
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1361
1382
  nth *= 2;
1362
1383
  }
1363
1384
  if (use_f16) {
@@ -1366,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
1366
1387
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1367
1388
  }
1368
1389
  } else {
1369
- while (nth < ne00 && nth < 1024) {
1390
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1370
1391
  nth *= 2;
1371
1392
  }
1372
1393
  if (use_f16) {
@@ -1385,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
1385
1406
  const int64_t nrows_x = ggml_nrows(src0);
1386
1407
  const int64_t nrows_y = src0->ne[1];
1387
1408
 
1388
- const uint32_t n_head_kv = nrows_x/nrows_y;
1389
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1409
+ const uint32_t n_head = nrows_x/nrows_y;
1410
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1390
1411
 
1391
1412
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1392
1413
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1398,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
1398
1419
  } else {
1399
1420
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1400
1421
  }
1401
- if (id_src2) {
1402
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1403
- } else {
1404
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1405
- }
1406
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1407
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1408
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1409
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1410
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1411
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1412
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1413
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1414
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1422
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1423
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1424
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1425
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1426
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1427
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1428
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1429
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1430
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1415
1431
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1416
1432
 
1417
1433
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2216,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
2216
2232
 
2217
2233
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2218
2234
  } break;
2219
- case GGML_OP_ALIBI:
2220
- {
2221
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2222
-
2223
- const int nth = MIN(1024, ne00);
2224
-
2225
- //const int n_past = ((int32_t *) dst->op_params)[0];
2226
- const int n_head = ((int32_t *) dst->op_params)[1];
2227
-
2228
- float max_bias;
2229
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2230
-
2231
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2232
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2233
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2234
-
2235
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2236
-
2237
- [encoder setComputePipelineState:pipeline];
2238
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2239
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2240
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2241
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2242
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2243
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2244
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2245
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2246
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2247
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2248
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2249
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2250
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2251
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2252
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2253
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2254
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2255
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2256
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2257
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2258
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2259
-
2260
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2261
- } break;
2262
2235
  case GGML_OP_ROPE:
2263
2236
  {
2264
2237
  GGML_ASSERT(ne10 == ne02);
@@ -2380,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
2380
2353
  {
2381
2354
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2382
2355
 
2383
- const int sf = dst->op_params[0];
2356
+ const float sf0 = (float)ne0/src0->ne[0];
2357
+ const float sf1 = (float)ne1/src0->ne[1];
2358
+ const float sf2 = (float)ne2/src0->ne[2];
2359
+ const float sf3 = (float)ne3/src0->ne[3];
2384
2360
 
2385
2361
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2386
2362
 
@@ -2403,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
2403
2379
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2404
2380
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2405
2381
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2406
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2382
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2383
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2384
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2385
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2407
2386
 
2408
2387
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2409
2388
 
@@ -2539,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
2539
2518
  } break;
2540
2519
  case GGML_OP_FLASH_ATTN_EXT:
2541
2520
  {
2542
- GGML_ASSERT(ne00 % 4 == 0);
2521
+ GGML_ASSERT(ne00 % 4 == 0);
2522
+ GGML_ASSERT(ne11 % 32 == 0);
2523
+
2543
2524
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2544
2525
 
2545
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2526
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2546
2527
 
2547
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2548
- GGML_ASSERT(src3);
2528
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2549
2529
 
2550
2530
  size_t offs_src3 = 0;
2551
2531
 
@@ -2555,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
2555
2535
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2556
2536
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2557
2537
 
2538
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
+
2558
2543
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2559
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2544
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2560
2545
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2561
2546
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2562
2547
 
@@ -2568,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
2568
2553
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2569
2554
 
2570
2555
  float scale;
2571
- memcpy(&scale, dst->op_params, sizeof(float));
2556
+ float max_bias;
2557
+
2558
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2559
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2560
+
2561
+ const uint32_t n_head = src0->ne[2];
2562
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2563
+
2564
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2565
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2572
2566
 
2573
2567
  id<MTLComputePipelineState> pipeline = nil;
2574
2568
 
@@ -2605,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
2605
2599
  }
2606
2600
 
2607
2601
  [encoder setComputePipelineState:pipeline];
2608
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2609
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2610
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2611
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2612
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2613
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2614
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2615
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2616
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2617
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2618
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2619
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2620
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2621
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2622
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2623
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2624
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2625
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2626
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2627
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2628
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2629
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2630
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2631
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2632
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2633
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2634
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2635
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2602
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2603
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2604
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2605
+ if (id_src3) {
2606
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2607
+ } else {
2608
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2609
+ }
2610
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2611
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2612
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2613
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2614
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2615
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2616
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2617
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2618
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2619
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2620
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2621
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2622
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2623
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2624
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2625
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2626
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2627
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2628
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2629
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2630
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2631
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2632
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2633
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2636
2634
 
2637
2635
  if (!use_vec_kernel) {
2638
2636
  // half8x8 kernel
@@ -2840,7 +2838,11 @@ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_
2840
2838
  ggml_backend_metal_free_device();
2841
2839
 
2842
2840
  if (ctx->owned) {
2841
+ #if TARGET_OS_OSX
2842
+ vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
2843
+ #else
2843
2844
  free(ctx->all_data);
2845
+ #endif
2844
2846
  }
2845
2847
 
2846
2848
  free(ctx);
@@ -2944,14 +2946,16 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
2944
2946
  ctx->owned = true;
2945
2947
  ctx->n_buffers = 1;
2946
2948
 
2947
- ctx->buffers[0].data = ctx->all_data;
2948
- ctx->buffers[0].size = size;
2949
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
2950
- length:size_aligned
2951
- options:MTLResourceStorageModeShared
2952
- deallocator:nil];
2949
+ if (ctx->all_data != NULL) {
2950
+ ctx->buffers[0].data = ctx->all_data;
2951
+ ctx->buffers[0].size = size;
2952
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
2953
+ length:size_aligned
2954
+ options:MTLResourceStorageModeShared
2955
+ deallocator:nil];
2956
+ }
2953
2957
 
2954
- if (ctx->buffers[0].metal == nil) {
2958
+ if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
2955
2959
  GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
2956
2960
  free(ctx);
2957
2961
  ggml_backend_metal_free_device();