llama_cpp 0.8.0 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -571,7 +571,6 @@ int64_t ggml_cycles_per_ms(void) {
571
571
  #define ggml_perf_cycles_per_ms() 0
572
572
  #endif
573
573
 
574
-
575
574
  //
576
575
  // cache line
577
576
  //
@@ -1828,7 +1827,6 @@ ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
1828
1827
  return type_traits[type];
1829
1828
  }
1830
1829
 
1831
-
1832
1830
  //
1833
1831
  // simd mappings
1834
1832
  //
@@ -4057,16 +4055,17 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4057
4055
  "ALIBI",
4058
4056
  "CLAMP",
4059
4057
  "CONV_1D",
4058
+ "CONV_1D_STAGE_0",
4059
+ "CONV_1D_STAGE_1",
4060
4060
  "CONV_TRANSPOSE_1D",
4061
4061
  "CONV_2D",
4062
+ "CONV_2D_STAGE_0",
4063
+ "CONV_2D_STAGE_1",
4062
4064
  "CONV_TRANSPOSE_2D",
4063
4065
  "POOL_1D",
4064
4066
  "POOL_2D",
4065
4067
  "UPSCALE",
4066
4068
 
4067
- "CONV_1D_STAGE_0",
4068
- "CONV_1D_STAGE_1",
4069
-
4070
4069
  "FLASH_ATTN",
4071
4070
  "FLASH_FF",
4072
4071
  "FLASH_ATTN_BACK",
@@ -4092,7 +4091,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
4092
4091
  "CROSS_ENTROPY_LOSS_BACK",
4093
4092
  };
4094
4093
 
4095
- static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4094
+ static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
4096
4095
 
4097
4096
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4098
4097
  "none",
@@ -4143,16 +4142,17 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4143
4142
  "alibi(x)",
4144
4143
  "clamp(x)",
4145
4144
  "conv_1d(x)",
4145
+ "conv_1d_stage_0(x)",
4146
+ "conv_1d_stage_1(x)",
4146
4147
  "conv_transpose_1d(x)",
4147
4148
  "conv_2d(x)",
4149
+ "conv_2d_stage_0(x)",
4150
+ "conv_2d_stage_1(x)",
4148
4151
  "conv_transpose_2d(x)",
4149
4152
  "pool_1d(x)",
4150
4153
  "pool_2d(x)",
4151
4154
  "upscale(x)",
4152
4155
 
4153
- "conv_1d_stage_0(x)",
4154
- "conv_1d_stage_1(x)",
4155
-
4156
4156
  "flash_attn(x)",
4157
4157
  "flash_ff(x)",
4158
4158
  "flash_attn_back(x)",
@@ -4178,7 +4178,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
4178
4178
  "cross_entropy_loss_back(x,y)",
4179
4179
  };
4180
4180
 
4181
- static_assert(GGML_OP_COUNT == 71, "GGML_OP_COUNT != 71");
4181
+ static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
4182
4182
 
4183
4183
  static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
4184
4184
 
@@ -4209,8 +4209,10 @@ static void ggml_setup_op_has_task_pass(void) {
4209
4209
  p[GGML_OP_CONV_1D ] = true;
4210
4210
  p[GGML_OP_CONV_1D_STAGE_0 ] = true;
4211
4211
  p[GGML_OP_CONV_1D_STAGE_1 ] = true;
4212
- p[GGML_OP_CONV_2D ] = true;
4213
4212
  p[GGML_OP_CONV_TRANSPOSE_1D ] = true;
4213
+ p[GGML_OP_CONV_2D ] = true;
4214
+ p[GGML_OP_CONV_2D_STAGE_0 ] = true;
4215
+ p[GGML_OP_CONV_2D_STAGE_1 ] = true;
4214
4216
  p[GGML_OP_CONV_TRANSPOSE_2D ] = true;
4215
4217
  p[GGML_OP_FLASH_ATTN_BACK ] = true;
4216
4218
  p[GGML_OP_CROSS_ENTROPY_LOSS ] = true;
@@ -5954,7 +5956,6 @@ struct ggml_tensor * ggml_sqrt_inplace(
5954
5956
  return ggml_sqrt_impl(ctx, a, true);
5955
5957
  }
5956
5958
 
5957
-
5958
5959
  // ggml_log
5959
5960
 
5960
5961
  static struct ggml_tensor * ggml_log_impl(
@@ -6008,7 +6009,6 @@ struct ggml_tensor * ggml_sum(
6008
6009
  return result;
6009
6010
  }
6010
6011
 
6011
-
6012
6012
  // ggml_sum_rows
6013
6013
 
6014
6014
  struct ggml_tensor * ggml_sum_rows(
@@ -6640,7 +6640,6 @@ struct ggml_tensor * ggml_set_2d_inplace(
6640
6640
  return ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
6641
6641
  }
6642
6642
 
6643
-
6644
6643
  // ggml_cpy
6645
6644
 
6646
6645
  static struct ggml_tensor * ggml_cpy_impl(
@@ -6720,7 +6719,6 @@ struct ggml_tensor * ggml_cont_inplace(
6720
6719
  return ggml_cont_impl(ctx, a, true);
6721
6720
  }
6722
6721
 
6723
-
6724
6722
  // make contiguous, with new shape
6725
6723
  GGML_API struct ggml_tensor * ggml_cont_1d(
6726
6724
  struct ggml_context * ctx,
@@ -7173,7 +7171,6 @@ struct ggml_tensor * ggml_diag(
7173
7171
  return result;
7174
7172
  }
7175
7173
 
7176
-
7177
7174
  // ggml_diag_mask_inf
7178
7175
 
7179
7176
  static struct ggml_tensor * ggml_diag_mask_inf_impl(
@@ -7285,7 +7282,6 @@ struct ggml_tensor * ggml_soft_max_inplace(
7285
7282
  return ggml_soft_max_impl(ctx, a, true);
7286
7283
  }
7287
7284
 
7288
-
7289
7285
  // ggml_soft_max_back
7290
7286
 
7291
7287
  static struct ggml_tensor * ggml_soft_max_back_impl(
@@ -7702,7 +7698,11 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
7702
7698
 
7703
7699
  // ggml_conv_2d
7704
7700
 
7705
- struct ggml_tensor * ggml_conv_2d(
7701
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
7702
+ // a: [OC,IC, KH, KW]
7703
+ // b: [N, IC, IH, IW]
7704
+ // result: [N, OH, OW, IC*KH*KW]
7705
+ static struct ggml_tensor * ggml_conv_2d_stage_0(
7706
7706
  struct ggml_context * ctx,
7707
7707
  struct ggml_tensor * a,
7708
7708
  struct ggml_tensor * b,
@@ -7721,17 +7721,21 @@ struct ggml_tensor * ggml_conv_2d(
7721
7721
  is_node = true;
7722
7722
  }
7723
7723
 
7724
+ const int64_t OH = ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1);
7725
+ const int64_t OW = ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
7726
+
7724
7727
  const int64_t ne[4] = {
7725
- ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7726
- ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
7727
- a->ne[3], b->ne[3],
7728
+ a->ne[2] * a->ne[1] * a->ne[0],
7729
+ OW,
7730
+ OH,
7731
+ b->ne[3],
7728
7732
  };
7729
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7733
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F16, 4, ne);
7730
7734
 
7731
7735
  int32_t params[] = { s0, s1, p0, p1, d0, d1 };
7732
7736
  ggml_set_op_params(result, params, sizeof(params));
7733
7737
 
7734
- result->op = GGML_OP_CONV_2D;
7738
+ result->op = GGML_OP_CONV_2D_STAGE_0;
7735
7739
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7736
7740
  result->src[0] = a;
7737
7741
  result->src[1] = b;
@@ -7740,8 +7744,61 @@ struct ggml_tensor * ggml_conv_2d(
7740
7744
 
7741
7745
  }
7742
7746
 
7743
- // ggml_conv_2d_sk_p0
7747
+ // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
7748
+ // a: [OC, IC, KH, KW]
7749
+ // b: [N, OH, OW, IC * KH * KW]
7750
+ // result: [N, OC, OH, OW]
7751
+ static struct ggml_tensor * ggml_conv_2d_stage_1(
7752
+ struct ggml_context * ctx,
7753
+ struct ggml_tensor * a,
7754
+ struct ggml_tensor * b) {
7755
+
7756
+ bool is_node = false;
7757
+
7758
+ if (a->grad || b->grad) {
7759
+ GGML_ASSERT(false); // TODO: implement backward
7760
+ is_node = true;
7761
+ }
7762
+
7763
+ const int64_t ne[4] = {
7764
+ b->ne[1],
7765
+ b->ne[2],
7766
+ a->ne[3],
7767
+ b->ne[3],
7768
+ };
7769
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7770
+
7771
+ result->op = GGML_OP_CONV_2D_STAGE_1;
7772
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7773
+ result->src[0] = a;
7774
+ result->src[1] = b;
7775
+
7776
+ return result;
7777
+
7778
+ }
7779
+
7780
+ // a: [OC,IC, KH, KW]
7781
+ // b: [N, IC, IH, IW]
7782
+ // result: [N, OC, OH, OW]
7783
+ struct ggml_tensor * ggml_conv_2d(
7784
+ struct ggml_context * ctx,
7785
+ struct ggml_tensor * a,
7786
+ struct ggml_tensor * b,
7787
+ int s0,
7788
+ int s1,
7789
+ int p0,
7790
+ int p1,
7791
+ int d0,
7792
+ int d1) {
7744
7793
 
7794
+ struct ggml_tensor * result = ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW]
7795
+ result = ggml_conv_2d_stage_1(ctx, a, result);
7796
+
7797
+ return result;
7798
+
7799
+ }
7800
+
7801
+ // ggml_conv_2d_sk_p0
7745
7802
  struct ggml_tensor * ggml_conv_2d_sk_p0(
7746
7803
  struct ggml_context * ctx,
7747
7804
  struct ggml_tensor * a,
@@ -8180,7 +8237,6 @@ static struct ggml_tensor * ggml_add_rel_pos_impl(
8180
8237
  return result;
8181
8238
  }
8182
8239
 
8183
-
8184
8240
  struct ggml_tensor * ggml_add_rel_pos(
8185
8241
  struct ggml_context * ctx,
8186
8242
  struct ggml_tensor * a,
@@ -8625,8 +8681,6 @@ struct ggml_tensor * ggml_map_custom3_inplace(
8625
8681
  return ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
8626
8682
  }
8627
8683
 
8628
-
8629
-
8630
8684
  // ggml_cross_entropy_loss
8631
8685
 
8632
8686
  struct ggml_tensor * ggml_cross_entropy_loss(
@@ -9828,7 +9882,6 @@ static void ggml_compute_forward_add1(
9828
9882
  }
9829
9883
  }
9830
9884
 
9831
-
9832
9885
  // ggml_compute_forward_acc
9833
9886
 
9834
9887
  static void ggml_compute_forward_acc_f32(
@@ -9968,7 +10021,6 @@ static void ggml_compute_forward_sub_f32(
9968
10021
  const int i2 = (ir - i3*ne2*ne1)/ne1;
9969
10022
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
9970
10023
 
9971
-
9972
10024
  #ifdef GGML_USE_ACCELERATE
9973
10025
  vDSP_vsub(
9974
10026
  (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
@@ -10149,7 +10201,6 @@ static void ggml_compute_forward_div_f32(
10149
10201
  const int i2 = (ir - i3*ne2*ne1)/ne1;
10150
10202
  const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
10151
10203
 
10152
-
10153
10204
  #ifdef GGML_USE_ACCELERATE
10154
10205
  UNUSED(ggml_vec_div_f32);
10155
10206
 
@@ -10287,7 +10338,6 @@ static void ggml_compute_forward_sqrt(
10287
10338
  }
10288
10339
  }
10289
10340
 
10290
-
10291
10341
  // ggml_compute_forward_log
10292
10342
 
10293
10343
  static void ggml_compute_forward_log_f32(
@@ -12120,7 +12170,6 @@ static void ggml_compute_forward_out_prod_f32(
12120
12170
  }
12121
12171
  }
12122
12172
 
12123
-
12124
12173
  //int64_t t1 = ggml_perf_time_us();
12125
12174
  //static int64_t acc = 0;
12126
12175
  //acc += t1 - t0;
@@ -12316,7 +12365,6 @@ static void ggml_compute_forward_scale_f32(
12316
12365
 
12317
12366
  const size_t nb1 = dst->nb[1];
12318
12367
 
12319
-
12320
12368
  for (int i1 = ir0; i1 < ir1; i1++) {
12321
12369
  if (dst->data != src0->data) {
12322
12370
  // src0 is same shape as dst => same indices
@@ -12714,7 +12762,6 @@ static void ggml_compute_forward_get_rows_back_f32(
12714
12762
  }
12715
12763
  }
12716
12764
 
12717
-
12718
12765
  static void ggml_compute_forward_get_rows_back(
12719
12766
  const struct ggml_compute_params * params,
12720
12767
  const struct ggml_tensor * src0,
@@ -13997,6 +14044,7 @@ static void ggml_compute_forward_conv_1d_f32(
13997
14044
  }
13998
14045
  }
13999
14046
 
14047
+ // TODO: reuse ggml_mul_mat or implement ggml_im2col and remove stage_0 and stage_1
14000
14048
  static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k,
14001
14049
  ggml_fp16_t * A,
14002
14050
  ggml_fp16_t * B,
@@ -14298,6 +14346,9 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
14298
14346
  }
14299
14347
  }
14300
14348
 
14349
+ // need to zero dst since we are accumulating into it
14350
+ memset(dst->data, 0, ggml_nbytes(dst));
14351
+
14301
14352
  return;
14302
14353
  }
14303
14354
 
@@ -14370,7 +14421,7 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
14370
14421
  const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
14371
14422
  float * dst_data = wdata + i01*ne00*ne02;
14372
14423
  for (int64_t i00 = 0; i00 < ne00; i00++) {
14373
- dst_data[i01*ne00*ne02 + i00*ne02 + i02] = src[i00];
14424
+ dst_data[i00*ne02 + i02] = src[i00];
14374
14425
  }
14375
14426
  }
14376
14427
  }
@@ -14389,6 +14440,9 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
14389
14440
  }
14390
14441
  }
14391
14442
 
14443
+ // need to zero dst since we are accumulating into it
14444
+ memset(dst->data, 0, ggml_nbytes(dst));
14445
+
14392
14446
  return;
14393
14447
  }
14394
14448
 
@@ -14450,28 +14504,190 @@ static void ggml_compute_forward_conv_transpose_1d(
14450
14504
 
14451
14505
  // ggml_compute_forward_conv_2d
14452
14506
 
14453
- static void ggml_compute_forward_conv_2d_f16_f32(
14507
+ // src0: kernel [OC, IC, KH, KW]
14508
+ // src1: image [N, IC, IH, IW]
14509
+ // dst: result [N, OH, OW, IC*KH*KW]
14510
+ static void ggml_compute_forward_conv_2d_stage_0_f32(
14454
14511
  const struct ggml_compute_params * params,
14455
14512
  const struct ggml_tensor * src0,
14456
14513
  const struct ggml_tensor * src1,
14457
14514
  struct ggml_tensor * dst) {
14458
14515
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
14459
14516
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
14517
+ GGML_ASSERT( dst->type == GGML_TYPE_F16);
14518
+
14519
+ int64_t t0 = ggml_perf_time_us();
14520
+ UNUSED(t0);
14521
+
14522
+ GGML_TENSOR_BINARY_OP_LOCALS;
14523
+
14524
+ const int64_t N = ne13;
14525
+ const int64_t IC = ne12;
14526
+ const int64_t IH = ne11;
14527
+ const int64_t IW = ne10;
14528
+
14529
+ // const int64_t OC = ne03;
14530
+ // const int64_t IC = ne02;
14531
+ const int64_t KH = ne01;
14532
+ const int64_t KW = ne00;
14533
+
14534
+ const int64_t OH = ne2;
14535
+ const int64_t OW = ne1;
14536
+
14537
+ const int ith = params->ith;
14538
+ const int nth = params->nth;
14539
+
14540
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14541
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
14542
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
14543
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
14544
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
14545
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
14546
+
14547
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14548
+ GGML_ASSERT(nb10 == sizeof(float));
14549
+
14550
+ if (params->type == GGML_TASK_INIT) {
14551
+ memset(dst->data, 0, ggml_nbytes(dst));
14552
+ return;
14553
+ }
14554
+
14555
+ if (params->type == GGML_TASK_FINALIZE) {
14556
+ return;
14557
+ }
14558
+
14559
+ // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
14560
+ {
14561
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
14562
+
14563
+ for (int64_t in = 0; in < N; in++) {
14564
+ for (int64_t ioh = 0; ioh < OH; ioh++) {
14565
+ for (int64_t iow = 0; iow < OW; iow++) {
14566
+ for (int64_t iic = ith; iic < IC; iic+=nth) {
14567
+
14568
+ // micro kernel
14569
+ ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
14570
+ const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
14571
+
14572
+ for (int64_t ikh = 0; ikh < KH; ikh++) {
14573
+ for (int64_t ikw = 0; ikw < KW; ikw++) {
14574
+ const int64_t iiw = iow*s0 + ikw*d0 - p0;
14575
+ const int64_t iih = ioh*s1 + ikh*d1 - p1;
14576
+
14577
+ if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
14578
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
14579
+ }
14580
+ }
14581
+ }
14582
+ }
14583
+ }
14584
+ }
14585
+ }
14586
+ }
14587
+ }
14588
+
14589
+ // gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
14590
+ // src0: [OC, IC, KH, KW]
14591
+ // src1: [N, OH, OW, IC * KH * KW]
14592
+ // result: [N, OC, OH, OW]
14593
+ static void ggml_compute_forward_conv_2d_stage_1_f16(
14594
+ const struct ggml_compute_params * params,
14595
+ const struct ggml_tensor * src0,
14596
+ const struct ggml_tensor * src1,
14597
+ struct ggml_tensor * dst) {
14598
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14599
+ GGML_ASSERT(src1->type == GGML_TYPE_F16);
14460
14600
  GGML_ASSERT( dst->type == GGML_TYPE_F32);
14461
14601
 
14462
14602
  int64_t t0 = ggml_perf_time_us();
14463
14603
  UNUSED(t0);
14464
14604
 
14605
+ if (params->type == GGML_TASK_INIT) {
14606
+ return;
14607
+ }
14608
+
14609
+ if (params->type == GGML_TASK_FINALIZE) {
14610
+ return;
14611
+ }
14612
+
14465
14613
  GGML_TENSOR_BINARY_OP_LOCALS;
14466
14614
 
14615
+ GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
14616
+ GGML_ASSERT(nb10 == sizeof(ggml_fp16_t));
14617
+ GGML_ASSERT(nb0 == sizeof(float));
14618
+
14619
+ const int N = ne13;
14620
+ const int OH = ne12;
14621
+ const int OW = ne11;
14622
+
14623
+ const int OC = ne03;
14624
+ const int IC = ne02;
14625
+ const int KH = ne01;
14626
+ const int KW = ne00;
14627
+
14628
+ const int ith = params->ith;
14629
+ const int nth = params->nth;
14630
+
14631
+ int64_t m = OC;
14632
+ int64_t n = OH * OW;
14633
+ int64_t k = IC * KH * KW;
14634
+
14635
+ // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
14636
+ for (int i = 0; i < N; i++) {
14637
+ ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
14638
+ ggml_fp16_t * B = (ggml_fp16_t *)src1->data + i * m * k; // [n, k]
14639
+ float * C = (float *)dst->data + i * m * n; // [m, n]
14640
+
14641
+ gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
14642
+ }
14643
+ }
14644
+
14645
+ static void ggml_compute_forward_conv_2d_f16_f32(
14646
+ const struct ggml_compute_params * params,
14647
+ const struct ggml_tensor * src0,
14648
+ const struct ggml_tensor * src1,
14649
+ struct ggml_tensor * dst) {
14650
+ GGML_ASSERT(src0->type == GGML_TYPE_F16);
14651
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
14652
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
14653
+
14654
+ int64_t t0 = ggml_perf_time_us();
14655
+ UNUSED(t0);
14656
+
14657
+ GGML_TENSOR_BINARY_OP_LOCALS
14658
+
14659
+ // src1: image [N, IC, IH, IW]
14660
+ // src0: kernel [OC, IC, KH, KW]
14661
+ // dst: result [N, OC, OH, OW]
14662
+ // ne12: IC
14663
+ // ne0: OW
14664
+ // ne1: OH
14665
+ // nk0: KW
14666
+ // nk1: KH
14667
+ // ne13: N
14668
+
14669
+ const int N = ne13;
14670
+ const int IC = ne12;
14671
+ const int IH = ne11;
14672
+ const int IW = ne10;
14673
+
14674
+ const int OC = ne03;
14675
+ // const int IC = ne02;
14676
+ const int KH = ne01;
14677
+ const int KW = ne00;
14678
+
14679
+ const int OH = ne1;
14680
+ const int OW = ne0;
14681
+
14467
14682
  const int ith = params->ith;
14468
14683
  const int nth = params->nth;
14469
14684
 
14470
- const int nk0 = ne00;
14471
- const int nk1 = ne01;
14685
+ // const int nk0 = ne00;
14686
+ // const int nk1 = ne01;
14472
14687
 
14473
14688
  // size of the convolution row - the kernel size unrolled across all channels
14474
- const int ew0 = nk0*nk1*ne02;
14689
+ // const int ew0 = nk0*nk1*ne02;
14690
+ // ew0: IC*KH*KW
14475
14691
 
14476
14692
  const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
14477
14693
  const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
@@ -14487,24 +14703,27 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14487
14703
  memset(params->wdata, 0, params->wsize);
14488
14704
 
14489
14705
  // prepare source data (src1)
14706
+ // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW]
14707
+
14490
14708
  {
14491
14709
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14492
14710
 
14493
- for (int i13 = 0; i13 < ne13; i13++) {
14494
- for (int i12 = 0; i12 < ne12; i12++) {
14495
- const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
14496
- ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
14497
-
14498
- for (int i1 = 0; i1 < ne1; i1++) {
14499
- for (int i0 = 0; i0 < ne0; i0++) {
14500
- for (int ik1 = 0; ik1 < nk1; ik1++) {
14501
- for (int ik0 = 0; ik0 < nk0; ik0++) {
14502
- const int idx0 = i0*s0 + ik0*d0 - p0;
14503
- const int idx1 = i1*s1 + ik1*d1 - p1;
14504
-
14505
- if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
14506
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
14507
- GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
14711
+ for (int in = 0; in < N; in++) {
14712
+ for (int iic = 0; iic < IC; iic++) {
14713
+ for (int ioh = 0; ioh < OH; ioh++) {
14714
+ for (int iow = 0; iow < OW; iow++) {
14715
+
14716
+ // micro kernel
14717
+ ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
14718
+ const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW]
14719
+
14720
+ for (int ikh = 0; ikh < KH; ikh++) {
14721
+ for (int ikw = 0; ikw < KW; ikw++) {
14722
+ const int iiw = iow*s0 + ikw*d0 - p0;
14723
+ const int iih = ioh*s1 + ikh*d1 - p1;
14724
+
14725
+ if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) {
14726
+ dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_FP32_TO_FP16(src_data[iih*IW + iiw]);
14508
14727
  }
14509
14728
  }
14510
14729
  }
@@ -14521,30 +14740,22 @@ static void ggml_compute_forward_conv_2d_f16_f32(
14521
14740
  return;
14522
14741
  }
14523
14742
 
14524
- // total patches in dst
14525
- const int np = ne2;
14526
-
14527
- // patches per thread
14528
- const int dp = (np + nth - 1)/nth;
14743
+ ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14744
+ // wdata: [N*OH*OW, IC*KH*KW]
14745
+ // dst: result [N, OC, OH, OW]
14746
+ // src0: kernel [OC, IC, KH, KW]
14529
14747
 
14530
- // patch range for this thread
14531
- const int ip0 = dp*ith;
14532
- const int ip1 = MIN(ip0 + dp, np);
14748
+ int64_t m = OC;
14749
+ int64_t n = OH * OW;
14750
+ int64_t k = IC * KH * KW;
14533
14751
 
14534
- ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
14752
+ // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW]
14753
+ for (int i = 0; i < N; i++) {
14754
+ ggml_fp16_t * A = (ggml_fp16_t *)src0->data; // [m, k]
14755
+ ggml_fp16_t * B = (ggml_fp16_t *)wdata + i * m * k; // [n, k]
14756
+ float * C = (float *)dst->data + i * m * n; // [m * k]
14535
14757
 
14536
- for (int i3 = 0; i3 < ne3; i3++) {
14537
- for (int i2 = ip0; i2 < ip1; i2++) {
14538
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
14539
-
14540
- for (int i1 = 0; i1 < ne1; ++i1) {
14541
- for (int i0 = 0; i0 < ne0; ++i0) {
14542
- ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
14543
- (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
14544
- (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
14545
- }
14546
- }
14547
- }
14758
+ gemm_f16_out_f32(m, n, k, A, B, C, ith, nth);
14548
14759
  }
14549
14760
  }
14550
14761
 
@@ -14570,6 +14781,48 @@ static void ggml_compute_forward_conv_2d(
14570
14781
  }
14571
14782
  }
14572
14783
 
14784
+ static void ggml_compute_forward_conv_2d_stage_0(
14785
+ const struct ggml_compute_params * params,
14786
+ const struct ggml_tensor * src0,
14787
+ const struct ggml_tensor * src1,
14788
+ struct ggml_tensor * dst) {
14789
+ switch (src0->type) {
14790
+ case GGML_TYPE_F16:
14791
+ {
14792
+ ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst);
14793
+ } break;
14794
+ case GGML_TYPE_F32:
14795
+ {
14796
+ GGML_ASSERT(false);
14797
+ } break;
14798
+ default:
14799
+ {
14800
+ GGML_ASSERT(false);
14801
+ } break;
14802
+ }
14803
+ }
14804
+
14805
+ static void ggml_compute_forward_conv_2d_stage_1(
14806
+ const struct ggml_compute_params * params,
14807
+ const struct ggml_tensor * src0,
14808
+ const struct ggml_tensor * src1,
14809
+ struct ggml_tensor * dst) {
14810
+ switch (src0->type) {
14811
+ case GGML_TYPE_F16:
14812
+ {
14813
+ ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst);
14814
+ } break;
14815
+ case GGML_TYPE_F32:
14816
+ {
14817
+ GGML_ASSERT(false);
14818
+ } break;
14819
+ default:
14820
+ {
14821
+ GGML_ASSERT(false);
14822
+ } break;
14823
+ }
14824
+ }
14825
+
14573
14826
  // ggml_compute_forward_conv_transpose_2d
14574
14827
 
14575
14828
  static void ggml_compute_forward_conv_transpose_2d(
@@ -14628,6 +14881,8 @@ static void ggml_compute_forward_conv_transpose_2d(
14628
14881
  }
14629
14882
  }
14630
14883
 
14884
+ memset(dst->data, 0, ggml_nbytes(dst));
14885
+
14631
14886
  return;
14632
14887
  }
14633
14888
 
@@ -16126,7 +16381,6 @@ static void ggml_compute_forward_add_rel_pos_f32(
16126
16381
  const int ip0 = dp*ith;
16127
16382
  const int ip1 = MIN(ip0 + dp, np);
16128
16383
 
16129
-
16130
16384
  for (int64_t i13 = ip0; i13 < ip1; ++i13) {
16131
16385
  for (int64_t i12 = 0; i12 < ne12; ++i12) {
16132
16386
  for (int64_t i11 = 0; i11 < ne11; ++i11) {
@@ -16193,7 +16447,6 @@ static void ggml_compute_forward_map_unary_f32(
16193
16447
  }
16194
16448
  }
16195
16449
 
16196
-
16197
16450
  static void ggml_compute_forward_map_unary(
16198
16451
  const struct ggml_compute_params * params,
16199
16452
  const struct ggml_tensor * src0,
@@ -16241,7 +16494,6 @@ static void ggml_compute_forward_map_binary_f32(
16241
16494
  }
16242
16495
  }
16243
16496
 
16244
-
16245
16497
  static void ggml_compute_forward_map_binary(
16246
16498
  const struct ggml_compute_params * params,
16247
16499
  const struct ggml_tensor * src0,
@@ -16293,7 +16545,6 @@ static void ggml_compute_forward_map_custom2_f32(
16293
16545
  fun(dst, a, b);
16294
16546
  }
16295
16547
 
16296
-
16297
16548
  // ggml_compute_forward_map_custom3
16298
16549
 
16299
16550
  static void ggml_compute_forward_map_custom3_f32(
@@ -16568,7 +16819,6 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
16568
16819
  ggml_vec_sub_f32(nc, ds0, ds0, s1);
16569
16820
  ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
16570
16821
 
16571
-
16572
16822
  #ifndef NDEBUG
16573
16823
  for (int i = 0; i < nc; ++i) {
16574
16824
  assert(!isnan(ds0[i]));
@@ -16596,12 +16846,15 @@ static void ggml_compute_forward_cross_entropy_loss_back(
16596
16846
  }
16597
16847
  }
16598
16848
 
16599
-
16600
16849
  /////////////////////////////////
16601
16850
 
16602
16851
  static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
16603
16852
  GGML_ASSERT(params);
16604
16853
 
16854
+ if (tensor->op == GGML_OP_NONE) {
16855
+ return;
16856
+ }
16857
+
16605
16858
  #ifdef GGML_USE_CUBLAS
16606
16859
  bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
16607
16860
  if (skip_cpu) {
@@ -16804,6 +17057,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
16804
17057
  {
16805
17058
  ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor);
16806
17059
  } break;
17060
+ case GGML_OP_CONV_2D_STAGE_0:
17061
+ {
17062
+ ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor);
17063
+ } break;
17064
+ case GGML_OP_CONV_2D_STAGE_1:
17065
+ {
17066
+ ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor);
17067
+ } break;
16807
17068
  case GGML_OP_CONV_TRANSPOSE_2D:
16808
17069
  {
16809
17070
  ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor);
@@ -17733,11 +17994,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
17733
17994
  {
17734
17995
  GGML_ASSERT(false); // TODO: not implemented
17735
17996
  } break;
17997
+ case GGML_OP_CONV_TRANSPOSE_1D:
17998
+ {
17999
+ GGML_ASSERT(false); // TODO: not implemented
18000
+ } break;
17736
18001
  case GGML_OP_CONV_2D:
17737
18002
  {
17738
18003
  GGML_ASSERT(false); // TODO: not implemented
17739
18004
  } break;
17740
- case GGML_OP_CONV_TRANSPOSE_1D:
18005
+ case GGML_OP_CONV_2D_STAGE_0:
18006
+ {
18007
+ GGML_ASSERT(false); // TODO: not implemented
18008
+ } break;
18009
+ case GGML_OP_CONV_2D_STAGE_1:
17741
18010
  {
17742
18011
  GGML_ASSERT(false); // TODO: not implemented
17743
18012
  } break;
@@ -18666,6 +18935,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18666
18935
  const int64_t ne0 = node->ne[0];
18667
18936
  const int64_t ne1 = node->ne[1];
18668
18937
  const int64_t ne2 = node->ne[2];
18938
+ const int64_t ne3 = node->ne[3];
18669
18939
  const int64_t nk = ne00*ne01;
18670
18940
  const int64_t ew0 = nk * ne02;
18671
18941
 
@@ -18676,7 +18946,8 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18676
18946
 
18677
18947
  if (node->src[0]->type == GGML_TYPE_F16 &&
18678
18948
  node->src[1]->type == GGML_TYPE_F32) {
18679
- cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
18949
+ // im2col: [N*OH*OW, IC*KH*KW]
18950
+ cur = sizeof(ggml_fp16_t)*(ne3*ne0*ne1*ew0);
18680
18951
  } else if (node->src[0]->type == GGML_TYPE_F32 &&
18681
18952
  node->src[1]->type == GGML_TYPE_F32) {
18682
18953
  cur = sizeof(float)* (ne10*ne11*ne12);
@@ -18686,6 +18957,14 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
18686
18957
 
18687
18958
  work_size = MAX(work_size, cur);
18688
18959
  } break;
18960
+ case GGML_OP_CONV_2D_STAGE_0:
18961
+ {
18962
+ n_tasks = n_threads;
18963
+ } break;
18964
+ case GGML_OP_CONV_2D_STAGE_1:
18965
+ {
18966
+ n_tasks = n_threads;
18967
+ } break;
18689
18968
  case GGML_OP_CONV_TRANSPOSE_2D:
18690
18969
  {
18691
18970
  n_tasks = n_threads;
@@ -19874,7 +20153,6 @@ static enum ggml_opt_result ggml_opt_adam(
19874
20153
 
19875
20154
  opt->loss_after = fx;
19876
20155
 
19877
-
19878
20156
  // check convergence
19879
20157
  if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
19880
20158
  GGML_PRINT_DEBUG("converged\n");