llama_cpp 0.15.3 → 0.15.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/ext/llama_cpp/llama_cpp.cpp +12 -0
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +2 -0
- data/vendor/tmp/llama.cpp/Makefile +4 -1
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +27 -10
- data/vendor/tmp/llama.cpp/ggml-impl.h +4 -0
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +0 -1
- data/vendor/tmp/llama.cpp/ggml-metal.m +65 -11
- data/vendor/tmp/llama.cpp/ggml-metal.metal +69 -27
- data/vendor/tmp/llama.cpp/ggml-quants.c +101 -11
- data/vendor/tmp/llama.cpp/ggml-rpc.cpp +75 -58
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +338 -160
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +2 -0
- data/vendor/tmp/llama.cpp/ggml.c +145 -101
- data/vendor/tmp/llama.cpp/ggml.h +18 -3
- data/vendor/tmp/llama.cpp/llama.cpp +637 -249
- data/vendor/tmp/llama.cpp/llama.h +11 -5
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 167132898a0cb63faaf4fd7583d9b988992ba7c5ec0f5602d5a158f04e0cdfa0
|
4
|
+
data.tar.gz: 8a65658eb93b9cf80d5ede554b15968c495f045c32e57cc96ed732c56330d25f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9625ac088c4d5c50cc51bbbcbc744cb7041766ccbb7a42a9cd1b80b29ebe64414d39875dea5d61a87025e239ad78be2a2ea4d3f85a187684321e409fc01a40fd
|
7
|
+
data.tar.gz: 6f68445f10765a4eb1124ed1cfd2afb7544d146823efad27b2b6955bb0ee822ae8b0f9cccb68777c8cb211f665a0e2531eba04a4240399af1101a5dbcd645ae9
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,9 @@
|
|
1
|
+
## [[0.15.4](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.3...v0.15.4)] - 2024-06-01
|
2
|
+
|
3
|
+
- Bump llama.cpp from b2988 to b3056.
|
4
|
+
- Add LLAMA_VOCAB_PRE_TYPE_SMAUG constant.
|
5
|
+
- Add `token_is_control?` method to `Model`.
|
6
|
+
|
1
7
|
## [[0.15.3](https://github.com/yoshoku/llama_cpp.rb/compare/v0.15.2...v0.15.3)] - 2024-05-25
|
2
8
|
|
3
9
|
- Bump llama.cpp from b2917 to b2988.
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -1536,6 +1536,7 @@ public:
|
|
1536
1536
|
rb_define_method(rb_cLLaMAModel, "token_suffix", RUBY_METHOD_FUNC(_llama_model_token_suffix), 0);
|
1537
1537
|
rb_define_method(rb_cLLaMAModel, "token_eot", RUBY_METHOD_FUNC(_llama_model_token_eot), 0);
|
1538
1538
|
rb_define_method(rb_cLLaMAModel, "token_is_eog?", RUBY_METHOD_FUNC(_llama_model_token_is_eog), 1);
|
1539
|
+
rb_define_method(rb_cLLaMAModel, "token_is_control?", RUBY_METHOD_FUNC(_llama_model_token_is_control), 1);
|
1539
1540
|
}
|
1540
1541
|
|
1541
1542
|
private:
|
@@ -1848,6 +1849,16 @@ private:
|
|
1848
1849
|
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1849
1850
|
return llama_token_is_eog(ptr->model, token) ? Qtrue : Qfalse;
|
1850
1851
|
}
|
1852
|
+
|
1853
|
+
static VALUE _llama_model_token_is_control(VALUE self, VALUE token_) {
|
1854
|
+
if (!RB_INTEGER_TYPE_P(token_)) {
|
1855
|
+
rb_raise(rb_eArgError, "token must be an integer");
|
1856
|
+
return Qnil;
|
1857
|
+
}
|
1858
|
+
const llama_token token = NUM2INT(token_);
|
1859
|
+
LLaMAModelWrapper* ptr = get_llama_model(self);
|
1860
|
+
return llama_token_is_control(ptr->model, token) ? Qtrue : Qfalse;
|
1861
|
+
}
|
1851
1862
|
};
|
1852
1863
|
|
1853
1864
|
const rb_data_type_t RbLLaMAModel::llama_model_type = {
|
@@ -3482,6 +3493,7 @@ extern "C" void Init_llama_cpp(void) {
|
|
3482
3493
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_QWEN2", INT2NUM(LLAMA_VOCAB_PRE_TYPE_QWEN2));
|
3483
3494
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_OLMO", INT2NUM(LLAMA_VOCAB_PRE_TYPE_OLMO));
|
3484
3495
|
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_DBRX", INT2NUM(LLAMA_VOCAB_PRE_TYPE_DBRX));
|
3496
|
+
rb_define_const(rb_mLLaMACpp, "LLAMA_VOCAB_PRE_TYPE_SMAUG", INT2NUM(LLAMA_VOCAB_PRE_TYPE_SMAUG));
|
3485
3497
|
|
3486
3498
|
rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_UNDEFINED", INT2NUM(LLAMA_TOKEN_TYPE_UNDEFINED));
|
3487
3499
|
rb_define_const(rb_mLLaMACpp, "LLAMA_TOKEN_TYPE_NORMAL", INT2NUM(LLAMA_TOKEN_TYPE_NORMAL));
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.15.
|
6
|
+
VERSION = '0.15.4'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = '
|
9
|
+
LLAMA_CPP_VERSION = 'b3056'
|
10
10
|
end
|
data/sig/llama_cpp.rbs
CHANGED
@@ -30,6 +30,7 @@ module LLaMACpp
|
|
30
30
|
LLAMA_VOCAB_PRE_TYPE_QWEN2: Integer
|
31
31
|
LLAMA_VOCAB_PRE_TYPE_OLMO: Integer
|
32
32
|
LLAMA_VOCAB_PRE_TYPE_DBRX: Integer
|
33
|
+
LLAMA_VOCAB_PRE_TYPE_SMAUG: Integer
|
33
34
|
|
34
35
|
LLAMA_FTYPE_ALL_F32: Integer
|
35
36
|
LLAMA_FTYPE_MOSTLY_F16: Integer
|
@@ -159,6 +160,7 @@ module LLaMACpp
|
|
159
160
|
def token_suffix: () -> Integer
|
160
161
|
def token_eot: () -> Integer
|
161
162
|
def token_is_eog?: (Integer) -> bool
|
163
|
+
def token_is_control?: (Integer) -> bool
|
162
164
|
end
|
163
165
|
|
164
166
|
class Timings
|
@@ -443,6 +443,9 @@ endif # JETSON_EOL_MODULE_DETECT
|
|
443
443
|
ifdef LLAMA_DEBUG
|
444
444
|
MK_NVCCFLAGS += -lineinfo
|
445
445
|
endif # LLAMA_DEBUG
|
446
|
+
ifdef LLAMA_CUDA_DEBUG
|
447
|
+
MK_NVCCFLAGS += --device-debug
|
448
|
+
endif # LLAMA_CUDA_DEBUG
|
446
449
|
ifdef LLAMA_CUDA_NVCC
|
447
450
|
NVCC = $(CCACHE) $(LLAMA_CUDA_NVCC)
|
448
451
|
else
|
@@ -749,7 +752,7 @@ lib: llama.o ggml.o $(OBJS)
|
|
749
752
|
ar rcs libllama.a $^
|
750
753
|
|
751
754
|
clean:
|
752
|
-
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
|
755
|
+
rm -vrf *.o tests/*.o *.so *.a *.dll *.dylib benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
|
753
756
|
rm -vrf ggml-cuda/*.o
|
754
757
|
|
755
758
|
#
|
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
|
|
119
119
|
return id;
|
120
120
|
}
|
121
121
|
|
122
|
+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
123
|
+
ggml_cuda_set_device(device);
|
124
|
+
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
|
125
|
+
auto res = hipMallocManaged(ptr, size);
|
126
|
+
if (res == hipSuccess) {
|
127
|
+
// if error we "need" to know why...
|
128
|
+
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
|
129
|
+
}
|
130
|
+
return res;
|
131
|
+
#else
|
132
|
+
return cudaMalloc(ptr, size);
|
133
|
+
#endif
|
134
|
+
}
|
135
|
+
|
122
136
|
static ggml_cuda_device_info ggml_cuda_init() {
|
123
137
|
#ifdef __HIP_PLATFORM_AMD__
|
124
138
|
// Workaround for a rocBLAS bug when using multiple graphics cards:
|
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
|
271
285
|
size_t look_ahead_size = (size_t) (1.05 * size);
|
272
286
|
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
|
273
287
|
ggml_cuda_set_device(device);
|
274
|
-
CUDA_CHECK(
|
288
|
+
CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
|
275
289
|
*actual_size = look_ahead_size;
|
276
290
|
pool_size += look_ahead_size;
|
277
291
|
#ifdef DEBUG_CUDA_MALLOC
|
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
|
|
537
551
|
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
|
538
552
|
|
539
553
|
void * dev_ptr;
|
540
|
-
cudaError_t err =
|
554
|
+
cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
|
541
555
|
if (err != cudaSuccess) {
|
542
556
|
// clear the error
|
543
557
|
cudaGetLastError();
|
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
|
|
798
812
|
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
|
799
813
|
ggml_cuda_set_device(id);
|
800
814
|
char * buf;
|
801
|
-
CUDA_CHECK(
|
815
|
+
CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
|
802
816
|
|
803
817
|
// set padding to 0 to avoid possible NaN values
|
804
818
|
if (size > original_size) {
|
@@ -1856,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
1856
1870
|
}
|
1857
1871
|
}
|
1858
1872
|
#else
|
1859
|
-
if (r2 == 1 && r3 == 1 && src0
|
1873
|
+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
1860
1874
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
1861
1875
|
// use cublasGemmStridedBatchedEx
|
1862
1876
|
CUBLAS_CHECK(
|
@@ -2510,9 +2524,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2510
2524
|
|
2511
2525
|
bool use_cuda_graph = true;
|
2512
2526
|
bool cuda_graph_update_required = false;
|
2513
|
-
//
|
2527
|
+
// vector of pointers to CUDA cpy kernels, which are required to identify
|
2514
2528
|
// kernel parameters which need updated in the graph for each token
|
2515
|
-
void
|
2529
|
+
std::vector<void *> ggml_cuda_cpy_fn_ptrs;
|
2516
2530
|
|
2517
2531
|
if (cuda_ctx->cuda_graph->graph == nullptr) {
|
2518
2532
|
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
|
@@ -2588,9 +2602,10 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2588
2602
|
if (node->op == GGML_OP_CPY) {
|
2589
2603
|
// store the copy op parameter which changes with each token.
|
2590
2604
|
cuda_ctx->cuda_graph->updated_kernel_arg.push_back((char **) &(node->src[1]->data));
|
2591
|
-
|
2592
|
-
|
2593
|
-
|
2605
|
+
// store a pointer to each copy op CUDA kernel to identify it later
|
2606
|
+
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
|
2607
|
+
if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
|
2608
|
+
ggml_cuda_cpy_fn_ptrs.push_back(ptr);
|
2594
2609
|
}
|
2595
2610
|
}
|
2596
2611
|
|
@@ -2720,7 +2735,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
|
|
2720
2735
|
if (!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
|
2721
2736
|
int k = 0;
|
2722
2737
|
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
|
2723
|
-
if (cuda_ctx->cuda_graph->params[i].func
|
2738
|
+
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
|
2724
2739
|
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
|
2725
2740
|
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
|
2726
2741
|
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
|
@@ -2871,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
2871
2886
|
case GGML_OP_CONT:
|
2872
2887
|
case GGML_OP_DIAG_MASK_INF:
|
2873
2888
|
case GGML_OP_SOFT_MAX:
|
2889
|
+
return true;
|
2874
2890
|
case GGML_OP_ROPE:
|
2891
|
+
return ggml_is_contiguous(op->src[0]);
|
2875
2892
|
case GGML_OP_IM2COL:
|
2876
2893
|
case GGML_OP_POOL_2D:
|
2877
2894
|
case GGML_OP_SUM_ROWS:
|
@@ -1597,7 +1597,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
1597
1597
|
{
|
1598
1598
|
GGML_ASSERT(ne00 == ne10);
|
1599
1599
|
|
1600
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1601
1600
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1602
1601
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1603
1602
|
|
@@ -35,6 +35,10 @@ enum ggml_metal_kernel_type {
|
|
35
35
|
GGML_METAL_KERNEL_TYPE_MUL_ROW,
|
36
36
|
GGML_METAL_KERNEL_TYPE_DIV,
|
37
37
|
GGML_METAL_KERNEL_TYPE_DIV_ROW,
|
38
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
39
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
40
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
|
41
|
+
GGML_METAL_KERNEL_TYPE_REPEAT_I16,
|
38
42
|
GGML_METAL_KERNEL_TYPE_SCALE,
|
39
43
|
GGML_METAL_KERNEL_TYPE_SCALE_4,
|
40
44
|
GGML_METAL_KERNEL_TYPE_CLAMP,
|
@@ -184,9 +188,9 @@ enum ggml_metal_kernel_type {
|
|
184
188
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
185
189
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
186
190
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
187
|
-
|
191
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
188
192
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
189
|
-
|
193
|
+
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
190
194
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
191
195
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
192
196
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
@@ -485,6 +489,10 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
485
489
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
486
490
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
487
491
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
492
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
493
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
494
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
|
495
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I16, repeat_i16, true);
|
488
496
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
489
497
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
490
498
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true);
|
@@ -634,9 +642,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
634
642
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, ctx->support_simdgroup_mm);
|
635
643
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, ctx->support_simdgroup_mm);
|
636
644
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, ctx->support_simdgroup_mm);
|
637
|
-
|
645
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
638
646
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
639
|
-
|
647
|
+
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
640
648
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
641
649
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
642
650
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
@@ -746,6 +754,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
746
754
|
case GGML_OP_ACC:
|
747
755
|
case GGML_OP_MUL:
|
748
756
|
case GGML_OP_DIV:
|
757
|
+
case GGML_OP_REPEAT:
|
749
758
|
case GGML_OP_SCALE:
|
750
759
|
case GGML_OP_CLAMP:
|
751
760
|
case GGML_OP_SQR:
|
@@ -770,6 +779,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
770
779
|
case GGML_OP_LEAKY_RELU:
|
771
780
|
return true;
|
772
781
|
case GGML_OP_FLASH_ATTN_EXT:
|
782
|
+
if (op->src[0]->ne[0] == 256) {
|
783
|
+
return false;
|
784
|
+
}
|
773
785
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
774
786
|
case GGML_OP_MUL_MAT:
|
775
787
|
case GGML_OP_MUL_MAT_ID:
|
@@ -976,10 +988,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
976
988
|
switch (dst->op) {
|
977
989
|
case GGML_OP_CONCAT:
|
978
990
|
{
|
979
|
-
const int64_t nb = ne00;
|
980
|
-
|
981
991
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
982
992
|
|
993
|
+
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
994
|
+
|
983
995
|
[encoder setComputePipelineState:pipeline];
|
984
996
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
985
997
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
@@ -1008,7 +1020,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1008
1020
|
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
1009
1021
|
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
1010
1022
|
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
1011
|
-
[encoder setBytes:&
|
1023
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:27];
|
1012
1024
|
|
1013
1025
|
const int nth = MIN(1024, ne0);
|
1014
1026
|
|
@@ -1018,11 +1030,14 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1018
1030
|
case GGML_OP_MUL:
|
1019
1031
|
case GGML_OP_DIV:
|
1020
1032
|
{
|
1033
|
+
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
1034
|
+
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
1035
|
+
|
1021
1036
|
const size_t offs = 0;
|
1022
1037
|
|
1023
1038
|
bool bcast_row = false;
|
1024
1039
|
|
1025
|
-
int64_t nb = ne00;
|
1040
|
+
int64_t nb = ne00; // used by the "row" kernels
|
1026
1041
|
|
1027
1042
|
id<MTLComputePipelineState> pipeline = nil;
|
1028
1043
|
|
@@ -1091,6 +1106,42 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1091
1106
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1092
1107
|
}
|
1093
1108
|
} break;
|
1109
|
+
case GGML_OP_REPEAT:
|
1110
|
+
{
|
1111
|
+
id<MTLComputePipelineState> pipeline;
|
1112
|
+
|
1113
|
+
switch (src0t) {
|
1114
|
+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F32].pipeline; break;
|
1115
|
+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_F16].pipeline; break;
|
1116
|
+
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I32].pipeline; break;
|
1117
|
+
case GGML_TYPE_I16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REPEAT_I16].pipeline; break;
|
1118
|
+
default: GGML_ASSERT(false);
|
1119
|
+
}
|
1120
|
+
|
1121
|
+
[encoder setComputePipelineState:pipeline];
|
1122
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
1123
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
1124
|
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
1125
|
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
1126
|
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
1127
|
+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
1128
|
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
|
1129
|
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
1130
|
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
1131
|
+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
|
1132
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
|
1133
|
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
|
1134
|
+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
|
1135
|
+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
|
1136
|
+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
|
1137
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
1138
|
+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
|
1139
|
+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
|
1140
|
+
|
1141
|
+
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
1142
|
+
|
1143
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
1144
|
+
} break;
|
1094
1145
|
case GGML_OP_ACC:
|
1095
1146
|
{
|
1096
1147
|
GGML_ASSERT(src0t == GGML_TYPE_F32);
|
@@ -1468,7 +1519,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
1468
1519
|
{
|
1469
1520
|
GGML_ASSERT(ne00 == ne10);
|
1470
1521
|
|
1471
|
-
// TODO: assert that dim2 and dim3 are contiguous
|
1472
1522
|
GGML_ASSERT(ne12 % ne02 == 0);
|
1473
1523
|
GGML_ASSERT(ne13 % ne03 == 0);
|
1474
1524
|
|
@@ -2136,6 +2186,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2136
2186
|
case GGML_OP_RMS_NORM:
|
2137
2187
|
{
|
2138
2188
|
GGML_ASSERT(ne00 % 4 == 0);
|
2189
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2139
2190
|
|
2140
2191
|
float eps;
|
2141
2192
|
memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2163,6 +2214,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2163
2214
|
case GGML_OP_GROUP_NORM:
|
2164
2215
|
{
|
2165
2216
|
GGML_ASSERT(ne00 % 4 == 0);
|
2217
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
2166
2218
|
|
2167
2219
|
//float eps;
|
2168
2220
|
//memcpy(&eps, dst->op_params, sizeof(float));
|
@@ -2196,6 +2248,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2196
2248
|
} break;
|
2197
2249
|
case GGML_OP_NORM:
|
2198
2250
|
{
|
2251
|
+
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
2252
|
+
|
2199
2253
|
float eps;
|
2200
2254
|
memcpy(&eps, dst->op_params, sizeof(float));
|
2201
2255
|
|
@@ -2573,7 +2627,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2573
2627
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
2574
2628
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
2575
2629
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
2576
|
-
|
2630
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
2577
2631
|
default:
|
2578
2632
|
{
|
2579
2633
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -2586,7 +2640,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
2586
2640
|
|
2587
2641
|
switch (ne00) {
|
2588
2642
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
2589
|
-
|
2643
|
+
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
2590
2644
|
default:
|
2591
2645
|
{
|
2592
2646
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
@@ -168,6 +168,53 @@ kernel void kernel_div(
|
|
168
168
|
}
|
169
169
|
}
|
170
170
|
|
171
|
+
template<typename T>
|
172
|
+
kernel void kernel_repeat(
|
173
|
+
device const char * src0,
|
174
|
+
device char * dst,
|
175
|
+
constant int64_t & ne00,
|
176
|
+
constant int64_t & ne01,
|
177
|
+
constant int64_t & ne02,
|
178
|
+
constant int64_t & ne03,
|
179
|
+
constant uint64_t & nb00,
|
180
|
+
constant uint64_t & nb01,
|
181
|
+
constant uint64_t & nb02,
|
182
|
+
constant uint64_t & nb03,
|
183
|
+
constant int64_t & ne0,
|
184
|
+
constant int64_t & ne1,
|
185
|
+
constant int64_t & ne2,
|
186
|
+
constant int64_t & ne3,
|
187
|
+
constant uint64_t & nb0,
|
188
|
+
constant uint64_t & nb1,
|
189
|
+
constant uint64_t & nb2,
|
190
|
+
constant uint64_t & nb3,
|
191
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
192
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
193
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
194
|
+
const int64_t i3 = tgpig.z;
|
195
|
+
const int64_t i2 = tgpig.y;
|
196
|
+
const int64_t i1 = tgpig.x;
|
197
|
+
|
198
|
+
const int64_t i03 = i3 % ne03;
|
199
|
+
const int64_t i02 = i2 % ne02;
|
200
|
+
const int64_t i01 = i1 % ne01;
|
201
|
+
|
202
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
|
203
|
+
device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
|
204
|
+
|
205
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
206
|
+
const int i00 = i0 % ne00;
|
207
|
+
*((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
|
208
|
+
}
|
209
|
+
}
|
210
|
+
|
211
|
+
typedef decltype(kernel_repeat<float>) kernel_repeat_t;
|
212
|
+
|
213
|
+
template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
|
214
|
+
template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
|
215
|
+
template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
|
216
|
+
template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
|
217
|
+
|
171
218
|
// assumption: src1 is a row
|
172
219
|
// broadcast src1 into src0
|
173
220
|
kernel void kernel_add_row(
|
@@ -1720,13 +1767,13 @@ kernel void kernel_rope(
|
|
1720
1767
|
|
1721
1768
|
const int64_t p = pos[i2];
|
1722
1769
|
|
1723
|
-
const float
|
1770
|
+
const float theta_base = (float)p;
|
1724
1771
|
const float inv_ndims = -1.f/n_dims;
|
1725
1772
|
|
1726
1773
|
if (!is_neox) {
|
1727
1774
|
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
|
1775
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
|
1728
1776
|
|
1729
|
-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
|
1730
1777
|
float cos_theta, sin_theta;
|
1731
1778
|
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1732
1779
|
|
@@ -1742,18 +1789,14 @@ kernel void kernel_rope(
|
|
1742
1789
|
} else {
|
1743
1790
|
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
|
1744
1791
|
if (ic < n_dims) {
|
1745
|
-
const int64_t
|
1792
|
+
const int64_t i0 = ic/2;
|
1746
1793
|
|
1747
|
-
|
1748
|
-
const float cur_rot = inv_ndims*ic - ib;
|
1749
|
-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
|
1794
|
+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
|
1750
1795
|
|
1751
|
-
const float theta =
|
1796
|
+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
|
1752
1797
|
|
1753
1798
|
float cos_theta, sin_theta;
|
1754
|
-
rope_yarn(theta, freq_scale, corr_dims,
|
1755
|
-
|
1756
|
-
const int64_t i0 = ib*n_dims + ic/2;
|
1799
|
+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
|
1757
1800
|
|
1758
1801
|
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
1759
1802
|
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
@@ -2418,7 +2461,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f
|
|
2418
2461
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>;
|
2419
2462
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>;
|
2420
2463
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>;
|
2421
|
-
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2464
|
+
//template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>;
|
2422
2465
|
|
2423
2466
|
template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup
|
2424
2467
|
kernel void kernel_flash_attn_ext_vec_f16(
|
@@ -2696,7 +2739,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
2696
2739
|
}
|
2697
2740
|
|
2698
2741
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
2699
|
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2742
|
+
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
2700
2743
|
|
2701
2744
|
kernel void kernel_cpy_f16_f16(
|
2702
2745
|
device const half * src0,
|
@@ -3319,31 +3362,30 @@ kernel void kernel_concat(
|
|
3319
3362
|
constant uint64_t & nb1,
|
3320
3363
|
constant uint64_t & nb2,
|
3321
3364
|
constant uint64_t & nb3,
|
3365
|
+
constant int32_t & dim,
|
3322
3366
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
3323
3367
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
3324
3368
|
uint3 ntg[[threads_per_threadgroup]]) {
|
3325
3369
|
|
3326
|
-
const int64_t
|
3327
|
-
const int64_t
|
3328
|
-
const int64_t
|
3370
|
+
const int64_t i3 = tgpig.z;
|
3371
|
+
const int64_t i2 = tgpig.y;
|
3372
|
+
const int64_t i1 = tgpig.x;
|
3329
3373
|
|
3330
|
-
|
3331
|
-
|
3332
|
-
const int64_t i11 = i01 % ne11;
|
3374
|
+
int64_t o[4] = {0, 0, 0, 0};
|
3375
|
+
o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
|
3333
3376
|
|
3334
|
-
device const
|
3335
|
-
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
|
3336
|
-
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
|
3377
|
+
device const float * x;
|
3337
3378
|
|
3338
3379
|
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
3339
|
-
if (
|
3340
|
-
(
|
3341
|
-
src0_ptr += ntg.x*nb00;
|
3380
|
+
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
|
3381
|
+
x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
|
3342
3382
|
} else {
|
3343
|
-
(
|
3344
|
-
src1_ptr += ntg.x*nb10;
|
3383
|
+
x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
|
3345
3384
|
}
|
3346
|
-
|
3385
|
+
|
3386
|
+
device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
3387
|
+
|
3388
|
+
*y = *x;
|
3347
3389
|
}
|
3348
3390
|
}
|
3349
3391
|
|