llama_cpp 0.15.1 → 0.15.2

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,24 @@
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+ #include "ggml-backend.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ #define GGML_RPC_MAX_SERVERS 16
11
+
12
+ // backend API
13
+ GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
14
+ GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
15
+
16
+ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
17
+
18
+ GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19
+
20
+ GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21
+
22
+ #ifdef __cplusplus
23
+ }
24
+ #endif
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
3154
3154
  #define SYCL_SCALE_BLOCK_SIZE 256
3155
3155
  #define SYCL_CLAMP_BLOCK_SIZE 256
3156
3156
  #define SYCL_ROPE_BLOCK_SIZE 256
3157
- #define SYCL_ALIBI_BLOCK_SIZE 32
3158
3157
  #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
3159
3158
  #define SYCL_QUANTIZE_BLOCK_SIZE 256
3160
3159
  #define SYCL_DEQUANTIZE_BLOCK_SIZE 256
@@ -9316,32 +9315,6 @@ static void rope_glm_f32(
9316
9315
  dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
9317
9316
  }
9318
9317
 
9319
- static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
9320
- const int n_heads_log2_floor, const float m0, const float m1,
9321
- const sycl::nd_item<3> &item_ct1) {
9322
- const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
9323
- item_ct1.get_local_id(2);
9324
-
9325
- if (col >= ncols) {
9326
- return;
9327
- }
9328
-
9329
- const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
9330
- item_ct1.get_local_id(1);
9331
- const int i = row*ncols + col;
9332
-
9333
- const int k = row/k_rows;
9334
-
9335
- float m_k;
9336
- if (k < n_heads_log2_floor) {
9337
- m_k = dpct::pow(m0, k + 1);
9338
- } else {
9339
- m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
9340
- }
9341
-
9342
- dst[i] = col * m_k + x[i];
9343
- }
9344
-
9345
9318
  static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
9346
9319
  const sycl::nd_item<3> &item_ct1) {
9347
9320
  const int row = item_ct1.get_group(1);
@@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
9443
9416
 
9444
9417
 
9445
9418
  template <bool vals_smem, int ncols_template, int block_size_template>
9446
- static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
9419
+ static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
9447
9420
  const int nrows_y, const float scale, const float max_bias, const float m0,
9448
9421
  const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
9449
9422
  const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9457
9430
  const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
9458
9431
  const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
9459
9432
 
9460
- float slope = 0.0f;
9433
+ float slope = 1.0f;
9461
9434
 
9462
9435
  // ALiBi
9463
9436
  if (max_bias > 0.0f) {
@@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
9482
9455
  const int ix = rowx*ncols + col;
9483
9456
  const int iy = rowy*ncols + col;
9484
9457
 
9485
- const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
9458
+ const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
9486
9459
 
9487
9460
  vals[col] = val;
9488
9461
  max_val = sycl::max(max_val, val);
@@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
12964
12937
  });
12965
12938
  }
12966
12939
 
12967
- static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
12968
- const int nrows, const int k_rows,
12969
- const int n_heads_log2_floor, const float m0,
12970
- const float m1, dpct::queue_ptr stream) {
12971
- const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
12972
- const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
12973
- const sycl::range<3> block_nums(1, nrows, num_blocks_x);
12974
- stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
12975
- [=](sycl::nd_item<3> item_ct1) {
12976
- alibi_f32(x, dst, ncols, k_rows,
12977
- n_heads_log2_floor, m0, m1, item_ct1);
12978
- });
12979
- }
12980
-
12981
12940
  static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
12982
12941
  const int nrows, dpct::queue_ptr stream) {
12983
12942
  const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
13058
13017
  }
13059
13018
 
13060
13019
  template <bool vals_smem, int ncols_template, int block_size_template>
13061
- static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
13020
+ static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
13062
13021
  const int nrows_y, const float scale, const float max_bias, const float m0,
13063
13022
  const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
13064
13023
  const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13068
13027
  cgh.parallel_for(
13069
13028
  sycl::nd_range<3>(block_nums * block_dims, block_dims),
13070
13029
  [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
13071
- soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
13030
+ soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
13072
13031
  nrows_y, scale, max_bias, m0,
13073
13032
  m1, n_head_log2, item_ct1,
13074
13033
  local_buf_acc.get_pointer());
@@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
13076
13035
  });
13077
13036
  }
13078
13037
 
13079
- static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
13038
+ static void soft_max_f32_sycl(const float * x, const float * mask,
13080
13039
  float * dst, const int ncols_x, const int nrows_x,
13081
13040
  const int nrows_y, const float scale, const float max_bias,
13082
13041
  dpct::queue_ptr stream) {
@@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
13098
13057
  const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
13099
13058
  if (n_local_scratch*sizeof(float) < local_mem_size) {
13100
13059
  if (ncols_x > max_block_size) {
13101
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13060
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13102
13061
  max_bias, m0, m1, n_head_log2, block_nums,
13103
13062
  block_dims, n_local_scratch, stream);
13104
13063
  return;
13105
13064
  }
13106
13065
  switch (ncols_x) {
13107
13066
  case 32:
13108
- soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13067
+ soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
13109
13068
  max_bias, m0, m1, n_head_log2, block_nums,
13110
13069
  block_dims, n_local_scratch, stream);
13111
13070
  break;
13112
13071
  case 64:
13113
- soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13072
+ soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
13114
13073
  max_bias, m0, m1, n_head_log2, block_nums,
13115
13074
  block_dims, n_local_scratch, stream);
13116
13075
  break;
13117
13076
  case 128:
13118
- soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13077
+ soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
13119
13078
  max_bias, m0, m1, n_head_log2, block_nums,
13120
13079
  block_dims, n_local_scratch, stream);
13121
13080
  break;
13122
13081
  case 256:
13123
- soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13082
+ soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
13124
13083
  max_bias, m0, m1, n_head_log2, block_nums,
13125
13084
  block_dims, n_local_scratch, stream);
13126
13085
  break;
13127
13086
  case 512:
13128
- soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13087
+ soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
13129
13088
  max_bias, m0, m1, n_head_log2, block_nums,
13130
13089
  block_dims, n_local_scratch, stream);
13131
13090
  break;
13132
13091
  case 1024:
13133
- soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13092
+ soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13134
13093
  max_bias, m0, m1, n_head_log2, block_nums,
13135
13094
  block_dims, n_local_scratch, stream);
13136
13095
  break;
13137
13096
  case 2048:
13138
- soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13097
+ soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13139
13098
  max_bias, m0, m1, n_head_log2, block_nums,
13140
13099
  block_dims, n_local_scratch, stream);
13141
13100
  break;
13142
13101
  case 4096:
13143
- soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13102
+ soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
13144
13103
  max_bias, m0, m1, n_head_log2, block_nums,
13145
13104
  block_dims, n_local_scratch, stream);
13146
13105
  break;
13147
13106
  default:
13148
- soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13107
+ soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13149
13108
  max_bias, m0, m1, n_head_log2, block_nums,
13150
13109
  block_dims, n_local_scratch, stream);
13151
13110
  break;
13152
13111
  }
13153
13112
  } else {
13154
- soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
13113
+ soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
13155
13114
  max_bias, m0, m1, n_head_log2, block_nums,
13156
13115
  block_dims, WARP_SIZE, stream);
13157
13116
  }
@@ -14028,6 +13987,10 @@ inline void ggml_sycl_op_upscale(const ggml_tensor *src0,
14028
13987
  GGML_ASSERT(dst->type == GGML_TYPE_F32);
14029
13988
  GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
14030
13989
 
13990
+ #pragma message("TODO: generalize upscale operator")
13991
+ #pragma message(" https://github.com/ggerganov/ggml/pull/814")
13992
+ GGML_ASSERT(false && "TODO: generalize upscale operator");
13993
+
14031
13994
  const int scale_factor = dst->op_params[0];
14032
13995
 
14033
13996
  upscale_f32_sycl(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
@@ -14562,36 +14525,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
14562
14525
  (void) src1_dd;
14563
14526
  }
14564
14527
 
14565
- inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
14566
- ggml_tensor *dst, const float *src0_dd,
14567
- const float *src1_dd, float *dst_dd,
14568
- const dpct::queue_ptr &main_stream) {
14569
-
14570
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
14571
- GGML_ASSERT( dst->type == GGML_TYPE_F32);
14572
-
14573
- GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
14574
- const int64_t nrows = ggml_nrows(src0);
14575
-
14576
- //const int n_past = ((int32_t *) dst->op_params)[0];
14577
- const int n_head = ((int32_t *) dst->op_params)[1];
14578
- float max_bias;
14579
- memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
14580
-
14581
- //GGML_ASSERT(ne01 + n_past == ne00);
14582
- GGML_ASSERT(n_head == ne02);
14583
-
14584
- const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
14585
-
14586
- const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
14587
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
14588
-
14589
- alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
14590
-
14591
- (void) src1;
14592
- (void) src1_dd;
14593
- }
14594
-
14595
14528
  static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
14596
14529
  const ggml_tensor *src1, ggml_tensor *dst,
14597
14530
  const float *src0_dd, const float *src1_dd,
@@ -14746,12 +14679,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14746
14679
  GGML_ASSERT(src0->type == GGML_TYPE_F32);
14747
14680
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14748
14681
 
14749
- const ggml_tensor * src2 = dst->src[2];
14750
-
14751
- #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
14682
+ #pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
14752
14683
  #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
14753
14684
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
14754
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
14755
14685
 
14756
14686
  const int64_t ne00 = src0->ne[0];
14757
14687
  const int64_t nrows_x = ggml_nrows(src0);
@@ -14763,25 +14693,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
14763
14693
  memcpy(&scale, dst->op_params + 0, sizeof(float));
14764
14694
  memcpy(&max_bias, dst->op_params + 1, sizeof(float));
14765
14695
 
14766
- // positions tensor
14767
- float * src2_dd = nullptr;
14768
- sycl_pool_alloc<float> src2_f;
14769
-
14770
- const bool use_src2 = src2 != nullptr;
14771
-
14772
- if (use_src2) {
14773
- const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
14774
-
14775
- if (src2_on_device) {
14776
- ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
14777
- src2_dd = (float *) src2_extra->data_device[g_main_device];
14778
- } else {
14779
- src2_dd = src2_f.alloc(ggml_nelements(src2));
14780
- SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
14781
- }
14782
- }
14783
-
14784
- soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
14696
+ soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
14785
14697
  nrows_x, nrows_y, scale, max_bias, main_stream);
14786
14698
  }
14787
14699
 
@@ -15656,26 +15568,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15656
15568
  const int64_t r2 = ne12/ne02;
15657
15569
  const int64_t r3 = ne13/ne03;
15658
15570
 
15659
- #if 0
15660
- // use syclGemmEx
15661
- {
15662
- for (int i13 = 0; i13 < ne13; ++i13) {
15663
- for (int i12 = 0; i12 < ne12; ++i12) {
15664
- int i03 = i13 / r3;
15665
- int i02 = i12 / r2;
15666
-
15667
- SYCL_CHECK(
15668
- syclGemmEx(g_sycl_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
15669
- ne01, ne11, ne10,
15670
- alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , SYCL_R_16F, nb01/sizeof(half),
15671
- (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, SYCL_R_16F, nb11/sizeof(float),
15672
- beta, ( char *) dst_t + i12*nbd2 + i13*nbd3, cu_data_type, ne01,
15673
- cu_compute_type,
15674
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
15675
- }
15676
- }
15677
- }
15678
- #else
15679
15571
  if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
15680
15572
  // there is no broadcast and src0, src1 are contiguous across dims 2, 3
15681
15573
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
@@ -15687,7 +15579,6 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15687
15579
  nb11 / nb10, nb12 / nb10, beta,
15688
15580
  (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
15689
15581
  ne12 * ne13, cu_compute_type)));
15690
- g_sycl_handles[g_main_device]->wait();
15691
15582
  } else {
15692
15583
  const int ne23 = ne12*ne13;
15693
15584
 
@@ -15718,7 +15609,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15718
15609
  nb02, nb03, nb12_scaled, nb13_scaled,
15719
15610
  nbd2, nbd3, r2, r3, item_ct1);
15720
15611
  });
15721
- }).wait();
15612
+ });
15722
15613
  }
15723
15614
  SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
15724
15615
  *g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
@@ -15729,9 +15620,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
15729
15620
  dpct::library_data_t::real_half, nb11 / nb10, beta,
15730
15621
  (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
15731
15622
  cu_compute_type)));
15732
- g_sycl_handles[g_main_device]->wait();
15733
15623
  }
15734
- #endif
15735
15624
 
15736
15625
  if (no_mixed_dtypes) {
15737
15626
  const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
@@ -16232,10 +16121,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g
16232
16121
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
16233
16122
  }
16234
16123
 
16235
- static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16236
- ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
16237
- }
16238
-
16239
16124
  static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
16240
16125
  ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
16241
16126
  }
@@ -16612,9 +16497,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
16612
16497
  case GGML_OP_ROPE:
16613
16498
  func = ggml_sycl_rope;
16614
16499
  break;
16615
- case GGML_OP_ALIBI:
16616
- func = ggml_sycl_alibi;
16617
- break;
16618
16500
  case GGML_OP_IM2COL:
16619
16501
  func = ggml_sycl_im2col;
16620
16502
  break;
@@ -17744,7 +17626,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17744
17626
  case GGML_OP_DIAG_MASK_INF:
17745
17627
  case GGML_OP_SOFT_MAX:
17746
17628
  case GGML_OP_ROPE:
17747
- case GGML_OP_ALIBI:
17748
17629
  case GGML_OP_IM2COL:
17749
17630
  case GGML_OP_POOL_2D:
17750
17631
  case GGML_OP_SUM_ROWS:
@@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3830
3830
  return nullptr;
3831
3831
  case GGML_OP_SOFT_MAX:
3832
3832
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
3833
- GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
3834
3833
 
3835
- if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3834
+ if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3836
3835
  return ctx->device->pipeline_soft_max_f32;
3837
3836
  }
3838
3837
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
@@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
4286
4285
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4287
4286
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4288
4287
 
4288
+ #pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
4289
+ #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/7192")
4290
+
4289
4291
  ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
4290
4292
  ncols,
4291
4293
  src1 != nullptr ? nrows_y : (uint32_t)0,