llama_cpp 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,7 @@
25
25
  #include <float.h>
26
26
  #include <limits.h>
27
27
  #include <stdarg.h>
28
+ #include <signal.h>
28
29
 
29
30
  #ifdef GGML_USE_METAL
30
31
  #include <unistd.h>
@@ -49,23 +50,23 @@
49
50
  typedef volatile LONG atomic_int;
50
51
  typedef atomic_int atomic_bool;
51
52
 
52
- static void atomic_store(atomic_int* ptr, LONG val) {
53
+ static void atomic_store(atomic_int * ptr, LONG val) {
53
54
  InterlockedExchange(ptr, val);
54
55
  }
55
- static LONG atomic_load(atomic_int* ptr) {
56
+ static LONG atomic_load(atomic_int * ptr) {
56
57
  return InterlockedCompareExchange(ptr, 0, 0);
57
58
  }
58
- static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
59
+ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
59
60
  return InterlockedExchangeAdd(ptr, inc);
60
61
  }
61
- static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
62
+ static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
62
63
  return atomic_fetch_add(ptr, -(dec));
63
64
  }
64
65
 
65
66
  typedef HANDLE pthread_t;
66
67
 
67
68
  typedef DWORD thread_ret_t;
68
- static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
69
+ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
69
70
  (void) unused;
70
71
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
71
72
  if (handle == NULL)
@@ -77,7 +78,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
77
78
  return 0;
78
79
  }
79
80
 
80
- static int pthread_join(pthread_t thread, void* unused) {
81
+ static int pthread_join(pthread_t thread, void * unused) {
81
82
  (void) unused;
82
83
  return (int) WaitForSingleObject(thread, INFINITE);
83
84
  }
@@ -90,7 +91,7 @@ static int sched_yield (void) {
90
91
  #include <pthread.h>
91
92
  #include <stdatomic.h>
92
93
 
93
- typedef void* thread_ret_t;
94
+ typedef void * thread_ret_t;
94
95
 
95
96
  #include <sys/types.h>
96
97
  #include <sys/stat.h>
@@ -247,7 +248,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
247
248
  #include "ggml-opencl.h"
248
249
  #endif
249
250
  #elif defined(GGML_USE_OPENBLAS)
251
+ #if defined(GGML_BLAS_USE_MKL)
252
+ #include <mkl.h>
253
+ #else
250
254
  #include <cblas.h>
255
+ #endif
251
256
  #elif defined(GGML_USE_CUBLAS)
252
257
  #include "ggml-cuda.h"
253
258
  #elif defined(GGML_USE_CLBLAST)
@@ -3782,6 +3787,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3782
3787
  "CLAMP",
3783
3788
  "CONV_1D",
3784
3789
  "CONV_2D",
3790
+ "POOL_1D",
3791
+ "POOL_2D",
3785
3792
 
3786
3793
  "FLASH_ATTN",
3787
3794
  "FLASH_FF",
@@ -3800,7 +3807,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3800
3807
  "CROSS_ENTROPY_LOSS_BACK",
3801
3808
  };
3802
3809
 
3803
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3810
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3804
3811
 
3805
3812
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3806
3813
  "none",
@@ -3860,6 +3867,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3860
3867
  "clamp(x)",
3861
3868
  "conv_1d(x)",
3862
3869
  "conv_2d(x)",
3870
+ "pool_1d(x)",
3871
+ "pool_2d(x)",
3863
3872
 
3864
3873
  "flash_attn(x)",
3865
3874
  "flash_ff(x)",
@@ -3878,7 +3887,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3878
3887
  "cross_entropy_loss_back(x,y)",
3879
3888
  };
3880
3889
 
3881
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3890
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3891
+
3892
+ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3882
3893
 
3883
3894
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3884
3895
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4157,10 +4168,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
4157
4168
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4158
4169
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4159
4170
 
4160
- return
4161
- (t0->ne[0] == t1->ne[0]) &&
4162
- (t0->ne[2] == t1->ne[2]) &&
4163
- (t0->ne[3] == t1->ne[3]);
4171
+ return (t0->ne[0] == t1->ne[0]) &&
4172
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4173
+ (t1->ne[3]%t0->ne[3] == 0);
4164
4174
  }
4165
4175
 
4166
4176
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4580,17 +4590,14 @@ struct ggml_tensor * ggml_new_tensor_impl(
4580
4590
  /*.op =*/ GGML_OP_NONE,
4581
4591
  /*.is_param =*/ false,
4582
4592
  /*.grad =*/ NULL,
4583
- /*.src0 =*/ NULL,
4584
- /*.src1 =*/ NULL,
4585
- /*.opt =*/ { NULL },
4586
- /*.n_tasks =*/ 0,
4593
+ /*.src =*/ { NULL },
4587
4594
  /*.perf_runs =*/ 0,
4588
4595
  /*.perf_cycles =*/ 0,
4589
4596
  /*.perf_time_us =*/ 0,
4590
4597
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4591
4598
  /*.name =*/ { 0 },
4592
4599
  /*.extra =*/ NULL,
4593
- /*.pad =*/ { 0 },
4600
+ /*.padding =*/ { 0 },
4594
4601
  };
4595
4602
 
4596
4603
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@@ -4722,7 +4729,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
4722
4729
  {
4723
4730
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4724
4731
  for (int i = 0; i < n; i++) {
4725
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4732
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4726
4733
  }
4727
4734
  } break;
4728
4735
  case GGML_TYPE_F32:
@@ -4774,7 +4781,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
4774
4781
  {
4775
4782
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4776
4783
  for (int i = 0; i < n; i++) {
4777
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4784
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4778
4785
  }
4779
4786
  } break;
4780
4787
  case GGML_TYPE_F32:
@@ -5009,8 +5016,8 @@ struct ggml_tensor * ggml_dup_impl(
5009
5016
 
5010
5017
  result->op = GGML_OP_DUP;
5011
5018
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5012
- result->src0 = a;
5013
- result->src1 = NULL;
5019
+ result->src[0] = a;
5020
+ result->src[1] = NULL;
5014
5021
 
5015
5022
  return result;
5016
5023
  }
@@ -5034,11 +5041,15 @@ struct ggml_tensor * ggml_add_impl(
5034
5041
  struct ggml_tensor * a,
5035
5042
  struct ggml_tensor * b,
5036
5043
  bool inplace) {
5037
- GGML_ASSERT(ggml_are_same_shape(a, b));
5044
+ // TODO: support less-strict constraint
5045
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5046
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5038
5047
 
5039
5048
  bool is_node = false;
5040
5049
 
5041
- if (a->grad || b->grad) {
5050
+ if (!inplace && (a->grad || b->grad)) {
5051
+ // TODO: support backward pass for broadcasting
5052
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5042
5053
  is_node = true;
5043
5054
  }
5044
5055
 
@@ -5046,8 +5057,8 @@ struct ggml_tensor * ggml_add_impl(
5046
5057
 
5047
5058
  result->op = GGML_OP_ADD;
5048
5059
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5049
- result->src0 = a;
5050
- result->src1 = b;
5060
+ result->src[0] = a;
5061
+ result->src[1] = b;
5051
5062
 
5052
5063
  return result;
5053
5064
  }
@@ -5086,8 +5097,8 @@ struct ggml_tensor * ggml_add1_impl(
5086
5097
 
5087
5098
  result->op = GGML_OP_ADD1;
5088
5099
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5089
- result->src0 = a;
5090
- result->src1 = b;
5100
+ result->src[0] = a;
5101
+ result->src[1] = b;
5091
5102
 
5092
5103
  return result;
5093
5104
  }
@@ -5144,9 +5155,9 @@ struct ggml_tensor * ggml_acc_impl(
5144
5155
 
5145
5156
  result->op = GGML_OP_ACC;
5146
5157
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5147
- result->src0 = a;
5148
- result->src1 = b;
5149
- result->opt[0] = c;
5158
+ result->src[0] = a;
5159
+ result->src[1] = b;
5160
+ result->src[2] = c;
5150
5161
 
5151
5162
  return result;
5152
5163
  }
@@ -5192,8 +5203,8 @@ struct ggml_tensor * ggml_sub_impl(
5192
5203
 
5193
5204
  result->op = GGML_OP_SUB;
5194
5205
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5195
- result->src0 = a;
5196
- result->src1 = b;
5206
+ result->src[0] = a;
5207
+ result->src[1] = b;
5197
5208
 
5198
5209
  return result;
5199
5210
  }
@@ -5239,8 +5250,8 @@ struct ggml_tensor * ggml_mul_impl(
5239
5250
 
5240
5251
  result->op = GGML_OP_MUL;
5241
5252
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5242
- result->src0 = a;
5243
- result->src1 = b;
5253
+ result->src[0] = a;
5254
+ result->src[1] = b;
5244
5255
 
5245
5256
  return result;
5246
5257
  }
@@ -5282,8 +5293,8 @@ struct ggml_tensor * ggml_div_impl(
5282
5293
 
5283
5294
  result->op = GGML_OP_DIV;
5284
5295
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5285
- result->src0 = a;
5286
- result->src1 = b;
5296
+ result->src[0] = a;
5297
+ result->src[1] = b;
5287
5298
 
5288
5299
  return result;
5289
5300
  }
@@ -5318,8 +5329,8 @@ struct ggml_tensor * ggml_sqr_impl(
5318
5329
 
5319
5330
  result->op = GGML_OP_SQR;
5320
5331
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5321
- result->src0 = a;
5322
- result->src1 = NULL;
5332
+ result->src[0] = a;
5333
+ result->src[1] = NULL;
5323
5334
 
5324
5335
  return result;
5325
5336
  }
@@ -5352,8 +5363,8 @@ struct ggml_tensor * ggml_sqrt_impl(
5352
5363
 
5353
5364
  result->op = GGML_OP_SQRT;
5354
5365
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5355
- result->src0 = a;
5356
- result->src1 = NULL;
5366
+ result->src[0] = a;
5367
+ result->src[1] = NULL;
5357
5368
 
5358
5369
  return result;
5359
5370
  }
@@ -5387,8 +5398,8 @@ struct ggml_tensor * ggml_log_impl(
5387
5398
 
5388
5399
  result->op = GGML_OP_LOG;
5389
5400
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5390
- result->src0 = a;
5391
- result->src1 = NULL;
5401
+ result->src[0] = a;
5402
+ result->src[1] = NULL;
5392
5403
 
5393
5404
  return result;
5394
5405
  }
@@ -5420,8 +5431,8 @@ struct ggml_tensor * ggml_sum(
5420
5431
 
5421
5432
  result->op = GGML_OP_SUM;
5422
5433
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5423
- result->src0 = a;
5424
- result->src1 = NULL;
5434
+ result->src[0] = a;
5435
+ result->src[1] = NULL;
5425
5436
 
5426
5437
  return result;
5427
5438
  }
@@ -5447,8 +5458,8 @@ struct ggml_tensor * ggml_sum_rows(
5447
5458
 
5448
5459
  result->op = GGML_OP_SUM_ROWS;
5449
5460
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5450
- result->src0 = a;
5451
- result->src1 = NULL;
5461
+ result->src[0] = a;
5462
+ result->src[1] = NULL;
5452
5463
 
5453
5464
  return result;
5454
5465
  }
@@ -5470,8 +5481,8 @@ struct ggml_tensor * ggml_mean(
5470
5481
 
5471
5482
  result->op = GGML_OP_MEAN;
5472
5483
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5473
- result->src0 = a;
5474
- result->src1 = NULL;
5484
+ result->src[0] = a;
5485
+ result->src[1] = NULL;
5475
5486
 
5476
5487
  return result;
5477
5488
  }
@@ -5494,8 +5505,8 @@ struct ggml_tensor * ggml_argmax(
5494
5505
 
5495
5506
  result->op = GGML_OP_ARGMAX;
5496
5507
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5497
- result->src0 = a;
5498
- result->src1 = NULL;
5508
+ result->src[0] = a;
5509
+ result->src[1] = NULL;
5499
5510
 
5500
5511
  return result;
5501
5512
  }
@@ -5522,8 +5533,8 @@ struct ggml_tensor * ggml_repeat(
5522
5533
 
5523
5534
  result->op = GGML_OP_REPEAT;
5524
5535
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5525
- result->src0 = a;
5526
- result->src1 = b;
5536
+ result->src[0] = a;
5537
+ result->src[1] = b;
5527
5538
 
5528
5539
  return result;
5529
5540
  }
@@ -5550,8 +5561,8 @@ struct ggml_tensor * ggml_repeat_back(
5550
5561
 
5551
5562
  result->op = GGML_OP_REPEAT_BACK;
5552
5563
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5553
- result->src0 = a;
5554
- result->src1 = b;
5564
+ result->src[0] = a;
5565
+ result->src[1] = b;
5555
5566
 
5556
5567
  return result;
5557
5568
  }
@@ -5572,8 +5583,8 @@ struct ggml_tensor * ggml_abs_impl(
5572
5583
 
5573
5584
  result->op = GGML_OP_ABS;
5574
5585
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5575
- result->src0 = a;
5576
- result->src1 = NULL;
5586
+ result->src[0] = a;
5587
+ result->src[1] = NULL;
5577
5588
 
5578
5589
  return result;
5579
5590
  }
@@ -5607,8 +5618,8 @@ struct ggml_tensor * ggml_sgn_impl(
5607
5618
 
5608
5619
  result->op = GGML_OP_SGN;
5609
5620
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5610
- result->src0 = a;
5611
- result->src1 = NULL;
5621
+ result->src[0] = a;
5622
+ result->src[1] = NULL;
5612
5623
 
5613
5624
  return result;
5614
5625
  }
@@ -5641,8 +5652,8 @@ struct ggml_tensor * ggml_neg_impl(
5641
5652
 
5642
5653
  result->op = GGML_OP_NEG;
5643
5654
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5644
- result->src0 = a;
5645
- result->src1 = NULL;
5655
+ result->src[0] = a;
5656
+ result->src[1] = NULL;
5646
5657
 
5647
5658
  return result;
5648
5659
  }
@@ -5675,8 +5686,8 @@ struct ggml_tensor * ggml_step_impl(
5675
5686
 
5676
5687
  result->op = GGML_OP_STEP;
5677
5688
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5678
- result->src0 = a;
5679
- result->src1 = NULL;
5689
+ result->src[0] = a;
5690
+ result->src[1] = NULL;
5680
5691
 
5681
5692
  return result;
5682
5693
  }
@@ -5709,8 +5720,8 @@ struct ggml_tensor * ggml_tanh_impl(
5709
5720
 
5710
5721
  result->op = GGML_OP_TANH;
5711
5722
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5712
- result->src0 = a;
5713
- result->src1 = NULL;
5723
+ result->src[0] = a;
5724
+ result->src[1] = NULL;
5714
5725
 
5715
5726
  return result;
5716
5727
  }
@@ -5743,8 +5754,8 @@ struct ggml_tensor * ggml_elu_impl(
5743
5754
 
5744
5755
  result->op = GGML_OP_ELU;
5745
5756
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5746
- result->src0 = a;
5747
- result->src1 = NULL;
5757
+ result->src[0] = a;
5758
+ result->src[1] = NULL;
5748
5759
 
5749
5760
  return result;
5750
5761
  }
@@ -5777,8 +5788,8 @@ struct ggml_tensor * ggml_relu_impl(
5777
5788
 
5778
5789
  result->op = GGML_OP_RELU;
5779
5790
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5780
- result->src0 = a;
5781
- result->src1 = NULL;
5791
+ result->src[0] = a;
5792
+ result->src[1] = NULL;
5782
5793
 
5783
5794
  return result;
5784
5795
  }
@@ -5811,8 +5822,8 @@ struct ggml_tensor * ggml_gelu_impl(
5811
5822
 
5812
5823
  result->op = GGML_OP_GELU;
5813
5824
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5814
- result->src0 = a;
5815
- result->src1 = NULL;
5825
+ result->src[0] = a;
5826
+ result->src[1] = NULL;
5816
5827
 
5817
5828
  return result;
5818
5829
  }
@@ -5845,8 +5856,8 @@ struct ggml_tensor * ggml_gelu_quick_impl(
5845
5856
 
5846
5857
  result->op = GGML_OP_GELU_QUICK;
5847
5858
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5848
- result->src0 = a;
5849
- result->src1 = NULL;
5859
+ result->src[0] = a;
5860
+ result->src[1] = NULL;
5850
5861
 
5851
5862
  return result;
5852
5863
  }
@@ -5879,8 +5890,8 @@ struct ggml_tensor * ggml_silu_impl(
5879
5890
 
5880
5891
  result->op = GGML_OP_SILU;
5881
5892
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5882
- result->src0 = a;
5883
- result->src1 = NULL;
5893
+ result->src[0] = a;
5894
+ result->src[1] = NULL;
5884
5895
 
5885
5896
  return result;
5886
5897
  }
@@ -5914,8 +5925,8 @@ struct ggml_tensor * ggml_silu_back(
5914
5925
 
5915
5926
  result->op = GGML_OP_SILU_BACK;
5916
5927
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5917
- result->src0 = a;
5918
- result->src1 = b;
5928
+ result->src[0] = a;
5929
+ result->src[1] = b;
5919
5930
 
5920
5931
  return result;
5921
5932
  }
@@ -5937,8 +5948,8 @@ struct ggml_tensor * ggml_norm_impl(
5937
5948
 
5938
5949
  result->op = GGML_OP_NORM;
5939
5950
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5940
- result->src0 = a;
5941
- result->src1 = NULL; // TODO: maybe store epsilon here?
5951
+ result->src[0] = a;
5952
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5942
5953
 
5943
5954
  return result;
5944
5955
  }
@@ -5969,8 +5980,8 @@ struct ggml_tensor * ggml_rms_norm_impl(
5969
5980
 
5970
5981
  result->op = GGML_OP_RMS_NORM;
5971
5982
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5972
- result->src0 = a;
5973
- result->src1 = NULL; // TODO: maybe store epsilon here?
5983
+ result->src[0] = a;
5984
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5974
5985
 
5975
5986
  return result;
5976
5987
  }
@@ -6002,8 +6013,8 @@ struct ggml_tensor * ggml_rms_norm_back(
6002
6013
 
6003
6014
  result->op = GGML_OP_RMS_NORM_BACK;
6004
6015
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6005
- result->src0 = a;
6006
- result->src1 = b;
6016
+ result->src[0] = a;
6017
+ result->src[1] = b;
6007
6018
 
6008
6019
  return result;
6009
6020
  }
@@ -6024,13 +6035,13 @@ struct ggml_tensor * ggml_mul_mat(
6024
6035
  is_node = true;
6025
6036
  }
6026
6037
 
6027
- const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
6028
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6038
+ const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
6039
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6029
6040
 
6030
6041
  result->op = GGML_OP_MUL_MAT;
6031
6042
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6032
- result->src0 = a;
6033
- result->src1 = b;
6043
+ result->src[0] = a;
6044
+ result->src[1] = b;
6034
6045
 
6035
6046
  return result;
6036
6047
  }
@@ -6055,8 +6066,8 @@ struct ggml_tensor * ggml_out_prod(
6055
6066
 
6056
6067
  result->op = GGML_OP_OUT_PROD;
6057
6068
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6058
- result->src0 = a;
6059
- result->src1 = b;
6069
+ result->src[0] = a;
6070
+ result->src[1] = b;
6060
6071
 
6061
6072
  return result;
6062
6073
  }
@@ -6081,8 +6092,8 @@ struct ggml_tensor * ggml_scale_impl(
6081
6092
 
6082
6093
  result->op = GGML_OP_SCALE;
6083
6094
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6084
- result->src0 = a;
6085
- result->src1 = b;
6095
+ result->src[0] = a;
6096
+ result->src[1] = b;
6086
6097
 
6087
6098
  return result;
6088
6099
  }
@@ -6137,9 +6148,9 @@ struct ggml_tensor * ggml_set_impl(
6137
6148
 
6138
6149
  result->op = GGML_OP_SET;
6139
6150
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6140
- result->src0 = a;
6141
- result->src1 = b;
6142
- result->opt[0] = c;
6151
+ result->src[0] = a;
6152
+ result->src[1] = b;
6153
+ result->src[2] = c;
6143
6154
 
6144
6155
  return result;
6145
6156
  }
@@ -6226,8 +6237,8 @@ struct ggml_tensor * ggml_cpy_impl(
6226
6237
 
6227
6238
  result->op = GGML_OP_CPY;
6228
6239
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6229
- result->src0 = a;
6230
- result->src1 = b;
6240
+ result->src[0] = a;
6241
+ result->src[1] = b;
6231
6242
 
6232
6243
  return result;
6233
6244
  }
@@ -6263,8 +6274,8 @@ struct ggml_tensor * ggml_cont_impl(
6263
6274
 
6264
6275
  result->op = GGML_OP_CONT;
6265
6276
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6266
- result->src0 = a;
6267
- result->src1 = NULL;
6277
+ result->src[0] = a;
6278
+ result->src[1] = NULL;
6268
6279
 
6269
6280
  return result;
6270
6281
  }
@@ -6307,8 +6318,8 @@ struct ggml_tensor * ggml_reshape(
6307
6318
 
6308
6319
  result->op = GGML_OP_RESHAPE;
6309
6320
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6310
- result->src0 = a;
6311
- result->src1 = NULL;
6321
+ result->src[0] = a;
6322
+ result->src[1] = NULL;
6312
6323
 
6313
6324
  return result;
6314
6325
  }
@@ -6332,8 +6343,8 @@ struct ggml_tensor * ggml_reshape_1d(
6332
6343
 
6333
6344
  result->op = GGML_OP_RESHAPE;
6334
6345
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6335
- result->src0 = a;
6336
- result->src1 = NULL;
6346
+ result->src[0] = a;
6347
+ result->src[1] = NULL;
6337
6348
 
6338
6349
  return result;
6339
6350
  }
@@ -6358,8 +6369,8 @@ struct ggml_tensor * ggml_reshape_2d(
6358
6369
 
6359
6370
  result->op = GGML_OP_RESHAPE;
6360
6371
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6361
- result->src0 = a;
6362
- result->src1 = NULL;
6372
+ result->src[0] = a;
6373
+ result->src[1] = NULL;
6363
6374
 
6364
6375
  return result;
6365
6376
  }
@@ -6385,8 +6396,8 @@ struct ggml_tensor * ggml_reshape_3d(
6385
6396
 
6386
6397
  result->op = GGML_OP_RESHAPE;
6387
6398
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6388
- result->src0 = a;
6389
- result->src1 = NULL;
6399
+ result->src[0] = a;
6400
+ result->src[1] = NULL;
6390
6401
 
6391
6402
  return result;
6392
6403
  }
@@ -6414,8 +6425,8 @@ struct ggml_tensor * ggml_reshape_4d(
6414
6425
 
6415
6426
  result->op = GGML_OP_RESHAPE;
6416
6427
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6417
- result->src0 = a;
6418
- result->src1 = NULL;
6428
+ result->src[0] = a;
6429
+ result->src[1] = NULL;
6419
6430
 
6420
6431
  return result;
6421
6432
  }
@@ -6447,9 +6458,9 @@ struct ggml_tensor * ggml_view_1d(
6447
6458
 
6448
6459
  result->op = GGML_OP_VIEW;
6449
6460
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6450
- result->src0 = a;
6451
- result->src1 = NULL;
6452
- result->opt[0] = offs;
6461
+ result->src[0] = a;
6462
+ result->src[1] = NULL;
6463
+ result->src[2] = offs;
6453
6464
 
6454
6465
  return result;
6455
6466
  }
@@ -6489,9 +6500,9 @@ struct ggml_tensor * ggml_view_2d(
6489
6500
 
6490
6501
  result->op = GGML_OP_VIEW;
6491
6502
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6492
- result->src0 = a;
6493
- result->src1 = NULL;
6494
- result->opt[0] = offs;
6503
+ result->src[0] = a;
6504
+ result->src[1] = NULL;
6505
+ result->src[2] = offs;
6495
6506
 
6496
6507
  return result;
6497
6508
  }
@@ -6533,9 +6544,9 @@ struct ggml_tensor * ggml_view_3d(
6533
6544
 
6534
6545
  result->op = GGML_OP_VIEW;
6535
6546
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6536
- result->src0 = a;
6537
- result->src1 = NULL;
6538
- result->opt[0] = offs;
6547
+ result->src[0] = a;
6548
+ result->src[1] = NULL;
6549
+ result->src[2] = offs;
6539
6550
 
6540
6551
  return result;
6541
6552
  }
@@ -6579,9 +6590,9 @@ struct ggml_tensor * ggml_view_4d(
6579
6590
 
6580
6591
  result->op = GGML_OP_VIEW;
6581
6592
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6582
- result->src0 = a;
6583
- result->src1 = NULL;
6584
- result->opt[0] = offs;
6593
+ result->src[0] = a;
6594
+ result->src[1] = NULL;
6595
+ result->src[2] = offs;
6585
6596
 
6586
6597
  return result;
6587
6598
  }
@@ -6641,8 +6652,8 @@ struct ggml_tensor * ggml_permute(
6641
6652
 
6642
6653
  result->op = GGML_OP_PERMUTE;
6643
6654
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6644
- result->src0 = a;
6645
- result->src1 = NULL;
6655
+ result->src[0] = a;
6656
+ result->src[1] = NULL;
6646
6657
 
6647
6658
  if (is_node) {
6648
6659
  ggml_scratch_save(ctx);
@@ -6656,7 +6667,7 @@ struct ggml_tensor * ggml_permute(
6656
6667
 
6657
6668
  ggml_scratch_load(ctx);
6658
6669
 
6659
- result->opt[0] = b;
6670
+ result->src[2] = b;
6660
6671
  }
6661
6672
 
6662
6673
  return result;
@@ -6684,8 +6695,8 @@ struct ggml_tensor * ggml_transpose(
6684
6695
 
6685
6696
  result->op = GGML_OP_TRANSPOSE;
6686
6697
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6687
- result->src0 = a;
6688
- result->src1 = NULL;
6698
+ result->src[0] = a;
6699
+ result->src[1] = NULL;
6689
6700
 
6690
6701
  return result;
6691
6702
  }
@@ -6710,8 +6721,8 @@ struct ggml_tensor * ggml_get_rows(
6710
6721
 
6711
6722
  result->op = GGML_OP_GET_ROWS;
6712
6723
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6713
- result->src0 = a;
6714
- result->src1 = b;
6724
+ result->src[0] = a;
6725
+ result->src[1] = b;
6715
6726
 
6716
6727
  return result;
6717
6728
  }
@@ -6738,9 +6749,9 @@ struct ggml_tensor * ggml_get_rows_back(
6738
6749
 
6739
6750
  result->op = GGML_OP_GET_ROWS_BACK;
6740
6751
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6741
- result->src0 = a;
6742
- result->src1 = b;
6743
- result->opt[0] = c;
6752
+ result->src[0] = a;
6753
+ result->src[1] = b;
6754
+ result->src[2] = c;
6744
6755
 
6745
6756
  return result;
6746
6757
  }
@@ -6762,8 +6773,8 @@ struct ggml_tensor * ggml_diag(
6762
6773
 
6763
6774
  result->op = GGML_OP_DIAG;
6764
6775
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6765
- result->src0 = a;
6766
- result->src1 = NULL;
6776
+ result->src[0] = a;
6777
+ result->src[1] = NULL;
6767
6778
 
6768
6779
  return result;
6769
6780
  }
@@ -6795,8 +6806,8 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
6795
6806
 
6796
6807
  result->op = GGML_OP_DIAG_MASK_INF;
6797
6808
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6798
- result->src0 = a;
6799
- result->src1 = b;
6809
+ result->src[0] = a;
6810
+ result->src[1] = b;
6800
6811
 
6801
6812
  return result;
6802
6813
  }
@@ -6843,8 +6854,8 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
6843
6854
 
6844
6855
  result->op = GGML_OP_DIAG_MASK_ZERO;
6845
6856
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6846
- result->src0 = a;
6847
- result->src1 = b;
6857
+ result->src[0] = a;
6858
+ result->src[1] = b;
6848
6859
 
6849
6860
  return result;
6850
6861
  }
@@ -6879,8 +6890,8 @@ struct ggml_tensor * ggml_soft_max_impl(
6879
6890
 
6880
6891
  result->op = GGML_OP_SOFT_MAX;
6881
6892
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6882
- result->src0 = a;
6883
- result->src1 = NULL;
6893
+ result->src[0] = a;
6894
+ result->src[1] = NULL;
6884
6895
 
6885
6896
  return result;
6886
6897
  }
@@ -6915,8 +6926,8 @@ struct ggml_tensor * ggml_soft_max_back_impl(
6915
6926
 
6916
6927
  result->op = GGML_OP_SOFT_MAX_BACK;
6917
6928
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6918
- result->src0 = a;
6919
- result->src1 = b;
6929
+ result->src[0] = a;
6930
+ result->src[1] = b;
6920
6931
 
6921
6932
  return result;
6922
6933
  }
@@ -6967,8 +6978,8 @@ struct ggml_tensor * ggml_rope_impl(
6967
6978
 
6968
6979
  result->op = GGML_OP_ROPE;
6969
6980
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6970
- result->src0 = a;
6971
- result->src1 = b;
6981
+ result->src[0] = a;
6982
+ result->src[1] = b;
6972
6983
 
6973
6984
  return result;
6974
6985
  }
@@ -7025,8 +7036,8 @@ struct ggml_tensor * ggml_rope_back(
7025
7036
 
7026
7037
  result->op = GGML_OP_ROPE_BACK;
7027
7038
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7028
- result->src0 = a;
7029
- result->src1 = b;
7039
+ result->src[0] = a;
7040
+ result->src[1] = b;
7030
7041
 
7031
7042
  return result;
7032
7043
  }
@@ -7064,8 +7075,8 @@ struct ggml_tensor * ggml_alibi(
7064
7075
 
7065
7076
  result->op = GGML_OP_ALIBI;
7066
7077
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7067
- result->src0 = a;
7068
- result->src1 = b;
7078
+ result->src[0] = a;
7079
+ result->src[1] = b;
7069
7080
 
7070
7081
  return result;
7071
7082
  }
@@ -7098,8 +7109,8 @@ struct ggml_tensor * ggml_clamp(
7098
7109
 
7099
7110
  result->op = GGML_OP_CLAMP;
7100
7111
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7101
- result->src0 = a;
7102
- result->src1 = b;
7112
+ result->src[0] = a;
7113
+ result->src[1] = b;
7103
7114
 
7104
7115
  return result;
7105
7116
  }
@@ -7141,9 +7152,9 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7141
7152
 
7142
7153
  result->op = GGML_OP_CONV_1D;
7143
7154
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7144
- result->src0 = a;
7145
- result->src1 = b;
7146
- result->opt[0] = c;
7155
+ result->src[0] = a;
7156
+ result->src[1] = b;
7157
+ result->src[2] = c;
7147
7158
 
7148
7159
  return result;
7149
7160
  }
@@ -7161,7 +7172,6 @@ struct ggml_tensor* ggml_conv_2d(
7161
7172
  int d0,
7162
7173
  int d1) {
7163
7174
 
7164
- GGML_ASSERT(b->ne[3] == 1);
7165
7175
  GGML_ASSERT(a->ne[2] == b->ne[2]);
7166
7176
  bool is_node = false;
7167
7177
 
@@ -7173,7 +7183,7 @@ struct ggml_tensor* ggml_conv_2d(
7173
7183
  const int64_t ne[4] = {
7174
7184
  ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7175
7185
  ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
7176
- a->ne[3], 1,
7186
+ a->ne[3], b->ne[3],
7177
7187
  };
7178
7188
  struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7179
7189
 
@@ -7189,9 +7199,9 @@ struct ggml_tensor* ggml_conv_2d(
7189
7199
 
7190
7200
  result->op = GGML_OP_CONV_2D;
7191
7201
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7192
- result->src0 = a;
7193
- result->src1 = b;
7194
- result->opt[0] = c;
7202
+ result->src[0] = a;
7203
+ result->src[1] = b;
7204
+ result->src[2] = c;
7195
7205
 
7196
7206
  return result;
7197
7207
 
@@ -7208,6 +7218,98 @@ struct ggml_tensor* ggml_conv_1d_ph(
7208
7218
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7209
7219
  }
7210
7220
 
7221
+
7222
+ // ggml_pool_*
7223
+
7224
+ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
7225
+ return (ins + 2 * p - ks) / s + 1;
7226
+ }
7227
+
7228
+ // ggml_pool_2d
7229
+
7230
+ struct ggml_tensor* ggml_pool_1d(
7231
+ struct ggml_context * ctx,
7232
+ struct ggml_tensor * a,
7233
+ enum ggml_op_pool op,
7234
+ int k0,
7235
+ int s0,
7236
+ int p0) {
7237
+
7238
+ bool is_node = false;
7239
+
7240
+ if (a->grad) {
7241
+ GGML_ASSERT(false); // TODO: implement backward
7242
+ is_node = true;
7243
+ }
7244
+
7245
+ const int64_t ne[3] = {
7246
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7247
+ a->ne[1],
7248
+ };
7249
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7250
+
7251
+ ggml_scratch_save(ctx);
7252
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7253
+ ((int32_t*)c->data)[0] = op;
7254
+ ((int32_t*)c->data)[1] = k0;
7255
+ ((int32_t*)c->data)[2] = s0;
7256
+ ((int32_t*)c->data)[3] = p0;
7257
+ ggml_scratch_load(ctx);
7258
+
7259
+ result->op = GGML_OP_POOL_1D;
7260
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7261
+ result->src[0] = a;
7262
+ result->src[1] = c;
7263
+
7264
+ return result;
7265
+ }
7266
+
7267
+ // ggml_pool_2d
7268
+
7269
+ struct ggml_tensor* ggml_pool_2d(
7270
+ struct ggml_context * ctx,
7271
+ struct ggml_tensor * a,
7272
+ enum ggml_op_pool op,
7273
+ int k0,
7274
+ int k1,
7275
+ int s0,
7276
+ int s1,
7277
+ int p0,
7278
+ int p1) {
7279
+
7280
+ bool is_node = false;
7281
+
7282
+ if (a->grad) {
7283
+ GGML_ASSERT(false); // TODO: implement backward
7284
+ is_node = true;
7285
+ }
7286
+
7287
+ const int64_t ne[3] = {
7288
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7289
+ ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
7290
+ a->ne[2],
7291
+ };
7292
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7293
+
7294
+ ggml_scratch_save(ctx);
7295
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
7296
+ ((int32_t*)c->data)[0] = op;
7297
+ ((int32_t*)c->data)[1] = k0;
7298
+ ((int32_t*)c->data)[2] = k1;
7299
+ ((int32_t*)c->data)[3] = s0;
7300
+ ((int32_t*)c->data)[4] = s1;
7301
+ ((int32_t*)c->data)[5] = p0;
7302
+ ((int32_t*)c->data)[6] = p1;
7303
+ ggml_scratch_load(ctx);
7304
+
7305
+ result->op = GGML_OP_POOL_2D;
7306
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7307
+ result->src[0] = a;
7308
+ result->src[1] = c;
7309
+
7310
+ return result;
7311
+ }
7312
+
7211
7313
  // ggml_flash_attn
7212
7314
 
7213
7315
  struct ggml_tensor * ggml_flash_attn(
@@ -7230,10 +7332,10 @@ struct ggml_tensor * ggml_flash_attn(
7230
7332
 
7231
7333
  result->op = GGML_OP_FLASH_ATTN;
7232
7334
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7233
- result->src0 = q;
7234
- result->src1 = k;
7235
- result->opt[0] = v;
7236
- result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
7335
+ result->src[0] = q;
7336
+ result->src[1] = k;
7337
+ result->src[2] = v;
7338
+ result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
7237
7339
 
7238
7340
  return result;
7239
7341
  }
@@ -7261,11 +7363,11 @@ struct ggml_tensor * ggml_flash_ff(
7261
7363
 
7262
7364
  result->op = GGML_OP_FLASH_FF;
7263
7365
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7264
- result->src0 = a;
7265
- result->src1 = b0;
7266
- result->opt[0] = b1;
7267
- result->opt[1] = c0;
7268
- result->opt[2] = c1;
7366
+ result->src[0] = a;
7367
+ result->src[1] = b0;
7368
+ result->src[2] = b1;
7369
+ result->src[3] = c0;
7370
+ result->src[4] = c1;
7269
7371
 
7270
7372
  return result;
7271
7373
  }
@@ -7325,11 +7427,11 @@ struct ggml_tensor * ggml_flash_attn_back(
7325
7427
 
7326
7428
  result->op = GGML_OP_FLASH_ATTN_BACK;
7327
7429
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7328
- result->src0 = q;
7329
- result->src1 = k;
7330
- result->opt[0] = v;
7331
- result->opt[1] = d;
7332
- result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
7430
+ result->src[0] = q;
7431
+ result->src[1] = k;
7432
+ result->src[2] = v;
7433
+ result->src[3] = d;
7434
+ result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
7333
7435
 
7334
7436
  return result;
7335
7437
  }
@@ -7374,9 +7476,9 @@ struct ggml_tensor * ggml_win_part(
7374
7476
 
7375
7477
  result->op = GGML_OP_WIN_PART;
7376
7478
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7377
- result->src0 = a;
7378
- result->src1 = NULL;
7379
- result->opt[0] = b;
7479
+ result->src[0] = a;
7480
+ result->src[1] = NULL;
7481
+ result->src[2] = b;
7380
7482
 
7381
7483
  return result;
7382
7484
  }
@@ -7411,9 +7513,9 @@ struct ggml_tensor * ggml_win_unpart(
7411
7513
 
7412
7514
  result->op = GGML_OP_WIN_UNPART;
7413
7515
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7414
- result->src0 = a;
7415
- result->src1 = NULL;
7416
- result->opt[0] = b;
7516
+ result->src[0] = a;
7517
+ result->src[1] = NULL;
7518
+ result->src[2] = b;
7417
7519
 
7418
7520
  return result;
7419
7521
  }
@@ -7442,8 +7544,8 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
7442
7544
 
7443
7545
  result->op = GGML_OP_MAP_UNARY;
7444
7546
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7445
- result->src0 = a;
7446
- result->opt[0] = addr_tensor;
7547
+ result->src[0] = a;
7548
+ result->src[2] = addr_tensor;
7447
7549
 
7448
7550
  return result;
7449
7551
  }
@@ -7489,9 +7591,9 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
7489
7591
 
7490
7592
  result->op = GGML_OP_MAP_BINARY;
7491
7593
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7492
- result->src0 = a;
7493
- result->src1 = b;
7494
- result->opt[0] = addr_tensor;
7594
+ result->src[0] = a;
7595
+ result->src[1] = b;
7596
+ result->src[2] = addr_tensor;
7495
7597
 
7496
7598
  return result;
7497
7599
  }
@@ -7536,8 +7638,8 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
7536
7638
 
7537
7639
  result->op = GGML_OP_MAP_CUSTOM1;
7538
7640
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7539
- result->src0 = a;
7540
- result->opt[0] = addr_tensor;
7641
+ result->src[0] = a;
7642
+ result->src[2] = addr_tensor;
7541
7643
 
7542
7644
  return result;
7543
7645
  }
@@ -7581,9 +7683,9 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
7581
7683
 
7582
7684
  result->op = GGML_OP_MAP_CUSTOM2;
7583
7685
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7584
- result->src0 = a;
7585
- result->src1 = b;
7586
- result->opt[0] = addr_tensor;
7686
+ result->src[0] = a;
7687
+ result->src[1] = b;
7688
+ result->src[2] = addr_tensor;
7587
7689
 
7588
7690
  return result;
7589
7691
  }
@@ -7630,10 +7732,10 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
7630
7732
 
7631
7733
  result->op = GGML_OP_MAP_CUSTOM3;
7632
7734
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7633
- result->src0 = a;
7634
- result->src1 = b;
7635
- result->opt[0] = addr_tensor;
7636
- result->opt[1] = c;
7735
+ result->src[0] = a;
7736
+ result->src[1] = b;
7737
+ result->src[2] = addr_tensor;
7738
+ result->src[3] = c;
7637
7739
 
7638
7740
  return result;
7639
7741
  }
@@ -7673,8 +7775,8 @@ struct ggml_tensor * ggml_cross_entropy_loss(
7673
7775
 
7674
7776
  result->op = GGML_OP_CROSS_ENTROPY_LOSS;
7675
7777
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7676
- result->src0 = a;
7677
- result->src1 = b;
7778
+ result->src[0] = a;
7779
+ result->src[1] = b;
7678
7780
 
7679
7781
  return result;
7680
7782
  }
@@ -7693,9 +7795,9 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
7693
7795
 
7694
7796
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
7695
7797
  result->grad = NULL;
7696
- result->src0 = a;
7697
- result->src1 = b;
7698
- result->opt[0] = c;
7798
+ result->src[0] = a;
7799
+ result->src[1] = b;
7800
+ result->src[2] = c;
7699
7801
 
7700
7802
  return result;
7701
7803
  }
@@ -8296,7 +8398,7 @@ static void ggml_compute_forward_add_f32(
8296
8398
  const struct ggml_tensor * src0,
8297
8399
  const struct ggml_tensor * src1,
8298
8400
  struct ggml_tensor * dst) {
8299
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8401
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
8300
8402
 
8301
8403
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8302
8404
  return;
@@ -8321,23 +8423,23 @@ static void ggml_compute_forward_add_f32(
8321
8423
 
8322
8424
  if (nb10 == sizeof(float)) {
8323
8425
  for (int ir = ir0; ir < ir1; ++ir) {
8324
- // src0, src1 and dst are same shape => same indices
8325
- const int i3 = ir/(ne2*ne1);
8326
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8327
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8426
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8427
+ const int64_t i03 = ir/(ne02*ne01);
8428
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8429
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8430
+
8431
+ const int64_t i13 = i03 % ne13;
8432
+ const int64_t i12 = i02 % ne12;
8433
+ const int64_t i11 = i01 % ne11;
8328
8434
 
8435
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8436
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8437
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
8329
8438
 
8330
8439
  #ifdef GGML_USE_ACCELERATE
8331
- vDSP_vadd(
8332
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
8333
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
8334
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
8335
- ne0);
8440
+ vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
8336
8441
  #else
8337
- ggml_vec_add_f32(ne0,
8338
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
8339
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
8340
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8442
+ ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8341
8443
  #endif
8342
8444
  // }
8343
8445
  // }
@@ -8345,15 +8447,20 @@ static void ggml_compute_forward_add_f32(
8345
8447
  } else {
8346
8448
  // src1 is not contiguous
8347
8449
  for (int ir = ir0; ir < ir1; ++ir) {
8348
- // src0, src1 and dst are same shape => same indices
8349
- const int i3 = ir/(ne2*ne1);
8350
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8351
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8450
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8451
+ const int64_t i03 = ir/(ne02*ne01);
8452
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8453
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8454
+
8455
+ const int64_t i13 = i03 % ne13;
8456
+ const int64_t i12 = i02 % ne12;
8457
+ const int64_t i11 = i01 % ne11;
8458
+
8459
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8460
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8352
8461
 
8353
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
8354
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8355
8462
  for (int i0 = 0; i0 < ne0; i0++) {
8356
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
8463
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
8357
8464
 
8358
8465
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8359
8466
  }
@@ -10532,7 +10639,6 @@ static void ggml_compute_forward_rms_norm_back(
10532
10639
  }
10533
10640
  }
10534
10641
 
10535
-
10536
10642
  // ggml_compute_forward_mul_mat
10537
10643
 
10538
10644
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -10576,17 +10682,17 @@ static void ggml_compute_forward_mul_mat(
10576
10682
  const int ith = params->ith;
10577
10683
  const int nth = params->nth;
10578
10684
 
10579
- GGML_ASSERT(ne02 == ne12);
10580
- GGML_ASSERT(ne03 == ne13);
10581
- GGML_ASSERT(ne2 == ne12);
10582
- GGML_ASSERT(ne3 == ne13);
10583
-
10584
10685
  const enum ggml_type type = src0->type;
10585
10686
 
10586
10687
  ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10587
10688
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10588
10689
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10589
10690
 
10691
+ GGML_ASSERT(ne0 == ne01);
10692
+ GGML_ASSERT(ne1 == ne11);
10693
+ GGML_ASSERT(ne2 == ne12);
10694
+ GGML_ASSERT(ne3 == ne13);
10695
+
10590
10696
  // we don't support permuted src0 or src1
10591
10697
  GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
10592
10698
  GGML_ASSERT(nb10 == sizeof(float));
@@ -10597,16 +10703,16 @@ static void ggml_compute_forward_mul_mat(
10597
10703
  GGML_ASSERT(nb1 <= nb2);
10598
10704
  GGML_ASSERT(nb2 <= nb3);
10599
10705
 
10600
- GGML_ASSERT(ne0 == ne01);
10601
- GGML_ASSERT(ne1 == ne11);
10602
- GGML_ASSERT(ne2 == ne02);
10603
- GGML_ASSERT(ne3 == ne03);
10604
-
10605
10706
  // nb01 >= nb00 - src0 is not transposed
10606
10707
  // compute by src0 rows
10607
10708
 
10608
10709
  #if defined(GGML_USE_CLBLAST)
10609
10710
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
10711
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10712
+ // ref: https://github.com/ggerganov/ggml/pull/224
10713
+ GGML_ASSERT(ne02 == ne12);
10714
+ GGML_ASSERT(ne03 == ne13);
10715
+
10610
10716
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
10611
10717
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
10612
10718
  }
@@ -10616,6 +10722,11 @@ static void ggml_compute_forward_mul_mat(
10616
10722
 
10617
10723
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10618
10724
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
10725
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10726
+ // ref: https://github.com/ggerganov/ggml/pull/224
10727
+ GGML_ASSERT(ne02 == ne12);
10728
+ GGML_ASSERT(ne03 == ne13);
10729
+
10619
10730
  if (params->ith != 0) {
10620
10731
  return;
10621
10732
  }
@@ -10685,43 +10796,44 @@ static void ggml_compute_forward_mul_mat(
10685
10796
  return;
10686
10797
  }
10687
10798
 
10688
- // parallelize by src0 rows using ggml_vec_dot_q
10689
-
10690
- // total rows in src0
10691
- const int nr = ne01*ne02*ne03;
10799
+ // parallelize by src0 rows
10800
+ const int64_t dr = (ne01 + nth - 1)/nth;
10692
10801
 
10693
- // rows per thread
10694
- const int dr = (nr + nth - 1)/nth;
10802
+ const int64_t ir10 = dr*ith;
10803
+ const int64_t ir11 = MIN(ir10 + dr, ne01);
10695
10804
 
10696
- // row range for this thread
10697
- const int ir0 = dr*ith;
10698
- const int ir1 = MIN(ir0 + dr, nr);
10805
+ // src1 rows
10806
+ const int64_t nr1 = ne11*ne12*ne13;
10699
10807
 
10700
10808
  void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10701
- const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10702
-
10703
- for (int ir = ir0; ir < ir1; ++ir) {
10704
- // src0 indices
10705
- const int i03 = ir/(ne02*ne01);
10706
- const int i02 = (ir - i03*ne02*ne01)/ne01;
10707
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
10708
-
10709
- const int i13 = i03;
10710
- const int i12 = i02;
10711
-
10712
- const int i0 = i01;
10713
- const int i2 = i02;
10714
- const int i3 = i03;
10715
-
10716
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
10717
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
10718
-
10719
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
10720
-
10721
- assert(ne00 % 32 == 0);
10722
-
10723
- for (int64_t ic = 0; ic < ne11; ++ic) {
10724
- vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
10809
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10810
+
10811
+ for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
10812
+ const int64_t i13 = (ir1/(ne12*ne11));
10813
+ const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
10814
+ const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
10815
+
10816
+ const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
10817
+ const int64_t i03 = (ir0/(ne02));
10818
+ // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
10819
+ // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
10820
+ // GG: this is likely the correct way to broadcast, though need some more thought
10821
+ // therefore leaving the comments to remind us for now
10822
+ const int64_t i02 = (i12 / (ne12 / ne02));
10823
+ // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
10824
+ // const int64_t i02 = (ir0 - i03*ne02);
10825
+
10826
+ const int64_t i1 = i11;
10827
+ const int64_t i2 = i12;
10828
+ const int64_t i3 = i13;
10829
+
10830
+ const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
10831
+ const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
10832
+
10833
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10834
+
10835
+ for (int64_t ir = ir10; ir < ir11; ++ir) {
10836
+ vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
10725
10837
  }
10726
10838
  }
10727
10839
 
@@ -11718,7 +11830,7 @@ static void ggml_compute_forward_alibi_f32(
11718
11830
 
11719
11831
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11720
11832
  const int ne1 = src0->ne[1]; // seq_len_without_past
11721
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11833
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11722
11834
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11723
11835
 
11724
11836
  const int n = ggml_nrows(src0);
@@ -11729,8 +11841,9 @@ static void ggml_compute_forward_alibi_f32(
11729
11841
  const int nb2 = src0->nb[2];
11730
11842
  //const int nb3 = src0->nb[3];
11731
11843
 
11732
- assert(nb0 == sizeof(float));
11733
- assert(ne1 + n_past == ne0); (void) n_past;
11844
+ GGML_ASSERT(nb0 == sizeof(float));
11845
+ GGML_ASSERT(ne1 + n_past == ne0);
11846
+ GGML_ASSERT(n_head == ne2);
11734
11847
 
11735
11848
  // add alibi to src0 (KQ_scaled)
11736
11849
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11754,7 +11867,7 @@ static void ggml_compute_forward_alibi_f32(
11754
11867
  m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11755
11868
  }
11756
11869
 
11757
- pdst[0] = (i-ne0+1) * m_k + src[0];
11870
+ pdst[0] = i * m_k + src[0];
11758
11871
 
11759
11872
  }
11760
11873
  }
@@ -11783,7 +11896,7 @@ static void ggml_compute_forward_alibi_f16(
11783
11896
 
11784
11897
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11785
11898
  const int ne1 = src0->ne[1]; // seq_len_without_past
11786
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11899
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11787
11900
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11788
11901
 
11789
11902
  const int n = ggml_nrows(src0);
@@ -11794,8 +11907,9 @@ static void ggml_compute_forward_alibi_f16(
11794
11907
  const int nb2 = src0->nb[2];
11795
11908
  //const int nb3 = src0->nb[3];
11796
11909
 
11797
- assert(nb0 == sizeof(ggml_fp16_t));
11798
- assert(ne1 + n_past == ne0); (void) n_past;
11910
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
11911
+ GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
11912
+ GGML_ASSERT(n_head == ne2);
11799
11913
 
11800
11914
  // add alibi to src0 (KQ_scaled)
11801
11915
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11820,7 +11934,7 @@ static void ggml_compute_forward_alibi_f16(
11820
11934
  }
11821
11935
 
11822
11936
  // we return F32
11823
- pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
11937
+ pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
11824
11938
  }
11825
11939
  }
11826
11940
  }
@@ -12904,16 +13018,18 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12904
13018
  {
12905
13019
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12906
13020
 
12907
- for (int i12 = 0; i12 < ne12; i12++) {
12908
- const float * const src = (float *)((char *) src1->data + i12*nb12);
12909
- ggml_fp16_t * dst_data = wdata;
12910
-
12911
- for (int i1 = 0; i1 < ne1; i1++) {
12912
- for (int i0 = 0; i0 < ne0; i0++) {
12913
- for (int ik1 = 0; ik1 < nk1; ik1++) {
12914
- for (int ik0 = 0; ik0 < nk0; ik0++) {
12915
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
12916
- GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13021
+ for (int i13 = 0; i13 < ne13; i13++) {
13022
+ for (int i12 = 0; i12 < ne12; i12++) {
13023
+ const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
13024
+ ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
13025
+
13026
+ for (int i1 = 0; i1 < ne1; i1++) {
13027
+ for (int i0 = 0; i0 < ne0; i0++) {
13028
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
13029
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
13030
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
13031
+ GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13032
+ }
12917
13033
  }
12918
13034
  }
12919
13035
  }
@@ -12940,14 +13056,16 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12940
13056
 
12941
13057
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12942
13058
 
12943
- for (int i2 = ip0; i2 < ip1; i2++) {
12944
- float * dst_data = (float *)((char *) dst->data + i2*nb2);
12945
-
12946
- for (int i1 = 0; i1 < ne1; ++i1) {
12947
- for (int i0 = 0; i0 < ne0; ++i0) {
12948
- ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
12949
- (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
12950
- (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
13059
+ for (int i3 = 0; i3 < ne3; i3++) {
13060
+ for (int i2 = ip0; i2 < ip1; i2++) {
13061
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
13062
+
13063
+ for (int i1 = 0; i1 < ne1; ++i1) {
13064
+ for (int i0 = 0; i0 < ne0; ++i0) {
13065
+ ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
13066
+ (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
13067
+ (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
13068
+ }
12951
13069
  }
12952
13070
  }
12953
13071
  }
@@ -12996,10 +13114,169 @@ static void ggml_compute_forward_conv_2d(
12996
13114
 
12997
13115
  if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
12998
13116
  ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
12999
- }
13000
- else {
13117
+ } else {
13001
13118
  GGML_ASSERT(false); // only stride equal to kernel size is supported
13002
- };
13119
+ }
13120
+ }
13121
+
13122
+ // ggml_compute_forward_pool_1d_sk_p0
13123
+
13124
+ static void ggml_compute_forward_pool_1d_sk_p0(
13125
+ const struct ggml_compute_params * params,
13126
+ const enum ggml_op_pool op,
13127
+ const struct ggml_tensor * src,
13128
+ const int k,
13129
+ struct ggml_tensor * dst) {
13130
+ assert(src->type == GGML_TYPE_F32);
13131
+ assert(params->ith == 0);
13132
+
13133
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13134
+ return;
13135
+ }
13136
+
13137
+ const char * cdata = (const char *)src->data;
13138
+ const char * const data_end = cdata + ggml_nbytes(src);
13139
+ float * drow = (float *)dst->data;
13140
+
13141
+ const int64_t rs = dst->ne[0];
13142
+
13143
+ while (cdata < data_end) {
13144
+ const float * const srow = (const float *)cdata;
13145
+
13146
+ int j = 0;
13147
+
13148
+ for (int64_t i = 0; i < rs; ++i) {
13149
+ switch (op) {
13150
+ case GGML_OP_POOL_AVG: drow[i] = 0; break;
13151
+ case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
13152
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13153
+ }
13154
+ for (int ki = 0; ki < k; ++ki) {
13155
+ switch (op) {
13156
+ case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
13157
+ case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
13158
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13159
+ }
13160
+ ++j;
13161
+ }
13162
+ switch (op) {
13163
+ case GGML_OP_POOL_AVG: drow[i] /= k; break;
13164
+ case GGML_OP_POOL_MAX: break;
13165
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13166
+ }
13167
+ }
13168
+
13169
+ cdata += src->nb[1];
13170
+ drow += rs;
13171
+ }
13172
+ }
13173
+
13174
+ // ggml_compute_forward_pool_1d
13175
+
13176
+ static void ggml_compute_forward_pool_1d(
13177
+ const struct ggml_compute_params* params,
13178
+ const struct ggml_tensor* src0,
13179
+ const struct ggml_tensor* opt0,
13180
+ struct ggml_tensor* dst) {
13181
+ GGML_ASSERT(opt0->ne[0] == 4);
13182
+ const int* opts = (const int*)opt0->data;
13183
+ enum ggml_op_pool op = opts[0];
13184
+ const int k0 = opts[1];
13185
+ const int s0 = opts[2];
13186
+ const int p0 = opts[3];
13187
+ GGML_ASSERT(p0 == 0); // padding not supported
13188
+ GGML_ASSERT(k0 == s0); // only s = k supported
13189
+
13190
+ ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
13191
+ }
13192
+
13193
+ // ggml_compute_forward_pool_2d_sk_p0
13194
+
13195
+ static void ggml_compute_forward_pool_2d_sk_p0(
13196
+ const struct ggml_compute_params * params,
13197
+ const enum ggml_op_pool op,
13198
+ const struct ggml_tensor * src,
13199
+ const int k0,
13200
+ const int k1,
13201
+ struct ggml_tensor * dst) {
13202
+ assert(src->type == GGML_TYPE_F32);
13203
+ assert(params->ith == 0);
13204
+
13205
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13206
+ return;
13207
+ }
13208
+
13209
+ const char * cdata = (const char*)src->data;
13210
+ const char * const data_end = cdata + ggml_nbytes(src);
13211
+
13212
+ const int64_t px = dst->ne[0];
13213
+ const int64_t py = dst->ne[1];
13214
+ const int64_t pa = px * py;
13215
+
13216
+ float * dplane = (float *)dst->data;
13217
+
13218
+ const int ka = k0 * k1;
13219
+
13220
+ while (cdata < data_end) {
13221
+ for (int oy = 0; oy < py; ++oy) {
13222
+ float * const drow = dplane + oy * px;
13223
+ for (int ox = 0; ox < px; ++ox) {
13224
+ float * const out = drow + ox;
13225
+ switch (op) {
13226
+ case GGML_OP_POOL_AVG: *out = 0; break;
13227
+ case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
13228
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13229
+ }
13230
+
13231
+ const int ix = ox * k0;
13232
+ const int iy = oy * k1;
13233
+
13234
+ for (int ky = 0; ky < k1; ++ky) {
13235
+ const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
13236
+ for (int kx = 0; kx < k0; ++kx) {
13237
+ int j = ix + kx;
13238
+ switch (op) {
13239
+ case GGML_OP_POOL_AVG: *out += srow[j]; break;
13240
+ case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
13241
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13242
+ }
13243
+ }
13244
+ }
13245
+ switch (op) {
13246
+ case GGML_OP_POOL_AVG: *out /= ka; break;
13247
+ case GGML_OP_POOL_MAX: break;
13248
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13249
+ }
13250
+ }
13251
+ }
13252
+
13253
+ cdata += src->nb[2];
13254
+ dplane += pa;
13255
+ }
13256
+ }
13257
+
13258
+ // ggml_compute_forward_pool_2d
13259
+
13260
+ static void ggml_compute_forward_pool_2d(
13261
+ const struct ggml_compute_params * params,
13262
+ const struct ggml_tensor * src0,
13263
+ const struct ggml_tensor * opt0,
13264
+ struct ggml_tensor * dst) {
13265
+ GGML_ASSERT(opt0->ne[0] == 7);
13266
+ const int* opts = (const int*)opt0->data;
13267
+ enum ggml_op_pool op = opts[0];
13268
+ const int k0 = opts[1];
13269
+ const int k1 = opts[2];
13270
+ const int s0 = opts[3];
13271
+ const int s1 = opts[4];
13272
+ const int p0 = opts[5];
13273
+ const int p1 = opts[6];
13274
+ GGML_ASSERT(p0 == 0);
13275
+ GGML_ASSERT(p1 == 0); // padding not supported
13276
+ GGML_ASSERT(k0 == s0);
13277
+ GGML_ASSERT(k1 == s1); // only s = k supported
13278
+
13279
+ ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
13003
13280
  }
13004
13281
 
13005
13282
 
@@ -14566,287 +14843,295 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14566
14843
  if (skip_cpu) {
14567
14844
  return;
14568
14845
  }
14569
- GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
14570
- GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
14846
+ GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
14847
+ GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
14571
14848
  #endif // GGML_USE_CUBLAS
14572
14849
 
14573
14850
  switch (tensor->op) {
14574
14851
  case GGML_OP_DUP:
14575
14852
  {
14576
- ggml_compute_forward_dup(params, tensor->src0, tensor);
14853
+ ggml_compute_forward_dup(params, tensor->src[0], tensor);
14577
14854
  } break;
14578
14855
  case GGML_OP_ADD:
14579
14856
  {
14580
- ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
14857
+ ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
14581
14858
  } break;
14582
14859
  case GGML_OP_ADD1:
14583
14860
  {
14584
- ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor);
14861
+ ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor);
14585
14862
  } break;
14586
14863
  case GGML_OP_ACC:
14587
14864
  {
14588
- ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14865
+ ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14589
14866
  } break;
14590
14867
  case GGML_OP_SUB:
14591
14868
  {
14592
- ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
14869
+ ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor);
14593
14870
  } break;
14594
14871
  case GGML_OP_MUL:
14595
14872
  {
14596
- ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
14873
+ ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor);
14597
14874
  } break;
14598
14875
  case GGML_OP_DIV:
14599
14876
  {
14600
- ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
14877
+ ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor);
14601
14878
  } break;
14602
14879
  case GGML_OP_SQR:
14603
14880
  {
14604
- ggml_compute_forward_sqr(params, tensor->src0, tensor);
14881
+ ggml_compute_forward_sqr(params, tensor->src[0], tensor);
14605
14882
  } break;
14606
14883
  case GGML_OP_SQRT:
14607
14884
  {
14608
- ggml_compute_forward_sqrt(params, tensor->src0, tensor);
14885
+ ggml_compute_forward_sqrt(params, tensor->src[0], tensor);
14609
14886
  } break;
14610
14887
  case GGML_OP_LOG:
14611
14888
  {
14612
- ggml_compute_forward_log(params, tensor->src0, tensor);
14889
+ ggml_compute_forward_log(params, tensor->src[0], tensor);
14613
14890
  } break;
14614
14891
  case GGML_OP_SUM:
14615
14892
  {
14616
- ggml_compute_forward_sum(params, tensor->src0, tensor);
14893
+ ggml_compute_forward_sum(params, tensor->src[0], tensor);
14617
14894
  } break;
14618
14895
  case GGML_OP_SUM_ROWS:
14619
14896
  {
14620
- ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
14897
+ ggml_compute_forward_sum_rows(params, tensor->src[0], tensor);
14621
14898
  } break;
14622
14899
  case GGML_OP_MEAN:
14623
14900
  {
14624
- ggml_compute_forward_mean(params, tensor->src0, tensor);
14901
+ ggml_compute_forward_mean(params, tensor->src[0], tensor);
14625
14902
  } break;
14626
14903
  case GGML_OP_ARGMAX:
14627
14904
  {
14628
- ggml_compute_forward_argmax(params, tensor->src0, tensor);
14905
+ ggml_compute_forward_argmax(params, tensor->src[0], tensor);
14629
14906
  } break;
14630
14907
  case GGML_OP_REPEAT:
14631
14908
  {
14632
- ggml_compute_forward_repeat(params, tensor->src0, tensor);
14909
+ ggml_compute_forward_repeat(params, tensor->src[0], tensor);
14633
14910
  } break;
14634
14911
  case GGML_OP_REPEAT_BACK:
14635
14912
  {
14636
- ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14913
+ ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
14637
14914
  } break;
14638
14915
  case GGML_OP_ABS:
14639
14916
  {
14640
- ggml_compute_forward_abs(params, tensor->src0, tensor);
14917
+ ggml_compute_forward_abs(params, tensor->src[0], tensor);
14641
14918
  } break;
14642
14919
  case GGML_OP_SGN:
14643
14920
  {
14644
- ggml_compute_forward_sgn(params, tensor->src0, tensor);
14921
+ ggml_compute_forward_sgn(params, tensor->src[0], tensor);
14645
14922
  } break;
14646
14923
  case GGML_OP_NEG:
14647
14924
  {
14648
- ggml_compute_forward_neg(params, tensor->src0, tensor);
14925
+ ggml_compute_forward_neg(params, tensor->src[0], tensor);
14649
14926
  } break;
14650
14927
  case GGML_OP_STEP:
14651
14928
  {
14652
- ggml_compute_forward_step(params, tensor->src0, tensor);
14929
+ ggml_compute_forward_step(params, tensor->src[0], tensor);
14653
14930
  } break;
14654
14931
  case GGML_OP_TANH:
14655
14932
  {
14656
- ggml_compute_forward_tanh(params, tensor->src0, tensor);
14933
+ ggml_compute_forward_tanh(params, tensor->src[0], tensor);
14657
14934
  } break;
14658
14935
  case GGML_OP_ELU:
14659
14936
  {
14660
- ggml_compute_forward_elu(params, tensor->src0, tensor);
14937
+ ggml_compute_forward_elu(params, tensor->src[0], tensor);
14661
14938
  } break;
14662
14939
  case GGML_OP_RELU:
14663
14940
  {
14664
- ggml_compute_forward_relu(params, tensor->src0, tensor);
14941
+ ggml_compute_forward_relu(params, tensor->src[0], tensor);
14665
14942
  } break;
14666
14943
  case GGML_OP_GELU:
14667
14944
  {
14668
- ggml_compute_forward_gelu(params, tensor->src0, tensor);
14945
+ ggml_compute_forward_gelu(params, tensor->src[0], tensor);
14669
14946
  } break;
14670
14947
  case GGML_OP_GELU_QUICK:
14671
14948
  {
14672
- ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
14949
+ ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
14673
14950
  } break;
14674
14951
  case GGML_OP_SILU:
14675
14952
  {
14676
- ggml_compute_forward_silu(params, tensor->src0, tensor);
14953
+ ggml_compute_forward_silu(params, tensor->src[0], tensor);
14677
14954
  } break;
14678
14955
  case GGML_OP_SILU_BACK:
14679
14956
  {
14680
- ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor);
14957
+ ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
14681
14958
  } break;
14682
14959
  case GGML_OP_NORM:
14683
14960
  {
14684
- ggml_compute_forward_norm(params, tensor->src0, tensor);
14961
+ ggml_compute_forward_norm(params, tensor->src[0], tensor);
14685
14962
  } break;
14686
14963
  case GGML_OP_RMS_NORM:
14687
14964
  {
14688
- ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
14965
+ ggml_compute_forward_rms_norm(params, tensor->src[0], tensor);
14689
14966
  } break;
14690
14967
  case GGML_OP_RMS_NORM_BACK:
14691
14968
  {
14692
- ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor);
14969
+ ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
14693
14970
  } break;
14694
14971
  case GGML_OP_MUL_MAT:
14695
14972
  {
14696
- ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
14973
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14697
14974
  } break;
14698
14975
  case GGML_OP_OUT_PROD:
14699
14976
  {
14700
- ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
14977
+ ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
14701
14978
  } break;
14702
14979
  case GGML_OP_SCALE:
14703
14980
  {
14704
- ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
14981
+ ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
14705
14982
  } break;
14706
14983
  case GGML_OP_SET:
14707
14984
  {
14708
- ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14985
+ ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14709
14986
  } break;
14710
14987
  case GGML_OP_CPY:
14711
14988
  {
14712
- ggml_compute_forward_cpy(params, tensor->src0, tensor);
14989
+ ggml_compute_forward_cpy(params, tensor->src[0], tensor);
14713
14990
  } break;
14714
14991
  case GGML_OP_CONT:
14715
14992
  {
14716
- ggml_compute_forward_cont(params, tensor->src0, tensor);
14993
+ ggml_compute_forward_cont(params, tensor->src[0], tensor);
14717
14994
  } break;
14718
14995
  case GGML_OP_RESHAPE:
14719
14996
  {
14720
- ggml_compute_forward_reshape(params, tensor->src0, tensor);
14997
+ ggml_compute_forward_reshape(params, tensor->src[0], tensor);
14721
14998
  } break;
14722
14999
  case GGML_OP_VIEW:
14723
15000
  {
14724
- ggml_compute_forward_view(params, tensor->src0);
15001
+ ggml_compute_forward_view(params, tensor->src[0]);
14725
15002
  } break;
14726
15003
  case GGML_OP_PERMUTE:
14727
15004
  {
14728
- ggml_compute_forward_permute(params, tensor->src0);
15005
+ ggml_compute_forward_permute(params, tensor->src[0]);
14729
15006
  } break;
14730
15007
  case GGML_OP_TRANSPOSE:
14731
15008
  {
14732
- ggml_compute_forward_transpose(params, tensor->src0);
15009
+ ggml_compute_forward_transpose(params, tensor->src[0]);
14733
15010
  } break;
14734
15011
  case GGML_OP_GET_ROWS:
14735
15012
  {
14736
- ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
15013
+ ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor);
14737
15014
  } break;
14738
15015
  case GGML_OP_GET_ROWS_BACK:
14739
15016
  {
14740
- ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15017
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14741
15018
  } break;
14742
15019
  case GGML_OP_DIAG:
14743
15020
  {
14744
- ggml_compute_forward_diag(params, tensor->src0, tensor);
15021
+ ggml_compute_forward_diag(params, tensor->src[0], tensor);
14745
15022
  } break;
14746
15023
  case GGML_OP_DIAG_MASK_INF:
14747
15024
  {
14748
- ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
15025
+ ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
14749
15026
  } break;
14750
15027
  case GGML_OP_DIAG_MASK_ZERO:
14751
15028
  {
14752
- ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor);
15029
+ ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
14753
15030
  } break;
14754
15031
  case GGML_OP_SOFT_MAX:
14755
15032
  {
14756
- ggml_compute_forward_soft_max(params, tensor->src0, tensor);
15033
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
14757
15034
  } break;
14758
15035
  case GGML_OP_SOFT_MAX_BACK:
14759
15036
  {
14760
- ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
15037
+ ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor);
14761
15038
  } break;
14762
15039
  case GGML_OP_ROPE:
14763
15040
  {
14764
- ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
15041
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
14765
15042
  } break;
14766
15043
  case GGML_OP_ROPE_BACK:
14767
15044
  {
14768
- ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
15045
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
14769
15046
  } break;
14770
15047
  case GGML_OP_ALIBI:
14771
15048
  {
14772
- ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
15049
+ ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
14773
15050
  } break;
14774
15051
  case GGML_OP_CLAMP:
14775
15052
  {
14776
- ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
15053
+ ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
14777
15054
  } break;
14778
15055
  case GGML_OP_CONV_1D:
14779
15056
  {
14780
- ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15057
+ ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14781
15058
  } break;
14782
15059
  case GGML_OP_CONV_2D:
14783
15060
  {
14784
- ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15061
+ ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
15062
+ } break;
15063
+ case GGML_OP_POOL_1D:
15064
+ {
15065
+ ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
15066
+ } break;
15067
+ case GGML_OP_POOL_2D:
15068
+ {
15069
+ ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
14785
15070
  } break;
14786
15071
  case GGML_OP_FLASH_ATTN:
14787
15072
  {
14788
- const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15073
+ const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
14789
15074
  GGML_ASSERT(t == 0 || t == 1);
14790
15075
  const bool masked = t != 0;
14791
- ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
15076
+ ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
14792
15077
  } break;
14793
15078
  case GGML_OP_FLASH_FF:
14794
15079
  {
14795
- ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
15080
+ ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
14796
15081
  } break;
14797
15082
  case GGML_OP_FLASH_ATTN_BACK:
14798
15083
  {
14799
- int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
15084
+ int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
14800
15085
  GGML_ASSERT(t == 0 || t == 1);
14801
15086
  bool masked = t != 0;
14802
- ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
15087
+ ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
14803
15088
  } break;
14804
15089
  case GGML_OP_WIN_PART:
14805
15090
  {
14806
- ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
15091
+ ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
14807
15092
  } break;
14808
15093
  case GGML_OP_WIN_UNPART:
14809
15094
  {
14810
- ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
15095
+ ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
14811
15096
  } break;
14812
15097
  case GGML_OP_MAP_UNARY:
14813
15098
  {
14814
- const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
14815
- ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
15099
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
15100
+ ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
14816
15101
  }
14817
15102
  break;
14818
15103
  case GGML_OP_MAP_BINARY:
14819
15104
  {
14820
- const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
14821
- ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
15105
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
15106
+ ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
14822
15107
  }
14823
15108
  break;
14824
15109
  case GGML_OP_MAP_CUSTOM1:
14825
15110
  {
14826
- const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
14827
- ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
15111
+ const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
15112
+ ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
14828
15113
  }
14829
15114
  break;
14830
15115
  case GGML_OP_MAP_CUSTOM2:
14831
15116
  {
14832
- const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
14833
- ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
15117
+ const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
15118
+ ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
14834
15119
  }
14835
15120
  break;
14836
15121
  case GGML_OP_MAP_CUSTOM3:
14837
15122
  {
14838
- const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
14839
- ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
15123
+ const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
15124
+ ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
14840
15125
  }
14841
15126
  break;
14842
15127
  case GGML_OP_CROSS_ENTROPY_LOSS:
14843
15128
  {
14844
- ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
15129
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor);
14845
15130
  }
14846
15131
  break;
14847
15132
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14848
15133
  {
14849
- ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15134
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14850
15135
  }
14851
15136
  break;
14852
15137
  case GGML_OP_NONE:
@@ -14863,8 +15148,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14863
15148
  ////////////////////////////////////////////////////////////////////////////////
14864
15149
 
14865
15150
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
14866
- struct ggml_tensor * src0 = tensor->src0;
14867
- struct ggml_tensor * src1 = tensor->src1;
15151
+ struct ggml_tensor * src0 = tensor->src[0];
15152
+ struct ggml_tensor * src1 = tensor->src[1];
14868
15153
 
14869
15154
  switch (tensor->op) {
14870
15155
  case GGML_OP_DUP:
@@ -14900,12 +15185,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
14900
15185
  src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
14901
15186
  }
14902
15187
  if (src1->grad) {
14903
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
14904
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
14905
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
14906
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
14907
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
14908
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15188
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15189
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15190
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15191
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15192
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15193
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
14909
15194
 
14910
15195
  struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
14911
15196
  tensor->grad,
@@ -15213,12 +15498,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15213
15498
  } break;
15214
15499
  case GGML_OP_SET:
15215
15500
  {
15216
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
15217
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
15218
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
15219
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
15220
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
15221
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15501
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15502
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15503
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15504
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15505
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15506
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
15222
15507
 
15223
15508
  struct ggml_tensor * tensor_grad_view = NULL;
15224
15509
 
@@ -15295,8 +15580,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15295
15580
  if (src0->grad) {
15296
15581
  size_t offset;
15297
15582
 
15298
- GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
15299
- memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
15583
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
15584
+ memcpy(&offset, tensor->src[2]->data, sizeof(offset));
15300
15585
 
15301
15586
  size_t nb1 = tensor->nb[1];
15302
15587
  size_t nb2 = tensor->nb[2];
@@ -15323,7 +15608,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15323
15608
  {
15324
15609
  // necessary for llama
15325
15610
  if (src0->grad) {
15326
- int32_t * axes = (int32_t *) tensor->opt[0]->data;
15611
+ int32_t * axes = (int32_t *) tensor->src[2]->data;
15327
15612
  int axis0 = axes[0] & 0x3;
15328
15613
  int axis1 = axes[1] & 0x3;
15329
15614
  int axis2 = axes[2] & 0x3;
@@ -15483,18 +15768,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15483
15768
  {
15484
15769
  GGML_ASSERT(false); // TODO: not implemented
15485
15770
  } break;
15771
+ case GGML_OP_POOL_1D:
15772
+ {
15773
+ GGML_ASSERT(false); // TODO: not implemented
15774
+ } break;
15775
+ case GGML_OP_POOL_2D:
15776
+ {
15777
+ GGML_ASSERT(false); // TODO: not implemented
15778
+ } break;
15486
15779
  case GGML_OP_FLASH_ATTN:
15487
15780
  {
15488
15781
  struct ggml_tensor * flash_grad = NULL;
15489
- if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15490
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15782
+ if (src0->grad || src1->grad || tensor->src[2]->grad) {
15783
+ int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
15491
15784
  GGML_ASSERT(t == 0 || t == 1);
15492
15785
  bool masked = t != 0;
15493
15786
  flash_grad =
15494
15787
  ggml_flash_attn_back(ctx,
15495
15788
  src0,
15496
15789
  src1,
15497
- tensor->opt[0],
15790
+ tensor->src[2],
15498
15791
  tensor->grad,
15499
15792
  masked);
15500
15793
  }
@@ -15591,7 +15884,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15591
15884
  inplace);
15592
15885
  }
15593
15886
 
15594
- struct ggml_tensor * opt0 = tensor->opt[0];
15887
+ struct ggml_tensor * opt0 = tensor->src[2];
15595
15888
 
15596
15889
  if (opt0->grad) {
15597
15890
  struct ggml_tensor * grad_v = NULL;
@@ -15707,17 +16000,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
15707
16000
  }
15708
16001
  }
15709
16002
 
15710
- if (node->src0) {
15711
- ggml_visit_parents(cgraph, node->src0);
15712
- }
15713
-
15714
- if (node->src1) {
15715
- ggml_visit_parents(cgraph, node->src1);
15716
- }
15717
-
15718
- for (int i = 0; i < GGML_MAX_OPT; ++i) {
15719
- if (node->opt[i]) {
15720
- ggml_visit_parents(cgraph, node->opt[i]);
16003
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
16004
+ if (node->src[i]) {
16005
+ ggml_visit_parents(cgraph, node->src[i]);
15721
16006
  }
15722
16007
  }
15723
16008
 
@@ -15772,9 +16057,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
15772
16057
  struct ggml_cgraph result = {
15773
16058
  /*.n_nodes =*/ 0,
15774
16059
  /*.n_leafs =*/ 0,
15775
- /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
15776
- /*.work_size =*/ 0,
15777
- /*.work =*/ NULL,
15778
16060
  /*.nodes =*/ { NULL },
15779
16061
  /*.grads =*/ { NULL },
15780
16062
  /*.leafs =*/ { NULL },
@@ -15945,16 +16227,20 @@ void clear_numa_thread_affinity(void) {}
15945
16227
  #endif
15946
16228
 
15947
16229
  struct ggml_compute_state_shared {
15948
- struct ggml_cgraph * cgraph;
16230
+ const struct ggml_cgraph * cgraph;
16231
+ const struct ggml_cplan * cplan;
15949
16232
 
15950
16233
  int64_t perf_node_start_cycles;
15951
16234
  int64_t perf_node_start_time_us;
15952
16235
 
15953
- int n_threads;
16236
+ const int n_threads;
15954
16237
 
15955
16238
  // synchronization primitives
15956
16239
  atomic_int n_active; // num active threads
15957
16240
  atomic_int node_n; // active graph node
16241
+
16242
+ bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
16243
+ void * abort_callback_data;
15958
16244
  };
15959
16245
 
15960
16246
  struct ggml_compute_state {
@@ -15974,14 +16260,22 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
15974
16260
 
15975
16261
  static thread_ret_t ggml_graph_compute_thread(void * data) {
15976
16262
  struct ggml_compute_state * state = (struct ggml_compute_state *) data;
15977
- struct ggml_cgraph * cgraph = state->shared->cgraph;
15978
16263
 
15979
- const int n_threads = state->shared->n_threads;
16264
+ const struct ggml_cgraph * cgraph = state->shared->cgraph;
16265
+ const struct ggml_cplan * cplan = state->shared->cplan;
16266
+
16267
+ const int * n_tasks_arr = cplan->n_tasks;
16268
+ const int n_threads = state->shared->n_threads;
16269
+
15980
16270
  set_numa_thread_affinity(state->ith, n_threads);
15981
16271
 
15982
16272
  int node_n = -1;
15983
16273
 
15984
16274
  while (true) {
16275
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16276
+ state->shared->node_n += 1;
16277
+ return (thread_ret_t) GGML_EXIT_ABORTED;
16278
+ }
15985
16279
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
15986
16280
  // all other threads are finished and spinning
15987
16281
  // do finalize and init here so we don't have synchronize again
@@ -15989,15 +16283,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
15989
16283
  /*.type =*/ GGML_TASK_FINALIZE,
15990
16284
  /*.ith =*/ 0,
15991
16285
  /*.nth =*/ 0,
15992
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
15993
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16286
+ /*.wsize =*/ cplan->work_size,
16287
+ /*.wdata =*/ cplan->work_data,
15994
16288
  };
15995
16289
 
15996
16290
  if (node_n != -1) {
15997
16291
  /* FINALIZE */
15998
16292
  struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
15999
16293
  if (GGML_OP_HAS_FINALIZE[node->op]) {
16000
- params.nth = node->n_tasks;
16294
+ params.nth = n_tasks_arr[node_n];
16001
16295
  ggml_compute_forward(&params, node);
16002
16296
  ggml_graph_compute_perf_stats_node(node, state->shared);
16003
16297
  }
@@ -16008,11 +16302,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16008
16302
  GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
16009
16303
 
16010
16304
  struct ggml_tensor * node = cgraph->nodes[node_n];
16305
+ const int n_tasks = n_tasks_arr[node_n];
16011
16306
 
16012
16307
  state->shared->perf_node_start_cycles = ggml_perf_cycles();
16013
16308
  state->shared->perf_node_start_time_us = ggml_perf_time_us();
16014
16309
 
16015
- params.nth = node->n_tasks;
16310
+ params.nth = n_tasks;
16016
16311
 
16017
16312
  /* INIT */
16018
16313
  if (GGML_OP_HAS_INIT[node->op]) {
@@ -16020,7 +16315,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16020
16315
  ggml_compute_forward(&params, node);
16021
16316
  }
16022
16317
 
16023
- if (node->n_tasks == 1) {
16318
+ if (n_tasks == 1) {
16024
16319
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
16025
16320
  // they do something more efficient than spinning (?)
16026
16321
  params.type = GGML_TASK_COMPUTE;
@@ -16034,6 +16329,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16034
16329
  } else {
16035
16330
  break;
16036
16331
  }
16332
+
16333
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16334
+ break;
16335
+ }
16037
16336
  }
16038
16337
 
16039
16338
  atomic_store(&state->shared->n_active, n_threads);
@@ -16042,7 +16341,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16042
16341
  // wait for other threads to finish
16043
16342
  const int last = node_n;
16044
16343
  do {
16045
- sched_yield();
16344
+ //sched_yield();
16046
16345
  node_n = atomic_load(&state->shared->node_n);
16047
16346
  } while (node_n == last);
16048
16347
  }
@@ -16052,366 +16351,395 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16052
16351
 
16053
16352
  /* COMPUTE */
16054
16353
  struct ggml_tensor * node = cgraph->nodes[node_n];
16354
+ const int n_tasks = n_tasks_arr[node_n];
16055
16355
 
16056
16356
  struct ggml_compute_params params = {
16057
16357
  /*.type =*/ GGML_TASK_COMPUTE,
16058
16358
  /*.ith =*/ state->ith,
16059
- /*.nth =*/ node->n_tasks,
16060
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
16061
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16359
+ /*.nth =*/ n_tasks,
16360
+ /*.wsize =*/ cplan->work_size,
16361
+ /*.wdata =*/ cplan->work_data,
16062
16362
  };
16063
16363
 
16064
- if (state->ith < node->n_tasks) {
16364
+ if (state->ith < n_tasks) {
16065
16365
  ggml_compute_forward(&params, node);
16066
16366
  }
16067
16367
  }
16068
16368
 
16069
- return 0;
16369
+ return GGML_EXIT_SUCCESS;
16070
16370
  }
16071
16371
 
16072
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16073
- const int n_threads = cgraph->n_threads;
16372
+ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16373
+ if (n_threads <= 0) {
16374
+ n_threads = GGML_DEFAULT_N_THREADS;
16375
+ }
16074
16376
 
16075
- struct ggml_compute_state_shared state_shared = {
16076
- /*.cgraph =*/ cgraph,
16077
- /*.perf_node_start_cycles =*/ 0,
16078
- /*.perf_node_start_time_us =*/ 0,
16079
- /*.n_threads =*/ n_threads,
16080
- /*.n_active =*/ n_threads,
16081
- /*.node_n =*/ -1,
16082
- };
16083
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16377
+ size_t work_size = 0;
16084
16378
 
16085
- // initialize tasks + work buffer
16086
- {
16087
- size_t work_size = 0;
16379
+ struct ggml_cplan cplan;
16380
+ memset(&cplan, 0, sizeof(struct ggml_cplan));
16088
16381
 
16089
- // thread scheduling for the different operations
16090
- for (int i = 0; i < cgraph->n_nodes; i++) {
16091
- struct ggml_tensor * node = cgraph->nodes[i];
16382
+ // thread scheduling for the different operations + work buffer size estimation
16383
+ for (int i = 0; i < cgraph->n_nodes; i++) {
16384
+ int n_tasks = 1;
16092
16385
 
16093
- switch (node->op) {
16094
- case GGML_OP_CPY:
16095
- case GGML_OP_DUP:
16096
- {
16097
- node->n_tasks = n_threads;
16386
+ struct ggml_tensor * node = cgraph->nodes[i];
16098
16387
 
16099
- size_t cur = 0;
16100
- if (ggml_is_quantized(node->type)) {
16101
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
16102
- }
16388
+ switch (node->op) {
16389
+ case GGML_OP_CPY:
16390
+ case GGML_OP_DUP:
16391
+ {
16392
+ n_tasks = n_threads;
16103
16393
 
16104
- work_size = MAX(work_size, cur);
16105
- } break;
16106
- case GGML_OP_ADD:
16107
- case GGML_OP_ADD1:
16108
- {
16109
- node->n_tasks = n_threads;
16394
+ size_t cur = 0;
16395
+ if (ggml_is_quantized(node->type)) {
16396
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
16397
+ }
16110
16398
 
16111
- size_t cur = 0;
16399
+ work_size = MAX(work_size, cur);
16400
+ } break;
16401
+ case GGML_OP_ADD:
16402
+ case GGML_OP_ADD1:
16403
+ {
16404
+ n_tasks = n_threads;
16112
16405
 
16113
- if (ggml_is_quantized(node->src0->type)) {
16114
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
16115
- }
16406
+ size_t cur = 0;
16116
16407
 
16117
- work_size = MAX(work_size, cur);
16118
- } break;
16119
- case GGML_OP_ACC:
16120
- {
16121
- node->n_tasks = n_threads;
16408
+ if (ggml_is_quantized(node->src[0]->type)) {
16409
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
16410
+ }
16122
16411
 
16123
- size_t cur = 0;
16412
+ work_size = MAX(work_size, cur);
16413
+ } break;
16414
+ case GGML_OP_ACC:
16415
+ {
16416
+ n_tasks = n_threads;
16124
16417
 
16125
- if (ggml_is_quantized(node->src0->type)) {
16126
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
16127
- }
16418
+ size_t cur = 0;
16128
16419
 
16129
- work_size = MAX(work_size, cur);
16130
- } break;
16131
- case GGML_OP_SUB:
16132
- case GGML_OP_DIV:
16133
- case GGML_OP_SQR:
16134
- case GGML_OP_SQRT:
16135
- case GGML_OP_LOG:
16136
- case GGML_OP_SUM:
16137
- case GGML_OP_SUM_ROWS:
16138
- case GGML_OP_MEAN:
16139
- case GGML_OP_ARGMAX:
16140
- case GGML_OP_REPEAT:
16141
- case GGML_OP_REPEAT_BACK:
16142
- case GGML_OP_ABS:
16143
- case GGML_OP_SGN:
16144
- case GGML_OP_NEG:
16145
- case GGML_OP_STEP:
16146
- case GGML_OP_TANH:
16147
- case GGML_OP_ELU:
16148
- case GGML_OP_RELU:
16149
- {
16150
- node->n_tasks = 1;
16151
- } break;
16152
- case GGML_OP_MUL:
16153
- case GGML_OP_GELU:
16154
- case GGML_OP_GELU_QUICK:
16155
- case GGML_OP_SILU:
16156
- case GGML_OP_SILU_BACK:
16157
- case GGML_OP_NORM:
16158
- case GGML_OP_RMS_NORM:
16159
- case GGML_OP_RMS_NORM_BACK:
16160
- {
16161
- node->n_tasks = n_threads;
16162
- } break;
16163
- case GGML_OP_MUL_MAT:
16164
- case GGML_OP_OUT_PROD:
16165
- {
16166
- node->n_tasks = n_threads;
16167
-
16168
- // TODO: use different scheduling for different matrix sizes
16169
- //const int nr0 = ggml_nrows(node->src0);
16170
- //const int nr1 = ggml_nrows(node->src1);
16171
-
16172
- //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
16173
- //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
16174
-
16175
- size_t cur = 0;
16176
- const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
16420
+ if (ggml_is_quantized(node->src[0]->type)) {
16421
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
16422
+ }
16423
+
16424
+ work_size = MAX(work_size, cur);
16425
+ } break;
16426
+ case GGML_OP_SUB:
16427
+ case GGML_OP_DIV:
16428
+ case GGML_OP_SQR:
16429
+ case GGML_OP_SQRT:
16430
+ case GGML_OP_LOG:
16431
+ case GGML_OP_SUM:
16432
+ case GGML_OP_SUM_ROWS:
16433
+ case GGML_OP_MEAN:
16434
+ case GGML_OP_ARGMAX:
16435
+ case GGML_OP_REPEAT:
16436
+ case GGML_OP_REPEAT_BACK:
16437
+ case GGML_OP_ABS:
16438
+ case GGML_OP_SGN:
16439
+ case GGML_OP_NEG:
16440
+ case GGML_OP_STEP:
16441
+ case GGML_OP_TANH:
16442
+ case GGML_OP_ELU:
16443
+ case GGML_OP_RELU:
16444
+ {
16445
+ n_tasks = 1;
16446
+ } break;
16447
+ case GGML_OP_MUL:
16448
+ case GGML_OP_GELU:
16449
+ case GGML_OP_GELU_QUICK:
16450
+ case GGML_OP_SILU:
16451
+ case GGML_OP_SILU_BACK:
16452
+ case GGML_OP_NORM:
16453
+ case GGML_OP_RMS_NORM:
16454
+ case GGML_OP_RMS_NORM_BACK:
16455
+ {
16456
+ n_tasks = n_threads;
16457
+ } break;
16458
+ case GGML_OP_MUL_MAT:
16459
+ case GGML_OP_OUT_PROD:
16460
+ {
16461
+ n_tasks = n_threads;
16462
+
16463
+ // TODO: use different scheduling for different matrix sizes
16464
+ //const int nr0 = ggml_nrows(node->src[0]);
16465
+ //const int nr1 = ggml_nrows(node->src[1]);
16466
+
16467
+ //n_tasks = MIN(n_threads, MAX(1, nr0/128));
16468
+ //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
16469
+
16470
+ size_t cur = 0;
16471
+ const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
16177
16472
 
16178
16473
  #if defined(GGML_USE_CUBLAS)
16179
- if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
16180
- node->n_tasks = 1; // TODO: this actually is doing nothing
16181
- // the threads are still spinning
16182
- }
16183
- else
16474
+ if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
16475
+ n_tasks = 1; // TODO: this actually is doing nothing
16476
+ // the threads are still spinning
16477
+ } else
16184
16478
  #elif defined(GGML_USE_CLBLAST)
16185
- if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
16186
- node->n_tasks = 1; // TODO: this actually is doing nothing
16187
- // the threads are still spinning
16188
- cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
16189
- }
16190
- else
16479
+ if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
16480
+ n_tasks = 1; // TODO: this actually is doing nothing
16481
+ // the threads are still spinning
16482
+ cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
16483
+ } else
16191
16484
  #endif
16192
16485
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16193
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
16194
- node->n_tasks = 1; // TODO: this actually is doing nothing
16195
- // the threads are still spinning
16196
- if (node->src0->type != GGML_TYPE_F32) {
16197
- // here we need memory just for single 2D matrix from src0
16198
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
16199
- }
16200
- } else
16201
- #endif
16202
- if (node->src1->type != vec_dot_type) {
16203
- cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
16204
- } else {
16205
- cur = 0;
16486
+ if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
16487
+ n_tasks = 1; // TODO: this actually is doing nothing
16488
+ // the threads are still spinning
16489
+ if (node->src[0]->type != GGML_TYPE_F32) {
16490
+ // here we need memory just for single 2D matrix from src0
16491
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
16206
16492
  }
16493
+ } else
16494
+ #endif
16495
+ if (node->src[1]->type != vec_dot_type) {
16496
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
16497
+ } else {
16498
+ cur = 0;
16499
+ }
16207
16500
 
16208
- work_size = MAX(work_size, cur);
16209
- } break;
16210
- case GGML_OP_SCALE:
16211
- {
16212
- node->n_tasks = 1;
16213
- } break;
16214
- case GGML_OP_SET:
16215
- case GGML_OP_CONT:
16216
- case GGML_OP_RESHAPE:
16217
- case GGML_OP_VIEW:
16218
- case GGML_OP_PERMUTE:
16219
- case GGML_OP_TRANSPOSE:
16220
- case GGML_OP_GET_ROWS:
16221
- case GGML_OP_GET_ROWS_BACK:
16222
- case GGML_OP_DIAG:
16223
- case GGML_OP_DIAG_MASK_ZERO:
16224
- {
16225
- node->n_tasks = 1;
16226
- } break;
16227
- case GGML_OP_DIAG_MASK_INF:
16228
- case GGML_OP_SOFT_MAX:
16229
- case GGML_OP_SOFT_MAX_BACK:
16230
- case GGML_OP_ROPE:
16231
- case GGML_OP_ROPE_BACK:
16232
- {
16233
- node->n_tasks = n_threads;
16234
- } break;
16235
- case GGML_OP_ALIBI:
16236
- {
16237
- node->n_tasks = 1; //TODO
16238
- } break;
16239
- case GGML_OP_CLAMP:
16240
- {
16241
- node->n_tasks = 1; //TODO
16242
- } break;
16243
- case GGML_OP_CONV_1D:
16244
- {
16245
- node->n_tasks = n_threads;
16246
-
16247
- GGML_ASSERT(node->src0->ne[3] == 1);
16248
- GGML_ASSERT(node->src1->ne[2] == 1);
16249
- GGML_ASSERT(node->src1->ne[3] == 1);
16250
-
16251
- size_t cur = 0;
16252
- const int nk = node->src0->ne[0];
16253
-
16254
- if (node->src0->type == GGML_TYPE_F16 &&
16255
- node->src1->type == GGML_TYPE_F32) {
16256
- cur = sizeof(ggml_fp16_t)*(
16257
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16258
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16259
- );
16260
- } else if (node->src0->type == GGML_TYPE_F32 &&
16261
- node->src1->type == GGML_TYPE_F32) {
16262
- cur = sizeof(float)*(
16263
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16264
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16265
- );
16266
- } else {
16267
- GGML_ASSERT(false);
16268
- }
16501
+ work_size = MAX(work_size, cur);
16502
+ } break;
16503
+ case GGML_OP_SCALE:
16504
+ {
16505
+ n_tasks = 1;
16506
+ } break;
16507
+ case GGML_OP_SET:
16508
+ case GGML_OP_CONT:
16509
+ case GGML_OP_RESHAPE:
16510
+ case GGML_OP_VIEW:
16511
+ case GGML_OP_PERMUTE:
16512
+ case GGML_OP_TRANSPOSE:
16513
+ case GGML_OP_GET_ROWS:
16514
+ case GGML_OP_GET_ROWS_BACK:
16515
+ case GGML_OP_DIAG:
16516
+ case GGML_OP_DIAG_MASK_ZERO:
16517
+ {
16518
+ n_tasks = 1;
16519
+ } break;
16520
+ case GGML_OP_DIAG_MASK_INF:
16521
+ case GGML_OP_SOFT_MAX:
16522
+ case GGML_OP_SOFT_MAX_BACK:
16523
+ case GGML_OP_ROPE:
16524
+ case GGML_OP_ROPE_BACK:
16525
+ {
16526
+ n_tasks = n_threads;
16527
+ } break;
16528
+ case GGML_OP_ALIBI:
16529
+ {
16530
+ n_tasks = 1; //TODO
16531
+ } break;
16532
+ case GGML_OP_CLAMP:
16533
+ {
16534
+ n_tasks = 1; //TODO
16535
+ } break;
16536
+ case GGML_OP_CONV_1D:
16537
+ {
16538
+ n_tasks = n_threads;
16539
+
16540
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
16541
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
16542
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
16543
+
16544
+ size_t cur = 0;
16545
+ const int nk = node->src[0]->ne[0];
16546
+
16547
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16548
+ node->src[1]->type == GGML_TYPE_F32) {
16549
+ cur = sizeof(ggml_fp16_t)*(
16550
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16551
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16552
+ );
16553
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16554
+ node->src[1]->type == GGML_TYPE_F32) {
16555
+ cur = sizeof(float)*(
16556
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16557
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16558
+ );
16559
+ } else {
16560
+ GGML_ASSERT(false);
16561
+ }
16269
16562
 
16270
- work_size = MAX(work_size, cur);
16271
- } break;
16272
- case GGML_OP_CONV_2D:
16273
- {
16274
- node->n_tasks = n_threads;
16563
+ work_size = MAX(work_size, cur);
16564
+ } break;
16565
+ case GGML_OP_CONV_2D:
16566
+ {
16567
+ n_tasks = n_threads;
16275
16568
 
16276
- GGML_ASSERT(node->src1->ne[3] == 1);
16569
+ const int64_t ne00 = node->src[0]->ne[0]; // W
16570
+ const int64_t ne01 = node->src[0]->ne[1]; // H
16571
+ const int64_t ne02 = node->src[0]->ne[2]; // C
16572
+ const int64_t ne03 = node->src[0]->ne[3]; // N
16277
16573
 
16278
- const int64_t ne00 = node->src0->ne[0]; // W
16279
- const int64_t ne01 = node->src0->ne[1]; // H
16280
- const int64_t ne02 = node->src0->ne[2]; // C
16281
- const int64_t ne03 = node->src0->ne[3]; // N
16574
+ const int64_t ne10 = node->src[1]->ne[0]; // W
16575
+ const int64_t ne11 = node->src[1]->ne[1]; // H
16576
+ const int64_t ne12 = node->src[1]->ne[2]; // C
16282
16577
 
16283
- const int64_t ne10 = node->src1->ne[0]; // W
16284
- const int64_t ne11 = node->src1->ne[1]; // H
16285
- const int64_t ne12 = node->src1->ne[2]; // C
16578
+ const int64_t nk = ne00*ne01;
16286
16579
 
16287
- const int64_t nk = ne00*ne01;
16580
+ UNUSED(ne02);
16581
+ UNUSED(ne03);
16582
+ UNUSED(nk);
16288
16583
 
16289
- UNUSED(ne02);
16290
- UNUSED(ne03);
16291
- UNUSED(nk);
16584
+ size_t cur = 0;
16292
16585
 
16293
- size_t cur = 0;
16586
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16587
+ node->src[1]->type == GGML_TYPE_F32) {
16588
+ cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16589
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16590
+ node->src[1]->type == GGML_TYPE_F32) {
16591
+ cur = sizeof(float)* (ne10*ne11*ne12);
16592
+ } else {
16593
+ GGML_ASSERT(false);
16594
+ }
16294
16595
 
16295
- if (node->src0->type == GGML_TYPE_F16 &&
16296
- node->src1->type == GGML_TYPE_F32) {
16297
- cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16298
- } else if (node->src0->type == GGML_TYPE_F32 &&
16299
- node->src1->type == GGML_TYPE_F32) {
16300
- cur = sizeof(float)* (ne10*ne11*ne12);
16301
- } else {
16302
- GGML_ASSERT(false);
16303
- }
16596
+ work_size = MAX(work_size, cur);
16597
+ } break;
16598
+ case GGML_OP_POOL_1D:
16599
+ case GGML_OP_POOL_2D:
16600
+ {
16601
+ n_tasks = 1;
16602
+ } break;
16603
+ case GGML_OP_FLASH_ATTN:
16604
+ {
16605
+ n_tasks = n_threads;
16304
16606
 
16305
- work_size = MAX(work_size, cur);
16306
- } break;
16307
- case GGML_OP_FLASH_ATTN:
16308
- {
16309
- node->n_tasks = n_threads;
16607
+ size_t cur = 0;
16310
16608
 
16311
- size_t cur = 0;
16609
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16312
16610
 
16313
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16611
+ if (node->src[1]->type == GGML_TYPE_F32) {
16612
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16613
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16614
+ }
16314
16615
 
16315
- if (node->src1->type == GGML_TYPE_F32) {
16316
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16317
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16318
- }
16616
+ if (node->src[1]->type == GGML_TYPE_F16) {
16617
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16618
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16619
+ }
16319
16620
 
16320
- if (node->src1->type == GGML_TYPE_F16) {
16321
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16322
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16323
- }
16621
+ work_size = MAX(work_size, cur);
16622
+ } break;
16623
+ case GGML_OP_FLASH_FF:
16624
+ {
16625
+ n_tasks = n_threads;
16324
16626
 
16325
- work_size = MAX(work_size, cur);
16326
- } break;
16327
- case GGML_OP_FLASH_FF:
16328
- {
16329
- node->n_tasks = n_threads;
16627
+ size_t cur = 0;
16330
16628
 
16331
- size_t cur = 0;
16629
+ if (node->src[1]->type == GGML_TYPE_F32) {
16630
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16631
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16632
+ }
16332
16633
 
16333
- if (node->src1->type == GGML_TYPE_F32) {
16334
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16335
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16336
- }
16634
+ if (node->src[1]->type == GGML_TYPE_F16) {
16635
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16636
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16637
+ }
16337
16638
 
16338
- if (node->src1->type == GGML_TYPE_F16) {
16339
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16340
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16341
- }
16639
+ work_size = MAX(work_size, cur);
16640
+ } break;
16641
+ case GGML_OP_FLASH_ATTN_BACK:
16642
+ {
16643
+ n_tasks = n_threads;
16342
16644
 
16343
- work_size = MAX(work_size, cur);
16344
- } break;
16345
- case GGML_OP_FLASH_ATTN_BACK:
16346
- {
16347
- node->n_tasks = n_threads;
16645
+ size_t cur = 0;
16348
16646
 
16349
- size_t cur = 0;
16647
+ const int64_t D = node->src[0]->ne[0];
16648
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16649
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16650
+ if (node->src[1]->type == GGML_TYPE_F32) {
16651
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16652
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16653
+ }
16350
16654
 
16351
- const int64_t D = node->src0->ne[0];
16352
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16353
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16354
- if (node->src1->type == GGML_TYPE_F32) {
16355
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16356
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16357
- }
16655
+ if (node->src[1]->type == GGML_TYPE_F16) {
16656
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16657
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16658
+ }
16358
16659
 
16359
- if (node->src1->type == GGML_TYPE_F16) {
16360
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16361
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16362
- }
16660
+ work_size = MAX(work_size, cur);
16661
+ } break;
16662
+ case GGML_OP_WIN_PART:
16663
+ case GGML_OP_WIN_UNPART:
16664
+ case GGML_OP_MAP_UNARY:
16665
+ case GGML_OP_MAP_BINARY:
16666
+ case GGML_OP_MAP_CUSTOM1:
16667
+ case GGML_OP_MAP_CUSTOM2:
16668
+ case GGML_OP_MAP_CUSTOM3:
16669
+ {
16670
+ n_tasks = 1;
16671
+ } break;
16672
+ case GGML_OP_CROSS_ENTROPY_LOSS:
16673
+ {
16674
+ n_tasks = n_threads;
16363
16675
 
16364
- work_size = MAX(work_size, cur);
16365
- } break;
16366
- case GGML_OP_WIN_PART:
16367
- case GGML_OP_WIN_UNPART:
16368
- case GGML_OP_MAP_UNARY:
16369
- case GGML_OP_MAP_BINARY:
16370
- case GGML_OP_MAP_CUSTOM1:
16371
- case GGML_OP_MAP_CUSTOM2:
16372
- case GGML_OP_MAP_CUSTOM3:
16373
- {
16374
- node->n_tasks = 1;
16375
- } break;
16376
- case GGML_OP_CROSS_ENTROPY_LOSS:
16377
- {
16378
- node->n_tasks = n_threads;
16379
-
16380
- size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
16381
-
16382
- work_size = MAX(work_size, cur);
16383
- } break;
16384
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16385
- {
16386
- node->n_tasks = n_threads;
16387
-
16388
- size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
16389
-
16390
- work_size = MAX(work_size, cur);
16391
- } break;
16392
- case GGML_OP_NONE:
16393
- {
16394
- node->n_tasks = 1;
16395
- } break;
16396
- case GGML_OP_COUNT:
16397
- {
16398
- GGML_ASSERT(false);
16399
- } break;
16400
- }
16401
- }
16676
+ size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16677
+
16678
+ work_size = MAX(work_size, cur);
16679
+ } break;
16680
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16681
+ {
16682
+ n_tasks = n_threads;
16683
+
16684
+ size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
16402
16685
 
16403
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
16404
- GGML_ASSERT(false); // TODO: better handling
16686
+ work_size = MAX(work_size, cur);
16687
+ } break;
16688
+ case GGML_OP_NONE:
16689
+ {
16690
+ n_tasks = 1;
16691
+ } break;
16692
+ case GGML_OP_COUNT:
16693
+ {
16694
+ GGML_ASSERT(false);
16695
+ } break;
16405
16696
  }
16406
16697
 
16407
- if (work_size > 0 && cgraph->work == NULL) {
16408
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
16698
+ cplan.n_tasks[i] = n_tasks;
16699
+ }
16700
+
16701
+ if (work_size > 0) {
16702
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
16703
+ }
16704
+
16705
+ cplan.n_threads = n_threads;
16706
+ cplan.work_size = work_size;
16707
+ cplan.work_data = NULL;
16708
+
16709
+ return cplan;
16710
+ }
16711
+
16712
+ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
16713
+ {
16714
+ GGML_ASSERT(cplan);
16715
+ GGML_ASSERT(cplan->n_threads > 0);
16409
16716
 
16410
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
16411
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
16717
+ if (cplan->work_size > 0) {
16718
+ GGML_ASSERT(cplan->work_data);
16719
+ }
16720
+
16721
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
16722
+ if (cgraph->nodes[i]->op != GGML_OP_NONE) {
16723
+ GGML_ASSERT(cplan->n_tasks[i] > 0);
16724
+ }
16412
16725
  }
16413
16726
  }
16414
16727
 
16728
+ const int n_threads = cplan->n_threads;
16729
+
16730
+ struct ggml_compute_state_shared state_shared = {
16731
+ /*.cgraph =*/ cgraph,
16732
+ /*.cgraph_plan =*/ cplan,
16733
+ /*.perf_node_start_cycles =*/ 0,
16734
+ /*.perf_node_start_time_us =*/ 0,
16735
+ /*.n_threads =*/ n_threads,
16736
+ /*.n_active =*/ n_threads,
16737
+ /*.node_n =*/ -1,
16738
+ /*.abort_callback =*/ NULL,
16739
+ /*.abort_callback_data =*/ NULL,
16740
+ };
16741
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16742
+
16415
16743
  // create thread pool
16416
16744
  if (n_threads > 1) {
16417
16745
  for (int j = 1; j < n_threads; ++j) {
@@ -16432,12 +16760,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16432
16760
  const int64_t perf_start_time_us = ggml_perf_time_us();
16433
16761
 
16434
16762
  // this is a work thread too
16435
- ggml_graph_compute_thread(&workers[0]);
16763
+ int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
16436
16764
 
16437
16765
  // don't leave affinity set on the main thread
16438
16766
  clear_numa_thread_affinity();
16439
16767
 
16440
- // join thread pool
16768
+ // join or kill thread pool
16441
16769
  if (n_threads > 1) {
16442
16770
  for (int j = 1; j < n_threads; j++) {
16443
16771
  const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16461,6 +16789,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16461
16789
  (double) perf_time_us_cur / 1000.0,
16462
16790
  (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
16463
16791
  }
16792
+
16793
+ return compute_status;
16464
16794
  }
16465
16795
 
16466
16796
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -16473,6 +16803,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
16473
16803
  }
16474
16804
  }
16475
16805
 
16806
+ void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
16807
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
16808
+
16809
+ struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16810
+ GGML_ASSERT(buf);
16811
+
16812
+ cplan.work_data = buf->data;
16813
+
16814
+ ggml_graph_compute(cgraph, &cplan);
16815
+ }
16816
+
16476
16817
  struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
16477
16818
  for (int i = 0; i < cgraph->n_leafs; i++) {
16478
16819
  struct ggml_tensor * leaf = cgraph->leafs[i];
@@ -16511,14 +16852,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
16511
16852
  const int64_t * ne = tensor->ne;
16512
16853
  const size_t * nb = tensor->nb;
16513
16854
 
16514
- fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
16855
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
16515
16856
  arg,
16516
16857
  ggml_type_name(tensor->type),
16517
16858
  ggml_op_name (tensor->op),
16518
16859
  tensor->n_dims,
16519
16860
  ne[0], ne[1], ne[2], ne[3],
16520
16861
  nb[0], nb[1], nb[2], nb[3],
16521
- tensor->n_tasks,
16522
16862
  tensor->data,
16523
16863
  tensor->name);
16524
16864
  }
@@ -16555,8 +16895,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16555
16895
  ggml_graph_export_leaf(cgraph->leafs[i], fout);
16556
16896
 
16557
16897
  GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE);
16558
- GGML_ASSERT(cgraph->leafs[i]->src0 == NULL);
16559
- GGML_ASSERT(cgraph->leafs[i]->src1 == NULL);
16898
+ GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL);
16899
+ GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL);
16560
16900
  }
16561
16901
 
16562
16902
  // header
@@ -16567,17 +16907,9 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16567
16907
  for (int i = 0; i < cgraph->n_nodes; ++i) {
16568
16908
  ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
16569
16909
 
16570
- if (cgraph->nodes[i]->src0) {
16571
- ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout);
16572
- }
16573
-
16574
- if (cgraph->nodes[i]->src1) {
16575
- ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout);
16576
- }
16577
-
16578
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16579
- if (cgraph->nodes[i]->opt[j]) {
16580
- ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout);
16910
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16911
+ if (cgraph->nodes[i]->src[j]) {
16912
+ ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
16581
16913
  }
16582
16914
  }
16583
16915
 
@@ -16668,16 +17000,13 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16668
17000
 
16669
17001
  // output the op arguments
16670
17002
  {
16671
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
16672
-
16673
- args[0] = tensor->src0;
16674
- args[1] = tensor->src1;
17003
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16675
17004
 
16676
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16677
- args[2 + j] = tensor->opt[j];
17005
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17006
+ args[j] = tensor->src[j];
16678
17007
  }
16679
17008
 
16680
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17009
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16681
17010
  if (args[j]) {
16682
17011
  int32_t idx = -1;
16683
17012
 
@@ -16895,12 +17224,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16895
17224
 
16896
17225
  const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
16897
17226
 
16898
- const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
17227
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
16899
17228
 
16900
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
17229
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16901
17230
 
16902
17231
  // parse args
16903
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17232
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16904
17233
  const int32_t arg_idx = ptr_arg_idx[j];
16905
17234
 
16906
17235
  if (arg_idx == -1) {
@@ -16957,11 +17286,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16957
17286
  tensor->nb[j] = nb[j];
16958
17287
  }
16959
17288
 
16960
- tensor->src0 = args[0];
16961
- tensor->src1 = args[1];
16962
-
16963
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16964
- tensor->opt[j] = args[2 + j];
17289
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17290
+ tensor->src[j] = args[j];
16965
17291
  }
16966
17292
 
16967
17293
  result.nodes[i] = tensor;
@@ -17160,19 +17486,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17160
17486
  for (int i = 0; i < gb->n_nodes; i++) {
17161
17487
  struct ggml_tensor * node = gb->nodes[i];
17162
17488
 
17163
- if (node->src0) {
17164
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x");
17165
- }
17166
-
17167
- if (node->src1) {
17168
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y");
17169
- }
17170
-
17171
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17172
- if (node->opt[j]) {
17489
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17490
+ if (node->src[j]) {
17173
17491
  char label[16];
17174
- snprintf(label, sizeof(label), "opt %d", j);
17175
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label);
17492
+ snprintf(label, sizeof(label), "src %d", j);
17493
+ ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
17176
17494
  }
17177
17495
  }
17178
17496
  }
@@ -17180,19 +17498,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17180
17498
  for (int i = 0; i < gb->n_leafs; i++) {
17181
17499
  struct ggml_tensor * node = gb->leafs[i];
17182
17500
 
17183
- if (node->src0) {
17184
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x");
17185
- }
17186
-
17187
- if (node->src1) {
17188
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y");
17189
- }
17190
-
17191
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17192
- if (node->opt[j]) {
17501
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17502
+ if (node->src[j]) {
17193
17503
  char label[16];
17194
- snprintf(label, sizeof(label), "opt %d", j);
17195
- ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label);
17504
+ snprintf(label, sizeof(label), "src %d", j);
17505
+ ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
17196
17506
  }
17197
17507
  }
17198
17508
  }
@@ -17254,9 +17564,6 @@ static enum ggml_opt_result ggml_opt_adam(
17254
17564
  struct ggml_cgraph * gb) {
17255
17565
  GGML_ASSERT(ggml_is_scalar(f));
17256
17566
 
17257
- gf->n_threads = params.n_threads;
17258
- gb->n_threads = params.n_threads;
17259
-
17260
17567
  // these will store the parameters we want to optimize
17261
17568
  struct ggml_tensor * ps[GGML_MAX_PARAMS];
17262
17569
 
@@ -17303,7 +17610,8 @@ static enum ggml_opt_result ggml_opt_adam(
17303
17610
  // compute the function value
17304
17611
  ggml_graph_reset (gf);
17305
17612
  ggml_set_f32 (f->grad, 1.0f);
17306
- ggml_graph_compute(ctx, gb);
17613
+
17614
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17307
17615
 
17308
17616
  opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17309
17617
  opt->adam.fx_best = opt->adam.fx_prev;
@@ -17383,7 +17691,8 @@ static enum ggml_opt_result ggml_opt_adam(
17383
17691
 
17384
17692
  ggml_graph_reset (gf);
17385
17693
  ggml_set_f32 (f->grad, 1.0f);
17386
- ggml_graph_compute(ctx, gb);
17694
+
17695
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17387
17696
 
17388
17697
  const float fx = ggml_get_f32_1d(f, 0);
17389
17698
 
@@ -17505,7 +17814,8 @@ static enum ggml_opt_result linesearch_backtracking(
17505
17814
 
17506
17815
  ggml_graph_reset (gf);
17507
17816
  ggml_set_f32 (f->grad, 1.0f);
17508
- ggml_graph_compute(ctx, gb);
17817
+
17818
+ ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
17509
17819
 
17510
17820
  ggml_opt_get_grad(np, ps, g);
17511
17821
 
@@ -17573,9 +17883,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17573
17883
  }
17574
17884
  }
17575
17885
 
17576
- gf->n_threads = params.n_threads;
17577
- gb->n_threads = params.n_threads;
17578
-
17579
17886
  const int m = params.lbfgs.m;
17580
17887
 
17581
17888
  // these will store the parameters we want to optimize
@@ -17627,7 +17934,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17627
17934
 
17628
17935
  ggml_graph_reset (gf);
17629
17936
  ggml_set_f32 (f->grad, 1.0f);
17630
- ggml_graph_compute(ctx, gb);
17937
+
17938
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17631
17939
 
17632
17940
  ggml_opt_get_grad(np, ps, g);
17633
17941