llama_cpp 0.15.0 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,6 +17,83 @@
17
17
  #define MIN(a, b) ((a) < (b) ? (a) : (b))
18
18
  #define MAX(a, b) ((a) > (b) ? (a) : (b))
19
19
 
20
+ /**
21
+ * Converts brain16 to float32.
22
+ *
23
+ * The bfloat16 floating point format has the following structure:
24
+ *
25
+ * ┌sign
26
+ * │
27
+ * │ ┌exponent
28
+ * │ │
29
+ * │ │ ┌mantissa
30
+ * │ │ │
31
+ * │┌──┴───┐┌─┴───┐
32
+ * 0b0000000000000000 brain16
33
+ *
34
+ * Since bf16 has the same number of exponent bits as a 32bit float,
35
+ * encoding and decoding numbers becomes relatively straightforward.
36
+ *
37
+ * ┌sign
38
+ * │
39
+ * │ ┌exponent
40
+ * │ │
41
+ * │ │ ┌mantissa
42
+ * │ │ │
43
+ * │┌──┴───┐┌─┴───────────────────┐
44
+ * 0b00000000000000000000000000000000 IEEE binary32
45
+ *
46
+ * For comparison, the standard fp16 format has fewer exponent bits.
47
+ *
48
+ * ┌sign
49
+ * │
50
+ * │ ┌exponent
51
+ * │ │
52
+ * │ │ ┌mantissa
53
+ * │ │ │
54
+ * │┌─┴─┐┌─┴──────┐
55
+ * 0b0000000000000000 IEEE binary16
56
+ *
57
+ * @see IEEE 754-2008
58
+ */
59
+ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
60
+ union {
61
+ float f;
62
+ uint32_t i;
63
+ } u;
64
+ u.i = (uint32_t)h.bits << 16;
65
+ return u.f;
66
+ }
67
+
68
+ /**
69
+ * Converts float32 to brain16.
70
+ *
71
+ * This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
72
+ * Subnormals shall be flushed to zero, and NANs will be quiet.
73
+ * This code should vectorize nicely if using modern compilers.
74
+ */
75
+ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
76
+ ggml_bf16_t h;
77
+ union {
78
+ float f;
79
+ uint32_t i;
80
+ } u;
81
+ u.f = s;
82
+ if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
83
+ h.bits = (u.i >> 16) | 64; /* force to quiet */
84
+ return h;
85
+ }
86
+ if (!(u.i & 0x7f800000)) { /* subnormal */
87
+ h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
88
+ return h;
89
+ }
90
+ h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
91
+ return h;
92
+ }
93
+
94
+ #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
95
+ #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
96
+
20
97
  #ifdef __cplusplus
21
98
  extern "C" {
22
99
  #endif
@@ -43,9 +120,16 @@ extern "C" {
43
120
  #ifndef __F16C__
44
121
  #define __F16C__
45
122
  #endif
123
+ #endif
124
+
125
+ // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
126
+ #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
46
127
  #ifndef __SSE3__
47
128
  #define __SSE3__
48
129
  #endif
130
+ #ifndef __SSSE3__
131
+ #define __SSSE3__
132
+ #endif
49
133
  #endif
50
134
 
51
135
  // 16-bit float
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
1559
  case GGML_OP_SOFT_MAX:
1560
1560
  {
1561
1561
  float scale;
1562
- memcpy(&scale, dst->op_params, sizeof(float));
1562
+ float max_bias;
1563
1563
 
1564
- #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1564
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1565
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1566
+
1567
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1565
1568
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
1569
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
- GGML_ASSERT(src2 == nullptr);
1570
+
1571
+ #pragma message("TODO: add ALiBi support")
1572
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1573
+ GGML_ASSERT(max_bias == 0.0f);
1568
1574
 
1569
1575
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1570
1576
  } break;
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
40
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
41
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
44
  GGML_METAL_KERNEL_TYPE_GELU,
44
45
  GGML_METAL_KERNEL_TYPE_GELU_4,
45
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
169
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
171
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
172
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -265,11 +265,20 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
265
265
 
266
266
  static void * ggml_metal_host_malloc(size_t n) {
267
267
  void * data = NULL;
268
+
269
+ #if TARGET_OS_OSX
270
+ kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
271
+ if (err != KERN_SUCCESS) {
272
+ GGML_METAL_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
273
+ return NULL;
274
+ }
275
+ #else
268
276
  const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
269
277
  if (result != 0) {
270
278
  GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
271
279
  return NULL;
272
280
  }
281
+ #endif
273
282
 
274
283
  return data;
275
284
  }
@@ -485,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
486
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
487
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
488
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
489
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
490
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -614,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
614
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
615
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
616
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
617
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
618
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
619
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
620
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -624,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
624
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
625
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
626
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
627
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
628
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
629
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
630
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
631
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
632
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
633
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
634
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
635
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
636
645
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
637
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -723,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
723
732
  switch (ggml_get_unary_op(op)) {
724
733
  case GGML_UNARY_OP_TANH:
725
734
  case GGML_UNARY_OP_RELU:
735
+ case GGML_UNARY_OP_SIGMOID:
726
736
  case GGML_UNARY_OP_GELU:
727
737
  case GGML_UNARY_OP_GELU_QUICK:
728
738
  case GGML_UNARY_OP_SILU:
@@ -750,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
750
760
  case GGML_OP_GROUP_NORM:
751
761
  return ctx->support_simdgroup_reduction;
752
762
  case GGML_OP_NORM:
753
- case GGML_OP_ALIBI:
754
763
  case GGML_OP_ROPE:
755
764
  case GGML_OP_IM2COL:
756
765
  return true;
@@ -763,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
763
772
  case GGML_OP_TIMESTEP_EMBEDDING:
764
773
  case GGML_OP_ARGSORT:
765
774
  case GGML_OP_LEAKY_RELU:
766
- case GGML_OP_FLASH_ATTN_EXT:
767
775
  return true;
776
+ case GGML_OP_FLASH_ATTN_EXT:
777
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
768
778
  case GGML_OP_MUL_MAT:
769
779
  case GGML_OP_MUL_MAT_ID:
770
780
  return ctx->support_simdgroup_reduction &&
@@ -803,7 +813,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
803
813
  case GGML_OP_DIAG_MASK_INF:
804
814
  case GGML_OP_GET_ROWS:
805
815
  {
806
- return op->ne[3] == 1;
816
+ return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
807
817
  }
808
818
  default:
809
819
  return false;
@@ -1185,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
1185
1195
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1186
1196
  } break;
1187
1197
  case GGML_OP_CLAMP:
1188
- {
1189
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1198
+ {
1199
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1190
1200
 
1191
- float min;
1192
- float max;
1193
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1194
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1201
+ float min;
1202
+ float max;
1203
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1204
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1195
1205
 
1196
- [encoder setComputePipelineState:pipeline];
1197
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1198
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1199
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1200
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1206
+ [encoder setComputePipelineState:pipeline];
1207
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1208
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1209
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1210
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1201
1211
 
1202
- const int64_t n = ggml_nelements(dst);
1212
+ const int64_t n = ggml_nelements(dst);
1203
1213
 
1204
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1205
- } break;
1214
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1215
+ } break;
1206
1216
  case GGML_OP_UNARY:
1207
1217
  switch (ggml_get_unary_op(gf->nodes[i])) {
1208
1218
  // we are not taking into account the strides, so for now require contiguous tensors
@@ -1230,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
1230
1240
 
1231
1241
  const int64_t n = ggml_nelements(dst);
1232
1242
 
1243
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1244
+ } break;
1245
+ case GGML_UNARY_OP_SIGMOID:
1246
+ {
1247
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1248
+
1249
+ [encoder setComputePipelineState:pipeline];
1250
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1251
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1252
+
1253
+ const int64_t n = ggml_nelements(dst);
1254
+
1233
1255
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1234
1256
  } break;
1235
1257
  case GGML_UNARY_OP_GELU:
@@ -1348,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
1348
1370
  case GGML_OP_SOFT_MAX:
1349
1371
  {
1350
1372
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1351
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1352
1373
 
1353
1374
  int nth = 32; // SIMD width
1354
1375
 
1355
1376
  id<MTLComputePipelineState> pipeline = nil;
1356
1377
 
1357
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1378
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1358
1379
 
1359
1380
  if (ne00%4 == 0) {
1360
- while (nth < ne00/4 && nth < 256) {
1381
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1361
1382
  nth *= 2;
1362
1383
  }
1363
1384
  if (use_f16) {
@@ -1366,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
1366
1387
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1367
1388
  }
1368
1389
  } else {
1369
- while (nth < ne00 && nth < 1024) {
1390
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1370
1391
  nth *= 2;
1371
1392
  }
1372
1393
  if (use_f16) {
@@ -1385,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
1385
1406
  const int64_t nrows_x = ggml_nrows(src0);
1386
1407
  const int64_t nrows_y = src0->ne[1];
1387
1408
 
1388
- const uint32_t n_head_kv = nrows_x/nrows_y;
1389
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1409
+ const uint32_t n_head = nrows_x/nrows_y;
1410
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1390
1411
 
1391
1412
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1392
1413
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1398,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
1398
1419
  } else {
1399
1420
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1400
1421
  }
1401
- if (id_src2) {
1402
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1403
- } else {
1404
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1405
- }
1406
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1407
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1408
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1409
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1410
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1411
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1412
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1413
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1414
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1422
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1423
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1424
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1425
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1426
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1427
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1428
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1429
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1430
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1415
1431
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1416
1432
 
1417
1433
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2216,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
2216
2232
 
2217
2233
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2218
2234
  } break;
2219
- case GGML_OP_ALIBI:
2220
- {
2221
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2222
-
2223
- const int nth = MIN(1024, ne00);
2224
-
2225
- //const int n_past = ((int32_t *) dst->op_params)[0];
2226
- const int n_head = ((int32_t *) dst->op_params)[1];
2227
-
2228
- float max_bias;
2229
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2230
-
2231
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2232
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2233
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2234
-
2235
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2236
-
2237
- [encoder setComputePipelineState:pipeline];
2238
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2239
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2240
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2241
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2242
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2243
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2244
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2245
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2246
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2247
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2248
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2249
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2250
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2251
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2252
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2253
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2254
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2255
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2256
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2257
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2258
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2259
-
2260
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2261
- } break;
2262
2235
  case GGML_OP_ROPE:
2263
2236
  {
2264
2237
  GGML_ASSERT(ne10 == ne02);
@@ -2380,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
2380
2353
  {
2381
2354
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2382
2355
 
2383
- const int sf = dst->op_params[0];
2356
+ const float sf0 = (float)ne0/src0->ne[0];
2357
+ const float sf1 = (float)ne1/src0->ne[1];
2358
+ const float sf2 = (float)ne2/src0->ne[2];
2359
+ const float sf3 = (float)ne3/src0->ne[3];
2384
2360
 
2385
2361
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2386
2362
 
@@ -2403,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
2403
2379
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2404
2380
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2405
2381
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2406
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2382
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2383
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2384
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2385
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2407
2386
 
2408
2387
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2409
2388
 
@@ -2539,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
2539
2518
  } break;
2540
2519
  case GGML_OP_FLASH_ATTN_EXT:
2541
2520
  {
2542
- GGML_ASSERT(ne00 % 4 == 0);
2521
+ GGML_ASSERT(ne00 % 4 == 0);
2522
+ GGML_ASSERT(ne11 % 32 == 0);
2523
+
2543
2524
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2544
2525
 
2545
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2526
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2546
2527
 
2547
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2548
- GGML_ASSERT(src3);
2528
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2549
2529
 
2550
2530
  size_t offs_src3 = 0;
2551
2531
 
@@ -2555,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
2555
2535
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2556
2536
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2557
2537
 
2538
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
+
2558
2543
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2559
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2544
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2560
2545
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2561
2546
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2562
2547
 
@@ -2568,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
2568
2553
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2569
2554
 
2570
2555
  float scale;
2571
- memcpy(&scale, dst->op_params, sizeof(float));
2556
+ float max_bias;
2557
+
2558
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2559
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2560
+
2561
+ const uint32_t n_head = src0->ne[2];
2562
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2563
+
2564
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2565
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2572
2566
 
2573
2567
  id<MTLComputePipelineState> pipeline = nil;
2574
2568
 
@@ -2605,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
2605
2599
  }
2606
2600
 
2607
2601
  [encoder setComputePipelineState:pipeline];
2608
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2609
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2610
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2611
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2612
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2613
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2614
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2615
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2616
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2617
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2618
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2619
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2620
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2621
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2622
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2623
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2624
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2625
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2626
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2627
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2628
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2629
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2630
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2631
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2632
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2633
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2634
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2635
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2602
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2603
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2604
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2605
+ if (id_src3) {
2606
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2607
+ } else {
2608
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2609
+ }
2610
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2611
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2612
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2613
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2614
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2615
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2616
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2617
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2618
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2619
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2620
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2621
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2622
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2623
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2624
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2625
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2626
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2627
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2628
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2629
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2630
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2631
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2632
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2633
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2636
2634
 
2637
2635
  if (!use_vec_kernel) {
2638
2636
  // half8x8 kernel
@@ -2840,7 +2838,11 @@ GGML_CALL static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_
2840
2838
  ggml_backend_metal_free_device();
2841
2839
 
2842
2840
  if (ctx->owned) {
2841
+ #if TARGET_OS_OSX
2842
+ vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ctx->all_data, ctx->all_size);
2843
+ #else
2843
2844
  free(ctx->all_data);
2845
+ #endif
2844
2846
  }
2845
2847
 
2846
2848
  free(ctx);
@@ -2944,14 +2946,16 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
2944
2946
  ctx->owned = true;
2945
2947
  ctx->n_buffers = 1;
2946
2948
 
2947
- ctx->buffers[0].data = ctx->all_data;
2948
- ctx->buffers[0].size = size;
2949
- ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
2950
- length:size_aligned
2951
- options:MTLResourceStorageModeShared
2952
- deallocator:nil];
2949
+ if (ctx->all_data != NULL) {
2950
+ ctx->buffers[0].data = ctx->all_data;
2951
+ ctx->buffers[0].size = size;
2952
+ ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
2953
+ length:size_aligned
2954
+ options:MTLResourceStorageModeShared
2955
+ deallocator:nil];
2956
+ }
2953
2957
 
2954
- if (ctx->buffers[0].metal == nil) {
2958
+ if (ctx->all_data == NULL || ctx->buffers[0].metal == nil) {
2955
2959
  GGML_METAL_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
2956
2960
  free(ctx);
2957
2961
  ggml_backend_metal_free_device();