llama_cpp 0.15.3 → 0.15.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d0a9cdf86695522e27b1e8d3ed485dfa6ab3a4fc23d9bd9e44bf8c3cb483c347
4
- data.tar.gz: 5d97cec87f9b1df94f85f9e18dc46a1b8a4ec593c17d04e4bee0da3d28c34211
3
+ metadata.gz: 167132898a0cb63faaf4fd7583d9b988992ba7c5ec0f5602d5a158f04e0cdfa0
4
+ data.tar.gz: 8a65658eb93b9cf80d5ede554b15968c495f045c32e57cc96ed732c56330d25f
5
5
  SHA512:
6
- metadata.gz: 71f26009b872db64d0d0d416153b5fbd6afb598617b701cb6342d099542c962f410bccddf80b77928bfd8ab8f017a749fbc1d2ed488139d806ef0e3cf75a0e42
7
- data.tar.gz: 808c03f6664af65cadfea23071d0b55d459c119189346762ea9632156f7f35b8d1f0e594b356726fc26abdb1c81a3bce9d697b9ca2d6324c454a31f2a442f0d7
6
+ metadata.gz: 9625ac088c4d5c50cc51bbbcbc744cb7041766ccbb7a42a9cd1b80b29ebe64414d39875dea5d61a87025e239ad78be2a2ea4d3f85a187684321e409fc01a40fd
7
+ data.tar.gz: 6f68445f10765a4eb1124ed1cfd2afb7544d146823efad27b2b6955bb0ee822ae8b0f9cccb68777c8cb211f665a0e2531eba04a4240399af1101a5dbcd645ae9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## [[0.15.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.3...v0.15.4)] - 2024-06-01
2
+
3
+ - Bump llama.cpp from b2988 to b3056.
4
+ - Add LLAMA_VOCAB_PRE_TYPE_SMAUG constant.
5
+ - Add `token_is_control?` method to `Model`.
6
+
1
7
  ## [[0.15.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.2...v0.15.3)] - 2024-05-25
2
8
 
3
9
  - Bump llama.cpp from b2917 to b2988.
@@ -1536,6 +1536,7 @@ public:
1536
1536
  rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1537
1537
  rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1538
1538
  rb_define_method(rb_cLLaMAModel, "token_is_eog?", RUBY_METHOD_FUNC(_llama_model_token_is_eog), 1);
1539
+ rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
1539
1540
  }
1540
1541
 
1541
1542
  private:
@@ -1848,6 +1849,16 @@ private:
1848
1849
  LLaMAModelWrapper* ptr = get_llama_model(self);
1849
1850
  return llama_token_is_eog(ptr->model, token) ? Qtrue : Qfalse;
1850
1851
  }
1852
+
1853
+ static VALUE _llama_model_token_is_control(VALUE self, VALUE token_) {
1854
+ if (!RB_INTEGER_TYPE_P(token_)) {
1855
+ rb_raise(rb_eArgError, "token must be an integer");
1856
+ return Qnil;
1857
+ }
1858
+ const llama_token token = NUM2INT(token_);
1859
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1860
+ return llama_token_is_control(ptr->model, token) ? Qtrue : Qfalse;
1861
+ }
1851
1862
  };
1852
1863
 
1853
1864
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -3482,6 +3493,7 @@ extern "C" void Init_llama_cpp(void) {
3482
3493
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_QWEN2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_QWEN2));
3483
3494
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_OLMO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_OLMO));
3484
3495
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DBRX", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DBRX));
3496
+ rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMAUG", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMAUG));
3485
3497
 
3486
3498
  rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNDEFINED", INT2NUM(LLAMA_TOKEN_TYPE_UNDEFINED));
3487
3499
  rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_NORMAL", INT2NUM(LLAMA_TOKEN_TYPE_NORMAL));
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.15.3'
6
+ VERSION = '0.15.4'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b2988'
9
+ LLAMA_CPP_VERSION = 'b3056'
10
10
  end
data/sig/llama_cpp.rbs CHANGED
@@ -30,6 +30,7 @@ module LLaMACpp
30
30
  LLAMA_VOCAB_PRE_TYPE_QWEN2: Integer
31
31
  LLAMA_VOCAB_PRE_TYPE_OLMO: Integer
32
32
  LLAMA_VOCAB_PRE_TYPE_DBRX: Integer
33
+ LLAMA_VOCAB_PRE_TYPE_SMAUG: Integer
33
34
 
34
35
  LLAMA_FTYPE_ALL_F32: Integer
35
36
  LLAMA_FTYPE_MOSTLY_F16: Integer
@@ -159,6 +160,7 @@ module LLaMACpp
159
160
  def token_suffix: () -> Integer
160
161
  def token_eot: () -> Integer
161
162
  def token_is_eog?: (Integer) -> bool
163
+ def token_is_control?: (Integer) -> bool
162
164
  end
163
165
 
164
166
  class Timings
@@ -443,6 +443,9 @@ endif # JETSON_EOL_MODULE_DETECT
443
443
  ifdef LLAMA_DEBUG
444
444
  MK_NVCCFLAGS += -lineinfo
445
445
  endif # LLAMA_DEBUG
446
+ ifdef LLAMA_CUDA_DEBUG
447
+ MK_NVCCFLAGS += --device-debug
448
+ endif # LLAMA_CUDA_DEBUG
446
449
  ifdef LLAMA_CUDA_NVCC
447
450
  NVCC = $(CCACHE) $(LLAMA_CUDA_NVCC)
448
451
  else
@@ -749,7 +752,7 @@ lib: llama.o ggml.o $(OBJS)
749
752
  ar rcs libllama.a $^
750
753
 
751
754
  clean:
752
- rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
755
+ rm -vrf *.o tests/*.o *.so *.a *.dll *.dylib benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
753
756
  rm -vrf ggml-cuda/*.o
754
757
 
755
758
  #
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
119
119
  return id;
120
120
  }
121
121
 
122
+ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
123
+ ggml_cuda_set_device(device);
124
+ #if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
125
+ auto res = hipMallocManaged(ptr, size);
126
+ if (res == hipSuccess) {
127
+ // if error we "need" to know why...
128
+ CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
129
+ }
130
+ return res;
131
+ #else
132
+ return cudaMalloc(ptr, size);
133
+ #endif
134
+ }
135
+
122
136
  static ggml_cuda_device_info ggml_cuda_init() {
123
137
  #ifdef __HIP_PLATFORM_AMD__
124
138
  // Workaround for a rocBLAS bug when using multiple graphics cards:
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
271
285
  size_t look_ahead_size = (size_t) (1.05 * size);
272
286
  look_ahead_size = 256 * ((look_ahead_size + 255)/256);
273
287
  ggml_cuda_set_device(device);
274
- CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
288
+ CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
275
289
  *actual_size = look_ahead_size;
276
290
  pool_size += look_ahead_size;
277
291
  #ifdef DEBUG_CUDA_MALLOC
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
537
551
  size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
538
552
 
539
553
  void * dev_ptr;
540
- cudaError_t err = cudaMalloc(&dev_ptr, size);
554
+ cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
541
555
  if (err != cudaSuccess) {
542
556
  // clear the error
543
557
  cudaGetLastError();
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
798
812
  // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
799
813
  ggml_cuda_set_device(id);
800
814
  char * buf;
801
- CUDA_CHECK(cudaMalloc(&buf, size));
815
+ CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
802
816
 
803
817
  // set padding to 0 to avoid possible NaN values
804
818
  if (size > original_size) {
@@ -1856,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1856
1870
  }
1857
1871
  }
1858
1872
  #else
1859
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
1873
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1860
1874
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1861
1875
  // use cublasGemmStridedBatchedEx
1862
1876
  CUBLAS_CHECK(
@@ -2510,9 +2524,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2510
2524
 
2511
2525
  bool use_cuda_graph = true;
2512
2526
  bool cuda_graph_update_required = false;
2513
- // pointer to CUDA cpy kernel, which is required to identify
2527
+ // vector of pointers to CUDA cpy kernels, which are required to identify
2514
2528
  // kernel parameters which need updated in the graph for each token
2515
- void * ggml_cuda_cpy_fn_ptr = nullptr;
2529
+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2516
2530
 
2517
2531
  if (cuda_ctx->cuda_graph->graph == nullptr) {
2518
2532
  if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2588,9 +2602,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2588
2602
  if (node->op == GGML_OP_CPY) {
2589
2603
  // store the copy op parameter which changes with each token.
2590
2604
  cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2591
- if (ggml_cuda_cpy_fn_ptr == nullptr) {
2592
- // store a pointer to the copy op CUDA kernel to identify it later
2593
- ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2605
+ // store a pointer to each copy op CUDA kernel to identify it later
2606
+ void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2607
+ if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2608
+ ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2594
2609
  }
2595
2610
  }
2596
2611
 
@@ -2720,7 +2735,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2720
2735
  if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2721
2736
  int k = 0;
2722
2737
  for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2723
- if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
2738
+ if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2724
2739
  char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2725
2740
  cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2726
2741
  CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
@@ -2871,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2871
2886
  case GGML_OP_CONT:
2872
2887
  case GGML_OP_DIAG_MASK_INF:
2873
2888
  case GGML_OP_SOFT_MAX:
2889
+ return true;
2874
2890
  case GGML_OP_ROPE:
2891
+ return ggml_is_contiguous(op->src[0]);
2875
2892
  case GGML_OP_IM2COL:
2876
2893
  case GGML_OP_POOL_2D:
2877
2894
  case GGML_OP_SUM_ROWS:
@@ -144,6 +144,10 @@ extern "C" {
144
144
  #endif
145
145
  #endif
146
146
 
147
+ #if defined(__ARM_FEATURE_SVE)
148
+ #include <arm_sve.h>
149
+ #endif
150
+
147
151
  // 16-bit float
148
152
  // on Arm, we use __fp16
149
153
  // on x86, we use uint16_t
@@ -1597,7 +1597,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1597
1597
  {
1598
1598
  GGML_ASSERT(ne00 == ne10);
1599
1599
 
1600
- // TODO: assert that dim2 and dim3 are contiguous
1601
1600
  GGML_ASSERT(ne12 % ne02 == 0);
1602
1601
  GGML_ASSERT(ne13 % ne03 == 0);
1603
1602
 
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
  GGML_METAL_KERNEL_TYPE_SCALE,
39
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
184
188
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
189
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
190
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
191
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
192
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
193
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
194
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
195
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
488
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -634,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
634
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
643
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
645
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
647
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641
649
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642
650
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
754
  case GGML_OP_ACC:
747
755
  case GGML_OP_MUL:
748
756
  case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
749
758
  case GGML_OP_SCALE:
750
759
  case GGML_OP_CLAMP:
751
760
  case GGML_OP_SQR:
@@ -770,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770
779
  case GGML_OP_LEAKY_RELU:
771
780
  return true;
772
781
  case GGML_OP_FLASH_ATTN_EXT:
782
+ if (op->src[0]->ne[0] == 256) {
783
+ return false;
784
+ }
773
785
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
774
786
  case GGML_OP_MUL_MAT:
775
787
  case GGML_OP_MUL_MAT_ID:
@@ -976,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
976
988
  switch (dst->op) {
977
989
  case GGML_OP_CONCAT:
978
990
  {
979
- const int64_t nb = ne00;
980
-
981
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
982
992
 
993
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
994
+
983
995
  [encoder setComputePipelineState:pipeline];
984
996
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
985
997
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1008,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
1008
1020
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1009
1021
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1010
1022
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1011
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1023
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1012
1024
 
1013
1025
  const int nth = MIN(1024, ne0);
1014
1026
 
@@ -1018,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
1018
1030
  case GGML_OP_MUL:
1019
1031
  case GGML_OP_DIV:
1020
1032
  {
1033
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1034
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1035
+
1021
1036
  const size_t offs = 0;
1022
1037
 
1023
1038
  bool bcast_row = false;
1024
1039
 
1025
- int64_t nb = ne00;
1040
+ int64_t nb = ne00; // used by the "row" kernels
1026
1041
 
1027
1042
  id<MTLComputePipelineState> pipeline = nil;
1028
1043
 
@@ -1091,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
1091
1106
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1092
1107
  }
1093
1108
  } break;
1109
+ case GGML_OP_REPEAT:
1110
+ {
1111
+ id<MTLComputePipelineState> pipeline;
1112
+
1113
+ switch (src0t) {
1114
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1115
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1116
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1117
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1118
+ default: GGML_ASSERT(false);
1119
+ }
1120
+
1121
+ [encoder setComputePipelineState:pipeline];
1122
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1123
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1124
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1125
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1126
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1127
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1128
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1129
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1130
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1131
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1132
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1133
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1134
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1135
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1136
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1137
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1138
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1139
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1140
+
1141
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1142
+
1143
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1144
+ } break;
1094
1145
  case GGML_OP_ACC:
1095
1146
  {
1096
1147
  GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1468,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
1468
1519
  {
1469
1520
  GGML_ASSERT(ne00 == ne10);
1470
1521
 
1471
- // TODO: assert that dim2 and dim3 are contiguous
1472
1522
  GGML_ASSERT(ne12 % ne02 == 0);
1473
1523
  GGML_ASSERT(ne13 % ne03 == 0);
1474
1524
 
@@ -2136,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
2136
2186
  case GGML_OP_RMS_NORM:
2137
2187
  {
2138
2188
  GGML_ASSERT(ne00 % 4 == 0);
2189
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2139
2190
 
2140
2191
  float eps;
2141
2192
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -2163,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
2163
2214
  case GGML_OP_GROUP_NORM:
2164
2215
  {
2165
2216
  GGML_ASSERT(ne00 % 4 == 0);
2217
+ GGML_ASSERT(ggml_is_contiguous(src0));
2166
2218
 
2167
2219
  //float eps;
2168
2220
  //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2196,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
2196
2248
  } break;
2197
2249
  case GGML_OP_NORM:
2198
2250
  {
2251
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2252
+
2199
2253
  float eps;
2200
2254
  memcpy(&eps, dst->op_params, sizeof(float));
2201
2255
 
@@ -2573,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
2573
2627
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2574
2628
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2575
2629
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2576
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2630
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2577
2631
  default:
2578
2632
  {
2579
2633
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
2586
2640
 
2587
2641
  switch (ne00) {
2588
2642
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2589
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2643
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2590
2644
  default:
2591
2645
  {
2592
2646
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
168
  }
169
169
  }
170
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
171
218
  // assumption: src1 is a row
172
219
  // broadcast src1 into src0
173
220
  kernel void kernel_add_row(
@@ -1720,13 +1767,13 @@ kernel void kernel_rope(
1720
1767
 
1721
1768
  const int64_t p = pos[i2];
1722
1769
 
1723
- const float theta_0 = (float)p;
1770
+ const float theta_base = (float)p;
1724
1771
  const float inv_ndims = -1.f/n_dims;
1725
1772
 
1726
1773
  if (!is_neox) {
1727
1774
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1728
1776
 
1729
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1730
1777
  float cos_theta, sin_theta;
1731
1778
  rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1732
1779
 
@@ -1742,18 +1789,14 @@ kernel void kernel_rope(
1742
1789
  } else {
1743
1790
  for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1744
1791
  if (ic < n_dims) {
1745
- const int64_t ib = 0;
1792
+ const int64_t i0 = ic/2;
1746
1793
 
1747
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1748
- const float cur_rot = inv_ndims*ic - ib;
1749
- const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1794
+ const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1750
1795
 
1751
- const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1796
+ const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1752
1797
 
1753
1798
  float cos_theta, sin_theta;
1754
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1755
-
1756
- const int64_t i0 = ib*n_dims + ic/2;
1799
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1757
1800
 
1758
1801
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1759
1802
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -2418,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2418
2461
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2419
2462
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2420
2463
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2464
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2422
2465
 
2423
2466
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2424
2467
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2696
2739
  }
2697
2740
 
2698
2741
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2742
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2700
2743
 
2701
2744
  kernel void kernel_cpy_f16_f16(
2702
2745
  device const half * src0,
@@ -3319,31 +3362,30 @@ kernel void kernel_concat(
3319
3362
  constant uint64_t & nb1,
3320
3363
  constant uint64_t & nb2,
3321
3364
  constant uint64_t & nb3,
3365
+ constant int32_t & dim,
3322
3366
  uint3 tgpig[[threadgroup_position_in_grid]],
3323
3367
  uint3 tpitg[[thread_position_in_threadgroup]],
3324
3368
  uint3 ntg[[threads_per_threadgroup]]) {
3325
3369
 
3326
- const int64_t i03 = tgpig.z;
3327
- const int64_t i02 = tgpig.y;
3328
- const int64_t i01 = tgpig.x;
3370
+ const int64_t i3 = tgpig.z;
3371
+ const int64_t i2 = tgpig.y;
3372
+ const int64_t i1 = tgpig.x;
3329
3373
 
3330
- const int64_t i13 = i03 % ne13;
3331
- const int64_t i12 = i02 % ne12;
3332
- const int64_t i11 = i01 % ne11;
3374
+ int64_t o[4] = {0, 0, 0, 0};
3375
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3333
3376
 
3334
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3335
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3336
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3377
+ device const float * x;
3337
3378
 
3338
3379
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3339
- if (i02 < ne02) {
3340
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3341
- src0_ptr += ntg.x*nb00;
3380
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3381
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3342
3382
  } else {
3343
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3344
- src1_ptr += ntg.x*nb10;
3383
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3345
3384
  }
3346
- dst_ptr += ntg.x*nb0;
3385
+
3386
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3387
+
3388
+ *y = *x;
3347
3389
  }
3348
3390
  }
3349
3391