llama_cpp 0.13.0 → 0.14.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1927,10 +1927,10 @@ static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(g
1927
1927
  return ggml_backend_kompute_buffer_type(ctx->device);
1928
1928
  }
1929
1929
 
1930
- static bool ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1930
+ static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
1931
1931
  auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
1932
1932
  ggml_vk_graph_compute(ctx, cgraph);
1933
- return true;
1933
+ return GGML_STATUS_SUCCESS;
1934
1934
  }
1935
1935
 
1936
1936
  static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -163,6 +163,8 @@ enum ggml_metal_kernel_type {
163
163
  GGML_METAL_KERNEL_TYPE_IM2COL_F32,
164
164
  GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
165
165
  GGML_METAL_KERNEL_TYPE_PAD_F32,
166
+ GGML_METAL_KERNEL_TYPE_ARANGE_F32,
167
+ GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
166
168
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
167
169
  GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
168
170
  GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
@@ -569,6 +571,8 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
569
571
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
570
572
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
571
573
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
574
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
575
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
572
576
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
573
577
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
574
578
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
@@ -697,6 +701,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
697
701
  return false;
698
702
  case GGML_OP_UPSCALE:
699
703
  case GGML_OP_PAD:
704
+ case GGML_OP_ARANGE:
705
+ case GGML_OP_TIMESTEP_EMBEDDING:
700
706
  case GGML_OP_ARGSORT:
701
707
  case GGML_OP_LEAKY_RELU:
702
708
  return true;
@@ -742,7 +748,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
742
748
  }
743
749
  }
744
750
 
745
- static bool ggml_metal_graph_compute(
751
+ static enum ggml_status ggml_metal_graph_compute(
746
752
  struct ggml_metal_context * ctx,
747
753
  struct ggml_cgraph * gf) {
748
754
 
@@ -1091,7 +1097,8 @@ static bool ggml_metal_graph_compute(
1091
1097
  {
1092
1098
  GGML_ASSERT(ggml_is_contiguous(src0));
1093
1099
 
1094
- const float scale = *(const float *) dst->op_params;
1100
+ float scale;
1101
+ memcpy(&scale, dst->op_params, sizeof(scale));
1095
1102
 
1096
1103
  int64_t n = ggml_nelements(dst);
1097
1104
 
@@ -1250,11 +1257,15 @@ static bool ggml_metal_graph_compute(
1250
1257
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1251
1258
  }
1252
1259
 
1253
- const float scale = ((float *) dst->op_params)[0];
1254
- const float max_bias = ((float *) dst->op_params)[1];
1260
+ float scale;
1261
+ float max_bias;
1262
+
1263
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
1264
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
1255
1265
 
1256
1266
  const int64_t nrows_x = ggml_nrows(src0);
1257
1267
  const int64_t nrows_y = src0->ne[1];
1268
+
1258
1269
  const uint32_t n_head_kv = nrows_x/nrows_y;
1259
1270
  const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1260
1271
 
@@ -2086,6 +2097,7 @@ static bool ggml_metal_graph_compute(
2086
2097
 
2087
2098
  //const int n_past = ((int32_t *) dst->op_params)[0];
2088
2099
  const int n_head = ((int32_t *) dst->op_params)[1];
2100
+
2089
2101
  float max_bias;
2090
2102
  memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
2091
2103
 
@@ -2300,6 +2312,50 @@ static bool ggml_metal_graph_compute(
2300
2312
 
2301
2313
  [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2302
2314
  } break;
2315
+ case GGML_OP_ARANGE:
2316
+ {
2317
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2318
+
2319
+ float start;
2320
+ float step;
2321
+
2322
+ memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
2323
+ memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
2324
+
2325
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
2326
+
2327
+ [encoder setComputePipelineState:pipeline];
2328
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:0];
2329
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1];
2330
+ [encoder setBytes:&start length:sizeof(start) atIndex:2];
2331
+ [encoder setBytes:&step length:sizeof(step) atIndex:3];
2332
+
2333
+ const int nth = MIN(1024, ne0);
2334
+
2335
+ [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2336
+ } break;
2337
+ case GGML_OP_TIMESTEP_EMBEDDING:
2338
+ {
2339
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
2340
+
2341
+ const int dim = dst->op_params[0];
2342
+ const int max_period = dst->op_params[1];
2343
+
2344
+ const int half = dim / 2;
2345
+
2346
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
2347
+
2348
+ [encoder setComputePipelineState:pipeline];
2349
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2350
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2351
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:2];
2352
+ [encoder setBytes:&dim length:sizeof(dim) atIndex:3];
2353
+ [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
2354
+
2355
+ const int nth = MIN(1024, half);
2356
+
2357
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
2358
+ } break;
2303
2359
  case GGML_OP_ARGSORT:
2304
2360
  {
2305
2361
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
@@ -2428,7 +2484,7 @@ static bool ggml_metal_graph_compute(
2428
2484
  MTLCommandBufferStatus status = [command_buffer status];
2429
2485
  if (status != MTLCommandBufferStatusCompleted) {
2430
2486
  GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
2431
- return false;
2487
+ return GGML_STATUS_FAILED;
2432
2488
  }
2433
2489
  }
2434
2490
 
@@ -2437,7 +2493,7 @@ static bool ggml_metal_graph_compute(
2437
2493
  }
2438
2494
 
2439
2495
  }
2440
- return true;
2496
+ return GGML_STATUS_SUCCESS;
2441
2497
  }
2442
2498
 
2443
2499
  ////////////////////////////////////////////////////////////////////////////////
@@ -2739,7 +2795,7 @@ GGML_CALL static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffe
2739
2795
  UNUSED(backend);
2740
2796
  }
2741
2797
 
2742
- GGML_CALL static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2798
+ GGML_CALL static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
2743
2799
  struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
2744
2800
 
2745
2801
  return ggml_metal_graph_compute(metal_ctx, cgraph);
@@ -1959,6 +1959,49 @@ kernel void kernel_pad_f32(
1959
1959
  }
1960
1960
  }
1961
1961
 
1962
+ kernel void kernel_arange_f32(
1963
+ device char * dst,
1964
+ constant int64_t & ne0,
1965
+ constant float & start,
1966
+ constant float & step,
1967
+ uint3 tgpig[[threadgroup_position_in_grid]],
1968
+ uint3 tpitg[[thread_position_in_threadgroup]],
1969
+ uint3 ntg[[threads_per_threadgroup]]) {
1970
+
1971
+ device float * dst_ptr = (device float *) dst;
1972
+
1973
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1974
+ dst_ptr[i0] = start + step * i0;
1975
+ }
1976
+ }
1977
+
1978
+ kernel void kernel_timestep_embedding_f32(
1979
+ device const char * src0,
1980
+ device char * dst,
1981
+ constant uint64_t & nb1,
1982
+ constant int & dim,
1983
+ constant int & max_period,
1984
+ uint3 tgpig[[threadgroup_position_in_grid]],
1985
+ uint3 tpitg[[thread_position_in_threadgroup]],
1986
+ uint3 ntg[[threads_per_threadgroup]]) {
1987
+
1988
+ int i = tgpig.x;
1989
+ device float * embed_data = (device float *)(dst + i*nb1);
1990
+
1991
+ int half_ = dim / 2;
1992
+ for (int j = tpitg.x; j < half_; j += ntg.x) {
1993
+ float timestep = ((device float *)src0)[i];
1994
+ float freq = (float)exp(-log((float)max_period) * j / half_);
1995
+ float arg = timestep * freq;
1996
+ embed_data[j ] = cos(arg);
1997
+ embed_data[j + half_] = sin(arg);
1998
+ }
1999
+
2000
+ if (dim % 2 != 0 && tpitg.x == 0) {
2001
+ embed_data[dim] = 0.f;
2002
+ }
2003
+ }
2004
+
1962
2005
  // bitonic sort implementation following the CUDA kernels as reference
1963
2006
  typedef void (argsort_t)(
1964
2007
  device const float * x,
@@ -4087,71 +4130,71 @@ constexpr constant static uint32_t iq3xxs_grid[256] = {
4087
4130
  0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
4088
4131
  };
4089
4132
 
4090
- constexpr constant static uint32_t iq3xs_grid[512] = {
4091
- 0x04040404, 0x0404040c, 0x04040414, 0x0404042c, 0x0404043e, 0x04040c04, 0x04040c0c, 0x04040c14,
4092
- 0x04040c24, 0x04040c34, 0x04041404, 0x0404140c, 0x0404142c, 0x04041c1c, 0x04042404, 0x04042414,
4093
- 0x0404242c, 0x0404243e, 0x04042c0c, 0x04042c1c, 0x04043404, 0x04043414, 0x04043e0c, 0x04043e24,
4094
- 0x04043e3e, 0x040c0404, 0x040c040c, 0x040c0414, 0x040c0424, 0x040c0c04, 0x040c0c0c, 0x040c0c2c,
4095
- 0x040c1404, 0x040c141c, 0x040c143e, 0x040c1c0c, 0x040c1c2c, 0x040c2424, 0x040c340c, 0x040c342c,
4096
- 0x040c3e14, 0x04140404, 0x0414040c, 0x0414042c, 0x0414043e, 0x04140c04, 0x04140c1c, 0x04140c34,
4097
- 0x0414140c, 0x0414142c, 0x04141c04, 0x04141c24, 0x04142414, 0x0414242c, 0x0414243e, 0x04142c0c,
4098
- 0x04142c1c, 0x04143e04, 0x04143e1c, 0x041c041c, 0x041c0c0c, 0x041c0c2c, 0x041c1404, 0x041c1414,
4099
- 0x041c1c0c, 0x041c1c1c, 0x041c1c34, 0x041c2424, 0x041c2c04, 0x041c2c14, 0x041c343e, 0x041c3e0c,
4100
- 0x041c3e2c, 0x04240404, 0x04240c1c, 0x04240c3e, 0x0424140c, 0x04241424, 0x04241c14, 0x04242404,
4101
- 0x0424241c, 0x04242c0c, 0x04243e04, 0x042c0414, 0x042c0424, 0x042c1404, 0x042c1414, 0x042c1434,
4102
- 0x042c1c1c, 0x042c240c, 0x042c242c, 0x042c243e, 0x042c3434, 0x042c3e1c, 0x04340434, 0x04340c0c,
4103
- 0x04340c1c, 0x04341c0c, 0x04342c14, 0x04343e0c, 0x043e0404, 0x043e0414, 0x043e0424, 0x043e1404,
4104
- 0x043e1414, 0x043e1434, 0x043e1c1c, 0x043e2c04, 0x043e2c24, 0x0c040404, 0x0c04040c, 0x0c040414,
4105
- 0x0c040424, 0x0c040c04, 0x0c040c0c, 0x0c040c1c, 0x0c040c2c, 0x0c040c3e, 0x0c041404, 0x0c041414,
4106
- 0x0c041c0c, 0x0c041c24, 0x0c041c34, 0x0c042c24, 0x0c042c34, 0x0c04340c, 0x0c043e14, 0x0c0c0404,
4107
- 0x0c0c040c, 0x0c0c041c, 0x0c0c0434, 0x0c0c0c04, 0x0c0c0c24, 0x0c0c140c, 0x0c0c1c04, 0x0c0c1c1c,
4108
- 0x0c0c240c, 0x0c0c2c04, 0x0c0c2c14, 0x0c0c3e04, 0x0c0c3e34, 0x0c140404, 0x0c140c14, 0x0c140c2c,
4109
- 0x0c140c3e, 0x0c141404, 0x0c141424, 0x0c141c14, 0x0c142404, 0x0c14241c, 0x0c142c2c, 0x0c143404,
4110
- 0x0c143e14, 0x0c1c040c, 0x0c1c0424, 0x0c1c043e, 0x0c1c0c04, 0x0c1c0c1c, 0x0c1c140c, 0x0c1c143e,
4111
- 0x0c1c1c04, 0x0c1c1c24, 0x0c1c240c, 0x0c1c3414, 0x0c1c3e04, 0x0c24041c, 0x0c24042c, 0x0c240c14,
4112
- 0x0c240c24, 0x0c241c0c, 0x0c241c1c, 0x0c242414, 0x0c242434, 0x0c242c04, 0x0c242c24, 0x0c2c040c,
4113
- 0x0c2c0c04, 0x0c2c0c1c, 0x0c2c140c, 0x0c2c1c04, 0x0c2c1c14, 0x0c2c2c0c, 0x0c341404, 0x0c341424,
4114
- 0x0c34143e, 0x0c342424, 0x0c342434, 0x0c3e040c, 0x0c3e041c, 0x0c3e0c04, 0x0c3e0c14, 0x0c3e140c,
4115
- 0x0c3e1c2c, 0x0c3e240c, 0x0c3e3414, 0x0c3e3e04, 0x14040404, 0x1404040c, 0x1404041c, 0x1404042c,
4116
- 0x1404043e, 0x14040c04, 0x14040c14, 0x14040c24, 0x14040c34, 0x1404140c, 0x1404141c, 0x1404143e,
4117
- 0x14041c04, 0x14041c14, 0x1404240c, 0x1404241c, 0x1404242c, 0x14042c04, 0x14042c14, 0x1404343e,
4118
- 0x14043e04, 0x14043e1c, 0x14043e2c, 0x140c0404, 0x140c0414, 0x140c0c04, 0x140c0c1c, 0x140c0c3e,
4119
- 0x140c1414, 0x140c142c, 0x140c1c0c, 0x140c1c24, 0x140c2414, 0x140c2c0c, 0x1414040c, 0x14140424,
4120
- 0x1414043e, 0x1414140c, 0x1414141c, 0x14141c04, 0x14141c3e, 0x1414240c, 0x14142c1c, 0x14142c3e,
4121
- 0x14143e0c, 0x14143e24, 0x141c0404, 0x141c0414, 0x141c042c, 0x141c0c0c, 0x141c1414, 0x141c1424,
4122
- 0x141c1c0c, 0x141c1c1c, 0x141c2414, 0x141c2c04, 0x141c3434, 0x1424040c, 0x1424043e, 0x14241404,
4123
- 0x1424141c, 0x14241c14, 0x14241c2c, 0x1424240c, 0x14243e14, 0x14243e2c, 0x142c0424, 0x142c0c0c,
4124
- 0x142c1414, 0x142c1c3e, 0x142c2404, 0x142c2c1c, 0x142c3e04, 0x14340404, 0x14340414, 0x1434043e,
4125
- 0x1434140c, 0x14342c2c, 0x1434340c, 0x143e042c, 0x143e0c0c, 0x143e1434, 0x143e1c04, 0x143e241c,
4126
- 0x143e2c04, 0x1c040414, 0x1c040c0c, 0x1c040c1c, 0x1c040c2c, 0x1c040c3e, 0x1c041414, 0x1c041c0c,
4127
- 0x1c041c1c, 0x1c041c2c, 0x1c042414, 0x1c042424, 0x1c04243e, 0x1c042c0c, 0x1c04341c, 0x1c043e0c,
4128
- 0x1c0c040c, 0x1c0c041c, 0x1c0c042c, 0x1c0c0c24, 0x1c0c140c, 0x1c0c141c, 0x1c0c2404, 0x1c0c3404,
4129
- 0x1c0c3e14, 0x1c0c3e34, 0x1c140404, 0x1c140c14, 0x1c141404, 0x1c141c14, 0x1c141c24, 0x1c142c04,
4130
- 0x1c1c040c, 0x1c1c0c04, 0x1c1c0c24, 0x1c1c140c, 0x1c1c141c, 0x1c1c143e, 0x1c1c1c04, 0x1c1c240c,
4131
- 0x1c1c241c, 0x1c1c243e, 0x1c1c2c2c, 0x1c1c3e1c, 0x1c24041c, 0x1c240c0c, 0x1c240c34, 0x1c241414,
4132
- 0x1c241c0c, 0x1c242c14, 0x1c243404, 0x1c243424, 0x1c2c040c, 0x1c2c0c04, 0x1c2c0c14, 0x1c2c142c,
4133
- 0x1c2c1c14, 0x1c2c2424, 0x1c2c2c34, 0x1c2c3e1c, 0x1c340c34, 0x1c34240c, 0x1c3e040c, 0x1c3e041c,
4134
- 0x1c3e1404, 0x1c3e1414, 0x1c3e1c2c, 0x24040404, 0x24040424, 0x24040c14, 0x24041404, 0x24041424,
4135
- 0x2404143e, 0x24041c14, 0x2404240c, 0x24042c04, 0x24043e04, 0x240c0414, 0x240c043e, 0x240c0c0c,
4136
- 0x240c0c1c, 0x240c1414, 0x240c1c04, 0x240c1c2c, 0x240c241c, 0x240c2c0c, 0x240c2c2c, 0x2414040c,
4137
- 0x2414041c, 0x24140c04, 0x24140c2c, 0x2414140c, 0x24141c1c, 0x24142404, 0x24142c3e, 0x24143414,
4138
- 0x24143e04, 0x241c0424, 0x241c0c0c, 0x241c0c1c, 0x241c1404, 0x241c1414, 0x241c1c0c, 0x241c1c2c,
4139
- 0x24240404, 0x24240414, 0x24241424, 0x24241c3e, 0x24242404, 0x24243e0c, 0x242c042c, 0x242c043e,
4140
- 0x242c140c, 0x242c3414, 0x24340c1c, 0x24341c24, 0x24343404, 0x243e0c04, 0x243e0c2c, 0x243e1c04,
4141
- 0x243e241c, 0x243e2c0c, 0x2c040414, 0x2c040c04, 0x2c040c24, 0x2c041414, 0x2c042404, 0x2c042424,
4142
- 0x2c04243e, 0x2c042c14, 0x2c043434, 0x2c043e24, 0x2c0c040c, 0x2c0c041c, 0x2c0c042c, 0x2c0c0c14,
4143
- 0x2c0c140c, 0x2c0c1c14, 0x2c0c3e14, 0x2c140404, 0x2c140c0c, 0x2c14141c, 0x2c141c04, 0x2c141c34,
4144
- 0x2c142c1c, 0x2c1c0414, 0x2c1c043e, 0x2c1c0c04, 0x2c1c143e, 0x2c1c2424, 0x2c1c2c0c, 0x2c1c342c,
4145
- 0x2c1c3e1c, 0x2c24040c, 0x2c240424, 0x2c241404, 0x2c241c14, 0x2c242434, 0x2c2c0c14, 0x2c2c1434,
4146
- 0x2c2c2c0c, 0x2c2c2c1c, 0x2c342414, 0x2c3e0414, 0x2c3e0424, 0x2c3e1414, 0x34040c0c, 0x34040c1c,
4147
- 0x34040c2c, 0x34041c0c, 0x34041c1c, 0x34043404, 0x340c0404, 0x340c1404, 0x340c143e, 0x340c3424,
4148
- 0x34140c14, 0x34141c24, 0x34142414, 0x34142c2c, 0x34143414, 0x34143e04, 0x341c0404, 0x341c0c24,
4149
- 0x341c140c, 0x341c2404, 0x3424142c, 0x3424241c, 0x34243414, 0x342c0404, 0x342c041c, 0x342c1c24,
4150
- 0x342c3404, 0x3434042c, 0x34342404, 0x343e0c0c, 0x343e0c1c, 0x3e040404, 0x3e040424, 0x3e04043e,
4151
- 0x3e041404, 0x3e041414, 0x3e041c34, 0x3e042404, 0x3e042c24, 0x3e043414, 0x3e0c0414, 0x3e0c0c0c,
4152
- 0x3e0c1424, 0x3e0c241c, 0x3e0c242c, 0x3e14040c, 0x3e140424, 0x3e140c04, 0x3e140c34, 0x3e14140c,
4153
- 0x3e141c04, 0x3e142c0c, 0x3e1c0414, 0x3e1c1c14, 0x3e1c1c2c, 0x3e1c2c1c, 0x3e24040c, 0x3e24042c,
4154
- 0x3e240c1c, 0x3e241404, 0x3e242c04, 0x3e2c1414, 0x3e2c2414, 0x3e340414, 0x3e341c0c, 0x3e3e0404,
4133
+ constexpr constant static uint32_t iq3s_grid[512] = {
4134
+ 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
4135
+ 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
4136
+ 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09,
4137
+ 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b,
4138
+ 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b,
4139
+ 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d,
4140
+ 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03,
4141
+ 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505,
4142
+ 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03,
4143
+ 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901,
4144
+ 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d,
4145
+ 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303,
4146
+ 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501,
4147
+ 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105,
4148
+ 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505,
4149
+ 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101,
4150
+ 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707,
4151
+ 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b,
4152
+ 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01,
4153
+ 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f,
4154
+ 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305,
4155
+ 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103,
4156
+ 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509,
4157
+ 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503,
4158
+ 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b,
4159
+ 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f,
4160
+ 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f,
4161
+ 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f,
4162
+ 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109,
4163
+ 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f,
4164
+ 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509,
4165
+ 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501,
4166
+ 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303,
4167
+ 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f,
4168
+ 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907,
4169
+ 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703,
4170
+ 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03,
4171
+ 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01,
4172
+ 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01,
4173
+ 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903,
4174
+ 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505,
4175
+ 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b,
4176
+ 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107,
4177
+ 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509,
4178
+ 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303,
4179
+ 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103,
4180
+ 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05,
4181
+ 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b,
4182
+ 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f,
4183
+ 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701,
4184
+ 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909,
4185
+ 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305,
4186
+ 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d,
4187
+ 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b,
4188
+ 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d,
4189
+ 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307,
4190
+ 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09,
4191
+ 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309,
4192
+ 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709,
4193
+ 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f,
4194
+ 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303,
4195
+ 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503,
4196
+ 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
4197
+ 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
4155
4198
  };
4156
4199
 
4157
4200
  #define NGRID_IQ1S 512
@@ -4742,7 +4785,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
4742
4785
  {
4743
4786
  int nval = 8;
4744
4787
  int pos = (32*sgitg + tiisg)*nval;
4745
- for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
4788
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
4746
4789
  threadgroup_barrier(mem_flags::mem_threadgroup);
4747
4790
  }
4748
4791
 
@@ -4769,12 +4812,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
4769
4812
  for (int row = 0; row < N_DST; row++) {
4770
4813
 
4771
4814
  const float db = dh[0];
4772
- const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
4815
+ const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
4773
4816
 
4774
4817
  float2 sum = {0};
4775
4818
  for (int l = 0; l < 4; ++l) {
4776
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256)));
4777
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256)));
4819
+ const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
4820
+ const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
4821
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
4822
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
4778
4823
  for (int j = 0; j < 4; ++j) {
4779
4824
  sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
4780
4825
  sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
@@ -4795,7 +4840,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
4795
4840
  for (int row = 0; row < N_DST; ++row) {
4796
4841
  all_sum = simd_sum(sumf[row]);
4797
4842
  if (tiisg == 0) {
4798
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
4843
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
4799
4844
  }
4800
4845
  }
4801
4846
  }
@@ -5685,15 +5730,15 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 &
5685
5730
  device const uint8_t * qs = xb->qs + 8*ib32;
5686
5731
  device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
5687
5732
  const uint8_t qh = xb->qh[ib32] >> 4*il;
5688
- const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
5689
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5690
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5733
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
5734
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
5735
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
5691
5736
  for (int i = 0; i < 4; ++i) {
5692
5737
  reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
5693
5738
  reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
5694
5739
  }
5695
- grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5696
- grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5740
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
5741
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
5697
5742
  for (int i = 0; i < 4; ++i) {
5698
5743
  reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
5699
5744
  reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
@@ -2231,7 +2231,7 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
2231
2231
  GGML_UNUSED(backend);
2232
2232
  }
2233
2233
 
2234
- static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
2234
+ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
2235
2235
  for (int i = 0; i < graph->n_nodes; ++i) {
2236
2236
  ggml_tensor * node = graph->nodes[i];
2237
2237
  switch (node->op) {
@@ -2246,7 +2246,7 @@ static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgrap
2246
2246
  }
2247
2247
  }
2248
2248
 
2249
- return true;
2249
+ return GGML_STATUS_SUCCESS;
2250
2250
 
2251
2251
  GGML_UNUSED(backend);
2252
2252
  }