llama_cpp 0.3.2 → 0.3.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -25,6 +25,7 @@
25
25
  #include <float.h>
26
26
  #include <limits.h>
27
27
  #include <stdarg.h>
28
+ #include <signal.h>
28
29
 
29
30
  #ifdef GGML_USE_METAL
30
31
  #include <unistd.h>
@@ -49,23 +50,23 @@
49
50
  typedef volatile LONG atomic_int;
50
51
  typedef atomic_int atomic_bool;
51
52
 
52
- static void atomic_store(atomic_int* ptr, LONG val) {
53
+ static void atomic_store(atomic_int * ptr, LONG val) {
53
54
  InterlockedExchange(ptr, val);
54
55
  }
55
- static LONG atomic_load(atomic_int* ptr) {
56
+ static LONG atomic_load(atomic_int * ptr) {
56
57
  return InterlockedCompareExchange(ptr, 0, 0);
57
58
  }
58
- static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
59
+ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
59
60
  return InterlockedExchangeAdd(ptr, inc);
60
61
  }
61
- static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
62
+ static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
62
63
  return atomic_fetch_add(ptr, -(dec));
63
64
  }
64
65
 
65
66
  typedef HANDLE pthread_t;
66
67
 
67
68
  typedef DWORD thread_ret_t;
68
- static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
69
+ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
69
70
  (void) unused;
70
71
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
71
72
  if (handle == NULL)
@@ -77,7 +78,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
77
78
  return 0;
78
79
  }
79
80
 
80
- static int pthread_join(pthread_t thread, void* unused) {
81
+ static int pthread_join(pthread_t thread, void * unused) {
81
82
  (void) unused;
82
83
  return (int) WaitForSingleObject(thread, INFINITE);
83
84
  }
@@ -90,7 +91,7 @@ static int sched_yield (void) {
90
91
  #include <pthread.h>
91
92
  #include <stdatomic.h>
92
93
 
93
- typedef void* thread_ret_t;
94
+ typedef void * thread_ret_t;
94
95
 
95
96
  #include <sys/types.h>
96
97
  #include <sys/stat.h>
@@ -247,7 +248,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
247
248
  #include "ggml-opencl.h"
248
249
  #endif
249
250
  #elif defined(GGML_USE_OPENBLAS)
251
+ #if defined(GGML_BLAS_USE_MKL)
252
+ #include <mkl.h>
253
+ #else
250
254
  #include <cblas.h>
255
+ #endif
251
256
  #elif defined(GGML_USE_CUBLAS)
252
257
  #include "ggml-cuda.h"
253
258
  #elif defined(GGML_USE_CLBLAST)
@@ -3782,6 +3787,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3782
3787
  "CLAMP",
3783
3788
  "CONV_1D",
3784
3789
  "CONV_2D",
3790
+ "POOL_1D",
3791
+ "POOL_2D",
3785
3792
 
3786
3793
  "FLASH_ATTN",
3787
3794
  "FLASH_FF",
@@ -3800,7 +3807,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3800
3807
  "CROSS_ENTROPY_LOSS_BACK",
3801
3808
  };
3802
3809
 
3803
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3810
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3804
3811
 
3805
3812
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3806
3813
  "none",
@@ -3860,6 +3867,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3860
3867
  "clamp(x)",
3861
3868
  "conv_1d(x)",
3862
3869
  "conv_2d(x)",
3870
+ "pool_1d(x)",
3871
+ "pool_2d(x)",
3863
3872
 
3864
3873
  "flash_attn(x)",
3865
3874
  "flash_ff(x)",
@@ -3878,7 +3887,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3878
3887
  "cross_entropy_loss_back(x,y)",
3879
3888
  };
3880
3889
 
3881
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3890
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3891
+
3892
+ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3882
3893
 
3883
3894
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3884
3895
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4157,10 +4168,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
4157
4168
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4158
4169
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4159
4170
 
4160
- return
4161
- (t0->ne[0] == t1->ne[0]) &&
4162
- (t0->ne[2] == t1->ne[2]) &&
4163
- (t0->ne[3] == t1->ne[3]);
4171
+ return (t0->ne[0] == t1->ne[0]) &&
4172
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4173
+ (t1->ne[3]%t0->ne[3] == 0);
4164
4174
  }
4165
4175
 
4166
4176
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4580,17 +4590,14 @@ struct ggml_tensor * ggml_new_tensor_impl(
4580
4590
  /*.op =*/ GGML_OP_NONE,
4581
4591
  /*.is_param =*/ false,
4582
4592
  /*.grad =*/ NULL,
4583
- /*.src0 =*/ NULL,
4584
- /*.src1 =*/ NULL,
4585
- /*.opt =*/ { NULL },
4586
- /*.n_tasks =*/ 0,
4593
+ /*.src =*/ { NULL },
4587
4594
  /*.perf_runs =*/ 0,
4588
4595
  /*.perf_cycles =*/ 0,
4589
4596
  /*.perf_time_us =*/ 0,
4590
4597
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4591
4598
  /*.name =*/ { 0 },
4592
4599
  /*.extra =*/ NULL,
4593
- /*.pad =*/ { 0 },
4600
+ /*.padding =*/ { 0 },
4594
4601
  };
4595
4602
 
4596
4603
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@@ -4722,7 +4729,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
4722
4729
  {
4723
4730
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4724
4731
  for (int i = 0; i < n; i++) {
4725
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4732
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4726
4733
  }
4727
4734
  } break;
4728
4735
  case GGML_TYPE_F32:
@@ -4774,7 +4781,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
4774
4781
  {
4775
4782
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4776
4783
  for (int i = 0; i < n; i++) {
4777
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4784
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4778
4785
  }
4779
4786
  } break;
4780
4787
  case GGML_TYPE_F32:
@@ -5009,8 +5016,8 @@ struct ggml_tensor * ggml_dup_impl(
5009
5016
 
5010
5017
  result->op = GGML_OP_DUP;
5011
5018
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5012
- result->src0 = a;
5013
- result->src1 = NULL;
5019
+ result->src[0] = a;
5020
+ result->src[1] = NULL;
5014
5021
 
5015
5022
  return result;
5016
5023
  }
@@ -5034,11 +5041,15 @@ struct ggml_tensor * ggml_add_impl(
5034
5041
  struct ggml_tensor * a,
5035
5042
  struct ggml_tensor * b,
5036
5043
  bool inplace) {
5037
- GGML_ASSERT(ggml_are_same_shape(a, b));
5044
+ // TODO: support less-strict constraint
5045
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5046
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5038
5047
 
5039
5048
  bool is_node = false;
5040
5049
 
5041
- if (a->grad || b->grad) {
5050
+ if (!inplace && (a->grad || b->grad)) {
5051
+ // TODO: support backward pass for broadcasting
5052
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5042
5053
  is_node = true;
5043
5054
  }
5044
5055
 
@@ -5046,8 +5057,8 @@ struct ggml_tensor * ggml_add_impl(
5046
5057
 
5047
5058
  result->op = GGML_OP_ADD;
5048
5059
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5049
- result->src0 = a;
5050
- result->src1 = b;
5060
+ result->src[0] = a;
5061
+ result->src[1] = b;
5051
5062
 
5052
5063
  return result;
5053
5064
  }
@@ -5086,8 +5097,8 @@ struct ggml_tensor * ggml_add1_impl(
5086
5097
 
5087
5098
  result->op = GGML_OP_ADD1;
5088
5099
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5089
- result->src0 = a;
5090
- result->src1 = b;
5100
+ result->src[0] = a;
5101
+ result->src[1] = b;
5091
5102
 
5092
5103
  return result;
5093
5104
  }
@@ -5144,9 +5155,9 @@ struct ggml_tensor * ggml_acc_impl(
5144
5155
 
5145
5156
  result->op = GGML_OP_ACC;
5146
5157
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5147
- result->src0 = a;
5148
- result->src1 = b;
5149
- result->opt[0] = c;
5158
+ result->src[0] = a;
5159
+ result->src[1] = b;
5160
+ result->src[2] = c;
5150
5161
 
5151
5162
  return result;
5152
5163
  }
@@ -5192,8 +5203,8 @@ struct ggml_tensor * ggml_sub_impl(
5192
5203
 
5193
5204
  result->op = GGML_OP_SUB;
5194
5205
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5195
- result->src0 = a;
5196
- result->src1 = b;
5206
+ result->src[0] = a;
5207
+ result->src[1] = b;
5197
5208
 
5198
5209
  return result;
5199
5210
  }
@@ -5239,8 +5250,8 @@ struct ggml_tensor * ggml_mul_impl(
5239
5250
 
5240
5251
  result->op = GGML_OP_MUL;
5241
5252
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5242
- result->src0 = a;
5243
- result->src1 = b;
5253
+ result->src[0] = a;
5254
+ result->src[1] = b;
5244
5255
 
5245
5256
  return result;
5246
5257
  }
@@ -5282,8 +5293,8 @@ struct ggml_tensor * ggml_div_impl(
5282
5293
 
5283
5294
  result->op = GGML_OP_DIV;
5284
5295
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5285
- result->src0 = a;
5286
- result->src1 = b;
5296
+ result->src[0] = a;
5297
+ result->src[1] = b;
5287
5298
 
5288
5299
  return result;
5289
5300
  }
@@ -5318,8 +5329,8 @@ struct ggml_tensor * ggml_sqr_impl(
5318
5329
 
5319
5330
  result->op = GGML_OP_SQR;
5320
5331
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5321
- result->src0 = a;
5322
- result->src1 = NULL;
5332
+ result->src[0] = a;
5333
+ result->src[1] = NULL;
5323
5334
 
5324
5335
  return result;
5325
5336
  }
@@ -5352,8 +5363,8 @@ struct ggml_tensor * ggml_sqrt_impl(
5352
5363
 
5353
5364
  result->op = GGML_OP_SQRT;
5354
5365
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5355
- result->src0 = a;
5356
- result->src1 = NULL;
5366
+ result->src[0] = a;
5367
+ result->src[1] = NULL;
5357
5368
 
5358
5369
  return result;
5359
5370
  }
@@ -5387,8 +5398,8 @@ struct ggml_tensor * ggml_log_impl(
5387
5398
 
5388
5399
  result->op = GGML_OP_LOG;
5389
5400
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5390
- result->src0 = a;
5391
- result->src1 = NULL;
5401
+ result->src[0] = a;
5402
+ result->src[1] = NULL;
5392
5403
 
5393
5404
  return result;
5394
5405
  }
@@ -5420,8 +5431,8 @@ struct ggml_tensor * ggml_sum(
5420
5431
 
5421
5432
  result->op = GGML_OP_SUM;
5422
5433
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5423
- result->src0 = a;
5424
- result->src1 = NULL;
5434
+ result->src[0] = a;
5435
+ result->src[1] = NULL;
5425
5436
 
5426
5437
  return result;
5427
5438
  }
@@ -5447,8 +5458,8 @@ struct ggml_tensor * ggml_sum_rows(
5447
5458
 
5448
5459
  result->op = GGML_OP_SUM_ROWS;
5449
5460
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5450
- result->src0 = a;
5451
- result->src1 = NULL;
5461
+ result->src[0] = a;
5462
+ result->src[1] = NULL;
5452
5463
 
5453
5464
  return result;
5454
5465
  }
@@ -5470,8 +5481,8 @@ struct ggml_tensor * ggml_mean(
5470
5481
 
5471
5482
  result->op = GGML_OP_MEAN;
5472
5483
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5473
- result->src0 = a;
5474
- result->src1 = NULL;
5484
+ result->src[0] = a;
5485
+ result->src[1] = NULL;
5475
5486
 
5476
5487
  return result;
5477
5488
  }
@@ -5494,8 +5505,8 @@ struct ggml_tensor * ggml_argmax(
5494
5505
 
5495
5506
  result->op = GGML_OP_ARGMAX;
5496
5507
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5497
- result->src0 = a;
5498
- result->src1 = NULL;
5508
+ result->src[0] = a;
5509
+ result->src[1] = NULL;
5499
5510
 
5500
5511
  return result;
5501
5512
  }
@@ -5522,8 +5533,8 @@ struct ggml_tensor * ggml_repeat(
5522
5533
 
5523
5534
  result->op = GGML_OP_REPEAT;
5524
5535
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5525
- result->src0 = a;
5526
- result->src1 = b;
5536
+ result->src[0] = a;
5537
+ result->src[1] = b;
5527
5538
 
5528
5539
  return result;
5529
5540
  }
@@ -5550,8 +5561,8 @@ struct ggml_tensor * ggml_repeat_back(
5550
5561
 
5551
5562
  result->op = GGML_OP_REPEAT_BACK;
5552
5563
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5553
- result->src0 = a;
5554
- result->src1 = b;
5564
+ result->src[0] = a;
5565
+ result->src[1] = b;
5555
5566
 
5556
5567
  return result;
5557
5568
  }
@@ -5572,8 +5583,8 @@ struct ggml_tensor * ggml_abs_impl(
5572
5583
 
5573
5584
  result->op = GGML_OP_ABS;
5574
5585
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5575
- result->src0 = a;
5576
- result->src1 = NULL;
5586
+ result->src[0] = a;
5587
+ result->src[1] = NULL;
5577
5588
 
5578
5589
  return result;
5579
5590
  }
@@ -5607,8 +5618,8 @@ struct ggml_tensor * ggml_sgn_impl(
5607
5618
 
5608
5619
  result->op = GGML_OP_SGN;
5609
5620
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5610
- result->src0 = a;
5611
- result->src1 = NULL;
5621
+ result->src[0] = a;
5622
+ result->src[1] = NULL;
5612
5623
 
5613
5624
  return result;
5614
5625
  }
@@ -5641,8 +5652,8 @@ struct ggml_tensor * ggml_neg_impl(
5641
5652
 
5642
5653
  result->op = GGML_OP_NEG;
5643
5654
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5644
- result->src0 = a;
5645
- result->src1 = NULL;
5655
+ result->src[0] = a;
5656
+ result->src[1] = NULL;
5646
5657
 
5647
5658
  return result;
5648
5659
  }
@@ -5675,8 +5686,8 @@ struct ggml_tensor * ggml_step_impl(
5675
5686
 
5676
5687
  result->op = GGML_OP_STEP;
5677
5688
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5678
- result->src0 = a;
5679
- result->src1 = NULL;
5689
+ result->src[0] = a;
5690
+ result->src[1] = NULL;
5680
5691
 
5681
5692
  return result;
5682
5693
  }
@@ -5709,8 +5720,8 @@ struct ggml_tensor * ggml_tanh_impl(
5709
5720
 
5710
5721
  result->op = GGML_OP_TANH;
5711
5722
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5712
- result->src0 = a;
5713
- result->src1 = NULL;
5723
+ result->src[0] = a;
5724
+ result->src[1] = NULL;
5714
5725
 
5715
5726
  return result;
5716
5727
  }
@@ -5743,8 +5754,8 @@ struct ggml_tensor * ggml_elu_impl(
5743
5754
 
5744
5755
  result->op = GGML_OP_ELU;
5745
5756
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5746
- result->src0 = a;
5747
- result->src1 = NULL;
5757
+ result->src[0] = a;
5758
+ result->src[1] = NULL;
5748
5759
 
5749
5760
  return result;
5750
5761
  }
@@ -5777,8 +5788,8 @@ struct ggml_tensor * ggml_relu_impl(
5777
5788
 
5778
5789
  result->op = GGML_OP_RELU;
5779
5790
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5780
- result->src0 = a;
5781
- result->src1 = NULL;
5791
+ result->src[0] = a;
5792
+ result->src[1] = NULL;
5782
5793
 
5783
5794
  return result;
5784
5795
  }
@@ -5811,8 +5822,8 @@ struct ggml_tensor * ggml_gelu_impl(
5811
5822
 
5812
5823
  result->op = GGML_OP_GELU;
5813
5824
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5814
- result->src0 = a;
5815
- result->src1 = NULL;
5825
+ result->src[0] = a;
5826
+ result->src[1] = NULL;
5816
5827
 
5817
5828
  return result;
5818
5829
  }
@@ -5845,8 +5856,8 @@ struct ggml_tensor * ggml_gelu_quick_impl(
5845
5856
 
5846
5857
  result->op = GGML_OP_GELU_QUICK;
5847
5858
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5848
- result->src0 = a;
5849
- result->src1 = NULL;
5859
+ result->src[0] = a;
5860
+ result->src[1] = NULL;
5850
5861
 
5851
5862
  return result;
5852
5863
  }
@@ -5879,8 +5890,8 @@ struct ggml_tensor * ggml_silu_impl(
5879
5890
 
5880
5891
  result->op = GGML_OP_SILU;
5881
5892
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5882
- result->src0 = a;
5883
- result->src1 = NULL;
5893
+ result->src[0] = a;
5894
+ result->src[1] = NULL;
5884
5895
 
5885
5896
  return result;
5886
5897
  }
@@ -5914,8 +5925,8 @@ struct ggml_tensor * ggml_silu_back(
5914
5925
 
5915
5926
  result->op = GGML_OP_SILU_BACK;
5916
5927
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5917
- result->src0 = a;
5918
- result->src1 = b;
5928
+ result->src[0] = a;
5929
+ result->src[1] = b;
5919
5930
 
5920
5931
  return result;
5921
5932
  }
@@ -5937,8 +5948,8 @@ struct ggml_tensor * ggml_norm_impl(
5937
5948
 
5938
5949
  result->op = GGML_OP_NORM;
5939
5950
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5940
- result->src0 = a;
5941
- result->src1 = NULL; // TODO: maybe store epsilon here?
5951
+ result->src[0] = a;
5952
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5942
5953
 
5943
5954
  return result;
5944
5955
  }
@@ -5969,8 +5980,8 @@ struct ggml_tensor * ggml_rms_norm_impl(
5969
5980
 
5970
5981
  result->op = GGML_OP_RMS_NORM;
5971
5982
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5972
- result->src0 = a;
5973
- result->src1 = NULL; // TODO: maybe store epsilon here?
5983
+ result->src[0] = a;
5984
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5974
5985
 
5975
5986
  return result;
5976
5987
  }
@@ -6002,8 +6013,8 @@ struct ggml_tensor * ggml_rms_norm_back(
6002
6013
 
6003
6014
  result->op = GGML_OP_RMS_NORM_BACK;
6004
6015
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6005
- result->src0 = a;
6006
- result->src1 = b;
6016
+ result->src[0] = a;
6017
+ result->src[1] = b;
6007
6018
 
6008
6019
  return result;
6009
6020
  }
@@ -6024,13 +6035,13 @@ struct ggml_tensor * ggml_mul_mat(
6024
6035
  is_node = true;
6025
6036
  }
6026
6037
 
6027
- const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
6028
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6038
+ const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
6039
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6029
6040
 
6030
6041
  result->op = GGML_OP_MUL_MAT;
6031
6042
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6032
- result->src0 = a;
6033
- result->src1 = b;
6043
+ result->src[0] = a;
6044
+ result->src[1] = b;
6034
6045
 
6035
6046
  return result;
6036
6047
  }
@@ -6055,8 +6066,8 @@ struct ggml_tensor * ggml_out_prod(
6055
6066
 
6056
6067
  result->op = GGML_OP_OUT_PROD;
6057
6068
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6058
- result->src0 = a;
6059
- result->src1 = b;
6069
+ result->src[0] = a;
6070
+ result->src[1] = b;
6060
6071
 
6061
6072
  return result;
6062
6073
  }
@@ -6081,8 +6092,8 @@ struct ggml_tensor * ggml_scale_impl(
6081
6092
 
6082
6093
  result->op = GGML_OP_SCALE;
6083
6094
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6084
- result->src0 = a;
6085
- result->src1 = b;
6095
+ result->src[0] = a;
6096
+ result->src[1] = b;
6086
6097
 
6087
6098
  return result;
6088
6099
  }
@@ -6137,9 +6148,9 @@ struct ggml_tensor * ggml_set_impl(
6137
6148
 
6138
6149
  result->op = GGML_OP_SET;
6139
6150
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6140
- result->src0 = a;
6141
- result->src1 = b;
6142
- result->opt[0] = c;
6151
+ result->src[0] = a;
6152
+ result->src[1] = b;
6153
+ result->src[2] = c;
6143
6154
 
6144
6155
  return result;
6145
6156
  }
@@ -6226,8 +6237,8 @@ struct ggml_tensor * ggml_cpy_impl(
6226
6237
 
6227
6238
  result->op = GGML_OP_CPY;
6228
6239
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6229
- result->src0 = a;
6230
- result->src1 = b;
6240
+ result->src[0] = a;
6241
+ result->src[1] = b;
6231
6242
 
6232
6243
  return result;
6233
6244
  }
@@ -6263,8 +6274,8 @@ struct ggml_tensor * ggml_cont_impl(
6263
6274
 
6264
6275
  result->op = GGML_OP_CONT;
6265
6276
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6266
- result->src0 = a;
6267
- result->src1 = NULL;
6277
+ result->src[0] = a;
6278
+ result->src[1] = NULL;
6268
6279
 
6269
6280
  return result;
6270
6281
  }
@@ -6307,8 +6318,8 @@ struct ggml_tensor * ggml_reshape(
6307
6318
 
6308
6319
  result->op = GGML_OP_RESHAPE;
6309
6320
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6310
- result->src0 = a;
6311
- result->src1 = NULL;
6321
+ result->src[0] = a;
6322
+ result->src[1] = NULL;
6312
6323
 
6313
6324
  return result;
6314
6325
  }
@@ -6332,8 +6343,8 @@ struct ggml_tensor * ggml_reshape_1d(
6332
6343
 
6333
6344
  result->op = GGML_OP_RESHAPE;
6334
6345
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6335
- result->src0 = a;
6336
- result->src1 = NULL;
6346
+ result->src[0] = a;
6347
+ result->src[1] = NULL;
6337
6348
 
6338
6349
  return result;
6339
6350
  }
@@ -6358,8 +6369,8 @@ struct ggml_tensor * ggml_reshape_2d(
6358
6369
 
6359
6370
  result->op = GGML_OP_RESHAPE;
6360
6371
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6361
- result->src0 = a;
6362
- result->src1 = NULL;
6372
+ result->src[0] = a;
6373
+ result->src[1] = NULL;
6363
6374
 
6364
6375
  return result;
6365
6376
  }
@@ -6385,8 +6396,8 @@ struct ggml_tensor * ggml_reshape_3d(
6385
6396
 
6386
6397
  result->op = GGML_OP_RESHAPE;
6387
6398
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6388
- result->src0 = a;
6389
- result->src1 = NULL;
6399
+ result->src[0] = a;
6400
+ result->src[1] = NULL;
6390
6401
 
6391
6402
  return result;
6392
6403
  }
@@ -6414,8 +6425,8 @@ struct ggml_tensor * ggml_reshape_4d(
6414
6425
 
6415
6426
  result->op = GGML_OP_RESHAPE;
6416
6427
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6417
- result->src0 = a;
6418
- result->src1 = NULL;
6428
+ result->src[0] = a;
6429
+ result->src[1] = NULL;
6419
6430
 
6420
6431
  return result;
6421
6432
  }
@@ -6447,9 +6458,9 @@ struct ggml_tensor * ggml_view_1d(
6447
6458
 
6448
6459
  result->op = GGML_OP_VIEW;
6449
6460
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6450
- result->src0 = a;
6451
- result->src1 = NULL;
6452
- result->opt[0] = offs;
6461
+ result->src[0] = a;
6462
+ result->src[1] = NULL;
6463
+ result->src[2] = offs;
6453
6464
 
6454
6465
  return result;
6455
6466
  }
@@ -6489,9 +6500,9 @@ struct ggml_tensor * ggml_view_2d(
6489
6500
 
6490
6501
  result->op = GGML_OP_VIEW;
6491
6502
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6492
- result->src0 = a;
6493
- result->src1 = NULL;
6494
- result->opt[0] = offs;
6503
+ result->src[0] = a;
6504
+ result->src[1] = NULL;
6505
+ result->src[2] = offs;
6495
6506
 
6496
6507
  return result;
6497
6508
  }
@@ -6533,9 +6544,9 @@ struct ggml_tensor * ggml_view_3d(
6533
6544
 
6534
6545
  result->op = GGML_OP_VIEW;
6535
6546
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6536
- result->src0 = a;
6537
- result->src1 = NULL;
6538
- result->opt[0] = offs;
6547
+ result->src[0] = a;
6548
+ result->src[1] = NULL;
6549
+ result->src[2] = offs;
6539
6550
 
6540
6551
  return result;
6541
6552
  }
@@ -6579,9 +6590,9 @@ struct ggml_tensor * ggml_view_4d(
6579
6590
 
6580
6591
  result->op = GGML_OP_VIEW;
6581
6592
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6582
- result->src0 = a;
6583
- result->src1 = NULL;
6584
- result->opt[0] = offs;
6593
+ result->src[0] = a;
6594
+ result->src[1] = NULL;
6595
+ result->src[2] = offs;
6585
6596
 
6586
6597
  return result;
6587
6598
  }
@@ -6641,8 +6652,8 @@ struct ggml_tensor * ggml_permute(
6641
6652
 
6642
6653
  result->op = GGML_OP_PERMUTE;
6643
6654
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6644
- result->src0 = a;
6645
- result->src1 = NULL;
6655
+ result->src[0] = a;
6656
+ result->src[1] = NULL;
6646
6657
 
6647
6658
  if (is_node) {
6648
6659
  ggml_scratch_save(ctx);
@@ -6656,7 +6667,7 @@ struct ggml_tensor * ggml_permute(
6656
6667
 
6657
6668
  ggml_scratch_load(ctx);
6658
6669
 
6659
- result->opt[0] = b;
6670
+ result->src[2] = b;
6660
6671
  }
6661
6672
 
6662
6673
  return result;
@@ -6684,8 +6695,8 @@ struct ggml_tensor * ggml_transpose(
6684
6695
 
6685
6696
  result->op = GGML_OP_TRANSPOSE;
6686
6697
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6687
- result->src0 = a;
6688
- result->src1 = NULL;
6698
+ result->src[0] = a;
6699
+ result->src[1] = NULL;
6689
6700
 
6690
6701
  return result;
6691
6702
  }
@@ -6710,8 +6721,8 @@ struct ggml_tensor * ggml_get_rows(
6710
6721
 
6711
6722
  result->op = GGML_OP_GET_ROWS;
6712
6723
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6713
- result->src0 = a;
6714
- result->src1 = b;
6724
+ result->src[0] = a;
6725
+ result->src[1] = b;
6715
6726
 
6716
6727
  return result;
6717
6728
  }
@@ -6738,9 +6749,9 @@ struct ggml_tensor * ggml_get_rows_back(
6738
6749
 
6739
6750
  result->op = GGML_OP_GET_ROWS_BACK;
6740
6751
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6741
- result->src0 = a;
6742
- result->src1 = b;
6743
- result->opt[0] = c;
6752
+ result->src[0] = a;
6753
+ result->src[1] = b;
6754
+ result->src[2] = c;
6744
6755
 
6745
6756
  return result;
6746
6757
  }
@@ -6762,8 +6773,8 @@ struct ggml_tensor * ggml_diag(
6762
6773
 
6763
6774
  result->op = GGML_OP_DIAG;
6764
6775
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6765
- result->src0 = a;
6766
- result->src1 = NULL;
6776
+ result->src[0] = a;
6777
+ result->src[1] = NULL;
6767
6778
 
6768
6779
  return result;
6769
6780
  }
@@ -6795,8 +6806,8 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
6795
6806
 
6796
6807
  result->op = GGML_OP_DIAG_MASK_INF;
6797
6808
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6798
- result->src0 = a;
6799
- result->src1 = b;
6809
+ result->src[0] = a;
6810
+ result->src[1] = b;
6800
6811
 
6801
6812
  return result;
6802
6813
  }
@@ -6843,8 +6854,8 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
6843
6854
 
6844
6855
  result->op = GGML_OP_DIAG_MASK_ZERO;
6845
6856
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6846
- result->src0 = a;
6847
- result->src1 = b;
6857
+ result->src[0] = a;
6858
+ result->src[1] = b;
6848
6859
 
6849
6860
  return result;
6850
6861
  }
@@ -6879,8 +6890,8 @@ struct ggml_tensor * ggml_soft_max_impl(
6879
6890
 
6880
6891
  result->op = GGML_OP_SOFT_MAX;
6881
6892
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6882
- result->src0 = a;
6883
- result->src1 = NULL;
6893
+ result->src[0] = a;
6894
+ result->src[1] = NULL;
6884
6895
 
6885
6896
  return result;
6886
6897
  }
@@ -6915,8 +6926,8 @@ struct ggml_tensor * ggml_soft_max_back_impl(
6915
6926
 
6916
6927
  result->op = GGML_OP_SOFT_MAX_BACK;
6917
6928
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6918
- result->src0 = a;
6919
- result->src1 = b;
6929
+ result->src[0] = a;
6930
+ result->src[1] = b;
6920
6931
 
6921
6932
  return result;
6922
6933
  }
@@ -6967,8 +6978,8 @@ struct ggml_tensor * ggml_rope_impl(
6967
6978
 
6968
6979
  result->op = GGML_OP_ROPE;
6969
6980
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6970
- result->src0 = a;
6971
- result->src1 = b;
6981
+ result->src[0] = a;
6982
+ result->src[1] = b;
6972
6983
 
6973
6984
  return result;
6974
6985
  }
@@ -7025,8 +7036,8 @@ struct ggml_tensor * ggml_rope_back(
7025
7036
 
7026
7037
  result->op = GGML_OP_ROPE_BACK;
7027
7038
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7028
- result->src0 = a;
7029
- result->src1 = b;
7039
+ result->src[0] = a;
7040
+ result->src[1] = b;
7030
7041
 
7031
7042
  return result;
7032
7043
  }
@@ -7064,8 +7075,8 @@ struct ggml_tensor * ggml_alibi(
7064
7075
 
7065
7076
  result->op = GGML_OP_ALIBI;
7066
7077
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7067
- result->src0 = a;
7068
- result->src1 = b;
7078
+ result->src[0] = a;
7079
+ result->src[1] = b;
7069
7080
 
7070
7081
  return result;
7071
7082
  }
@@ -7098,8 +7109,8 @@ struct ggml_tensor * ggml_clamp(
7098
7109
 
7099
7110
  result->op = GGML_OP_CLAMP;
7100
7111
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7101
- result->src0 = a;
7102
- result->src1 = b;
7112
+ result->src[0] = a;
7113
+ result->src[1] = b;
7103
7114
 
7104
7115
  return result;
7105
7116
  }
@@ -7141,9 +7152,9 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7141
7152
 
7142
7153
  result->op = GGML_OP_CONV_1D;
7143
7154
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7144
- result->src0 = a;
7145
- result->src1 = b;
7146
- result->opt[0] = c;
7155
+ result->src[0] = a;
7156
+ result->src[1] = b;
7157
+ result->src[2] = c;
7147
7158
 
7148
7159
  return result;
7149
7160
  }
@@ -7161,7 +7172,6 @@ struct ggml_tensor* ggml_conv_2d(
7161
7172
  int d0,
7162
7173
  int d1) {
7163
7174
 
7164
- GGML_ASSERT(b->ne[3] == 1);
7165
7175
  GGML_ASSERT(a->ne[2] == b->ne[2]);
7166
7176
  bool is_node = false;
7167
7177
 
@@ -7173,7 +7183,7 @@ struct ggml_tensor* ggml_conv_2d(
7173
7183
  const int64_t ne[4] = {
7174
7184
  ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7175
7185
  ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
7176
- a->ne[3], 1,
7186
+ a->ne[3], b->ne[3],
7177
7187
  };
7178
7188
  struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7179
7189
 
@@ -7189,9 +7199,9 @@ struct ggml_tensor* ggml_conv_2d(
7189
7199
 
7190
7200
  result->op = GGML_OP_CONV_2D;
7191
7201
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7192
- result->src0 = a;
7193
- result->src1 = b;
7194
- result->opt[0] = c;
7202
+ result->src[0] = a;
7203
+ result->src[1] = b;
7204
+ result->src[2] = c;
7195
7205
 
7196
7206
  return result;
7197
7207
 
@@ -7208,6 +7218,98 @@ struct ggml_tensor* ggml_conv_1d_ph(
7208
7218
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7209
7219
  }
7210
7220
 
7221
+
7222
+ // ggml_pool_*
7223
+
7224
+ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
7225
+ return (ins + 2 * p - ks) / s + 1;
7226
+ }
7227
+
7228
+ // ggml_pool_2d
7229
+
7230
+ struct ggml_tensor* ggml_pool_1d(
7231
+ struct ggml_context * ctx,
7232
+ struct ggml_tensor * a,
7233
+ enum ggml_op_pool op,
7234
+ int k0,
7235
+ int s0,
7236
+ int p0) {
7237
+
7238
+ bool is_node = false;
7239
+
7240
+ if (a->grad) {
7241
+ GGML_ASSERT(false); // TODO: implement backward
7242
+ is_node = true;
7243
+ }
7244
+
7245
+ const int64_t ne[3] = {
7246
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7247
+ a->ne[1],
7248
+ };
7249
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7250
+
7251
+ ggml_scratch_save(ctx);
7252
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7253
+ ((int32_t*)c->data)[0] = op;
7254
+ ((int32_t*)c->data)[1] = k0;
7255
+ ((int32_t*)c->data)[2] = s0;
7256
+ ((int32_t*)c->data)[3] = p0;
7257
+ ggml_scratch_load(ctx);
7258
+
7259
+ result->op = GGML_OP_POOL_1D;
7260
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7261
+ result->src[0] = a;
7262
+ result->src[1] = c;
7263
+
7264
+ return result;
7265
+ }
7266
+
7267
+ // ggml_pool_2d
7268
+
7269
+ struct ggml_tensor* ggml_pool_2d(
7270
+ struct ggml_context * ctx,
7271
+ struct ggml_tensor * a,
7272
+ enum ggml_op_pool op,
7273
+ int k0,
7274
+ int k1,
7275
+ int s0,
7276
+ int s1,
7277
+ int p0,
7278
+ int p1) {
7279
+
7280
+ bool is_node = false;
7281
+
7282
+ if (a->grad) {
7283
+ GGML_ASSERT(false); // TODO: implement backward
7284
+ is_node = true;
7285
+ }
7286
+
7287
+ const int64_t ne[3] = {
7288
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7289
+ ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
7290
+ a->ne[2],
7291
+ };
7292
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7293
+
7294
+ ggml_scratch_save(ctx);
7295
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
7296
+ ((int32_t*)c->data)[0] = op;
7297
+ ((int32_t*)c->data)[1] = k0;
7298
+ ((int32_t*)c->data)[2] = k1;
7299
+ ((int32_t*)c->data)[3] = s0;
7300
+ ((int32_t*)c->data)[4] = s1;
7301
+ ((int32_t*)c->data)[5] = p0;
7302
+ ((int32_t*)c->data)[6] = p1;
7303
+ ggml_scratch_load(ctx);
7304
+
7305
+ result->op = GGML_OP_POOL_2D;
7306
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7307
+ result->src[0] = a;
7308
+ result->src[1] = c;
7309
+
7310
+ return result;
7311
+ }
7312
+
7211
7313
  // ggml_flash_attn
7212
7314
 
7213
7315
  struct ggml_tensor * ggml_flash_attn(
@@ -7230,10 +7332,10 @@ struct ggml_tensor * ggml_flash_attn(
7230
7332
 
7231
7333
  result->op = GGML_OP_FLASH_ATTN;
7232
7334
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7233
- result->src0 = q;
7234
- result->src1 = k;
7235
- result->opt[0] = v;
7236
- result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
7335
+ result->src[0] = q;
7336
+ result->src[1] = k;
7337
+ result->src[2] = v;
7338
+ result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
7237
7339
 
7238
7340
  return result;
7239
7341
  }
@@ -7261,11 +7363,11 @@ struct ggml_tensor * ggml_flash_ff(
7261
7363
 
7262
7364
  result->op = GGML_OP_FLASH_FF;
7263
7365
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7264
- result->src0 = a;
7265
- result->src1 = b0;
7266
- result->opt[0] = b1;
7267
- result->opt[1] = c0;
7268
- result->opt[2] = c1;
7366
+ result->src[0] = a;
7367
+ result->src[1] = b0;
7368
+ result->src[2] = b1;
7369
+ result->src[3] = c0;
7370
+ result->src[4] = c1;
7269
7371
 
7270
7372
  return result;
7271
7373
  }
@@ -7325,11 +7427,11 @@ struct ggml_tensor * ggml_flash_attn_back(
7325
7427
 
7326
7428
  result->op = GGML_OP_FLASH_ATTN_BACK;
7327
7429
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7328
- result->src0 = q;
7329
- result->src1 = k;
7330
- result->opt[0] = v;
7331
- result->opt[1] = d;
7332
- result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
7430
+ result->src[0] = q;
7431
+ result->src[1] = k;
7432
+ result->src[2] = v;
7433
+ result->src[3] = d;
7434
+ result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
7333
7435
 
7334
7436
  return result;
7335
7437
  }
@@ -7374,9 +7476,9 @@ struct ggml_tensor * ggml_win_part(
7374
7476
 
7375
7477
  result->op = GGML_OP_WIN_PART;
7376
7478
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7377
- result->src0 = a;
7378
- result->src1 = NULL;
7379
- result->opt[0] = b;
7479
+ result->src[0] = a;
7480
+ result->src[1] = NULL;
7481
+ result->src[2] = b;
7380
7482
 
7381
7483
  return result;
7382
7484
  }
@@ -7411,9 +7513,9 @@ struct ggml_tensor * ggml_win_unpart(
7411
7513
 
7412
7514
  result->op = GGML_OP_WIN_UNPART;
7413
7515
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7414
- result->src0 = a;
7415
- result->src1 = NULL;
7416
- result->opt[0] = b;
7516
+ result->src[0] = a;
7517
+ result->src[1] = NULL;
7518
+ result->src[2] = b;
7417
7519
 
7418
7520
  return result;
7419
7521
  }
@@ -7442,8 +7544,8 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
7442
7544
 
7443
7545
  result->op = GGML_OP_MAP_UNARY;
7444
7546
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7445
- result->src0 = a;
7446
- result->opt[0] = addr_tensor;
7547
+ result->src[0] = a;
7548
+ result->src[2] = addr_tensor;
7447
7549
 
7448
7550
  return result;
7449
7551
  }
@@ -7489,9 +7591,9 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
7489
7591
 
7490
7592
  result->op = GGML_OP_MAP_BINARY;
7491
7593
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7492
- result->src0 = a;
7493
- result->src1 = b;
7494
- result->opt[0] = addr_tensor;
7594
+ result->src[0] = a;
7595
+ result->src[1] = b;
7596
+ result->src[2] = addr_tensor;
7495
7597
 
7496
7598
  return result;
7497
7599
  }
@@ -7536,8 +7638,8 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
7536
7638
 
7537
7639
  result->op = GGML_OP_MAP_CUSTOM1;
7538
7640
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7539
- result->src0 = a;
7540
- result->opt[0] = addr_tensor;
7641
+ result->src[0] = a;
7642
+ result->src[2] = addr_tensor;
7541
7643
 
7542
7644
  return result;
7543
7645
  }
@@ -7581,9 +7683,9 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
7581
7683
 
7582
7684
  result->op = GGML_OP_MAP_CUSTOM2;
7583
7685
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7584
- result->src0 = a;
7585
- result->src1 = b;
7586
- result->opt[0] = addr_tensor;
7686
+ result->src[0] = a;
7687
+ result->src[1] = b;
7688
+ result->src[2] = addr_tensor;
7587
7689
 
7588
7690
  return result;
7589
7691
  }
@@ -7630,10 +7732,10 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
7630
7732
 
7631
7733
  result->op = GGML_OP_MAP_CUSTOM3;
7632
7734
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7633
- result->src0 = a;
7634
- result->src1 = b;
7635
- result->opt[0] = addr_tensor;
7636
- result->opt[1] = c;
7735
+ result->src[0] = a;
7736
+ result->src[1] = b;
7737
+ result->src[2] = addr_tensor;
7738
+ result->src[3] = c;
7637
7739
 
7638
7740
  return result;
7639
7741
  }
@@ -7673,8 +7775,8 @@ struct ggml_tensor * ggml_cross_entropy_loss(
7673
7775
 
7674
7776
  result->op = GGML_OP_CROSS_ENTROPY_LOSS;
7675
7777
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7676
- result->src0 = a;
7677
- result->src1 = b;
7778
+ result->src[0] = a;
7779
+ result->src[1] = b;
7678
7780
 
7679
7781
  return result;
7680
7782
  }
@@ -7693,9 +7795,9 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
7693
7795
 
7694
7796
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
7695
7797
  result->grad = NULL;
7696
- result->src0 = a;
7697
- result->src1 = b;
7698
- result->opt[0] = c;
7798
+ result->src[0] = a;
7799
+ result->src[1] = b;
7800
+ result->src[2] = c;
7699
7801
 
7700
7802
  return result;
7701
7803
  }
@@ -8296,7 +8398,7 @@ static void ggml_compute_forward_add_f32(
8296
8398
  const struct ggml_tensor * src0,
8297
8399
  const struct ggml_tensor * src1,
8298
8400
  struct ggml_tensor * dst) {
8299
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8401
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
8300
8402
 
8301
8403
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8302
8404
  return;
@@ -8321,23 +8423,23 @@ static void ggml_compute_forward_add_f32(
8321
8423
 
8322
8424
  if (nb10 == sizeof(float)) {
8323
8425
  for (int ir = ir0; ir < ir1; ++ir) {
8324
- // src0, src1 and dst are same shape => same indices
8325
- const int i3 = ir/(ne2*ne1);
8326
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8327
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8426
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8427
+ const int64_t i03 = ir/(ne02*ne01);
8428
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8429
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8430
+
8431
+ const int64_t i13 = i03 % ne13;
8432
+ const int64_t i12 = i02 % ne12;
8433
+ const int64_t i11 = i01 % ne11;
8328
8434
 
8435
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8436
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8437
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
8329
8438
 
8330
8439
  #ifdef GGML_USE_ACCELERATE
8331
- vDSP_vadd(
8332
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
8333
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
8334
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
8335
- ne0);
8440
+ vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
8336
8441
  #else
8337
- ggml_vec_add_f32(ne0,
8338
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
8339
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
8340
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8442
+ ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8341
8443
  #endif
8342
8444
  // }
8343
8445
  // }
@@ -8345,15 +8447,20 @@ static void ggml_compute_forward_add_f32(
8345
8447
  } else {
8346
8448
  // src1 is not contiguous
8347
8449
  for (int ir = ir0; ir < ir1; ++ir) {
8348
- // src0, src1 and dst are same shape => same indices
8349
- const int i3 = ir/(ne2*ne1);
8350
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8351
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8450
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8451
+ const int64_t i03 = ir/(ne02*ne01);
8452
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8453
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8454
+
8455
+ const int64_t i13 = i03 % ne13;
8456
+ const int64_t i12 = i02 % ne12;
8457
+ const int64_t i11 = i01 % ne11;
8458
+
8459
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8460
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8352
8461
 
8353
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
8354
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8355
8462
  for (int i0 = 0; i0 < ne0; i0++) {
8356
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
8463
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
8357
8464
 
8358
8465
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8359
8466
  }
@@ -10532,7 +10639,6 @@ static void ggml_compute_forward_rms_norm_back(
10532
10639
  }
10533
10640
  }
10534
10641
 
10535
-
10536
10642
  // ggml_compute_forward_mul_mat
10537
10643
 
10538
10644
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -10576,17 +10682,17 @@ static void ggml_compute_forward_mul_mat(
10576
10682
  const int ith = params->ith;
10577
10683
  const int nth = params->nth;
10578
10684
 
10579
- GGML_ASSERT(ne02 == ne12);
10580
- GGML_ASSERT(ne03 == ne13);
10581
- GGML_ASSERT(ne2 == ne12);
10582
- GGML_ASSERT(ne3 == ne13);
10583
-
10584
10685
  const enum ggml_type type = src0->type;
10585
10686
 
10586
10687
  ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10587
10688
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10588
10689
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10589
10690
 
10691
+ GGML_ASSERT(ne0 == ne01);
10692
+ GGML_ASSERT(ne1 == ne11);
10693
+ GGML_ASSERT(ne2 == ne12);
10694
+ GGML_ASSERT(ne3 == ne13);
10695
+
10590
10696
  // we don't support permuted src0 or src1
10591
10697
  GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
10592
10698
  GGML_ASSERT(nb10 == sizeof(float));
@@ -10597,16 +10703,16 @@ static void ggml_compute_forward_mul_mat(
10597
10703
  GGML_ASSERT(nb1 <= nb2);
10598
10704
  GGML_ASSERT(nb2 <= nb3);
10599
10705
 
10600
- GGML_ASSERT(ne0 == ne01);
10601
- GGML_ASSERT(ne1 == ne11);
10602
- GGML_ASSERT(ne2 == ne02);
10603
- GGML_ASSERT(ne3 == ne03);
10604
-
10605
10706
  // nb01 >= nb00 - src0 is not transposed
10606
10707
  // compute by src0 rows
10607
10708
 
10608
10709
  #if defined(GGML_USE_CLBLAST)
10609
10710
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
10711
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10712
+ // ref: https://github.com/ggerganov/ggml/pull/224
10713
+ GGML_ASSERT(ne02 == ne12);
10714
+ GGML_ASSERT(ne03 == ne13);
10715
+
10610
10716
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
10611
10717
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
10612
10718
  }
@@ -10616,6 +10722,11 @@ static void ggml_compute_forward_mul_mat(
10616
10722
 
10617
10723
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10618
10724
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
10725
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10726
+ // ref: https://github.com/ggerganov/ggml/pull/224
10727
+ GGML_ASSERT(ne02 == ne12);
10728
+ GGML_ASSERT(ne03 == ne13);
10729
+
10619
10730
  if (params->ith != 0) {
10620
10731
  return;
10621
10732
  }
@@ -10685,43 +10796,44 @@ static void ggml_compute_forward_mul_mat(
10685
10796
  return;
10686
10797
  }
10687
10798
 
10688
- // parallelize by src0 rows using ggml_vec_dot_q
10689
-
10690
- // total rows in src0
10691
- const int nr = ne01*ne02*ne03;
10799
+ // parallelize by src0 rows
10800
+ const int64_t dr = (ne01 + nth - 1)/nth;
10692
10801
 
10693
- // rows per thread
10694
- const int dr = (nr + nth - 1)/nth;
10802
+ const int64_t ir10 = dr*ith;
10803
+ const int64_t ir11 = MIN(ir10 + dr, ne01);
10695
10804
 
10696
- // row range for this thread
10697
- const int ir0 = dr*ith;
10698
- const int ir1 = MIN(ir0 + dr, nr);
10805
+ // src1 rows
10806
+ const int64_t nr1 = ne11*ne12*ne13;
10699
10807
 
10700
10808
  void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10701
- const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10702
-
10703
- for (int ir = ir0; ir < ir1; ++ir) {
10704
- // src0 indices
10705
- const int i03 = ir/(ne02*ne01);
10706
- const int i02 = (ir - i03*ne02*ne01)/ne01;
10707
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
10708
-
10709
- const int i13 = i03;
10710
- const int i12 = i02;
10711
-
10712
- const int i0 = i01;
10713
- const int i2 = i02;
10714
- const int i3 = i03;
10715
-
10716
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
10717
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
10718
-
10719
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
10720
-
10721
- assert(ne00 % 32 == 0);
10722
-
10723
- for (int64_t ic = 0; ic < ne11; ++ic) {
10724
- vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
10809
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10810
+
10811
+ for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
10812
+ const int64_t i13 = (ir1/(ne12*ne11));
10813
+ const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
10814
+ const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
10815
+
10816
+ const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
10817
+ const int64_t i03 = (ir0/(ne02));
10818
+ // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
10819
+ // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
10820
+ // GG: this is likely the correct way to broadcast, though need some more thought
10821
+ // therefore leaving the comments to remind us for now
10822
+ const int64_t i02 = (i12 / (ne12 / ne02));
10823
+ // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
10824
+ // const int64_t i02 = (ir0 - i03*ne02);
10825
+
10826
+ const int64_t i1 = i11;
10827
+ const int64_t i2 = i12;
10828
+ const int64_t i3 = i13;
10829
+
10830
+ const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
10831
+ const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
10832
+
10833
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10834
+
10835
+ for (int64_t ir = ir10; ir < ir11; ++ir) {
10836
+ vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
10725
10837
  }
10726
10838
  }
10727
10839
 
@@ -11718,7 +11830,7 @@ static void ggml_compute_forward_alibi_f32(
11718
11830
 
11719
11831
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11720
11832
  const int ne1 = src0->ne[1]; // seq_len_without_past
11721
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11833
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11722
11834
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11723
11835
 
11724
11836
  const int n = ggml_nrows(src0);
@@ -11729,8 +11841,9 @@ static void ggml_compute_forward_alibi_f32(
11729
11841
  const int nb2 = src0->nb[2];
11730
11842
  //const int nb3 = src0->nb[3];
11731
11843
 
11732
- assert(nb0 == sizeof(float));
11733
- assert(ne1 + n_past == ne0); (void) n_past;
11844
+ GGML_ASSERT(nb0 == sizeof(float));
11845
+ GGML_ASSERT(ne1 + n_past == ne0);
11846
+ GGML_ASSERT(n_head == ne2);
11734
11847
 
11735
11848
  // add alibi to src0 (KQ_scaled)
11736
11849
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11754,7 +11867,7 @@ static void ggml_compute_forward_alibi_f32(
11754
11867
  m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11755
11868
  }
11756
11869
 
11757
- pdst[0] = (i-ne0+1) * m_k + src[0];
11870
+ pdst[0] = i * m_k + src[0];
11758
11871
 
11759
11872
  }
11760
11873
  }
@@ -11783,7 +11896,7 @@ static void ggml_compute_forward_alibi_f16(
11783
11896
 
11784
11897
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11785
11898
  const int ne1 = src0->ne[1]; // seq_len_without_past
11786
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11899
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11787
11900
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11788
11901
 
11789
11902
  const int n = ggml_nrows(src0);
@@ -11794,8 +11907,9 @@ static void ggml_compute_forward_alibi_f16(
11794
11907
  const int nb2 = src0->nb[2];
11795
11908
  //const int nb3 = src0->nb[3];
11796
11909
 
11797
- assert(nb0 == sizeof(ggml_fp16_t));
11798
- assert(ne1 + n_past == ne0); (void) n_past;
11910
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
11911
+ GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
11912
+ GGML_ASSERT(n_head == ne2);
11799
11913
 
11800
11914
  // add alibi to src0 (KQ_scaled)
11801
11915
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11820,7 +11934,7 @@ static void ggml_compute_forward_alibi_f16(
11820
11934
  }
11821
11935
 
11822
11936
  // we return F32
11823
- pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
11937
+ pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
11824
11938
  }
11825
11939
  }
11826
11940
  }
@@ -12904,16 +13018,18 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12904
13018
  {
12905
13019
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12906
13020
 
12907
- for (int i12 = 0; i12 < ne12; i12++) {
12908
- const float * const src = (float *)((char *) src1->data + i12*nb12);
12909
- ggml_fp16_t * dst_data = wdata;
12910
-
12911
- for (int i1 = 0; i1 < ne1; i1++) {
12912
- for (int i0 = 0; i0 < ne0; i0++) {
12913
- for (int ik1 = 0; ik1 < nk1; ik1++) {
12914
- for (int ik0 = 0; ik0 < nk0; ik0++) {
12915
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
12916
- GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13021
+ for (int i13 = 0; i13 < ne13; i13++) {
13022
+ for (int i12 = 0; i12 < ne12; i12++) {
13023
+ const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
13024
+ ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
13025
+
13026
+ for (int i1 = 0; i1 < ne1; i1++) {
13027
+ for (int i0 = 0; i0 < ne0; i0++) {
13028
+ for (int ik1 = 0; ik1 < nk1; ik1++) {
13029
+ for (int ik0 = 0; ik0 < nk0; ik0++) {
13030
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
13031
+ GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13032
+ }
12917
13033
  }
12918
13034
  }
12919
13035
  }
@@ -12940,14 +13056,16 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12940
13056
 
12941
13057
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12942
13058
 
12943
- for (int i2 = ip0; i2 < ip1; i2++) {
12944
- float * dst_data = (float *)((char *) dst->data + i2*nb2);
12945
-
12946
- for (int i1 = 0; i1 < ne1; ++i1) {
12947
- for (int i0 = 0; i0 < ne0; ++i0) {
12948
- ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
12949
- (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
12950
- (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
13059
+ for (int i3 = 0; i3 < ne3; i3++) {
13060
+ for (int i2 = ip0; i2 < ip1; i2++) {
13061
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
13062
+
13063
+ for (int i1 = 0; i1 < ne1; ++i1) {
13064
+ for (int i0 = 0; i0 < ne0; ++i0) {
13065
+ ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
13066
+ (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
13067
+ (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
13068
+ }
12951
13069
  }
12952
13070
  }
12953
13071
  }
@@ -12996,10 +13114,169 @@ static void ggml_compute_forward_conv_2d(
12996
13114
 
12997
13115
  if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
12998
13116
  ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
12999
- }
13000
- else {
13117
+ } else {
13001
13118
  GGML_ASSERT(false); // only stride equal to kernel size is supported
13002
- };
13119
+ }
13120
+ }
13121
+
13122
+ // ggml_compute_forward_pool_1d_sk_p0
13123
+
13124
+ static void ggml_compute_forward_pool_1d_sk_p0(
13125
+ const struct ggml_compute_params * params,
13126
+ const enum ggml_op_pool op,
13127
+ const struct ggml_tensor * src,
13128
+ const int k,
13129
+ struct ggml_tensor * dst) {
13130
+ assert(src->type == GGML_TYPE_F32);
13131
+ assert(params->ith == 0);
13132
+
13133
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13134
+ return;
13135
+ }
13136
+
13137
+ const char * cdata = (const char *)src->data;
13138
+ const char * const data_end = cdata + ggml_nbytes(src);
13139
+ float * drow = (float *)dst->data;
13140
+
13141
+ const int64_t rs = dst->ne[0];
13142
+
13143
+ while (cdata < data_end) {
13144
+ const float * const srow = (const float *)cdata;
13145
+
13146
+ int j = 0;
13147
+
13148
+ for (int64_t i = 0; i < rs; ++i) {
13149
+ switch (op) {
13150
+ case GGML_OP_POOL_AVG: drow[i] = 0; break;
13151
+ case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
13152
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13153
+ }
13154
+ for (int ki = 0; ki < k; ++ki) {
13155
+ switch (op) {
13156
+ case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
13157
+ case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
13158
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13159
+ }
13160
+ ++j;
13161
+ }
13162
+ switch (op) {
13163
+ case GGML_OP_POOL_AVG: drow[i] /= k; break;
13164
+ case GGML_OP_POOL_MAX: break;
13165
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13166
+ }
13167
+ }
13168
+
13169
+ cdata += src->nb[1];
13170
+ drow += rs;
13171
+ }
13172
+ }
13173
+
13174
+ // ggml_compute_forward_pool_1d
13175
+
13176
+ static void ggml_compute_forward_pool_1d(
13177
+ const struct ggml_compute_params* params,
13178
+ const struct ggml_tensor* src0,
13179
+ const struct ggml_tensor* opt0,
13180
+ struct ggml_tensor* dst) {
13181
+ GGML_ASSERT(opt0->ne[0] == 4);
13182
+ const int* opts = (const int*)opt0->data;
13183
+ enum ggml_op_pool op = opts[0];
13184
+ const int k0 = opts[1];
13185
+ const int s0 = opts[2];
13186
+ const int p0 = opts[3];
13187
+ GGML_ASSERT(p0 == 0); // padding not supported
13188
+ GGML_ASSERT(k0 == s0); // only s = k supported
13189
+
13190
+ ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
13191
+ }
13192
+
13193
+ // ggml_compute_forward_pool_2d_sk_p0
13194
+
13195
+ static void ggml_compute_forward_pool_2d_sk_p0(
13196
+ const struct ggml_compute_params * params,
13197
+ const enum ggml_op_pool op,
13198
+ const struct ggml_tensor * src,
13199
+ const int k0,
13200
+ const int k1,
13201
+ struct ggml_tensor * dst) {
13202
+ assert(src->type == GGML_TYPE_F32);
13203
+ assert(params->ith == 0);
13204
+
13205
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13206
+ return;
13207
+ }
13208
+
13209
+ const char * cdata = (const char*)src->data;
13210
+ const char * const data_end = cdata + ggml_nbytes(src);
13211
+
13212
+ const int64_t px = dst->ne[0];
13213
+ const int64_t py = dst->ne[1];
13214
+ const int64_t pa = px * py;
13215
+
13216
+ float * dplane = (float *)dst->data;
13217
+
13218
+ const int ka = k0 * k1;
13219
+
13220
+ while (cdata < data_end) {
13221
+ for (int oy = 0; oy < py; ++oy) {
13222
+ float * const drow = dplane + oy * px;
13223
+ for (int ox = 0; ox < px; ++ox) {
13224
+ float * const out = drow + ox;
13225
+ switch (op) {
13226
+ case GGML_OP_POOL_AVG: *out = 0; break;
13227
+ case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
13228
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13229
+ }
13230
+
13231
+ const int ix = ox * k0;
13232
+ const int iy = oy * k1;
13233
+
13234
+ for (int ky = 0; ky < k1; ++ky) {
13235
+ const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
13236
+ for (int kx = 0; kx < k0; ++kx) {
13237
+ int j = ix + kx;
13238
+ switch (op) {
13239
+ case GGML_OP_POOL_AVG: *out += srow[j]; break;
13240
+ case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
13241
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13242
+ }
13243
+ }
13244
+ }
13245
+ switch (op) {
13246
+ case GGML_OP_POOL_AVG: *out /= ka; break;
13247
+ case GGML_OP_POOL_MAX: break;
13248
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13249
+ }
13250
+ }
13251
+ }
13252
+
13253
+ cdata += src->nb[2];
13254
+ dplane += pa;
13255
+ }
13256
+ }
13257
+
13258
+ // ggml_compute_forward_pool_2d
13259
+
13260
+ static void ggml_compute_forward_pool_2d(
13261
+ const struct ggml_compute_params * params,
13262
+ const struct ggml_tensor * src0,
13263
+ const struct ggml_tensor * opt0,
13264
+ struct ggml_tensor * dst) {
13265
+ GGML_ASSERT(opt0->ne[0] == 7);
13266
+ const int* opts = (const int*)opt0->data;
13267
+ enum ggml_op_pool op = opts[0];
13268
+ const int k0 = opts[1];
13269
+ const int k1 = opts[2];
13270
+ const int s0 = opts[3];
13271
+ const int s1 = opts[4];
13272
+ const int p0 = opts[5];
13273
+ const int p1 = opts[6];
13274
+ GGML_ASSERT(p0 == 0);
13275
+ GGML_ASSERT(p1 == 0); // padding not supported
13276
+ GGML_ASSERT(k0 == s0);
13277
+ GGML_ASSERT(k1 == s1); // only s = k supported
13278
+
13279
+ ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
13003
13280
  }
13004
13281
 
13005
13282
 
@@ -14566,287 +14843,295 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14566
14843
  if (skip_cpu) {
14567
14844
  return;
14568
14845
  }
14569
- GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
14570
- GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
14846
+ GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
14847
+ GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
14571
14848
  #endif // GGML_USE_CUBLAS
14572
14849
 
14573
14850
  switch (tensor->op) {
14574
14851
  case GGML_OP_DUP:
14575
14852
  {
14576
- ggml_compute_forward_dup(params, tensor->src0, tensor);
14853
+ ggml_compute_forward_dup(params, tensor->src[0], tensor);
14577
14854
  } break;
14578
14855
  case GGML_OP_ADD:
14579
14856
  {
14580
- ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
14857
+ ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
14581
14858
  } break;
14582
14859
  case GGML_OP_ADD1:
14583
14860
  {
14584
- ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor);
14861
+ ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor);
14585
14862
  } break;
14586
14863
  case GGML_OP_ACC:
14587
14864
  {
14588
- ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14865
+ ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14589
14866
  } break;
14590
14867
  case GGML_OP_SUB:
14591
14868
  {
14592
- ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
14869
+ ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor);
14593
14870
  } break;
14594
14871
  case GGML_OP_MUL:
14595
14872
  {
14596
- ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
14873
+ ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor);
14597
14874
  } break;
14598
14875
  case GGML_OP_DIV:
14599
14876
  {
14600
- ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
14877
+ ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor);
14601
14878
  } break;
14602
14879
  case GGML_OP_SQR:
14603
14880
  {
14604
- ggml_compute_forward_sqr(params, tensor->src0, tensor);
14881
+ ggml_compute_forward_sqr(params, tensor->src[0], tensor);
14605
14882
  } break;
14606
14883
  case GGML_OP_SQRT:
14607
14884
  {
14608
- ggml_compute_forward_sqrt(params, tensor->src0, tensor);
14885
+ ggml_compute_forward_sqrt(params, tensor->src[0], tensor);
14609
14886
  } break;
14610
14887
  case GGML_OP_LOG:
14611
14888
  {
14612
- ggml_compute_forward_log(params, tensor->src0, tensor);
14889
+ ggml_compute_forward_log(params, tensor->src[0], tensor);
14613
14890
  } break;
14614
14891
  case GGML_OP_SUM:
14615
14892
  {
14616
- ggml_compute_forward_sum(params, tensor->src0, tensor);
14893
+ ggml_compute_forward_sum(params, tensor->src[0], tensor);
14617
14894
  } break;
14618
14895
  case GGML_OP_SUM_ROWS:
14619
14896
  {
14620
- ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
14897
+ ggml_compute_forward_sum_rows(params, tensor->src[0], tensor);
14621
14898
  } break;
14622
14899
  case GGML_OP_MEAN:
14623
14900
  {
14624
- ggml_compute_forward_mean(params, tensor->src0, tensor);
14901
+ ggml_compute_forward_mean(params, tensor->src[0], tensor);
14625
14902
  } break;
14626
14903
  case GGML_OP_ARGMAX:
14627
14904
  {
14628
- ggml_compute_forward_argmax(params, tensor->src0, tensor);
14905
+ ggml_compute_forward_argmax(params, tensor->src[0], tensor);
14629
14906
  } break;
14630
14907
  case GGML_OP_REPEAT:
14631
14908
  {
14632
- ggml_compute_forward_repeat(params, tensor->src0, tensor);
14909
+ ggml_compute_forward_repeat(params, tensor->src[0], tensor);
14633
14910
  } break;
14634
14911
  case GGML_OP_REPEAT_BACK:
14635
14912
  {
14636
- ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14913
+ ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
14637
14914
  } break;
14638
14915
  case GGML_OP_ABS:
14639
14916
  {
14640
- ggml_compute_forward_abs(params, tensor->src0, tensor);
14917
+ ggml_compute_forward_abs(params, tensor->src[0], tensor);
14641
14918
  } break;
14642
14919
  case GGML_OP_SGN:
14643
14920
  {
14644
- ggml_compute_forward_sgn(params, tensor->src0, tensor);
14921
+ ggml_compute_forward_sgn(params, tensor->src[0], tensor);
14645
14922
  } break;
14646
14923
  case GGML_OP_NEG:
14647
14924
  {
14648
- ggml_compute_forward_neg(params, tensor->src0, tensor);
14925
+ ggml_compute_forward_neg(params, tensor->src[0], tensor);
14649
14926
  } break;
14650
14927
  case GGML_OP_STEP:
14651
14928
  {
14652
- ggml_compute_forward_step(params, tensor->src0, tensor);
14929
+ ggml_compute_forward_step(params, tensor->src[0], tensor);
14653
14930
  } break;
14654
14931
  case GGML_OP_TANH:
14655
14932
  {
14656
- ggml_compute_forward_tanh(params, tensor->src0, tensor);
14933
+ ggml_compute_forward_tanh(params, tensor->src[0], tensor);
14657
14934
  } break;
14658
14935
  case GGML_OP_ELU:
14659
14936
  {
14660
- ggml_compute_forward_elu(params, tensor->src0, tensor);
14937
+ ggml_compute_forward_elu(params, tensor->src[0], tensor);
14661
14938
  } break;
14662
14939
  case GGML_OP_RELU:
14663
14940
  {
14664
- ggml_compute_forward_relu(params, tensor->src0, tensor);
14941
+ ggml_compute_forward_relu(params, tensor->src[0], tensor);
14665
14942
  } break;
14666
14943
  case GGML_OP_GELU:
14667
14944
  {
14668
- ggml_compute_forward_gelu(params, tensor->src0, tensor);
14945
+ ggml_compute_forward_gelu(params, tensor->src[0], tensor);
14669
14946
  } break;
14670
14947
  case GGML_OP_GELU_QUICK:
14671
14948
  {
14672
- ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
14949
+ ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
14673
14950
  } break;
14674
14951
  case GGML_OP_SILU:
14675
14952
  {
14676
- ggml_compute_forward_silu(params, tensor->src0, tensor);
14953
+ ggml_compute_forward_silu(params, tensor->src[0], tensor);
14677
14954
  } break;
14678
14955
  case GGML_OP_SILU_BACK:
14679
14956
  {
14680
- ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor);
14957
+ ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
14681
14958
  } break;
14682
14959
  case GGML_OP_NORM:
14683
14960
  {
14684
- ggml_compute_forward_norm(params, tensor->src0, tensor);
14961
+ ggml_compute_forward_norm(params, tensor->src[0], tensor);
14685
14962
  } break;
14686
14963
  case GGML_OP_RMS_NORM:
14687
14964
  {
14688
- ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
14965
+ ggml_compute_forward_rms_norm(params, tensor->src[0], tensor);
14689
14966
  } break;
14690
14967
  case GGML_OP_RMS_NORM_BACK:
14691
14968
  {
14692
- ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor);
14969
+ ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
14693
14970
  } break;
14694
14971
  case GGML_OP_MUL_MAT:
14695
14972
  {
14696
- ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
14973
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14697
14974
  } break;
14698
14975
  case GGML_OP_OUT_PROD:
14699
14976
  {
14700
- ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
14977
+ ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
14701
14978
  } break;
14702
14979
  case GGML_OP_SCALE:
14703
14980
  {
14704
- ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
14981
+ ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
14705
14982
  } break;
14706
14983
  case GGML_OP_SET:
14707
14984
  {
14708
- ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14985
+ ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14709
14986
  } break;
14710
14987
  case GGML_OP_CPY:
14711
14988
  {
14712
- ggml_compute_forward_cpy(params, tensor->src0, tensor);
14989
+ ggml_compute_forward_cpy(params, tensor->src[0], tensor);
14713
14990
  } break;
14714
14991
  case GGML_OP_CONT:
14715
14992
  {
14716
- ggml_compute_forward_cont(params, tensor->src0, tensor);
14993
+ ggml_compute_forward_cont(params, tensor->src[0], tensor);
14717
14994
  } break;
14718
14995
  case GGML_OP_RESHAPE:
14719
14996
  {
14720
- ggml_compute_forward_reshape(params, tensor->src0, tensor);
14997
+ ggml_compute_forward_reshape(params, tensor->src[0], tensor);
14721
14998
  } break;
14722
14999
  case GGML_OP_VIEW:
14723
15000
  {
14724
- ggml_compute_forward_view(params, tensor->src0);
15001
+ ggml_compute_forward_view(params, tensor->src[0]);
14725
15002
  } break;
14726
15003
  case GGML_OP_PERMUTE:
14727
15004
  {
14728
- ggml_compute_forward_permute(params, tensor->src0);
15005
+ ggml_compute_forward_permute(params, tensor->src[0]);
14729
15006
  } break;
14730
15007
  case GGML_OP_TRANSPOSE:
14731
15008
  {
14732
- ggml_compute_forward_transpose(params, tensor->src0);
15009
+ ggml_compute_forward_transpose(params, tensor->src[0]);
14733
15010
  } break;
14734
15011
  case GGML_OP_GET_ROWS:
14735
15012
  {
14736
- ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
15013
+ ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor);
14737
15014
  } break;
14738
15015
  case GGML_OP_GET_ROWS_BACK:
14739
15016
  {
14740
- ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15017
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14741
15018
  } break;
14742
15019
  case GGML_OP_DIAG:
14743
15020
  {
14744
- ggml_compute_forward_diag(params, tensor->src0, tensor);
15021
+ ggml_compute_forward_diag(params, tensor->src[0], tensor);
14745
15022
  } break;
14746
15023
  case GGML_OP_DIAG_MASK_INF:
14747
15024
  {
14748
- ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
15025
+ ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
14749
15026
  } break;
14750
15027
  case GGML_OP_DIAG_MASK_ZERO:
14751
15028
  {
14752
- ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor);
15029
+ ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
14753
15030
  } break;
14754
15031
  case GGML_OP_SOFT_MAX:
14755
15032
  {
14756
- ggml_compute_forward_soft_max(params, tensor->src0, tensor);
15033
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
14757
15034
  } break;
14758
15035
  case GGML_OP_SOFT_MAX_BACK:
14759
15036
  {
14760
- ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
15037
+ ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor);
14761
15038
  } break;
14762
15039
  case GGML_OP_ROPE:
14763
15040
  {
14764
- ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
15041
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
14765
15042
  } break;
14766
15043
  case GGML_OP_ROPE_BACK:
14767
15044
  {
14768
- ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
15045
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
14769
15046
  } break;
14770
15047
  case GGML_OP_ALIBI:
14771
15048
  {
14772
- ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
15049
+ ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
14773
15050
  } break;
14774
15051
  case GGML_OP_CLAMP:
14775
15052
  {
14776
- ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
15053
+ ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
14777
15054
  } break;
14778
15055
  case GGML_OP_CONV_1D:
14779
15056
  {
14780
- ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15057
+ ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14781
15058
  } break;
14782
15059
  case GGML_OP_CONV_2D:
14783
15060
  {
14784
- ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15061
+ ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
15062
+ } break;
15063
+ case GGML_OP_POOL_1D:
15064
+ {
15065
+ ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
15066
+ } break;
15067
+ case GGML_OP_POOL_2D:
15068
+ {
15069
+ ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
14785
15070
  } break;
14786
15071
  case GGML_OP_FLASH_ATTN:
14787
15072
  {
14788
- const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15073
+ const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
14789
15074
  GGML_ASSERT(t == 0 || t == 1);
14790
15075
  const bool masked = t != 0;
14791
- ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
15076
+ ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
14792
15077
  } break;
14793
15078
  case GGML_OP_FLASH_FF:
14794
15079
  {
14795
- ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
15080
+ ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
14796
15081
  } break;
14797
15082
  case GGML_OP_FLASH_ATTN_BACK:
14798
15083
  {
14799
- int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
15084
+ int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
14800
15085
  GGML_ASSERT(t == 0 || t == 1);
14801
15086
  bool masked = t != 0;
14802
- ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
15087
+ ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
14803
15088
  } break;
14804
15089
  case GGML_OP_WIN_PART:
14805
15090
  {
14806
- ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
15091
+ ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
14807
15092
  } break;
14808
15093
  case GGML_OP_WIN_UNPART:
14809
15094
  {
14810
- ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
15095
+ ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
14811
15096
  } break;
14812
15097
  case GGML_OP_MAP_UNARY:
14813
15098
  {
14814
- const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
14815
- ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
15099
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
15100
+ ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
14816
15101
  }
14817
15102
  break;
14818
15103
  case GGML_OP_MAP_BINARY:
14819
15104
  {
14820
- const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
14821
- ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
15105
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
15106
+ ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
14822
15107
  }
14823
15108
  break;
14824
15109
  case GGML_OP_MAP_CUSTOM1:
14825
15110
  {
14826
- const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
14827
- ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
15111
+ const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
15112
+ ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
14828
15113
  }
14829
15114
  break;
14830
15115
  case GGML_OP_MAP_CUSTOM2:
14831
15116
  {
14832
- const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
14833
- ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
15117
+ const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
15118
+ ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
14834
15119
  }
14835
15120
  break;
14836
15121
  case GGML_OP_MAP_CUSTOM3:
14837
15122
  {
14838
- const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
14839
- ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
15123
+ const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
15124
+ ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
14840
15125
  }
14841
15126
  break;
14842
15127
  case GGML_OP_CROSS_ENTROPY_LOSS:
14843
15128
  {
14844
- ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
15129
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor);
14845
15130
  }
14846
15131
  break;
14847
15132
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14848
15133
  {
14849
- ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15134
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14850
15135
  }
14851
15136
  break;
14852
15137
  case GGML_OP_NONE:
@@ -14863,8 +15148,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14863
15148
  ////////////////////////////////////////////////////////////////////////////////
14864
15149
 
14865
15150
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
14866
- struct ggml_tensor * src0 = tensor->src0;
14867
- struct ggml_tensor * src1 = tensor->src1;
15151
+ struct ggml_tensor * src0 = tensor->src[0];
15152
+ struct ggml_tensor * src1 = tensor->src[1];
14868
15153
 
14869
15154
  switch (tensor->op) {
14870
15155
  case GGML_OP_DUP:
@@ -14900,12 +15185,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
14900
15185
  src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
14901
15186
  }
14902
15187
  if (src1->grad) {
14903
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
14904
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
14905
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
14906
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
14907
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
14908
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15188
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15189
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15190
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15191
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15192
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15193
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
14909
15194
 
14910
15195
  struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
14911
15196
  tensor->grad,
@@ -15213,12 +15498,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15213
15498
  } break;
15214
15499
  case GGML_OP_SET:
15215
15500
  {
15216
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
15217
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
15218
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
15219
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
15220
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
15221
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15501
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15502
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15503
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15504
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15505
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15506
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
15222
15507
 
15223
15508
  struct ggml_tensor * tensor_grad_view = NULL;
15224
15509
 
@@ -15295,8 +15580,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15295
15580
  if (src0->grad) {
15296
15581
  size_t offset;
15297
15582
 
15298
- GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
15299
- memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
15583
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
15584
+ memcpy(&offset, tensor->src[2]->data, sizeof(offset));
15300
15585
 
15301
15586
  size_t nb1 = tensor->nb[1];
15302
15587
  size_t nb2 = tensor->nb[2];
@@ -15323,7 +15608,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15323
15608
  {
15324
15609
  // necessary for llama
15325
15610
  if (src0->grad) {
15326
- int32_t * axes = (int32_t *) tensor->opt[0]->data;
15611
+ int32_t * axes = (int32_t *) tensor->src[2]->data;
15327
15612
  int axis0 = axes[0] & 0x3;
15328
15613
  int axis1 = axes[1] & 0x3;
15329
15614
  int axis2 = axes[2] & 0x3;
@@ -15483,18 +15768,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15483
15768
  {
15484
15769
  GGML_ASSERT(false); // TODO: not implemented
15485
15770
  } break;
15771
+ case GGML_OP_POOL_1D:
15772
+ {
15773
+ GGML_ASSERT(false); // TODO: not implemented
15774
+ } break;
15775
+ case GGML_OP_POOL_2D:
15776
+ {
15777
+ GGML_ASSERT(false); // TODO: not implemented
15778
+ } break;
15486
15779
  case GGML_OP_FLASH_ATTN:
15487
15780
  {
15488
15781
  struct ggml_tensor * flash_grad = NULL;
15489
- if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15490
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15782
+ if (src0->grad || src1->grad || tensor->src[2]->grad) {
15783
+ int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
15491
15784
  GGML_ASSERT(t == 0 || t == 1);
15492
15785
  bool masked = t != 0;
15493
15786
  flash_grad =
15494
15787
  ggml_flash_attn_back(ctx,
15495
15788
  src0,
15496
15789
  src1,
15497
- tensor->opt[0],
15790
+ tensor->src[2],
15498
15791
  tensor->grad,
15499
15792
  masked);
15500
15793
  }
@@ -15591,7 +15884,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15591
15884
  inplace);
15592
15885
  }
15593
15886
 
15594
- struct ggml_tensor * opt0 = tensor->opt[0];
15887
+ struct ggml_tensor * opt0 = tensor->src[2];
15595
15888
 
15596
15889
  if (opt0->grad) {
15597
15890
  struct ggml_tensor * grad_v = NULL;
@@ -15707,17 +16000,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
15707
16000
  }
15708
16001
  }
15709
16002
 
15710
- if (node->src0) {
15711
- ggml_visit_parents(cgraph, node->src0);
15712
- }
15713
-
15714
- if (node->src1) {
15715
- ggml_visit_parents(cgraph, node->src1);
15716
- }
15717
-
15718
- for (int i = 0; i < GGML_MAX_OPT; ++i) {
15719
- if (node->opt[i]) {
15720
- ggml_visit_parents(cgraph, node->opt[i]);
16003
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
16004
+ if (node->src[i]) {
16005
+ ggml_visit_parents(cgraph, node->src[i]);
15721
16006
  }
15722
16007
  }
15723
16008
 
@@ -15772,9 +16057,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
15772
16057
  struct ggml_cgraph result = {
15773
16058
  /*.n_nodes =*/ 0,
15774
16059
  /*.n_leafs =*/ 0,
15775
- /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
15776
- /*.work_size =*/ 0,
15777
- /*.work =*/ NULL,
15778
16060
  /*.nodes =*/ { NULL },
15779
16061
  /*.grads =*/ { NULL },
15780
16062
  /*.leafs =*/ { NULL },
@@ -15945,16 +16227,20 @@ void clear_numa_thread_affinity(void) {}
15945
16227
  #endif
15946
16228
 
15947
16229
  struct ggml_compute_state_shared {
15948
- struct ggml_cgraph * cgraph;
16230
+ const struct ggml_cgraph * cgraph;
16231
+ const struct ggml_cplan * cplan;
15949
16232
 
15950
16233
  int64_t perf_node_start_cycles;
15951
16234
  int64_t perf_node_start_time_us;
15952
16235
 
15953
- int n_threads;
16236
+ const int n_threads;
15954
16237
 
15955
16238
  // synchronization primitives
15956
16239
  atomic_int n_active; // num active threads
15957
16240
  atomic_int node_n; // active graph node
16241
+
16242
+ bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
16243
+ void * abort_callback_data;
15958
16244
  };
15959
16245
 
15960
16246
  struct ggml_compute_state {
@@ -15974,14 +16260,22 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
15974
16260
 
15975
16261
  static thread_ret_t ggml_graph_compute_thread(void * data) {
15976
16262
  struct ggml_compute_state * state = (struct ggml_compute_state *) data;
15977
- struct ggml_cgraph * cgraph = state->shared->cgraph;
15978
16263
 
15979
- const int n_threads = state->shared->n_threads;
16264
+ const struct ggml_cgraph * cgraph = state->shared->cgraph;
16265
+ const struct ggml_cplan * cplan = state->shared->cplan;
16266
+
16267
+ const int * n_tasks_arr = cplan->n_tasks;
16268
+ const int n_threads = state->shared->n_threads;
16269
+
15980
16270
  set_numa_thread_affinity(state->ith, n_threads);
15981
16271
 
15982
16272
  int node_n = -1;
15983
16273
 
15984
16274
  while (true) {
16275
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16276
+ state->shared->node_n += 1;
16277
+ return (thread_ret_t) GGML_EXIT_ABORTED;
16278
+ }
15985
16279
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
15986
16280
  // all other threads are finished and spinning
15987
16281
  // do finalize and init here so we don't have synchronize again
@@ -15989,15 +16283,15 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
15989
16283
  /*.type =*/ GGML_TASK_FINALIZE,
15990
16284
  /*.ith =*/ 0,
15991
16285
  /*.nth =*/ 0,
15992
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
15993
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16286
+ /*.wsize =*/ cplan->work_size,
16287
+ /*.wdata =*/ cplan->work_data,
15994
16288
  };
15995
16289
 
15996
16290
  if (node_n != -1) {
15997
16291
  /* FINALIZE */
15998
16292
  struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
15999
16293
  if (GGML_OP_HAS_FINALIZE[node->op]) {
16000
- params.nth = node->n_tasks;
16294
+ params.nth = n_tasks_arr[node_n];
16001
16295
  ggml_compute_forward(&params, node);
16002
16296
  ggml_graph_compute_perf_stats_node(node, state->shared);
16003
16297
  }
@@ -16008,11 +16302,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16008
16302
  GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
16009
16303
 
16010
16304
  struct ggml_tensor * node = cgraph->nodes[node_n];
16305
+ const int n_tasks = n_tasks_arr[node_n];
16011
16306
 
16012
16307
  state->shared->perf_node_start_cycles = ggml_perf_cycles();
16013
16308
  state->shared->perf_node_start_time_us = ggml_perf_time_us();
16014
16309
 
16015
- params.nth = node->n_tasks;
16310
+ params.nth = n_tasks;
16016
16311
 
16017
16312
  /* INIT */
16018
16313
  if (GGML_OP_HAS_INIT[node->op]) {
@@ -16020,7 +16315,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16020
16315
  ggml_compute_forward(&params, node);
16021
16316
  }
16022
16317
 
16023
- if (node->n_tasks == 1) {
16318
+ if (n_tasks == 1) {
16024
16319
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
16025
16320
  // they do something more efficient than spinning (?)
16026
16321
  params.type = GGML_TASK_COMPUTE;
@@ -16034,6 +16329,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16034
16329
  } else {
16035
16330
  break;
16036
16331
  }
16332
+
16333
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16334
+ break;
16335
+ }
16037
16336
  }
16038
16337
 
16039
16338
  atomic_store(&state->shared->n_active, n_threads);
@@ -16042,7 +16341,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16042
16341
  // wait for other threads to finish
16043
16342
  const int last = node_n;
16044
16343
  do {
16045
- sched_yield();
16344
+ //sched_yield();
16046
16345
  node_n = atomic_load(&state->shared->node_n);
16047
16346
  } while (node_n == last);
16048
16347
  }
@@ -16052,366 +16351,395 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16052
16351
 
16053
16352
  /* COMPUTE */
16054
16353
  struct ggml_tensor * node = cgraph->nodes[node_n];
16354
+ const int n_tasks = n_tasks_arr[node_n];
16055
16355
 
16056
16356
  struct ggml_compute_params params = {
16057
16357
  /*.type =*/ GGML_TASK_COMPUTE,
16058
16358
  /*.ith =*/ state->ith,
16059
- /*.nth =*/ node->n_tasks,
16060
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
16061
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16359
+ /*.nth =*/ n_tasks,
16360
+ /*.wsize =*/ cplan->work_size,
16361
+ /*.wdata =*/ cplan->work_data,
16062
16362
  };
16063
16363
 
16064
- if (state->ith < node->n_tasks) {
16364
+ if (state->ith < n_tasks) {
16065
16365
  ggml_compute_forward(&params, node);
16066
16366
  }
16067
16367
  }
16068
16368
 
16069
- return 0;
16369
+ return GGML_EXIT_SUCCESS;
16070
16370
  }
16071
16371
 
16072
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16073
- const int n_threads = cgraph->n_threads;
16372
+ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16373
+ if (n_threads <= 0) {
16374
+ n_threads = GGML_DEFAULT_N_THREADS;
16375
+ }
16074
16376
 
16075
- struct ggml_compute_state_shared state_shared = {
16076
- /*.cgraph =*/ cgraph,
16077
- /*.perf_node_start_cycles =*/ 0,
16078
- /*.perf_node_start_time_us =*/ 0,
16079
- /*.n_threads =*/ n_threads,
16080
- /*.n_active =*/ n_threads,
16081
- /*.node_n =*/ -1,
16082
- };
16083
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16377
+ size_t work_size = 0;
16084
16378
 
16085
- // initialize tasks + work buffer
16086
- {
16087
- size_t work_size = 0;
16379
+ struct ggml_cplan cplan;
16380
+ memset(&cplan, 0, sizeof(struct ggml_cplan));
16088
16381
 
16089
- // thread scheduling for the different operations
16090
- for (int i = 0; i < cgraph->n_nodes; i++) {
16091
- struct ggml_tensor * node = cgraph->nodes[i];
16382
+ // thread scheduling for the different operations + work buffer size estimation
16383
+ for (int i = 0; i < cgraph->n_nodes; i++) {
16384
+ int n_tasks = 1;
16092
16385
 
16093
- switch (node->op) {
16094
- case GGML_OP_CPY:
16095
- case GGML_OP_DUP:
16096
- {
16097
- node->n_tasks = n_threads;
16386
+ struct ggml_tensor * node = cgraph->nodes[i];
16098
16387
 
16099
- size_t cur = 0;
16100
- if (ggml_is_quantized(node->type)) {
16101
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
16102
- }
16388
+ switch (node->op) {
16389
+ case GGML_OP_CPY:
16390
+ case GGML_OP_DUP:
16391
+ {
16392
+ n_tasks = n_threads;
16103
16393
 
16104
- work_size = MAX(work_size, cur);
16105
- } break;
16106
- case GGML_OP_ADD:
16107
- case GGML_OP_ADD1:
16108
- {
16109
- node->n_tasks = n_threads;
16394
+ size_t cur = 0;
16395
+ if (ggml_is_quantized(node->type)) {
16396
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
16397
+ }
16110
16398
 
16111
- size_t cur = 0;
16399
+ work_size = MAX(work_size, cur);
16400
+ } break;
16401
+ case GGML_OP_ADD:
16402
+ case GGML_OP_ADD1:
16403
+ {
16404
+ n_tasks = n_threads;
16112
16405
 
16113
- if (ggml_is_quantized(node->src0->type)) {
16114
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
16115
- }
16406
+ size_t cur = 0;
16116
16407
 
16117
- work_size = MAX(work_size, cur);
16118
- } break;
16119
- case GGML_OP_ACC:
16120
- {
16121
- node->n_tasks = n_threads;
16408
+ if (ggml_is_quantized(node->src[0]->type)) {
16409
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
16410
+ }
16122
16411
 
16123
- size_t cur = 0;
16412
+ work_size = MAX(work_size, cur);
16413
+ } break;
16414
+ case GGML_OP_ACC:
16415
+ {
16416
+ n_tasks = n_threads;
16124
16417
 
16125
- if (ggml_is_quantized(node->src0->type)) {
16126
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
16127
- }
16418
+ size_t cur = 0;
16128
16419
 
16129
- work_size = MAX(work_size, cur);
16130
- } break;
16131
- case GGML_OP_SUB:
16132
- case GGML_OP_DIV:
16133
- case GGML_OP_SQR:
16134
- case GGML_OP_SQRT:
16135
- case GGML_OP_LOG:
16136
- case GGML_OP_SUM:
16137
- case GGML_OP_SUM_ROWS:
16138
- case GGML_OP_MEAN:
16139
- case GGML_OP_ARGMAX:
16140
- case GGML_OP_REPEAT:
16141
- case GGML_OP_REPEAT_BACK:
16142
- case GGML_OP_ABS:
16143
- case GGML_OP_SGN:
16144
- case GGML_OP_NEG:
16145
- case GGML_OP_STEP:
16146
- case GGML_OP_TANH:
16147
- case GGML_OP_ELU:
16148
- case GGML_OP_RELU:
16149
- {
16150
- node->n_tasks = 1;
16151
- } break;
16152
- case GGML_OP_MUL:
16153
- case GGML_OP_GELU:
16154
- case GGML_OP_GELU_QUICK:
16155
- case GGML_OP_SILU:
16156
- case GGML_OP_SILU_BACK:
16157
- case GGML_OP_NORM:
16158
- case GGML_OP_RMS_NORM:
16159
- case GGML_OP_RMS_NORM_BACK:
16160
- {
16161
- node->n_tasks = n_threads;
16162
- } break;
16163
- case GGML_OP_MUL_MAT:
16164
- case GGML_OP_OUT_PROD:
16165
- {
16166
- node->n_tasks = n_threads;
16167
-
16168
- // TODO: use different scheduling for different matrix sizes
16169
- //const int nr0 = ggml_nrows(node->src0);
16170
- //const int nr1 = ggml_nrows(node->src1);
16171
-
16172
- //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
16173
- //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
16174
-
16175
- size_t cur = 0;
16176
- const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
16420
+ if (ggml_is_quantized(node->src[0]->type)) {
16421
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
16422
+ }
16423
+
16424
+ work_size = MAX(work_size, cur);
16425
+ } break;
16426
+ case GGML_OP_SUB:
16427
+ case GGML_OP_DIV:
16428
+ case GGML_OP_SQR:
16429
+ case GGML_OP_SQRT:
16430
+ case GGML_OP_LOG:
16431
+ case GGML_OP_SUM:
16432
+ case GGML_OP_SUM_ROWS:
16433
+ case GGML_OP_MEAN:
16434
+ case GGML_OP_ARGMAX:
16435
+ case GGML_OP_REPEAT:
16436
+ case GGML_OP_REPEAT_BACK:
16437
+ case GGML_OP_ABS:
16438
+ case GGML_OP_SGN:
16439
+ case GGML_OP_NEG:
16440
+ case GGML_OP_STEP:
16441
+ case GGML_OP_TANH:
16442
+ case GGML_OP_ELU:
16443
+ case GGML_OP_RELU:
16444
+ {
16445
+ n_tasks = 1;
16446
+ } break;
16447
+ case GGML_OP_MUL:
16448
+ case GGML_OP_GELU:
16449
+ case GGML_OP_GELU_QUICK:
16450
+ case GGML_OP_SILU:
16451
+ case GGML_OP_SILU_BACK:
16452
+ case GGML_OP_NORM:
16453
+ case GGML_OP_RMS_NORM:
16454
+ case GGML_OP_RMS_NORM_BACK:
16455
+ {
16456
+ n_tasks = n_threads;
16457
+ } break;
16458
+ case GGML_OP_MUL_MAT:
16459
+ case GGML_OP_OUT_PROD:
16460
+ {
16461
+ n_tasks = n_threads;
16462
+
16463
+ // TODO: use different scheduling for different matrix sizes
16464
+ //const int nr0 = ggml_nrows(node->src[0]);
16465
+ //const int nr1 = ggml_nrows(node->src[1]);
16466
+
16467
+ //n_tasks = MIN(n_threads, MAX(1, nr0/128));
16468
+ //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
16469
+
16470
+ size_t cur = 0;
16471
+ const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
16177
16472
 
16178
16473
  #if defined(GGML_USE_CUBLAS)
16179
- if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
16180
- node->n_tasks = 1; // TODO: this actually is doing nothing
16181
- // the threads are still spinning
16182
- }
16183
- else
16474
+ if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
16475
+ n_tasks = 1; // TODO: this actually is doing nothing
16476
+ // the threads are still spinning
16477
+ } else
16184
16478
  #elif defined(GGML_USE_CLBLAST)
16185
- if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
16186
- node->n_tasks = 1; // TODO: this actually is doing nothing
16187
- // the threads are still spinning
16188
- cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
16189
- }
16190
- else
16479
+ if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
16480
+ n_tasks = 1; // TODO: this actually is doing nothing
16481
+ // the threads are still spinning
16482
+ cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
16483
+ } else
16191
16484
  #endif
16192
16485
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16193
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
16194
- node->n_tasks = 1; // TODO: this actually is doing nothing
16195
- // the threads are still spinning
16196
- if (node->src0->type != GGML_TYPE_F32) {
16197
- // here we need memory just for single 2D matrix from src0
16198
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
16199
- }
16200
- } else
16201
- #endif
16202
- if (node->src1->type != vec_dot_type) {
16203
- cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
16204
- } else {
16205
- cur = 0;
16486
+ if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
16487
+ n_tasks = 1; // TODO: this actually is doing nothing
16488
+ // the threads are still spinning
16489
+ if (node->src[0]->type != GGML_TYPE_F32) {
16490
+ // here we need memory just for single 2D matrix from src0
16491
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
16206
16492
  }
16493
+ } else
16494
+ #endif
16495
+ if (node->src[1]->type != vec_dot_type) {
16496
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
16497
+ } else {
16498
+ cur = 0;
16499
+ }
16207
16500
 
16208
- work_size = MAX(work_size, cur);
16209
- } break;
16210
- case GGML_OP_SCALE:
16211
- {
16212
- node->n_tasks = 1;
16213
- } break;
16214
- case GGML_OP_SET:
16215
- case GGML_OP_CONT:
16216
- case GGML_OP_RESHAPE:
16217
- case GGML_OP_VIEW:
16218
- case GGML_OP_PERMUTE:
16219
- case GGML_OP_TRANSPOSE:
16220
- case GGML_OP_GET_ROWS:
16221
- case GGML_OP_GET_ROWS_BACK:
16222
- case GGML_OP_DIAG:
16223
- case GGML_OP_DIAG_MASK_ZERO:
16224
- {
16225
- node->n_tasks = 1;
16226
- } break;
16227
- case GGML_OP_DIAG_MASK_INF:
16228
- case GGML_OP_SOFT_MAX:
16229
- case GGML_OP_SOFT_MAX_BACK:
16230
- case GGML_OP_ROPE:
16231
- case GGML_OP_ROPE_BACK:
16232
- {
16233
- node->n_tasks = n_threads;
16234
- } break;
16235
- case GGML_OP_ALIBI:
16236
- {
16237
- node->n_tasks = 1; //TODO
16238
- } break;
16239
- case GGML_OP_CLAMP:
16240
- {
16241
- node->n_tasks = 1; //TODO
16242
- } break;
16243
- case GGML_OP_CONV_1D:
16244
- {
16245
- node->n_tasks = n_threads;
16246
-
16247
- GGML_ASSERT(node->src0->ne[3] == 1);
16248
- GGML_ASSERT(node->src1->ne[2] == 1);
16249
- GGML_ASSERT(node->src1->ne[3] == 1);
16250
-
16251
- size_t cur = 0;
16252
- const int nk = node->src0->ne[0];
16253
-
16254
- if (node->src0->type == GGML_TYPE_F16 &&
16255
- node->src1->type == GGML_TYPE_F32) {
16256
- cur = sizeof(ggml_fp16_t)*(
16257
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16258
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16259
- );
16260
- } else if (node->src0->type == GGML_TYPE_F32 &&
16261
- node->src1->type == GGML_TYPE_F32) {
16262
- cur = sizeof(float)*(
16263
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16264
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16265
- );
16266
- } else {
16267
- GGML_ASSERT(false);
16268
- }
16501
+ work_size = MAX(work_size, cur);
16502
+ } break;
16503
+ case GGML_OP_SCALE:
16504
+ {
16505
+ n_tasks = 1;
16506
+ } break;
16507
+ case GGML_OP_SET:
16508
+ case GGML_OP_CONT:
16509
+ case GGML_OP_RESHAPE:
16510
+ case GGML_OP_VIEW:
16511
+ case GGML_OP_PERMUTE:
16512
+ case GGML_OP_TRANSPOSE:
16513
+ case GGML_OP_GET_ROWS:
16514
+ case GGML_OP_GET_ROWS_BACK:
16515
+ case GGML_OP_DIAG:
16516
+ case GGML_OP_DIAG_MASK_ZERO:
16517
+ {
16518
+ n_tasks = 1;
16519
+ } break;
16520
+ case GGML_OP_DIAG_MASK_INF:
16521
+ case GGML_OP_SOFT_MAX:
16522
+ case GGML_OP_SOFT_MAX_BACK:
16523
+ case GGML_OP_ROPE:
16524
+ case GGML_OP_ROPE_BACK:
16525
+ {
16526
+ n_tasks = n_threads;
16527
+ } break;
16528
+ case GGML_OP_ALIBI:
16529
+ {
16530
+ n_tasks = 1; //TODO
16531
+ } break;
16532
+ case GGML_OP_CLAMP:
16533
+ {
16534
+ n_tasks = 1; //TODO
16535
+ } break;
16536
+ case GGML_OP_CONV_1D:
16537
+ {
16538
+ n_tasks = n_threads;
16539
+
16540
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
16541
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
16542
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
16543
+
16544
+ size_t cur = 0;
16545
+ const int nk = node->src[0]->ne[0];
16546
+
16547
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16548
+ node->src[1]->type == GGML_TYPE_F32) {
16549
+ cur = sizeof(ggml_fp16_t)*(
16550
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16551
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16552
+ );
16553
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16554
+ node->src[1]->type == GGML_TYPE_F32) {
16555
+ cur = sizeof(float)*(
16556
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16557
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16558
+ );
16559
+ } else {
16560
+ GGML_ASSERT(false);
16561
+ }
16269
16562
 
16270
- work_size = MAX(work_size, cur);
16271
- } break;
16272
- case GGML_OP_CONV_2D:
16273
- {
16274
- node->n_tasks = n_threads;
16563
+ work_size = MAX(work_size, cur);
16564
+ } break;
16565
+ case GGML_OP_CONV_2D:
16566
+ {
16567
+ n_tasks = n_threads;
16275
16568
 
16276
- GGML_ASSERT(node->src1->ne[3] == 1);
16569
+ const int64_t ne00 = node->src[0]->ne[0]; // W
16570
+ const int64_t ne01 = node->src[0]->ne[1]; // H
16571
+ const int64_t ne02 = node->src[0]->ne[2]; // C
16572
+ const int64_t ne03 = node->src[0]->ne[3]; // N
16277
16573
 
16278
- const int64_t ne00 = node->src0->ne[0]; // W
16279
- const int64_t ne01 = node->src0->ne[1]; // H
16280
- const int64_t ne02 = node->src0->ne[2]; // C
16281
- const int64_t ne03 = node->src0->ne[3]; // N
16574
+ const int64_t ne10 = node->src[1]->ne[0]; // W
16575
+ const int64_t ne11 = node->src[1]->ne[1]; // H
16576
+ const int64_t ne12 = node->src[1]->ne[2]; // C
16282
16577
 
16283
- const int64_t ne10 = node->src1->ne[0]; // W
16284
- const int64_t ne11 = node->src1->ne[1]; // H
16285
- const int64_t ne12 = node->src1->ne[2]; // C
16578
+ const int64_t nk = ne00*ne01;
16286
16579
 
16287
- const int64_t nk = ne00*ne01;
16580
+ UNUSED(ne02);
16581
+ UNUSED(ne03);
16582
+ UNUSED(nk);
16288
16583
 
16289
- UNUSED(ne02);
16290
- UNUSED(ne03);
16291
- UNUSED(nk);
16584
+ size_t cur = 0;
16292
16585
 
16293
- size_t cur = 0;
16586
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16587
+ node->src[1]->type == GGML_TYPE_F32) {
16588
+ cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16589
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16590
+ node->src[1]->type == GGML_TYPE_F32) {
16591
+ cur = sizeof(float)* (ne10*ne11*ne12);
16592
+ } else {
16593
+ GGML_ASSERT(false);
16594
+ }
16294
16595
 
16295
- if (node->src0->type == GGML_TYPE_F16 &&
16296
- node->src1->type == GGML_TYPE_F32) {
16297
- cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16298
- } else if (node->src0->type == GGML_TYPE_F32 &&
16299
- node->src1->type == GGML_TYPE_F32) {
16300
- cur = sizeof(float)* (ne10*ne11*ne12);
16301
- } else {
16302
- GGML_ASSERT(false);
16303
- }
16596
+ work_size = MAX(work_size, cur);
16597
+ } break;
16598
+ case GGML_OP_POOL_1D:
16599
+ case GGML_OP_POOL_2D:
16600
+ {
16601
+ n_tasks = 1;
16602
+ } break;
16603
+ case GGML_OP_FLASH_ATTN:
16604
+ {
16605
+ n_tasks = n_threads;
16304
16606
 
16305
- work_size = MAX(work_size, cur);
16306
- } break;
16307
- case GGML_OP_FLASH_ATTN:
16308
- {
16309
- node->n_tasks = n_threads;
16607
+ size_t cur = 0;
16310
16608
 
16311
- size_t cur = 0;
16609
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16312
16610
 
16313
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16611
+ if (node->src[1]->type == GGML_TYPE_F32) {
16612
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16613
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16614
+ }
16314
16615
 
16315
- if (node->src1->type == GGML_TYPE_F32) {
16316
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16317
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16318
- }
16616
+ if (node->src[1]->type == GGML_TYPE_F16) {
16617
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16618
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16619
+ }
16319
16620
 
16320
- if (node->src1->type == GGML_TYPE_F16) {
16321
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16322
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16323
- }
16621
+ work_size = MAX(work_size, cur);
16622
+ } break;
16623
+ case GGML_OP_FLASH_FF:
16624
+ {
16625
+ n_tasks = n_threads;
16324
16626
 
16325
- work_size = MAX(work_size, cur);
16326
- } break;
16327
- case GGML_OP_FLASH_FF:
16328
- {
16329
- node->n_tasks = n_threads;
16627
+ size_t cur = 0;
16330
16628
 
16331
- size_t cur = 0;
16629
+ if (node->src[1]->type == GGML_TYPE_F32) {
16630
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16631
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16632
+ }
16332
16633
 
16333
- if (node->src1->type == GGML_TYPE_F32) {
16334
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16335
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16336
- }
16634
+ if (node->src[1]->type == GGML_TYPE_F16) {
16635
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16636
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16637
+ }
16337
16638
 
16338
- if (node->src1->type == GGML_TYPE_F16) {
16339
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16340
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16341
- }
16639
+ work_size = MAX(work_size, cur);
16640
+ } break;
16641
+ case GGML_OP_FLASH_ATTN_BACK:
16642
+ {
16643
+ n_tasks = n_threads;
16342
16644
 
16343
- work_size = MAX(work_size, cur);
16344
- } break;
16345
- case GGML_OP_FLASH_ATTN_BACK:
16346
- {
16347
- node->n_tasks = n_threads;
16645
+ size_t cur = 0;
16348
16646
 
16349
- size_t cur = 0;
16647
+ const int64_t D = node->src[0]->ne[0];
16648
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16649
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16650
+ if (node->src[1]->type == GGML_TYPE_F32) {
16651
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16652
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16653
+ }
16350
16654
 
16351
- const int64_t D = node->src0->ne[0];
16352
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16353
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16354
- if (node->src1->type == GGML_TYPE_F32) {
16355
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16356
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16357
- }
16655
+ if (node->src[1]->type == GGML_TYPE_F16) {
16656
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16657
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16658
+ }
16358
16659
 
16359
- if (node->src1->type == GGML_TYPE_F16) {
16360
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16361
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16362
- }
16660
+ work_size = MAX(work_size, cur);
16661
+ } break;
16662
+ case GGML_OP_WIN_PART:
16663
+ case GGML_OP_WIN_UNPART:
16664
+ case GGML_OP_MAP_UNARY:
16665
+ case GGML_OP_MAP_BINARY:
16666
+ case GGML_OP_MAP_CUSTOM1:
16667
+ case GGML_OP_MAP_CUSTOM2:
16668
+ case GGML_OP_MAP_CUSTOM3:
16669
+ {
16670
+ n_tasks = 1;
16671
+ } break;
16672
+ case GGML_OP_CROSS_ENTROPY_LOSS:
16673
+ {
16674
+ n_tasks = n_threads;
16363
16675
 
16364
- work_size = MAX(work_size, cur);
16365
- } break;
16366
- case GGML_OP_WIN_PART:
16367
- case GGML_OP_WIN_UNPART:
16368
- case GGML_OP_MAP_UNARY:
16369
- case GGML_OP_MAP_BINARY:
16370
- case GGML_OP_MAP_CUSTOM1:
16371
- case GGML_OP_MAP_CUSTOM2:
16372
- case GGML_OP_MAP_CUSTOM3:
16373
- {
16374
- node->n_tasks = 1;
16375
- } break;
16376
- case GGML_OP_CROSS_ENTROPY_LOSS:
16377
- {
16378
- node->n_tasks = n_threads;
16379
-
16380
- size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
16381
-
16382
- work_size = MAX(work_size, cur);
16383
- } break;
16384
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16385
- {
16386
- node->n_tasks = n_threads;
16387
-
16388
- size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
16389
-
16390
- work_size = MAX(work_size, cur);
16391
- } break;
16392
- case GGML_OP_NONE:
16393
- {
16394
- node->n_tasks = 1;
16395
- } break;
16396
- case GGML_OP_COUNT:
16397
- {
16398
- GGML_ASSERT(false);
16399
- } break;
16400
- }
16401
- }
16676
+ size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16677
+
16678
+ work_size = MAX(work_size, cur);
16679
+ } break;
16680
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16681
+ {
16682
+ n_tasks = n_threads;
16683
+
16684
+ size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
16402
16685
 
16403
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
16404
- GGML_ASSERT(false); // TODO: better handling
16686
+ work_size = MAX(work_size, cur);
16687
+ } break;
16688
+ case GGML_OP_NONE:
16689
+ {
16690
+ n_tasks = 1;
16691
+ } break;
16692
+ case GGML_OP_COUNT:
16693
+ {
16694
+ GGML_ASSERT(false);
16695
+ } break;
16405
16696
  }
16406
16697
 
16407
- if (work_size > 0 && cgraph->work == NULL) {
16408
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
16698
+ cplan.n_tasks[i] = n_tasks;
16699
+ }
16700
+
16701
+ if (work_size > 0) {
16702
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
16703
+ }
16704
+
16705
+ cplan.n_threads = n_threads;
16706
+ cplan.work_size = work_size;
16707
+ cplan.work_data = NULL;
16708
+
16709
+ return cplan;
16710
+ }
16711
+
16712
+ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
16713
+ {
16714
+ GGML_ASSERT(cplan);
16715
+ GGML_ASSERT(cplan->n_threads > 0);
16409
16716
 
16410
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
16411
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
16717
+ if (cplan->work_size > 0) {
16718
+ GGML_ASSERT(cplan->work_data);
16719
+ }
16720
+
16721
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
16722
+ if (cgraph->nodes[i]->op != GGML_OP_NONE) {
16723
+ GGML_ASSERT(cplan->n_tasks[i] > 0);
16724
+ }
16412
16725
  }
16413
16726
  }
16414
16727
 
16728
+ const int n_threads = cplan->n_threads;
16729
+
16730
+ struct ggml_compute_state_shared state_shared = {
16731
+ /*.cgraph =*/ cgraph,
16732
+ /*.cgraph_plan =*/ cplan,
16733
+ /*.perf_node_start_cycles =*/ 0,
16734
+ /*.perf_node_start_time_us =*/ 0,
16735
+ /*.n_threads =*/ n_threads,
16736
+ /*.n_active =*/ n_threads,
16737
+ /*.node_n =*/ -1,
16738
+ /*.abort_callback =*/ NULL,
16739
+ /*.abort_callback_data =*/ NULL,
16740
+ };
16741
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16742
+
16415
16743
  // create thread pool
16416
16744
  if (n_threads > 1) {
16417
16745
  for (int j = 1; j < n_threads; ++j) {
@@ -16432,12 +16760,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16432
16760
  const int64_t perf_start_time_us = ggml_perf_time_us();
16433
16761
 
16434
16762
  // this is a work thread too
16435
- ggml_graph_compute_thread(&workers[0]);
16763
+ int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
16436
16764
 
16437
16765
  // don't leave affinity set on the main thread
16438
16766
  clear_numa_thread_affinity();
16439
16767
 
16440
- // join thread pool
16768
+ // join or kill thread pool
16441
16769
  if (n_threads > 1) {
16442
16770
  for (int j = 1; j < n_threads; j++) {
16443
16771
  const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16461,6 +16789,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16461
16789
  (double) perf_time_us_cur / 1000.0,
16462
16790
  (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
16463
16791
  }
16792
+
16793
+ return compute_status;
16464
16794
  }
16465
16795
 
16466
16796
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -16473,6 +16803,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
16473
16803
  }
16474
16804
  }
16475
16805
 
16806
+ void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
16807
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
16808
+
16809
+ struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16810
+ GGML_ASSERT(buf);
16811
+
16812
+ cplan.work_data = buf->data;
16813
+
16814
+ ggml_graph_compute(cgraph, &cplan);
16815
+ }
16816
+
16476
16817
  struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
16477
16818
  for (int i = 0; i < cgraph->n_leafs; i++) {
16478
16819
  struct ggml_tensor * leaf = cgraph->leafs[i];
@@ -16511,14 +16852,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
16511
16852
  const int64_t * ne = tensor->ne;
16512
16853
  const size_t * nb = tensor->nb;
16513
16854
 
16514
- fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
16855
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
16515
16856
  arg,
16516
16857
  ggml_type_name(tensor->type),
16517
16858
  ggml_op_name (tensor->op),
16518
16859
  tensor->n_dims,
16519
16860
  ne[0], ne[1], ne[2], ne[3],
16520
16861
  nb[0], nb[1], nb[2], nb[3],
16521
- tensor->n_tasks,
16522
16862
  tensor->data,
16523
16863
  tensor->name);
16524
16864
  }
@@ -16555,8 +16895,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16555
16895
  ggml_graph_export_leaf(cgraph->leafs[i], fout);
16556
16896
 
16557
16897
  GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE);
16558
- GGML_ASSERT(cgraph->leafs[i]->src0 == NULL);
16559
- GGML_ASSERT(cgraph->leafs[i]->src1 == NULL);
16898
+ GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL);
16899
+ GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL);
16560
16900
  }
16561
16901
 
16562
16902
  // header
@@ -16567,17 +16907,9 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16567
16907
  for (int i = 0; i < cgraph->n_nodes; ++i) {
16568
16908
  ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
16569
16909
 
16570
- if (cgraph->nodes[i]->src0) {
16571
- ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout);
16572
- }
16573
-
16574
- if (cgraph->nodes[i]->src1) {
16575
- ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout);
16576
- }
16577
-
16578
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16579
- if (cgraph->nodes[i]->opt[j]) {
16580
- ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout);
16910
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16911
+ if (cgraph->nodes[i]->src[j]) {
16912
+ ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
16581
16913
  }
16582
16914
  }
16583
16915
 
@@ -16668,16 +17000,13 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16668
17000
 
16669
17001
  // output the op arguments
16670
17002
  {
16671
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
16672
-
16673
- args[0] = tensor->src0;
16674
- args[1] = tensor->src1;
17003
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16675
17004
 
16676
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16677
- args[2 + j] = tensor->opt[j];
17005
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17006
+ args[j] = tensor->src[j];
16678
17007
  }
16679
17008
 
16680
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17009
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16681
17010
  if (args[j]) {
16682
17011
  int32_t idx = -1;
16683
17012
 
@@ -16895,12 +17224,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16895
17224
 
16896
17225
  const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
16897
17226
 
16898
- const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
17227
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
16899
17228
 
16900
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
17229
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16901
17230
 
16902
17231
  // parse args
16903
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17232
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16904
17233
  const int32_t arg_idx = ptr_arg_idx[j];
16905
17234
 
16906
17235
  if (arg_idx == -1) {
@@ -16957,11 +17286,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16957
17286
  tensor->nb[j] = nb[j];
16958
17287
  }
16959
17288
 
16960
- tensor->src0 = args[0];
16961
- tensor->src1 = args[1];
16962
-
16963
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16964
- tensor->opt[j] = args[2 + j];
17289
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17290
+ tensor->src[j] = args[j];
16965
17291
  }
16966
17292
 
16967
17293
  result.nodes[i] = tensor;
@@ -17160,19 +17486,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17160
17486
  for (int i = 0; i < gb->n_nodes; i++) {
17161
17487
  struct ggml_tensor * node = gb->nodes[i];
17162
17488
 
17163
- if (node->src0) {
17164
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x");
17165
- }
17166
-
17167
- if (node->src1) {
17168
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y");
17169
- }
17170
-
17171
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17172
- if (node->opt[j]) {
17489
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17490
+ if (node->src[j]) {
17173
17491
  char label[16];
17174
- snprintf(label, sizeof(label), "opt %d", j);
17175
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label);
17492
+ snprintf(label, sizeof(label), "src %d", j);
17493
+ ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
17176
17494
  }
17177
17495
  }
17178
17496
  }
@@ -17180,19 +17498,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17180
17498
  for (int i = 0; i < gb->n_leafs; i++) {
17181
17499
  struct ggml_tensor * node = gb->leafs[i];
17182
17500
 
17183
- if (node->src0) {
17184
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x");
17185
- }
17186
-
17187
- if (node->src1) {
17188
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y");
17189
- }
17190
-
17191
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17192
- if (node->opt[j]) {
17501
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17502
+ if (node->src[j]) {
17193
17503
  char label[16];
17194
- snprintf(label, sizeof(label), "opt %d", j);
17195
- ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label);
17504
+ snprintf(label, sizeof(label), "src %d", j);
17505
+ ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
17196
17506
  }
17197
17507
  }
17198
17508
  }
@@ -17254,9 +17564,6 @@ static enum ggml_opt_result ggml_opt_adam(
17254
17564
  struct ggml_cgraph * gb) {
17255
17565
  GGML_ASSERT(ggml_is_scalar(f));
17256
17566
 
17257
- gf->n_threads = params.n_threads;
17258
- gb->n_threads = params.n_threads;
17259
-
17260
17567
  // these will store the parameters we want to optimize
17261
17568
  struct ggml_tensor * ps[GGML_MAX_PARAMS];
17262
17569
 
@@ -17303,7 +17610,8 @@ static enum ggml_opt_result ggml_opt_adam(
17303
17610
  // compute the function value
17304
17611
  ggml_graph_reset (gf);
17305
17612
  ggml_set_f32 (f->grad, 1.0f);
17306
- ggml_graph_compute(ctx, gb);
17613
+
17614
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17307
17615
 
17308
17616
  opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17309
17617
  opt->adam.fx_best = opt->adam.fx_prev;
@@ -17383,7 +17691,8 @@ static enum ggml_opt_result ggml_opt_adam(
17383
17691
 
17384
17692
  ggml_graph_reset (gf);
17385
17693
  ggml_set_f32 (f->grad, 1.0f);
17386
- ggml_graph_compute(ctx, gb);
17694
+
17695
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17387
17696
 
17388
17697
  const float fx = ggml_get_f32_1d(f, 0);
17389
17698
 
@@ -17505,7 +17814,8 @@ static enum ggml_opt_result linesearch_backtracking(
17505
17814
 
17506
17815
  ggml_graph_reset (gf);
17507
17816
  ggml_set_f32 (f->grad, 1.0f);
17508
- ggml_graph_compute(ctx, gb);
17817
+
17818
+ ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
17509
17819
 
17510
17820
  ggml_opt_get_grad(np, ps, g);
17511
17821
 
@@ -17573,9 +17883,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17573
17883
  }
17574
17884
  }
17575
17885
 
17576
- gf->n_threads = params.n_threads;
17577
- gb->n_threads = params.n_threads;
17578
-
17579
17886
  const int m = params.lbfgs.m;
17580
17887
 
17581
17888
  // these will store the parameters we want to optimize
@@ -17627,7 +17934,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17627
17934
 
17628
17935
  ggml_graph_reset (gf);
17629
17936
  ggml_set_f32 (f->grad, 1.0f);
17630
- ggml_graph_compute(ctx, gb);
17937
+
17938
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17631
17939
 
17632
17940
  ggml_opt_get_grad(np, ps, g);
17633
17941