llama_cpp 0.15.1 → 0.15.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-backend.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ #define GGML_RPC_MAX_SERVERS 16
11
+
12
+ // backend API
13
+ GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
14
+ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
15
+
16
+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
17
+
18
+ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19
+
20
+ GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21
+
22
+ #ifdef __cplusplus
23
+ }
24
+ #endif
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3157
- #define SYCL_ALIBI_BLOCK_SIZE 32
3158
3157
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3159
3158
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
3160
3159
  #define SYCL_DEQUANTIZE_BLOCK_SIZE 256
@@ -9316,32 +9315,6 @@ static void rope_glm_f32(
9316
9315
  dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
9317
9316
  }
9318
9317
 
9319
- static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
9320
- const int n_heads_log2_floor, const float m0, const float m1,
9321
- const sycl::nd_item<3> &item_ct1) {
9322
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9323
- item_ct1.get_local_id(2);
9324
-
9325
- if (col >= ncols) {
9326
- return;
9327
- }
9328
-
9329
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
9330
- item_ct1.get_local_id(1);
9331
- const int i = row*ncols + col;
9332
-
9333
- const int k = row/k_rows;
9334
-
9335
- float m_k;
9336
- if (k < n_heads_log2_floor) {
9337
- m_k = dpct::pow(m0, k + 1);
9338
- } else {
9339
- m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
9340
- }
9341
-
9342
- dst[i] = col * m_k + x[i];
9343
- }
9344
-
9345
9318
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
9346
9319
  const sycl::nd_item<3> &item_ct1) {
9347
9320
  const int row = item_ct1.get_group(1);
@@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
9443
9416
 
9444
9417
 
9445
9418
  template <bool vals_smem, int ncols_template, int block_size_template>
9446
- static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9419
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
9447
9420
  const int nrows_y, const float scale, const float max_bias, const float m0,
9448
9421
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
9449
9422
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9457
9430
  const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
9458
9431
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
9459
9432
 
9460
- float slope = 0.0f;
9433
+ float slope = 1.0f;
9461
9434
 
9462
9435
  // ALiBi
9463
9436
  if (max_bias > 0.0f) {
@@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9482
9455
  const int ix = rowx*ncols + col;
9483
9456
  const int iy = rowy*ncols + col;
9484
9457
 
9485
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
9458
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
9486
9459
 
9487
9460
  vals[col] = val;
9488
9461
  max_val = sycl::max(max_val, val);
@@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12964
12937
  });
12965
12938
  }
12966
12939
 
12967
- static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
12968
- const int nrows, const int k_rows,
12969
- const int n_heads_log2_floor, const float m0,
12970
- const float m1, dpct::queue_ptr stream) {
12971
- const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
12972
- const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
12973
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12974
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12975
- [=](sycl::nd_item<3> item_ct1) {
12976
- alibi_f32(x, dst, ncols, k_rows,
12977
- n_heads_log2_floor, m0, m1, item_ct1);
12978
- });
12979
- }
12980
-
12981
12940
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12982
12941
  const int nrows, dpct::queue_ptr stream) {
12983
12942
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
13058
13017
  }
13059
13018
 
13060
13019
  template <bool vals_smem, int ncols_template, int block_size_template>
13061
- static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13020
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
13062
13021
  const int nrows_y, const float scale, const float max_bias, const float m0,
13063
13022
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
13064
13023
  const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13068
13027
  cgh.parallel_for(
13069
13028
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
13070
13029
  [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13071
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13030
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
13072
13031
  nrows_y, scale, max_bias, m0,
13073
13032
  m1, n_head_log2, item_ct1,
13074
13033
  local_buf_acc.get_pointer());
@@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13076
13035
  });
13077
13036
  }
13078
13037
 
13079
- static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13038
+ static void soft_max_f32_sycl(const float * x, const float * mask,
13080
13039
  float * dst, const int ncols_x, const int nrows_x,
13081
13040
  const int nrows_y, const float scale, const float max_bias,
13082
13041
  dpct::queue_ptr stream) {
@@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13098
13057
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13099
13058
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13100
13059
  if (ncols_x > max_block_size) {
13101
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13060
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13102
13061
  max_bias, m0, m1, n_head_log2, block_nums,
13103
13062
  block_dims, n_local_scratch, stream);
13104
13063
  return;
13105
13064
  }
13106
13065
  switch (ncols_x) {
13107
13066
  case 32:
13108
- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13067
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
13109
13068
  max_bias, m0, m1, n_head_log2, block_nums,
13110
13069
  block_dims, n_local_scratch, stream);
13111
13070
  break;
13112
13071
  case 64:
13113
- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13072
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
13114
13073
  max_bias, m0, m1, n_head_log2, block_nums,
13115
13074
  block_dims, n_local_scratch, stream);
13116
13075
  break;
13117
13076
  case 128:
13118
- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13077
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
13119
13078
  max_bias, m0, m1, n_head_log2, block_nums,
13120
13079
  block_dims, n_local_scratch, stream);
13121
13080
  break;
13122
13081
  case 256:
13123
- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13082
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
13124
13083
  max_bias, m0, m1, n_head_log2, block_nums,
13125
13084
  block_dims, n_local_scratch, stream);
13126
13085
  break;
13127
13086
  case 512:
13128
- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13087
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
13129
13088
  max_bias, m0, m1, n_head_log2, block_nums,
13130
13089
  block_dims, n_local_scratch, stream);
13131
13090
  break;
13132
13091
  case 1024:
13133
- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13092
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13134
13093
  max_bias, m0, m1, n_head_log2, block_nums,
13135
13094
  block_dims, n_local_scratch, stream);
13136
13095
  break;
13137
13096
  case 2048:
13138
- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13097
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13139
13098
  max_bias, m0, m1, n_head_log2, block_nums,
13140
13099
  block_dims, n_local_scratch, stream);
13141
13100
  break;
13142
13101
  case 4096:
13143
- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13102
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13144
13103
  max_bias, m0, m1, n_head_log2, block_nums,
13145
13104
  block_dims, n_local_scratch, stream);
13146
13105
  break;
13147
13106
  default:
13148
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13107
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13149
13108
  max_bias, m0, m1, n_head_log2, block_nums,
13150
13109
  block_dims, n_local_scratch, stream);
13151
13110
  break;
13152
13111
  }
13153
13112
  } else {
13154
- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13113
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13155
13114
  max_bias, m0, m1, n_head_log2, block_nums,
13156
13115
  block_dims, WARP_SIZE, stream);
13157
13116
  }
@@ -14028,6 +13987,10 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
14028
13987
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
14029
13988
  GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
14030
13989
 
13990
+ #pragma message("TODO: generalize upscale operator")
13991
+ #pragma message(" https://github.com/ggerganov/ggml/pull/814")
13992
+ GGML_ASSERT(false && "TODO: generalize upscale operator");
13993
+
14031
13994
  const int scale_factor = dst->op_params[0];
14032
13995
 
14033
13996
  upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
@@ -14562,36 +14525,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14562
14525
  (void) src1_dd;
14563
14526
  }
14564
14527
 
14565
- inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
14566
- ggml_tensor *dst, const float *src0_dd,
14567
- const float *src1_dd, float *dst_dd,
14568
- const dpct::queue_ptr &main_stream) {
14569
-
14570
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
14571
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
14572
-
14573
- GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
14574
- const int64_t nrows = ggml_nrows(src0);
14575
-
14576
- //const int n_past = ((int32_t *) dst->op_params)[0];
14577
- const int n_head = ((int32_t *) dst->op_params)[1];
14578
- float max_bias;
14579
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
14580
-
14581
- //GGML_ASSERT(ne01 + n_past == ne00);
14582
- GGML_ASSERT(n_head == ne02);
14583
-
14584
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
14585
-
14586
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
14587
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
14588
-
14589
- alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
14590
-
14591
- (void) src1;
14592
- (void) src1_dd;
14593
- }
14594
-
14595
14528
  static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
14596
14529
  const ggml_tensor *src1, ggml_tensor *dst,
14597
14530
  const float *src0_dd, const float *src1_dd,
@@ -14746,12 +14679,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14746
14679
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14747
14680
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14748
14681
 
14749
- const ggml_tensor * src2 = dst->src[2];
14750
-
14751
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14682
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
14752
14683
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14753
14684
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14754
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14755
14685
 
14756
14686
  const int64_t ne00 = src0->ne[0];
14757
14687
  const int64_t nrows_x = ggml_nrows(src0);
@@ -14763,25 +14693,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14763
14693
  memcpy(&scale, dst->op_params + 0, sizeof(float));
14764
14694
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
14765
14695
 
14766
- // positions tensor
14767
- float * src2_dd = nullptr;
14768
- sycl_pool_alloc<float> src2_f;
14769
-
14770
- const bool use_src2 = src2 != nullptr;
14771
-
14772
- if (use_src2) {
14773
- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14774
-
14775
- if (src2_on_device) {
14776
- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14777
- src2_dd = (float *) src2_extra->data_device[g_main_device];
14778
- } else {
14779
- src2_dd = src2_f.alloc(ggml_nelements(src2));
14780
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14781
- }
14782
- }
14783
-
14784
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14696
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
14785
14697
  nrows_x, nrows_y, scale, max_bias, main_stream);
14786
14698
  }
14787
14699
 
@@ -15656,26 +15568,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15656
15568
  const int64_t r2 = ne12/ne02;
15657
15569
  const int64_t r3 = ne13/ne03;
15658
15570
 
15659
- #if 0
15660
- // use syclGemmEx
15661
- {
15662
- for (int i13 = 0; i13 < ne13; ++i13) {
15663
- for (int i12 = 0; i12 < ne12; ++i12) {
15664
- int i03 = i13 / r3;
15665
- int i02 = i12 / r2;
15666
-
15667
- SYCL_CHECK(
15668
- syclGemmEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
15669
- ne01, ne11, ne10,
15670
- alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , SYCL_R_16F, nb01/sizeof(half),
15671
- (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, SYCL_R_16F, nb11/sizeof(float),
15672
- beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
15673
- cu_compute_type,
15674
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
15675
- }
15676
- }
15677
- }
15678
- #else
15679
15571
  if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15680
15572
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15681
15573
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
@@ -15687,7 +15579,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15687
15579
  nb11 / nb10, nb12 / nb10, beta,
15688
15580
  (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
15689
15581
  ne12 * ne13, cu_compute_type)));
15690
- g_sycl_handles[g_main_device]->wait();
15691
15582
  } else {
15692
15583
  const int ne23 = ne12*ne13;
15693
15584
 
@@ -15718,7 +15609,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15718
15609
  nb02, nb03, nb12_scaled, nb13_scaled,
15719
15610
  nbd2, nbd3, r2, r3, item_ct1);
15720
15611
  });
15721
- }).wait();
15612
+ });
15722
15613
  }
15723
15614
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15724
15615
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15729,9 +15620,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15729
15620
  dpct::library_data_t::real_half, nb11 / nb10, beta,
15730
15621
  (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
15731
15622
  cu_compute_type)));
15732
- g_sycl_handles[g_main_device]->wait();
15733
15623
  }
15734
- #endif
15735
15624
 
15736
15625
  if (no_mixed_dtypes) {
15737
15626
  const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
@@ -16232,10 +16121,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g
16232
16121
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
16233
16122
  }
16234
16123
 
16235
- static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16236
- ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
16237
- }
16238
-
16239
16124
  static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16240
16125
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
16241
16126
  }
@@ -16612,9 +16497,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
16612
16497
  case GGML_OP_ROPE:
16613
16498
  func = ggml_sycl_rope;
16614
16499
  break;
16615
- case GGML_OP_ALIBI:
16616
- func = ggml_sycl_alibi;
16617
- break;
16618
16500
  case GGML_OP_IM2COL:
16619
16501
  func = ggml_sycl_im2col;
16620
16502
  break;
@@ -17744,7 +17626,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17744
17626
  case GGML_OP_DIAG_MASK_INF:
17745
17627
  case GGML_OP_SOFT_MAX:
17746
17628
  case GGML_OP_ROPE:
17747
- case GGML_OP_ALIBI:
17748
17629
  case GGML_OP_IM2COL:
17749
17630
  case GGML_OP_POOL_2D:
17750
17631
  case GGML_OP_SUM_ROWS:
@@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3830
3830
  return nullptr;
3831
3831
  case GGML_OP_SOFT_MAX:
3832
3832
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
3833
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
3834
3833
 
3835
- if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3834
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3836
3835
  return ctx->device->pipeline_soft_max_f32;
3837
3836
  }
3838
3837
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
@@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
4286
4285
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4287
4286
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4288
4287
 
4288
+ #pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
4289
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
4290
+
4289
4291
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
4290
4292
  ncols,
4291
4293
  src1 != nullptr ? nrows_y : (uint32_t)0,