llama_cpp 0.15.1 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ce6d72aeb5fb9aff775d44284bf934e164f8470973619507ef6e6eb1ac0bec4d
4
- data.tar.gz: 7c1ae823c90f957219b3edbc20f091b65a50caa984c1a6f4d137a46c376b2f0c
3
+ metadata.gz: 30dd4c29b86098faf7c78de5fa8e57021b631bb5eb3d14c93f63f1d186383ab8
4
+ data.tar.gz: b011d891f1cd725f84821428a8db24004b52c9614e785f493f721f7abde71029
5
5
  SHA512:
6
- metadata.gz: d23cb6a63b7734df2547c5e61a699fa206878c747e274e004c829b77335a7cc7434e92168a55d8ab0a617b11eddb5d45d5057a91b92e848735fd9e852b2476cd
7
- data.tar.gz: f54b09de3cc60de81be977e9706a9beb3bf28e7740a19a57f6add543fe10cd6dc4101cbbe22dd5b62870c78a1ad4d10f57dd29b7c3e3e12b950e6575cf67b0c7
6
+ metadata.gz: 6c1628f93762747688f802db8593946e8581c869f63c610669b45759f644b3d19b061825b788e328b6b984977112837586ed398b6118a8f8e5f0c7f6fd0eb2dd
7
+ data.tar.gz: 2f8c3d9f1e6c0f6db7e0682995c8d34179d5405d32784bf00f04a3408cb5bf4c95557bfa1692026f8d3dc9e672d6b15dec5d33cbd76ddc1d94e5ec964a9d0409
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## [[0.15.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.1...v0.15.2)] - 2024-05-18
2
+
3
+ - Bump llama.cpp from b2839 to b2917.
4
+
5
+ Implementation binding for rpc_servers in llama_model_params has been skipped.
6
+
1
7
  ## [[0.15.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.0...v0.15.1)] - 2024-05-11
2
8
 
3
9
  - Bump llama.cpp from b2781 to b2839.
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.15.1'
6
+ VERSION = '0.15.2'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b2839'
9
+ LLAMA_CPP_VERSION = 'b2917'
10
10
  end
@@ -562,10 +562,10 @@ endif # LLAMA_VULKAN
562
562
  ifdef LLAMA_HIPBLAS
563
563
  ifeq ($(wildcard /opt/rocm),)
564
564
  ROCM_PATH ?= /usr
565
- GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
565
+ AMDGPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
566
566
  else
567
567
  ROCM_PATH ?= /opt/rocm
568
- GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
568
+ AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
569
569
  endif
570
570
  HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc
571
571
  LLAMA_CUDA_DMMV_X ?= 32
@@ -577,7 +577,7 @@ ifdef LLAMA_HIP_UMA
577
577
  endif # LLAMA_HIP_UMA
578
578
  MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
579
579
  MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
580
- HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
580
+ HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
581
581
  HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
582
582
  HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
583
583
  HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
@@ -1182,9 +1182,9 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
1182
1182
  static char * fmt_size(size_t size) {
1183
1183
  static char buffer[128];
1184
1184
  if (size >= 1024*1024) {
1185
- sprintf(buffer, "%zuM", size/1024/1024);
1185
+ snprintf(buffer, sizeof(buffer), "%zuM", size/1024/1024);
1186
1186
  } else {
1187
- sprintf(buffer, "%zuK", size/1024);
1187
+ snprintf(buffer, sizeof(buffer), "%zuK", size/1024);
1188
1188
  }
1189
1189
  return buffer;
1190
1190
  }
@@ -1895,7 +1895,6 @@ void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * t
1895
1895
 
1896
1896
  tensor->buffer = buffer;
1897
1897
  tensor->data = (char *)tensor->view_src->data + tensor->view_offs;
1898
- tensor->backend = tensor->view_src->backend;
1899
1898
  ggml_backend_buffer_init_tensor(buffer, tensor);
1900
1899
  }
1901
1900
 
@@ -4,7 +4,6 @@
4
4
 
5
5
  #include "ggml-cuda/common.cuh"
6
6
  #include "ggml-cuda/acc.cuh"
7
- #include "ggml-cuda/alibi.cuh"
8
7
  #include "ggml-cuda/arange.cuh"
9
8
  #include "ggml-cuda/argsort.cuh"
10
9
  #include "ggml-cuda/binbcast.cuh"
@@ -2205,6 +2204,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2205
2204
  case GGML_UNARY_OP_RELU:
2206
2205
  ggml_cuda_op_relu(ctx, dst);
2207
2206
  break;
2207
+ case GGML_UNARY_OP_SIGMOID:
2208
+ ggml_cuda_op_sigmoid(ctx, dst);
2209
+ break;
2208
2210
  case GGML_UNARY_OP_HARDSIGMOID:
2209
2211
  ggml_cuda_op_hardsigmoid(ctx, dst);
2210
2212
  break;
@@ -2277,9 +2279,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
2277
2279
  case GGML_OP_ROPE:
2278
2280
  ggml_cuda_op_rope(ctx, dst);
2279
2281
  break;
2280
- case GGML_OP_ALIBI:
2281
- ggml_cuda_op_alibi(ctx, dst);
2282
- break;
2283
2282
  case GGML_OP_IM2COL:
2284
2283
  ggml_cuda_op_im2col(ctx, dst);
2285
2284
  break;
@@ -2559,7 +2558,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2559
2558
  }
2560
2559
 
2561
2560
  // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
2562
- if (cuda_graph_update_required) {
2561
+ if (use_cuda_graph && cuda_graph_update_required) {
2563
2562
  cuda_ctx->cuda_graph->number_consecutive_updates++;
2564
2563
  } else {
2565
2564
  cuda_ctx->cuda_graph->number_consecutive_updates = 0;
@@ -2714,12 +2713,14 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2714
2713
  }
2715
2714
 
2716
2715
  GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
2716
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
2717
2717
  switch (op->op) {
2718
2718
  case GGML_OP_UNARY:
2719
2719
  switch (ggml_get_unary_op(op)) {
2720
2720
  case GGML_UNARY_OP_GELU:
2721
2721
  case GGML_UNARY_OP_SILU:
2722
2722
  case GGML_UNARY_OP_RELU:
2723
+ case GGML_UNARY_OP_SIGMOID:
2723
2724
  case GGML_UNARY_OP_HARDSIGMOID:
2724
2725
  case GGML_UNARY_OP_HARDSWISH:
2725
2726
  case GGML_UNARY_OP_GELU_QUICK:
@@ -2829,7 +2830,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2829
2830
  case GGML_OP_DIAG_MASK_INF:
2830
2831
  case GGML_OP_SOFT_MAX:
2831
2832
  case GGML_OP_ROPE:
2832
- case GGML_OP_ALIBI:
2833
2833
  case GGML_OP_IM2COL:
2834
2834
  case GGML_OP_POOL_2D:
2835
2835
  case GGML_OP_SUM_ROWS:
@@ -2841,8 +2841,16 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2841
2841
  case GGML_OP_ARANGE:
2842
2842
  case GGML_OP_TIMESTEP_EMBEDDING:
2843
2843
  case GGML_OP_LEAKY_RELU:
2844
- case GGML_OP_FLASH_ATTN_EXT:
2845
2844
  return true;
2845
+ case GGML_OP_FLASH_ATTN_EXT:
2846
+ #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2847
+ return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128;
2848
+ #else
2849
+ if (op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128) {
2850
+ return true;
2851
+ }
2852
+ return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA;
2853
+ #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2846
2854
  default:
2847
2855
  return false;
2848
2856
  }
@@ -120,9 +120,16 @@ extern "C" {
120
120
  #ifndef __F16C__
121
121
  #define __F16C__
122
122
  #endif
123
+ #endif
124
+
125
+ // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
126
+ #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
123
127
  #ifndef __SSE3__
124
128
  #define __SSE3__
125
129
  #endif
130
+ #ifndef __SSSE3__
131
+ #define __SSSE3__
132
+ #endif
126
133
  #endif
127
134
 
128
135
  // 16-bit float
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1559
1559
  case GGML_OP_SOFT_MAX:
1560
1560
  {
1561
1561
  float scale;
1562
- memcpy(&scale, dst->op_params, sizeof(float));
1562
+ float max_bias;
1563
1563
 
1564
- #pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
1564
+ memcpy(&scale, (float *)dst->op_params + 0, sizeof(float));
1565
+ memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
1566
+
1567
+ #pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
1565
1568
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
1566
1569
  GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
1567
- GGML_ASSERT(src2 == nullptr);
1570
+
1571
+ #pragma message("TODO: add ALiBi support")
1572
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
1573
+ GGML_ASSERT(max_bias == 0.0f);
1568
1574
 
1569
1575
  ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
1570
1576
  } break;
@@ -40,6 +40,7 @@ enum ggml_metal_kernel_type {
40
40
  GGML_METAL_KERNEL_TYPE_CLAMP,
41
41
  GGML_METAL_KERNEL_TYPE_TANH,
42
42
  GGML_METAL_KERNEL_TYPE_RELU,
43
+ GGML_METAL_KERNEL_TYPE_SIGMOID,
43
44
  GGML_METAL_KERNEL_TYPE_GELU,
44
45
  GGML_METAL_KERNEL_TYPE_GELU_4,
45
46
  GGML_METAL_KERNEL_TYPE_GELU_QUICK,
@@ -169,7 +170,6 @@ enum ggml_metal_kernel_type {
169
170
  GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
170
171
  GGML_METAL_KERNEL_TYPE_ROPE_F32,
171
172
  GGML_METAL_KERNEL_TYPE_ROPE_F16,
172
- GGML_METAL_KERNEL_TYPE_ALIBI_F32,
173
173
  GGML_METAL_KERNEL_TYPE_IM2COL_F16,
174
174
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
175
175
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -494,6 +494,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
494
494
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
495
495
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
496
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
497
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
497
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
498
499
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
499
500
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
@@ -623,7 +624,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
623
624
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
624
625
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
625
626
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
626
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
627
627
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
628
628
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
629
629
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -633,14 +633,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
633
633
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
634
634
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
635
635
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
636
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
638
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
640
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
641
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
642
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, true);
643
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, true);
636
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, ctx->support_simdgroup_mm);
637
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, ctx->support_simdgroup_mm);
638
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
639
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
640
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
641
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
642
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
643
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
644
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
645
645
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
646
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -732,6 +732,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
732
732
  switch (ggml_get_unary_op(op)) {
733
733
  case GGML_UNARY_OP_TANH:
734
734
  case GGML_UNARY_OP_RELU:
735
+ case GGML_UNARY_OP_SIGMOID:
735
736
  case GGML_UNARY_OP_GELU:
736
737
  case GGML_UNARY_OP_GELU_QUICK:
737
738
  case GGML_UNARY_OP_SILU:
@@ -759,7 +760,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
759
760
  case GGML_OP_GROUP_NORM:
760
761
  return ctx->support_simdgroup_reduction;
761
762
  case GGML_OP_NORM:
762
- case GGML_OP_ALIBI:
763
763
  case GGML_OP_ROPE:
764
764
  case GGML_OP_IM2COL:
765
765
  return true;
@@ -772,8 +772,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
772
772
  case GGML_OP_TIMESTEP_EMBEDDING:
773
773
  case GGML_OP_ARGSORT:
774
774
  case GGML_OP_LEAKY_RELU:
775
- case GGML_OP_FLASH_ATTN_EXT:
776
775
  return true;
776
+ case GGML_OP_FLASH_ATTN_EXT:
777
+ return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
777
778
  case GGML_OP_MUL_MAT:
778
779
  case GGML_OP_MUL_MAT_ID:
779
780
  return ctx->support_simdgroup_reduction &&
@@ -1194,24 +1195,24 @@ static enum ggml_status ggml_metal_graph_compute(
1194
1195
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1195
1196
  } break;
1196
1197
  case GGML_OP_CLAMP:
1197
- {
1198
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1198
+ {
1199
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CLAMP].pipeline;
1199
1200
 
1200
- float min;
1201
- float max;
1202
- memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1203
- memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1201
+ float min;
1202
+ float max;
1203
+ memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
1204
+ memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
1204
1205
 
1205
- [encoder setComputePipelineState:pipeline];
1206
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1207
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1208
- [encoder setBytes:&min length:sizeof(min) atIndex:2];
1209
- [encoder setBytes:&max length:sizeof(max) atIndex:3];
1206
+ [encoder setComputePipelineState:pipeline];
1207
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1208
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1209
+ [encoder setBytes:&min length:sizeof(min) atIndex:2];
1210
+ [encoder setBytes:&max length:sizeof(max) atIndex:3];
1210
1211
 
1211
- const int64_t n = ggml_nelements(dst);
1212
+ const int64_t n = ggml_nelements(dst);
1212
1213
 
1213
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1214
- } break;
1214
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1215
+ } break;
1215
1216
  case GGML_OP_UNARY:
1216
1217
  switch (ggml_get_unary_op(gf->nodes[i])) {
1217
1218
  // we are not taking into account the strides, so for now require contiguous tensors
@@ -1239,6 +1240,18 @@ static enum ggml_status ggml_metal_graph_compute(
1239
1240
 
1240
1241
  const int64_t n = ggml_nelements(dst);
1241
1242
 
1243
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1244
+ } break;
1245
+ case GGML_UNARY_OP_SIGMOID:
1246
+ {
1247
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
1248
+
1249
+ [encoder setComputePipelineState:pipeline];
1250
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1251
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1252
+
1253
+ const int64_t n = ggml_nelements(dst);
1254
+
1242
1255
  [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1243
1256
  } break;
1244
1257
  case GGML_UNARY_OP_GELU:
@@ -1357,16 +1370,15 @@ static enum ggml_status ggml_metal_graph_compute(
1357
1370
  case GGML_OP_SOFT_MAX:
1358
1371
  {
1359
1372
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
1360
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
1361
1373
 
1362
1374
  int nth = 32; // SIMD width
1363
1375
 
1364
1376
  id<MTLComputePipelineState> pipeline = nil;
1365
1377
 
1366
- const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
1378
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
1367
1379
 
1368
1380
  if (ne00%4 == 0) {
1369
- while (nth < ne00/4 && nth < 256) {
1381
+ while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1370
1382
  nth *= 2;
1371
1383
  }
1372
1384
  if (use_f16) {
@@ -1375,7 +1387,7 @@ static enum ggml_status ggml_metal_graph_compute(
1375
1387
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4].pipeline;
1376
1388
  }
1377
1389
  } else {
1378
- while (nth < ne00 && nth < 1024) {
1390
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1379
1391
  nth *= 2;
1380
1392
  }
1381
1393
  if (use_f16) {
@@ -1394,8 +1406,8 @@ static enum ggml_status ggml_metal_graph_compute(
1394
1406
  const int64_t nrows_x = ggml_nrows(src0);
1395
1407
  const int64_t nrows_y = src0->ne[1];
1396
1408
 
1397
- const uint32_t n_head_kv = nrows_x/nrows_y;
1398
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1409
+ const uint32_t n_head = nrows_x/nrows_y;
1410
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1399
1411
 
1400
1412
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1401
1413
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1407,20 +1419,15 @@ static enum ggml_status ggml_metal_graph_compute(
1407
1419
  } else {
1408
1420
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1409
1421
  }
1410
- if (id_src2) {
1411
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1412
- } else {
1413
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1414
- }
1415
- [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1416
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1417
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1418
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1419
- [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1420
- [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1421
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1422
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1423
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1422
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1423
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1424
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1425
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1426
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
1427
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:7];
1428
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:8];
1429
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:9];
1430
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
1424
1431
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1425
1432
 
1426
1433
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2225,49 +2232,6 @@ static enum ggml_status ggml_metal_graph_compute(
2225
2232
 
2226
2233
  [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2227
2234
  } break;
2228
- case GGML_OP_ALIBI:
2229
- {
2230
- GGML_ASSERT((src0t == GGML_TYPE_F32));
2231
-
2232
- const int nth = MIN(1024, ne00);
2233
-
2234
- //const int n_past = ((int32_t *) dst->op_params)[0];
2235
- const int n_head = ((int32_t *) dst->op_params)[1];
2236
-
2237
- float max_bias;
2238
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2239
-
2240
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
2241
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
2242
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
2243
-
2244
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
2245
-
2246
- [encoder setComputePipelineState:pipeline];
2247
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2248
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2249
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2250
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
2251
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
2252
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
2253
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
2254
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
2255
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
2256
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
2257
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
2258
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
2259
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
2260
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
2261
- [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
2262
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
2263
- [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
2264
- [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
2265
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
2266
- [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
2267
- [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
2268
-
2269
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2270
- } break;
2271
2235
  case GGML_OP_ROPE:
2272
2236
  {
2273
2237
  GGML_ASSERT(ne10 == ne02);
@@ -2389,7 +2353,10 @@ static enum ggml_status ggml_metal_graph_compute(
2389
2353
  {
2390
2354
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2391
2355
 
2392
- const int sf = dst->op_params[0];
2356
+ const float sf0 = (float)ne0/src0->ne[0];
2357
+ const float sf1 = (float)ne1/src0->ne[1];
2358
+ const float sf2 = (float)ne2/src0->ne[2];
2359
+ const float sf3 = (float)ne3/src0->ne[3];
2393
2360
 
2394
2361
  const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
2395
2362
 
@@ -2412,7 +2379,10 @@ static enum ggml_status ggml_metal_graph_compute(
2412
2379
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
2413
2380
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
2414
2381
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
2415
- [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
2382
+ [encoder setBytes:&sf0 length:sizeof(sf0) atIndex:18];
2383
+ [encoder setBytes:&sf1 length:sizeof(sf1) atIndex:19];
2384
+ [encoder setBytes:&sf2 length:sizeof(sf2) atIndex:20];
2385
+ [encoder setBytes:&sf3 length:sizeof(sf3) atIndex:21];
2416
2386
 
2417
2387
  const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
2418
2388
 
@@ -2548,13 +2518,14 @@ static enum ggml_status ggml_metal_graph_compute(
2548
2518
  } break;
2549
2519
  case GGML_OP_FLASH_ATTN_EXT:
2550
2520
  {
2551
- GGML_ASSERT(ne00 % 4 == 0);
2521
+ GGML_ASSERT(ne00 % 4 == 0);
2522
+ GGML_ASSERT(ne11 % 32 == 0);
2523
+
2552
2524
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
2553
2525
 
2554
- struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2526
+ GGML_ASSERT(ggml_are_same_shape (src1, src2));
2555
2527
 
2556
- GGML_ASSERT(ggml_are_same_shape(src1, src2));
2557
- GGML_ASSERT(src3);
2528
+ struct ggml_tensor * src3 = gf->nodes[i]->src[3];
2558
2529
 
2559
2530
  size_t offs_src3 = 0;
2560
2531
 
@@ -2564,8 +2535,13 @@ static enum ggml_status ggml_metal_graph_compute(
2564
2535
  GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
2565
2536
  "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
2566
2537
 
2538
+ const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
2539
+ const uint64_t nb21 = src2 ? src2->nb[1] : 0;
2540
+ const uint64_t nb22 = src2 ? src2->nb[2] : 0;
2541
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0;
2542
+
2567
2543
  const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2568
- const int64_t ne31 = src3 ? src3->ne[1] : 0;
2544
+ //const int64_t ne31 = src3 ? src3->ne[1] : 0;
2569
2545
  const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
2570
2546
  const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
2571
2547
 
@@ -2577,7 +2553,16 @@ static enum ggml_status ggml_metal_graph_compute(
2577
2553
  const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
2578
2554
 
2579
2555
  float scale;
2580
- memcpy(&scale, dst->op_params, sizeof(float));
2556
+ float max_bias;
2557
+
2558
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2559
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2560
+
2561
+ const uint32_t n_head = src0->ne[2];
2562
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2563
+
2564
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2565
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2581
2566
 
2582
2567
  id<MTLComputePipelineState> pipeline = nil;
2583
2568
 
@@ -2614,34 +2599,38 @@ static enum ggml_status ggml_metal_graph_compute(
2614
2599
  }
2615
2600
 
2616
2601
  [encoder setComputePipelineState:pipeline];
2617
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2618
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2619
- [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2620
- [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2621
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2622
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2623
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2624
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2625
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2626
- [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2627
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2628
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2629
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2630
- [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2631
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2632
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2633
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2634
- [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2635
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2636
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2637
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2638
- [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2639
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2640
- [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2641
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2642
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2643
- [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2644
- [encoder setBytes:&scale length:sizeof( float) atIndex:27];
2602
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2603
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2604
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2605
+ if (id_src3) {
2606
+ [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2607
+ } else {
2608
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
2609
+ }
2610
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2611
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
2612
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
2613
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
2614
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
2615
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
2616
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
2617
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
2618
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
2619
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
2620
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
2621
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
2622
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
2623
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
2624
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
2625
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
2626
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
2627
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
2628
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
2629
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
2630
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
2631
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
2632
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
2633
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
2645
2634
 
2646
2635
  if (!use_vec_kernel) {
2647
2636
  // half8x8 kernel