llama_cpp 0.15.3 → 0.15.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +12 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- data/vendor/tmp/llama.cpp/Makefile +4 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +27 -10
- data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +65 -11
- data/vendor/tmp/llama.cpp/ggml-metal.metal +69 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +338 -160
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +2 -0
- data/vendor/tmp/llama.cpp/ggml.c +145 -101
- data/vendor/tmp/llama.cpp/ggml.h +18 -3
- data/vendor/tmp/llama.cpp/llama.cpp +637 -249
- data/vendor/tmp/llama.cpp/llama.h +11 -5
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 167132898a0cb63faaf4fd7583d9b988992ba7c5ec0f5602d5a158f04e0cdfa0
|
4
|
+
data.tar.gz: 8a65658eb93b9cf80d5ede554b15968c495f045c32e57cc96ed732c56330d25f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9625ac088c4d5c50cc51bbbcbc744cb7041766ccbb7a42a9cd1b80b29ebe64414d39875dea5d61a87025e239ad78be2a2ea4d3f85a187684321e409fc01a40fd
|
7
|
+
data.tar.gz: 6f68445f10765a4eb1124ed1cfd2afb7544d146823efad27b2b6955bb0ee822ae8b0f9cccb68777c8cb211f665a0e2531eba04a4240399af1101a5dbcd645ae9
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## [[0.15.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.3...v0.15.4)] - 2024-06-01
|
2
|
+
|
3
|
+
- Bump llama.cpp from b2988 to b3056.
|
4
|
+
- Add LLAMA_VOCAB_PRE_TYPE_SMAUG constant.
|
5
|
+
- Add `token_is_control?` method to `Model`.
|
6
|
+
|
1
7
|
## [[0.15.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.2...v0.15.3)] - 2024-05-25
|
2
8
|
|
3
9
|
- Bump llama.cpp from b2917 to b2988.
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1536,6 +1536,7 @@ public:
|
|
1536
1536
|
rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
|
1537
1537
|
rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
|
1538
1538
|
rb_define_method(rb_cLLaMAModel, "token_is_eog?", RUBY_METHOD_FUNC(_llama_model_token_is_eog), 1);
|
1539
|
+
rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
|
1539
1540
|
}
|
1540
1541
|
|
1541
1542
|
private:
|
@@ -1848,6 +1849,16 @@ private:
|
|
1848
1849
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1849
1850
|
return llama_token_is_eog(ptr->model, token) ? Qtrue : Qfalse;
|
1850
1851
|
}
|
1852
|
+
|
1853
|
+
static VALUE _llama_model_token_is_control(VALUE self, VALUE token_) {
|
1854
|
+
if (!RB_INTEGER_TYPE_P(token_)) {
|
1855
|
+
rb_raise(rb_eArgError, "token must be an integer");
|
1856
|
+
return Qnil;
|
1857
|
+
}
|
1858
|
+
const llama_token token = NUM2INT(token_);
|
1859
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1860
|
+
return llama_token_is_control(ptr->model, token) ? Qtrue : Qfalse;
|
1861
|
+
}
|
1851
1862
|
};
|
1852
1863
|
|
1853
1864
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -3482,6 +3493,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
3482
3493
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_QWEN2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_QWEN2));
|
3483
3494
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_OLMO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_OLMO));
|
3484
3495
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DBRX", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DBRX));
|
3496
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMAUG", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMAUG));
|
3485
3497
|
|
3486
3498
|
rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNDEFINED", INT2NUM(LLAMA_TOKEN_TYPE_UNDEFINED));
|
3487
3499
|
rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_NORMAL", INT2NUM(LLAMA_TOKEN_TYPE_NORMAL));
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.15.
|
6
|
+
VERSION = '0.15.4'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = '
|
9
|
+
LLAMA_CPP_VERSION = 'b3056'
|
10
10
|
end
|
data/sig/llama_cpp.rbs
CHANGED
@@ -30,6 +30,7 @@ module LLaMACpp
|
|
30
30
|
LLAMA_VOCAB_PRE_TYPE_QWEN2: Integer
|
31
31
|
LLAMA_VOCAB_PRE_TYPE_OLMO: Integer
|
32
32
|
LLAMA_VOCAB_PRE_TYPE_DBRX: Integer
|
33
|
+
LLAMA_VOCAB_PRE_TYPE_SMAUG: Integer
|
33
34
|
|
34
35
|
LLAMA_FTYPE_ALL_F32: Integer
|
35
36
|
LLAMA_FTYPE_MOSTLY_F16: Integer
|
@@ -159,6 +160,7 @@ module LLaMACpp
|
|
159
160
|
def token_suffix: () -> Integer
|
160
161
|
def token_eot: () -> Integer
|
161
162
|
def token_is_eog?: (Integer) -> bool
|
163
|
+
def token_is_control?: (Integer) -> bool
|
162
164
|
end
|
163
165
|
|
164
166
|
class Timings
|
@@ -443,6 +443,9 @@ endif # JETSON_EOL_MODULE_DETECT
|
|
443
443
|
ifdef LLAMA_DEBUG
|
444
444
|
MK_NVCCFLAGS += -lineinfo
|
445
445
|
endif # LLAMA_DEBUG
|
446
|
+
ifdef LLAMA_CUDA_DEBUG
|
447
|
+
MK_NVCCFLAGS += --device-debug
|
448
|
+
endif # LLAMA_CUDA_DEBUG
|
446
449
|
ifdef LLAMA_CUDA_NVCC
|
447
450
|
NVCC = $(CCACHE) $(LLAMA_CUDA_NVCC)
|
448
451
|
else
|
@@ -749,7 +752,7 @@ lib: llama.o ggml.o $(OBJS)
|
|
749
752
|
ar rcs libllama.a $^
|
750
753
|
|
751
754
|
clean:
|
752
|
-
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
|
755
|
+
rm -vrf *.o tests/*.o *.so *.a *.dll *.dylib benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
|
753
756
|
rm -vrf ggml-cuda/*.o
|
754
757
|
|
755
758
|
#
|
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
|
|
119
119
|
return id;
|
120
120
|
}
|
121
121
|
|
122
|
+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
123
|
+
ggml_cuda_set_device(device);
|
124
|
+
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
|
125
|
+
auto res = hipMallocManaged(ptr, size);
|
126
|
+
if (res == hipSuccess) {
|
127
|
+
// if error we "need" to know why...
|
128
|
+
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
|
129
|
+
}
|
130
|
+
return res;
|
131
|
+
#else
|
132
|
+
return cudaMalloc(ptr, size);
|
133
|
+
#endif
|
134
|
+
}
|
135
|
+
|
122
136
|
static ggml_cuda_device_info ggml_cuda_init() {
|
123
137
|
#ifdef __HIP_PLATFORM_AMD__
|
124
138
|
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|
271
285
|
size_t look_ahead_size = (size_t) (1.05 * size);
|
272
286
|
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
273
287
|
ggml_cuda_set_device(device);
|
274
|
-
CUDA_CHECK(
|
288
|
+
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
|
275
289
|
*actual_size = look_ahead_size;
|
276
290
|
pool_size += look_ahead_size;
|
277
291
|
#ifdef DEBUG_CUDA_MALLOC
|
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
|
|
537
551
|
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
|
538
552
|
|
539
553
|
void * dev_ptr;
|
540
|
-
cudaError_t err =
|
554
|
+
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
541
555
|
if (err != cudaSuccess) {
|
542
556
|
// clear the error
|
543
557
|
cudaGetLastError();
|
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
|
|
798
812
|
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
|
799
813
|
ggml_cuda_set_device(id);
|
800
814
|
char * buf;
|
801
|
-
CUDA_CHECK(
|
815
|
+
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
|
802
816
|
|
803
817
|
// set padding to 0 to avoid possible NaN values
|
804
818
|
if (size > original_size) {
|
@@ -1856,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
1856
1870
|
}
|
1857
1871
|
}
|
1858
1872
|
#else
|
1859
|
-
if (r2 == 1 && r3 == 1 && src0
|
1873
|
+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
1860
1874
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
1861
1875
|
// use cublasGemmStridedBatchedEx
|
1862
1876
|
CUBLAS_CHECK(
|
@@ -2510,9 +2524,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2510
2524
|
|
2511
2525
|
bool use_cuda_graph = true;
|
2512
2526
|
bool cuda_graph_update_required = false;
|
2513
|
-
//
|
2527
|
+
// vector of pointers to CUDA cpy kernels, which are required to identify
|
2514
2528
|
// kernel parameters which need updated in the graph for each token
|
2515
|
-
void
|
2529
|
+
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
2516
2530
|
|
2517
2531
|
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
2518
2532
|
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
@@ -2588,9 +2602,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2588
2602
|
if (node->op == GGML_OP_CPY) {
|
2589
2603
|
// store the copy op parameter which changes with each token.
|
2590
2604
|
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
|
2591
|
-
|
2592
|
-
|
2593
|
-
|
2605
|
+
// store a pointer to each copy op CUDA kernel to identify it later
|
2606
|
+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
2607
|
+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
|
2608
|
+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
|
2594
2609
|
}
|
2595
2610
|
}
|
2596
2611
|
|
@@ -2720,7 +2735,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2720
2735
|
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
|
2721
2736
|
int k = 0;
|
2722
2737
|
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
2723
|
-
if (cuda_ctx->cuda_graph->params[i].func
|
2738
|
+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
2724
2739
|
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
2725
2740
|
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
2726
2741
|
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
@@ -2871,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
2871
2886
|
case GGML_OP_CONT:
|
2872
2887
|
case GGML_OP_DIAG_MASK_INF:
|
2873
2888
|
case GGML_OP_SOFT_MAX:
|
2889
|
+
return true;
|
2874
2890
|
case GGML_OP_ROPE:
|
2891
|
+
return ggml_is_contiguous(op->src[0]);
|
2875
2892
|
case GGML_OP_IM2COL:
|
2876
2893
|
case GGML_OP_POOL_2D:
|
2877
2894
|
case GGML_OP_SUM_ROWS:
|
@@ -1597,7 +1597,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1597
1597
|
{
|
1598
1598
|
GGML_ASSERT(ne00 == ne10);
|
1599
1599
|
|
1600
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1601
1600
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1602
1601
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1603
1602
|
|
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
35
35
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
36
36
|
GGML_METAL_KERNEL_TYPE_DIV,
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
39
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
40
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
41
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
38
42
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
43
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
44
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
|
|
184
188
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
189
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
190
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
-
|
191
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
188
192
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
-
|
193
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
190
194
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
191
195
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
192
196
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
485
489
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
486
490
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
487
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
488
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
489
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
490
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
@@ -634,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
634
642
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
635
643
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
636
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
637
|
-
|
645
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
638
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
639
|
-
|
647
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
640
648
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
641
649
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
642
650
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
746
754
|
case GGML_OP_ACC:
|
747
755
|
case GGML_OP_MUL:
|
748
756
|
case GGML_OP_DIV:
|
757
|
+
case GGML_OP_REPEAT:
|
749
758
|
case GGML_OP_SCALE:
|
750
759
|
case GGML_OP_CLAMP:
|
751
760
|
case GGML_OP_SQR:
|
@@ -770,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
770
779
|
case GGML_OP_LEAKY_RELU:
|
771
780
|
return true;
|
772
781
|
case GGML_OP_FLASH_ATTN_EXT:
|
782
|
+
if (op->src[0]->ne[0] == 256) {
|
783
|
+
return false;
|
784
|
+
}
|
773
785
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
774
786
|
case GGML_OP_MUL_MAT:
|
775
787
|
case GGML_OP_MUL_MAT_ID:
|
@@ -976,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
976
988
|
switch (dst->op) {
|
977
989
|
case GGML_OP_CONCAT:
|
978
990
|
{
|
979
|
-
const int64_t nb = ne00;
|
980
|
-
|
981
991
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
982
992
|
|
993
|
+
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
994
|
+
|
983
995
|
[encoder setComputePipelineState:pipeline];
|
984
996
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
985
997
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1008,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1008
1020
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1009
1021
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1010
1022
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1011
|
-
[encoder setBytes:&
|
1023
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
1012
1024
|
|
1013
1025
|
const int nth = MIN(1024, ne0);
|
1014
1026
|
|
@@ -1018,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1018
1030
|
case GGML_OP_MUL:
|
1019
1031
|
case GGML_OP_DIV:
|
1020
1032
|
{
|
1033
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1034
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1035
|
+
|
1021
1036
|
const size_t offs = 0;
|
1022
1037
|
|
1023
1038
|
bool bcast_row = false;
|
1024
1039
|
|
1025
|
-
int64_t nb = ne00;
|
1040
|
+
int64_t nb = ne00; // used by the "row" kernels
|
1026
1041
|
|
1027
1042
|
id<MTLComputePipelineState> pipeline = nil;
|
1028
1043
|
|
@@ -1091,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1091
1106
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1092
1107
|
}
|
1093
1108
|
} break;
|
1109
|
+
case GGML_OP_REPEAT:
|
1110
|
+
{
|
1111
|
+
id<MTLComputePipelineState> pipeline;
|
1112
|
+
|
1113
|
+
switch (src0t) {
|
1114
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
1115
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
1116
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
1117
|
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
1118
|
+
default: GGML_ASSERT(false);
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
[encoder setComputePipelineState:pipeline];
|
1122
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1123
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1124
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1125
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1126
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1127
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1128
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1129
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1130
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1131
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1132
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
1133
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
1134
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
1135
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
1136
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
1137
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1138
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
1139
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
1140
|
+
|
1141
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1142
|
+
|
1143
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1144
|
+
} break;
|
1094
1145
|
case GGML_OP_ACC:
|
1095
1146
|
{
|
1096
1147
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
@@ -1468,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1468
1519
|
{
|
1469
1520
|
GGML_ASSERT(ne00 == ne10);
|
1470
1521
|
|
1471
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1472
1522
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1473
1523
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1474
1524
|
|
@@ -2136,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2136
2186
|
case GGML_OP_RMS_NORM:
|
2137
2187
|
{
|
2138
2188
|
GGML_ASSERT(ne00 % 4 == 0);
|
2189
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2139
2190
|
|
2140
2191
|
float eps;
|
2141
2192
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2163,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2163
2214
|
case GGML_OP_GROUP_NORM:
|
2164
2215
|
{
|
2165
2216
|
GGML_ASSERT(ne00 % 4 == 0);
|
2217
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
2166
2218
|
|
2167
2219
|
//float eps;
|
2168
2220
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2196,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2196
2248
|
} break;
|
2197
2249
|
case GGML_OP_NORM:
|
2198
2250
|
{
|
2251
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2252
|
+
|
2199
2253
|
float eps;
|
2200
2254
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2201
2255
|
|
@@ -2573,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2573
2627
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2574
2628
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2575
2629
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2576
|
-
|
2630
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2577
2631
|
default:
|
2578
2632
|
{
|
2579
2633
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -2586,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2586
2640
|
|
2587
2641
|
switch (ne00) {
|
2588
2642
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2589
|
-
|
2643
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2590
2644
|
default:
|
2591
2645
|
{
|
2592
2646
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
168
168
|
}
|
169
169
|
}
|
170
170
|
|
171
|
+
template<typename T>
|
172
|
+
kernel void kernel_repeat(
|
173
|
+
device const char * src0,
|
174
|
+
device char * dst,
|
175
|
+
constant int64_t & ne00,
|
176
|
+
constant int64_t & ne01,
|
177
|
+
constant int64_t & ne02,
|
178
|
+
constant int64_t & ne03,
|
179
|
+
constant uint64_t & nb00,
|
180
|
+
constant uint64_t & nb01,
|
181
|
+
constant uint64_t & nb02,
|
182
|
+
constant uint64_t & nb03,
|
183
|
+
constant int64_t & ne0,
|
184
|
+
constant int64_t & ne1,
|
185
|
+
constant int64_t & ne2,
|
186
|
+
constant int64_t & ne3,
|
187
|
+
constant uint64_t & nb0,
|
188
|
+
constant uint64_t & nb1,
|
189
|
+
constant uint64_t & nb2,
|
190
|
+
constant uint64_t & nb3,
|
191
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
192
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
193
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
194
|
+
const int64_t i3 = tgpig.z;
|
195
|
+
const int64_t i2 = tgpig.y;
|
196
|
+
const int64_t i1 = tgpig.x;
|
197
|
+
|
198
|
+
const int64_t i03 = i3 % ne03;
|
199
|
+
const int64_t i02 = i2 % ne02;
|
200
|
+
const int64_t i01 = i1 % ne01;
|
201
|
+
|
202
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
203
|
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
204
|
+
|
205
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
206
|
+
const int i00 = i0 % ne00;
|
207
|
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
208
|
+
}
|
209
|
+
}
|
210
|
+
|
211
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
212
|
+
|
213
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
214
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
215
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
216
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
217
|
+
|
171
218
|
// assumption: src1 is a row
|
172
219
|
// broadcast src1 into src0
|
173
220
|
kernel void kernel_add_row(
|
@@ -1720,13 +1767,13 @@ kernel void kernel_rope(
|
|
1720
1767
|
|
1721
1768
|
const int64_t p = pos[i2];
|
1722
1769
|
|
1723
|
-
const float
|
1770
|
+
const float theta_base = (float)p;
|
1724
1771
|
const float inv_ndims = -1.f/n_dims;
|
1725
1772
|
|
1726
1773
|
if (!is_neox) {
|
1727
1774
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1775
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1728
1776
|
|
1729
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1730
1777
|
float cos_theta, sin_theta;
|
1731
1778
|
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1732
1779
|
|
@@ -1742,18 +1789,14 @@ kernel void kernel_rope(
|
|
1742
1789
|
} else {
|
1743
1790
|
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1744
1791
|
if (ic < n_dims) {
|
1745
|
-
const int64_t
|
1792
|
+
const int64_t i0 = ic/2;
|
1746
1793
|
|
1747
|
-
|
1748
|
-
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1794
|
+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
1750
1795
|
|
1751
|
-
const float theta =
|
1796
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
1752
1797
|
|
1753
1798
|
float cos_theta, sin_theta;
|
1754
|
-
rope_yarn(theta, freq_scale, corr_dims,
|
1755
|
-
|
1756
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
1799
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1757
1800
|
|
1758
1801
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1759
1802
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -2418,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
2418
2461
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2419
2462
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2420
2463
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2421
|
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2464
|
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2422
2465
|
|
2423
2466
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2424
2467
|
kernel void kernel_flash_attn_ext_vec_f16(
|
@@ -2696,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2696
2739
|
}
|
2697
2740
|
|
2698
2741
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2699
|
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2742
|
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2700
2743
|
|
2701
2744
|
kernel void kernel_cpy_f16_f16(
|
2702
2745
|
device const half * src0,
|
@@ -3319,31 +3362,30 @@ kernel void kernel_concat(
|
|
3319
3362
|
constant uint64_t & nb1,
|
3320
3363
|
constant uint64_t & nb2,
|
3321
3364
|
constant uint64_t & nb3,
|
3365
|
+
constant int32_t & dim,
|
3322
3366
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3323
3367
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
3324
3368
|
uint3 ntg[[threads_per_threadgroup]]) {
|
3325
3369
|
|
3326
|
-
const int64_t
|
3327
|
-
const int64_t
|
3328
|
-
const int64_t
|
3370
|
+
const int64_t i3 = tgpig.z;
|
3371
|
+
const int64_t i2 = tgpig.y;
|
3372
|
+
const int64_t i1 = tgpig.x;
|
3329
3373
|
|
3330
|
-
|
3331
|
-
|
3332
|
-
const int64_t i11 = i01 % ne11;
|
3374
|
+
int64_t o[4] = {0, 0, 0, 0};
|
3375
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
3333
3376
|
|
3334
|
-
device const
|
3335
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
3336
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
3377
|
+
device const float * x;
|
3337
3378
|
|
3338
3379
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
3339
|
-
if (
|
3340
|
-
(
|
3341
|
-
src0_ptr += ntg.x*nb00;
|
3380
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
3381
|
+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
3342
3382
|
} else {
|
3343
|
-
(
|
3344
|
-
src1_ptr += ntg.x*nb10;
|
3383
|
+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
3345
3384
|
}
|
3346
|
-
|
3385
|
+
|
3386
|
+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
3387
|
+
|
3388
|
+
*y = *x;
|
3347
3389
|
}
|
3348
3390
|
}
|
3349
3391
|
|