llama_cpp 0.13.0 → 0.14.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +59 -26
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +6 -4
- data/vendor/tmp/llama.cpp/Makefile +2 -3
- data/vendor/tmp/llama.cpp/ggml-backend-impl.h +4 -3
- data/vendor/tmp/llama.cpp/ggml-backend.c +18 -21
- data/vendor/tmp/llama.cpp/ggml-backend.h +16 -15
- data/vendor/tmp/llama.cpp/ggml-cuda.cu +949 -168
- data/vendor/tmp/llama.cpp/ggml-kompute.cpp +2 -2
- data/vendor/tmp/llama.cpp/ggml-metal.m +63 -7
- data/vendor/tmp/llama.cpp/ggml-metal.metal +120 -75
- data/vendor/tmp/llama.cpp/ggml-opencl.cpp +2 -2
- data/vendor/tmp/llama.cpp/ggml-quants.c +178 -133
- data/vendor/tmp/llama.cpp/ggml-sycl.cpp +3432 -1118
- data/vendor/tmp/llama.cpp/ggml-sycl.h +5 -0
- data/vendor/tmp/llama.cpp/ggml-vulkan-shaders.hpp +39336 -43461
- data/vendor/tmp/llama.cpp/ggml-vulkan.cpp +1327 -773
- data/vendor/tmp/llama.cpp/ggml-vulkan.h +1 -0
- data/vendor/tmp/llama.cpp/ggml.c +227 -15
- data/vendor/tmp/llama.cpp/ggml.h +30 -4
- data/vendor/tmp/llama.cpp/llama.cpp +631 -211
- data/vendor/tmp/llama.cpp/llama.h +28 -10
- metadata +2 -2
|
@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
|
|
|
1927
1927
|
return ggml_backend_kompute_buffer_type(ctx->device);
|
|
1928
1928
|
}
|
|
1929
1929
|
|
|
1930
|
-
static
|
|
1930
|
+
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
1931
1931
|
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
|
1932
1932
|
ggml_vk_graph_compute(ctx, cgraph);
|
|
1933
|
-
return
|
|
1933
|
+
return GGML_STATUS_SUCCESS;
|
|
1934
1934
|
}
|
|
1935
1935
|
|
|
1936
1936
|
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
|
@@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
|
|
|
163
163
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
164
164
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
165
165
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
166
|
+
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
167
|
+
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
|
166
168
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
167
169
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
|
168
170
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
|
@@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
569
571
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
570
572
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
571
573
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
574
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
575
|
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
|
572
576
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
573
577
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
|
574
578
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
|
@@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
697
701
|
return false;
|
|
698
702
|
case GGML_OP_UPSCALE:
|
|
699
703
|
case GGML_OP_PAD:
|
|
704
|
+
case GGML_OP_ARANGE:
|
|
705
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
700
706
|
case GGML_OP_ARGSORT:
|
|
701
707
|
case GGML_OP_LEAKY_RELU:
|
|
702
708
|
return true;
|
|
@@ -742,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
742
748
|
}
|
|
743
749
|
}
|
|
744
750
|
|
|
745
|
-
static
|
|
751
|
+
static enum ggml_status ggml_metal_graph_compute(
|
|
746
752
|
struct ggml_metal_context * ctx,
|
|
747
753
|
struct ggml_cgraph * gf) {
|
|
748
754
|
|
|
@@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
|
|
|
1091
1097
|
{
|
|
1092
1098
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
1093
1099
|
|
|
1094
|
-
|
|
1100
|
+
float scale;
|
|
1101
|
+
memcpy(&scale, dst->op_params, sizeof(scale));
|
|
1095
1102
|
|
|
1096
1103
|
int64_t n = ggml_nelements(dst);
|
|
1097
1104
|
|
|
@@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
|
|
|
1250
1257
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
|
1251
1258
|
}
|
|
1252
1259
|
|
|
1253
|
-
|
|
1254
|
-
|
|
1260
|
+
float scale;
|
|
1261
|
+
float max_bias;
|
|
1262
|
+
|
|
1263
|
+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
|
1264
|
+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
|
1255
1265
|
|
|
1256
1266
|
const int64_t nrows_x = ggml_nrows(src0);
|
|
1257
1267
|
const int64_t nrows_y = src0->ne[1];
|
|
1268
|
+
|
|
1258
1269
|
const uint32_t n_head_kv = nrows_x/nrows_y;
|
|
1259
1270
|
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
|
1260
1271
|
|
|
@@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
|
|
|
2086
2097
|
|
|
2087
2098
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
|
2088
2099
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
|
2100
|
+
|
|
2089
2101
|
float max_bias;
|
|
2090
2102
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
|
2091
2103
|
|
|
@@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
|
|
|
2300
2312
|
|
|
2301
2313
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2302
2314
|
} break;
|
|
2315
|
+
case GGML_OP_ARANGE:
|
|
2316
|
+
{
|
|
2317
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
2318
|
+
|
|
2319
|
+
float start;
|
|
2320
|
+
float step;
|
|
2321
|
+
|
|
2322
|
+
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
|
2323
|
+
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
|
|
2324
|
+
|
|
2325
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
|
2326
|
+
|
|
2327
|
+
[encoder setComputePipelineState:pipeline];
|
|
2328
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:0];
|
|
2329
|
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
|
|
2330
|
+
[encoder setBytes:&start length:sizeof(start) atIndex:2];
|
|
2331
|
+
[encoder setBytes:&step length:sizeof(step) atIndex:3];
|
|
2332
|
+
|
|
2333
|
+
const int nth = MIN(1024, ne0);
|
|
2334
|
+
|
|
2335
|
+
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2336
|
+
} break;
|
|
2337
|
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
2338
|
+
{
|
|
2339
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
2340
|
+
|
|
2341
|
+
const int dim = dst->op_params[0];
|
|
2342
|
+
const int max_period = dst->op_params[1];
|
|
2343
|
+
|
|
2344
|
+
const int half = dim / 2;
|
|
2345
|
+
|
|
2346
|
+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
|
|
2347
|
+
|
|
2348
|
+
[encoder setComputePipelineState:pipeline];
|
|
2349
|
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
2350
|
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
2351
|
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
|
|
2352
|
+
[encoder setBytes:&dim length:sizeof(dim) atIndex:3];
|
|
2353
|
+
[encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
|
|
2354
|
+
|
|
2355
|
+
const int nth = MIN(1024, half);
|
|
2356
|
+
|
|
2357
|
+
[encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
2358
|
+
} break;
|
|
2303
2359
|
case GGML_OP_ARGSORT:
|
|
2304
2360
|
{
|
|
2305
2361
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
@@ -2428,7 +2484,7 @@ static bool ggml_metal_graph_compute(
|
|
|
2428
2484
|
MTLCommandBufferStatus status = [command_buffer status];
|
|
2429
2485
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
2430
2486
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
2431
|
-
return
|
|
2487
|
+
return GGML_STATUS_FAILED;
|
|
2432
2488
|
}
|
|
2433
2489
|
}
|
|
2434
2490
|
|
|
@@ -2437,7 +2493,7 @@ static bool ggml_metal_graph_compute(
|
|
|
2437
2493
|
}
|
|
2438
2494
|
|
|
2439
2495
|
}
|
|
2440
|
-
return
|
|
2496
|
+
return GGML_STATUS_SUCCESS;
|
|
2441
2497
|
}
|
|
2442
2498
|
|
|
2443
2499
|
////////////////////////////////////////////////////////////////////////////////
|
|
@@ -2739,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
|
|
|
2739
2795
|
UNUSED(backend);
|
|
2740
2796
|
}
|
|
2741
2797
|
|
|
2742
|
-
GGML_CALL static
|
|
2798
|
+
GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
2743
2799
|
struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
|
|
2744
2800
|
|
|
2745
2801
|
return ggml_metal_graph_compute(metal_ctx, cgraph);
|
|
@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
|
|
|
1959
1959
|
}
|
|
1960
1960
|
}
|
|
1961
1961
|
|
|
1962
|
+
kernel void kernel_arange_f32(
|
|
1963
|
+
device char * dst,
|
|
1964
|
+
constant int64_t & ne0,
|
|
1965
|
+
constant float & start,
|
|
1966
|
+
constant float & step,
|
|
1967
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1968
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1969
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1970
|
+
|
|
1971
|
+
device float * dst_ptr = (device float *) dst;
|
|
1972
|
+
|
|
1973
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
1974
|
+
dst_ptr[i0] = start + step * i0;
|
|
1975
|
+
}
|
|
1976
|
+
}
|
|
1977
|
+
|
|
1978
|
+
kernel void kernel_timestep_embedding_f32(
|
|
1979
|
+
device const char * src0,
|
|
1980
|
+
device char * dst,
|
|
1981
|
+
constant uint64_t & nb1,
|
|
1982
|
+
constant int & dim,
|
|
1983
|
+
constant int & max_period,
|
|
1984
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1985
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
1986
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
1987
|
+
|
|
1988
|
+
int i = tgpig.x;
|
|
1989
|
+
device float * embed_data = (device float *)(dst + i*nb1);
|
|
1990
|
+
|
|
1991
|
+
int half_ = dim / 2;
|
|
1992
|
+
for (int j = tpitg.x; j < half_; j += ntg.x) {
|
|
1993
|
+
float timestep = ((device float *)src0)[i];
|
|
1994
|
+
float freq = (float)exp(-log((float)max_period) * j / half_);
|
|
1995
|
+
float arg = timestep * freq;
|
|
1996
|
+
embed_data[j ] = cos(arg);
|
|
1997
|
+
embed_data[j + half_] = sin(arg);
|
|
1998
|
+
}
|
|
1999
|
+
|
|
2000
|
+
if (dim % 2 != 0 && tpitg.x == 0) {
|
|
2001
|
+
embed_data[dim] = 0.f;
|
|
2002
|
+
}
|
|
2003
|
+
}
|
|
2004
|
+
|
|
1962
2005
|
// bitonic sort implementation following the CUDA kernels as reference
|
|
1963
2006
|
typedef void (argsort_t)(
|
|
1964
2007
|
device const float * x,
|
|
@@ -4087,71 +4130,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
|
|
|
4087
4130
|
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
|
|
4088
4131
|
};
|
|
4089
4132
|
|
|
4090
|
-
constexpr constant static uint32_t
|
|
4091
|
-
|
|
4092
|
-
|
|
4093
|
-
|
|
4094
|
-
|
|
4095
|
-
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
|
|
4099
|
-
|
|
4100
|
-
|
|
4101
|
-
|
|
4102
|
-
|
|
4103
|
-
|
|
4104
|
-
|
|
4105
|
-
|
|
4106
|
-
|
|
4107
|
-
|
|
4108
|
-
|
|
4109
|
-
|
|
4110
|
-
|
|
4111
|
-
|
|
4112
|
-
|
|
4113
|
-
|
|
4114
|
-
|
|
4115
|
-
|
|
4116
|
-
|
|
4117
|
-
|
|
4118
|
-
|
|
4119
|
-
|
|
4120
|
-
|
|
4121
|
-
|
|
4122
|
-
|
|
4123
|
-
|
|
4124
|
-
|
|
4125
|
-
|
|
4126
|
-
|
|
4127
|
-
|
|
4128
|
-
|
|
4129
|
-
|
|
4130
|
-
|
|
4131
|
-
|
|
4132
|
-
|
|
4133
|
-
|
|
4134
|
-
|
|
4135
|
-
|
|
4136
|
-
|
|
4137
|
-
|
|
4138
|
-
|
|
4139
|
-
|
|
4140
|
-
|
|
4141
|
-
|
|
4142
|
-
|
|
4143
|
-
|
|
4144
|
-
|
|
4145
|
-
|
|
4146
|
-
|
|
4147
|
-
|
|
4148
|
-
|
|
4149
|
-
|
|
4150
|
-
|
|
4151
|
-
|
|
4152
|
-
|
|
4153
|
-
|
|
4154
|
-
|
|
4133
|
+
constexpr constant static uint32_t iq3s_grid[512] = {
|
|
4134
|
+
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
|
4135
|
+
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
|
4136
|
+
0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
|
|
4137
|
+
0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
|
|
4138
|
+
0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
|
|
4139
|
+
0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
|
|
4140
|
+
0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
|
|
4141
|
+
0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
|
|
4142
|
+
0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
|
|
4143
|
+
0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
|
|
4144
|
+
0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
|
|
4145
|
+
0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
|
|
4146
|
+
0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
|
|
4147
|
+
0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
|
|
4148
|
+
0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
|
|
4149
|
+
0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
|
|
4150
|
+
0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
|
|
4151
|
+
0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
|
|
4152
|
+
0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
|
|
4153
|
+
0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
|
|
4154
|
+
0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
|
|
4155
|
+
0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
|
|
4156
|
+
0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
|
|
4157
|
+
0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
|
|
4158
|
+
0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
|
|
4159
|
+
0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
|
|
4160
|
+
0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
|
|
4161
|
+
0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
|
|
4162
|
+
0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
|
|
4163
|
+
0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
|
|
4164
|
+
0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
|
|
4165
|
+
0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
|
|
4166
|
+
0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
|
|
4167
|
+
0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
|
|
4168
|
+
0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
|
|
4169
|
+
0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
|
|
4170
|
+
0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
|
|
4171
|
+
0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
|
|
4172
|
+
0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
|
|
4173
|
+
0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
|
|
4174
|
+
0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
|
|
4175
|
+
0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
|
|
4176
|
+
0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
|
|
4177
|
+
0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
|
|
4178
|
+
0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
|
|
4179
|
+
0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
|
|
4180
|
+
0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
|
|
4181
|
+
0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
|
|
4182
|
+
0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
|
|
4183
|
+
0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
|
|
4184
|
+
0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
|
|
4185
|
+
0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
|
|
4186
|
+
0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
|
|
4187
|
+
0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
|
|
4188
|
+
0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
|
|
4189
|
+
0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
|
|
4190
|
+
0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
|
|
4191
|
+
0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
|
|
4192
|
+
0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
|
|
4193
|
+
0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
|
|
4194
|
+
0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
|
|
4195
|
+
0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
|
|
4196
|
+
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
|
4197
|
+
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
|
4155
4198
|
};
|
|
4156
4199
|
|
|
4157
4200
|
#define NGRID_IQ1S 512
|
|
@@ -4742,7 +4785,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
4742
4785
|
{
|
|
4743
4786
|
int nval = 8;
|
|
4744
4787
|
int pos = (32*sgitg + tiisg)*nval;
|
|
4745
|
-
for (int i = 0; i < nval; ++i) values[pos + i] =
|
|
4788
|
+
for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
|
|
4746
4789
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
4747
4790
|
}
|
|
4748
4791
|
|
|
@@ -4769,12 +4812,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
4769
4812
|
for (int row = 0; row < N_DST; row++) {
|
|
4770
4813
|
|
|
4771
4814
|
const float db = dh[0];
|
|
4772
|
-
const float d = db * (
|
|
4815
|
+
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
|
4773
4816
|
|
|
4774
4817
|
float2 sum = {0};
|
|
4775
4818
|
for (int l = 0; l < 4; ++l) {
|
|
4776
|
-
const threadgroup
|
|
4777
|
-
const threadgroup
|
|
4819
|
+
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
|
|
4820
|
+
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
|
|
4821
|
+
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
|
4822
|
+
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
|
4778
4823
|
for (int j = 0; j < 4; ++j) {
|
|
4779
4824
|
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
|
4780
4825
|
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
|
@@ -4795,7 +4840,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
4795
4840
|
for (int row = 0; row < N_DST; ++row) {
|
|
4796
4841
|
all_sum = simd_sum(sumf[row]);
|
|
4797
4842
|
if (tiisg == 0) {
|
|
4798
|
-
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum
|
|
4843
|
+
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
|
4799
4844
|
}
|
|
4800
4845
|
}
|
|
4801
4846
|
}
|
|
@@ -5685,15 +5730,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
|
|
|
5685
5730
|
device const uint8_t * qs = xb->qs + 8*ib32;
|
|
5686
5731
|
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
|
5687
5732
|
const uint8_t qh = xb->qh[ib32] >> 4*il;
|
|
5688
|
-
const float dl = d * (
|
|
5689
|
-
constant uint8_t * grid1 = (constant uint8_t *)(
|
|
5690
|
-
constant uint8_t * grid2 = (constant uint8_t *)(
|
|
5733
|
+
const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
|
|
5734
|
+
constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
|
|
5735
|
+
constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
|
|
5691
5736
|
for (int i = 0; i < 4; ++i) {
|
|
5692
5737
|
reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
|
|
5693
5738
|
reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
|
|
5694
5739
|
}
|
|
5695
|
-
grid1 = (constant uint8_t *)(
|
|
5696
|
-
grid2 = (constant uint8_t *)(
|
|
5740
|
+
grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
|
|
5741
|
+
grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
|
|
5697
5742
|
for (int i = 0; i < 4; ++i) {
|
|
5698
5743
|
reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
|
|
5699
5744
|
reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
|
|
@@ -2231,7 +2231,7 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
|
|
|
2231
2231
|
GGML_UNUSED(backend);
|
|
2232
2232
|
}
|
|
2233
2233
|
|
|
2234
|
-
static
|
|
2234
|
+
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
|
|
2235
2235
|
for (int i = 0; i < graph->n_nodes; ++i) {
|
|
2236
2236
|
ggml_tensor * node = graph->nodes[i];
|
|
2237
2237
|
switch (node->op) {
|
|
@@ -2246,7 +2246,7 @@ static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgrap
|
|
|
2246
2246
|
}
|
|
2247
2247
|
}
|
|
2248
2248
|
|
|
2249
|
-
return
|
|
2249
|
+
return GGML_STATUS_SUCCESS;
|
|
2250
2250
|
|
|
2251
2251
|
GGML_UNUSED(backend);
|
|
2252
2252
|
}
|