llama_cpp 0.15.1 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ce6d72aeb5fb9aff775d44284bf934e164f8470973619507ef6e6eb1ac0bec4d
4
- data.tar.gz: 7c1ae823c90f957219b3edbc20f091b65a50caa984c1a6f4d137a46c376b2f0c
3
+ metadata.gz: 30dd4c29b86098faf7c78de5fa8e57021b631bb5eb3d14c93f63f1d186383ab8
4
+ data.tar.gz: b011d891f1cd725f84821428a8db24004b52c9614e785f493f721f7abde71029
5
5
  SHA512:
6
- metadata.gz: d23cb6a63b7734df2547c5e61a699fa206878c747e274e004c829b77335a7cc7434e92168a55d8ab0a617b11eddb5d45d5057a91b92e848735fd9e852b2476cd
7
- data.tar.gz: f54b09de3cc60de81be977e9706a9beb3bf28e7740a19a57f6add543fe10cd6dc4101cbbe22dd5b62870c78a1ad4d10f57dd29b7c3e3e12b950e6575cf67b0c7
6
+ metadata.gz: 6c1628f93762747688f802db8593946e8581c869f63c610669b45759f644b3d19b061825b788e328b6b984977112837586ed398b6118a8f8e5f0c7f6fd0eb2dd
7
+ data.tar.gz: 2f8c3d9f1e6c0f6db7e0682995c8d34179d5405d32784bf00f04a3408cb5bf4c95557bfa1692026f8d3dc9e672d6b15dec5d33cbd76ddc1d94e5ec964a9d0409
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## [[0.15.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.1...v0.15.2)] - 2024-05-18
2
+
3
+ - Bump llama.cpp from b2839 to b2917.
4
+
5
+ Implementation binding for rpc_servers in llama_model_params has been skipped.
6
+
1
7
  ## [[0.15.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.0...v0.15.1)] - 2024-05-11
2
8
 
3
9
  - Bump llama.cpp from b2781 to b2839.
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.15.1'
6
+ VERSION = '0.15.2'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b2839'
9
+ LLAMA_CPP_VERSION = 'b2917'
10
10
  end
@@ -562,10 +562,10 @@ endif # LLAMA_VULKAN
562
562
  ifdef LLAMA_HIPBLAS
563
563
  ifeq ($(wildcard /opt/rocm),)
564
564
  ROCM_PATH ?= /usr
565
- GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
565
+ AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
566
566
  else
567
567
  ROCM_PATH ?= /opt/rocm
568
- GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
568
+ AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
569
569
  endif
570
570
  HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc
571
571
  LLAMA_CUDA_DMMV_X ?= 32
@@ -577,7 +577,7 @@ ifdef LLAMA_HIP_UMA
577
577
  endif # LLAMA_HIP_UMA
578
578
  MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
579
579
  MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
580
- HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
580
+ HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
581
581
  HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
582
582
  HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
583
583
  HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
@@ -1182,9 +1182,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
1182
1182
  static char * fmt_size(size_t size) {
1183
1183
  static char buffer[128];
1184
1184
  if (size >= 1024*1024) {
1185
- sprintf(buffer, "%zuM", size/1024/1024);
1185
+ snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
1186
1186
  } else {
1187
- sprintf(buffer, "%zuK", size/1024);
1187
+ snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
1188
1188
  }
1189
1189
  return buffer;
1190
1190
  }
@@ -1895,7 +1895,6 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t
1895
1895
 
1896
1896
  tensor->buffer = buffer;
1897
1897
  tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1898
- tensor->backend = tensor->view_src->backend;
1899
1898
  ggml_backend_buffer_init_tensor(buffer, tensor);
1900
1899
  }
1901
1900
 
@@ -4,7 +4,6 @@
4
4
 
5
5
  #include "ggml-cuda/common.cuh"
6
6
  #include "ggml-cuda/acc.cuh"
7
- #include "ggml-cuda/alibi.cuh"
8
7
  #include "ggml-cuda/arange.cuh"
9
8
  #include "ggml-cuda/argsort.cuh"
10
9
  #include "ggml-cuda/binbcast.cuh"
@@ -2205,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2205
2204
  case GGML_UNARY_OP_RELU:
2206
2205
  ggml_cuda_op_relu(ctx, dst);
2207
2206
  break;
2207
+ case GGML_UNARY_OP_SIGMOID:
2208
+ ggml_cuda_op_sigmoid(ctx, dst);
2209
+ break;
2208
2210
  case GGML_UNARY_OP_HARDSIGMOID:
2209
2211
  ggml_cuda_op_hardsigmoid(ctx, dst);
2210
2212
  break;
@@ -2277,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2277
2279
  case GGML_OP_ROPE:
2278
2280
  ggml_cuda_op_rope(ctx, dst);
2279
2281
  break;
2280
- case GGML_OP_ALIBI:
2281
- ggml_cuda_op_alibi(ctx, dst);
2282
- break;
2283
2282
  case GGML_OP_IM2COL:
2284
2283
  ggml_cuda_op_im2col(ctx, dst);
2285
2284
  break;
@@ -2559,7 +2558,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2559
2558
  }
2560
2559
 
2561
2560
  // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
2562
- if (cuda_graph_update_required) {
2561
+ if (use_cuda_graph && cuda_graph_update_required) {
2563
2562
  cuda_ctx->cuda_graph->number_consecutive_updates++;
2564
2563
  } else {
2565
2564
  cuda_ctx->cuda_graph->number_consecutive_updates = 0;
@@ -2714,12 +2713,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2714
2713
  }
2715
2714
 
2716
2715
  GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
2716
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2717
2717
  switch (op->op) {
2718
2718
  case GGML_OP_UNARY:
2719
2719
  switch (ggml_get_unary_op(op)) {
2720
2720
  case GGML_UNARY_OP_GELU:
2721
2721
  case GGML_UNARY_OP_SILU:
2722
2722
  case GGML_UNARY_OP_RELU:
2723
+ case GGML_UNARY_OP_SIGMOID:
2723
2724
  case GGML_UNARY_OP_HARDSIGMOID:
2724
2725
  case GGML_UNARY_OP_HARDSWISH:
2725
2726
  case GGML_UNARY_OP_GELU_QUICK:
@@ -2829,7 +2830,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2829
2830
  case GGML_OP_DIAG_MASK_INF:
2830
2831
  case GGML_OP_SOFT_MAX:
2831
2832
  case GGML_OP_ROPE:
2832
- case GGML_OP_ALIBI:
2833
2833
  case GGML_OP_IM2COL:
2834
2834
  case GGML_OP_POOL_2D:
2835
2835
  case GGML_OP_SUM_ROWS:
@@ -2841,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2841
2841
  case GGML_OP_ARANGE:
2842
2842
  case GGML_OP_TIMESTEP_EMBEDDING:
2843
2843
  case GGML_OP_LEAKY_RELU:
2844
- case GGML_OP_FLASH_ATTN_EXT:
2845
2844
  return true;
2845
+ case GGML_OP_FLASH_ATTN_EXT:
2846
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2847
+ return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2848
+ #else
2849
+ if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2850
+ return true;
2851
+ }
2852
+ return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2853
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2846
2854
  default:
2847
2855
  return false;
2848
2856
  }
@@ -120,9 +120,16 @@ extern "C" {
120
120
  #ifndef __F16C__
121
121
  #define __F16C__
122
122
  #endif
123
+ #endif
124
+
125
+ // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
126
+ #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
123
127
  #ifndef __SSE3__
124
128
  #define __SSE3__
125
129
  #endif
130
+ #ifndef __SSSE3__
131
+ #define __SSSE3__
132
+ #endif
126
133
  #endif
127
134
 
128
135
  // 16-bit float
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
1559
  case GGML_OP_SOFT_MAX:
1560
1560
  {
1561
1561
  float scale;
1562
- memcpy(&scale, dst->op_params, sizeof(float));
1562
+ float max_bias;
1563
1563
 
1564
- #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1564
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1565
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1566
+
1567
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1565
1568
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
1569
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
- GGML_ASSERT(src2 == nullptr);
1570
+
1571
+ #pragma message("TODO: add ALiBi support")
1572
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1573
+ GGML_ASSERT(max_bias == 0.0f);
1568
1574
 
1569
1575
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1570
1576
  } break;
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
40
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
41
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
44
  GGML_METAL_KERNEL_TYPE_GELU,
44
45
  GGML_METAL_KERNEL_TYPE_GELU_4,
45
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
169
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
171
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
172
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -494,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
494
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
495
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
496
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
497
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
498
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
499
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -623,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
623
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
624
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
625
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
627
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
628
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
629
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -633,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
633
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
634
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
635
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
636
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
644
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
645
645
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
646
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -732,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
732
732
  switch (ggml_get_unary_op(op)) {
733
733
  case GGML_UNARY_OP_TANH:
734
734
  case GGML_UNARY_OP_RELU:
735
+ case GGML_UNARY_OP_SIGMOID:
735
736
  case GGML_UNARY_OP_GELU:
736
737
  case GGML_UNARY_OP_GELU_QUICK:
737
738
  case GGML_UNARY_OP_SILU:
@@ -759,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
759
760
  case GGML_OP_GROUP_NORM:
760
761
  return ctx->support_simdgroup_reduction;
761
762
  case GGML_OP_NORM:
762
- case GGML_OP_ALIBI:
763
763
  case GGML_OP_ROPE:
764
764
  case GGML_OP_IM2COL:
765
765
  return true;
@@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
772
772
  case GGML_OP_TIMESTEP_EMBEDDING:
773
773
  case GGML_OP_ARGSORT:
774
774
  case GGML_OP_LEAKY_RELU:
775
- case GGML_OP_FLASH_ATTN_EXT:
776
775
  return true;
776
+ case GGML_OP_FLASH_ATTN_EXT:
777
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
777
778
  case GGML_OP_MUL_MAT:
778
779
  case GGML_OP_MUL_MAT_ID:
779
780
  return ctx->support_simdgroup_reduction &&
@@ -1194,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
1194
1195
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1196
  } break;
1196
1197
  case GGML_OP_CLAMP:
1197
- {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1198
+ {
1199
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1199
1200
 
1200
- float min;
1201
- float max;
1202
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1203
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1201
+ float min;
1202
+ float max;
1203
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1204
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1204
1205
 
1205
- [encoder setComputePipelineState:pipeline];
1206
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1208
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1209
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1206
+ [encoder setComputePipelineState:pipeline];
1207
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1208
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1209
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1210
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1210
1211
 
1211
- const int64_t n = ggml_nelements(dst);
1212
+ const int64_t n = ggml_nelements(dst);
1212
1213
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1214
- } break;
1214
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1215
+ } break;
1215
1216
  case GGML_OP_UNARY:
1216
1217
  switch (ggml_get_unary_op(gf->nodes[i])) {
1217
1218
  // we are not taking into account the strides, so for now require contiguous tensors
@@ -1239,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
1239
1240
 
1240
1241
  const int64_t n = ggml_nelements(dst);
1241
1242
 
1243
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1244
+ } break;
1245
+ case GGML_UNARY_OP_SIGMOID:
1246
+ {
1247
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1248
+
1249
+ [encoder setComputePipelineState:pipeline];
1250
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1251
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1252
+
1253
+ const int64_t n = ggml_nelements(dst);
1254
+
1242
1255
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1243
1256
  } break;
1244
1257
  case GGML_UNARY_OP_GELU:
@@ -1357,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
1357
1370
  case GGML_OP_SOFT_MAX:
1358
1371
  {
1359
1372
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1360
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1361
1373
 
1362
1374
  int nth = 32; // SIMD width
1363
1375
 
1364
1376
  id<MTLComputePipelineState> pipeline = nil;
1365
1377
 
1366
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1378
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1367
1379
 
1368
1380
  if (ne00%4 == 0) {
1369
- while (nth < ne00/4 && nth < 256) {
1381
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1370
1382
  nth *= 2;
1371
1383
  }
1372
1384
  if (use_f16) {
@@ -1375,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
1375
1387
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1376
1388
  }
1377
1389
  } else {
1378
- while (nth < ne00 && nth < 1024) {
1390
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1379
1391
  nth *= 2;
1380
1392
  }
1381
1393
  if (use_f16) {
@@ -1394,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
1394
1406
  const int64_t nrows_x = ggml_nrows(src0);
1395
1407
  const int64_t nrows_y = src0->ne[1];
1396
1408
 
1397
- const uint32_t n_head_kv = nrows_x/nrows_y;
1398
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1409
+ const uint32_t n_head = nrows_x/nrows_y;
1410
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1399
1411
 
1400
1412
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1401
1413
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1407,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
1407
1419
  } else {
1408
1420
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1409
1421
  }
1410
- if (id_src2) {
1411
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1412
- } else {
1413
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1414
- }
1415
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1416
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1417
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1418
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1419
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1420
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1421
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1422
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1423
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1422
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1423
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1424
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1425
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1426
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1427
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1428
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1429
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1430
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1424
1431
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1425
1432
 
1426
1433
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2225,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
2225
2232
 
2226
2233
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2227
2234
  } break;
2228
- case GGML_OP_ALIBI:
2229
- {
2230
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2231
-
2232
- const int nth = MIN(1024, ne00);
2233
-
2234
- //const int n_past = ((int32_t *) dst->op_params)[0];
2235
- const int n_head = ((int32_t *) dst->op_params)[1];
2236
-
2237
- float max_bias;
2238
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2239
-
2240
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2241
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2242
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2243
-
2244
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2245
-
2246
- [encoder setComputePipelineState:pipeline];
2247
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2251
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2252
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2253
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2254
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2255
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2256
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2257
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2258
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2259
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2260
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2261
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2262
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2263
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2264
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2265
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2266
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2267
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2268
-
2269
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2270
- } break;
2271
2235
  case GGML_OP_ROPE:
2272
2236
  {
2273
2237
  GGML_ASSERT(ne10 == ne02);
@@ -2389,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
2389
2353
  {
2390
2354
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2391
2355
 
2392
- const int sf = dst->op_params[0];
2356
+ const float sf0 = (float)ne0/src0->ne[0];
2357
+ const float sf1 = (float)ne1/src0->ne[1];
2358
+ const float sf2 = (float)ne2/src0->ne[2];
2359
+ const float sf3 = (float)ne3/src0->ne[3];
2393
2360
 
2394
2361
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2395
2362
 
@@ -2412,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
2412
2379
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2413
2380
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2414
2381
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2415
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2382
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2383
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2384
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2385
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2416
2386
 
2417
2387
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2418
2388
 
@@ -2548,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
2548
2518
  } break;
2549
2519
  case GGML_OP_FLASH_ATTN_EXT:
2550
2520
  {
2551
- GGML_ASSERT(ne00 % 4 == 0);
2521
+ GGML_ASSERT(ne00 % 4 == 0);
2522
+ GGML_ASSERT(ne11 % 32 == 0);
2523
+
2552
2524
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2553
2525
 
2554
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2526
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2555
2527
 
2556
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2557
- GGML_ASSERT(src3);
2528
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2558
2529
 
2559
2530
  size_t offs_src3 = 0;
2560
2531
 
@@ -2564,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
2564
2535
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2565
2536
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2566
2537
 
2538
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
+
2567
2543
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2568
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2544
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2569
2545
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2570
2546
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2571
2547
 
@@ -2577,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
2577
2553
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2578
2554
 
2579
2555
  float scale;
2580
- memcpy(&scale, dst->op_params, sizeof(float));
2556
+ float max_bias;
2557
+
2558
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2559
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2560
+
2561
+ const uint32_t n_head = src0->ne[2];
2562
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2563
+
2564
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2565
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2581
2566
 
2582
2567
  id<MTLComputePipelineState> pipeline = nil;
2583
2568
 
@@ -2614,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
2614
2599
  }
2615
2600
 
2616
2601
  [encoder setComputePipelineState:pipeline];
2617
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2618
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2619
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2620
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2621
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2622
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2623
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2624
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2625
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2626
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2627
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2628
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2629
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2630
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2631
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2632
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2633
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2634
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2635
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2636
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2637
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2638
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2639
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2640
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2641
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2642
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2643
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2644
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2602
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2603
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2604
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2605
+ if (id_src3) {
2606
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2607
+ } else {
2608
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2609
+ }
2610
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2611
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2612
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2613
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2614
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2615
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2616
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2617
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2618
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2619
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2620
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2621
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2622
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2623
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2624
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2625
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2626
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2627
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2628
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2629
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2630
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2631
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2632
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2633
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2645
2634
 
2646
2635
  if (!use_vec_kernel) {
2647
2636
  // half8x8 kernel