llama_cpp 0.3.2 → 0.3.4

Sign up to get free protection for your applications and to get access to all the features.
@@ -25,16 +25,23 @@
25
25
  #include <float.h>
26
26
  #include <limits.h>
27
27
  #include <stdarg.h>
28
+ #include <signal.h>
28
29
 
29
30
  #ifdef GGML_USE_METAL
30
31
  #include <unistd.h>
31
32
  #endif
32
33
 
34
+ // static_assert should be a #define, but if it's not,
35
+ // fall back to the _Static_assert C11 keyword.
33
36
  // if C99 - static_assert is noop
34
37
  // ref: https://stackoverflow.com/a/53923785/4039976
35
38
  #ifndef static_assert
39
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
40
+ #define static_assert(cond, msg) _Static_assert(cond, msg)
41
+ #else
36
42
  #define static_assert(cond, msg) struct global_scope_noop_trick
37
43
  #endif
44
+ #endif
38
45
 
39
46
  #if defined(_MSC_VER)
40
47
  // disable "possible loss of data" to avoid hundreds of casts
@@ -49,23 +56,23 @@
49
56
  typedef volatile LONG atomic_int;
50
57
  typedef atomic_int atomic_bool;
51
58
 
52
- static void atomic_store(atomic_int* ptr, LONG val) {
59
+ static void atomic_store(atomic_int * ptr, LONG val) {
53
60
  InterlockedExchange(ptr, val);
54
61
  }
55
- static LONG atomic_load(atomic_int* ptr) {
62
+ static LONG atomic_load(atomic_int * ptr) {
56
63
  return InterlockedCompareExchange(ptr, 0, 0);
57
64
  }
58
- static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
65
+ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
59
66
  return InterlockedExchangeAdd(ptr, inc);
60
67
  }
61
- static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
68
+ static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
62
69
  return atomic_fetch_add(ptr, -(dec));
63
70
  }
64
71
 
65
72
  typedef HANDLE pthread_t;
66
73
 
67
74
  typedef DWORD thread_ret_t;
68
- static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
75
+ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
69
76
  (void) unused;
70
77
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
71
78
  if (handle == NULL)
@@ -77,7 +84,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
77
84
  return 0;
78
85
  }
79
86
 
80
- static int pthread_join(pthread_t thread, void* unused) {
87
+ static int pthread_join(pthread_t thread, void * unused) {
81
88
  (void) unused;
82
89
  return (int) WaitForSingleObject(thread, INFINITE);
83
90
  }
@@ -90,7 +97,7 @@ static int sched_yield (void) {
90
97
  #include <pthread.h>
91
98
  #include <stdatomic.h>
92
99
 
93
- typedef void* thread_ret_t;
100
+ typedef void * thread_ret_t;
94
101
 
95
102
  #include <sys/types.h>
96
103
  #include <sys/stat.h>
@@ -111,10 +118,6 @@ typedef void* thread_ret_t;
111
118
  #endif
112
119
  #endif
113
120
 
114
- #ifdef __HAIKU__
115
- #define static_assert(cond, msg) _Static_assert(cond, msg)
116
- #endif
117
-
118
121
  /*#define GGML_PERF*/
119
122
  #define GGML_DEBUG 0
120
123
  #define GGML_GELU_FP16
@@ -247,7 +250,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
247
250
  #include "ggml-opencl.h"
248
251
  #endif
249
252
  #elif defined(GGML_USE_OPENBLAS)
253
+ #if defined(GGML_BLAS_USE_MKL)
254
+ #include <mkl.h>
255
+ #else
250
256
  #include <cblas.h>
257
+ #endif
251
258
  #elif defined(GGML_USE_CUBLAS)
252
259
  #include "ggml-cuda.h"
253
260
  #elif defined(GGML_USE_CLBLAST)
@@ -3782,6 +3789,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3782
3789
  "CLAMP",
3783
3790
  "CONV_1D",
3784
3791
  "CONV_2D",
3792
+ "POOL_1D",
3793
+ "POOL_2D",
3785
3794
 
3786
3795
  "FLASH_ATTN",
3787
3796
  "FLASH_FF",
@@ -3800,7 +3809,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3800
3809
  "CROSS_ENTROPY_LOSS_BACK",
3801
3810
  };
3802
3811
 
3803
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3812
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3804
3813
 
3805
3814
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3806
3815
  "none",
@@ -3860,6 +3869,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3860
3869
  "clamp(x)",
3861
3870
  "conv_1d(x)",
3862
3871
  "conv_2d(x)",
3872
+ "pool_1d(x)",
3873
+ "pool_2d(x)",
3863
3874
 
3864
3875
  "flash_attn(x)",
3865
3876
  "flash_ff(x)",
@@ -3878,7 +3889,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3878
3889
  "cross_entropy_loss_back(x,y)",
3879
3890
  };
3880
3891
 
3881
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3892
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3893
+
3894
+ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3882
3895
 
3883
3896
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3884
3897
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4157,10 +4170,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
4157
4170
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4158
4171
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4159
4172
 
4160
- return
4161
- (t0->ne[0] == t1->ne[0]) &&
4162
- (t0->ne[2] == t1->ne[2]) &&
4163
- (t0->ne[3] == t1->ne[3]);
4173
+ return (t0->ne[0] == t1->ne[0]) &&
4174
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4175
+ (t1->ne[3]%t0->ne[3] == 0);
4164
4176
  }
4165
4177
 
4166
4178
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4400,8 +4412,8 @@ void ggml_free(struct ggml_context * ctx) {
4400
4412
  if (&g_state.contexts[i].context == ctx) {
4401
4413
  g_state.contexts[i].used = false;
4402
4414
 
4403
- GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
4404
- __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
4415
+ GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
4416
+ __func__, i, ggml_used_mem(ctx));
4405
4417
 
4406
4418
  if (ctx->mem_buffer_owned) {
4407
4419
  GGML_ALIGNED_FREE(ctx->mem_buffer);
@@ -4580,17 +4592,14 @@ struct ggml_tensor * ggml_new_tensor_impl(
4580
4592
  /*.op =*/ GGML_OP_NONE,
4581
4593
  /*.is_param =*/ false,
4582
4594
  /*.grad =*/ NULL,
4583
- /*.src0 =*/ NULL,
4584
- /*.src1 =*/ NULL,
4585
- /*.opt =*/ { NULL },
4586
- /*.n_tasks =*/ 0,
4595
+ /*.src =*/ { NULL },
4587
4596
  /*.perf_runs =*/ 0,
4588
4597
  /*.perf_cycles =*/ 0,
4589
4598
  /*.perf_time_us =*/ 0,
4590
4599
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4591
4600
  /*.name =*/ { 0 },
4592
4601
  /*.extra =*/ NULL,
4593
- /*.pad =*/ { 0 },
4602
+ /*.padding =*/ { 0 },
4594
4603
  };
4595
4604
 
4596
4605
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@@ -4722,7 +4731,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
4722
4731
  {
4723
4732
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4724
4733
  for (int i = 0; i < n; i++) {
4725
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4734
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4726
4735
  }
4727
4736
  } break;
4728
4737
  case GGML_TYPE_F32:
@@ -4774,7 +4783,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
4774
4783
  {
4775
4784
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4776
4785
  for (int i = 0; i < n; i++) {
4777
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4786
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4778
4787
  }
4779
4788
  } break;
4780
4789
  case GGML_TYPE_F32:
@@ -5009,8 +5018,8 @@ struct ggml_tensor * ggml_dup_impl(
5009
5018
 
5010
5019
  result->op = GGML_OP_DUP;
5011
5020
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5012
- result->src0 = a;
5013
- result->src1 = NULL;
5021
+ result->src[0] = a;
5022
+ result->src[1] = NULL;
5014
5023
 
5015
5024
  return result;
5016
5025
  }
@@ -5034,11 +5043,15 @@ struct ggml_tensor * ggml_add_impl(
5034
5043
  struct ggml_tensor * a,
5035
5044
  struct ggml_tensor * b,
5036
5045
  bool inplace) {
5037
- GGML_ASSERT(ggml_are_same_shape(a, b));
5046
+ // TODO: support less-strict constraint
5047
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5048
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5038
5049
 
5039
5050
  bool is_node = false;
5040
5051
 
5041
- if (a->grad || b->grad) {
5052
+ if (!inplace && (a->grad || b->grad)) {
5053
+ // TODO: support backward pass for broadcasting
5054
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5042
5055
  is_node = true;
5043
5056
  }
5044
5057
 
@@ -5046,8 +5059,8 @@ struct ggml_tensor * ggml_add_impl(
5046
5059
 
5047
5060
  result->op = GGML_OP_ADD;
5048
5061
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5049
- result->src0 = a;
5050
- result->src1 = b;
5062
+ result->src[0] = a;
5063
+ result->src[1] = b;
5051
5064
 
5052
5065
  return result;
5053
5066
  }
@@ -5086,8 +5099,8 @@ struct ggml_tensor * ggml_add1_impl(
5086
5099
 
5087
5100
  result->op = GGML_OP_ADD1;
5088
5101
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5089
- result->src0 = a;
5090
- result->src1 = b;
5102
+ result->src[0] = a;
5103
+ result->src[1] = b;
5091
5104
 
5092
5105
  return result;
5093
5106
  }
@@ -5144,9 +5157,9 @@ struct ggml_tensor * ggml_acc_impl(
5144
5157
 
5145
5158
  result->op = GGML_OP_ACC;
5146
5159
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5147
- result->src0 = a;
5148
- result->src1 = b;
5149
- result->opt[0] = c;
5160
+ result->src[0] = a;
5161
+ result->src[1] = b;
5162
+ result->src[2] = c;
5150
5163
 
5151
5164
  return result;
5152
5165
  }
@@ -5192,8 +5205,8 @@ struct ggml_tensor * ggml_sub_impl(
5192
5205
 
5193
5206
  result->op = GGML_OP_SUB;
5194
5207
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5195
- result->src0 = a;
5196
- result->src1 = b;
5208
+ result->src[0] = a;
5209
+ result->src[1] = b;
5197
5210
 
5198
5211
  return result;
5199
5212
  }
@@ -5239,8 +5252,8 @@ struct ggml_tensor * ggml_mul_impl(
5239
5252
 
5240
5253
  result->op = GGML_OP_MUL;
5241
5254
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5242
- result->src0 = a;
5243
- result->src1 = b;
5255
+ result->src[0] = a;
5256
+ result->src[1] = b;
5244
5257
 
5245
5258
  return result;
5246
5259
  }
@@ -5282,8 +5295,8 @@ struct ggml_tensor * ggml_div_impl(
5282
5295
 
5283
5296
  result->op = GGML_OP_DIV;
5284
5297
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5285
- result->src0 = a;
5286
- result->src1 = b;
5298
+ result->src[0] = a;
5299
+ result->src[1] = b;
5287
5300
 
5288
5301
  return result;
5289
5302
  }
@@ -5318,8 +5331,8 @@ struct ggml_tensor * ggml_sqr_impl(
5318
5331
 
5319
5332
  result->op = GGML_OP_SQR;
5320
5333
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5321
- result->src0 = a;
5322
- result->src1 = NULL;
5334
+ result->src[0] = a;
5335
+ result->src[1] = NULL;
5323
5336
 
5324
5337
  return result;
5325
5338
  }
@@ -5352,8 +5365,8 @@ struct ggml_tensor * ggml_sqrt_impl(
5352
5365
 
5353
5366
  result->op = GGML_OP_SQRT;
5354
5367
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5355
- result->src0 = a;
5356
- result->src1 = NULL;
5368
+ result->src[0] = a;
5369
+ result->src[1] = NULL;
5357
5370
 
5358
5371
  return result;
5359
5372
  }
@@ -5387,8 +5400,8 @@ struct ggml_tensor * ggml_log_impl(
5387
5400
 
5388
5401
  result->op = GGML_OP_LOG;
5389
5402
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5390
- result->src0 = a;
5391
- result->src1 = NULL;
5403
+ result->src[0] = a;
5404
+ result->src[1] = NULL;
5392
5405
 
5393
5406
  return result;
5394
5407
  }
@@ -5420,8 +5433,8 @@ struct ggml_tensor * ggml_sum(
5420
5433
 
5421
5434
  result->op = GGML_OP_SUM;
5422
5435
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5423
- result->src0 = a;
5424
- result->src1 = NULL;
5436
+ result->src[0] = a;
5437
+ result->src[1] = NULL;
5425
5438
 
5426
5439
  return result;
5427
5440
  }
@@ -5447,8 +5460,8 @@ struct ggml_tensor * ggml_sum_rows(
5447
5460
 
5448
5461
  result->op = GGML_OP_SUM_ROWS;
5449
5462
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5450
- result->src0 = a;
5451
- result->src1 = NULL;
5463
+ result->src[0] = a;
5464
+ result->src[1] = NULL;
5452
5465
 
5453
5466
  return result;
5454
5467
  }
@@ -5470,8 +5483,8 @@ struct ggml_tensor * ggml_mean(
5470
5483
 
5471
5484
  result->op = GGML_OP_MEAN;
5472
5485
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5473
- result->src0 = a;
5474
- result->src1 = NULL;
5486
+ result->src[0] = a;
5487
+ result->src[1] = NULL;
5475
5488
 
5476
5489
  return result;
5477
5490
  }
@@ -5494,8 +5507,8 @@ struct ggml_tensor * ggml_argmax(
5494
5507
 
5495
5508
  result->op = GGML_OP_ARGMAX;
5496
5509
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5497
- result->src0 = a;
5498
- result->src1 = NULL;
5510
+ result->src[0] = a;
5511
+ result->src[1] = NULL;
5499
5512
 
5500
5513
  return result;
5501
5514
  }
@@ -5522,8 +5535,8 @@ struct ggml_tensor * ggml_repeat(
5522
5535
 
5523
5536
  result->op = GGML_OP_REPEAT;
5524
5537
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5525
- result->src0 = a;
5526
- result->src1 = b;
5538
+ result->src[0] = a;
5539
+ result->src[1] = b;
5527
5540
 
5528
5541
  return result;
5529
5542
  }
@@ -5550,8 +5563,8 @@ struct ggml_tensor * ggml_repeat_back(
5550
5563
 
5551
5564
  result->op = GGML_OP_REPEAT_BACK;
5552
5565
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5553
- result->src0 = a;
5554
- result->src1 = b;
5566
+ result->src[0] = a;
5567
+ result->src[1] = b;
5555
5568
 
5556
5569
  return result;
5557
5570
  }
@@ -5572,8 +5585,8 @@ struct ggml_tensor * ggml_abs_impl(
5572
5585
 
5573
5586
  result->op = GGML_OP_ABS;
5574
5587
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5575
- result->src0 = a;
5576
- result->src1 = NULL;
5588
+ result->src[0] = a;
5589
+ result->src[1] = NULL;
5577
5590
 
5578
5591
  return result;
5579
5592
  }
@@ -5607,8 +5620,8 @@ struct ggml_tensor * ggml_sgn_impl(
5607
5620
 
5608
5621
  result->op = GGML_OP_SGN;
5609
5622
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5610
- result->src0 = a;
5611
- result->src1 = NULL;
5623
+ result->src[0] = a;
5624
+ result->src[1] = NULL;
5612
5625
 
5613
5626
  return result;
5614
5627
  }
@@ -5641,8 +5654,8 @@ struct ggml_tensor * ggml_neg_impl(
5641
5654
 
5642
5655
  result->op = GGML_OP_NEG;
5643
5656
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5644
- result->src0 = a;
5645
- result->src1 = NULL;
5657
+ result->src[0] = a;
5658
+ result->src[1] = NULL;
5646
5659
 
5647
5660
  return result;
5648
5661
  }
@@ -5675,8 +5688,8 @@ struct ggml_tensor * ggml_step_impl(
5675
5688
 
5676
5689
  result->op = GGML_OP_STEP;
5677
5690
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5678
- result->src0 = a;
5679
- result->src1 = NULL;
5691
+ result->src[0] = a;
5692
+ result->src[1] = NULL;
5680
5693
 
5681
5694
  return result;
5682
5695
  }
@@ -5709,8 +5722,8 @@ struct ggml_tensor * ggml_tanh_impl(
5709
5722
 
5710
5723
  result->op = GGML_OP_TANH;
5711
5724
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5712
- result->src0 = a;
5713
- result->src1 = NULL;
5725
+ result->src[0] = a;
5726
+ result->src[1] = NULL;
5714
5727
 
5715
5728
  return result;
5716
5729
  }
@@ -5743,8 +5756,8 @@ struct ggml_tensor * ggml_elu_impl(
5743
5756
 
5744
5757
  result->op = GGML_OP_ELU;
5745
5758
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5746
- result->src0 = a;
5747
- result->src1 = NULL;
5759
+ result->src[0] = a;
5760
+ result->src[1] = NULL;
5748
5761
 
5749
5762
  return result;
5750
5763
  }
@@ -5777,8 +5790,8 @@ struct ggml_tensor * ggml_relu_impl(
5777
5790
 
5778
5791
  result->op = GGML_OP_RELU;
5779
5792
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5780
- result->src0 = a;
5781
- result->src1 = NULL;
5793
+ result->src[0] = a;
5794
+ result->src[1] = NULL;
5782
5795
 
5783
5796
  return result;
5784
5797
  }
@@ -5811,8 +5824,8 @@ struct ggml_tensor * ggml_gelu_impl(
5811
5824
 
5812
5825
  result->op = GGML_OP_GELU;
5813
5826
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5814
- result->src0 = a;
5815
- result->src1 = NULL;
5827
+ result->src[0] = a;
5828
+ result->src[1] = NULL;
5816
5829
 
5817
5830
  return result;
5818
5831
  }
@@ -5845,8 +5858,8 @@ struct ggml_tensor * ggml_gelu_quick_impl(
5845
5858
 
5846
5859
  result->op = GGML_OP_GELU_QUICK;
5847
5860
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5848
- result->src0 = a;
5849
- result->src1 = NULL;
5861
+ result->src[0] = a;
5862
+ result->src[1] = NULL;
5850
5863
 
5851
5864
  return result;
5852
5865
  }
@@ -5879,8 +5892,8 @@ struct ggml_tensor * ggml_silu_impl(
5879
5892
 
5880
5893
  result->op = GGML_OP_SILU;
5881
5894
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5882
- result->src0 = a;
5883
- result->src1 = NULL;
5895
+ result->src[0] = a;
5896
+ result->src[1] = NULL;
5884
5897
 
5885
5898
  return result;
5886
5899
  }
@@ -5914,8 +5927,8 @@ struct ggml_tensor * ggml_silu_back(
5914
5927
 
5915
5928
  result->op = GGML_OP_SILU_BACK;
5916
5929
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5917
- result->src0 = a;
5918
- result->src1 = b;
5930
+ result->src[0] = a;
5931
+ result->src[1] = b;
5919
5932
 
5920
5933
  return result;
5921
5934
  }
@@ -5937,8 +5950,8 @@ struct ggml_tensor * ggml_norm_impl(
5937
5950
 
5938
5951
  result->op = GGML_OP_NORM;
5939
5952
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5940
- result->src0 = a;
5941
- result->src1 = NULL; // TODO: maybe store epsilon here?
5953
+ result->src[0] = a;
5954
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5942
5955
 
5943
5956
  return result;
5944
5957
  }
@@ -5969,8 +5982,8 @@ struct ggml_tensor * ggml_rms_norm_impl(
5969
5982
 
5970
5983
  result->op = GGML_OP_RMS_NORM;
5971
5984
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5972
- result->src0 = a;
5973
- result->src1 = NULL; // TODO: maybe store epsilon here?
5985
+ result->src[0] = a;
5986
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5974
5987
 
5975
5988
  return result;
5976
5989
  }
@@ -6002,8 +6015,8 @@ struct ggml_tensor * ggml_rms_norm_back(
6002
6015
 
6003
6016
  result->op = GGML_OP_RMS_NORM_BACK;
6004
6017
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6005
- result->src0 = a;
6006
- result->src1 = b;
6018
+ result->src[0] = a;
6019
+ result->src[1] = b;
6007
6020
 
6008
6021
  return result;
6009
6022
  }
@@ -6024,13 +6037,13 @@ struct ggml_tensor * ggml_mul_mat(
6024
6037
  is_node = true;
6025
6038
  }
6026
6039
 
6027
- const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
6028
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6040
+ const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
6041
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6029
6042
 
6030
6043
  result->op = GGML_OP_MUL_MAT;
6031
6044
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6032
- result->src0 = a;
6033
- result->src1 = b;
6045
+ result->src[0] = a;
6046
+ result->src[1] = b;
6034
6047
 
6035
6048
  return result;
6036
6049
  }
@@ -6055,8 +6068,8 @@ struct ggml_tensor * ggml_out_prod(
6055
6068
 
6056
6069
  result->op = GGML_OP_OUT_PROD;
6057
6070
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6058
- result->src0 = a;
6059
- result->src1 = b;
6071
+ result->src[0] = a;
6072
+ result->src[1] = b;
6060
6073
 
6061
6074
  return result;
6062
6075
  }
@@ -6081,8 +6094,8 @@ struct ggml_tensor * ggml_scale_impl(
6081
6094
 
6082
6095
  result->op = GGML_OP_SCALE;
6083
6096
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6084
- result->src0 = a;
6085
- result->src1 = b;
6097
+ result->src[0] = a;
6098
+ result->src[1] = b;
6086
6099
 
6087
6100
  return result;
6088
6101
  }
@@ -6137,9 +6150,9 @@ struct ggml_tensor * ggml_set_impl(
6137
6150
 
6138
6151
  result->op = GGML_OP_SET;
6139
6152
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6140
- result->src0 = a;
6141
- result->src1 = b;
6142
- result->opt[0] = c;
6153
+ result->src[0] = a;
6154
+ result->src[1] = b;
6155
+ result->src[2] = c;
6143
6156
 
6144
6157
  return result;
6145
6158
  }
@@ -6226,8 +6239,8 @@ struct ggml_tensor * ggml_cpy_impl(
6226
6239
 
6227
6240
  result->op = GGML_OP_CPY;
6228
6241
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6229
- result->src0 = a;
6230
- result->src1 = b;
6242
+ result->src[0] = a;
6243
+ result->src[1] = b;
6231
6244
 
6232
6245
  return result;
6233
6246
  }
@@ -6263,8 +6276,8 @@ struct ggml_tensor * ggml_cont_impl(
6263
6276
 
6264
6277
  result->op = GGML_OP_CONT;
6265
6278
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6266
- result->src0 = a;
6267
- result->src1 = NULL;
6279
+ result->src[0] = a;
6280
+ result->src[1] = NULL;
6268
6281
 
6269
6282
  return result;
6270
6283
  }
@@ -6307,8 +6320,8 @@ struct ggml_tensor * ggml_reshape(
6307
6320
 
6308
6321
  result->op = GGML_OP_RESHAPE;
6309
6322
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6310
- result->src0 = a;
6311
- result->src1 = NULL;
6323
+ result->src[0] = a;
6324
+ result->src[1] = NULL;
6312
6325
 
6313
6326
  return result;
6314
6327
  }
@@ -6332,8 +6345,8 @@ struct ggml_tensor * ggml_reshape_1d(
6332
6345
 
6333
6346
  result->op = GGML_OP_RESHAPE;
6334
6347
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6335
- result->src0 = a;
6336
- result->src1 = NULL;
6348
+ result->src[0] = a;
6349
+ result->src[1] = NULL;
6337
6350
 
6338
6351
  return result;
6339
6352
  }
@@ -6358,8 +6371,8 @@ struct ggml_tensor * ggml_reshape_2d(
6358
6371
 
6359
6372
  result->op = GGML_OP_RESHAPE;
6360
6373
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6361
- result->src0 = a;
6362
- result->src1 = NULL;
6374
+ result->src[0] = a;
6375
+ result->src[1] = NULL;
6363
6376
 
6364
6377
  return result;
6365
6378
  }
@@ -6385,8 +6398,8 @@ struct ggml_tensor * ggml_reshape_3d(
6385
6398
 
6386
6399
  result->op = GGML_OP_RESHAPE;
6387
6400
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6388
- result->src0 = a;
6389
- result->src1 = NULL;
6401
+ result->src[0] = a;
6402
+ result->src[1] = NULL;
6390
6403
 
6391
6404
  return result;
6392
6405
  }
@@ -6414,8 +6427,8 @@ struct ggml_tensor * ggml_reshape_4d(
6414
6427
 
6415
6428
  result->op = GGML_OP_RESHAPE;
6416
6429
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6417
- result->src0 = a;
6418
- result->src1 = NULL;
6430
+ result->src[0] = a;
6431
+ result->src[1] = NULL;
6419
6432
 
6420
6433
  return result;
6421
6434
  }
@@ -6447,9 +6460,9 @@ struct ggml_tensor * ggml_view_1d(
6447
6460
 
6448
6461
  result->op = GGML_OP_VIEW;
6449
6462
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6450
- result->src0 = a;
6451
- result->src1 = NULL;
6452
- result->opt[0] = offs;
6463
+ result->src[0] = a;
6464
+ result->src[1] = NULL;
6465
+ result->src[2] = offs;
6453
6466
 
6454
6467
  return result;
6455
6468
  }
@@ -6489,9 +6502,9 @@ struct ggml_tensor * ggml_view_2d(
6489
6502
 
6490
6503
  result->op = GGML_OP_VIEW;
6491
6504
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6492
- result->src0 = a;
6493
- result->src1 = NULL;
6494
- result->opt[0] = offs;
6505
+ result->src[0] = a;
6506
+ result->src[1] = NULL;
6507
+ result->src[2] = offs;
6495
6508
 
6496
6509
  return result;
6497
6510
  }
@@ -6533,9 +6546,9 @@ struct ggml_tensor * ggml_view_3d(
6533
6546
 
6534
6547
  result->op = GGML_OP_VIEW;
6535
6548
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6536
- result->src0 = a;
6537
- result->src1 = NULL;
6538
- result->opt[0] = offs;
6549
+ result->src[0] = a;
6550
+ result->src[1] = NULL;
6551
+ result->src[2] = offs;
6539
6552
 
6540
6553
  return result;
6541
6554
  }
@@ -6579,9 +6592,9 @@ struct ggml_tensor * ggml_view_4d(
6579
6592
 
6580
6593
  result->op = GGML_OP_VIEW;
6581
6594
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6582
- result->src0 = a;
6583
- result->src1 = NULL;
6584
- result->opt[0] = offs;
6595
+ result->src[0] = a;
6596
+ result->src[1] = NULL;
6597
+ result->src[2] = offs;
6585
6598
 
6586
6599
  return result;
6587
6600
  }
@@ -6641,8 +6654,8 @@ struct ggml_tensor * ggml_permute(
6641
6654
 
6642
6655
  result->op = GGML_OP_PERMUTE;
6643
6656
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6644
- result->src0 = a;
6645
- result->src1 = NULL;
6657
+ result->src[0] = a;
6658
+ result->src[1] = NULL;
6646
6659
 
6647
6660
  if (is_node) {
6648
6661
  ggml_scratch_save(ctx);
@@ -6656,7 +6669,7 @@ struct ggml_tensor * ggml_permute(
6656
6669
 
6657
6670
  ggml_scratch_load(ctx);
6658
6671
 
6659
- result->opt[0] = b;
6672
+ result->src[2] = b;
6660
6673
  }
6661
6674
 
6662
6675
  return result;
@@ -6684,8 +6697,8 @@ struct ggml_tensor * ggml_transpose(
6684
6697
 
6685
6698
  result->op = GGML_OP_TRANSPOSE;
6686
6699
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6687
- result->src0 = a;
6688
- result->src1 = NULL;
6700
+ result->src[0] = a;
6701
+ result->src[1] = NULL;
6689
6702
 
6690
6703
  return result;
6691
6704
  }
@@ -6710,8 +6723,8 @@ struct ggml_tensor * ggml_get_rows(
6710
6723
 
6711
6724
  result->op = GGML_OP_GET_ROWS;
6712
6725
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6713
- result->src0 = a;
6714
- result->src1 = b;
6726
+ result->src[0] = a;
6727
+ result->src[1] = b;
6715
6728
 
6716
6729
  return result;
6717
6730
  }
@@ -6738,9 +6751,9 @@ struct ggml_tensor * ggml_get_rows_back(
6738
6751
 
6739
6752
  result->op = GGML_OP_GET_ROWS_BACK;
6740
6753
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6741
- result->src0 = a;
6742
- result->src1 = b;
6743
- result->opt[0] = c;
6754
+ result->src[0] = a;
6755
+ result->src[1] = b;
6756
+ result->src[2] = c;
6744
6757
 
6745
6758
  return result;
6746
6759
  }
@@ -6762,8 +6775,8 @@ struct ggml_tensor * ggml_diag(
6762
6775
 
6763
6776
  result->op = GGML_OP_DIAG;
6764
6777
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6765
- result->src0 = a;
6766
- result->src1 = NULL;
6778
+ result->src[0] = a;
6779
+ result->src[1] = NULL;
6767
6780
 
6768
6781
  return result;
6769
6782
  }
@@ -6795,8 +6808,8 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
6795
6808
 
6796
6809
  result->op = GGML_OP_DIAG_MASK_INF;
6797
6810
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6798
- result->src0 = a;
6799
- result->src1 = b;
6811
+ result->src[0] = a;
6812
+ result->src[1] = b;
6800
6813
 
6801
6814
  return result;
6802
6815
  }
@@ -6843,8 +6856,8 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
6843
6856
 
6844
6857
  result->op = GGML_OP_DIAG_MASK_ZERO;
6845
6858
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6846
- result->src0 = a;
6847
- result->src1 = b;
6859
+ result->src[0] = a;
6860
+ result->src[1] = b;
6848
6861
 
6849
6862
  return result;
6850
6863
  }
@@ -6879,8 +6892,8 @@ struct ggml_tensor * ggml_soft_max_impl(
6879
6892
 
6880
6893
  result->op = GGML_OP_SOFT_MAX;
6881
6894
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6882
- result->src0 = a;
6883
- result->src1 = NULL;
6895
+ result->src[0] = a;
6896
+ result->src[1] = NULL;
6884
6897
 
6885
6898
  return result;
6886
6899
  }
@@ -6915,8 +6928,8 @@ struct ggml_tensor * ggml_soft_max_back_impl(
6915
6928
 
6916
6929
  result->op = GGML_OP_SOFT_MAX_BACK;
6917
6930
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6918
- result->src0 = a;
6919
- result->src1 = b;
6931
+ result->src[0] = a;
6932
+ result->src[1] = b;
6920
6933
 
6921
6934
  return result;
6922
6935
  }
@@ -6944,6 +6957,8 @@ struct ggml_tensor * ggml_rope_impl(
6944
6957
  int n_dims,
6945
6958
  int mode,
6946
6959
  int n_ctx,
6960
+ float freq_base,
6961
+ float freq_scale,
6947
6962
  bool inplace) {
6948
6963
  GGML_ASSERT(n_past >= 0);
6949
6964
  bool is_node = false;
@@ -6956,19 +6971,21 @@ struct ggml_tensor * ggml_rope_impl(
6956
6971
 
6957
6972
  ggml_scratch_save(ctx);
6958
6973
 
6959
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6974
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
6960
6975
 
6961
6976
  ((int32_t *) b->data)[0] = n_past;
6962
6977
  ((int32_t *) b->data)[1] = n_dims;
6963
6978
  ((int32_t *) b->data)[2] = mode;
6964
6979
  ((int32_t *) b->data)[3] = n_ctx;
6980
+ memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
6981
+ memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
6965
6982
 
6966
6983
  ggml_scratch_load(ctx);
6967
6984
 
6968
6985
  result->op = GGML_OP_ROPE;
6969
6986
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6970
- result->src0 = a;
6971
- result->src1 = b;
6987
+ result->src[0] = a;
6988
+ result->src[1] = b;
6972
6989
 
6973
6990
  return result;
6974
6991
  }
@@ -6980,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
6980
6997
  int n_dims,
6981
6998
  int mode,
6982
6999
  int n_ctx) {
6983
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
7000
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
6984
7001
  }
6985
7002
 
6986
7003
  struct ggml_tensor * ggml_rope_inplace(
@@ -6990,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
6990
7007
  int n_dims,
6991
7008
  int mode,
6992
7009
  int n_ctx) {
6993
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
7010
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
7011
+ }
7012
+
7013
+ struct ggml_tensor * ggml_rope_custom_inplace(
7014
+ struct ggml_context * ctx,
7015
+ struct ggml_tensor * a,
7016
+ int n_past,
7017
+ int n_dims,
7018
+ int mode,
7019
+ int n_ctx,
7020
+ float freq_base,
7021
+ float freq_scale) {
7022
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
6994
7023
  }
6995
7024
 
6996
7025
  // ggml_rope_back
@@ -7000,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
7000
7029
  struct ggml_tensor * a,
7001
7030
  int n_past,
7002
7031
  int n_dims,
7003
- int mode) {
7032
+ int mode,
7033
+ int n_ctx) {
7004
7034
  GGML_ASSERT(n_past >= 0);
7005
7035
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7006
7036
 
@@ -7014,19 +7044,20 @@ struct ggml_tensor * ggml_rope_back(
7014
7044
 
7015
7045
  ggml_scratch_save(ctx);
7016
7046
 
7017
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
7047
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7018
7048
  ggml_set_name(b, "n_past, n_dims, mode");
7019
7049
 
7020
7050
  ((int32_t *) b->data)[0] = n_past;
7021
7051
  ((int32_t *) b->data)[1] = n_dims;
7022
7052
  ((int32_t *) b->data)[2] = mode;
7053
+ ((int32_t *) b->data)[3] = n_ctx;
7023
7054
 
7024
7055
  ggml_scratch_load(ctx);
7025
7056
 
7026
7057
  result->op = GGML_OP_ROPE_BACK;
7027
7058
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7028
- result->src0 = a;
7029
- result->src1 = b;
7059
+ result->src[0] = a;
7060
+ result->src[1] = b;
7030
7061
 
7031
7062
  return result;
7032
7063
  }
@@ -7064,8 +7095,8 @@ struct ggml_tensor * ggml_alibi(
7064
7095
 
7065
7096
  result->op = GGML_OP_ALIBI;
7066
7097
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7067
- result->src0 = a;
7068
- result->src1 = b;
7098
+ result->src[0] = a;
7099
+ result->src[1] = b;
7069
7100
 
7070
7101
  return result;
7071
7102
  }
@@ -7098,8 +7129,8 @@ struct ggml_tensor * ggml_clamp(
7098
7129
 
7099
7130
  result->op = GGML_OP_CLAMP;
7100
7131
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7101
- result->src0 = a;
7102
- result->src1 = b;
7132
+ result->src[0] = a;
7133
+ result->src[1] = b;
7103
7134
 
7104
7135
  return result;
7105
7136
  }
@@ -7141,9 +7172,9 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7141
7172
 
7142
7173
  result->op = GGML_OP_CONV_1D;
7143
7174
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7144
- result->src0 = a;
7145
- result->src1 = b;
7146
- result->opt[0] = c;
7175
+ result->src[0] = a;
7176
+ result->src[1] = b;
7177
+ result->src[2] = c;
7147
7178
 
7148
7179
  return result;
7149
7180
  }
@@ -7161,7 +7192,6 @@ struct ggml_tensor* ggml_conv_2d(
7161
7192
  int d0,
7162
7193
  int d1) {
7163
7194
 
7164
- GGML_ASSERT(b->ne[3] == 1);
7165
7195
  GGML_ASSERT(a->ne[2] == b->ne[2]);
7166
7196
  bool is_node = false;
7167
7197
 
@@ -7173,7 +7203,7 @@ struct ggml_tensor* ggml_conv_2d(
7173
7203
  const int64_t ne[4] = {
7174
7204
  ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7175
7205
  ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
7176
- a->ne[3], 1,
7206
+ a->ne[3], b->ne[3],
7177
7207
  };
7178
7208
  struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7179
7209
 
@@ -7189,9 +7219,9 @@ struct ggml_tensor* ggml_conv_2d(
7189
7219
 
7190
7220
  result->op = GGML_OP_CONV_2D;
7191
7221
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7192
- result->src0 = a;
7193
- result->src1 = b;
7194
- result->opt[0] = c;
7222
+ result->src[0] = a;
7223
+ result->src[1] = b;
7224
+ result->src[2] = c;
7195
7225
 
7196
7226
  return result;
7197
7227
 
@@ -7208,6 +7238,98 @@ struct ggml_tensor* ggml_conv_1d_ph(
7208
7238
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7209
7239
  }
7210
7240
 
7241
+
7242
+ // ggml_pool_*
7243
+
7244
+ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
7245
+ return (ins + 2 * p - ks) / s + 1;
7246
+ }
7247
+
7248
+ // ggml_pool_2d
7249
+
7250
+ struct ggml_tensor* ggml_pool_1d(
7251
+ struct ggml_context * ctx,
7252
+ struct ggml_tensor * a,
7253
+ enum ggml_op_pool op,
7254
+ int k0,
7255
+ int s0,
7256
+ int p0) {
7257
+
7258
+ bool is_node = false;
7259
+
7260
+ if (a->grad) {
7261
+ GGML_ASSERT(false); // TODO: implement backward
7262
+ is_node = true;
7263
+ }
7264
+
7265
+ const int64_t ne[3] = {
7266
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7267
+ a->ne[1],
7268
+ };
7269
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7270
+
7271
+ ggml_scratch_save(ctx);
7272
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7273
+ ((int32_t*)c->data)[0] = op;
7274
+ ((int32_t*)c->data)[1] = k0;
7275
+ ((int32_t*)c->data)[2] = s0;
7276
+ ((int32_t*)c->data)[3] = p0;
7277
+ ggml_scratch_load(ctx);
7278
+
7279
+ result->op = GGML_OP_POOL_1D;
7280
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7281
+ result->src[0] = a;
7282
+ result->src[1] = c;
7283
+
7284
+ return result;
7285
+ }
7286
+
7287
+ // ggml_pool_2d
7288
+
7289
+ struct ggml_tensor* ggml_pool_2d(
7290
+ struct ggml_context * ctx,
7291
+ struct ggml_tensor * a,
7292
+ enum ggml_op_pool op,
7293
+ int k0,
7294
+ int k1,
7295
+ int s0,
7296
+ int s1,
7297
+ int p0,
7298
+ int p1) {
7299
+
7300
+ bool is_node = false;
7301
+
7302
+ if (a->grad) {
7303
+ GGML_ASSERT(false); // TODO: implement backward
7304
+ is_node = true;
7305
+ }
7306
+
7307
+ const int64_t ne[3] = {
7308
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7309
+ ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
7310
+ a->ne[2],
7311
+ };
7312
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7313
+
7314
+ ggml_scratch_save(ctx);
7315
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
7316
+ ((int32_t*)c->data)[0] = op;
7317
+ ((int32_t*)c->data)[1] = k0;
7318
+ ((int32_t*)c->data)[2] = k1;
7319
+ ((int32_t*)c->data)[3] = s0;
7320
+ ((int32_t*)c->data)[4] = s1;
7321
+ ((int32_t*)c->data)[5] = p0;
7322
+ ((int32_t*)c->data)[6] = p1;
7323
+ ggml_scratch_load(ctx);
7324
+
7325
+ result->op = GGML_OP_POOL_2D;
7326
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7327
+ result->src[0] = a;
7328
+ result->src[1] = c;
7329
+
7330
+ return result;
7331
+ }
7332
+
7211
7333
  // ggml_flash_attn
7212
7334
 
7213
7335
  struct ggml_tensor * ggml_flash_attn(
@@ -7230,10 +7352,10 @@ struct ggml_tensor * ggml_flash_attn(
7230
7352
 
7231
7353
  result->op = GGML_OP_FLASH_ATTN;
7232
7354
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7233
- result->src0 = q;
7234
- result->src1 = k;
7235
- result->opt[0] = v;
7236
- result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
7355
+ result->src[0] = q;
7356
+ result->src[1] = k;
7357
+ result->src[2] = v;
7358
+ result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
7237
7359
 
7238
7360
  return result;
7239
7361
  }
@@ -7261,11 +7383,11 @@ struct ggml_tensor * ggml_flash_ff(
7261
7383
 
7262
7384
  result->op = GGML_OP_FLASH_FF;
7263
7385
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7264
- result->src0 = a;
7265
- result->src1 = b0;
7266
- result->opt[0] = b1;
7267
- result->opt[1] = c0;
7268
- result->opt[2] = c1;
7386
+ result->src[0] = a;
7387
+ result->src[1] = b0;
7388
+ result->src[2] = b1;
7389
+ result->src[3] = c0;
7390
+ result->src[4] = c1;
7269
7391
 
7270
7392
  return result;
7271
7393
  }
@@ -7325,11 +7447,11 @@ struct ggml_tensor * ggml_flash_attn_back(
7325
7447
 
7326
7448
  result->op = GGML_OP_FLASH_ATTN_BACK;
7327
7449
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7328
- result->src0 = q;
7329
- result->src1 = k;
7330
- result->opt[0] = v;
7331
- result->opt[1] = d;
7332
- result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
7450
+ result->src[0] = q;
7451
+ result->src[1] = k;
7452
+ result->src[2] = v;
7453
+ result->src[3] = d;
7454
+ result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
7333
7455
 
7334
7456
  return result;
7335
7457
  }
@@ -7374,9 +7496,9 @@ struct ggml_tensor * ggml_win_part(
7374
7496
 
7375
7497
  result->op = GGML_OP_WIN_PART;
7376
7498
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7377
- result->src0 = a;
7378
- result->src1 = NULL;
7379
- result->opt[0] = b;
7499
+ result->src[0] = a;
7500
+ result->src[1] = NULL;
7501
+ result->src[2] = b;
7380
7502
 
7381
7503
  return result;
7382
7504
  }
@@ -7411,9 +7533,9 @@ struct ggml_tensor * ggml_win_unpart(
7411
7533
 
7412
7534
  result->op = GGML_OP_WIN_UNPART;
7413
7535
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7414
- result->src0 = a;
7415
- result->src1 = NULL;
7416
- result->opt[0] = b;
7536
+ result->src[0] = a;
7537
+ result->src[1] = NULL;
7538
+ result->src[2] = b;
7417
7539
 
7418
7540
  return result;
7419
7541
  }
@@ -7442,8 +7564,8 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
7442
7564
 
7443
7565
  result->op = GGML_OP_MAP_UNARY;
7444
7566
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7445
- result->src0 = a;
7446
- result->opt[0] = addr_tensor;
7567
+ result->src[0] = a;
7568
+ result->src[2] = addr_tensor;
7447
7569
 
7448
7570
  return result;
7449
7571
  }
@@ -7489,9 +7611,9 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
7489
7611
 
7490
7612
  result->op = GGML_OP_MAP_BINARY;
7491
7613
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7492
- result->src0 = a;
7493
- result->src1 = b;
7494
- result->opt[0] = addr_tensor;
7614
+ result->src[0] = a;
7615
+ result->src[1] = b;
7616
+ result->src[2] = addr_tensor;
7495
7617
 
7496
7618
  return result;
7497
7619
  }
@@ -7536,8 +7658,8 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
7536
7658
 
7537
7659
  result->op = GGML_OP_MAP_CUSTOM1;
7538
7660
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7539
- result->src0 = a;
7540
- result->opt[0] = addr_tensor;
7661
+ result->src[0] = a;
7662
+ result->src[2] = addr_tensor;
7541
7663
 
7542
7664
  return result;
7543
7665
  }
@@ -7581,9 +7703,9 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
7581
7703
 
7582
7704
  result->op = GGML_OP_MAP_CUSTOM2;
7583
7705
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7584
- result->src0 = a;
7585
- result->src1 = b;
7586
- result->opt[0] = addr_tensor;
7706
+ result->src[0] = a;
7707
+ result->src[1] = b;
7708
+ result->src[2] = addr_tensor;
7587
7709
 
7588
7710
  return result;
7589
7711
  }
@@ -7630,10 +7752,10 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
7630
7752
 
7631
7753
  result->op = GGML_OP_MAP_CUSTOM3;
7632
7754
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7633
- result->src0 = a;
7634
- result->src1 = b;
7635
- result->opt[0] = addr_tensor;
7636
- result->opt[1] = c;
7755
+ result->src[0] = a;
7756
+ result->src[1] = b;
7757
+ result->src[2] = addr_tensor;
7758
+ result->src[3] = c;
7637
7759
 
7638
7760
  return result;
7639
7761
  }
@@ -7673,8 +7795,8 @@ struct ggml_tensor * ggml_cross_entropy_loss(
7673
7795
 
7674
7796
  result->op = GGML_OP_CROSS_ENTROPY_LOSS;
7675
7797
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7676
- result->src0 = a;
7677
- result->src1 = b;
7798
+ result->src[0] = a;
7799
+ result->src[1] = b;
7678
7800
 
7679
7801
  return result;
7680
7802
  }
@@ -7693,9 +7815,9 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
7693
7815
 
7694
7816
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
7695
7817
  result->grad = NULL;
7696
- result->src0 = a;
7697
- result->src1 = b;
7698
- result->opt[0] = c;
7818
+ result->src[0] = a;
7819
+ result->src[1] = b;
7820
+ result->src[2] = c;
7699
7821
 
7700
7822
  return result;
7701
7823
  }
@@ -8296,7 +8418,7 @@ static void ggml_compute_forward_add_f32(
8296
8418
  const struct ggml_tensor * src0,
8297
8419
  const struct ggml_tensor * src1,
8298
8420
  struct ggml_tensor * dst) {
8299
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8421
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
8300
8422
 
8301
8423
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8302
8424
  return;
@@ -8321,23 +8443,23 @@ static void ggml_compute_forward_add_f32(
8321
8443
 
8322
8444
  if (nb10 == sizeof(float)) {
8323
8445
  for (int ir = ir0; ir < ir1; ++ir) {
8324
- // src0, src1 and dst are same shape => same indices
8325
- const int i3 = ir/(ne2*ne1);
8326
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8327
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8446
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8447
+ const int64_t i03 = ir/(ne02*ne01);
8448
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8449
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8450
+
8451
+ const int64_t i13 = i03 % ne13;
8452
+ const int64_t i12 = i02 % ne12;
8453
+ const int64_t i11 = i01 % ne11;
8328
8454
 
8455
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8456
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8457
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
8329
8458
 
8330
8459
  #ifdef GGML_USE_ACCELERATE
8331
- vDSP_vadd(
8332
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
8333
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
8334
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
8335
- ne0);
8460
+ vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
8336
8461
  #else
8337
- ggml_vec_add_f32(ne0,
8338
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
8339
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
8340
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8462
+ ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8341
8463
  #endif
8342
8464
  // }
8343
8465
  // }
@@ -8345,15 +8467,20 @@ static void ggml_compute_forward_add_f32(
8345
8467
  } else {
8346
8468
  // src1 is not contiguous
8347
8469
  for (int ir = ir0; ir < ir1; ++ir) {
8348
- // src0, src1 and dst are same shape => same indices
8349
- const int i3 = ir/(ne2*ne1);
8350
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8351
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8470
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8471
+ const int64_t i03 = ir/(ne02*ne01);
8472
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8473
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8474
+
8475
+ const int64_t i13 = i03 % ne13;
8476
+ const int64_t i12 = i02 % ne12;
8477
+ const int64_t i11 = i01 % ne11;
8478
+
8479
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8480
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8352
8481
 
8353
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
8354
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8355
8482
  for (int i0 = 0; i0 < ne0; i0++) {
8356
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
8483
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
8357
8484
 
8358
8485
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8359
8486
  }
@@ -10532,7 +10659,6 @@ static void ggml_compute_forward_rms_norm_back(
10532
10659
  }
10533
10660
  }
10534
10661
 
10535
-
10536
10662
  // ggml_compute_forward_mul_mat
10537
10663
 
10538
10664
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -10576,17 +10702,19 @@ static void ggml_compute_forward_mul_mat(
10576
10702
  const int ith = params->ith;
10577
10703
  const int nth = params->nth;
10578
10704
 
10579
- GGML_ASSERT(ne02 == ne12);
10580
- GGML_ASSERT(ne03 == ne13);
10581
- GGML_ASSERT(ne2 == ne12);
10582
- GGML_ASSERT(ne3 == ne13);
10583
-
10584
10705
  const enum ggml_type type = src0->type;
10585
10706
 
10707
+ const bool src1_cont = ggml_is_contiguous(src1);
10708
+
10586
10709
  ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10587
10710
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10588
10711
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10589
10712
 
10713
+ GGML_ASSERT(ne0 == ne01);
10714
+ GGML_ASSERT(ne1 == ne11);
10715
+ GGML_ASSERT(ne2 == ne12);
10716
+ GGML_ASSERT(ne3 == ne13);
10717
+
10590
10718
  // we don't support permuted src0 or src1
10591
10719
  GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
10592
10720
  GGML_ASSERT(nb10 == sizeof(float));
@@ -10597,16 +10725,16 @@ static void ggml_compute_forward_mul_mat(
10597
10725
  GGML_ASSERT(nb1 <= nb2);
10598
10726
  GGML_ASSERT(nb2 <= nb3);
10599
10727
 
10600
- GGML_ASSERT(ne0 == ne01);
10601
- GGML_ASSERT(ne1 == ne11);
10602
- GGML_ASSERT(ne2 == ne02);
10603
- GGML_ASSERT(ne3 == ne03);
10604
-
10605
10728
  // nb01 >= nb00 - src0 is not transposed
10606
10729
  // compute by src0 rows
10607
10730
 
10608
10731
  #if defined(GGML_USE_CLBLAST)
10609
10732
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
10733
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10734
+ // ref: https://github.com/ggerganov/ggml/pull/224
10735
+ GGML_ASSERT(ne02 == ne12);
10736
+ GGML_ASSERT(ne03 == ne13);
10737
+
10610
10738
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
10611
10739
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
10612
10740
  }
@@ -10616,6 +10744,11 @@ static void ggml_compute_forward_mul_mat(
10616
10744
 
10617
10745
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10618
10746
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
10747
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10748
+ // ref: https://github.com/ggerganov/ggml/pull/224
10749
+ GGML_ASSERT(ne02 == ne12);
10750
+ GGML_ASSERT(ne03 == ne13);
10751
+
10619
10752
  if (params->ith != 0) {
10620
10753
  return;
10621
10754
  }
@@ -10636,7 +10769,7 @@ static void ggml_compute_forward_mul_mat(
10636
10769
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
10637
10770
 
10638
10771
  if (type != GGML_TYPE_F32) {
10639
- float * const wdata = params->wdata;
10772
+ float * const wdata = params->wdata;
10640
10773
  ggml_to_float_t const to_float = type_traits[type].to_float;
10641
10774
 
10642
10775
  size_t id = 0;
@@ -10685,43 +10818,52 @@ static void ggml_compute_forward_mul_mat(
10685
10818
  return;
10686
10819
  }
10687
10820
 
10688
- // parallelize by src0 rows using ggml_vec_dot_q
10821
+ // parallelize by src0 rows
10822
+ const int64_t dr = (ne01 + nth - 1)/nth;
10689
10823
 
10690
- // total rows in src0
10691
- const int nr = ne01*ne02*ne03;
10824
+ const int64_t ir10 = dr*ith;
10825
+ const int64_t ir11 = MIN(ir10 + dr, ne01);
10692
10826
 
10693
- // rows per thread
10694
- const int dr = (nr + nth - 1)/nth;
10827
+ // src1 rows
10828
+ const int64_t nr1 = ne11*ne12*ne13;
10695
10829
 
10696
- // row range for this thread
10697
- const int ir0 = dr*ith;
10698
- const int ir1 = MIN(ir0 + dr, nr);
10830
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10831
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10699
10832
 
10700
- void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10701
- const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10833
+ for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
10834
+ const int64_t i13 = (ir1/(ne12*ne11));
10835
+ const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
10836
+ const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
10702
10837
 
10703
- for (int ir = ir0; ir < ir1; ++ir) {
10704
- // src0 indices
10705
- const int i03 = ir/(ne02*ne01);
10706
- const int i02 = (ir - i03*ne02*ne01)/ne01;
10707
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
10838
+ const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
10839
+ const int64_t i03 = (ir0/(ne02));
10840
+ // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
10841
+ // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
10842
+ // GG: this is likely the correct way to broadcast, though need some more thought
10843
+ // therefore leaving the comments to remind us for now
10844
+ const int64_t i02 = (i12 / (ne12 / ne02));
10845
+ // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
10846
+ // const int64_t i02 = (ir0 - i03*ne02);
10708
10847
 
10709
- const int i13 = i03;
10710
- const int i12 = i02;
10848
+ const int64_t i1 = i11;
10849
+ const int64_t i2 = i12;
10850
+ const int64_t i3 = i13;
10711
10851
 
10712
- const int i0 = i01;
10713
- const int i2 = i02;
10714
- const int i3 = i03;
10852
+ const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
10715
10853
 
10716
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
10717
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
10854
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10855
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10856
+ // the original src1 data pointer, so we should index using the indices directly
10857
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
10858
+ const char * src1_col = (const char *) wdata +
10859
+ (src1_cont || src1->type != vec_dot_type
10860
+ ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10861
+ : (i11*nb11 + i12*nb12 + i13*nb13));
10718
10862
 
10719
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
10863
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10720
10864
 
10721
- assert(ne00 % 32 == 0);
10722
-
10723
- for (int64_t ic = 0; ic < ne11; ++ic) {
10724
- vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
10865
+ for (int64_t ir = ir10; ir < ir11; ++ir) {
10866
+ vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
10725
10867
  }
10726
10868
  }
10727
10869
 
@@ -11718,7 +11860,7 @@ static void ggml_compute_forward_alibi_f32(
11718
11860
 
11719
11861
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11720
11862
  const int ne1 = src0->ne[1]; // seq_len_without_past
11721
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11863
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11722
11864
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11723
11865
 
11724
11866
  const int n = ggml_nrows(src0);
@@ -11729,8 +11871,9 @@ static void ggml_compute_forward_alibi_f32(
11729
11871
  const int nb2 = src0->nb[2];
11730
11872
  //const int nb3 = src0->nb[3];
11731
11873
 
11732
- assert(nb0 == sizeof(float));
11733
- assert(ne1 + n_past == ne0); (void) n_past;
11874
+ GGML_ASSERT(nb0 == sizeof(float));
11875
+ GGML_ASSERT(ne1 + n_past == ne0);
11876
+ GGML_ASSERT(n_head == ne2);
11734
11877
 
11735
11878
  // add alibi to src0 (KQ_scaled)
11736
11879
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11754,7 +11897,7 @@ static void ggml_compute_forward_alibi_f32(
11754
11897
  m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11755
11898
  }
11756
11899
 
11757
- pdst[0] = (i-ne0+1) * m_k + src[0];
11900
+ pdst[0] = i * m_k + src[0];
11758
11901
 
11759
11902
  }
11760
11903
  }
@@ -11783,7 +11926,7 @@ static void ggml_compute_forward_alibi_f16(
11783
11926
 
11784
11927
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11785
11928
  const int ne1 = src0->ne[1]; // seq_len_without_past
11786
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11929
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11787
11930
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11788
11931
 
11789
11932
  const int n = ggml_nrows(src0);
@@ -11794,8 +11937,9 @@ static void ggml_compute_forward_alibi_f16(
11794
11937
  const int nb2 = src0->nb[2];
11795
11938
  //const int nb3 = src0->nb[3];
11796
11939
 
11797
- assert(nb0 == sizeof(ggml_fp16_t));
11798
- assert(ne1 + n_past == ne0); (void) n_past;
11940
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
11941
+ GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
11942
+ GGML_ASSERT(n_head == ne2);
11799
11943
 
11800
11944
  // add alibi to src0 (KQ_scaled)
11801
11945
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11820,7 +11964,7 @@ static void ggml_compute_forward_alibi_f16(
11820
11964
  }
11821
11965
 
11822
11966
  // we return F32
11823
- pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
11967
+ pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
11824
11968
  }
11825
11969
  }
11826
11970
  }
@@ -11948,16 +12092,21 @@ static void ggml_compute_forward_rope_f32(
11948
12092
  const struct ggml_tensor * src1,
11949
12093
  struct ggml_tensor * dst) {
11950
12094
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
11951
- GGML_ASSERT(ggml_nelements(src1) == 4);
12095
+ GGML_ASSERT(ggml_nelements(src1) == 6);
11952
12096
 
11953
12097
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11954
12098
  return;
11955
12099
  }
11956
12100
 
12101
+ float freq_base;
12102
+ float freq_scale;
12103
+
11957
12104
  const int n_past = ((int32_t *) src1->data)[0];
11958
12105
  const int n_dims = ((int32_t *) src1->data)[1];
11959
12106
  const int mode = ((int32_t *) src1->data)[2];
11960
12107
  const int n_ctx = ((int32_t *) src1->data)[3];
12108
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12109
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
11961
12110
 
11962
12111
  assert(n_past >= 0);
11963
12112
 
@@ -11986,7 +12135,7 @@ static void ggml_compute_forward_rope_f32(
11986
12135
  // row index used to determine which thread to use
11987
12136
  int ir = 0;
11988
12137
 
11989
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
12138
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
11990
12139
 
11991
12140
  const bool is_neox = mode & 2;
11992
12141
  const bool is_glm = mode & 4;
@@ -11998,7 +12147,7 @@ static void ggml_compute_forward_rope_f32(
11998
12147
  if (ir++ < ir0) continue;
11999
12148
  if (ir > ir1) break;
12000
12149
 
12001
- float theta = (float)p;
12150
+ float theta = freq_scale * (float)p;
12002
12151
 
12003
12152
  if (is_glm) {
12004
12153
  theta = MIN(p, n_ctx - 2);
@@ -12075,16 +12224,21 @@ static void ggml_compute_forward_rope_f16(
12075
12224
  const struct ggml_tensor * src1,
12076
12225
  struct ggml_tensor * dst) {
12077
12226
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
12078
- GGML_ASSERT(ggml_nelements(src1) == 4);
12227
+ GGML_ASSERT(ggml_nelements(src1) == 6);
12079
12228
 
12080
12229
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12081
12230
  return;
12082
12231
  }
12083
12232
 
12233
+ float freq_base;
12234
+ float freq_scale;
12235
+
12084
12236
  const int n_past = ((int32_t *) src1->data)[0];
12085
12237
  const int n_dims = ((int32_t *) src1->data)[1];
12086
12238
  const int mode = ((int32_t *) src1->data)[2];
12087
12239
  const int n_ctx = ((int32_t *) src1->data)[3];
12240
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12241
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
12088
12242
 
12089
12243
  assert(n_past >= 0);
12090
12244
 
@@ -12113,7 +12267,7 @@ static void ggml_compute_forward_rope_f16(
12113
12267
  // row index used to determine which thread to use
12114
12268
  int ir = 0;
12115
12269
 
12116
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
12270
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
12117
12271
 
12118
12272
  const bool is_neox = mode & 2;
12119
12273
  const bool is_glm = mode & 4;
@@ -12125,7 +12279,7 @@ static void ggml_compute_forward_rope_f16(
12125
12279
  if (ir++ < ir0) continue;
12126
12280
  if (ir > ir1) break;
12127
12281
 
12128
- float theta = (float)p;
12282
+ float theta = freq_scale * (float)p;
12129
12283
 
12130
12284
  if (is_glm) {
12131
12285
  theta = MIN(p, n_ctx - 2);
@@ -12186,7 +12340,7 @@ static void ggml_compute_forward_rope_f16(
12186
12340
  const float x0 = GGML_FP16_TO_FP32(src[0]);
12187
12341
  const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
12188
12342
 
12189
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12343
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12190
12344
  dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
12191
12345
  }
12192
12346
  }
@@ -12225,7 +12379,7 @@ static void ggml_compute_forward_rope_back_f32(
12225
12379
  const struct ggml_tensor * src1,
12226
12380
  struct ggml_tensor * dst) {
12227
12381
  assert(src1->type == GGML_TYPE_I32);
12228
- assert(ggml_nelements(src1) == 3);
12382
+ assert(ggml_nelements(src1) == 4);
12229
12383
 
12230
12384
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12231
12385
  return;
@@ -12868,12 +13022,13 @@ static void ggml_compute_forward_conv_1d(
12868
13022
  };
12869
13023
  }
12870
13024
 
12871
- // ggml_compute_forward_conv_2d_sk_p0
13025
+ // ggml_compute_forward_conv_2d
12872
13026
 
12873
- static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
13027
+ static void ggml_compute_forward_conv_2d_f16_f32(
12874
13028
  const struct ggml_compute_params * params,
12875
13029
  const struct ggml_tensor * src0,
12876
13030
  const struct ggml_tensor * src1,
13031
+ const struct ggml_tensor * opt0,
12877
13032
  struct ggml_tensor * dst) {
12878
13033
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
12879
13034
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -12893,11 +13048,17 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12893
13048
  // size of the convolution row - the kernel size unrolled across all channels
12894
13049
  const int ew0 = nk0*nk1*ne02;
12895
13050
 
13051
+ const int32_t s0 = ((const int32_t*)(opt0->data))[0];
13052
+ const int32_t s1 = ((const int32_t*)(opt0->data))[1];
13053
+ const int32_t p0 = ((const int32_t*)(opt0->data))[2];
13054
+ const int32_t p1 = ((const int32_t*)(opt0->data))[3];
13055
+ const int32_t d0 = ((const int32_t*)(opt0->data))[4];
13056
+ const int32_t d1 = ((const int32_t*)(opt0->data))[5];
13057
+
12896
13058
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
12897
13059
  GGML_ASSERT(nb10 == sizeof(float));
12898
13060
 
12899
13061
  if (params->type == GGML_TASK_INIT) {
12900
- // TODO: fix this memset (wsize is overestimated)
12901
13062
  memset(params->wdata, 0, params->wsize);
12902
13063
 
12903
13064
  // prepare source data (src1)
@@ -12912,8 +13073,13 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12912
13073
  for (int i0 = 0; i0 < ne0; i0++) {
12913
13074
  for (int ik1 = 0; ik1 < nk1; ik1++) {
12914
13075
  for (int ik0 = 0; ik0 < nk0; ik0++) {
12915
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
12916
- GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13076
+ const int idx0 = i0*s0 + ik0*d0 - p0;
13077
+ const int idx1 = i1*s1 + ik1*d1 - p1;
13078
+
13079
+ if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
13080
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
13081
+ GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
13082
+ }
12917
13083
  }
12918
13084
  }
12919
13085
  }
@@ -12940,32 +13106,36 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12940
13106
 
12941
13107
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12942
13108
 
12943
- for (int i2 = ip0; i2 < ip1; i2++) {
12944
- float * dst_data = (float *)((char *) dst->data + i2*nb2);
12945
-
12946
- for (int i1 = 0; i1 < ne1; ++i1) {
12947
- for (int i0 = 0; i0 < ne0; ++i0) {
12948
- ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
12949
- (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
12950
- (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
13109
+ for (int i3 = 0; i3 < ne3; i3++) {
13110
+ for (int i2 = ip0; i2 < ip1; i2++) {
13111
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
13112
+
13113
+ for (int i1 = 0; i1 < ne1; ++i1) {
13114
+ for (int i0 = 0; i0 < ne0; ++i0) {
13115
+ ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
13116
+ (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
13117
+ (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
13118
+ }
12951
13119
  }
12952
13120
  }
12953
13121
  }
12954
13122
  }
12955
13123
 
12956
- static void ggml_compute_forward_conv_2d_sk_p0(
13124
+ static void ggml_compute_forward_conv_2d(
12957
13125
  const struct ggml_compute_params * params,
12958
13126
  const struct ggml_tensor * src0,
12959
13127
  const struct ggml_tensor * src1,
12960
- struct ggml_tensor * dst) {
13128
+ const struct ggml_tensor * opt0,
13129
+ struct ggml_tensor * dst
13130
+ ) {
12961
13131
  switch (src0->type) {
12962
13132
  case GGML_TYPE_F16:
12963
13133
  {
12964
- ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
13134
+ ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
12965
13135
  } break;
12966
13136
  case GGML_TYPE_F32:
12967
13137
  {
12968
- //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
13138
+ //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
12969
13139
  GGML_ASSERT(false);
12970
13140
  } break;
12971
13141
  default:
@@ -12975,31 +13145,164 @@ static void ggml_compute_forward_conv_2d_sk_p0(
12975
13145
  }
12976
13146
  }
12977
13147
 
12978
- // ggml_compute_forward_conv_2d
13148
+ // ggml_compute_forward_pool_1d_sk_p0
12979
13149
 
12980
- static void ggml_compute_forward_conv_2d(
13150
+ static void ggml_compute_forward_pool_1d_sk_p0(
13151
+ const struct ggml_compute_params * params,
13152
+ const enum ggml_op_pool op,
13153
+ const struct ggml_tensor * src,
13154
+ const int k,
13155
+ struct ggml_tensor * dst) {
13156
+ assert(src->type == GGML_TYPE_F32);
13157
+ assert(params->ith == 0);
13158
+
13159
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13160
+ return;
13161
+ }
13162
+
13163
+ const char * cdata = (const char *)src->data;
13164
+ const char * const data_end = cdata + ggml_nbytes(src);
13165
+ float * drow = (float *)dst->data;
13166
+
13167
+ const int64_t rs = dst->ne[0];
13168
+
13169
+ while (cdata < data_end) {
13170
+ const float * const srow = (const float *)cdata;
13171
+
13172
+ int j = 0;
13173
+
13174
+ for (int64_t i = 0; i < rs; ++i) {
13175
+ switch (op) {
13176
+ case GGML_OP_POOL_AVG: drow[i] = 0; break;
13177
+ case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
13178
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13179
+ }
13180
+ for (int ki = 0; ki < k; ++ki) {
13181
+ switch (op) {
13182
+ case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
13183
+ case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
13184
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13185
+ }
13186
+ ++j;
13187
+ }
13188
+ switch (op) {
13189
+ case GGML_OP_POOL_AVG: drow[i] /= k; break;
13190
+ case GGML_OP_POOL_MAX: break;
13191
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13192
+ }
13193
+ }
13194
+
13195
+ cdata += src->nb[1];
13196
+ drow += rs;
13197
+ }
13198
+ }
13199
+
13200
+ // ggml_compute_forward_pool_1d
13201
+
13202
+ static void ggml_compute_forward_pool_1d(
12981
13203
  const struct ggml_compute_params* params,
12982
13204
  const struct ggml_tensor* src0,
12983
- const struct ggml_tensor* src1,
12984
13205
  const struct ggml_tensor* opt0,
12985
13206
  struct ggml_tensor* dst) {
12986
- const int32_t s0 = ((const int32_t*)(opt0->data))[0];
12987
- const int32_t s1 = ((const int32_t*)(opt0->data))[1];
12988
- const int32_t p0 = ((const int32_t*)(opt0->data))[2];
12989
- const int32_t p1 = ((const int32_t*)(opt0->data))[3];
12990
- const int32_t d0 = ((const int32_t*)(opt0->data))[4];
12991
- const int32_t d1 = ((const int32_t*)(opt0->data))[5];
12992
- GGML_ASSERT(d0 == 1); // dilation not supported
12993
- GGML_ASSERT(d1 == 1);
13207
+ GGML_ASSERT(opt0->ne[0] == 4);
13208
+ const int* opts = (const int*)opt0->data;
13209
+ enum ggml_op_pool op = opts[0];
13210
+ const int k0 = opts[1];
13211
+ const int s0 = opts[2];
13212
+ const int p0 = opts[3];
12994
13213
  GGML_ASSERT(p0 == 0); // padding not supported
12995
- GGML_ASSERT(p1 == 0);
13214
+ GGML_ASSERT(k0 == s0); // only s = k supported
13215
+
13216
+ ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
13217
+ }
12996
13218
 
12997
- if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
12998
- ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
13219
+ // ggml_compute_forward_pool_2d_sk_p0
13220
+
13221
+ static void ggml_compute_forward_pool_2d_sk_p0(
13222
+ const struct ggml_compute_params * params,
13223
+ const enum ggml_op_pool op,
13224
+ const struct ggml_tensor * src,
13225
+ const int k0,
13226
+ const int k1,
13227
+ struct ggml_tensor * dst) {
13228
+ assert(src->type == GGML_TYPE_F32);
13229
+ assert(params->ith == 0);
13230
+
13231
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13232
+ return;
13233
+ }
13234
+
13235
+ const char * cdata = (const char*)src->data;
13236
+ const char * const data_end = cdata + ggml_nbytes(src);
13237
+
13238
+ const int64_t px = dst->ne[0];
13239
+ const int64_t py = dst->ne[1];
13240
+ const int64_t pa = px * py;
13241
+
13242
+ float * dplane = (float *)dst->data;
13243
+
13244
+ const int ka = k0 * k1;
13245
+
13246
+ while (cdata < data_end) {
13247
+ for (int oy = 0; oy < py; ++oy) {
13248
+ float * const drow = dplane + oy * px;
13249
+ for (int ox = 0; ox < px; ++ox) {
13250
+ float * const out = drow + ox;
13251
+ switch (op) {
13252
+ case GGML_OP_POOL_AVG: *out = 0; break;
13253
+ case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
13254
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13255
+ }
13256
+
13257
+ const int ix = ox * k0;
13258
+ const int iy = oy * k1;
13259
+
13260
+ for (int ky = 0; ky < k1; ++ky) {
13261
+ const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
13262
+ for (int kx = 0; kx < k0; ++kx) {
13263
+ int j = ix + kx;
13264
+ switch (op) {
13265
+ case GGML_OP_POOL_AVG: *out += srow[j]; break;
13266
+ case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
13267
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13268
+ }
13269
+ }
13270
+ }
13271
+ switch (op) {
13272
+ case GGML_OP_POOL_AVG: *out /= ka; break;
13273
+ case GGML_OP_POOL_MAX: break;
13274
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13275
+ }
13276
+ }
13277
+ }
13278
+
13279
+ cdata += src->nb[2];
13280
+ dplane += pa;
12999
13281
  }
13000
- else {
13001
- GGML_ASSERT(false); // only stride equal to kernel size is supported
13002
- };
13282
+ }
13283
+
13284
+ // ggml_compute_forward_pool_2d
13285
+
13286
+ static void ggml_compute_forward_pool_2d(
13287
+ const struct ggml_compute_params * params,
13288
+ const struct ggml_tensor * src0,
13289
+ const struct ggml_tensor * opt0,
13290
+ struct ggml_tensor * dst) {
13291
+ GGML_ASSERT(opt0->ne[0] == 7);
13292
+ const int* opts = (const int*)opt0->data;
13293
+ enum ggml_op_pool op = opts[0];
13294
+ const int k0 = opts[1];
13295
+ const int k1 = opts[2];
13296
+ const int s0 = opts[3];
13297
+ const int s1 = opts[4];
13298
+ const int p0 = opts[5];
13299
+ const int p1 = opts[6];
13300
+ GGML_ASSERT(p0 == 0);
13301
+ GGML_ASSERT(p1 == 0); // padding not supported
13302
+ GGML_ASSERT(k0 == s0);
13303
+ GGML_ASSERT(k1 == s1); // only s = k supported
13304
+
13305
+ ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
13003
13306
  }
13004
13307
 
13005
13308
 
@@ -14566,287 +14869,295 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14566
14869
  if (skip_cpu) {
14567
14870
  return;
14568
14871
  }
14569
- GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
14570
- GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
14872
+ GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
14873
+ GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
14571
14874
  #endif // GGML_USE_CUBLAS
14572
14875
 
14573
14876
  switch (tensor->op) {
14574
14877
  case GGML_OP_DUP:
14575
14878
  {
14576
- ggml_compute_forward_dup(params, tensor->src0, tensor);
14879
+ ggml_compute_forward_dup(params, tensor->src[0], tensor);
14577
14880
  } break;
14578
14881
  case GGML_OP_ADD:
14579
14882
  {
14580
- ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
14883
+ ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
14581
14884
  } break;
14582
14885
  case GGML_OP_ADD1:
14583
14886
  {
14584
- ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor);
14887
+ ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor);
14585
14888
  } break;
14586
14889
  case GGML_OP_ACC:
14587
14890
  {
14588
- ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14891
+ ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14589
14892
  } break;
14590
14893
  case GGML_OP_SUB:
14591
14894
  {
14592
- ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
14895
+ ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor);
14593
14896
  } break;
14594
14897
  case GGML_OP_MUL:
14595
14898
  {
14596
- ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
14899
+ ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor);
14597
14900
  } break;
14598
14901
  case GGML_OP_DIV:
14599
14902
  {
14600
- ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
14903
+ ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor);
14601
14904
  } break;
14602
14905
  case GGML_OP_SQR:
14603
14906
  {
14604
- ggml_compute_forward_sqr(params, tensor->src0, tensor);
14907
+ ggml_compute_forward_sqr(params, tensor->src[0], tensor);
14605
14908
  } break;
14606
14909
  case GGML_OP_SQRT:
14607
14910
  {
14608
- ggml_compute_forward_sqrt(params, tensor->src0, tensor);
14911
+ ggml_compute_forward_sqrt(params, tensor->src[0], tensor);
14609
14912
  } break;
14610
14913
  case GGML_OP_LOG:
14611
14914
  {
14612
- ggml_compute_forward_log(params, tensor->src0, tensor);
14915
+ ggml_compute_forward_log(params, tensor->src[0], tensor);
14613
14916
  } break;
14614
14917
  case GGML_OP_SUM:
14615
14918
  {
14616
- ggml_compute_forward_sum(params, tensor->src0, tensor);
14919
+ ggml_compute_forward_sum(params, tensor->src[0], tensor);
14617
14920
  } break;
14618
14921
  case GGML_OP_SUM_ROWS:
14619
14922
  {
14620
- ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
14923
+ ggml_compute_forward_sum_rows(params, tensor->src[0], tensor);
14621
14924
  } break;
14622
14925
  case GGML_OP_MEAN:
14623
14926
  {
14624
- ggml_compute_forward_mean(params, tensor->src0, tensor);
14927
+ ggml_compute_forward_mean(params, tensor->src[0], tensor);
14625
14928
  } break;
14626
14929
  case GGML_OP_ARGMAX:
14627
14930
  {
14628
- ggml_compute_forward_argmax(params, tensor->src0, tensor);
14931
+ ggml_compute_forward_argmax(params, tensor->src[0], tensor);
14629
14932
  } break;
14630
14933
  case GGML_OP_REPEAT:
14631
14934
  {
14632
- ggml_compute_forward_repeat(params, tensor->src0, tensor);
14935
+ ggml_compute_forward_repeat(params, tensor->src[0], tensor);
14633
14936
  } break;
14634
14937
  case GGML_OP_REPEAT_BACK:
14635
14938
  {
14636
- ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14939
+ ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
14637
14940
  } break;
14638
14941
  case GGML_OP_ABS:
14639
14942
  {
14640
- ggml_compute_forward_abs(params, tensor->src0, tensor);
14943
+ ggml_compute_forward_abs(params, tensor->src[0], tensor);
14641
14944
  } break;
14642
14945
  case GGML_OP_SGN:
14643
14946
  {
14644
- ggml_compute_forward_sgn(params, tensor->src0, tensor);
14947
+ ggml_compute_forward_sgn(params, tensor->src[0], tensor);
14645
14948
  } break;
14646
14949
  case GGML_OP_NEG:
14647
14950
  {
14648
- ggml_compute_forward_neg(params, tensor->src0, tensor);
14951
+ ggml_compute_forward_neg(params, tensor->src[0], tensor);
14649
14952
  } break;
14650
14953
  case GGML_OP_STEP:
14651
14954
  {
14652
- ggml_compute_forward_step(params, tensor->src0, tensor);
14955
+ ggml_compute_forward_step(params, tensor->src[0], tensor);
14653
14956
  } break;
14654
14957
  case GGML_OP_TANH:
14655
14958
  {
14656
- ggml_compute_forward_tanh(params, tensor->src0, tensor);
14959
+ ggml_compute_forward_tanh(params, tensor->src[0], tensor);
14657
14960
  } break;
14658
14961
  case GGML_OP_ELU:
14659
14962
  {
14660
- ggml_compute_forward_elu(params, tensor->src0, tensor);
14963
+ ggml_compute_forward_elu(params, tensor->src[0], tensor);
14661
14964
  } break;
14662
14965
  case GGML_OP_RELU:
14663
14966
  {
14664
- ggml_compute_forward_relu(params, tensor->src0, tensor);
14967
+ ggml_compute_forward_relu(params, tensor->src[0], tensor);
14665
14968
  } break;
14666
14969
  case GGML_OP_GELU:
14667
14970
  {
14668
- ggml_compute_forward_gelu(params, tensor->src0, tensor);
14971
+ ggml_compute_forward_gelu(params, tensor->src[0], tensor);
14669
14972
  } break;
14670
14973
  case GGML_OP_GELU_QUICK:
14671
14974
  {
14672
- ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
14975
+ ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
14673
14976
  } break;
14674
14977
  case GGML_OP_SILU:
14675
14978
  {
14676
- ggml_compute_forward_silu(params, tensor->src0, tensor);
14979
+ ggml_compute_forward_silu(params, tensor->src[0], tensor);
14677
14980
  } break;
14678
14981
  case GGML_OP_SILU_BACK:
14679
14982
  {
14680
- ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor);
14983
+ ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
14681
14984
  } break;
14682
14985
  case GGML_OP_NORM:
14683
14986
  {
14684
- ggml_compute_forward_norm(params, tensor->src0, tensor);
14987
+ ggml_compute_forward_norm(params, tensor->src[0], tensor);
14685
14988
  } break;
14686
14989
  case GGML_OP_RMS_NORM:
14687
14990
  {
14688
- ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
14991
+ ggml_compute_forward_rms_norm(params, tensor->src[0], tensor);
14689
14992
  } break;
14690
14993
  case GGML_OP_RMS_NORM_BACK:
14691
14994
  {
14692
- ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor);
14995
+ ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
14693
14996
  } break;
14694
14997
  case GGML_OP_MUL_MAT:
14695
14998
  {
14696
- ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
14999
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14697
15000
  } break;
14698
15001
  case GGML_OP_OUT_PROD:
14699
15002
  {
14700
- ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
15003
+ ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
14701
15004
  } break;
14702
15005
  case GGML_OP_SCALE:
14703
15006
  {
14704
- ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
15007
+ ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
14705
15008
  } break;
14706
15009
  case GGML_OP_SET:
14707
15010
  {
14708
- ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15011
+ ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14709
15012
  } break;
14710
15013
  case GGML_OP_CPY:
14711
15014
  {
14712
- ggml_compute_forward_cpy(params, tensor->src0, tensor);
15015
+ ggml_compute_forward_cpy(params, tensor->src[0], tensor);
14713
15016
  } break;
14714
15017
  case GGML_OP_CONT:
14715
15018
  {
14716
- ggml_compute_forward_cont(params, tensor->src0, tensor);
15019
+ ggml_compute_forward_cont(params, tensor->src[0], tensor);
14717
15020
  } break;
14718
15021
  case GGML_OP_RESHAPE:
14719
15022
  {
14720
- ggml_compute_forward_reshape(params, tensor->src0, tensor);
15023
+ ggml_compute_forward_reshape(params, tensor->src[0], tensor);
14721
15024
  } break;
14722
15025
  case GGML_OP_VIEW:
14723
15026
  {
14724
- ggml_compute_forward_view(params, tensor->src0);
15027
+ ggml_compute_forward_view(params, tensor->src[0]);
14725
15028
  } break;
14726
15029
  case GGML_OP_PERMUTE:
14727
15030
  {
14728
- ggml_compute_forward_permute(params, tensor->src0);
15031
+ ggml_compute_forward_permute(params, tensor->src[0]);
14729
15032
  } break;
14730
15033
  case GGML_OP_TRANSPOSE:
14731
15034
  {
14732
- ggml_compute_forward_transpose(params, tensor->src0);
15035
+ ggml_compute_forward_transpose(params, tensor->src[0]);
14733
15036
  } break;
14734
15037
  case GGML_OP_GET_ROWS:
14735
15038
  {
14736
- ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
15039
+ ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor);
14737
15040
  } break;
14738
15041
  case GGML_OP_GET_ROWS_BACK:
14739
15042
  {
14740
- ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15043
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14741
15044
  } break;
14742
15045
  case GGML_OP_DIAG:
14743
15046
  {
14744
- ggml_compute_forward_diag(params, tensor->src0, tensor);
15047
+ ggml_compute_forward_diag(params, tensor->src[0], tensor);
14745
15048
  } break;
14746
15049
  case GGML_OP_DIAG_MASK_INF:
14747
15050
  {
14748
- ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
15051
+ ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
14749
15052
  } break;
14750
15053
  case GGML_OP_DIAG_MASK_ZERO:
14751
15054
  {
14752
- ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor);
15055
+ ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
14753
15056
  } break;
14754
15057
  case GGML_OP_SOFT_MAX:
14755
15058
  {
14756
- ggml_compute_forward_soft_max(params, tensor->src0, tensor);
15059
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
14757
15060
  } break;
14758
15061
  case GGML_OP_SOFT_MAX_BACK:
14759
15062
  {
14760
- ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
15063
+ ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor);
14761
15064
  } break;
14762
15065
  case GGML_OP_ROPE:
14763
15066
  {
14764
- ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
15067
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
14765
15068
  } break;
14766
15069
  case GGML_OP_ROPE_BACK:
14767
15070
  {
14768
- ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
15071
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
14769
15072
  } break;
14770
15073
  case GGML_OP_ALIBI:
14771
15074
  {
14772
- ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
15075
+ ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
14773
15076
  } break;
14774
15077
  case GGML_OP_CLAMP:
14775
15078
  {
14776
- ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
15079
+ ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
14777
15080
  } break;
14778
15081
  case GGML_OP_CONV_1D:
14779
15082
  {
14780
- ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15083
+ ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14781
15084
  } break;
14782
15085
  case GGML_OP_CONV_2D:
14783
15086
  {
14784
- ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15087
+ ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
15088
+ } break;
15089
+ case GGML_OP_POOL_1D:
15090
+ {
15091
+ ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
15092
+ } break;
15093
+ case GGML_OP_POOL_2D:
15094
+ {
15095
+ ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
14785
15096
  } break;
14786
15097
  case GGML_OP_FLASH_ATTN:
14787
15098
  {
14788
- const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15099
+ const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
14789
15100
  GGML_ASSERT(t == 0 || t == 1);
14790
15101
  const bool masked = t != 0;
14791
- ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
15102
+ ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
14792
15103
  } break;
14793
15104
  case GGML_OP_FLASH_FF:
14794
15105
  {
14795
- ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
15106
+ ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
14796
15107
  } break;
14797
15108
  case GGML_OP_FLASH_ATTN_BACK:
14798
15109
  {
14799
- int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
15110
+ int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
14800
15111
  GGML_ASSERT(t == 0 || t == 1);
14801
15112
  bool masked = t != 0;
14802
- ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
15113
+ ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
14803
15114
  } break;
14804
15115
  case GGML_OP_WIN_PART:
14805
15116
  {
14806
- ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
15117
+ ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
14807
15118
  } break;
14808
15119
  case GGML_OP_WIN_UNPART:
14809
15120
  {
14810
- ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
15121
+ ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
14811
15122
  } break;
14812
15123
  case GGML_OP_MAP_UNARY:
14813
15124
  {
14814
- const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
14815
- ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
15125
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
15126
+ ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
14816
15127
  }
14817
15128
  break;
14818
15129
  case GGML_OP_MAP_BINARY:
14819
15130
  {
14820
- const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
14821
- ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
15131
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
15132
+ ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
14822
15133
  }
14823
15134
  break;
14824
15135
  case GGML_OP_MAP_CUSTOM1:
14825
15136
  {
14826
- const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
14827
- ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
15137
+ const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
15138
+ ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
14828
15139
  }
14829
15140
  break;
14830
15141
  case GGML_OP_MAP_CUSTOM2:
14831
15142
  {
14832
- const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
14833
- ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
15143
+ const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
15144
+ ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
14834
15145
  }
14835
15146
  break;
14836
15147
  case GGML_OP_MAP_CUSTOM3:
14837
15148
  {
14838
- const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
14839
- ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
15149
+ const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
15150
+ ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
14840
15151
  }
14841
15152
  break;
14842
15153
  case GGML_OP_CROSS_ENTROPY_LOSS:
14843
15154
  {
14844
- ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
15155
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor);
14845
15156
  }
14846
15157
  break;
14847
15158
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14848
15159
  {
14849
- ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15160
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14850
15161
  }
14851
15162
  break;
14852
15163
  case GGML_OP_NONE:
@@ -14863,8 +15174,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14863
15174
  ////////////////////////////////////////////////////////////////////////////////
14864
15175
 
14865
15176
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
14866
- struct ggml_tensor * src0 = tensor->src0;
14867
- struct ggml_tensor * src1 = tensor->src1;
15177
+ struct ggml_tensor * src0 = tensor->src[0];
15178
+ struct ggml_tensor * src1 = tensor->src[1];
14868
15179
 
14869
15180
  switch (tensor->op) {
14870
15181
  case GGML_OP_DUP:
@@ -14900,12 +15211,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
14900
15211
  src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
14901
15212
  }
14902
15213
  if (src1->grad) {
14903
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
14904
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
14905
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
14906
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
14907
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
14908
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15214
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15215
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15216
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15217
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15218
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15219
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
14909
15220
 
14910
15221
  struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
14911
15222
  tensor->grad,
@@ -15213,12 +15524,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15213
15524
  } break;
15214
15525
  case GGML_OP_SET:
15215
15526
  {
15216
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
15217
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
15218
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
15219
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
15220
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
15221
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15527
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15528
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15529
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15530
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15531
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15532
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
15222
15533
 
15223
15534
  struct ggml_tensor * tensor_grad_view = NULL;
15224
15535
 
@@ -15295,8 +15606,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15295
15606
  if (src0->grad) {
15296
15607
  size_t offset;
15297
15608
 
15298
- GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
15299
- memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
15609
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
15610
+ memcpy(&offset, tensor->src[2]->data, sizeof(offset));
15300
15611
 
15301
15612
  size_t nb1 = tensor->nb[1];
15302
15613
  size_t nb2 = tensor->nb[2];
@@ -15323,7 +15634,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15323
15634
  {
15324
15635
  // necessary for llama
15325
15636
  if (src0->grad) {
15326
- int32_t * axes = (int32_t *) tensor->opt[0]->data;
15637
+ int32_t * axes = (int32_t *) tensor->src[2]->data;
15327
15638
  int axis0 = axes[0] & 0x3;
15328
15639
  int axis1 = axes[1] & 0x3;
15329
15640
  int axis2 = axes[2] & 0x3;
@@ -15427,17 +15738,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15427
15738
  // necessary for llama
15428
15739
  if (src0->grad) {
15429
15740
  assert(src1->type == GGML_TYPE_I32);
15430
- assert(ggml_nelements(src1) == 4);
15741
+ assert(ggml_nelements(src1) == 6);
15431
15742
  const int n_past = ((int32_t *) src1->data)[0];
15432
15743
  const int n_dims = ((int32_t *) src1->data)[1];
15433
15744
  const int mode = ((int32_t *) src1->data)[2];
15745
+ const int n_ctx = ((int32_t *) src1->data)[3];
15434
15746
  src0->grad = ggml_add_impl(ctx,
15435
15747
  src0->grad,
15436
15748
  ggml_rope_back(ctx,
15437
15749
  tensor->grad,
15438
15750
  n_past,
15439
15751
  n_dims,
15440
- mode),
15752
+ mode,
15753
+ n_ctx),
15441
15754
  inplace);
15442
15755
  }
15443
15756
  if (src1->grad) {
@@ -15483,18 +15796,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15483
15796
  {
15484
15797
  GGML_ASSERT(false); // TODO: not implemented
15485
15798
  } break;
15799
+ case GGML_OP_POOL_1D:
15800
+ {
15801
+ GGML_ASSERT(false); // TODO: not implemented
15802
+ } break;
15803
+ case GGML_OP_POOL_2D:
15804
+ {
15805
+ GGML_ASSERT(false); // TODO: not implemented
15806
+ } break;
15486
15807
  case GGML_OP_FLASH_ATTN:
15487
15808
  {
15488
15809
  struct ggml_tensor * flash_grad = NULL;
15489
- if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15490
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15810
+ if (src0->grad || src1->grad || tensor->src[2]->grad) {
15811
+ int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
15491
15812
  GGML_ASSERT(t == 0 || t == 1);
15492
15813
  bool masked = t != 0;
15493
15814
  flash_grad =
15494
15815
  ggml_flash_attn_back(ctx,
15495
15816
  src0,
15496
15817
  src1,
15497
- tensor->opt[0],
15818
+ tensor->src[2],
15498
15819
  tensor->grad,
15499
15820
  masked);
15500
15821
  }
@@ -15591,7 +15912,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15591
15912
  inplace);
15592
15913
  }
15593
15914
 
15594
- struct ggml_tensor * opt0 = tensor->opt[0];
15915
+ struct ggml_tensor * opt0 = tensor->src[2];
15595
15916
 
15596
15917
  if (opt0->grad) {
15597
15918
  struct ggml_tensor * grad_v = NULL;
@@ -15707,17 +16028,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
15707
16028
  }
15708
16029
  }
15709
16030
 
15710
- if (node->src0) {
15711
- ggml_visit_parents(cgraph, node->src0);
15712
- }
15713
-
15714
- if (node->src1) {
15715
- ggml_visit_parents(cgraph, node->src1);
15716
- }
15717
-
15718
- for (int i = 0; i < GGML_MAX_OPT; ++i) {
15719
- if (node->opt[i]) {
15720
- ggml_visit_parents(cgraph, node->opt[i]);
16031
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
16032
+ if (node->src[i]) {
16033
+ ggml_visit_parents(cgraph, node->src[i]);
15721
16034
  }
15722
16035
  }
15723
16036
 
@@ -15772,9 +16085,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
15772
16085
  struct ggml_cgraph result = {
15773
16086
  /*.n_nodes =*/ 0,
15774
16087
  /*.n_leafs =*/ 0,
15775
- /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
15776
- /*.work_size =*/ 0,
15777
- /*.work =*/ NULL,
15778
16088
  /*.nodes =*/ { NULL },
15779
16089
  /*.grads =*/ { NULL },
15780
16090
  /*.leafs =*/ { NULL },
@@ -15945,16 +16255,20 @@ void clear_numa_thread_affinity(void) {}
15945
16255
  #endif
15946
16256
 
15947
16257
  struct ggml_compute_state_shared {
15948
- struct ggml_cgraph * cgraph;
16258
+ const struct ggml_cgraph * cgraph;
16259
+ const struct ggml_cplan * cplan;
15949
16260
 
15950
16261
  int64_t perf_node_start_cycles;
15951
16262
  int64_t perf_node_start_time_us;
15952
16263
 
15953
- int n_threads;
16264
+ const int n_threads;
15954
16265
 
15955
16266
  // synchronization primitives
15956
16267
  atomic_int n_active; // num active threads
15957
16268
  atomic_int node_n; // active graph node
16269
+
16270
+ bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
16271
+ void * abort_callback_data;
15958
16272
  };
15959
16273
 
15960
16274
  struct ggml_compute_state {
@@ -15974,14 +16288,22 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
15974
16288
 
15975
16289
  static thread_ret_t ggml_graph_compute_thread(void * data) {
15976
16290
  struct ggml_compute_state * state = (struct ggml_compute_state *) data;
15977
- struct ggml_cgraph * cgraph = state->shared->cgraph;
15978
16291
 
15979
- const int n_threads = state->shared->n_threads;
16292
+ const struct ggml_cgraph * cgraph = state->shared->cgraph;
16293
+ const struct ggml_cplan * cplan = state->shared->cplan;
16294
+
16295
+ const int * n_tasks_arr = cplan->n_tasks;
16296
+ const int n_threads = state->shared->n_threads;
16297
+
15980
16298
  set_numa_thread_affinity(state->ith, n_threads);
15981
16299
 
15982
16300
  int node_n = -1;
15983
16301
 
15984
16302
  while (true) {
16303
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16304
+ state->shared->node_n += 1;
16305
+ return (thread_ret_t) GGML_EXIT_ABORTED;
16306
+ }
15985
16307
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
15986
16308
  // all other threads are finished and spinning
15987
16309
  // do finalize and init here so we don't have synchronize again
@@ -15989,18 +16311,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
15989
16311
  /*.type =*/ GGML_TASK_FINALIZE,
15990
16312
  /*.ith =*/ 0,
15991
16313
  /*.nth =*/ 0,
15992
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
15993
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16314
+ /*.wsize =*/ cplan->work_size,
16315
+ /*.wdata =*/ cplan->work_data,
15994
16316
  };
15995
16317
 
15996
16318
  if (node_n != -1) {
15997
16319
  /* FINALIZE */
15998
16320
  struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
15999
16321
  if (GGML_OP_HAS_FINALIZE[node->op]) {
16000
- params.nth = node->n_tasks;
16322
+ params.nth = n_tasks_arr[node_n];
16001
16323
  ggml_compute_forward(&params, node);
16002
- ggml_graph_compute_perf_stats_node(node, state->shared);
16003
16324
  }
16325
+ ggml_graph_compute_perf_stats_node(node, state->shared);
16004
16326
  }
16005
16327
 
16006
16328
  // distribute new work or execute it direct if 1T
@@ -16008,11 +16330,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16008
16330
  GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
16009
16331
 
16010
16332
  struct ggml_tensor * node = cgraph->nodes[node_n];
16333
+ const int n_tasks = n_tasks_arr[node_n];
16011
16334
 
16012
16335
  state->shared->perf_node_start_cycles = ggml_perf_cycles();
16013
16336
  state->shared->perf_node_start_time_us = ggml_perf_time_us();
16014
16337
 
16015
- params.nth = node->n_tasks;
16338
+ params.nth = n_tasks;
16016
16339
 
16017
16340
  /* INIT */
16018
16341
  if (GGML_OP_HAS_INIT[node->op]) {
@@ -16020,7 +16343,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16020
16343
  ggml_compute_forward(&params, node);
16021
16344
  }
16022
16345
 
16023
- if (node->n_tasks == 1) {
16346
+ if (n_tasks == 1) {
16024
16347
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
16025
16348
  // they do something more efficient than spinning (?)
16026
16349
  params.type = GGML_TASK_COMPUTE;
@@ -16029,11 +16352,16 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16029
16352
  if (GGML_OP_HAS_FINALIZE[node->op]) {
16030
16353
  params.type = GGML_TASK_FINALIZE;
16031
16354
  ggml_compute_forward(&params, node);
16032
- ggml_graph_compute_perf_stats_node(node, state->shared);
16033
16355
  }
16356
+
16357
+ ggml_graph_compute_perf_stats_node(node, state->shared);
16034
16358
  } else {
16035
16359
  break;
16036
16360
  }
16361
+
16362
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16363
+ break;
16364
+ }
16037
16365
  }
16038
16366
 
16039
16367
  atomic_store(&state->shared->n_active, n_threads);
@@ -16042,7 +16370,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16042
16370
  // wait for other threads to finish
16043
16371
  const int last = node_n;
16044
16372
  do {
16045
- sched_yield();
16373
+ //sched_yield();
16046
16374
  node_n = atomic_load(&state->shared->node_n);
16047
16375
  } while (node_n == last);
16048
16376
  }
@@ -16052,366 +16380,398 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16052
16380
 
16053
16381
  /* COMPUTE */
16054
16382
  struct ggml_tensor * node = cgraph->nodes[node_n];
16383
+ const int n_tasks = n_tasks_arr[node_n];
16055
16384
 
16056
16385
  struct ggml_compute_params params = {
16057
16386
  /*.type =*/ GGML_TASK_COMPUTE,
16058
16387
  /*.ith =*/ state->ith,
16059
- /*.nth =*/ node->n_tasks,
16060
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
16061
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16388
+ /*.nth =*/ n_tasks,
16389
+ /*.wsize =*/ cplan->work_size,
16390
+ /*.wdata =*/ cplan->work_data,
16062
16391
  };
16063
16392
 
16064
- if (state->ith < node->n_tasks) {
16393
+ if (state->ith < n_tasks) {
16065
16394
  ggml_compute_forward(&params, node);
16066
16395
  }
16067
16396
  }
16068
16397
 
16069
- return 0;
16398
+ return GGML_EXIT_SUCCESS;
16070
16399
  }
16071
16400
 
16072
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16073
- const int n_threads = cgraph->n_threads;
16401
+ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16402
+ if (n_threads <= 0) {
16403
+ n_threads = GGML_DEFAULT_N_THREADS;
16404
+ }
16074
16405
 
16075
- struct ggml_compute_state_shared state_shared = {
16076
- /*.cgraph =*/ cgraph,
16077
- /*.perf_node_start_cycles =*/ 0,
16078
- /*.perf_node_start_time_us =*/ 0,
16079
- /*.n_threads =*/ n_threads,
16080
- /*.n_active =*/ n_threads,
16081
- /*.node_n =*/ -1,
16082
- };
16083
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16406
+ size_t work_size = 0;
16084
16407
 
16085
- // initialize tasks + work buffer
16086
- {
16087
- size_t work_size = 0;
16408
+ struct ggml_cplan cplan;
16409
+ memset(&cplan, 0, sizeof(struct ggml_cplan));
16088
16410
 
16089
- // thread scheduling for the different operations
16090
- for (int i = 0; i < cgraph->n_nodes; i++) {
16091
- struct ggml_tensor * node = cgraph->nodes[i];
16411
+ // thread scheduling for the different operations + work buffer size estimation
16412
+ for (int i = 0; i < cgraph->n_nodes; i++) {
16413
+ int n_tasks = 1;
16092
16414
 
16093
- switch (node->op) {
16094
- case GGML_OP_CPY:
16095
- case GGML_OP_DUP:
16096
- {
16097
- node->n_tasks = n_threads;
16415
+ struct ggml_tensor * node = cgraph->nodes[i];
16098
16416
 
16099
- size_t cur = 0;
16100
- if (ggml_is_quantized(node->type)) {
16101
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
16102
- }
16417
+ switch (node->op) {
16418
+ case GGML_OP_CPY:
16419
+ case GGML_OP_DUP:
16420
+ {
16421
+ n_tasks = n_threads;
16422
+
16423
+ size_t cur = 0;
16424
+ if (ggml_is_quantized(node->type)) {
16425
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
16426
+ }
16103
16427
 
16104
- work_size = MAX(work_size, cur);
16105
- } break;
16106
- case GGML_OP_ADD:
16107
- case GGML_OP_ADD1:
16108
- {
16109
- node->n_tasks = n_threads;
16428
+ work_size = MAX(work_size, cur);
16429
+ } break;
16430
+ case GGML_OP_ADD:
16431
+ case GGML_OP_ADD1:
16432
+ {
16433
+ n_tasks = n_threads;
16110
16434
 
16111
- size_t cur = 0;
16435
+ size_t cur = 0;
16112
16436
 
16113
- if (ggml_is_quantized(node->src0->type)) {
16114
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
16115
- }
16437
+ if (ggml_is_quantized(node->src[0]->type)) {
16438
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
16439
+ }
16116
16440
 
16117
- work_size = MAX(work_size, cur);
16118
- } break;
16119
- case GGML_OP_ACC:
16120
- {
16121
- node->n_tasks = n_threads;
16441
+ work_size = MAX(work_size, cur);
16442
+ } break;
16443
+ case GGML_OP_ACC:
16444
+ {
16445
+ n_tasks = n_threads;
16122
16446
 
16123
- size_t cur = 0;
16447
+ size_t cur = 0;
16124
16448
 
16125
- if (ggml_is_quantized(node->src0->type)) {
16126
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
16127
- }
16449
+ if (ggml_is_quantized(node->src[0]->type)) {
16450
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
16451
+ }
16452
+
16453
+ work_size = MAX(work_size, cur);
16454
+ } break;
16455
+ case GGML_OP_SUB:
16456
+ case GGML_OP_DIV:
16457
+ case GGML_OP_SQR:
16458
+ case GGML_OP_SQRT:
16459
+ case GGML_OP_LOG:
16460
+ case GGML_OP_SUM:
16461
+ case GGML_OP_SUM_ROWS:
16462
+ case GGML_OP_MEAN:
16463
+ case GGML_OP_ARGMAX:
16464
+ case GGML_OP_REPEAT:
16465
+ case GGML_OP_REPEAT_BACK:
16466
+ case GGML_OP_ABS:
16467
+ case GGML_OP_SGN:
16468
+ case GGML_OP_NEG:
16469
+ case GGML_OP_STEP:
16470
+ case GGML_OP_TANH:
16471
+ case GGML_OP_ELU:
16472
+ case GGML_OP_RELU:
16473
+ {
16474
+ n_tasks = 1;
16475
+ } break;
16476
+ case GGML_OP_MUL:
16477
+ case GGML_OP_GELU:
16478
+ case GGML_OP_GELU_QUICK:
16479
+ case GGML_OP_SILU:
16480
+ case GGML_OP_SILU_BACK:
16481
+ case GGML_OP_NORM:
16482
+ case GGML_OP_RMS_NORM:
16483
+ case GGML_OP_RMS_NORM_BACK:
16484
+ {
16485
+ n_tasks = n_threads;
16486
+ } break;
16487
+ case GGML_OP_MUL_MAT:
16488
+ case GGML_OP_OUT_PROD:
16489
+ {
16490
+ n_tasks = n_threads;
16491
+
16492
+ // TODO: use different scheduling for different matrix sizes
16493
+ //const int nr0 = ggml_nrows(node->src[0]);
16494
+ //const int nr1 = ggml_nrows(node->src[1]);
16495
+
16496
+ //n_tasks = MIN(n_threads, MAX(1, nr0/128));
16497
+ //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
16128
16498
 
16129
- work_size = MAX(work_size, cur);
16130
- } break;
16131
- case GGML_OP_SUB:
16132
- case GGML_OP_DIV:
16133
- case GGML_OP_SQR:
16134
- case GGML_OP_SQRT:
16135
- case GGML_OP_LOG:
16136
- case GGML_OP_SUM:
16137
- case GGML_OP_SUM_ROWS:
16138
- case GGML_OP_MEAN:
16139
- case GGML_OP_ARGMAX:
16140
- case GGML_OP_REPEAT:
16141
- case GGML_OP_REPEAT_BACK:
16142
- case GGML_OP_ABS:
16143
- case GGML_OP_SGN:
16144
- case GGML_OP_NEG:
16145
- case GGML_OP_STEP:
16146
- case GGML_OP_TANH:
16147
- case GGML_OP_ELU:
16148
- case GGML_OP_RELU:
16149
- {
16150
- node->n_tasks = 1;
16151
- } break;
16152
- case GGML_OP_MUL:
16153
- case GGML_OP_GELU:
16154
- case GGML_OP_GELU_QUICK:
16155
- case GGML_OP_SILU:
16156
- case GGML_OP_SILU_BACK:
16157
- case GGML_OP_NORM:
16158
- case GGML_OP_RMS_NORM:
16159
- case GGML_OP_RMS_NORM_BACK:
16160
- {
16161
- node->n_tasks = n_threads;
16162
- } break;
16163
- case GGML_OP_MUL_MAT:
16164
- case GGML_OP_OUT_PROD:
16165
- {
16166
- node->n_tasks = n_threads;
16167
-
16168
- // TODO: use different scheduling for different matrix sizes
16169
- //const int nr0 = ggml_nrows(node->src0);
16170
- //const int nr1 = ggml_nrows(node->src1);
16171
-
16172
- //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
16173
- //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
16174
-
16175
- size_t cur = 0;
16176
- const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
16499
+ size_t cur = 0;
16500
+ const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
16177
16501
 
16178
16502
  #if defined(GGML_USE_CUBLAS)
16179
- if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
16180
- node->n_tasks = 1; // TODO: this actually is doing nothing
16181
- // the threads are still spinning
16182
- }
16183
- else
16503
+ if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
16504
+ n_tasks = 1; // TODO: this actually is doing nothing
16505
+ // the threads are still spinning
16506
+ } else
16184
16507
  #elif defined(GGML_USE_CLBLAST)
16185
- if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
16186
- node->n_tasks = 1; // TODO: this actually is doing nothing
16187
- // the threads are still spinning
16188
- cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
16189
- }
16190
- else
16508
+ if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
16509
+ n_tasks = 1; // TODO: this actually is doing nothing
16510
+ // the threads are still spinning
16511
+ cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
16512
+ } else
16191
16513
  #endif
16192
16514
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16193
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
16194
- node->n_tasks = 1; // TODO: this actually is doing nothing
16195
- // the threads are still spinning
16196
- if (node->src0->type != GGML_TYPE_F32) {
16197
- // here we need memory just for single 2D matrix from src0
16198
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
16199
- }
16200
- } else
16201
- #endif
16202
- if (node->src1->type != vec_dot_type) {
16203
- cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
16204
- } else {
16205
- cur = 0;
16515
+ if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
16516
+ n_tasks = 1; // TODO: this actually is doing nothing
16517
+ // the threads are still spinning
16518
+ if (node->src[0]->type != GGML_TYPE_F32) {
16519
+ // here we need memory just for single 2D matrix from src0
16520
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
16206
16521
  }
16522
+ } else
16523
+ #endif
16524
+ if (node->src[1]->type != vec_dot_type) {
16525
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
16526
+ } else {
16527
+ cur = 0;
16528
+ }
16207
16529
 
16208
- work_size = MAX(work_size, cur);
16209
- } break;
16210
- case GGML_OP_SCALE:
16211
- {
16212
- node->n_tasks = 1;
16213
- } break;
16214
- case GGML_OP_SET:
16215
- case GGML_OP_CONT:
16216
- case GGML_OP_RESHAPE:
16217
- case GGML_OP_VIEW:
16218
- case GGML_OP_PERMUTE:
16219
- case GGML_OP_TRANSPOSE:
16220
- case GGML_OP_GET_ROWS:
16221
- case GGML_OP_GET_ROWS_BACK:
16222
- case GGML_OP_DIAG:
16223
- case GGML_OP_DIAG_MASK_ZERO:
16224
- {
16225
- node->n_tasks = 1;
16226
- } break;
16227
- case GGML_OP_DIAG_MASK_INF:
16228
- case GGML_OP_SOFT_MAX:
16229
- case GGML_OP_SOFT_MAX_BACK:
16230
- case GGML_OP_ROPE:
16231
- case GGML_OP_ROPE_BACK:
16232
- {
16233
- node->n_tasks = n_threads;
16234
- } break;
16235
- case GGML_OP_ALIBI:
16236
- {
16237
- node->n_tasks = 1; //TODO
16238
- } break;
16239
- case GGML_OP_CLAMP:
16240
- {
16241
- node->n_tasks = 1; //TODO
16242
- } break;
16243
- case GGML_OP_CONV_1D:
16244
- {
16245
- node->n_tasks = n_threads;
16246
-
16247
- GGML_ASSERT(node->src0->ne[3] == 1);
16248
- GGML_ASSERT(node->src1->ne[2] == 1);
16249
- GGML_ASSERT(node->src1->ne[3] == 1);
16250
-
16251
- size_t cur = 0;
16252
- const int nk = node->src0->ne[0];
16253
-
16254
- if (node->src0->type == GGML_TYPE_F16 &&
16255
- node->src1->type == GGML_TYPE_F32) {
16256
- cur = sizeof(ggml_fp16_t)*(
16257
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16258
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16259
- );
16260
- } else if (node->src0->type == GGML_TYPE_F32 &&
16261
- node->src1->type == GGML_TYPE_F32) {
16262
- cur = sizeof(float)*(
16263
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16264
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16265
- );
16266
- } else {
16267
- GGML_ASSERT(false);
16268
- }
16530
+ work_size = MAX(work_size, cur);
16531
+ } break;
16532
+ case GGML_OP_SCALE:
16533
+ {
16534
+ n_tasks = 1;
16535
+ } break;
16536
+ case GGML_OP_SET:
16537
+ case GGML_OP_CONT:
16538
+ case GGML_OP_RESHAPE:
16539
+ case GGML_OP_VIEW:
16540
+ case GGML_OP_PERMUTE:
16541
+ case GGML_OP_TRANSPOSE:
16542
+ case GGML_OP_GET_ROWS:
16543
+ case GGML_OP_GET_ROWS_BACK:
16544
+ case GGML_OP_DIAG:
16545
+ case GGML_OP_DIAG_MASK_ZERO:
16546
+ {
16547
+ n_tasks = 1;
16548
+ } break;
16549
+ case GGML_OP_DIAG_MASK_INF:
16550
+ case GGML_OP_SOFT_MAX:
16551
+ case GGML_OP_SOFT_MAX_BACK:
16552
+ case GGML_OP_ROPE:
16553
+ case GGML_OP_ROPE_BACK:
16554
+ {
16555
+ n_tasks = n_threads;
16556
+ } break;
16557
+ case GGML_OP_ALIBI:
16558
+ {
16559
+ n_tasks = 1; //TODO
16560
+ } break;
16561
+ case GGML_OP_CLAMP:
16562
+ {
16563
+ n_tasks = 1; //TODO
16564
+ } break;
16565
+ case GGML_OP_CONV_1D:
16566
+ {
16567
+ n_tasks = n_threads;
16568
+
16569
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
16570
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
16571
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
16572
+
16573
+ size_t cur = 0;
16574
+ const int nk = node->src[0]->ne[0];
16575
+
16576
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16577
+ node->src[1]->type == GGML_TYPE_F32) {
16578
+ cur = sizeof(ggml_fp16_t)*(
16579
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16580
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16581
+ );
16582
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16583
+ node->src[1]->type == GGML_TYPE_F32) {
16584
+ cur = sizeof(float)*(
16585
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16586
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16587
+ );
16588
+ } else {
16589
+ GGML_ASSERT(false);
16590
+ }
16269
16591
 
16270
- work_size = MAX(work_size, cur);
16271
- } break;
16272
- case GGML_OP_CONV_2D:
16273
- {
16274
- node->n_tasks = n_threads;
16592
+ work_size = MAX(work_size, cur);
16593
+ } break;
16594
+ case GGML_OP_CONV_2D:
16595
+ {
16596
+ n_tasks = n_threads;
16597
+
16598
+ const int64_t ne00 = node->src[0]->ne[0]; // W
16599
+ const int64_t ne01 = node->src[0]->ne[1]; // H
16600
+ const int64_t ne02 = node->src[0]->ne[2]; // C
16601
+ const int64_t ne03 = node->src[0]->ne[3]; // N
16602
+
16603
+ const int64_t ne10 = node->src[1]->ne[0]; // W
16604
+ const int64_t ne11 = node->src[1]->ne[1]; // H
16605
+ const int64_t ne12 = node->src[1]->ne[2]; // C
16606
+
16607
+ const int64_t ne0 = node->ne[0];
16608
+ const int64_t ne1 = node->ne[1];
16609
+ const int64_t ne2 = node->ne[2];
16610
+ const int64_t nk = ne00*ne01;
16611
+ const int64_t ew0 = nk * ne02;
16612
+
16613
+ UNUSED(ne03);
16614
+ UNUSED(ne2);
16615
+
16616
+ size_t cur = 0;
16617
+
16618
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16619
+ node->src[1]->type == GGML_TYPE_F32) {
16620
+ cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
16621
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16622
+ node->src[1]->type == GGML_TYPE_F32) {
16623
+ cur = sizeof(float)* (ne10*ne11*ne12);
16624
+ } else {
16625
+ GGML_ASSERT(false);
16626
+ }
16275
16627
 
16276
- GGML_ASSERT(node->src1->ne[3] == 1);
16628
+ work_size = MAX(work_size, cur);
16629
+ } break;
16630
+ case GGML_OP_POOL_1D:
16631
+ case GGML_OP_POOL_2D:
16632
+ {
16633
+ n_tasks = 1;
16634
+ } break;
16635
+ case GGML_OP_FLASH_ATTN:
16636
+ {
16637
+ n_tasks = n_threads;
16277
16638
 
16278
- const int64_t ne00 = node->src0->ne[0]; // W
16279
- const int64_t ne01 = node->src0->ne[1]; // H
16280
- const int64_t ne02 = node->src0->ne[2]; // C
16281
- const int64_t ne03 = node->src0->ne[3]; // N
16639
+ size_t cur = 0;
16282
16640
 
16283
- const int64_t ne10 = node->src1->ne[0]; // W
16284
- const int64_t ne11 = node->src1->ne[1]; // H
16285
- const int64_t ne12 = node->src1->ne[2]; // C
16641
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16286
16642
 
16287
- const int64_t nk = ne00*ne01;
16643
+ if (node->src[1]->type == GGML_TYPE_F32) {
16644
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16645
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16646
+ }
16288
16647
 
16289
- UNUSED(ne02);
16290
- UNUSED(ne03);
16291
- UNUSED(nk);
16648
+ if (node->src[1]->type == GGML_TYPE_F16) {
16649
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16650
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16651
+ }
16292
16652
 
16293
- size_t cur = 0;
16653
+ work_size = MAX(work_size, cur);
16654
+ } break;
16655
+ case GGML_OP_FLASH_FF:
16656
+ {
16657
+ n_tasks = n_threads;
16294
16658
 
16295
- if (node->src0->type == GGML_TYPE_F16 &&
16296
- node->src1->type == GGML_TYPE_F32) {
16297
- cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16298
- } else if (node->src0->type == GGML_TYPE_F32 &&
16299
- node->src1->type == GGML_TYPE_F32) {
16300
- cur = sizeof(float)* (ne10*ne11*ne12);
16301
- } else {
16302
- GGML_ASSERT(false);
16303
- }
16659
+ size_t cur = 0;
16304
16660
 
16305
- work_size = MAX(work_size, cur);
16306
- } break;
16307
- case GGML_OP_FLASH_ATTN:
16308
- {
16309
- node->n_tasks = n_threads;
16661
+ if (node->src[1]->type == GGML_TYPE_F32) {
16662
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16663
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16664
+ }
16310
16665
 
16311
- size_t cur = 0;
16666
+ if (node->src[1]->type == GGML_TYPE_F16) {
16667
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16668
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16669
+ }
16312
16670
 
16313
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16671
+ work_size = MAX(work_size, cur);
16672
+ } break;
16673
+ case GGML_OP_FLASH_ATTN_BACK:
16674
+ {
16675
+ n_tasks = n_threads;
16314
16676
 
16315
- if (node->src1->type == GGML_TYPE_F32) {
16316
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16317
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16318
- }
16677
+ size_t cur = 0;
16319
16678
 
16320
- if (node->src1->type == GGML_TYPE_F16) {
16321
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16322
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16323
- }
16679
+ const int64_t D = node->src[0]->ne[0];
16680
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16681
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16682
+ if (node->src[1]->type == GGML_TYPE_F32) {
16683
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16684
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16685
+ }
16324
16686
 
16325
- work_size = MAX(work_size, cur);
16326
- } break;
16327
- case GGML_OP_FLASH_FF:
16328
- {
16329
- node->n_tasks = n_threads;
16687
+ if (node->src[1]->type == GGML_TYPE_F16) {
16688
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16689
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16690
+ }
16330
16691
 
16331
- size_t cur = 0;
16692
+ work_size = MAX(work_size, cur);
16693
+ } break;
16694
+ case GGML_OP_WIN_PART:
16695
+ case GGML_OP_WIN_UNPART:
16696
+ case GGML_OP_MAP_UNARY:
16697
+ case GGML_OP_MAP_BINARY:
16698
+ case GGML_OP_MAP_CUSTOM1:
16699
+ case GGML_OP_MAP_CUSTOM2:
16700
+ case GGML_OP_MAP_CUSTOM3:
16701
+ {
16702
+ n_tasks = 1;
16703
+ } break;
16704
+ case GGML_OP_CROSS_ENTROPY_LOSS:
16705
+ {
16706
+ n_tasks = n_threads;
16332
16707
 
16333
- if (node->src1->type == GGML_TYPE_F32) {
16334
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16335
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16336
- }
16708
+ size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16337
16709
 
16338
- if (node->src1->type == GGML_TYPE_F16) {
16339
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16340
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16341
- }
16710
+ work_size = MAX(work_size, cur);
16711
+ } break;
16712
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16713
+ {
16714
+ n_tasks = n_threads;
16342
16715
 
16343
- work_size = MAX(work_size, cur);
16344
- } break;
16345
- case GGML_OP_FLASH_ATTN_BACK:
16346
- {
16347
- node->n_tasks = n_threads;
16716
+ size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
16348
16717
 
16349
- size_t cur = 0;
16718
+ work_size = MAX(work_size, cur);
16719
+ } break;
16720
+ case GGML_OP_NONE:
16721
+ {
16722
+ n_tasks = 1;
16723
+ } break;
16724
+ case GGML_OP_COUNT:
16725
+ {
16726
+ GGML_ASSERT(false);
16727
+ } break;
16728
+ }
16350
16729
 
16351
- const int64_t D = node->src0->ne[0];
16352
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16353
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16354
- if (node->src1->type == GGML_TYPE_F32) {
16355
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16356
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16357
- }
16730
+ cplan.n_tasks[i] = n_tasks;
16731
+ }
16358
16732
 
16359
- if (node->src1->type == GGML_TYPE_F16) {
16360
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16361
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16362
- }
16733
+ if (work_size > 0) {
16734
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
16735
+ }
16363
16736
 
16364
- work_size = MAX(work_size, cur);
16365
- } break;
16366
- case GGML_OP_WIN_PART:
16367
- case GGML_OP_WIN_UNPART:
16368
- case GGML_OP_MAP_UNARY:
16369
- case GGML_OP_MAP_BINARY:
16370
- case GGML_OP_MAP_CUSTOM1:
16371
- case GGML_OP_MAP_CUSTOM2:
16372
- case GGML_OP_MAP_CUSTOM3:
16373
- {
16374
- node->n_tasks = 1;
16375
- } break;
16376
- case GGML_OP_CROSS_ENTROPY_LOSS:
16377
- {
16378
- node->n_tasks = n_threads;
16379
-
16380
- size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
16381
-
16382
- work_size = MAX(work_size, cur);
16383
- } break;
16384
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16385
- {
16386
- node->n_tasks = n_threads;
16387
-
16388
- size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
16389
-
16390
- work_size = MAX(work_size, cur);
16391
- } break;
16392
- case GGML_OP_NONE:
16393
- {
16394
- node->n_tasks = 1;
16395
- } break;
16396
- case GGML_OP_COUNT:
16397
- {
16398
- GGML_ASSERT(false);
16399
- } break;
16400
- }
16401
- }
16737
+ cplan.n_threads = n_threads;
16738
+ cplan.work_size = work_size;
16739
+ cplan.work_data = NULL;
16402
16740
 
16403
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
16404
- GGML_ASSERT(false); // TODO: better handling
16405
- }
16741
+ return cplan;
16742
+ }
16406
16743
 
16407
- if (work_size > 0 && cgraph->work == NULL) {
16408
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
16744
+ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
16745
+ {
16746
+ GGML_ASSERT(cplan);
16747
+ GGML_ASSERT(cplan->n_threads > 0);
16748
+
16749
+ if (cplan->work_size > 0) {
16750
+ GGML_ASSERT(cplan->work_data);
16751
+ }
16409
16752
 
16410
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
16411
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
16753
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
16754
+ if (cgraph->nodes[i]->op != GGML_OP_NONE) {
16755
+ GGML_ASSERT(cplan->n_tasks[i] > 0);
16756
+ }
16412
16757
  }
16413
16758
  }
16414
16759
 
16760
+ const int n_threads = cplan->n_threads;
16761
+
16762
+ struct ggml_compute_state_shared state_shared = {
16763
+ /*.cgraph =*/ cgraph,
16764
+ /*.cgraph_plan =*/ cplan,
16765
+ /*.perf_node_start_cycles =*/ 0,
16766
+ /*.perf_node_start_time_us =*/ 0,
16767
+ /*.n_threads =*/ n_threads,
16768
+ /*.n_active =*/ n_threads,
16769
+ /*.node_n =*/ -1,
16770
+ /*.abort_callback =*/ NULL,
16771
+ /*.abort_callback_data =*/ NULL,
16772
+ };
16773
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16774
+
16415
16775
  // create thread pool
16416
16776
  if (n_threads > 1) {
16417
16777
  for (int j = 1; j < n_threads; ++j) {
@@ -16432,12 +16792,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16432
16792
  const int64_t perf_start_time_us = ggml_perf_time_us();
16433
16793
 
16434
16794
  // this is a work thread too
16435
- ggml_graph_compute_thread(&workers[0]);
16795
+ int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
16436
16796
 
16437
16797
  // don't leave affinity set on the main thread
16438
16798
  clear_numa_thread_affinity();
16439
16799
 
16440
- // join thread pool
16800
+ // join or kill thread pool
16441
16801
  if (n_threads > 1) {
16442
16802
  for (int j = 1; j < n_threads; j++) {
16443
16803
  const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16461,6 +16821,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16461
16821
  (double) perf_time_us_cur / 1000.0,
16462
16822
  (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
16463
16823
  }
16824
+
16825
+ return compute_status;
16464
16826
  }
16465
16827
 
16466
16828
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -16473,6 +16835,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
16473
16835
  }
16474
16836
  }
16475
16837
 
16838
+ void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
16839
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
16840
+
16841
+ struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16842
+ GGML_ASSERT(buf);
16843
+
16844
+ cplan.work_data = buf->data;
16845
+
16846
+ ggml_graph_compute(cgraph, &cplan);
16847
+ }
16848
+
16476
16849
  struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
16477
16850
  for (int i = 0; i < cgraph->n_leafs; i++) {
16478
16851
  struct ggml_tensor * leaf = cgraph->leafs[i];
@@ -16511,22 +16884,18 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
16511
16884
  const int64_t * ne = tensor->ne;
16512
16885
  const size_t * nb = tensor->nb;
16513
16886
 
16514
- fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
16887
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
16515
16888
  arg,
16516
16889
  ggml_type_name(tensor->type),
16517
16890
  ggml_op_name (tensor->op),
16518
16891
  tensor->n_dims,
16519
16892
  ne[0], ne[1], ne[2], ne[3],
16520
16893
  nb[0], nb[1], nb[2], nb[3],
16521
- tensor->n_tasks,
16522
16894
  tensor->data,
16523
16895
  tensor->name);
16524
16896
  }
16525
16897
 
16526
16898
  void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16527
- //assert(cgraph->work == NULL);
16528
- //assert(cgraph->work_size == 0);
16529
-
16530
16899
  uint64_t size_eval = 0;
16531
16900
 
16532
16901
  // compute size of intermediate results
@@ -16555,8 +16924,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16555
16924
  ggml_graph_export_leaf(cgraph->leafs[i], fout);
16556
16925
 
16557
16926
  GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE);
16558
- GGML_ASSERT(cgraph->leafs[i]->src0 == NULL);
16559
- GGML_ASSERT(cgraph->leafs[i]->src1 == NULL);
16927
+ GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL);
16928
+ GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL);
16560
16929
  }
16561
16930
 
16562
16931
  // header
@@ -16567,17 +16936,9 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16567
16936
  for (int i = 0; i < cgraph->n_nodes; ++i) {
16568
16937
  ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
16569
16938
 
16570
- if (cgraph->nodes[i]->src0) {
16571
- ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout);
16572
- }
16573
-
16574
- if (cgraph->nodes[i]->src1) {
16575
- ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout);
16576
- }
16577
-
16578
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16579
- if (cgraph->nodes[i]->opt[j]) {
16580
- ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout);
16939
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16940
+ if (cgraph->nodes[i]->src[j]) {
16941
+ ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
16581
16942
  }
16582
16943
  }
16583
16944
 
@@ -16668,16 +17029,13 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16668
17029
 
16669
17030
  // output the op arguments
16670
17031
  {
16671
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
16672
-
16673
- args[0] = tensor->src0;
16674
- args[1] = tensor->src1;
17032
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16675
17033
 
16676
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16677
- args[2 + j] = tensor->opt[j];
17034
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17035
+ args[j] = tensor->src[j];
16678
17036
  }
16679
17037
 
16680
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17038
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16681
17039
  if (args[j]) {
16682
17040
  int32_t idx = -1;
16683
17041
 
@@ -16895,12 +17253,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16895
17253
 
16896
17254
  const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
16897
17255
 
16898
- const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
17256
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
16899
17257
 
16900
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
17258
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16901
17259
 
16902
17260
  // parse args
16903
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17261
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16904
17262
  const int32_t arg_idx = ptr_arg_idx[j];
16905
17263
 
16906
17264
  if (arg_idx == -1) {
@@ -16957,11 +17315,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16957
17315
  tensor->nb[j] = nb[j];
16958
17316
  }
16959
17317
 
16960
- tensor->src0 = args[0];
16961
- tensor->src1 = args[1];
16962
-
16963
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16964
- tensor->opt[j] = args[2 + j];
17318
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17319
+ tensor->src[j] = args[j];
16965
17320
  }
16966
17321
 
16967
17322
  result.nodes[i] = tensor;
@@ -16979,9 +17334,6 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
16979
17334
 
16980
17335
  GGML_PRINT("=== GRAPH ===\n");
16981
17336
 
16982
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
16983
- GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
16984
-
16985
17337
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
16986
17338
  for (int i = 0; i < cgraph->n_nodes; i++) {
16987
17339
  struct ggml_tensor * node = cgraph->nodes[i];
@@ -17160,19 +17512,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17160
17512
  for (int i = 0; i < gb->n_nodes; i++) {
17161
17513
  struct ggml_tensor * node = gb->nodes[i];
17162
17514
 
17163
- if (node->src0) {
17164
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x");
17165
- }
17166
-
17167
- if (node->src1) {
17168
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y");
17169
- }
17170
-
17171
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17172
- if (node->opt[j]) {
17515
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17516
+ if (node->src[j]) {
17173
17517
  char label[16];
17174
- snprintf(label, sizeof(label), "opt %d", j);
17175
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label);
17518
+ snprintf(label, sizeof(label), "src %d", j);
17519
+ ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
17176
17520
  }
17177
17521
  }
17178
17522
  }
@@ -17180,19 +17524,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17180
17524
  for (int i = 0; i < gb->n_leafs; i++) {
17181
17525
  struct ggml_tensor * node = gb->leafs[i];
17182
17526
 
17183
- if (node->src0) {
17184
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x");
17185
- }
17186
-
17187
- if (node->src1) {
17188
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y");
17189
- }
17190
-
17191
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17192
- if (node->opt[j]) {
17527
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17528
+ if (node->src[j]) {
17193
17529
  char label[16];
17194
- snprintf(label, sizeof(label), "opt %d", j);
17195
- ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label);
17530
+ snprintf(label, sizeof(label), "src %d", j);
17531
+ ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
17196
17532
  }
17197
17533
  }
17198
17534
  }
@@ -17254,9 +17590,6 @@ static enum ggml_opt_result ggml_opt_adam(
17254
17590
  struct ggml_cgraph * gb) {
17255
17591
  GGML_ASSERT(ggml_is_scalar(f));
17256
17592
 
17257
- gf->n_threads = params.n_threads;
17258
- gb->n_threads = params.n_threads;
17259
-
17260
17593
  // these will store the parameters we want to optimize
17261
17594
  struct ggml_tensor * ps[GGML_MAX_PARAMS];
17262
17595
 
@@ -17303,7 +17636,8 @@ static enum ggml_opt_result ggml_opt_adam(
17303
17636
  // compute the function value
17304
17637
  ggml_graph_reset (gf);
17305
17638
  ggml_set_f32 (f->grad, 1.0f);
17306
- ggml_graph_compute(ctx, gb);
17639
+
17640
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17307
17641
 
17308
17642
  opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17309
17643
  opt->adam.fx_best = opt->adam.fx_prev;
@@ -17383,7 +17717,8 @@ static enum ggml_opt_result ggml_opt_adam(
17383
17717
 
17384
17718
  ggml_graph_reset (gf);
17385
17719
  ggml_set_f32 (f->grad, 1.0f);
17386
- ggml_graph_compute(ctx, gb);
17720
+
17721
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17387
17722
 
17388
17723
  const float fx = ggml_get_f32_1d(f, 0);
17389
17724
 
@@ -17505,7 +17840,8 @@ static enum ggml_opt_result linesearch_backtracking(
17505
17840
 
17506
17841
  ggml_graph_reset (gf);
17507
17842
  ggml_set_f32 (f->grad, 1.0f);
17508
- ggml_graph_compute(ctx, gb);
17843
+
17844
+ ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
17509
17845
 
17510
17846
  ggml_opt_get_grad(np, ps, g);
17511
17847
 
@@ -17573,9 +17909,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17573
17909
  }
17574
17910
  }
17575
17911
 
17576
- gf->n_threads = params.n_threads;
17577
- gb->n_threads = params.n_threads;
17578
-
17579
17912
  const int m = params.lbfgs.m;
17580
17913
 
17581
17914
  // these will store the parameters we want to optimize
@@ -17627,7 +17960,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17627
17960
 
17628
17961
  ggml_graph_reset (gf);
17629
17962
  ggml_set_f32 (f->grad, 1.0f);
17630
- ggml_graph_compute(ctx, gb);
17963
+
17964
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17631
17965
 
17632
17966
  ggml_opt_get_grad(np, ps, g);
17633
17967