llama_cpp 0.15.3 → 0.15.4

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d0a9cdf86695522e27b1e8d3ed485dfa6ab3a4fc23d9bd9e44bf8c3cb483c347
4
- data.tar.gz: 5d97cec87f9b1df94f85f9e18dc46a1b8a4ec593c17d04e4bee0da3d28c34211
3
+ metadata.gz: 167132898a0cb63faaf4fd7583d9b988992ba7c5ec0f5602d5a158f04e0cdfa0
4
+ data.tar.gz: 8a65658eb93b9cf80d5ede554b15968c495f045c32e57cc96ed732c56330d25f
5
5
  SHA512:
6
- metadata.gz: 71f26009b872db64d0d0d416153b5fbd6afb598617b701cb6342d099542c962f410bccddf80b77928bfd8ab8f017a749fbc1d2ed488139d806ef0e3cf75a0e42
7
- data.tar.gz: 808c03f6664af65cadfea23071d0b55d459c119189346762ea9632156f7f35b8d1f0e594b356726fc26abdb1c81a3bce9d697b9ca2d6324c454a31f2a442f0d7
6
+ metadata.gz: 9625ac088c4d5c50cc51bbbcbc744cb7041766ccbb7a42a9cd1b80b29ebe64414d39875dea5d61a87025e239ad78be2a2ea4d3f85a187684321e409fc01a40fd
7
+ data.tar.gz: 6f68445f10765a4eb1124ed1cfd2afb7544d146823efad27b2b6955bb0ee822ae8b0f9cccb68777c8cb211f665a0e2531eba04a4240399af1101a5dbcd645ae9
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## [[0.15.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.3...v0.15.4)] - 2024-06-01
2
+
3
+ - Bump llama.cpp from b2988 to b3056.
4
+ - Add LLAMA_VOCAB_PRE_TYPE_SMAUG constant.
5
+ - Add `token_is_control?` method to `Model`.
6
+
1
7
  ## [[0.15.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.2...v0.15.3)] - 2024-05-25
2
8
 
3
9
  - Bump llama.cpp from b2917 to b2988.
@@ -1536,6 +1536,7 @@ public:
1536
1536
  rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
1537
1537
  rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
1538
1538
  rb_define_method(rb_cLLaMAModel, "token_is_eog?", RUBY_METHOD_FUNC(_llama_model_token_is_eog), 1);
1539
+ rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
1539
1540
  }
1540
1541
 
1541
1542
  private:
@@ -1848,6 +1849,16 @@ private:
1848
1849
  LLaMAModelWrapper* ptr = get_llama_model(self);
1849
1850
  return llama_token_is_eog(ptr->model, token) ? Qtrue : Qfalse;
1850
1851
  }
1852
+
1853
+ static VALUE _llama_model_token_is_control(VALUE self, VALUE token_) {
1854
+ if (!RB_INTEGER_TYPE_P(token_)) {
1855
+ rb_raise(rb_eArgError, "token must be an integer");
1856
+ return Qnil;
1857
+ }
1858
+ const llama_token token = NUM2INT(token_);
1859
+ LLaMAModelWrapper* ptr = get_llama_model(self);
1860
+ return llama_token_is_control(ptr->model, token) ? Qtrue : Qfalse;
1861
+ }
1851
1862
  };
1852
1863
 
1853
1864
  const rb_data_type_t RbLLaMAModel::llama_model_type = {
@@ -3482,6 +3493,7 @@ extern "C" void Init_llama_cpp(void) {
3482
3493
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_QWEN2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_QWEN2));
3483
3494
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_OLMO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_OLMO));
3484
3495
  rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DBRX", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DBRX));
3496
+ rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMAUG", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMAUG));
3485
3497
 
3486
3498
  rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNDEFINED", INT2NUM(LLAMA_TOKEN_TYPE_UNDEFINED));
3487
3499
  rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_NORMAL", INT2NUM(LLAMA_TOKEN_TYPE_NORMAL));
@@ -3,8 +3,8 @@
3
3
  # llama_cpp.rb provides Ruby bindings for the llama.cpp.
4
4
  module LLaMACpp
5
5
  # The version of llama_cpp.rb you install.
6
- VERSION = '0.15.3'
6
+ VERSION = '0.15.4'
7
7
 
8
8
  # The version of llama.cpp bundled with llama_cpp.rb.
9
- LLAMA_CPP_VERSION = 'b2988'
9
+ LLAMA_CPP_VERSION = 'b3056'
10
10
  end
data/sig/llama_cpp.rbs CHANGED
@@ -30,6 +30,7 @@ module LLaMACpp
30
30
  LLAMA_VOCAB_PRE_TYPE_QWEN2: Integer
31
31
  LLAMA_VOCAB_PRE_TYPE_OLMO: Integer
32
32
  LLAMA_VOCAB_PRE_TYPE_DBRX: Integer
33
+ LLAMA_VOCAB_PRE_TYPE_SMAUG: Integer
33
34
 
34
35
  LLAMA_FTYPE_ALL_F32: Integer
35
36
  LLAMA_FTYPE_MOSTLY_F16: Integer
@@ -159,6 +160,7 @@ module LLaMACpp
159
160
  def token_suffix: () -> Integer
160
161
  def token_eot: () -> Integer
161
162
  def token_is_eog?: (Integer) -> bool
163
+ def token_is_control?: (Integer) -> bool
162
164
  end
163
165
 
164
166
  class Timings
@@ -443,6 +443,9 @@ endif # JETSON_EOL_MODULE_DETECT
443
443
  ifdef LLAMA_DEBUG
444
444
  MK_NVCCFLAGS += -lineinfo
445
445
  endif # LLAMA_DEBUG
446
+ ifdef LLAMA_CUDA_DEBUG
447
+ MK_NVCCFLAGS += --device-debug
448
+ endif # LLAMA_CUDA_DEBUG
446
449
  ifdef LLAMA_CUDA_NVCC
447
450
  NVCC = $(CCACHE) $(LLAMA_CUDA_NVCC)
448
451
  else
@@ -749,7 +752,7 @@ lib: llama.o ggml.o $(OBJS)
749
752
  ar rcs libllama.a $^
750
753
 
751
754
  clean:
752
- rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
755
+ rm -vrf *.o tests/*.o *.so *.a *.dll *.dylib benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
753
756
  rm -vrf ggml-cuda/*.o
754
757
 
755
758
  #
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
119
119
  return id;
120
120
  }
121
121
 
122
+ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
123
+ ggml_cuda_set_device(device);
124
+ #if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
125
+ auto res = hipMallocManaged(ptr, size);
126
+ if (res == hipSuccess) {
127
+ // if error we "need" to know why...
128
+ CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
129
+ }
130
+ return res;
131
+ #else
132
+ return cudaMalloc(ptr, size);
133
+ #endif
134
+ }
135
+
122
136
  static ggml_cuda_device_info ggml_cuda_init() {
123
137
  #ifdef __HIP_PLATFORM_AMD__
124
138
  // Workaround for a rocBLAS bug when using multiple graphics cards:
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
271
285
  size_t look_ahead_size = (size_t) (1.05 * size);
272
286
  look_ahead_size = 256 * ((look_ahead_size + 255)/256);
273
287
  ggml_cuda_set_device(device);
274
- CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
288
+ CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
275
289
  *actual_size = look_ahead_size;
276
290
  pool_size += look_ahead_size;
277
291
  #ifdef DEBUG_CUDA_MALLOC
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
537
551
  size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
538
552
 
539
553
  void * dev_ptr;
540
- cudaError_t err = cudaMalloc(&dev_ptr, size);
554
+ cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
541
555
  if (err != cudaSuccess) {
542
556
  // clear the error
543
557
  cudaGetLastError();
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
798
812
  // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
799
813
  ggml_cuda_set_device(id);
800
814
  char * buf;
801
- CUDA_CHECK(cudaMalloc(&buf, size));
815
+ CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
802
816
 
803
817
  // set padding to 0 to avoid possible NaN values
804
818
  if (size > original_size) {
@@ -1856,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
1856
1870
  }
1857
1871
  }
1858
1872
  #else
1859
- if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
1873
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
1860
1874
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1861
1875
  // use cublasGemmStridedBatchedEx
1862
1876
  CUBLAS_CHECK(
@@ -2510,9 +2524,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2510
2524
 
2511
2525
  bool use_cuda_graph = true;
2512
2526
  bool cuda_graph_update_required = false;
2513
- // pointer to CUDA cpy kernel, which is required to identify
2527
+ // vector of pointers to CUDA cpy kernels, which are required to identify
2514
2528
  // kernel parameters which need updated in the graph for each token
2515
- void * ggml_cuda_cpy_fn_ptr = nullptr;
2529
+ std::vector<void *> ggml_cuda_cpy_fn_ptrs;
2516
2530
 
2517
2531
  if (cuda_ctx->cuda_graph->graph == nullptr) {
2518
2532
  if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
@@ -2588,9 +2602,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2588
2602
  if (node->op == GGML_OP_CPY) {
2589
2603
  // store the copy op parameter which changes with each token.
2590
2604
  cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
2591
- if (ggml_cuda_cpy_fn_ptr == nullptr) {
2592
- // store a pointer to the copy op CUDA kernel to identify it later
2593
- ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2605
+ // store a pointer to each copy op CUDA kernel to identify it later
2606
+ void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2607
+ if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
2608
+ ggml_cuda_cpy_fn_ptrs.push_back(ptr);
2594
2609
  }
2595
2610
  }
2596
2611
 
@@ -2720,7 +2735,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2720
2735
  if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
2721
2736
  int k = 0;
2722
2737
  for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
2723
- if (cuda_ctx->cuda_graph->params[i].func == ggml_cuda_cpy_fn_ptr) {
2738
+ if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
2724
2739
  char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
2725
2740
  cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
2726
2741
  CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
@@ -2871,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2871
2886
  case GGML_OP_CONT:
2872
2887
  case GGML_OP_DIAG_MASK_INF:
2873
2888
  case GGML_OP_SOFT_MAX:
2889
+ return true;
2874
2890
  case GGML_OP_ROPE:
2891
+ return ggml_is_contiguous(op->src[0]);
2875
2892
  case GGML_OP_IM2COL:
2876
2893
  case GGML_OP_POOL_2D:
2877
2894
  case GGML_OP_SUM_ROWS:
@@ -144,6 +144,10 @@ extern "C" {
144
144
  #endif
145
145
  #endif
146
146
 
147
+ #if defined(__ARM_FEATURE_SVE)
148
+ #include <arm_sve.h>
149
+ #endif
150
+
147
151
  // 16-bit float
148
152
  // on Arm, we use __fp16
149
153
  // on x86, we use uint16_t
@@ -1597,7 +1597,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
1597
1597
  {
1598
1598
  GGML_ASSERT(ne00 == ne10);
1599
1599
 
1600
- // TODO: assert that dim2 and dim3 are contiguous
1601
1600
  GGML_ASSERT(ne12 % ne02 == 0);
1602
1601
  GGML_ASSERT(ne13 % ne03 == 0);
1603
1602
 
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
35
35
  GGML_METAL_KERNEL_TYPE_MUL_ROW,
36
36
  GGML_METAL_KERNEL_TYPE_DIV,
37
37
  GGML_METAL_KERNEL_TYPE_DIV_ROW,
38
+ GGML_METAL_KERNEL_TYPE_REPEAT_F32,
39
+ GGML_METAL_KERNEL_TYPE_REPEAT_F16,
40
+ GGML_METAL_KERNEL_TYPE_REPEAT_I32,
41
+ GGML_METAL_KERNEL_TYPE_REPEAT_I16,
38
42
  GGML_METAL_KERNEL_TYPE_SCALE,
39
43
  GGML_METAL_KERNEL_TYPE_SCALE_4,
40
44
  GGML_METAL_KERNEL_TYPE_CLAMP,
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
184
188
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
185
189
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
186
190
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
187
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
191
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
188
192
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
189
- GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
193
+ //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
190
194
  GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
191
195
  GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
192
196
  GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
485
489
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
486
490
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
487
491
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
492
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
493
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
494
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
495
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
488
496
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
489
497
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
490
498
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
@@ -634,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
634
642
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
635
643
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
636
644
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
637
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
645
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
638
646
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
639
- GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
647
+ //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
640
648
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
641
649
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
642
650
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
746
754
  case GGML_OP_ACC:
747
755
  case GGML_OP_MUL:
748
756
  case GGML_OP_DIV:
757
+ case GGML_OP_REPEAT:
749
758
  case GGML_OP_SCALE:
750
759
  case GGML_OP_CLAMP:
751
760
  case GGML_OP_SQR:
@@ -770,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
770
779
  case GGML_OP_LEAKY_RELU:
771
780
  return true;
772
781
  case GGML_OP_FLASH_ATTN_EXT:
782
+ if (op->src[0]->ne[0] == 256) {
783
+ return false;
784
+ }
773
785
  return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
774
786
  case GGML_OP_MUL_MAT:
775
787
  case GGML_OP_MUL_MAT_ID:
@@ -976,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
976
988
  switch (dst->op) {
977
989
  case GGML_OP_CONCAT:
978
990
  {
979
- const int64_t nb = ne00;
980
-
981
991
  id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
982
992
 
993
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
994
+
983
995
  [encoder setComputePipelineState:pipeline];
984
996
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
985
997
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1008,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
1008
1020
  [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
1009
1021
  [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
1010
1022
  [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
1011
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
1023
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:27];
1012
1024
 
1013
1025
  const int nth = MIN(1024, ne0);
1014
1026
 
@@ -1018,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
1018
1030
  case GGML_OP_MUL:
1019
1031
  case GGML_OP_DIV:
1020
1032
  {
1033
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
1034
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
1035
+
1021
1036
  const size_t offs = 0;
1022
1037
 
1023
1038
  bool bcast_row = false;
1024
1039
 
1025
- int64_t nb = ne00;
1040
+ int64_t nb = ne00; // used by the "row" kernels
1026
1041
 
1027
1042
  id<MTLComputePipelineState> pipeline = nil;
1028
1043
 
@@ -1091,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
1091
1106
  [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1092
1107
  }
1093
1108
  } break;
1109
+ case GGML_OP_REPEAT:
1110
+ {
1111
+ id<MTLComputePipelineState> pipeline;
1112
+
1113
+ switch (src0t) {
1114
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
1115
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
1116
+ case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
1117
+ case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
1118
+ default: GGML_ASSERT(false);
1119
+ }
1120
+
1121
+ [encoder setComputePipelineState:pipeline];
1122
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1123
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
1124
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
1125
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
1126
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
1127
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
1128
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1129
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1130
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1131
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
1132
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
1133
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
1134
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
1135
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
1136
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
1137
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
1138
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
1139
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
1140
+
1141
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
1142
+
1143
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
1144
+ } break;
1094
1145
  case GGML_OP_ACC:
1095
1146
  {
1096
1147
  GGML_ASSERT(src0t == GGML_TYPE_F32);
@@ -1468,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
1468
1519
  {
1469
1520
  GGML_ASSERT(ne00 == ne10);
1470
1521
 
1471
- // TODO: assert that dim2 and dim3 are contiguous
1472
1522
  GGML_ASSERT(ne12 % ne02 == 0);
1473
1523
  GGML_ASSERT(ne13 % ne03 == 0);
1474
1524
 
@@ -2136,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
2136
2186
  case GGML_OP_RMS_NORM:
2137
2187
  {
2138
2188
  GGML_ASSERT(ne00 % 4 == 0);
2189
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2139
2190
 
2140
2191
  float eps;
2141
2192
  memcpy(&eps, dst->op_params, sizeof(float));
@@ -2163,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
2163
2214
  case GGML_OP_GROUP_NORM:
2164
2215
  {
2165
2216
  GGML_ASSERT(ne00 % 4 == 0);
2217
+ GGML_ASSERT(ggml_is_contiguous(src0));
2166
2218
 
2167
2219
  //float eps;
2168
2220
  //memcpy(&eps, dst->op_params, sizeof(float));
@@ -2196,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
2196
2248
  } break;
2197
2249
  case GGML_OP_NORM:
2198
2250
  {
2251
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
2252
+
2199
2253
  float eps;
2200
2254
  memcpy(&eps, dst->op_params, sizeof(float));
2201
2255
 
@@ -2573,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
2573
2627
  case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
2574
2628
  case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
2575
2629
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
2576
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2630
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
2577
2631
  default:
2578
2632
  {
2579
2633
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -2586,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
2586
2640
 
2587
2641
  switch (ne00) {
2588
2642
  case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
2589
- case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2643
+ //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
2590
2644
  default:
2591
2645
  {
2592
2646
  GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
@@ -168,6 +168,53 @@ kernel void kernel_div(
168
168
  }
169
169
  }
170
170
 
171
+ template<typename T>
172
+ kernel void kernel_repeat(
173
+ device const char * src0,
174
+ device char * dst,
175
+ constant int64_t & ne00,
176
+ constant int64_t & ne01,
177
+ constant int64_t & ne02,
178
+ constant int64_t & ne03,
179
+ constant uint64_t & nb00,
180
+ constant uint64_t & nb01,
181
+ constant uint64_t & nb02,
182
+ constant uint64_t & nb03,
183
+ constant int64_t & ne0,
184
+ constant int64_t & ne1,
185
+ constant int64_t & ne2,
186
+ constant int64_t & ne3,
187
+ constant uint64_t & nb0,
188
+ constant uint64_t & nb1,
189
+ constant uint64_t & nb2,
190
+ constant uint64_t & nb3,
191
+ uint3 tgpig[[threadgroup_position_in_grid]],
192
+ uint3 tpitg[[thread_position_in_threadgroup]],
193
+ uint3 ntg[[threads_per_threadgroup]]) {
194
+ const int64_t i3 = tgpig.z;
195
+ const int64_t i2 = tgpig.y;
196
+ const int64_t i1 = tgpig.x;
197
+
198
+ const int64_t i03 = i3 % ne03;
199
+ const int64_t i02 = i2 % ne02;
200
+ const int64_t i01 = i1 % ne01;
201
+
202
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
203
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
204
+
205
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
206
+ const int i00 = i0 % ne00;
207
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
208
+ }
209
+ }
210
+
211
+ typedef decltype(kernel_repeat<float>) kernel_repeat_t;
212
+
213
+ template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
214
+ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
215
+ template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
216
+ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
217
+
171
218
  // assumption: src1 is a row
172
219
  // broadcast src1 into src0
173
220
  kernel void kernel_add_row(
@@ -1720,13 +1767,13 @@ kernel void kernel_rope(
1720
1767
 
1721
1768
  const int64_t p = pos[i2];
1722
1769
 
1723
- const float theta_0 = (float)p;
1770
+ const float theta_base = (float)p;
1724
1771
  const float inv_ndims = -1.f/n_dims;
1725
1772
 
1726
1773
  if (!is_neox) {
1727
1774
  for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775
+ const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1728
1776
 
1729
- const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
1730
1777
  float cos_theta, sin_theta;
1731
1778
  rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1732
1779
 
@@ -1742,18 +1789,14 @@ kernel void kernel_rope(
1742
1789
  } else {
1743
1790
  for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1744
1791
  if (ic < n_dims) {
1745
- const int64_t ib = 0;
1792
+ const int64_t i0 = ic/2;
1746
1793
 
1747
- // simplified from `(ib * n_dims + ic) * inv_ndims`
1748
- const float cur_rot = inv_ndims*ic - ib;
1749
- const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1794
+ const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1750
1795
 
1751
- const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1796
+ const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1752
1797
 
1753
1798
  float cos_theta, sin_theta;
1754
- rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1755
-
1756
- const int64_t i0 = ib*n_dims + ic/2;
1799
+ rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1757
1800
 
1758
1801
  device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1759
1802
  device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -2418,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
2418
2461
  template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
2419
2462
  template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
2420
2463
  template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
2421
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2464
+ //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
2422
2465
 
2423
2466
  template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
2424
2467
  kernel void kernel_flash_attn_ext_vec_f16(
@@ -2696,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
2696
2739
  }
2697
2740
 
2698
2741
  template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
2699
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2742
+ //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
2700
2743
 
2701
2744
  kernel void kernel_cpy_f16_f16(
2702
2745
  device const half * src0,
@@ -3319,31 +3362,30 @@ kernel void kernel_concat(
3319
3362
  constant uint64_t & nb1,
3320
3363
  constant uint64_t & nb2,
3321
3364
  constant uint64_t & nb3,
3365
+ constant int32_t & dim,
3322
3366
  uint3 tgpig[[threadgroup_position_in_grid]],
3323
3367
  uint3 tpitg[[thread_position_in_threadgroup]],
3324
3368
  uint3 ntg[[threads_per_threadgroup]]) {
3325
3369
 
3326
- const int64_t i03 = tgpig.z;
3327
- const int64_t i02 = tgpig.y;
3328
- const int64_t i01 = tgpig.x;
3370
+ const int64_t i3 = tgpig.z;
3371
+ const int64_t i2 = tgpig.y;
3372
+ const int64_t i1 = tgpig.x;
3329
3373
 
3330
- const int64_t i13 = i03 % ne13;
3331
- const int64_t i12 = i02 % ne12;
3332
- const int64_t i11 = i01 % ne11;
3374
+ int64_t o[4] = {0, 0, 0, 0};
3375
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
3333
3376
 
3334
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
3335
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
3336
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
3377
+ device const float * x;
3337
3378
 
3338
3379
  for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
3339
- if (i02 < ne02) {
3340
- ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
3341
- src0_ptr += ntg.x*nb00;
3380
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
3381
+ x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
3342
3382
  } else {
3343
- ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
3344
- src1_ptr += ntg.x*nb10;
3383
+ x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
3345
3384
  }
3346
- dst_ptr += ntg.x*nb0;
3385
+
3386
+ device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
3387
+
3388
+ *y = *x;
3347
3389
  }
3348
3390
  }
3349
3391