whisper.rn 0.4.0-rc.5 → 0.4.0-rc.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cpp/ggml.c CHANGED
@@ -1,4 +1,4 @@
1
- #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
1
+ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
2
2
  #define _USE_MATH_DEFINES // For M_PI on MSVC
3
3
 
4
4
  #include "ggml-impl.h"
@@ -33,7 +33,7 @@
33
33
  // we should just be careful :)
34
34
  #pragma warning(disable: 4244 4267)
35
35
 
36
- // disable POSIX deprecation warnigns
36
+ // disable POSIX deprecation warnings
37
37
  // these functions are never going away, anyway
38
38
  #pragma warning(disable: 4996)
39
39
  #endif
@@ -1395,7 +1395,7 @@ inline static void wsp_ggml_vec_step_f32 (const int n, float * y, const float *
1395
1395
  inline static void wsp_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
1396
1396
  inline static void wsp_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
1397
1397
  inline static void wsp_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
1398
- inline static void wsp_ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
1398
+ inline static void wsp_ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
1399
1399
 
1400
1400
  static const float GELU_COEF_A = 0.044715f;
1401
1401
  static const float GELU_QUICK_COEF = -1.702f;
@@ -1623,7 +1623,9 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1623
1623
  "POOL_1D",
1624
1624
  "POOL_2D",
1625
1625
  "UPSCALE",
1626
+ "PAD",
1626
1627
  "ARGSORT",
1628
+ "LEAKY_RELU",
1627
1629
 
1628
1630
  "FLASH_ATTN",
1629
1631
  "FLASH_FF",
@@ -1650,7 +1652,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = {
1650
1652
  "CROSS_ENTROPY_LOSS_BACK",
1651
1653
  };
1652
1654
 
1653
- static_assert(WSP_GGML_OP_COUNT == 70, "WSP_GGML_OP_COUNT != 70");
1655
+ static_assert(WSP_GGML_OP_COUNT == 72, "WSP_GGML_OP_COUNT != 72");
1654
1656
 
1655
1657
  static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1656
1658
  "none",
@@ -1707,7 +1709,9 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1707
1709
  "pool_1d(x)",
1708
1710
  "pool_2d(x)",
1709
1711
  "upscale(x)",
1712
+ "pad(x)",
1710
1713
  "argsort(x)",
1714
+ "leaky_relu(x)",
1711
1715
 
1712
1716
  "flash_attn(x)",
1713
1717
  "flash_ff(x)",
@@ -1734,7 +1738,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = {
1734
1738
  "cross_entropy_loss_back(x,y)",
1735
1739
  };
1736
1740
 
1737
- static_assert(WSP_GGML_OP_COUNT == 70, "WSP_GGML_OP_COUNT != 70");
1741
+ static_assert(WSP_GGML_OP_COUNT == 72, "WSP_GGML_OP_COUNT != 72");
1738
1742
 
1739
1743
  static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2");
1740
1744
 
@@ -1750,17 +1754,16 @@ static const char * WSP_GGML_UNARY_OP_NAME[WSP_GGML_UNARY_OP_COUNT] = {
1750
1754
  "GELU",
1751
1755
  "GELU_QUICK",
1752
1756
  "SILU",
1753
- "LEAKY",
1754
1757
  };
1755
1758
 
1756
- static_assert(WSP_GGML_UNARY_OP_COUNT == 11, "WSP_GGML_UNARY_OP_COUNT != 11");
1759
+ static_assert(WSP_GGML_UNARY_OP_COUNT == 10, "WSP_GGML_UNARY_OP_COUNT != 10");
1757
1760
 
1758
1761
 
1759
1762
  static_assert(sizeof(struct wsp_ggml_object)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_object size must be a multiple of WSP_GGML_MEM_ALIGN");
1760
1763
  static_assert(sizeof(struct wsp_ggml_tensor)%WSP_GGML_MEM_ALIGN == 0, "wsp_ggml_tensor size must be a multiple of WSP_GGML_MEM_ALIGN");
1761
1764
 
1762
1765
  // WARN:
1763
- // Mis-confguration can lead to problem that's hard to reason about:
1766
+ // Mis-configuration can lead to problem that's hard to reason about:
1764
1767
  // * At best it crash or talks nosense.
1765
1768
  // * At worst it talks slightly difference but hard to perceive.
1766
1769
  //
@@ -3830,12 +3833,25 @@ struct wsp_ggml_tensor * wsp_ggml_relu_inplace(
3830
3833
  return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_RELU);
3831
3834
  }
3832
3835
 
3833
- // wsp_ggml_leaky
3836
+ // wsp_ggml_leaky_relu
3834
3837
 
3835
- struct wsp_ggml_tensor * wsp_ggml_leaky(
3838
+ struct wsp_ggml_tensor * wsp_ggml_leaky_relu(
3836
3839
  struct wsp_ggml_context * ctx,
3837
- struct wsp_ggml_tensor * a) {
3838
- return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_LEAKY);
3840
+ struct wsp_ggml_tensor * a, float negative_slope, bool inplace) {
3841
+ bool is_node = false;
3842
+
3843
+ if (!inplace && (a->grad)) {
3844
+ is_node = true;
3845
+ }
3846
+
3847
+ struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
3848
+ wsp_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
3849
+
3850
+ result->op = WSP_GGML_OP_LEAKY_RELU;
3851
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
3852
+ result->src[0] = a;
3853
+
3854
+ return result;
3839
3855
  }
3840
3856
 
3841
3857
  // wsp_ggml_gelu
@@ -4022,8 +4038,9 @@ static struct wsp_ggml_tensor * wsp_ggml_group_norm_impl(
4022
4038
 
4023
4039
  struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a);
4024
4040
 
4025
- result->op = WSP_GGML_OP_GROUP_NORM;
4026
4041
  result->op_params[0] = n_groups;
4042
+
4043
+ result->op = WSP_GGML_OP_GROUP_NORM;
4027
4044
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4028
4045
  result->src[0] = a;
4029
4046
  result->src[1] = NULL; // TODO: maybe store epsilon here?
@@ -4075,17 +4092,18 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat(
4075
4092
 
4076
4093
  struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
4077
4094
  struct wsp_ggml_context * ctx,
4078
- struct wsp_ggml_tensor * as[],
4095
+ struct wsp_ggml_tensor * const as[],
4096
+ int n_as,
4079
4097
  struct wsp_ggml_tensor * ids,
4080
4098
  int id,
4081
4099
  struct wsp_ggml_tensor * b) {
4082
4100
 
4083
- int64_t n_as = ids->ne[0];
4084
-
4085
4101
  WSP_GGML_ASSERT(ids->type == WSP_GGML_TYPE_I32);
4086
- WSP_GGML_ASSERT(wsp_ggml_is_vector(ids));
4102
+ WSP_GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
4103
+ WSP_GGML_ASSERT(ids->ne[1] == b->ne[1]);
4104
+ WSP_GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
4087
4105
  WSP_GGML_ASSERT(n_as > 0 && n_as <= WSP_GGML_MAX_SRC - 2);
4088
- WSP_GGML_ASSERT(id >= 0 && id < n_as);
4106
+ WSP_GGML_ASSERT(id >= 0 && id < ids->ne[0]);
4089
4107
 
4090
4108
  bool is_node = false;
4091
4109
 
@@ -4097,13 +4115,14 @@ struct wsp_ggml_tensor * wsp_ggml_mul_mat_id(
4097
4115
  struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
4098
4116
 
4099
4117
  wsp_ggml_set_op_params_i32(result, 0, id);
4118
+ wsp_ggml_set_op_params_i32(result, 1, n_as);
4100
4119
 
4101
4120
  result->op = WSP_GGML_OP_MUL_MAT_ID;
4102
4121
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
4103
4122
  result->src[0] = ids;
4104
4123
  result->src[1] = b;
4105
4124
 
4106
- for (int64_t i = 0; i < n_as; i++) {
4125
+ for (int i = 0; i < n_as; i++) {
4107
4126
  struct wsp_ggml_tensor * a = as[i];
4108
4127
  WSP_GGML_ASSERT(wsp_ggml_are_same_shape(as[0], a));
4109
4128
  WSP_GGML_ASSERT(wsp_ggml_can_mul_mat(a, b));
@@ -4731,7 +4750,9 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
4731
4750
  struct wsp_ggml_context * ctx,
4732
4751
  struct wsp_ggml_tensor * a,
4733
4752
  struct wsp_ggml_tensor * b) {
4734
- WSP_GGML_ASSERT(wsp_ggml_is_matrix(a) && wsp_ggml_is_vector(b) && b->type == WSP_GGML_TYPE_I32);
4753
+ WSP_GGML_ASSERT(a->ne[2] == b->ne[1]);
4754
+ WSP_GGML_ASSERT(b->ne[3] == 1);
4755
+ WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32);
4735
4756
 
4736
4757
  bool is_node = false;
4737
4758
 
@@ -4741,7 +4762,7 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows(
4741
4762
 
4742
4763
  // TODO: implement non F32 return
4743
4764
  //struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
4744
- struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0]);
4765
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, WSP_GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
4745
4766
 
4746
4767
  result->op = WSP_GGML_OP_GET_ROWS;
4747
4768
  result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
@@ -5519,6 +5540,30 @@ static struct wsp_ggml_tensor * wsp_ggml_upscale_impl(
5519
5540
  return result;
5520
5541
  }
5521
5542
 
5543
+ struct wsp_ggml_tensor * wsp_ggml_pad(
5544
+ struct wsp_ggml_context * ctx,
5545
+ struct wsp_ggml_tensor * a,
5546
+ int p0, int p1, int p2, int p3) {
5547
+ bool is_node = false;
5548
+
5549
+ if (a->grad) {
5550
+ WSP_GGML_ASSERT(false); // TODO: implement backward
5551
+ is_node = true;
5552
+ }
5553
+
5554
+ struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type,
5555
+ a->ne[0] + p0,
5556
+ a->ne[1] + p1,
5557
+ a->ne[2] + p2,
5558
+ a->ne[3] + p3);
5559
+
5560
+ result->op = WSP_GGML_OP_PAD;
5561
+ result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL;
5562
+ result->src[0] = a;
5563
+
5564
+ return result;
5565
+ }
5566
+
5522
5567
  struct wsp_ggml_tensor * wsp_ggml_upscale(
5523
5568
  struct wsp_ggml_context * ctx,
5524
5569
  struct wsp_ggml_tensor * a,
@@ -7520,7 +7565,7 @@ static void wsp_ggml_compute_forward_acc_f32(
7520
7565
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
7521
7566
 
7522
7567
  // view src0 and dst with these strides and data offset inbytes during acc
7523
- // nb0 is implicitely element_size because src0 and dst are contiguous
7568
+ // nb0 is implicitly element_size because src0 and dst are contiguous
7524
7569
  size_t nb1 = ((int32_t *) dst->op_params)[0];
7525
7570
  size_t nb2 = ((int32_t *) dst->op_params)[1];
7526
7571
  size_t nb3 = ((int32_t *) dst->op_params)[2];
@@ -7714,8 +7759,10 @@ static void wsp_ggml_compute_forward_mul_f32(
7714
7759
  const int ith = params->ith;
7715
7760
  const int nth = params->nth;
7716
7761
 
7762
+ // TODO: OpenCL kernel support broadcast
7717
7763
  #ifdef WSP_GGML_USE_CLBLAST
7718
7764
  if (src1->backend == WSP_GGML_BACKEND_GPU) {
7765
+ WSP_GGML_ASSERT(wsp_ggml_are_same_shape(src0, src1));
7719
7766
  if (ith == 0) {
7720
7767
  wsp_ggml_cl_mul(src0, src1, dst);
7721
7768
  }
@@ -8981,10 +9028,9 @@ static void wsp_ggml_compute_forward_silu(
8981
9028
  } break;
8982
9029
  }
8983
9030
  }
9031
+ // wsp_ggml_compute_forward_leaky_relu
8984
9032
 
8985
- // wsp_ggml_compute_forward_leaky
8986
-
8987
- static void wsp_ggml_compute_forward_leaky_f32(
9033
+ static void wsp_ggml_compute_forward_leaky_relu_f32(
8988
9034
  const struct wsp_ggml_compute_params * params,
8989
9035
  const struct wsp_ggml_tensor * src0,
8990
9036
  struct wsp_ggml_tensor * dst) {
@@ -8998,24 +9044,27 @@ static void wsp_ggml_compute_forward_leaky_f32(
8998
9044
  const int n = wsp_ggml_nrows(src0);
8999
9045
  const int nc = src0->ne[0];
9000
9046
 
9047
+ float negative_slope;
9048
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
9049
+
9001
9050
  assert(dst->nb[0] == sizeof(float));
9002
9051
  assert(src0->nb[0] == sizeof(float));
9003
9052
 
9004
9053
  for (int i = 0; i < n; i++) {
9005
- wsp_ggml_vec_leaky_f32(nc,
9054
+ wsp_ggml_vec_leaky_relu_f32(nc,
9006
9055
  (float *) ((char *) dst->data + i*( dst->nb[1])),
9007
- (float *) ((char *) src0->data + i*(src0->nb[1])));
9056
+ (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
9008
9057
  }
9009
9058
  }
9010
9059
 
9011
- static void wsp_ggml_compute_forward_leaky(
9060
+ static void wsp_ggml_compute_forward_leaky_relu(
9012
9061
  const struct wsp_ggml_compute_params * params,
9013
9062
  const struct wsp_ggml_tensor * src0,
9014
9063
  struct wsp_ggml_tensor * dst) {
9015
9064
  switch (src0->type) {
9016
9065
  case WSP_GGML_TYPE_F32:
9017
9066
  {
9018
- wsp_ggml_compute_forward_leaky_f32(params, src0, dst);
9067
+ wsp_ggml_compute_forward_leaky_relu_f32(params, src0, dst);
9019
9068
  } break;
9020
9069
  default:
9021
9070
  {
@@ -9504,8 +9553,11 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9504
9553
  const int64_t ne0 = dst->ne[0];
9505
9554
  const int64_t ne1 = dst->ne[1];
9506
9555
 
9556
+ // NOTE: with WSP_GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
9557
+ // all the experts for each batch element and the processing would become incredibly slow
9507
9558
  // TODO: find the optimal values for these
9508
- if (wsp_ggml_is_contiguous(src0) &&
9559
+ if (dst->op != WSP_GGML_OP_MUL_MAT_ID &&
9560
+ wsp_ggml_is_contiguous(src0) &&
9509
9561
  wsp_ggml_is_contiguous(src1) &&
9510
9562
  //src0->type == WSP_GGML_TYPE_F32 &&
9511
9563
  src1->type == WSP_GGML_TYPE_F32 &&
@@ -9519,11 +9571,16 @@ static bool wsp_ggml_compute_forward_mul_mat_use_blas(
9519
9571
  }
9520
9572
  #endif
9521
9573
 
9574
+ // off1 = offset in i11 and i1
9575
+ // cne1 = ne11 and ne1
9576
+ // in a normal matrix multiplication, off1 = 0 and cne1 = ne1
9577
+ // during WSP_GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
9522
9578
  static void wsp_ggml_compute_forward_mul_mat(
9523
9579
  const struct wsp_ggml_compute_params * params,
9524
9580
  const struct wsp_ggml_tensor * src0,
9525
9581
  const struct wsp_ggml_tensor * src1,
9526
- struct wsp_ggml_tensor * dst) {
9582
+ struct wsp_ggml_tensor * dst,
9583
+ int64_t off1, int64_t cne1) {
9527
9584
  int64_t t0 = wsp_ggml_perf_time_us();
9528
9585
  UNUSED(t0);
9529
9586
 
@@ -9591,10 +9648,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9591
9648
  const int64_t i03 = i13/r3;
9592
9649
  const int64_t i02 = i12/r2;
9593
9650
 
9594
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9595
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
9596
-
9597
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
9651
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
9652
+ const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
9653
+ float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
9598
9654
 
9599
9655
  if (type != WSP_GGML_TYPE_F32) {
9600
9656
  float * const wdata = params->wdata;
@@ -9611,10 +9667,10 @@ static void wsp_ggml_compute_forward_mul_mat(
9611
9667
  }
9612
9668
 
9613
9669
  cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
9614
- ne11, ne01, ne10,
9615
- 1.0f, y, ne10,
9616
- x, ne00,
9617
- 0.0f, d, ne01);
9670
+ cne1, ne01, ne10,
9671
+ 1.0f, y, ne10,
9672
+ x, ne00,
9673
+ 0.0f, d, ne01);
9618
9674
  }
9619
9675
  }
9620
9676
 
@@ -9630,6 +9686,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9630
9686
  const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9631
9687
 
9632
9688
  assert(params->wsize >= ne11*ne12*ne13*row_size);
9689
+ assert(src1->type == WSP_GGML_TYPE_F32);
9633
9690
 
9634
9691
  for (int64_t i13 = 0; i13 < ne13; ++i13) {
9635
9692
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
@@ -9652,7 +9709,7 @@ static void wsp_ggml_compute_forward_mul_mat(
9652
9709
  const size_t row_size = ne10*wsp_ggml_type_size(vec_dot_type)/wsp_ggml_blck_size(vec_dot_type);
9653
9710
 
9654
9711
  const int64_t nr0 = ne01; // src0 rows
9655
- const int64_t nr1 = ne11*ne12*ne13; // src1 rows
9712
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
9656
9713
 
9657
9714
  //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
9658
9715
 
@@ -9694,9 +9751,9 @@ static void wsp_ggml_compute_forward_mul_mat(
9694
9751
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
9695
9752
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
9696
9753
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
9697
- const int64_t i13 = (ir1/(ne12*ne11));
9698
- const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
9699
- const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
9754
+ const int64_t i13 = (ir1/(ne12*cne1));
9755
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
9756
+ const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
9700
9757
 
9701
9758
  // broadcast src0 into src1
9702
9759
  const int64_t i03 = i13/r3;
@@ -9736,20 +9793,28 @@ static void wsp_ggml_compute_forward_mul_mat(
9736
9793
 
9737
9794
  static void wsp_ggml_compute_forward_mul_mat_id(
9738
9795
  const struct wsp_ggml_compute_params * params,
9796
+ const struct wsp_ggml_tensor * src0,
9797
+ const struct wsp_ggml_tensor * src1,
9739
9798
  struct wsp_ggml_tensor * dst) {
9740
9799
 
9741
- const struct wsp_ggml_tensor * ids = dst->src[0];
9742
- const struct wsp_ggml_tensor * src1 = dst->src[1];
9743
-
9744
- const int id = wsp_ggml_get_op_params_i32(dst, 0);
9800
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
9801
+ // during WSP_GGML_TASK_INIT the entire src1 is converted to vec_dot_type
9802
+ wsp_ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
9803
+ return;
9804
+ }
9745
9805
 
9746
- const int a_id = ((int32_t *)ids->data)[id];
9806
+ const struct wsp_ggml_tensor * ids = src0;
9807
+ const int id = wsp_ggml_get_op_params_i32(dst, 0);
9808
+ const int n_as = wsp_ggml_get_op_params_i32(dst, 1);
9747
9809
 
9748
- WSP_GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
9810
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
9811
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
9749
9812
 
9750
- const struct wsp_ggml_tensor * src0 = dst->src[a_id + 2];
9813
+ WSP_GGML_ASSERT(row_id >= 0 && row_id < n_as);
9751
9814
 
9752
- wsp_ggml_compute_forward_mul_mat(params, src0, src1, dst);
9815
+ const struct wsp_ggml_tensor * src0_row = dst->src[row_id + 2];
9816
+ wsp_ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
9817
+ }
9753
9818
  }
9754
9819
 
9755
9820
  // wsp_ggml_compute_forward_out_prod
@@ -10161,7 +10226,7 @@ static void wsp_ggml_compute_forward_set_f32(
10161
10226
  WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst) && wsp_ggml_is_contiguous(src0));
10162
10227
 
10163
10228
  // view src0 and dst with these strides and data offset inbytes during set
10164
- // nb0 is implicitely element_size because src0 and dst are contiguous
10229
+ // nb0 is implicitly element_size because src0 and dst are contiguous
10165
10230
  size_t nb1 = ((int32_t *) dst->op_params)[0];
10166
10231
  size_t nb2 = ((int32_t *) dst->op_params)[1];
10167
10232
  size_t nb3 = ((int32_t *) dst->op_params)[2];
@@ -10325,21 +10390,30 @@ static void wsp_ggml_compute_forward_get_rows_q(
10325
10390
  return;
10326
10391
  }
10327
10392
 
10328
- const int nc = src0->ne[0];
10329
- const int nr = wsp_ggml_nelements(src1);
10393
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
10394
+
10395
+ const int64_t nc = ne00;
10396
+ const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
10397
+
10330
10398
  const enum wsp_ggml_type type = src0->type;
10331
10399
  wsp_ggml_to_float_t const wsp_dewsp_quantize_row_q = type_traits[type].to_float;
10332
10400
 
10333
- assert( dst->ne[0] == nc);
10334
- assert( dst->ne[1] == nr);
10335
- assert(src0->nb[0] == wsp_ggml_type_size(type));
10401
+ assert(ne0 == nc);
10402
+ assert(ne02 == ne11);
10403
+ assert(nb00 == wsp_ggml_type_size(type));
10404
+ assert(wsp_ggml_nrows(dst) == nr);
10336
10405
 
10337
- for (int i = 0; i < nr; ++i) {
10338
- const int r = ((int32_t *) src1->data)[i];
10406
+ // TODO: multi-thread
10407
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10408
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10409
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10410
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10339
10411
 
10340
- wsp_dewsp_quantize_row_q(
10341
- (const void *) ((char *) src0->data + r*src0->nb[1]),
10342
- (float *) ((char *) dst->data + i*dst->nb[1]), nc);
10412
+ wsp_dewsp_quantize_row_q(
10413
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10414
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10415
+ }
10416
+ }
10343
10417
  }
10344
10418
  }
10345
10419
 
@@ -10354,19 +10428,26 @@ static void wsp_ggml_compute_forward_get_rows_f16(
10354
10428
  return;
10355
10429
  }
10356
10430
 
10357
- const int nc = src0->ne[0];
10358
- const int nr = wsp_ggml_nelements(src1);
10431
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
10359
10432
 
10360
- assert( dst->ne[0] == nc);
10361
- assert( dst->ne[1] == nr);
10362
- assert(src0->nb[0] == sizeof(wsp_ggml_fp16_t));
10433
+ const int64_t nc = ne00;
10434
+ const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
10363
10435
 
10364
- for (int i = 0; i < nr; ++i) {
10365
- const int r = ((int32_t *) src1->data)[i];
10436
+ assert(ne0 == nc);
10437
+ assert(ne02 == ne11);
10438
+ assert(nb00 == sizeof(wsp_ggml_fp16_t));
10439
+ assert(wsp_ggml_nrows(dst) == nr);
10366
10440
 
10367
- for (int j = 0; j < nc; ++j) {
10368
- wsp_ggml_fp16_t v = ((wsp_ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
10369
- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = WSP_GGML_FP16_TO_FP32(v);
10441
+ // TODO: multi-thread
10442
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10443
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10444
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10445
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10446
+
10447
+ wsp_ggml_fp16_to_fp32_row(
10448
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
10449
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
10450
+ }
10370
10451
  }
10371
10452
  }
10372
10453
  }
@@ -10382,19 +10463,27 @@ static void wsp_ggml_compute_forward_get_rows_f32(
10382
10463
  return;
10383
10464
  }
10384
10465
 
10385
- const int nc = src0->ne[0];
10386
- const int nr = wsp_ggml_nelements(src1);
10466
+ WSP_GGML_TENSOR_BINARY_OP_LOCALS
10387
10467
 
10388
- assert( dst->ne[0] == nc);
10389
- assert( dst->ne[1] == nr);
10390
- assert(src0->nb[0] == sizeof(float));
10468
+ const int64_t nc = ne00;
10469
+ const int64_t nr = wsp_ggml_nelements(src1); WSP_GGML_UNUSED(nr);
10391
10470
 
10392
- for (int i = 0; i < nr; ++i) {
10393
- const int r = ((int32_t *) src1->data)[i];
10471
+ assert(ne0 == nc);
10472
+ assert(ne02 == ne11);
10473
+ assert(nb00 == sizeof(float));
10474
+ assert(wsp_ggml_nrows(dst) == nr);
10394
10475
 
10395
- wsp_ggml_vec_cpy_f32(nc,
10396
- (float *) ((char *) dst->data + i*dst->nb[1]),
10397
- (float *) ((char *) src0->data + r*src0->nb[1]));
10476
+ // TODO: multi-thread
10477
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
10478
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
10479
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
10480
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
10481
+
10482
+ wsp_ggml_vec_cpy_f32(nc,
10483
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
10484
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
10485
+ }
10486
+ }
10398
10487
  }
10399
10488
  }
10400
10489
 
@@ -12114,6 +12203,7 @@ static void wsp_ggml_compute_forward_upscale_f32(
12114
12203
  WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
12115
12204
 
12116
12205
  const int ith = params->ith;
12206
+ const int nth = params->nth;
12117
12207
 
12118
12208
  WSP_GGML_TENSOR_UNARY_OP_LOCALS
12119
12209
 
@@ -12121,16 +12211,17 @@ static void wsp_ggml_compute_forward_upscale_f32(
12121
12211
 
12122
12212
  // TODO: optimize
12123
12213
 
12124
- for (int i03 = 0; i03 < ne03; i03++) {
12125
- for (int i02 = ith; i02 < ne02; i02++) {
12126
- for (int m = 0; m < dst->ne[1]; m++) {
12127
- int i01 = m / scale_factor;
12128
- for (int n = 0; n < dst->ne[0]; n++) {
12129
- int i00 = n / scale_factor;
12130
-
12131
- const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
12214
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
12215
+ const int64_t i03 = i3;
12216
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
12217
+ const int64_t i02 = i2;
12218
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
12219
+ const int64_t i01 = i1 / scale_factor;
12220
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
12221
+ const int64_t i00 = i0 / scale_factor;
12132
12222
 
12133
- float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
12223
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
12224
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
12134
12225
 
12135
12226
  *y = *x;
12136
12227
  }
@@ -12155,6 +12246,64 @@ static void wsp_ggml_compute_forward_upscale(
12155
12246
  }
12156
12247
  }
12157
12248
 
12249
+ // wsp_ggml_compute_forward_pad
12250
+
12251
+ static void wsp_ggml_compute_forward_pad_f32(
12252
+ const struct wsp_ggml_compute_params * params,
12253
+ const struct wsp_ggml_tensor * src0,
12254
+ struct wsp_ggml_tensor * dst) {
12255
+
12256
+ if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) {
12257
+ return;
12258
+ }
12259
+
12260
+ WSP_GGML_ASSERT(src0->nb[0] == sizeof(float));
12261
+ WSP_GGML_ASSERT( dst->nb[0] == sizeof(float));
12262
+
12263
+ const int ith = params->ith;
12264
+ const int nth = params->nth;
12265
+
12266
+ WSP_GGML_TENSOR_UNARY_OP_LOCALS
12267
+
12268
+ float * dst_ptr = (float *) dst->data;
12269
+
12270
+ // TODO: optimize
12271
+
12272
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
12273
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
12274
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
12275
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
12276
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
12277
+
12278
+ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
12279
+
12280
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
12281
+ dst_ptr[dst_idx] = *src_ptr;
12282
+ } else {
12283
+ dst_ptr[dst_idx] = 0;
12284
+ }
12285
+ }
12286
+ }
12287
+ }
12288
+ }
12289
+ }
12290
+
12291
+ static void wsp_ggml_compute_forward_pad(
12292
+ const struct wsp_ggml_compute_params * params,
12293
+ const struct wsp_ggml_tensor * src0,
12294
+ struct wsp_ggml_tensor * dst) {
12295
+ switch (src0->type) {
12296
+ case WSP_GGML_TYPE_F32:
12297
+ {
12298
+ wsp_ggml_compute_forward_pad_f32(params, src0, dst);
12299
+ } break;
12300
+ default:
12301
+ {
12302
+ WSP_GGML_ASSERT(false);
12303
+ } break;
12304
+ }
12305
+ }
12306
+
12158
12307
  // wsp_ggml_compute_forward_argsort
12159
12308
 
12160
12309
  static void wsp_ggml_compute_forward_argsort_f32(
@@ -13362,10 +13511,6 @@ static void wsp_ggml_compute_forward_unary(
13362
13511
  {
13363
13512
  wsp_ggml_compute_forward_silu(params, src0, dst);
13364
13513
  } break;
13365
- case WSP_GGML_UNARY_OP_LEAKY:
13366
- {
13367
- wsp_ggml_compute_forward_leaky(params, src0, dst);
13368
- } break;
13369
13514
  default:
13370
13515
  {
13371
13516
  WSP_GGML_ASSERT(false);
@@ -14037,11 +14182,11 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14037
14182
  } break;
14038
14183
  case WSP_GGML_OP_MUL_MAT:
14039
14184
  {
14040
- wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14185
+ wsp_ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
14041
14186
  } break;
14042
14187
  case WSP_GGML_OP_MUL_MAT_ID:
14043
14188
  {
14044
- wsp_ggml_compute_forward_mul_mat_id(params, tensor);
14189
+ wsp_ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
14045
14190
  } break;
14046
14191
  case WSP_GGML_OP_OUT_PROD:
14047
14192
  {
@@ -14147,10 +14292,18 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st
14147
14292
  {
14148
14293
  wsp_ggml_compute_forward_upscale(params, tensor->src[0], tensor);
14149
14294
  } break;
14295
+ case WSP_GGML_OP_PAD:
14296
+ {
14297
+ wsp_ggml_compute_forward_pad(params, tensor->src[0], tensor);
14298
+ } break;
14150
14299
  case WSP_GGML_OP_ARGSORT:
14151
14300
  {
14152
14301
  wsp_ggml_compute_forward_argsort(params, tensor->src[0], tensor);
14153
14302
  } break;
14303
+ case WSP_GGML_OP_LEAKY_RELU:
14304
+ {
14305
+ wsp_ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
14306
+ } break;
14154
14307
  case WSP_GGML_OP_FLASH_ATTN:
14155
14308
  {
14156
14309
  const int32_t t = wsp_ggml_get_op_params_i32(tensor, 0);
@@ -14475,7 +14628,7 @@ void wsp_ggml_build_backward_gradient_checkpointing(
14475
14628
  // insert new tensors recomputing src, reusing already made replacements,
14476
14629
  // remember replacements: remember new tensors with mapping from corresponding gf nodes
14477
14630
  // recurse for input tensors,
14478
- // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
14631
+ // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
14479
14632
  node->src[k] = wsp_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
14480
14633
  }
14481
14634
  // insert rewritten backward node with replacements made into resulting backward graph gb
@@ -15143,10 +15296,18 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_
15143
15296
  {
15144
15297
  WSP_GGML_ASSERT(false); // TODO: not implemented
15145
15298
  } break;
15299
+ case WSP_GGML_OP_PAD:
15300
+ {
15301
+ WSP_GGML_ASSERT(false); // TODO: not implemented
15302
+ } break;
15146
15303
  case WSP_GGML_OP_ARGSORT:
15147
15304
  {
15148
15305
  WSP_GGML_ASSERT(false); // TODO: not implemented
15149
15306
  } break;
15307
+ case WSP_GGML_OP_LEAKY_RELU:
15308
+ {
15309
+ WSP_GGML_ASSERT(false); // TODO: not implemented
15310
+ } break;
15150
15311
  case WSP_GGML_OP_FLASH_ATTN:
15151
15312
  {
15152
15313
  struct wsp_ggml_tensor * flash_grad = NULL;
@@ -15752,6 +15913,7 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
15752
15913
  case WSP_GGML_OP_ARGMAX:
15753
15914
  case WSP_GGML_OP_REPEAT:
15754
15915
  case WSP_GGML_OP_REPEAT_BACK:
15916
+ case WSP_GGML_OP_LEAKY_RELU:
15755
15917
  {
15756
15918
  n_tasks = 1;
15757
15919
  } break;
@@ -15764,7 +15926,6 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
15764
15926
  case WSP_GGML_UNARY_OP_TANH:
15765
15927
  case WSP_GGML_UNARY_OP_ELU:
15766
15928
  case WSP_GGML_UNARY_OP_RELU:
15767
- case WSP_GGML_UNARY_OP_LEAKY:
15768
15929
  {
15769
15930
  n_tasks = 1;
15770
15931
  } break;
@@ -15883,6 +16044,10 @@ static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) {
15883
16044
  {
15884
16045
  n_tasks = n_threads;
15885
16046
  } break;
16047
+ case WSP_GGML_OP_PAD:
16048
+ {
16049
+ n_tasks = n_threads;
16050
+ } break;
15886
16051
  case WSP_GGML_OP_ARGSORT:
15887
16052
  {
15888
16053
  n_tasks = n_threads;