llama_cpp 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,16 +25,23 @@
25
25
  #include <float.h>
26
26
  #include <limits.h>
27
27
  #include <stdarg.h>
28
+ #include <signal.h>
28
29
 
29
30
  #ifdef GGML_USE_METAL
30
31
  #include <unistd.h>
31
32
  #endif
32
33
 
34
+ // static_assert should be a #define, but if it's not,
35
+ // fall back to the _Static_assert C11 keyword.
33
36
  // if C99 - static_assert is noop
34
37
  // ref: https://stackoverflow.com/a/53923785/4039976
35
38
  #ifndef static_assert
39
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
40
+ #define static_assert(cond, msg) _Static_assert(cond, msg)
41
+ #else
36
42
  #define static_assert(cond, msg) struct global_scope_noop_trick
37
43
  #endif
44
+ #endif
38
45
 
39
46
  #if defined(_MSC_VER)
40
47
  // disable "possible loss of data" to avoid hundreds of casts
@@ -49,23 +56,23 @@
49
56
  typedef volatile LONG atomic_int;
50
57
  typedef atomic_int atomic_bool;
51
58
 
52
- static void atomic_store(atomic_int* ptr, LONG val) {
59
+ static void atomic_store(atomic_int * ptr, LONG val) {
53
60
  InterlockedExchange(ptr, val);
54
61
  }
55
- static LONG atomic_load(atomic_int* ptr) {
62
+ static LONG atomic_load(atomic_int * ptr) {
56
63
  return InterlockedCompareExchange(ptr, 0, 0);
57
64
  }
58
- static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) {
65
+ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
59
66
  return InterlockedExchangeAdd(ptr, inc);
60
67
  }
61
- static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) {
68
+ static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
62
69
  return atomic_fetch_add(ptr, -(dec));
63
70
  }
64
71
 
65
72
  typedef HANDLE pthread_t;
66
73
 
67
74
  typedef DWORD thread_ret_t;
68
- static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) {
75
+ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
69
76
  (void) unused;
70
77
  HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
71
78
  if (handle == NULL)
@@ -77,7 +84,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
77
84
  return 0;
78
85
  }
79
86
 
80
- static int pthread_join(pthread_t thread, void* unused) {
87
+ static int pthread_join(pthread_t thread, void * unused) {
81
88
  (void) unused;
82
89
  return (int) WaitForSingleObject(thread, INFINITE);
83
90
  }
@@ -90,7 +97,7 @@ static int sched_yield (void) {
90
97
  #include <pthread.h>
91
98
  #include <stdatomic.h>
92
99
 
93
- typedef void* thread_ret_t;
100
+ typedef void * thread_ret_t;
94
101
 
95
102
  #include <sys/types.h>
96
103
  #include <sys/stat.h>
@@ -111,10 +118,6 @@ typedef void* thread_ret_t;
111
118
  #endif
112
119
  #endif
113
120
 
114
- #ifdef __HAIKU__
115
- #define static_assert(cond, msg) _Static_assert(cond, msg)
116
- #endif
117
-
118
121
  /*#define GGML_PERF*/
119
122
  #define GGML_DEBUG 0
120
123
  #define GGML_GELU_FP16
@@ -247,7 +250,11 @@ inline static void* ggml_aligned_malloc(size_t size) {
247
250
  #include "ggml-opencl.h"
248
251
  #endif
249
252
  #elif defined(GGML_USE_OPENBLAS)
253
+ #if defined(GGML_BLAS_USE_MKL)
254
+ #include <mkl.h>
255
+ #else
250
256
  #include <cblas.h>
257
+ #endif
251
258
  #elif defined(GGML_USE_CUBLAS)
252
259
  #include "ggml-cuda.h"
253
260
  #elif defined(GGML_USE_CLBLAST)
@@ -3782,6 +3789,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3782
3789
  "CLAMP",
3783
3790
  "CONV_1D",
3784
3791
  "CONV_2D",
3792
+ "POOL_1D",
3793
+ "POOL_2D",
3785
3794
 
3786
3795
  "FLASH_ATTN",
3787
3796
  "FLASH_FF",
@@ -3800,7 +3809,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
3800
3809
  "CROSS_ENTROPY_LOSS_BACK",
3801
3810
  };
3802
3811
 
3803
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3812
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3804
3813
 
3805
3814
  static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3806
3815
  "none",
@@ -3860,6 +3869,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3860
3869
  "clamp(x)",
3861
3870
  "conv_1d(x)",
3862
3871
  "conv_2d(x)",
3872
+ "pool_1d(x)",
3873
+ "pool_2d(x)",
3863
3874
 
3864
3875
  "flash_attn(x)",
3865
3876
  "flash_ff(x)",
@@ -3878,7 +3889,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
3878
3889
  "cross_entropy_loss_back(x,y)",
3879
3890
  };
3880
3891
 
3881
- static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66");
3892
+ static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
3893
+
3894
+ static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
3882
3895
 
3883
3896
  static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
3884
3897
  static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -4157,10 +4170,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
4157
4170
  static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
4158
4171
  static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
4159
4172
 
4160
- return
4161
- (t0->ne[0] == t1->ne[0]) &&
4162
- (t0->ne[2] == t1->ne[2]) &&
4163
- (t0->ne[3] == t1->ne[3]);
4173
+ return (t0->ne[0] == t1->ne[0]) &&
4174
+ (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
4175
+ (t1->ne[3]%t0->ne[3] == 0);
4164
4176
  }
4165
4177
 
4166
4178
  static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@@ -4400,8 +4412,8 @@ void ggml_free(struct ggml_context * ctx) {
4400
4412
  if (&g_state.contexts[i].context == ctx) {
4401
4413
  g_state.contexts[i].used = false;
4402
4414
 
4403
- GGML_PRINT_DEBUG("%s: context %d with %d objects has been freed. memory used = %zu\n",
4404
- __func__, i, ctx->n_objects, ctx->objects_end->offs + ctx->objects_end->size);
4415
+ GGML_PRINT_DEBUG("%s: context %d has been freed. memory used = %zu\n",
4416
+ __func__, i, ggml_used_mem(ctx));
4405
4417
 
4406
4418
  if (ctx->mem_buffer_owned) {
4407
4419
  GGML_ALIGNED_FREE(ctx->mem_buffer);
@@ -4580,17 +4592,14 @@ struct ggml_tensor * ggml_new_tensor_impl(
4580
4592
  /*.op =*/ GGML_OP_NONE,
4581
4593
  /*.is_param =*/ false,
4582
4594
  /*.grad =*/ NULL,
4583
- /*.src0 =*/ NULL,
4584
- /*.src1 =*/ NULL,
4585
- /*.opt =*/ { NULL },
4586
- /*.n_tasks =*/ 0,
4595
+ /*.src =*/ { NULL },
4587
4596
  /*.perf_runs =*/ 0,
4588
4597
  /*.perf_cycles =*/ 0,
4589
4598
  /*.perf_time_us =*/ 0,
4590
4599
  /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
4591
4600
  /*.name =*/ { 0 },
4592
4601
  /*.extra =*/ NULL,
4593
- /*.pad =*/ { 0 },
4602
+ /*.padding =*/ { 0 },
4594
4603
  };
4595
4604
 
4596
4605
  // TODO: this should not be needed as long as we don't rely on aligned SIMD loads
@@ -4722,7 +4731,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
4722
4731
  {
4723
4732
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4724
4733
  for (int i = 0; i < n; i++) {
4725
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4734
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4726
4735
  }
4727
4736
  } break;
4728
4737
  case GGML_TYPE_F32:
@@ -4774,7 +4783,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
4774
4783
  {
4775
4784
  assert(tensor->nb[0] == sizeof(ggml_fp16_t));
4776
4785
  for (int i = 0; i < n; i++) {
4777
- ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value);
4786
+ ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
4778
4787
  }
4779
4788
  } break;
4780
4789
  case GGML_TYPE_F32:
@@ -5009,8 +5018,8 @@ struct ggml_tensor * ggml_dup_impl(
5009
5018
 
5010
5019
  result->op = GGML_OP_DUP;
5011
5020
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5012
- result->src0 = a;
5013
- result->src1 = NULL;
5021
+ result->src[0] = a;
5022
+ result->src[1] = NULL;
5014
5023
 
5015
5024
  return result;
5016
5025
  }
@@ -5034,11 +5043,15 @@ struct ggml_tensor * ggml_add_impl(
5034
5043
  struct ggml_tensor * a,
5035
5044
  struct ggml_tensor * b,
5036
5045
  bool inplace) {
5037
- GGML_ASSERT(ggml_are_same_shape(a, b));
5046
+ // TODO: support less-strict constraint
5047
+ // GGML_ASSERT(ggml_can_repeat(b, a));
5048
+ GGML_ASSERT(ggml_can_repeat_rows(b, a));
5038
5049
 
5039
5050
  bool is_node = false;
5040
5051
 
5041
- if (a->grad || b->grad) {
5052
+ if (!inplace && (a->grad || b->grad)) {
5053
+ // TODO: support backward pass for broadcasting
5054
+ GGML_ASSERT(ggml_are_same_shape(a, b));
5042
5055
  is_node = true;
5043
5056
  }
5044
5057
 
@@ -5046,8 +5059,8 @@ struct ggml_tensor * ggml_add_impl(
5046
5059
 
5047
5060
  result->op = GGML_OP_ADD;
5048
5061
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5049
- result->src0 = a;
5050
- result->src1 = b;
5062
+ result->src[0] = a;
5063
+ result->src[1] = b;
5051
5064
 
5052
5065
  return result;
5053
5066
  }
@@ -5086,8 +5099,8 @@ struct ggml_tensor * ggml_add1_impl(
5086
5099
 
5087
5100
  result->op = GGML_OP_ADD1;
5088
5101
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5089
- result->src0 = a;
5090
- result->src1 = b;
5102
+ result->src[0] = a;
5103
+ result->src[1] = b;
5091
5104
 
5092
5105
  return result;
5093
5106
  }
@@ -5144,9 +5157,9 @@ struct ggml_tensor * ggml_acc_impl(
5144
5157
 
5145
5158
  result->op = GGML_OP_ACC;
5146
5159
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5147
- result->src0 = a;
5148
- result->src1 = b;
5149
- result->opt[0] = c;
5160
+ result->src[0] = a;
5161
+ result->src[1] = b;
5162
+ result->src[2] = c;
5150
5163
 
5151
5164
  return result;
5152
5165
  }
@@ -5192,8 +5205,8 @@ struct ggml_tensor * ggml_sub_impl(
5192
5205
 
5193
5206
  result->op = GGML_OP_SUB;
5194
5207
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5195
- result->src0 = a;
5196
- result->src1 = b;
5208
+ result->src[0] = a;
5209
+ result->src[1] = b;
5197
5210
 
5198
5211
  return result;
5199
5212
  }
@@ -5239,8 +5252,8 @@ struct ggml_tensor * ggml_mul_impl(
5239
5252
 
5240
5253
  result->op = GGML_OP_MUL;
5241
5254
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5242
- result->src0 = a;
5243
- result->src1 = b;
5255
+ result->src[0] = a;
5256
+ result->src[1] = b;
5244
5257
 
5245
5258
  return result;
5246
5259
  }
@@ -5282,8 +5295,8 @@ struct ggml_tensor * ggml_div_impl(
5282
5295
 
5283
5296
  result->op = GGML_OP_DIV;
5284
5297
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5285
- result->src0 = a;
5286
- result->src1 = b;
5298
+ result->src[0] = a;
5299
+ result->src[1] = b;
5287
5300
 
5288
5301
  return result;
5289
5302
  }
@@ -5318,8 +5331,8 @@ struct ggml_tensor * ggml_sqr_impl(
5318
5331
 
5319
5332
  result->op = GGML_OP_SQR;
5320
5333
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5321
- result->src0 = a;
5322
- result->src1 = NULL;
5334
+ result->src[0] = a;
5335
+ result->src[1] = NULL;
5323
5336
 
5324
5337
  return result;
5325
5338
  }
@@ -5352,8 +5365,8 @@ struct ggml_tensor * ggml_sqrt_impl(
5352
5365
 
5353
5366
  result->op = GGML_OP_SQRT;
5354
5367
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5355
- result->src0 = a;
5356
- result->src1 = NULL;
5368
+ result->src[0] = a;
5369
+ result->src[1] = NULL;
5357
5370
 
5358
5371
  return result;
5359
5372
  }
@@ -5387,8 +5400,8 @@ struct ggml_tensor * ggml_log_impl(
5387
5400
 
5388
5401
  result->op = GGML_OP_LOG;
5389
5402
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5390
- result->src0 = a;
5391
- result->src1 = NULL;
5403
+ result->src[0] = a;
5404
+ result->src[1] = NULL;
5392
5405
 
5393
5406
  return result;
5394
5407
  }
@@ -5420,8 +5433,8 @@ struct ggml_tensor * ggml_sum(
5420
5433
 
5421
5434
  result->op = GGML_OP_SUM;
5422
5435
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5423
- result->src0 = a;
5424
- result->src1 = NULL;
5436
+ result->src[0] = a;
5437
+ result->src[1] = NULL;
5425
5438
 
5426
5439
  return result;
5427
5440
  }
@@ -5447,8 +5460,8 @@ struct ggml_tensor * ggml_sum_rows(
5447
5460
 
5448
5461
  result->op = GGML_OP_SUM_ROWS;
5449
5462
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5450
- result->src0 = a;
5451
- result->src1 = NULL;
5463
+ result->src[0] = a;
5464
+ result->src[1] = NULL;
5452
5465
 
5453
5466
  return result;
5454
5467
  }
@@ -5470,8 +5483,8 @@ struct ggml_tensor * ggml_mean(
5470
5483
 
5471
5484
  result->op = GGML_OP_MEAN;
5472
5485
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5473
- result->src0 = a;
5474
- result->src1 = NULL;
5486
+ result->src[0] = a;
5487
+ result->src[1] = NULL;
5475
5488
 
5476
5489
  return result;
5477
5490
  }
@@ -5494,8 +5507,8 @@ struct ggml_tensor * ggml_argmax(
5494
5507
 
5495
5508
  result->op = GGML_OP_ARGMAX;
5496
5509
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5497
- result->src0 = a;
5498
- result->src1 = NULL;
5510
+ result->src[0] = a;
5511
+ result->src[1] = NULL;
5499
5512
 
5500
5513
  return result;
5501
5514
  }
@@ -5522,8 +5535,8 @@ struct ggml_tensor * ggml_repeat(
5522
5535
 
5523
5536
  result->op = GGML_OP_REPEAT;
5524
5537
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5525
- result->src0 = a;
5526
- result->src1 = b;
5538
+ result->src[0] = a;
5539
+ result->src[1] = b;
5527
5540
 
5528
5541
  return result;
5529
5542
  }
@@ -5550,8 +5563,8 @@ struct ggml_tensor * ggml_repeat_back(
5550
5563
 
5551
5564
  result->op = GGML_OP_REPEAT_BACK;
5552
5565
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5553
- result->src0 = a;
5554
- result->src1 = b;
5566
+ result->src[0] = a;
5567
+ result->src[1] = b;
5555
5568
 
5556
5569
  return result;
5557
5570
  }
@@ -5572,8 +5585,8 @@ struct ggml_tensor * ggml_abs_impl(
5572
5585
 
5573
5586
  result->op = GGML_OP_ABS;
5574
5587
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5575
- result->src0 = a;
5576
- result->src1 = NULL;
5588
+ result->src[0] = a;
5589
+ result->src[1] = NULL;
5577
5590
 
5578
5591
  return result;
5579
5592
  }
@@ -5607,8 +5620,8 @@ struct ggml_tensor * ggml_sgn_impl(
5607
5620
 
5608
5621
  result->op = GGML_OP_SGN;
5609
5622
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5610
- result->src0 = a;
5611
- result->src1 = NULL;
5623
+ result->src[0] = a;
5624
+ result->src[1] = NULL;
5612
5625
 
5613
5626
  return result;
5614
5627
  }
@@ -5641,8 +5654,8 @@ struct ggml_tensor * ggml_neg_impl(
5641
5654
 
5642
5655
  result->op = GGML_OP_NEG;
5643
5656
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5644
- result->src0 = a;
5645
- result->src1 = NULL;
5657
+ result->src[0] = a;
5658
+ result->src[1] = NULL;
5646
5659
 
5647
5660
  return result;
5648
5661
  }
@@ -5675,8 +5688,8 @@ struct ggml_tensor * ggml_step_impl(
5675
5688
 
5676
5689
  result->op = GGML_OP_STEP;
5677
5690
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5678
- result->src0 = a;
5679
- result->src1 = NULL;
5691
+ result->src[0] = a;
5692
+ result->src[1] = NULL;
5680
5693
 
5681
5694
  return result;
5682
5695
  }
@@ -5709,8 +5722,8 @@ struct ggml_tensor * ggml_tanh_impl(
5709
5722
 
5710
5723
  result->op = GGML_OP_TANH;
5711
5724
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5712
- result->src0 = a;
5713
- result->src1 = NULL;
5725
+ result->src[0] = a;
5726
+ result->src[1] = NULL;
5714
5727
 
5715
5728
  return result;
5716
5729
  }
@@ -5743,8 +5756,8 @@ struct ggml_tensor * ggml_elu_impl(
5743
5756
 
5744
5757
  result->op = GGML_OP_ELU;
5745
5758
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5746
- result->src0 = a;
5747
- result->src1 = NULL;
5759
+ result->src[0] = a;
5760
+ result->src[1] = NULL;
5748
5761
 
5749
5762
  return result;
5750
5763
  }
@@ -5777,8 +5790,8 @@ struct ggml_tensor * ggml_relu_impl(
5777
5790
 
5778
5791
  result->op = GGML_OP_RELU;
5779
5792
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5780
- result->src0 = a;
5781
- result->src1 = NULL;
5793
+ result->src[0] = a;
5794
+ result->src[1] = NULL;
5782
5795
 
5783
5796
  return result;
5784
5797
  }
@@ -5811,8 +5824,8 @@ struct ggml_tensor * ggml_gelu_impl(
5811
5824
 
5812
5825
  result->op = GGML_OP_GELU;
5813
5826
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5814
- result->src0 = a;
5815
- result->src1 = NULL;
5827
+ result->src[0] = a;
5828
+ result->src[1] = NULL;
5816
5829
 
5817
5830
  return result;
5818
5831
  }
@@ -5845,8 +5858,8 @@ struct ggml_tensor * ggml_gelu_quick_impl(
5845
5858
 
5846
5859
  result->op = GGML_OP_GELU_QUICK;
5847
5860
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5848
- result->src0 = a;
5849
- result->src1 = NULL;
5861
+ result->src[0] = a;
5862
+ result->src[1] = NULL;
5850
5863
 
5851
5864
  return result;
5852
5865
  }
@@ -5879,8 +5892,8 @@ struct ggml_tensor * ggml_silu_impl(
5879
5892
 
5880
5893
  result->op = GGML_OP_SILU;
5881
5894
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5882
- result->src0 = a;
5883
- result->src1 = NULL;
5895
+ result->src[0] = a;
5896
+ result->src[1] = NULL;
5884
5897
 
5885
5898
  return result;
5886
5899
  }
@@ -5914,8 +5927,8 @@ struct ggml_tensor * ggml_silu_back(
5914
5927
 
5915
5928
  result->op = GGML_OP_SILU_BACK;
5916
5929
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5917
- result->src0 = a;
5918
- result->src1 = b;
5930
+ result->src[0] = a;
5931
+ result->src[1] = b;
5919
5932
 
5920
5933
  return result;
5921
5934
  }
@@ -5937,8 +5950,8 @@ struct ggml_tensor * ggml_norm_impl(
5937
5950
 
5938
5951
  result->op = GGML_OP_NORM;
5939
5952
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5940
- result->src0 = a;
5941
- result->src1 = NULL; // TODO: maybe store epsilon here?
5953
+ result->src[0] = a;
5954
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5942
5955
 
5943
5956
  return result;
5944
5957
  }
@@ -5969,8 +5982,8 @@ struct ggml_tensor * ggml_rms_norm_impl(
5969
5982
 
5970
5983
  result->op = GGML_OP_RMS_NORM;
5971
5984
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5972
- result->src0 = a;
5973
- result->src1 = NULL; // TODO: maybe store epsilon here?
5985
+ result->src[0] = a;
5986
+ result->src[1] = NULL; // TODO: maybe store epsilon here?
5974
5987
 
5975
5988
  return result;
5976
5989
  }
@@ -6002,8 +6015,8 @@ struct ggml_tensor * ggml_rms_norm_back(
6002
6015
 
6003
6016
  result->op = GGML_OP_RMS_NORM_BACK;
6004
6017
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6005
- result->src0 = a;
6006
- result->src1 = b;
6018
+ result->src[0] = a;
6019
+ result->src[1] = b;
6007
6020
 
6008
6021
  return result;
6009
6022
  }
@@ -6024,13 +6037,13 @@ struct ggml_tensor * ggml_mul_mat(
6024
6037
  is_node = true;
6025
6038
  }
6026
6039
 
6027
- const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
6028
- struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
6040
+ const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
6041
+ struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
6029
6042
 
6030
6043
  result->op = GGML_OP_MUL_MAT;
6031
6044
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6032
- result->src0 = a;
6033
- result->src1 = b;
6045
+ result->src[0] = a;
6046
+ result->src[1] = b;
6034
6047
 
6035
6048
  return result;
6036
6049
  }
@@ -6055,8 +6068,8 @@ struct ggml_tensor * ggml_out_prod(
6055
6068
 
6056
6069
  result->op = GGML_OP_OUT_PROD;
6057
6070
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6058
- result->src0 = a;
6059
- result->src1 = b;
6071
+ result->src[0] = a;
6072
+ result->src[1] = b;
6060
6073
 
6061
6074
  return result;
6062
6075
  }
@@ -6081,8 +6094,8 @@ struct ggml_tensor * ggml_scale_impl(
6081
6094
 
6082
6095
  result->op = GGML_OP_SCALE;
6083
6096
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6084
- result->src0 = a;
6085
- result->src1 = b;
6097
+ result->src[0] = a;
6098
+ result->src[1] = b;
6086
6099
 
6087
6100
  return result;
6088
6101
  }
@@ -6137,9 +6150,9 @@ struct ggml_tensor * ggml_set_impl(
6137
6150
 
6138
6151
  result->op = GGML_OP_SET;
6139
6152
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6140
- result->src0 = a;
6141
- result->src1 = b;
6142
- result->opt[0] = c;
6153
+ result->src[0] = a;
6154
+ result->src[1] = b;
6155
+ result->src[2] = c;
6143
6156
 
6144
6157
  return result;
6145
6158
  }
@@ -6226,8 +6239,8 @@ struct ggml_tensor * ggml_cpy_impl(
6226
6239
 
6227
6240
  result->op = GGML_OP_CPY;
6228
6241
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6229
- result->src0 = a;
6230
- result->src1 = b;
6242
+ result->src[0] = a;
6243
+ result->src[1] = b;
6231
6244
 
6232
6245
  return result;
6233
6246
  }
@@ -6263,8 +6276,8 @@ struct ggml_tensor * ggml_cont_impl(
6263
6276
 
6264
6277
  result->op = GGML_OP_CONT;
6265
6278
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6266
- result->src0 = a;
6267
- result->src1 = NULL;
6279
+ result->src[0] = a;
6280
+ result->src[1] = NULL;
6268
6281
 
6269
6282
  return result;
6270
6283
  }
@@ -6307,8 +6320,8 @@ struct ggml_tensor * ggml_reshape(
6307
6320
 
6308
6321
  result->op = GGML_OP_RESHAPE;
6309
6322
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6310
- result->src0 = a;
6311
- result->src1 = NULL;
6323
+ result->src[0] = a;
6324
+ result->src[1] = NULL;
6312
6325
 
6313
6326
  return result;
6314
6327
  }
@@ -6332,8 +6345,8 @@ struct ggml_tensor * ggml_reshape_1d(
6332
6345
 
6333
6346
  result->op = GGML_OP_RESHAPE;
6334
6347
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6335
- result->src0 = a;
6336
- result->src1 = NULL;
6348
+ result->src[0] = a;
6349
+ result->src[1] = NULL;
6337
6350
 
6338
6351
  return result;
6339
6352
  }
@@ -6358,8 +6371,8 @@ struct ggml_tensor * ggml_reshape_2d(
6358
6371
 
6359
6372
  result->op = GGML_OP_RESHAPE;
6360
6373
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6361
- result->src0 = a;
6362
- result->src1 = NULL;
6374
+ result->src[0] = a;
6375
+ result->src[1] = NULL;
6363
6376
 
6364
6377
  return result;
6365
6378
  }
@@ -6385,8 +6398,8 @@ struct ggml_tensor * ggml_reshape_3d(
6385
6398
 
6386
6399
  result->op = GGML_OP_RESHAPE;
6387
6400
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6388
- result->src0 = a;
6389
- result->src1 = NULL;
6401
+ result->src[0] = a;
6402
+ result->src[1] = NULL;
6390
6403
 
6391
6404
  return result;
6392
6405
  }
@@ -6414,8 +6427,8 @@ struct ggml_tensor * ggml_reshape_4d(
6414
6427
 
6415
6428
  result->op = GGML_OP_RESHAPE;
6416
6429
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6417
- result->src0 = a;
6418
- result->src1 = NULL;
6430
+ result->src[0] = a;
6431
+ result->src[1] = NULL;
6419
6432
 
6420
6433
  return result;
6421
6434
  }
@@ -6447,9 +6460,9 @@ struct ggml_tensor * ggml_view_1d(
6447
6460
 
6448
6461
  result->op = GGML_OP_VIEW;
6449
6462
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6450
- result->src0 = a;
6451
- result->src1 = NULL;
6452
- result->opt[0] = offs;
6463
+ result->src[0] = a;
6464
+ result->src[1] = NULL;
6465
+ result->src[2] = offs;
6453
6466
 
6454
6467
  return result;
6455
6468
  }
@@ -6489,9 +6502,9 @@ struct ggml_tensor * ggml_view_2d(
6489
6502
 
6490
6503
  result->op = GGML_OP_VIEW;
6491
6504
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6492
- result->src0 = a;
6493
- result->src1 = NULL;
6494
- result->opt[0] = offs;
6505
+ result->src[0] = a;
6506
+ result->src[1] = NULL;
6507
+ result->src[2] = offs;
6495
6508
 
6496
6509
  return result;
6497
6510
  }
@@ -6533,9 +6546,9 @@ struct ggml_tensor * ggml_view_3d(
6533
6546
 
6534
6547
  result->op = GGML_OP_VIEW;
6535
6548
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6536
- result->src0 = a;
6537
- result->src1 = NULL;
6538
- result->opt[0] = offs;
6549
+ result->src[0] = a;
6550
+ result->src[1] = NULL;
6551
+ result->src[2] = offs;
6539
6552
 
6540
6553
  return result;
6541
6554
  }
@@ -6579,9 +6592,9 @@ struct ggml_tensor * ggml_view_4d(
6579
6592
 
6580
6593
  result->op = GGML_OP_VIEW;
6581
6594
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6582
- result->src0 = a;
6583
- result->src1 = NULL;
6584
- result->opt[0] = offs;
6595
+ result->src[0] = a;
6596
+ result->src[1] = NULL;
6597
+ result->src[2] = offs;
6585
6598
 
6586
6599
  return result;
6587
6600
  }
@@ -6641,8 +6654,8 @@ struct ggml_tensor * ggml_permute(
6641
6654
 
6642
6655
  result->op = GGML_OP_PERMUTE;
6643
6656
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6644
- result->src0 = a;
6645
- result->src1 = NULL;
6657
+ result->src[0] = a;
6658
+ result->src[1] = NULL;
6646
6659
 
6647
6660
  if (is_node) {
6648
6661
  ggml_scratch_save(ctx);
@@ -6656,7 +6669,7 @@ struct ggml_tensor * ggml_permute(
6656
6669
 
6657
6670
  ggml_scratch_load(ctx);
6658
6671
 
6659
- result->opt[0] = b;
6672
+ result->src[2] = b;
6660
6673
  }
6661
6674
 
6662
6675
  return result;
@@ -6684,8 +6697,8 @@ struct ggml_tensor * ggml_transpose(
6684
6697
 
6685
6698
  result->op = GGML_OP_TRANSPOSE;
6686
6699
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6687
- result->src0 = a;
6688
- result->src1 = NULL;
6700
+ result->src[0] = a;
6701
+ result->src[1] = NULL;
6689
6702
 
6690
6703
  return result;
6691
6704
  }
@@ -6710,8 +6723,8 @@ struct ggml_tensor * ggml_get_rows(
6710
6723
 
6711
6724
  result->op = GGML_OP_GET_ROWS;
6712
6725
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6713
- result->src0 = a;
6714
- result->src1 = b;
6726
+ result->src[0] = a;
6727
+ result->src[1] = b;
6715
6728
 
6716
6729
  return result;
6717
6730
  }
@@ -6738,9 +6751,9 @@ struct ggml_tensor * ggml_get_rows_back(
6738
6751
 
6739
6752
  result->op = GGML_OP_GET_ROWS_BACK;
6740
6753
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6741
- result->src0 = a;
6742
- result->src1 = b;
6743
- result->opt[0] = c;
6754
+ result->src[0] = a;
6755
+ result->src[1] = b;
6756
+ result->src[2] = c;
6744
6757
 
6745
6758
  return result;
6746
6759
  }
@@ -6762,8 +6775,8 @@ struct ggml_tensor * ggml_diag(
6762
6775
 
6763
6776
  result->op = GGML_OP_DIAG;
6764
6777
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6765
- result->src0 = a;
6766
- result->src1 = NULL;
6778
+ result->src[0] = a;
6779
+ result->src[1] = NULL;
6767
6780
 
6768
6781
  return result;
6769
6782
  }
@@ -6795,8 +6808,8 @@ struct ggml_tensor * ggml_diag_mask_inf_impl(
6795
6808
 
6796
6809
  result->op = GGML_OP_DIAG_MASK_INF;
6797
6810
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6798
- result->src0 = a;
6799
- result->src1 = b;
6811
+ result->src[0] = a;
6812
+ result->src[1] = b;
6800
6813
 
6801
6814
  return result;
6802
6815
  }
@@ -6843,8 +6856,8 @@ struct ggml_tensor * ggml_diag_mask_zero_impl(
6843
6856
 
6844
6857
  result->op = GGML_OP_DIAG_MASK_ZERO;
6845
6858
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6846
- result->src0 = a;
6847
- result->src1 = b;
6859
+ result->src[0] = a;
6860
+ result->src[1] = b;
6848
6861
 
6849
6862
  return result;
6850
6863
  }
@@ -6879,8 +6892,8 @@ struct ggml_tensor * ggml_soft_max_impl(
6879
6892
 
6880
6893
  result->op = GGML_OP_SOFT_MAX;
6881
6894
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6882
- result->src0 = a;
6883
- result->src1 = NULL;
6895
+ result->src[0] = a;
6896
+ result->src[1] = NULL;
6884
6897
 
6885
6898
  return result;
6886
6899
  }
@@ -6915,8 +6928,8 @@ struct ggml_tensor * ggml_soft_max_back_impl(
6915
6928
 
6916
6929
  result->op = GGML_OP_SOFT_MAX_BACK;
6917
6930
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6918
- result->src0 = a;
6919
- result->src1 = b;
6931
+ result->src[0] = a;
6932
+ result->src[1] = b;
6920
6933
 
6921
6934
  return result;
6922
6935
  }
@@ -6944,6 +6957,8 @@ struct ggml_tensor * ggml_rope_impl(
6944
6957
  int n_dims,
6945
6958
  int mode,
6946
6959
  int n_ctx,
6960
+ float freq_base,
6961
+ float freq_scale,
6947
6962
  bool inplace) {
6948
6963
  GGML_ASSERT(n_past >= 0);
6949
6964
  bool is_node = false;
@@ -6956,19 +6971,21 @@ struct ggml_tensor * ggml_rope_impl(
6956
6971
 
6957
6972
  ggml_scratch_save(ctx);
6958
6973
 
6959
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6974
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
6960
6975
 
6961
6976
  ((int32_t *) b->data)[0] = n_past;
6962
6977
  ((int32_t *) b->data)[1] = n_dims;
6963
6978
  ((int32_t *) b->data)[2] = mode;
6964
6979
  ((int32_t *) b->data)[3] = n_ctx;
6980
+ memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
6981
+ memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
6965
6982
 
6966
6983
  ggml_scratch_load(ctx);
6967
6984
 
6968
6985
  result->op = GGML_OP_ROPE;
6969
6986
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
6970
- result->src0 = a;
6971
- result->src1 = b;
6987
+ result->src[0] = a;
6988
+ result->src[1] = b;
6972
6989
 
6973
6990
  return result;
6974
6991
  }
@@ -6980,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
6980
6997
  int n_dims,
6981
6998
  int mode,
6982
6999
  int n_ctx) {
6983
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
7000
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
6984
7001
  }
6985
7002
 
6986
7003
  struct ggml_tensor * ggml_rope_inplace(
@@ -6990,7 +7007,19 @@ struct ggml_tensor * ggml_rope_inplace(
6990
7007
  int n_dims,
6991
7008
  int mode,
6992
7009
  int n_ctx) {
6993
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
7010
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
7011
+ }
7012
+
7013
+ struct ggml_tensor * ggml_rope_custom_inplace(
7014
+ struct ggml_context * ctx,
7015
+ struct ggml_tensor * a,
7016
+ int n_past,
7017
+ int n_dims,
7018
+ int mode,
7019
+ int n_ctx,
7020
+ float freq_base,
7021
+ float freq_scale) {
7022
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
6994
7023
  }
6995
7024
 
6996
7025
  // ggml_rope_back
@@ -7000,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
7000
7029
  struct ggml_tensor * a,
7001
7030
  int n_past,
7002
7031
  int n_dims,
7003
- int mode) {
7032
+ int mode,
7033
+ int n_ctx) {
7004
7034
  GGML_ASSERT(n_past >= 0);
7005
7035
  GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
7006
7036
 
@@ -7014,19 +7044,20 @@ struct ggml_tensor * ggml_rope_back(
7014
7044
 
7015
7045
  ggml_scratch_save(ctx);
7016
7046
 
7017
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
7047
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7018
7048
  ggml_set_name(b, "n_past, n_dims, mode");
7019
7049
 
7020
7050
  ((int32_t *) b->data)[0] = n_past;
7021
7051
  ((int32_t *) b->data)[1] = n_dims;
7022
7052
  ((int32_t *) b->data)[2] = mode;
7053
+ ((int32_t *) b->data)[3] = n_ctx;
7023
7054
 
7024
7055
  ggml_scratch_load(ctx);
7025
7056
 
7026
7057
  result->op = GGML_OP_ROPE_BACK;
7027
7058
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7028
- result->src0 = a;
7029
- result->src1 = b;
7059
+ result->src[0] = a;
7060
+ result->src[1] = b;
7030
7061
 
7031
7062
  return result;
7032
7063
  }
@@ -7064,8 +7095,8 @@ struct ggml_tensor * ggml_alibi(
7064
7095
 
7065
7096
  result->op = GGML_OP_ALIBI;
7066
7097
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7067
- result->src0 = a;
7068
- result->src1 = b;
7098
+ result->src[0] = a;
7099
+ result->src[1] = b;
7069
7100
 
7070
7101
  return result;
7071
7102
  }
@@ -7098,8 +7129,8 @@ struct ggml_tensor * ggml_clamp(
7098
7129
 
7099
7130
  result->op = GGML_OP_CLAMP;
7100
7131
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7101
- result->src0 = a;
7102
- result->src1 = b;
7132
+ result->src[0] = a;
7133
+ result->src[1] = b;
7103
7134
 
7104
7135
  return result;
7105
7136
  }
@@ -7141,9 +7172,9 @@ GGML_API struct ggml_tensor * ggml_conv_1d(
7141
7172
 
7142
7173
  result->op = GGML_OP_CONV_1D;
7143
7174
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7144
- result->src0 = a;
7145
- result->src1 = b;
7146
- result->opt[0] = c;
7175
+ result->src[0] = a;
7176
+ result->src[1] = b;
7177
+ result->src[2] = c;
7147
7178
 
7148
7179
  return result;
7149
7180
  }
@@ -7161,7 +7192,6 @@ struct ggml_tensor* ggml_conv_2d(
7161
7192
  int d0,
7162
7193
  int d1) {
7163
7194
 
7164
- GGML_ASSERT(b->ne[3] == 1);
7165
7195
  GGML_ASSERT(a->ne[2] == b->ne[2]);
7166
7196
  bool is_node = false;
7167
7197
 
@@ -7173,7 +7203,7 @@ struct ggml_tensor* ggml_conv_2d(
7173
7203
  const int64_t ne[4] = {
7174
7204
  ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
7175
7205
  ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
7176
- a->ne[3], 1,
7206
+ a->ne[3], b->ne[3],
7177
7207
  };
7178
7208
  struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
7179
7209
 
@@ -7189,9 +7219,9 @@ struct ggml_tensor* ggml_conv_2d(
7189
7219
 
7190
7220
  result->op = GGML_OP_CONV_2D;
7191
7221
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7192
- result->src0 = a;
7193
- result->src1 = b;
7194
- result->opt[0] = c;
7222
+ result->src[0] = a;
7223
+ result->src[1] = b;
7224
+ result->src[2] = c;
7195
7225
 
7196
7226
  return result;
7197
7227
 
@@ -7208,6 +7238,98 @@ struct ggml_tensor* ggml_conv_1d_ph(
7208
7238
  return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
7209
7239
  }
7210
7240
 
7241
+
7242
+ // ggml_pool_*
7243
+
7244
+ static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
7245
+ return (ins + 2 * p - ks) / s + 1;
7246
+ }
7247
+
7248
+ // ggml_pool_2d
7249
+
7250
+ struct ggml_tensor* ggml_pool_1d(
7251
+ struct ggml_context * ctx,
7252
+ struct ggml_tensor * a,
7253
+ enum ggml_op_pool op,
7254
+ int k0,
7255
+ int s0,
7256
+ int p0) {
7257
+
7258
+ bool is_node = false;
7259
+
7260
+ if (a->grad) {
7261
+ GGML_ASSERT(false); // TODO: implement backward
7262
+ is_node = true;
7263
+ }
7264
+
7265
+ const int64_t ne[3] = {
7266
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7267
+ a->ne[1],
7268
+ };
7269
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
7270
+
7271
+ ggml_scratch_save(ctx);
7272
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
7273
+ ((int32_t*)c->data)[0] = op;
7274
+ ((int32_t*)c->data)[1] = k0;
7275
+ ((int32_t*)c->data)[2] = s0;
7276
+ ((int32_t*)c->data)[3] = p0;
7277
+ ggml_scratch_load(ctx);
7278
+
7279
+ result->op = GGML_OP_POOL_1D;
7280
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7281
+ result->src[0] = a;
7282
+ result->src[1] = c;
7283
+
7284
+ return result;
7285
+ }
7286
+
7287
+ // ggml_pool_2d
7288
+
7289
+ struct ggml_tensor* ggml_pool_2d(
7290
+ struct ggml_context * ctx,
7291
+ struct ggml_tensor * a,
7292
+ enum ggml_op_pool op,
7293
+ int k0,
7294
+ int k1,
7295
+ int s0,
7296
+ int s1,
7297
+ int p0,
7298
+ int p1) {
7299
+
7300
+ bool is_node = false;
7301
+
7302
+ if (a->grad) {
7303
+ GGML_ASSERT(false); // TODO: implement backward
7304
+ is_node = true;
7305
+ }
7306
+
7307
+ const int64_t ne[3] = {
7308
+ ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7309
+ ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
7310
+ a->ne[2],
7311
+ };
7312
+ struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
7313
+
7314
+ ggml_scratch_save(ctx);
7315
+ struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
7316
+ ((int32_t*)c->data)[0] = op;
7317
+ ((int32_t*)c->data)[1] = k0;
7318
+ ((int32_t*)c->data)[2] = k1;
7319
+ ((int32_t*)c->data)[3] = s0;
7320
+ ((int32_t*)c->data)[4] = s1;
7321
+ ((int32_t*)c->data)[5] = p0;
7322
+ ((int32_t*)c->data)[6] = p1;
7323
+ ggml_scratch_load(ctx);
7324
+
7325
+ result->op = GGML_OP_POOL_2D;
7326
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7327
+ result->src[0] = a;
7328
+ result->src[1] = c;
7329
+
7330
+ return result;
7331
+ }
7332
+
7211
7333
  // ggml_flash_attn
7212
7334
 
7213
7335
  struct ggml_tensor * ggml_flash_attn(
@@ -7230,10 +7352,10 @@ struct ggml_tensor * ggml_flash_attn(
7230
7352
 
7231
7353
  result->op = GGML_OP_FLASH_ATTN;
7232
7354
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7233
- result->src0 = q;
7234
- result->src1 = k;
7235
- result->opt[0] = v;
7236
- result->opt[1] = ggml_new_i32(ctx, masked ? 1 : 0);
7355
+ result->src[0] = q;
7356
+ result->src[1] = k;
7357
+ result->src[2] = v;
7358
+ result->src[3] = ggml_new_i32(ctx, masked ? 1 : 0);
7237
7359
 
7238
7360
  return result;
7239
7361
  }
@@ -7261,11 +7383,11 @@ struct ggml_tensor * ggml_flash_ff(
7261
7383
 
7262
7384
  result->op = GGML_OP_FLASH_FF;
7263
7385
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7264
- result->src0 = a;
7265
- result->src1 = b0;
7266
- result->opt[0] = b1;
7267
- result->opt[1] = c0;
7268
- result->opt[2] = c1;
7386
+ result->src[0] = a;
7387
+ result->src[1] = b0;
7388
+ result->src[2] = b1;
7389
+ result->src[3] = c0;
7390
+ result->src[4] = c1;
7269
7391
 
7270
7392
  return result;
7271
7393
  }
@@ -7325,11 +7447,11 @@ struct ggml_tensor * ggml_flash_attn_back(
7325
7447
 
7326
7448
  result->op = GGML_OP_FLASH_ATTN_BACK;
7327
7449
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7328
- result->src0 = q;
7329
- result->src1 = k;
7330
- result->opt[0] = v;
7331
- result->opt[1] = d;
7332
- result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
7450
+ result->src[0] = q;
7451
+ result->src[1] = k;
7452
+ result->src[2] = v;
7453
+ result->src[3] = d;
7454
+ result->src[4] = ggml_new_i32(ctx, masked ? 1 : 0);
7333
7455
 
7334
7456
  return result;
7335
7457
  }
@@ -7374,9 +7496,9 @@ struct ggml_tensor * ggml_win_part(
7374
7496
 
7375
7497
  result->op = GGML_OP_WIN_PART;
7376
7498
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7377
- result->src0 = a;
7378
- result->src1 = NULL;
7379
- result->opt[0] = b;
7499
+ result->src[0] = a;
7500
+ result->src[1] = NULL;
7501
+ result->src[2] = b;
7380
7502
 
7381
7503
  return result;
7382
7504
  }
@@ -7411,9 +7533,9 @@ struct ggml_tensor * ggml_win_unpart(
7411
7533
 
7412
7534
  result->op = GGML_OP_WIN_UNPART;
7413
7535
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7414
- result->src0 = a;
7415
- result->src1 = NULL;
7416
- result->opt[0] = b;
7536
+ result->src[0] = a;
7537
+ result->src[1] = NULL;
7538
+ result->src[2] = b;
7417
7539
 
7418
7540
  return result;
7419
7541
  }
@@ -7442,8 +7564,8 @@ struct ggml_tensor * ggml_map_unary_impl_f32(
7442
7564
 
7443
7565
  result->op = GGML_OP_MAP_UNARY;
7444
7566
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7445
- result->src0 = a;
7446
- result->opt[0] = addr_tensor;
7567
+ result->src[0] = a;
7568
+ result->src[2] = addr_tensor;
7447
7569
 
7448
7570
  return result;
7449
7571
  }
@@ -7489,9 +7611,9 @@ struct ggml_tensor * ggml_map_binary_impl_f32(
7489
7611
 
7490
7612
  result->op = GGML_OP_MAP_BINARY;
7491
7613
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7492
- result->src0 = a;
7493
- result->src1 = b;
7494
- result->opt[0] = addr_tensor;
7614
+ result->src[0] = a;
7615
+ result->src[1] = b;
7616
+ result->src[2] = addr_tensor;
7495
7617
 
7496
7618
  return result;
7497
7619
  }
@@ -7536,8 +7658,8 @@ struct ggml_tensor * ggml_map_custom1_impl_f32(
7536
7658
 
7537
7659
  result->op = GGML_OP_MAP_CUSTOM1;
7538
7660
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7539
- result->src0 = a;
7540
- result->opt[0] = addr_tensor;
7661
+ result->src[0] = a;
7662
+ result->src[2] = addr_tensor;
7541
7663
 
7542
7664
  return result;
7543
7665
  }
@@ -7581,9 +7703,9 @@ struct ggml_tensor * ggml_map_custom2_impl_f32(
7581
7703
 
7582
7704
  result->op = GGML_OP_MAP_CUSTOM2;
7583
7705
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7584
- result->src0 = a;
7585
- result->src1 = b;
7586
- result->opt[0] = addr_tensor;
7706
+ result->src[0] = a;
7707
+ result->src[1] = b;
7708
+ result->src[2] = addr_tensor;
7587
7709
 
7588
7710
  return result;
7589
7711
  }
@@ -7630,10 +7752,10 @@ struct ggml_tensor * ggml_map_custom3_impl_f32(
7630
7752
 
7631
7753
  result->op = GGML_OP_MAP_CUSTOM3;
7632
7754
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7633
- result->src0 = a;
7634
- result->src1 = b;
7635
- result->opt[0] = addr_tensor;
7636
- result->opt[1] = c;
7755
+ result->src[0] = a;
7756
+ result->src[1] = b;
7757
+ result->src[2] = addr_tensor;
7758
+ result->src[3] = c;
7637
7759
 
7638
7760
  return result;
7639
7761
  }
@@ -7673,8 +7795,8 @@ struct ggml_tensor * ggml_cross_entropy_loss(
7673
7795
 
7674
7796
  result->op = GGML_OP_CROSS_ENTROPY_LOSS;
7675
7797
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
7676
- result->src0 = a;
7677
- result->src1 = b;
7798
+ result->src[0] = a;
7799
+ result->src[1] = b;
7678
7800
 
7679
7801
  return result;
7680
7802
  }
@@ -7693,9 +7815,9 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
7693
7815
 
7694
7816
  result->op = GGML_OP_CROSS_ENTROPY_LOSS_BACK;
7695
7817
  result->grad = NULL;
7696
- result->src0 = a;
7697
- result->src1 = b;
7698
- result->opt[0] = c;
7818
+ result->src[0] = a;
7819
+ result->src[1] = b;
7820
+ result->src[2] = c;
7699
7821
 
7700
7822
  return result;
7701
7823
  }
@@ -8296,7 +8418,7 @@ static void ggml_compute_forward_add_f32(
8296
8418
  const struct ggml_tensor * src0,
8297
8419
  const struct ggml_tensor * src1,
8298
8420
  struct ggml_tensor * dst) {
8299
- GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
8421
+ GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
8300
8422
 
8301
8423
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
8302
8424
  return;
@@ -8321,23 +8443,23 @@ static void ggml_compute_forward_add_f32(
8321
8443
 
8322
8444
  if (nb10 == sizeof(float)) {
8323
8445
  for (int ir = ir0; ir < ir1; ++ir) {
8324
- // src0, src1 and dst are same shape => same indices
8325
- const int i3 = ir/(ne2*ne1);
8326
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8327
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8446
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8447
+ const int64_t i03 = ir/(ne02*ne01);
8448
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8449
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8450
+
8451
+ const int64_t i13 = i03 % ne13;
8452
+ const int64_t i12 = i02 % ne12;
8453
+ const int64_t i11 = i01 % ne11;
8328
8454
 
8455
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8456
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8457
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
8329
8458
 
8330
8459
  #ifdef GGML_USE_ACCELERATE
8331
- vDSP_vadd(
8332
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
8333
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
8334
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
8335
- ne0);
8460
+ vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
8336
8461
  #else
8337
- ggml_vec_add_f32(ne0,
8338
- (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
8339
- (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
8340
- (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
8462
+ ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
8341
8463
  #endif
8342
8464
  // }
8343
8465
  // }
@@ -8345,15 +8467,20 @@ static void ggml_compute_forward_add_f32(
8345
8467
  } else {
8346
8468
  // src1 is not contiguous
8347
8469
  for (int ir = ir0; ir < ir1; ++ir) {
8348
- // src0, src1 and dst are same shape => same indices
8349
- const int i3 = ir/(ne2*ne1);
8350
- const int i2 = (ir - i3*ne2*ne1)/ne1;
8351
- const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
8470
+ // src1 is broadcastable across src0 and dst in i1, i2, i3
8471
+ const int64_t i03 = ir/(ne02*ne01);
8472
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
8473
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
8474
+
8475
+ const int64_t i13 = i03 % ne13;
8476
+ const int64_t i12 = i02 % ne12;
8477
+ const int64_t i11 = i01 % ne11;
8478
+
8479
+ float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
8480
+ float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
8352
8481
 
8353
- float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
8354
- float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
8355
8482
  for (int i0 = 0; i0 < ne0; i0++) {
8356
- float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10);
8483
+ float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
8357
8484
 
8358
8485
  dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
8359
8486
  }
@@ -10532,7 +10659,6 @@ static void ggml_compute_forward_rms_norm_back(
10532
10659
  }
10533
10660
  }
10534
10661
 
10535
-
10536
10662
  // ggml_compute_forward_mul_mat
10537
10663
 
10538
10664
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@@ -10576,17 +10702,19 @@ static void ggml_compute_forward_mul_mat(
10576
10702
  const int ith = params->ith;
10577
10703
  const int nth = params->nth;
10578
10704
 
10579
- GGML_ASSERT(ne02 == ne12);
10580
- GGML_ASSERT(ne03 == ne13);
10581
- GGML_ASSERT(ne2 == ne12);
10582
- GGML_ASSERT(ne3 == ne13);
10583
-
10584
10705
  const enum ggml_type type = src0->type;
10585
10706
 
10707
+ const bool src1_cont = ggml_is_contiguous(src1);
10708
+
10586
10709
  ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
10587
10710
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
10588
10711
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
10589
10712
 
10713
+ GGML_ASSERT(ne0 == ne01);
10714
+ GGML_ASSERT(ne1 == ne11);
10715
+ GGML_ASSERT(ne2 == ne12);
10716
+ GGML_ASSERT(ne3 == ne13);
10717
+
10590
10718
  // we don't support permuted src0 or src1
10591
10719
  GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
10592
10720
  GGML_ASSERT(nb10 == sizeof(float));
@@ -10597,16 +10725,16 @@ static void ggml_compute_forward_mul_mat(
10597
10725
  GGML_ASSERT(nb1 <= nb2);
10598
10726
  GGML_ASSERT(nb2 <= nb3);
10599
10727
 
10600
- GGML_ASSERT(ne0 == ne01);
10601
- GGML_ASSERT(ne1 == ne11);
10602
- GGML_ASSERT(ne2 == ne02);
10603
- GGML_ASSERT(ne3 == ne03);
10604
-
10605
10728
  // nb01 >= nb00 - src0 is not transposed
10606
10729
  // compute by src0 rows
10607
10730
 
10608
10731
  #if defined(GGML_USE_CLBLAST)
10609
10732
  if (ggml_cl_can_mul_mat(src0, src1, dst)) {
10733
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10734
+ // ref: https://github.com/ggerganov/ggml/pull/224
10735
+ GGML_ASSERT(ne02 == ne12);
10736
+ GGML_ASSERT(ne03 == ne13);
10737
+
10610
10738
  if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
10611
10739
  ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
10612
10740
  }
@@ -10616,6 +10744,11 @@ static void ggml_compute_forward_mul_mat(
10616
10744
 
10617
10745
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
10618
10746
  if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
10747
+ // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
10748
+ // ref: https://github.com/ggerganov/ggml/pull/224
10749
+ GGML_ASSERT(ne02 == ne12);
10750
+ GGML_ASSERT(ne03 == ne13);
10751
+
10619
10752
  if (params->ith != 0) {
10620
10753
  return;
10621
10754
  }
@@ -10636,7 +10769,7 @@ static void ggml_compute_forward_mul_mat(
10636
10769
  float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
10637
10770
 
10638
10771
  if (type != GGML_TYPE_F32) {
10639
- float * const wdata = params->wdata;
10772
+ float * const wdata = params->wdata;
10640
10773
  ggml_to_float_t const to_float = type_traits[type].to_float;
10641
10774
 
10642
10775
  size_t id = 0;
@@ -10685,43 +10818,52 @@ static void ggml_compute_forward_mul_mat(
10685
10818
  return;
10686
10819
  }
10687
10820
 
10688
- // parallelize by src0 rows using ggml_vec_dot_q
10821
+ // parallelize by src0 rows
10822
+ const int64_t dr = (ne01 + nth - 1)/nth;
10689
10823
 
10690
- // total rows in src0
10691
- const int nr = ne01*ne02*ne03;
10824
+ const int64_t ir10 = dr*ith;
10825
+ const int64_t ir11 = MIN(ir10 + dr, ne01);
10692
10826
 
10693
- // rows per thread
10694
- const int dr = (nr + nth - 1)/nth;
10827
+ // src1 rows
10828
+ const int64_t nr1 = ne11*ne12*ne13;
10695
10829
 
10696
- // row range for this thread
10697
- const int ir0 = dr*ith;
10698
- const int ir1 = MIN(ir0 + dr, nr);
10830
+ const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10831
+ const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10699
10832
 
10700
- void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
10701
- const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
10833
+ for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
10834
+ const int64_t i13 = (ir1/(ne12*ne11));
10835
+ const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
10836
+ const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
10702
10837
 
10703
- for (int ir = ir0; ir < ir1; ++ir) {
10704
- // src0 indices
10705
- const int i03 = ir/(ne02*ne01);
10706
- const int i02 = (ir - i03*ne02*ne01)/ne01;
10707
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
10838
+ const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
10839
+ const int64_t i03 = (ir0/(ne02));
10840
+ // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
10841
+ // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
10842
+ // GG: this is likely the correct way to broadcast, though need some more thought
10843
+ // therefore leaving the comments to remind us for now
10844
+ const int64_t i02 = (i12 / (ne12 / ne02));
10845
+ // Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
10846
+ // const int64_t i02 = (ir0 - i03*ne02);
10708
10847
 
10709
- const int i13 = i03;
10710
- const int i12 = i02;
10848
+ const int64_t i1 = i11;
10849
+ const int64_t i2 = i12;
10850
+ const int64_t i3 = i13;
10711
10851
 
10712
- const int i0 = i01;
10713
- const int i2 = i02;
10714
- const int i3 = i03;
10852
+ const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
10715
10853
 
10716
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
10717
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
10854
+ // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
10855
+ // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
10856
+ // the original src1 data pointer, so we should index using the indices directly
10857
+ // TODO: this is a bit of a hack, we should probably have a better way to handle this
10858
+ const char * src1_col = (const char *) wdata +
10859
+ (src1_cont || src1->type != vec_dot_type
10860
+ ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
10861
+ : (i11*nb11 + i12*nb12 + i13*nb13));
10718
10862
 
10719
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
10863
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
10720
10864
 
10721
- assert(ne00 % 32 == 0);
10722
-
10723
- for (int64_t ic = 0; ic < ne11; ++ic) {
10724
- vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
10865
+ for (int64_t ir = ir10; ir < ir11; ++ir) {
10866
+ vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
10725
10867
  }
10726
10868
  }
10727
10869
 
@@ -11718,7 +11860,7 @@ static void ggml_compute_forward_alibi_f32(
11718
11860
 
11719
11861
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11720
11862
  const int ne1 = src0->ne[1]; // seq_len_without_past
11721
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11863
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11722
11864
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11723
11865
 
11724
11866
  const int n = ggml_nrows(src0);
@@ -11729,8 +11871,9 @@ static void ggml_compute_forward_alibi_f32(
11729
11871
  const int nb2 = src0->nb[2];
11730
11872
  //const int nb3 = src0->nb[3];
11731
11873
 
11732
- assert(nb0 == sizeof(float));
11733
- assert(ne1 + n_past == ne0); (void) n_past;
11874
+ GGML_ASSERT(nb0 == sizeof(float));
11875
+ GGML_ASSERT(ne1 + n_past == ne0);
11876
+ GGML_ASSERT(n_head == ne2);
11734
11877
 
11735
11878
  // add alibi to src0 (KQ_scaled)
11736
11879
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11754,7 +11897,7 @@ static void ggml_compute_forward_alibi_f32(
11754
11897
  m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11755
11898
  }
11756
11899
 
11757
- pdst[0] = (i-ne0+1) * m_k + src[0];
11900
+ pdst[0] = i * m_k + src[0];
11758
11901
 
11759
11902
  }
11760
11903
  }
@@ -11783,7 +11926,7 @@ static void ggml_compute_forward_alibi_f16(
11783
11926
 
11784
11927
  const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
11785
11928
  const int ne1 = src0->ne[1]; // seq_len_without_past
11786
- //const int ne2 = src0->ne[2]; // n_head -> this is k
11929
+ const int ne2 = src0->ne[2]; // n_head -> this is k
11787
11930
  //const int ne3 = src0->ne[3]; // 1 -> bsz
11788
11931
 
11789
11932
  const int n = ggml_nrows(src0);
@@ -11794,8 +11937,9 @@ static void ggml_compute_forward_alibi_f16(
11794
11937
  const int nb2 = src0->nb[2];
11795
11938
  //const int nb3 = src0->nb[3];
11796
11939
 
11797
- assert(nb0 == sizeof(ggml_fp16_t));
11798
- assert(ne1 + n_past == ne0); (void) n_past;
11940
+ GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
11941
+ GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
11942
+ GGML_ASSERT(n_head == ne2);
11799
11943
 
11800
11944
  // add alibi to src0 (KQ_scaled)
11801
11945
  const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -11820,7 +11964,7 @@ static void ggml_compute_forward_alibi_f16(
11820
11964
  }
11821
11965
 
11822
11966
  // we return F32
11823
- pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]);
11967
+ pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
11824
11968
  }
11825
11969
  }
11826
11970
  }
@@ -11948,16 +12092,21 @@ static void ggml_compute_forward_rope_f32(
11948
12092
  const struct ggml_tensor * src1,
11949
12093
  struct ggml_tensor * dst) {
11950
12094
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
11951
- GGML_ASSERT(ggml_nelements(src1) == 4);
12095
+ GGML_ASSERT(ggml_nelements(src1) == 6);
11952
12096
 
11953
12097
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
11954
12098
  return;
11955
12099
  }
11956
12100
 
12101
+ float freq_base;
12102
+ float freq_scale;
12103
+
11957
12104
  const int n_past = ((int32_t *) src1->data)[0];
11958
12105
  const int n_dims = ((int32_t *) src1->data)[1];
11959
12106
  const int mode = ((int32_t *) src1->data)[2];
11960
12107
  const int n_ctx = ((int32_t *) src1->data)[3];
12108
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12109
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
11961
12110
 
11962
12111
  assert(n_past >= 0);
11963
12112
 
@@ -11986,7 +12135,7 @@ static void ggml_compute_forward_rope_f32(
11986
12135
  // row index used to determine which thread to use
11987
12136
  int ir = 0;
11988
12137
 
11989
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
12138
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
11990
12139
 
11991
12140
  const bool is_neox = mode & 2;
11992
12141
  const bool is_glm = mode & 4;
@@ -11998,7 +12147,7 @@ static void ggml_compute_forward_rope_f32(
11998
12147
  if (ir++ < ir0) continue;
11999
12148
  if (ir > ir1) break;
12000
12149
 
12001
- float theta = (float)p;
12150
+ float theta = freq_scale * (float)p;
12002
12151
 
12003
12152
  if (is_glm) {
12004
12153
  theta = MIN(p, n_ctx - 2);
@@ -12075,16 +12224,21 @@ static void ggml_compute_forward_rope_f16(
12075
12224
  const struct ggml_tensor * src1,
12076
12225
  struct ggml_tensor * dst) {
12077
12226
  GGML_ASSERT(src1->type == GGML_TYPE_I32);
12078
- GGML_ASSERT(ggml_nelements(src1) == 4);
12227
+ GGML_ASSERT(ggml_nelements(src1) == 6);
12079
12228
 
12080
12229
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12081
12230
  return;
12082
12231
  }
12083
12232
 
12233
+ float freq_base;
12234
+ float freq_scale;
12235
+
12084
12236
  const int n_past = ((int32_t *) src1->data)[0];
12085
12237
  const int n_dims = ((int32_t *) src1->data)[1];
12086
12238
  const int mode = ((int32_t *) src1->data)[2];
12087
12239
  const int n_ctx = ((int32_t *) src1->data)[3];
12240
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12241
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
12088
12242
 
12089
12243
  assert(n_past >= 0);
12090
12244
 
@@ -12113,7 +12267,7 @@ static void ggml_compute_forward_rope_f16(
12113
12267
  // row index used to determine which thread to use
12114
12268
  int ir = 0;
12115
12269
 
12116
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
12270
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
12117
12271
 
12118
12272
  const bool is_neox = mode & 2;
12119
12273
  const bool is_glm = mode & 4;
@@ -12125,7 +12279,7 @@ static void ggml_compute_forward_rope_f16(
12125
12279
  if (ir++ < ir0) continue;
12126
12280
  if (ir > ir1) break;
12127
12281
 
12128
- float theta = (float)p;
12282
+ float theta = freq_scale * (float)p;
12129
12283
 
12130
12284
  if (is_glm) {
12131
12285
  theta = MIN(p, n_ctx - 2);
@@ -12186,7 +12340,7 @@ static void ggml_compute_forward_rope_f16(
12186
12340
  const float x0 = GGML_FP16_TO_FP32(src[0]);
12187
12341
  const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
12188
12342
 
12189
- dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12343
+ dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
12190
12344
  dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
12191
12345
  }
12192
12346
  }
@@ -12225,7 +12379,7 @@ static void ggml_compute_forward_rope_back_f32(
12225
12379
  const struct ggml_tensor * src1,
12226
12380
  struct ggml_tensor * dst) {
12227
12381
  assert(src1->type == GGML_TYPE_I32);
12228
- assert(ggml_nelements(src1) == 3);
12382
+ assert(ggml_nelements(src1) == 4);
12229
12383
 
12230
12384
  if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
12231
12385
  return;
@@ -12868,12 +13022,13 @@ static void ggml_compute_forward_conv_1d(
12868
13022
  };
12869
13023
  }
12870
13024
 
12871
- // ggml_compute_forward_conv_2d_sk_p0
13025
+ // ggml_compute_forward_conv_2d
12872
13026
 
12873
- static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
13027
+ static void ggml_compute_forward_conv_2d_f16_f32(
12874
13028
  const struct ggml_compute_params * params,
12875
13029
  const struct ggml_tensor * src0,
12876
13030
  const struct ggml_tensor * src1,
13031
+ const struct ggml_tensor * opt0,
12877
13032
  struct ggml_tensor * dst) {
12878
13033
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
12879
13034
  GGML_ASSERT(src1->type == GGML_TYPE_F32);
@@ -12893,11 +13048,17 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12893
13048
  // size of the convolution row - the kernel size unrolled across all channels
12894
13049
  const int ew0 = nk0*nk1*ne02;
12895
13050
 
13051
+ const int32_t s0 = ((const int32_t*)(opt0->data))[0];
13052
+ const int32_t s1 = ((const int32_t*)(opt0->data))[1];
13053
+ const int32_t p0 = ((const int32_t*)(opt0->data))[2];
13054
+ const int32_t p1 = ((const int32_t*)(opt0->data))[3];
13055
+ const int32_t d0 = ((const int32_t*)(opt0->data))[4];
13056
+ const int32_t d1 = ((const int32_t*)(opt0->data))[5];
13057
+
12896
13058
  GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
12897
13059
  GGML_ASSERT(nb10 == sizeof(float));
12898
13060
 
12899
13061
  if (params->type == GGML_TASK_INIT) {
12900
- // TODO: fix this memset (wsize is overestimated)
12901
13062
  memset(params->wdata, 0, params->wsize);
12902
13063
 
12903
13064
  // prepare source data (src1)
@@ -12912,8 +13073,13 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12912
13073
  for (int i0 = 0; i0 < ne0; i0++) {
12913
13074
  for (int ik1 = 0; ik1 < nk1; ik1++) {
12914
13075
  for (int ik0 = 0; ik0 < nk0; ik0++) {
12915
- dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
12916
- GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]);
13076
+ const int idx0 = i0*s0 + ik0*d0 - p0;
13077
+ const int idx1 = i1*s1 + ik1*d1 - p1;
13078
+
13079
+ if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
13080
+ dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
13081
+ GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
13082
+ }
12917
13083
  }
12918
13084
  }
12919
13085
  }
@@ -12940,32 +13106,36 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
12940
13106
 
12941
13107
  ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
12942
13108
 
12943
- for (int i2 = ip0; i2 < ip1; i2++) {
12944
- float * dst_data = (float *)((char *) dst->data + i2*nb2);
12945
-
12946
- for (int i1 = 0; i1 < ne1; ++i1) {
12947
- for (int i0 = 0; i0 < ne0; ++i0) {
12948
- ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
12949
- (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
12950
- (ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
13109
+ for (int i3 = 0; i3 < ne3; i3++) {
13110
+ for (int i2 = ip0; i2 < ip1; i2++) {
13111
+ float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
13112
+
13113
+ for (int i1 = 0; i1 < ne1; ++i1) {
13114
+ for (int i0 = 0; i0 < ne0; ++i0) {
13115
+ ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
13116
+ (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
13117
+ (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
13118
+ }
12951
13119
  }
12952
13120
  }
12953
13121
  }
12954
13122
  }
12955
13123
 
12956
- static void ggml_compute_forward_conv_2d_sk_p0(
13124
+ static void ggml_compute_forward_conv_2d(
12957
13125
  const struct ggml_compute_params * params,
12958
13126
  const struct ggml_tensor * src0,
12959
13127
  const struct ggml_tensor * src1,
12960
- struct ggml_tensor * dst) {
13128
+ const struct ggml_tensor * opt0,
13129
+ struct ggml_tensor * dst
13130
+ ) {
12961
13131
  switch (src0->type) {
12962
13132
  case GGML_TYPE_F16:
12963
13133
  {
12964
- ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst);
13134
+ ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
12965
13135
  } break;
12966
13136
  case GGML_TYPE_F32:
12967
13137
  {
12968
- //ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst);
13138
+ //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
12969
13139
  GGML_ASSERT(false);
12970
13140
  } break;
12971
13141
  default:
@@ -12975,31 +13145,164 @@ static void ggml_compute_forward_conv_2d_sk_p0(
12975
13145
  }
12976
13146
  }
12977
13147
 
12978
- // ggml_compute_forward_conv_2d
13148
+ // ggml_compute_forward_pool_1d_sk_p0
12979
13149
 
12980
- static void ggml_compute_forward_conv_2d(
13150
+ static void ggml_compute_forward_pool_1d_sk_p0(
13151
+ const struct ggml_compute_params * params,
13152
+ const enum ggml_op_pool op,
13153
+ const struct ggml_tensor * src,
13154
+ const int k,
13155
+ struct ggml_tensor * dst) {
13156
+ assert(src->type == GGML_TYPE_F32);
13157
+ assert(params->ith == 0);
13158
+
13159
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13160
+ return;
13161
+ }
13162
+
13163
+ const char * cdata = (const char *)src->data;
13164
+ const char * const data_end = cdata + ggml_nbytes(src);
13165
+ float * drow = (float *)dst->data;
13166
+
13167
+ const int64_t rs = dst->ne[0];
13168
+
13169
+ while (cdata < data_end) {
13170
+ const float * const srow = (const float *)cdata;
13171
+
13172
+ int j = 0;
13173
+
13174
+ for (int64_t i = 0; i < rs; ++i) {
13175
+ switch (op) {
13176
+ case GGML_OP_POOL_AVG: drow[i] = 0; break;
13177
+ case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
13178
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13179
+ }
13180
+ for (int ki = 0; ki < k; ++ki) {
13181
+ switch (op) {
13182
+ case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
13183
+ case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
13184
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13185
+ }
13186
+ ++j;
13187
+ }
13188
+ switch (op) {
13189
+ case GGML_OP_POOL_AVG: drow[i] /= k; break;
13190
+ case GGML_OP_POOL_MAX: break;
13191
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13192
+ }
13193
+ }
13194
+
13195
+ cdata += src->nb[1];
13196
+ drow += rs;
13197
+ }
13198
+ }
13199
+
13200
+ // ggml_compute_forward_pool_1d
13201
+
13202
+ static void ggml_compute_forward_pool_1d(
12981
13203
  const struct ggml_compute_params* params,
12982
13204
  const struct ggml_tensor* src0,
12983
- const struct ggml_tensor* src1,
12984
13205
  const struct ggml_tensor* opt0,
12985
13206
  struct ggml_tensor* dst) {
12986
- const int32_t s0 = ((const int32_t*)(opt0->data))[0];
12987
- const int32_t s1 = ((const int32_t*)(opt0->data))[1];
12988
- const int32_t p0 = ((const int32_t*)(opt0->data))[2];
12989
- const int32_t p1 = ((const int32_t*)(opt0->data))[3];
12990
- const int32_t d0 = ((const int32_t*)(opt0->data))[4];
12991
- const int32_t d1 = ((const int32_t*)(opt0->data))[5];
12992
- GGML_ASSERT(d0 == 1); // dilation not supported
12993
- GGML_ASSERT(d1 == 1);
13207
+ GGML_ASSERT(opt0->ne[0] == 4);
13208
+ const int* opts = (const int*)opt0->data;
13209
+ enum ggml_op_pool op = opts[0];
13210
+ const int k0 = opts[1];
13211
+ const int s0 = opts[2];
13212
+ const int p0 = opts[3];
12994
13213
  GGML_ASSERT(p0 == 0); // padding not supported
12995
- GGML_ASSERT(p1 == 0);
13214
+ GGML_ASSERT(k0 == s0); // only s = k supported
13215
+
13216
+ ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
13217
+ }
12996
13218
 
12997
- if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
12998
- ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
13219
+ // ggml_compute_forward_pool_2d_sk_p0
13220
+
13221
+ static void ggml_compute_forward_pool_2d_sk_p0(
13222
+ const struct ggml_compute_params * params,
13223
+ const enum ggml_op_pool op,
13224
+ const struct ggml_tensor * src,
13225
+ const int k0,
13226
+ const int k1,
13227
+ struct ggml_tensor * dst) {
13228
+ assert(src->type == GGML_TYPE_F32);
13229
+ assert(params->ith == 0);
13230
+
13231
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
13232
+ return;
13233
+ }
13234
+
13235
+ const char * cdata = (const char*)src->data;
13236
+ const char * const data_end = cdata + ggml_nbytes(src);
13237
+
13238
+ const int64_t px = dst->ne[0];
13239
+ const int64_t py = dst->ne[1];
13240
+ const int64_t pa = px * py;
13241
+
13242
+ float * dplane = (float *)dst->data;
13243
+
13244
+ const int ka = k0 * k1;
13245
+
13246
+ while (cdata < data_end) {
13247
+ for (int oy = 0; oy < py; ++oy) {
13248
+ float * const drow = dplane + oy * px;
13249
+ for (int ox = 0; ox < px; ++ox) {
13250
+ float * const out = drow + ox;
13251
+ switch (op) {
13252
+ case GGML_OP_POOL_AVG: *out = 0; break;
13253
+ case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
13254
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13255
+ }
13256
+
13257
+ const int ix = ox * k0;
13258
+ const int iy = oy * k1;
13259
+
13260
+ for (int ky = 0; ky < k1; ++ky) {
13261
+ const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
13262
+ for (int kx = 0; kx < k0; ++kx) {
13263
+ int j = ix + kx;
13264
+ switch (op) {
13265
+ case GGML_OP_POOL_AVG: *out += srow[j]; break;
13266
+ case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
13267
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13268
+ }
13269
+ }
13270
+ }
13271
+ switch (op) {
13272
+ case GGML_OP_POOL_AVG: *out /= ka; break;
13273
+ case GGML_OP_POOL_MAX: break;
13274
+ case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
13275
+ }
13276
+ }
13277
+ }
13278
+
13279
+ cdata += src->nb[2];
13280
+ dplane += pa;
12999
13281
  }
13000
- else {
13001
- GGML_ASSERT(false); // only stride equal to kernel size is supported
13002
- };
13282
+ }
13283
+
13284
+ // ggml_compute_forward_pool_2d
13285
+
13286
+ static void ggml_compute_forward_pool_2d(
13287
+ const struct ggml_compute_params * params,
13288
+ const struct ggml_tensor * src0,
13289
+ const struct ggml_tensor * opt0,
13290
+ struct ggml_tensor * dst) {
13291
+ GGML_ASSERT(opt0->ne[0] == 7);
13292
+ const int* opts = (const int*)opt0->data;
13293
+ enum ggml_op_pool op = opts[0];
13294
+ const int k0 = opts[1];
13295
+ const int k1 = opts[2];
13296
+ const int s0 = opts[3];
13297
+ const int s1 = opts[4];
13298
+ const int p0 = opts[5];
13299
+ const int p1 = opts[6];
13300
+ GGML_ASSERT(p0 == 0);
13301
+ GGML_ASSERT(p1 == 0); // padding not supported
13302
+ GGML_ASSERT(k0 == s0);
13303
+ GGML_ASSERT(k1 == s1); // only s = k supported
13304
+
13305
+ ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
13003
13306
  }
13004
13307
 
13005
13308
 
@@ -14566,287 +14869,295 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14566
14869
  if (skip_cpu) {
14567
14870
  return;
14568
14871
  }
14569
- GGML_ASSERT(tensor->src0 == NULL || tensor->src0->backend == GGML_BACKEND_CPU);
14570
- GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
14872
+ GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
14873
+ GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
14571
14874
  #endif // GGML_USE_CUBLAS
14572
14875
 
14573
14876
  switch (tensor->op) {
14574
14877
  case GGML_OP_DUP:
14575
14878
  {
14576
- ggml_compute_forward_dup(params, tensor->src0, tensor);
14879
+ ggml_compute_forward_dup(params, tensor->src[0], tensor);
14577
14880
  } break;
14578
14881
  case GGML_OP_ADD:
14579
14882
  {
14580
- ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
14883
+ ggml_compute_forward_add(params, tensor->src[0], tensor->src[1], tensor);
14581
14884
  } break;
14582
14885
  case GGML_OP_ADD1:
14583
14886
  {
14584
- ggml_compute_forward_add1(params, tensor->src0, tensor->src1, tensor);
14887
+ ggml_compute_forward_add1(params, tensor->src[0], tensor->src[1], tensor);
14585
14888
  } break;
14586
14889
  case GGML_OP_ACC:
14587
14890
  {
14588
- ggml_compute_forward_acc(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
14891
+ ggml_compute_forward_acc(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14589
14892
  } break;
14590
14893
  case GGML_OP_SUB:
14591
14894
  {
14592
- ggml_compute_forward_sub(params, tensor->src0, tensor->src1, tensor);
14895
+ ggml_compute_forward_sub(params, tensor->src[0], tensor->src[1], tensor);
14593
14896
  } break;
14594
14897
  case GGML_OP_MUL:
14595
14898
  {
14596
- ggml_compute_forward_mul(params, tensor->src0, tensor->src1, tensor);
14899
+ ggml_compute_forward_mul(params, tensor->src[0], tensor->src[1], tensor);
14597
14900
  } break;
14598
14901
  case GGML_OP_DIV:
14599
14902
  {
14600
- ggml_compute_forward_div(params, tensor->src0, tensor->src1, tensor);
14903
+ ggml_compute_forward_div(params, tensor->src[0], tensor->src[1], tensor);
14601
14904
  } break;
14602
14905
  case GGML_OP_SQR:
14603
14906
  {
14604
- ggml_compute_forward_sqr(params, tensor->src0, tensor);
14907
+ ggml_compute_forward_sqr(params, tensor->src[0], tensor);
14605
14908
  } break;
14606
14909
  case GGML_OP_SQRT:
14607
14910
  {
14608
- ggml_compute_forward_sqrt(params, tensor->src0, tensor);
14911
+ ggml_compute_forward_sqrt(params, tensor->src[0], tensor);
14609
14912
  } break;
14610
14913
  case GGML_OP_LOG:
14611
14914
  {
14612
- ggml_compute_forward_log(params, tensor->src0, tensor);
14915
+ ggml_compute_forward_log(params, tensor->src[0], tensor);
14613
14916
  } break;
14614
14917
  case GGML_OP_SUM:
14615
14918
  {
14616
- ggml_compute_forward_sum(params, tensor->src0, tensor);
14919
+ ggml_compute_forward_sum(params, tensor->src[0], tensor);
14617
14920
  } break;
14618
14921
  case GGML_OP_SUM_ROWS:
14619
14922
  {
14620
- ggml_compute_forward_sum_rows(params, tensor->src0, tensor);
14923
+ ggml_compute_forward_sum_rows(params, tensor->src[0], tensor);
14621
14924
  } break;
14622
14925
  case GGML_OP_MEAN:
14623
14926
  {
14624
- ggml_compute_forward_mean(params, tensor->src0, tensor);
14927
+ ggml_compute_forward_mean(params, tensor->src[0], tensor);
14625
14928
  } break;
14626
14929
  case GGML_OP_ARGMAX:
14627
14930
  {
14628
- ggml_compute_forward_argmax(params, tensor->src0, tensor);
14931
+ ggml_compute_forward_argmax(params, tensor->src[0], tensor);
14629
14932
  } break;
14630
14933
  case GGML_OP_REPEAT:
14631
14934
  {
14632
- ggml_compute_forward_repeat(params, tensor->src0, tensor);
14935
+ ggml_compute_forward_repeat(params, tensor->src[0], tensor);
14633
14936
  } break;
14634
14937
  case GGML_OP_REPEAT_BACK:
14635
14938
  {
14636
- ggml_compute_forward_repeat_back(params, tensor->src0, tensor);
14939
+ ggml_compute_forward_repeat_back(params, tensor->src[0], tensor);
14637
14940
  } break;
14638
14941
  case GGML_OP_ABS:
14639
14942
  {
14640
- ggml_compute_forward_abs(params, tensor->src0, tensor);
14943
+ ggml_compute_forward_abs(params, tensor->src[0], tensor);
14641
14944
  } break;
14642
14945
  case GGML_OP_SGN:
14643
14946
  {
14644
- ggml_compute_forward_sgn(params, tensor->src0, tensor);
14947
+ ggml_compute_forward_sgn(params, tensor->src[0], tensor);
14645
14948
  } break;
14646
14949
  case GGML_OP_NEG:
14647
14950
  {
14648
- ggml_compute_forward_neg(params, tensor->src0, tensor);
14951
+ ggml_compute_forward_neg(params, tensor->src[0], tensor);
14649
14952
  } break;
14650
14953
  case GGML_OP_STEP:
14651
14954
  {
14652
- ggml_compute_forward_step(params, tensor->src0, tensor);
14955
+ ggml_compute_forward_step(params, tensor->src[0], tensor);
14653
14956
  } break;
14654
14957
  case GGML_OP_TANH:
14655
14958
  {
14656
- ggml_compute_forward_tanh(params, tensor->src0, tensor);
14959
+ ggml_compute_forward_tanh(params, tensor->src[0], tensor);
14657
14960
  } break;
14658
14961
  case GGML_OP_ELU:
14659
14962
  {
14660
- ggml_compute_forward_elu(params, tensor->src0, tensor);
14963
+ ggml_compute_forward_elu(params, tensor->src[0], tensor);
14661
14964
  } break;
14662
14965
  case GGML_OP_RELU:
14663
14966
  {
14664
- ggml_compute_forward_relu(params, tensor->src0, tensor);
14967
+ ggml_compute_forward_relu(params, tensor->src[0], tensor);
14665
14968
  } break;
14666
14969
  case GGML_OP_GELU:
14667
14970
  {
14668
- ggml_compute_forward_gelu(params, tensor->src0, tensor);
14971
+ ggml_compute_forward_gelu(params, tensor->src[0], tensor);
14669
14972
  } break;
14670
14973
  case GGML_OP_GELU_QUICK:
14671
14974
  {
14672
- ggml_compute_forward_gelu_quick(params, tensor->src0, tensor);
14975
+ ggml_compute_forward_gelu_quick(params, tensor->src[0], tensor);
14673
14976
  } break;
14674
14977
  case GGML_OP_SILU:
14675
14978
  {
14676
- ggml_compute_forward_silu(params, tensor->src0, tensor);
14979
+ ggml_compute_forward_silu(params, tensor->src[0], tensor);
14677
14980
  } break;
14678
14981
  case GGML_OP_SILU_BACK:
14679
14982
  {
14680
- ggml_compute_forward_silu_back(params, tensor->src0, tensor->src1, tensor);
14983
+ ggml_compute_forward_silu_back(params, tensor->src[0], tensor->src[1], tensor);
14681
14984
  } break;
14682
14985
  case GGML_OP_NORM:
14683
14986
  {
14684
- ggml_compute_forward_norm(params, tensor->src0, tensor);
14987
+ ggml_compute_forward_norm(params, tensor->src[0], tensor);
14685
14988
  } break;
14686
14989
  case GGML_OP_RMS_NORM:
14687
14990
  {
14688
- ggml_compute_forward_rms_norm(params, tensor->src0, tensor);
14991
+ ggml_compute_forward_rms_norm(params, tensor->src[0], tensor);
14689
14992
  } break;
14690
14993
  case GGML_OP_RMS_NORM_BACK:
14691
14994
  {
14692
- ggml_compute_forward_rms_norm_back(params, tensor->src0, tensor->src1, tensor);
14995
+ ggml_compute_forward_rms_norm_back(params, tensor->src[0], tensor->src[1], tensor);
14693
14996
  } break;
14694
14997
  case GGML_OP_MUL_MAT:
14695
14998
  {
14696
- ggml_compute_forward_mul_mat(params, tensor->src0, tensor->src1, tensor);
14999
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
14697
15000
  } break;
14698
15001
  case GGML_OP_OUT_PROD:
14699
15002
  {
14700
- ggml_compute_forward_out_prod(params, tensor->src0, tensor->src1, tensor);
15003
+ ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor);
14701
15004
  } break;
14702
15005
  case GGML_OP_SCALE:
14703
15006
  {
14704
- ggml_compute_forward_scale(params, tensor->src0, tensor->src1, tensor);
15007
+ ggml_compute_forward_scale(params, tensor->src[0], tensor->src[1], tensor);
14705
15008
  } break;
14706
15009
  case GGML_OP_SET:
14707
15010
  {
14708
- ggml_compute_forward_set(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15011
+ ggml_compute_forward_set(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14709
15012
  } break;
14710
15013
  case GGML_OP_CPY:
14711
15014
  {
14712
- ggml_compute_forward_cpy(params, tensor->src0, tensor);
15015
+ ggml_compute_forward_cpy(params, tensor->src[0], tensor);
14713
15016
  } break;
14714
15017
  case GGML_OP_CONT:
14715
15018
  {
14716
- ggml_compute_forward_cont(params, tensor->src0, tensor);
15019
+ ggml_compute_forward_cont(params, tensor->src[0], tensor);
14717
15020
  } break;
14718
15021
  case GGML_OP_RESHAPE:
14719
15022
  {
14720
- ggml_compute_forward_reshape(params, tensor->src0, tensor);
15023
+ ggml_compute_forward_reshape(params, tensor->src[0], tensor);
14721
15024
  } break;
14722
15025
  case GGML_OP_VIEW:
14723
15026
  {
14724
- ggml_compute_forward_view(params, tensor->src0);
15027
+ ggml_compute_forward_view(params, tensor->src[0]);
14725
15028
  } break;
14726
15029
  case GGML_OP_PERMUTE:
14727
15030
  {
14728
- ggml_compute_forward_permute(params, tensor->src0);
15031
+ ggml_compute_forward_permute(params, tensor->src[0]);
14729
15032
  } break;
14730
15033
  case GGML_OP_TRANSPOSE:
14731
15034
  {
14732
- ggml_compute_forward_transpose(params, tensor->src0);
15035
+ ggml_compute_forward_transpose(params, tensor->src[0]);
14733
15036
  } break;
14734
15037
  case GGML_OP_GET_ROWS:
14735
15038
  {
14736
- ggml_compute_forward_get_rows(params, tensor->src0, tensor->src1, tensor);
15039
+ ggml_compute_forward_get_rows(params, tensor->src[0], tensor->src[1], tensor);
14737
15040
  } break;
14738
15041
  case GGML_OP_GET_ROWS_BACK:
14739
15042
  {
14740
- ggml_compute_forward_get_rows_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15043
+ ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14741
15044
  } break;
14742
15045
  case GGML_OP_DIAG:
14743
15046
  {
14744
- ggml_compute_forward_diag(params, tensor->src0, tensor);
15047
+ ggml_compute_forward_diag(params, tensor->src[0], tensor);
14745
15048
  } break;
14746
15049
  case GGML_OP_DIAG_MASK_INF:
14747
15050
  {
14748
- ggml_compute_forward_diag_mask_inf(params, tensor->src0, tensor->src1, tensor);
15051
+ ggml_compute_forward_diag_mask_inf(params, tensor->src[0], tensor->src[1], tensor);
14749
15052
  } break;
14750
15053
  case GGML_OP_DIAG_MASK_ZERO:
14751
15054
  {
14752
- ggml_compute_forward_diag_mask_zero(params, tensor->src0, tensor->src1, tensor);
15055
+ ggml_compute_forward_diag_mask_zero(params, tensor->src[0], tensor->src[1], tensor);
14753
15056
  } break;
14754
15057
  case GGML_OP_SOFT_MAX:
14755
15058
  {
14756
- ggml_compute_forward_soft_max(params, tensor->src0, tensor);
15059
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
14757
15060
  } break;
14758
15061
  case GGML_OP_SOFT_MAX_BACK:
14759
15062
  {
14760
- ggml_compute_forward_soft_max_back(params, tensor->src0, tensor->src1, tensor);
15063
+ ggml_compute_forward_soft_max_back(params, tensor->src[0], tensor->src[1], tensor);
14761
15064
  } break;
14762
15065
  case GGML_OP_ROPE:
14763
15066
  {
14764
- ggml_compute_forward_rope(params, tensor->src0, tensor->src1, tensor);
15067
+ ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);
14765
15068
  } break;
14766
15069
  case GGML_OP_ROPE_BACK:
14767
15070
  {
14768
- ggml_compute_forward_rope_back(params, tensor->src0, tensor->src1, tensor);
15071
+ ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor);
14769
15072
  } break;
14770
15073
  case GGML_OP_ALIBI:
14771
15074
  {
14772
- ggml_compute_forward_alibi(params, tensor->src0, tensor->src1, tensor);
15075
+ ggml_compute_forward_alibi(params, tensor->src[0], tensor->src[1], tensor);
14773
15076
  } break;
14774
15077
  case GGML_OP_CLAMP:
14775
15078
  {
14776
- ggml_compute_forward_clamp(params, tensor->src0, tensor->src1, tensor);
15079
+ ggml_compute_forward_clamp(params, tensor->src[0], tensor->src[1], tensor);
14777
15080
  } break;
14778
15081
  case GGML_OP_CONV_1D:
14779
15082
  {
14780
- ggml_compute_forward_conv_1d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15083
+ ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14781
15084
  } break;
14782
15085
  case GGML_OP_CONV_2D:
14783
15086
  {
14784
- ggml_compute_forward_conv_2d(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15087
+ ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
15088
+ } break;
15089
+ case GGML_OP_POOL_1D:
15090
+ {
15091
+ ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
15092
+ } break;
15093
+ case GGML_OP_POOL_2D:
15094
+ {
15095
+ ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
14785
15096
  } break;
14786
15097
  case GGML_OP_FLASH_ATTN:
14787
15098
  {
14788
- const int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15099
+ const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
14789
15100
  GGML_ASSERT(t == 0 || t == 1);
14790
15101
  const bool masked = t != 0;
14791
- ggml_compute_forward_flash_attn(params, tensor->src0, tensor->src1, tensor->opt[0], masked, tensor);
15102
+ ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
14792
15103
  } break;
14793
15104
  case GGML_OP_FLASH_FF:
14794
15105
  {
14795
- ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
15106
+ ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
14796
15107
  } break;
14797
15108
  case GGML_OP_FLASH_ATTN_BACK:
14798
15109
  {
14799
- int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
15110
+ int32_t t = ggml_get_i32_1d(tensor->src[4], 0);
14800
15111
  GGML_ASSERT(t == 0 || t == 1);
14801
15112
  bool masked = t != 0;
14802
- ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
15113
+ ggml_compute_forward_flash_attn_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], masked, tensor);
14803
15114
  } break;
14804
15115
  case GGML_OP_WIN_PART:
14805
15116
  {
14806
- ggml_compute_forward_win_part(params, tensor->src0, tensor->opt[0], tensor);
15117
+ ggml_compute_forward_win_part(params, tensor->src[0], tensor->src[2], tensor);
14807
15118
  } break;
14808
15119
  case GGML_OP_WIN_UNPART:
14809
15120
  {
14810
- ggml_compute_forward_win_unpart(params, tensor->src0, tensor->opt[0], tensor);
15121
+ ggml_compute_forward_win_unpart(params, tensor->src[0], tensor->src[2], tensor);
14811
15122
  } break;
14812
15123
  case GGML_OP_MAP_UNARY:
14813
15124
  {
14814
- const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
14815
- ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
15125
+ const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->src[2]->data);
15126
+ ggml_compute_forward_map_unary(params, tensor->src[0], tensor, fun);
14816
15127
  }
14817
15128
  break;
14818
15129
  case GGML_OP_MAP_BINARY:
14819
15130
  {
14820
- const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->opt[0]->data);
14821
- ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
15131
+ const ggml_binary_op_f32_t fun = *((ggml_binary_op_f32_t *)tensor->src[2]->data);
15132
+ ggml_compute_forward_map_binary(params, tensor->src[0], tensor->src[1], tensor, fun);
14822
15133
  }
14823
15134
  break;
14824
15135
  case GGML_OP_MAP_CUSTOM1:
14825
15136
  {
14826
- const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->opt[0]->data);
14827
- ggml_compute_forward_map_custom1(params, tensor->src0, tensor, fun);
15137
+ const ggml_custom1_op_f32_t fun = *((ggml_custom1_op_f32_t *)tensor->src[2]->data);
15138
+ ggml_compute_forward_map_custom1(params, tensor->src[0], tensor, fun);
14828
15139
  }
14829
15140
  break;
14830
15141
  case GGML_OP_MAP_CUSTOM2:
14831
15142
  {
14832
- const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->opt[0]->data);
14833
- ggml_compute_forward_map_custom2(params, tensor->src0, tensor->src1, tensor, fun);
15143
+ const ggml_custom2_op_f32_t fun = *((ggml_custom2_op_f32_t *)tensor->src[2]->data);
15144
+ ggml_compute_forward_map_custom2(params, tensor->src[0], tensor->src[1], tensor, fun);
14834
15145
  }
14835
15146
  break;
14836
15147
  case GGML_OP_MAP_CUSTOM3:
14837
15148
  {
14838
- const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->opt[0]->data);
14839
- ggml_compute_forward_map_custom3(params, tensor->src0, tensor->src1, tensor->opt[1], tensor, fun);
15149
+ const ggml_custom3_op_f32_t fun = *((ggml_custom3_op_f32_t *)tensor->src[2]->data);
15150
+ ggml_compute_forward_map_custom3(params, tensor->src[0], tensor->src[1], tensor->src[3], tensor, fun);
14840
15151
  }
14841
15152
  break;
14842
15153
  case GGML_OP_CROSS_ENTROPY_LOSS:
14843
15154
  {
14844
- ggml_compute_forward_cross_entropy_loss(params, tensor->src0, tensor->src1, tensor);
15155
+ ggml_compute_forward_cross_entropy_loss(params, tensor->src[0], tensor->src[1], tensor);
14845
15156
  }
14846
15157
  break;
14847
15158
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
14848
15159
  {
14849
- ggml_compute_forward_cross_entropy_loss_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor);
15160
+ ggml_compute_forward_cross_entropy_loss_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
14850
15161
  }
14851
15162
  break;
14852
15163
  case GGML_OP_NONE:
@@ -14863,8 +15174,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
14863
15174
  ////////////////////////////////////////////////////////////////////////////////
14864
15175
 
14865
15176
  static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
14866
- struct ggml_tensor * src0 = tensor->src0;
14867
- struct ggml_tensor * src1 = tensor->src1;
15177
+ struct ggml_tensor * src0 = tensor->src[0];
15178
+ struct ggml_tensor * src1 = tensor->src[1];
14868
15179
 
14869
15180
  switch (tensor->op) {
14870
15181
  case GGML_OP_DUP:
@@ -14900,12 +15211,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
14900
15211
  src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
14901
15212
  }
14902
15213
  if (src1->grad) {
14903
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
14904
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
14905
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
14906
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
14907
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
14908
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15214
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15215
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15216
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15217
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15218
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15219
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
14909
15220
 
14910
15221
  struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx,
14911
15222
  tensor->grad,
@@ -15213,12 +15524,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15213
15524
  } break;
15214
15525
  case GGML_OP_SET:
15215
15526
  {
15216
- GGML_ASSERT(ggml_nelements(tensor->opt[0]) == 5);
15217
- GGML_ASSERT(tensor->opt[0]->type == GGML_TYPE_I32);
15218
- const size_t nb1 = (( int32_t * ) tensor->opt[0]->data)[0];
15219
- const size_t nb2 = (( int32_t * ) tensor->opt[0]->data)[1];
15220
- const size_t nb3 = (( int32_t * ) tensor->opt[0]->data)[2];
15221
- const size_t offset = (( int32_t * ) tensor->opt[0]->data)[3];
15527
+ GGML_ASSERT(ggml_nelements(tensor->src[2]) == 5);
15528
+ GGML_ASSERT(tensor->src[2]->type == GGML_TYPE_I32);
15529
+ const size_t nb1 = (( int32_t * ) tensor->src[2]->data)[0];
15530
+ const size_t nb2 = (( int32_t * ) tensor->src[2]->data)[1];
15531
+ const size_t nb3 = (( int32_t * ) tensor->src[2]->data)[2];
15532
+ const size_t offset = (( int32_t * ) tensor->src[2]->data)[3];
15222
15533
 
15223
15534
  struct ggml_tensor * tensor_grad_view = NULL;
15224
15535
 
@@ -15295,8 +15606,8 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15295
15606
  if (src0->grad) {
15296
15607
  size_t offset;
15297
15608
 
15298
- GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->opt[0]));
15299
- memcpy(&offset, tensor->opt[0]->data, sizeof(offset));
15609
+ GGML_ASSERT(sizeof(offset) <= ggml_nbytes(tensor->src[2]));
15610
+ memcpy(&offset, tensor->src[2]->data, sizeof(offset));
15300
15611
 
15301
15612
  size_t nb1 = tensor->nb[1];
15302
15613
  size_t nb2 = tensor->nb[2];
@@ -15323,7 +15634,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15323
15634
  {
15324
15635
  // necessary for llama
15325
15636
  if (src0->grad) {
15326
- int32_t * axes = (int32_t *) tensor->opt[0]->data;
15637
+ int32_t * axes = (int32_t *) tensor->src[2]->data;
15327
15638
  int axis0 = axes[0] & 0x3;
15328
15639
  int axis1 = axes[1] & 0x3;
15329
15640
  int axis2 = axes[2] & 0x3;
@@ -15427,17 +15738,19 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15427
15738
  // necessary for llama
15428
15739
  if (src0->grad) {
15429
15740
  assert(src1->type == GGML_TYPE_I32);
15430
- assert(ggml_nelements(src1) == 4);
15741
+ assert(ggml_nelements(src1) == 6);
15431
15742
  const int n_past = ((int32_t *) src1->data)[0];
15432
15743
  const int n_dims = ((int32_t *) src1->data)[1];
15433
15744
  const int mode = ((int32_t *) src1->data)[2];
15745
+ const int n_ctx = ((int32_t *) src1->data)[3];
15434
15746
  src0->grad = ggml_add_impl(ctx,
15435
15747
  src0->grad,
15436
15748
  ggml_rope_back(ctx,
15437
15749
  tensor->grad,
15438
15750
  n_past,
15439
15751
  n_dims,
15440
- mode),
15752
+ mode,
15753
+ n_ctx),
15441
15754
  inplace);
15442
15755
  }
15443
15756
  if (src1->grad) {
@@ -15483,18 +15796,26 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15483
15796
  {
15484
15797
  GGML_ASSERT(false); // TODO: not implemented
15485
15798
  } break;
15799
+ case GGML_OP_POOL_1D:
15800
+ {
15801
+ GGML_ASSERT(false); // TODO: not implemented
15802
+ } break;
15803
+ case GGML_OP_POOL_2D:
15804
+ {
15805
+ GGML_ASSERT(false); // TODO: not implemented
15806
+ } break;
15486
15807
  case GGML_OP_FLASH_ATTN:
15487
15808
  {
15488
15809
  struct ggml_tensor * flash_grad = NULL;
15489
- if (src0->grad || src1->grad || tensor->opt[0]->grad) {
15490
- int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
15810
+ if (src0->grad || src1->grad || tensor->src[2]->grad) {
15811
+ int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
15491
15812
  GGML_ASSERT(t == 0 || t == 1);
15492
15813
  bool masked = t != 0;
15493
15814
  flash_grad =
15494
15815
  ggml_flash_attn_back(ctx,
15495
15816
  src0,
15496
15817
  src1,
15497
- tensor->opt[0],
15818
+ tensor->src[2],
15498
15819
  tensor->grad,
15499
15820
  masked);
15500
15821
  }
@@ -15591,7 +15912,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
15591
15912
  inplace);
15592
15913
  }
15593
15914
 
15594
- struct ggml_tensor * opt0 = tensor->opt[0];
15915
+ struct ggml_tensor * opt0 = tensor->src[2];
15595
15916
 
15596
15917
  if (opt0->grad) {
15597
15918
  struct ggml_tensor * grad_v = NULL;
@@ -15707,17 +16028,9 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
15707
16028
  }
15708
16029
  }
15709
16030
 
15710
- if (node->src0) {
15711
- ggml_visit_parents(cgraph, node->src0);
15712
- }
15713
-
15714
- if (node->src1) {
15715
- ggml_visit_parents(cgraph, node->src1);
15716
- }
15717
-
15718
- for (int i = 0; i < GGML_MAX_OPT; ++i) {
15719
- if (node->opt[i]) {
15720
- ggml_visit_parents(cgraph, node->opt[i]);
16031
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
16032
+ if (node->src[i]) {
16033
+ ggml_visit_parents(cgraph, node->src[i]);
15721
16034
  }
15722
16035
  }
15723
16036
 
@@ -15772,9 +16085,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
15772
16085
  struct ggml_cgraph result = {
15773
16086
  /*.n_nodes =*/ 0,
15774
16087
  /*.n_leafs =*/ 0,
15775
- /*.n_threads =*/ GGML_DEFAULT_N_THREADS,
15776
- /*.work_size =*/ 0,
15777
- /*.work =*/ NULL,
15778
16088
  /*.nodes =*/ { NULL },
15779
16089
  /*.grads =*/ { NULL },
15780
16090
  /*.leafs =*/ { NULL },
@@ -15945,16 +16255,20 @@ void clear_numa_thread_affinity(void) {}
15945
16255
  #endif
15946
16256
 
15947
16257
  struct ggml_compute_state_shared {
15948
- struct ggml_cgraph * cgraph;
16258
+ const struct ggml_cgraph * cgraph;
16259
+ const struct ggml_cplan * cplan;
15949
16260
 
15950
16261
  int64_t perf_node_start_cycles;
15951
16262
  int64_t perf_node_start_time_us;
15952
16263
 
15953
- int n_threads;
16264
+ const int n_threads;
15954
16265
 
15955
16266
  // synchronization primitives
15956
16267
  atomic_int n_active; // num active threads
15957
16268
  atomic_int node_n; // active graph node
16269
+
16270
+ bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
16271
+ void * abort_callback_data;
15958
16272
  };
15959
16273
 
15960
16274
  struct ggml_compute_state {
@@ -15974,14 +16288,22 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
15974
16288
 
15975
16289
  static thread_ret_t ggml_graph_compute_thread(void * data) {
15976
16290
  struct ggml_compute_state * state = (struct ggml_compute_state *) data;
15977
- struct ggml_cgraph * cgraph = state->shared->cgraph;
15978
16291
 
15979
- const int n_threads = state->shared->n_threads;
16292
+ const struct ggml_cgraph * cgraph = state->shared->cgraph;
16293
+ const struct ggml_cplan * cplan = state->shared->cplan;
16294
+
16295
+ const int * n_tasks_arr = cplan->n_tasks;
16296
+ const int n_threads = state->shared->n_threads;
16297
+
15980
16298
  set_numa_thread_affinity(state->ith, n_threads);
15981
16299
 
15982
16300
  int node_n = -1;
15983
16301
 
15984
16302
  while (true) {
16303
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16304
+ state->shared->node_n += 1;
16305
+ return (thread_ret_t) GGML_EXIT_ABORTED;
16306
+ }
15985
16307
  if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
15986
16308
  // all other threads are finished and spinning
15987
16309
  // do finalize and init here so we don't have synchronize again
@@ -15989,18 +16311,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
15989
16311
  /*.type =*/ GGML_TASK_FINALIZE,
15990
16312
  /*.ith =*/ 0,
15991
16313
  /*.nth =*/ 0,
15992
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
15993
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16314
+ /*.wsize =*/ cplan->work_size,
16315
+ /*.wdata =*/ cplan->work_data,
15994
16316
  };
15995
16317
 
15996
16318
  if (node_n != -1) {
15997
16319
  /* FINALIZE */
15998
16320
  struct ggml_tensor * node = state->shared->cgraph->nodes[node_n];
15999
16321
  if (GGML_OP_HAS_FINALIZE[node->op]) {
16000
- params.nth = node->n_tasks;
16322
+ params.nth = n_tasks_arr[node_n];
16001
16323
  ggml_compute_forward(&params, node);
16002
- ggml_graph_compute_perf_stats_node(node, state->shared);
16003
16324
  }
16325
+ ggml_graph_compute_perf_stats_node(node, state->shared);
16004
16326
  }
16005
16327
 
16006
16328
  // distribute new work or execute it direct if 1T
@@ -16008,11 +16330,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16008
16330
  GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes);
16009
16331
 
16010
16332
  struct ggml_tensor * node = cgraph->nodes[node_n];
16333
+ const int n_tasks = n_tasks_arr[node_n];
16011
16334
 
16012
16335
  state->shared->perf_node_start_cycles = ggml_perf_cycles();
16013
16336
  state->shared->perf_node_start_time_us = ggml_perf_time_us();
16014
16337
 
16015
- params.nth = node->n_tasks;
16338
+ params.nth = n_tasks;
16016
16339
 
16017
16340
  /* INIT */
16018
16341
  if (GGML_OP_HAS_INIT[node->op]) {
@@ -16020,7 +16343,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16020
16343
  ggml_compute_forward(&params, node);
16021
16344
  }
16022
16345
 
16023
- if (node->n_tasks == 1) {
16346
+ if (n_tasks == 1) {
16024
16347
  // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
16025
16348
  // they do something more efficient than spinning (?)
16026
16349
  params.type = GGML_TASK_COMPUTE;
@@ -16029,11 +16352,16 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16029
16352
  if (GGML_OP_HAS_FINALIZE[node->op]) {
16030
16353
  params.type = GGML_TASK_FINALIZE;
16031
16354
  ggml_compute_forward(&params, node);
16032
- ggml_graph_compute_perf_stats_node(node, state->shared);
16033
16355
  }
16356
+
16357
+ ggml_graph_compute_perf_stats_node(node, state->shared);
16034
16358
  } else {
16035
16359
  break;
16036
16360
  }
16361
+
16362
+ if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
16363
+ break;
16364
+ }
16037
16365
  }
16038
16366
 
16039
16367
  atomic_store(&state->shared->n_active, n_threads);
@@ -16042,7 +16370,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16042
16370
  // wait for other threads to finish
16043
16371
  const int last = node_n;
16044
16372
  do {
16045
- sched_yield();
16373
+ //sched_yield();
16046
16374
  node_n = atomic_load(&state->shared->node_n);
16047
16375
  } while (node_n == last);
16048
16376
  }
@@ -16052,366 +16380,398 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
16052
16380
 
16053
16381
  /* COMPUTE */
16054
16382
  struct ggml_tensor * node = cgraph->nodes[node_n];
16383
+ const int n_tasks = n_tasks_arr[node_n];
16055
16384
 
16056
16385
  struct ggml_compute_params params = {
16057
16386
  /*.type =*/ GGML_TASK_COMPUTE,
16058
16387
  /*.ith =*/ state->ith,
16059
- /*.nth =*/ node->n_tasks,
16060
- /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
16061
- /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
16388
+ /*.nth =*/ n_tasks,
16389
+ /*.wsize =*/ cplan->work_size,
16390
+ /*.wdata =*/ cplan->work_data,
16062
16391
  };
16063
16392
 
16064
- if (state->ith < node->n_tasks) {
16393
+ if (state->ith < n_tasks) {
16065
16394
  ggml_compute_forward(&params, node);
16066
16395
  }
16067
16396
  }
16068
16397
 
16069
- return 0;
16398
+ return GGML_EXIT_SUCCESS;
16070
16399
  }
16071
16400
 
16072
- void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
16073
- const int n_threads = cgraph->n_threads;
16401
+ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
16402
+ if (n_threads <= 0) {
16403
+ n_threads = GGML_DEFAULT_N_THREADS;
16404
+ }
16074
16405
 
16075
- struct ggml_compute_state_shared state_shared = {
16076
- /*.cgraph =*/ cgraph,
16077
- /*.perf_node_start_cycles =*/ 0,
16078
- /*.perf_node_start_time_us =*/ 0,
16079
- /*.n_threads =*/ n_threads,
16080
- /*.n_active =*/ n_threads,
16081
- /*.node_n =*/ -1,
16082
- };
16083
- struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16406
+ size_t work_size = 0;
16084
16407
 
16085
- // initialize tasks + work buffer
16086
- {
16087
- size_t work_size = 0;
16408
+ struct ggml_cplan cplan;
16409
+ memset(&cplan, 0, sizeof(struct ggml_cplan));
16088
16410
 
16089
- // thread scheduling for the different operations
16090
- for (int i = 0; i < cgraph->n_nodes; i++) {
16091
- struct ggml_tensor * node = cgraph->nodes[i];
16411
+ // thread scheduling for the different operations + work buffer size estimation
16412
+ for (int i = 0; i < cgraph->n_nodes; i++) {
16413
+ int n_tasks = 1;
16092
16414
 
16093
- switch (node->op) {
16094
- case GGML_OP_CPY:
16095
- case GGML_OP_DUP:
16096
- {
16097
- node->n_tasks = n_threads;
16415
+ struct ggml_tensor * node = cgraph->nodes[i];
16098
16416
 
16099
- size_t cur = 0;
16100
- if (ggml_is_quantized(node->type)) {
16101
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads;
16102
- }
16417
+ switch (node->op) {
16418
+ case GGML_OP_CPY:
16419
+ case GGML_OP_DUP:
16420
+ {
16421
+ n_tasks = n_threads;
16422
+
16423
+ size_t cur = 0;
16424
+ if (ggml_is_quantized(node->type)) {
16425
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks;
16426
+ }
16103
16427
 
16104
- work_size = MAX(work_size, cur);
16105
- } break;
16106
- case GGML_OP_ADD:
16107
- case GGML_OP_ADD1:
16108
- {
16109
- node->n_tasks = n_threads;
16428
+ work_size = MAX(work_size, cur);
16429
+ } break;
16430
+ case GGML_OP_ADD:
16431
+ case GGML_OP_ADD1:
16432
+ {
16433
+ n_tasks = n_threads;
16110
16434
 
16111
- size_t cur = 0;
16435
+ size_t cur = 0;
16112
16436
 
16113
- if (ggml_is_quantized(node->src0->type)) {
16114
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
16115
- }
16437
+ if (ggml_is_quantized(node->src[0]->type)) {
16438
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[0]->ne[0] * n_tasks;
16439
+ }
16116
16440
 
16117
- work_size = MAX(work_size, cur);
16118
- } break;
16119
- case GGML_OP_ACC:
16120
- {
16121
- node->n_tasks = n_threads;
16441
+ work_size = MAX(work_size, cur);
16442
+ } break;
16443
+ case GGML_OP_ACC:
16444
+ {
16445
+ n_tasks = n_threads;
16122
16446
 
16123
- size_t cur = 0;
16447
+ size_t cur = 0;
16124
16448
 
16125
- if (ggml_is_quantized(node->src0->type)) {
16126
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads;
16127
- }
16449
+ if (ggml_is_quantized(node->src[0]->type)) {
16450
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src[1]->ne[0] * n_tasks;
16451
+ }
16452
+
16453
+ work_size = MAX(work_size, cur);
16454
+ } break;
16455
+ case GGML_OP_SUB:
16456
+ case GGML_OP_DIV:
16457
+ case GGML_OP_SQR:
16458
+ case GGML_OP_SQRT:
16459
+ case GGML_OP_LOG:
16460
+ case GGML_OP_SUM:
16461
+ case GGML_OP_SUM_ROWS:
16462
+ case GGML_OP_MEAN:
16463
+ case GGML_OP_ARGMAX:
16464
+ case GGML_OP_REPEAT:
16465
+ case GGML_OP_REPEAT_BACK:
16466
+ case GGML_OP_ABS:
16467
+ case GGML_OP_SGN:
16468
+ case GGML_OP_NEG:
16469
+ case GGML_OP_STEP:
16470
+ case GGML_OP_TANH:
16471
+ case GGML_OP_ELU:
16472
+ case GGML_OP_RELU:
16473
+ {
16474
+ n_tasks = 1;
16475
+ } break;
16476
+ case GGML_OP_MUL:
16477
+ case GGML_OP_GELU:
16478
+ case GGML_OP_GELU_QUICK:
16479
+ case GGML_OP_SILU:
16480
+ case GGML_OP_SILU_BACK:
16481
+ case GGML_OP_NORM:
16482
+ case GGML_OP_RMS_NORM:
16483
+ case GGML_OP_RMS_NORM_BACK:
16484
+ {
16485
+ n_tasks = n_threads;
16486
+ } break;
16487
+ case GGML_OP_MUL_MAT:
16488
+ case GGML_OP_OUT_PROD:
16489
+ {
16490
+ n_tasks = n_threads;
16491
+
16492
+ // TODO: use different scheduling for different matrix sizes
16493
+ //const int nr0 = ggml_nrows(node->src[0]);
16494
+ //const int nr1 = ggml_nrows(node->src[1]);
16495
+
16496
+ //n_tasks = MIN(n_threads, MAX(1, nr0/128));
16497
+ //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks);
16128
16498
 
16129
- work_size = MAX(work_size, cur);
16130
- } break;
16131
- case GGML_OP_SUB:
16132
- case GGML_OP_DIV:
16133
- case GGML_OP_SQR:
16134
- case GGML_OP_SQRT:
16135
- case GGML_OP_LOG:
16136
- case GGML_OP_SUM:
16137
- case GGML_OP_SUM_ROWS:
16138
- case GGML_OP_MEAN:
16139
- case GGML_OP_ARGMAX:
16140
- case GGML_OP_REPEAT:
16141
- case GGML_OP_REPEAT_BACK:
16142
- case GGML_OP_ABS:
16143
- case GGML_OP_SGN:
16144
- case GGML_OP_NEG:
16145
- case GGML_OP_STEP:
16146
- case GGML_OP_TANH:
16147
- case GGML_OP_ELU:
16148
- case GGML_OP_RELU:
16149
- {
16150
- node->n_tasks = 1;
16151
- } break;
16152
- case GGML_OP_MUL:
16153
- case GGML_OP_GELU:
16154
- case GGML_OP_GELU_QUICK:
16155
- case GGML_OP_SILU:
16156
- case GGML_OP_SILU_BACK:
16157
- case GGML_OP_NORM:
16158
- case GGML_OP_RMS_NORM:
16159
- case GGML_OP_RMS_NORM_BACK:
16160
- {
16161
- node->n_tasks = n_threads;
16162
- } break;
16163
- case GGML_OP_MUL_MAT:
16164
- case GGML_OP_OUT_PROD:
16165
- {
16166
- node->n_tasks = n_threads;
16167
-
16168
- // TODO: use different scheduling for different matrix sizes
16169
- //const int nr0 = ggml_nrows(node->src0);
16170
- //const int nr1 = ggml_nrows(node->src1);
16171
-
16172
- //node->n_tasks = MIN(n_threads, MAX(1, nr0/128));
16173
- //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks);
16174
-
16175
- size_t cur = 0;
16176
- const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type;
16499
+ size_t cur = 0;
16500
+ const enum ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
16177
16501
 
16178
16502
  #if defined(GGML_USE_CUBLAS)
16179
- if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
16180
- node->n_tasks = 1; // TODO: this actually is doing nothing
16181
- // the threads are still spinning
16182
- }
16183
- else
16503
+ if (ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) {
16504
+ n_tasks = 1; // TODO: this actually is doing nothing
16505
+ // the threads are still spinning
16506
+ } else
16184
16507
  #elif defined(GGML_USE_CLBLAST)
16185
- if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) {
16186
- node->n_tasks = 1; // TODO: this actually is doing nothing
16187
- // the threads are still spinning
16188
- cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node);
16189
- }
16190
- else
16508
+ if (ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) {
16509
+ n_tasks = 1; // TODO: this actually is doing nothing
16510
+ // the threads are still spinning
16511
+ cur = ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node);
16512
+ } else
16191
16513
  #endif
16192
16514
  #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
16193
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
16194
- node->n_tasks = 1; // TODO: this actually is doing nothing
16195
- // the threads are still spinning
16196
- if (node->src0->type != GGML_TYPE_F32) {
16197
- // here we need memory just for single 2D matrix from src0
16198
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
16199
- }
16200
- } else
16201
- #endif
16202
- if (node->src1->type != vec_dot_type) {
16203
- cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type];
16204
- } else {
16205
- cur = 0;
16515
+ if (ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) {
16516
+ n_tasks = 1; // TODO: this actually is doing nothing
16517
+ // the threads are still spinning
16518
+ if (node->src[0]->type != GGML_TYPE_F32) {
16519
+ // here we need memory just for single 2D matrix from src0
16520
+ cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src[0]->ne[0]*node->src[0]->ne[1]);
16206
16521
  }
16522
+ } else
16523
+ #endif
16524
+ if (node->src[1]->type != vec_dot_type) {
16525
+ cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src[1])/GGML_BLCK_SIZE[vec_dot_type];
16526
+ } else {
16527
+ cur = 0;
16528
+ }
16207
16529
 
16208
- work_size = MAX(work_size, cur);
16209
- } break;
16210
- case GGML_OP_SCALE:
16211
- {
16212
- node->n_tasks = 1;
16213
- } break;
16214
- case GGML_OP_SET:
16215
- case GGML_OP_CONT:
16216
- case GGML_OP_RESHAPE:
16217
- case GGML_OP_VIEW:
16218
- case GGML_OP_PERMUTE:
16219
- case GGML_OP_TRANSPOSE:
16220
- case GGML_OP_GET_ROWS:
16221
- case GGML_OP_GET_ROWS_BACK:
16222
- case GGML_OP_DIAG:
16223
- case GGML_OP_DIAG_MASK_ZERO:
16224
- {
16225
- node->n_tasks = 1;
16226
- } break;
16227
- case GGML_OP_DIAG_MASK_INF:
16228
- case GGML_OP_SOFT_MAX:
16229
- case GGML_OP_SOFT_MAX_BACK:
16230
- case GGML_OP_ROPE:
16231
- case GGML_OP_ROPE_BACK:
16232
- {
16233
- node->n_tasks = n_threads;
16234
- } break;
16235
- case GGML_OP_ALIBI:
16236
- {
16237
- node->n_tasks = 1; //TODO
16238
- } break;
16239
- case GGML_OP_CLAMP:
16240
- {
16241
- node->n_tasks = 1; //TODO
16242
- } break;
16243
- case GGML_OP_CONV_1D:
16244
- {
16245
- node->n_tasks = n_threads;
16246
-
16247
- GGML_ASSERT(node->src0->ne[3] == 1);
16248
- GGML_ASSERT(node->src1->ne[2] == 1);
16249
- GGML_ASSERT(node->src1->ne[3] == 1);
16250
-
16251
- size_t cur = 0;
16252
- const int nk = node->src0->ne[0];
16253
-
16254
- if (node->src0->type == GGML_TYPE_F16 &&
16255
- node->src1->type == GGML_TYPE_F32) {
16256
- cur = sizeof(ggml_fp16_t)*(
16257
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16258
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16259
- );
16260
- } else if (node->src0->type == GGML_TYPE_F32 &&
16261
- node->src1->type == GGML_TYPE_F32) {
16262
- cur = sizeof(float)*(
16263
- nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] +
16264
- ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1]
16265
- );
16266
- } else {
16267
- GGML_ASSERT(false);
16268
- }
16530
+ work_size = MAX(work_size, cur);
16531
+ } break;
16532
+ case GGML_OP_SCALE:
16533
+ {
16534
+ n_tasks = 1;
16535
+ } break;
16536
+ case GGML_OP_SET:
16537
+ case GGML_OP_CONT:
16538
+ case GGML_OP_RESHAPE:
16539
+ case GGML_OP_VIEW:
16540
+ case GGML_OP_PERMUTE:
16541
+ case GGML_OP_TRANSPOSE:
16542
+ case GGML_OP_GET_ROWS:
16543
+ case GGML_OP_GET_ROWS_BACK:
16544
+ case GGML_OP_DIAG:
16545
+ case GGML_OP_DIAG_MASK_ZERO:
16546
+ {
16547
+ n_tasks = 1;
16548
+ } break;
16549
+ case GGML_OP_DIAG_MASK_INF:
16550
+ case GGML_OP_SOFT_MAX:
16551
+ case GGML_OP_SOFT_MAX_BACK:
16552
+ case GGML_OP_ROPE:
16553
+ case GGML_OP_ROPE_BACK:
16554
+ {
16555
+ n_tasks = n_threads;
16556
+ } break;
16557
+ case GGML_OP_ALIBI:
16558
+ {
16559
+ n_tasks = 1; //TODO
16560
+ } break;
16561
+ case GGML_OP_CLAMP:
16562
+ {
16563
+ n_tasks = 1; //TODO
16564
+ } break;
16565
+ case GGML_OP_CONV_1D:
16566
+ {
16567
+ n_tasks = n_threads;
16568
+
16569
+ GGML_ASSERT(node->src[0]->ne[3] == 1);
16570
+ GGML_ASSERT(node->src[1]->ne[2] == 1);
16571
+ GGML_ASSERT(node->src[1]->ne[3] == 1);
16572
+
16573
+ size_t cur = 0;
16574
+ const int nk = node->src[0]->ne[0];
16575
+
16576
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16577
+ node->src[1]->type == GGML_TYPE_F32) {
16578
+ cur = sizeof(ggml_fp16_t)*(
16579
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16580
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16581
+ );
16582
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16583
+ node->src[1]->type == GGML_TYPE_F32) {
16584
+ cur = sizeof(float)*(
16585
+ nk*ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] +
16586
+ ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1]
16587
+ );
16588
+ } else {
16589
+ GGML_ASSERT(false);
16590
+ }
16269
16591
 
16270
- work_size = MAX(work_size, cur);
16271
- } break;
16272
- case GGML_OP_CONV_2D:
16273
- {
16274
- node->n_tasks = n_threads;
16592
+ work_size = MAX(work_size, cur);
16593
+ } break;
16594
+ case GGML_OP_CONV_2D:
16595
+ {
16596
+ n_tasks = n_threads;
16597
+
16598
+ const int64_t ne00 = node->src[0]->ne[0]; // W
16599
+ const int64_t ne01 = node->src[0]->ne[1]; // H
16600
+ const int64_t ne02 = node->src[0]->ne[2]; // C
16601
+ const int64_t ne03 = node->src[0]->ne[3]; // N
16602
+
16603
+ const int64_t ne10 = node->src[1]->ne[0]; // W
16604
+ const int64_t ne11 = node->src[1]->ne[1]; // H
16605
+ const int64_t ne12 = node->src[1]->ne[2]; // C
16606
+
16607
+ const int64_t ne0 = node->ne[0];
16608
+ const int64_t ne1 = node->ne[1];
16609
+ const int64_t ne2 = node->ne[2];
16610
+ const int64_t nk = ne00*ne01;
16611
+ const int64_t ew0 = nk * ne02;
16612
+
16613
+ UNUSED(ne03);
16614
+ UNUSED(ne2);
16615
+
16616
+ size_t cur = 0;
16617
+
16618
+ if (node->src[0]->type == GGML_TYPE_F16 &&
16619
+ node->src[1]->type == GGML_TYPE_F32) {
16620
+ cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
16621
+ } else if (node->src[0]->type == GGML_TYPE_F32 &&
16622
+ node->src[1]->type == GGML_TYPE_F32) {
16623
+ cur = sizeof(float)* (ne10*ne11*ne12);
16624
+ } else {
16625
+ GGML_ASSERT(false);
16626
+ }
16275
16627
 
16276
- GGML_ASSERT(node->src1->ne[3] == 1);
16628
+ work_size = MAX(work_size, cur);
16629
+ } break;
16630
+ case GGML_OP_POOL_1D:
16631
+ case GGML_OP_POOL_2D:
16632
+ {
16633
+ n_tasks = 1;
16634
+ } break;
16635
+ case GGML_OP_FLASH_ATTN:
16636
+ {
16637
+ n_tasks = n_threads;
16277
16638
 
16278
- const int64_t ne00 = node->src0->ne[0]; // W
16279
- const int64_t ne01 = node->src0->ne[1]; // H
16280
- const int64_t ne02 = node->src0->ne[2]; // C
16281
- const int64_t ne03 = node->src0->ne[3]; // N
16639
+ size_t cur = 0;
16282
16640
 
16283
- const int64_t ne10 = node->src1->ne[0]; // W
16284
- const int64_t ne11 = node->src1->ne[1]; // H
16285
- const int64_t ne12 = node->src1->ne[2]; // C
16641
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16286
16642
 
16287
- const int64_t nk = ne00*ne01;
16643
+ if (node->src[1]->type == GGML_TYPE_F32) {
16644
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16645
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16646
+ }
16288
16647
 
16289
- UNUSED(ne02);
16290
- UNUSED(ne03);
16291
- UNUSED(nk);
16648
+ if (node->src[1]->type == GGML_TYPE_F16) {
16649
+ cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1)
16650
+ cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
16651
+ }
16292
16652
 
16293
- size_t cur = 0;
16653
+ work_size = MAX(work_size, cur);
16654
+ } break;
16655
+ case GGML_OP_FLASH_FF:
16656
+ {
16657
+ n_tasks = n_threads;
16294
16658
 
16295
- if (node->src0->type == GGML_TYPE_F16 &&
16296
- node->src1->type == GGML_TYPE_F32) {
16297
- cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12);
16298
- } else if (node->src0->type == GGML_TYPE_F32 &&
16299
- node->src1->type == GGML_TYPE_F32) {
16300
- cur = sizeof(float)* (ne10*ne11*ne12);
16301
- } else {
16302
- GGML_ASSERT(false);
16303
- }
16659
+ size_t cur = 0;
16304
16660
 
16305
- work_size = MAX(work_size, cur);
16306
- } break;
16307
- case GGML_OP_FLASH_ATTN:
16308
- {
16309
- node->n_tasks = n_threads;
16661
+ if (node->src[1]->type == GGML_TYPE_F32) {
16662
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16663
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16664
+ }
16310
16665
 
16311
- size_t cur = 0;
16666
+ if (node->src[1]->type == GGML_TYPE_F16) {
16667
+ cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1)
16668
+ cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2
16669
+ }
16312
16670
 
16313
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16671
+ work_size = MAX(work_size, cur);
16672
+ } break;
16673
+ case GGML_OP_FLASH_ATTN_BACK:
16674
+ {
16675
+ n_tasks = n_threads;
16314
16676
 
16315
- if (node->src1->type == GGML_TYPE_F32) {
16316
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16317
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16318
- }
16677
+ size_t cur = 0;
16319
16678
 
16320
- if (node->src1->type == GGML_TYPE_F16) {
16321
- cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1)
16322
- cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2
16323
- }
16679
+ const int64_t D = node->src[0]->ne[0];
16680
+ const int64_t ne11 = ggml_up(node->src[1]->ne[1], GGML_SOFT_MAX_UNROLL);
16681
+ const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16682
+ if (node->src[1]->type == GGML_TYPE_F32) {
16683
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16684
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16685
+ }
16324
16686
 
16325
- work_size = MAX(work_size, cur);
16326
- } break;
16327
- case GGML_OP_FLASH_FF:
16328
- {
16329
- node->n_tasks = n_threads;
16687
+ if (node->src[1]->type == GGML_TYPE_F16) {
16688
+ cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1)
16689
+ cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2
16690
+ }
16330
16691
 
16331
- size_t cur = 0;
16692
+ work_size = MAX(work_size, cur);
16693
+ } break;
16694
+ case GGML_OP_WIN_PART:
16695
+ case GGML_OP_WIN_UNPART:
16696
+ case GGML_OP_MAP_UNARY:
16697
+ case GGML_OP_MAP_BINARY:
16698
+ case GGML_OP_MAP_CUSTOM1:
16699
+ case GGML_OP_MAP_CUSTOM2:
16700
+ case GGML_OP_MAP_CUSTOM3:
16701
+ {
16702
+ n_tasks = 1;
16703
+ } break;
16704
+ case GGML_OP_CROSS_ENTROPY_LOSS:
16705
+ {
16706
+ n_tasks = n_threads;
16332
16707
 
16333
- if (node->src1->type == GGML_TYPE_F32) {
16334
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16335
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16336
- }
16708
+ size_t cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks);
16337
16709
 
16338
- if (node->src1->type == GGML_TYPE_F16) {
16339
- cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1)
16340
- cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
16341
- }
16710
+ work_size = MAX(work_size, cur);
16711
+ } break;
16712
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16713
+ {
16714
+ n_tasks = n_threads;
16342
16715
 
16343
- work_size = MAX(work_size, cur);
16344
- } break;
16345
- case GGML_OP_FLASH_ATTN_BACK:
16346
- {
16347
- node->n_tasks = n_threads;
16716
+ size_t cur = ggml_type_size(node->type)*node->src[0]->ne[0]*n_tasks;
16348
16717
 
16349
- size_t cur = 0;
16718
+ work_size = MAX(work_size, cur);
16719
+ } break;
16720
+ case GGML_OP_NONE:
16721
+ {
16722
+ n_tasks = 1;
16723
+ } break;
16724
+ case GGML_OP_COUNT:
16725
+ {
16726
+ GGML_ASSERT(false);
16727
+ } break;
16728
+ }
16350
16729
 
16351
- const int64_t D = node->src0->ne[0];
16352
- const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
16353
- const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
16354
- if (node->src1->type == GGML_TYPE_F32) {
16355
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16356
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16357
- }
16730
+ cplan.n_tasks[i] = n_tasks;
16731
+ }
16358
16732
 
16359
- if (node->src1->type == GGML_TYPE_F16) {
16360
- cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
16361
- cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
16362
- }
16733
+ if (work_size > 0) {
16734
+ work_size += CACHE_LINE_SIZE*(n_threads - 1);
16735
+ }
16363
16736
 
16364
- work_size = MAX(work_size, cur);
16365
- } break;
16366
- case GGML_OP_WIN_PART:
16367
- case GGML_OP_WIN_UNPART:
16368
- case GGML_OP_MAP_UNARY:
16369
- case GGML_OP_MAP_BINARY:
16370
- case GGML_OP_MAP_CUSTOM1:
16371
- case GGML_OP_MAP_CUSTOM2:
16372
- case GGML_OP_MAP_CUSTOM3:
16373
- {
16374
- node->n_tasks = 1;
16375
- } break;
16376
- case GGML_OP_CROSS_ENTROPY_LOSS:
16377
- {
16378
- node->n_tasks = n_threads;
16379
-
16380
- size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks);
16381
-
16382
- work_size = MAX(work_size, cur);
16383
- } break;
16384
- case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
16385
- {
16386
- node->n_tasks = n_threads;
16387
-
16388
- size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks;
16389
-
16390
- work_size = MAX(work_size, cur);
16391
- } break;
16392
- case GGML_OP_NONE:
16393
- {
16394
- node->n_tasks = 1;
16395
- } break;
16396
- case GGML_OP_COUNT:
16397
- {
16398
- GGML_ASSERT(false);
16399
- } break;
16400
- }
16401
- }
16737
+ cplan.n_threads = n_threads;
16738
+ cplan.work_size = work_size;
16739
+ cplan.work_data = NULL;
16402
16740
 
16403
- if (cgraph->work != NULL && work_size > cgraph->work_size) {
16404
- GGML_ASSERT(false); // TODO: better handling
16405
- }
16741
+ return cplan;
16742
+ }
16406
16743
 
16407
- if (work_size > 0 && cgraph->work == NULL) {
16408
- cgraph->work_size = work_size + CACHE_LINE_SIZE*(n_threads - 1);
16744
+ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
16745
+ {
16746
+ GGML_ASSERT(cplan);
16747
+ GGML_ASSERT(cplan->n_threads > 0);
16748
+
16749
+ if (cplan->work_size > 0) {
16750
+ GGML_ASSERT(cplan->work_data);
16751
+ }
16409
16752
 
16410
- GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, cgraph->work_size);
16411
- cgraph->work = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cgraph->work_size);
16753
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
16754
+ if (cgraph->nodes[i]->op != GGML_OP_NONE) {
16755
+ GGML_ASSERT(cplan->n_tasks[i] > 0);
16756
+ }
16412
16757
  }
16413
16758
  }
16414
16759
 
16760
+ const int n_threads = cplan->n_threads;
16761
+
16762
+ struct ggml_compute_state_shared state_shared = {
16763
+ /*.cgraph =*/ cgraph,
16764
+ /*.cgraph_plan =*/ cplan,
16765
+ /*.perf_node_start_cycles =*/ 0,
16766
+ /*.perf_node_start_time_us =*/ 0,
16767
+ /*.n_threads =*/ n_threads,
16768
+ /*.n_active =*/ n_threads,
16769
+ /*.node_n =*/ -1,
16770
+ /*.abort_callback =*/ NULL,
16771
+ /*.abort_callback_data =*/ NULL,
16772
+ };
16773
+ struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
16774
+
16415
16775
  // create thread pool
16416
16776
  if (n_threads > 1) {
16417
16777
  for (int j = 1; j < n_threads; ++j) {
@@ -16432,12 +16792,12 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16432
16792
  const int64_t perf_start_time_us = ggml_perf_time_us();
16433
16793
 
16434
16794
  // this is a work thread too
16435
- ggml_graph_compute_thread(&workers[0]);
16795
+ int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
16436
16796
 
16437
16797
  // don't leave affinity set on the main thread
16438
16798
  clear_numa_thread_affinity();
16439
16799
 
16440
- // join thread pool
16800
+ // join or kill thread pool
16441
16801
  if (n_threads > 1) {
16442
16802
  for (int j = 1; j < n_threads; j++) {
16443
16803
  const int rc = ggml_thread_join(workers[j].thrd, NULL);
@@ -16461,6 +16821,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
16461
16821
  (double) perf_time_us_cur / 1000.0,
16462
16822
  (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
16463
16823
  }
16824
+
16825
+ return compute_status;
16464
16826
  }
16465
16827
 
16466
16828
  void ggml_graph_reset(struct ggml_cgraph * cgraph) {
@@ -16473,6 +16835,17 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
16473
16835
  }
16474
16836
  }
16475
16837
 
16838
+ void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads) {
16839
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, n_threads);
16840
+
16841
+ struct ggml_tensor * buf = ggml_new_tensor_1d(ctx, GGML_TYPE_I8, cplan.work_size);
16842
+ GGML_ASSERT(buf);
16843
+
16844
+ cplan.work_data = buf->data;
16845
+
16846
+ ggml_graph_compute(cgraph, &cplan);
16847
+ }
16848
+
16476
16849
  struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) {
16477
16850
  for (int i = 0; i < cgraph->n_leafs; i++) {
16478
16851
  struct ggml_tensor * leaf = cgraph->leafs[i];
@@ -16511,22 +16884,18 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char
16511
16884
  const int64_t * ne = tensor->ne;
16512
16885
  const size_t * nb = tensor->nb;
16513
16886
 
16514
- fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n",
16887
+ fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n",
16515
16888
  arg,
16516
16889
  ggml_type_name(tensor->type),
16517
16890
  ggml_op_name (tensor->op),
16518
16891
  tensor->n_dims,
16519
16892
  ne[0], ne[1], ne[2], ne[3],
16520
16893
  nb[0], nb[1], nb[2], nb[3],
16521
- tensor->n_tasks,
16522
16894
  tensor->data,
16523
16895
  tensor->name);
16524
16896
  }
16525
16897
 
16526
16898
  void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16527
- //assert(cgraph->work == NULL);
16528
- //assert(cgraph->work_size == 0);
16529
-
16530
16899
  uint64_t size_eval = 0;
16531
16900
 
16532
16901
  // compute size of intermediate results
@@ -16555,8 +16924,8 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16555
16924
  ggml_graph_export_leaf(cgraph->leafs[i], fout);
16556
16925
 
16557
16926
  GGML_ASSERT(cgraph->leafs[i]->op == GGML_OP_NONE);
16558
- GGML_ASSERT(cgraph->leafs[i]->src0 == NULL);
16559
- GGML_ASSERT(cgraph->leafs[i]->src1 == NULL);
16927
+ GGML_ASSERT(cgraph->leafs[i]->src[0] == NULL);
16928
+ GGML_ASSERT(cgraph->leafs[i]->src[1] == NULL);
16560
16929
  }
16561
16930
 
16562
16931
  // header
@@ -16567,17 +16936,9 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16567
16936
  for (int i = 0; i < cgraph->n_nodes; ++i) {
16568
16937
  ggml_graph_export_node(cgraph->nodes[i], "DST", fout);
16569
16938
 
16570
- if (cgraph->nodes[i]->src0) {
16571
- ggml_graph_export_node(cgraph->nodes[i]->src0, "SRC0", fout);
16572
- }
16573
-
16574
- if (cgraph->nodes[i]->src1) {
16575
- ggml_graph_export_node(cgraph->nodes[i]->src1, "SRC1", fout);
16576
- }
16577
-
16578
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16579
- if (cgraph->nodes[i]->opt[j]) {
16580
- ggml_graph_export_node(cgraph->nodes[i]->opt[j], "OPT", fout);
16939
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16940
+ if (cgraph->nodes[i]->src[j]) {
16941
+ ggml_graph_export_node(cgraph->nodes[i]->src[j], "SRC", fout);
16581
16942
  }
16582
16943
  }
16583
16944
 
@@ -16668,16 +17029,13 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
16668
17029
 
16669
17030
  // output the op arguments
16670
17031
  {
16671
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
16672
-
16673
- args[0] = tensor->src0;
16674
- args[1] = tensor->src1;
17032
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16675
17033
 
16676
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16677
- args[2 + j] = tensor->opt[j];
17034
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17035
+ args[j] = tensor->src[j];
16678
17036
  }
16679
17037
 
16680
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17038
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16681
17039
  if (args[j]) {
16682
17040
  int32_t idx = -1;
16683
17041
 
@@ -16895,12 +17253,12 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16895
17253
 
16896
17254
  const char * ptr_name = ptr; ptr += GGML_MAX_NAME;
16897
17255
 
16898
- const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += (2 + GGML_MAX_OPT)*sizeof(int32_t);
17256
+ const int32_t * ptr_arg_idx = (const int32_t *) ptr; ptr += GGML_MAX_SRC*sizeof(int32_t);
16899
17257
 
16900
- struct ggml_tensor * args[2 + GGML_MAX_OPT] = { NULL };
17258
+ struct ggml_tensor * args[GGML_MAX_SRC] = { NULL };
16901
17259
 
16902
17260
  // parse args
16903
- for (int j = 0; j < 2 + GGML_MAX_OPT; ++j) {
17261
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
16904
17262
  const int32_t arg_idx = ptr_arg_idx[j];
16905
17263
 
16906
17264
  if (arg_idx == -1) {
@@ -16957,11 +17315,8 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
16957
17315
  tensor->nb[j] = nb[j];
16958
17316
  }
16959
17317
 
16960
- tensor->src0 = args[0];
16961
- tensor->src1 = args[1];
16962
-
16963
- for (int j = 0; j < GGML_MAX_OPT; ++j) {
16964
- tensor->opt[j] = args[2 + j];
17318
+ for (int j = 0; j < GGML_MAX_SRC; ++j) {
17319
+ tensor->src[j] = args[j];
16965
17320
  }
16966
17321
 
16967
17322
  result.nodes[i] = tensor;
@@ -16979,9 +17334,6 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
16979
17334
 
16980
17335
  GGML_PRINT("=== GRAPH ===\n");
16981
17336
 
16982
- GGML_PRINT_DEBUG("n_threads = %d\n", cgraph->n_threads);
16983
- GGML_PRINT_DEBUG("total work size = %zu bytes\n", cgraph->work_size);
16984
-
16985
17337
  GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
16986
17338
  for (int i = 0; i < cgraph->n_nodes; i++) {
16987
17339
  struct ggml_tensor * node = cgraph->nodes[i];
@@ -17160,19 +17512,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17160
17512
  for (int i = 0; i < gb->n_nodes; i++) {
17161
17513
  struct ggml_tensor * node = gb->nodes[i];
17162
17514
 
17163
- if (node->src0) {
17164
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src0, "x");
17165
- }
17166
-
17167
- if (node->src1) {
17168
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->src1, "y");
17169
- }
17170
-
17171
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17172
- if (node->opt[j]) {
17515
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17516
+ if (node->src[j]) {
17173
17517
  char label[16];
17174
- snprintf(label, sizeof(label), "opt %d", j);
17175
- ggml_graph_dump_dot_node_edge(fp, gb, node, node->opt[j], label);
17518
+ snprintf(label, sizeof(label), "src %d", j);
17519
+ ggml_graph_dump_dot_node_edge(fp, gb, node, node->src[j], label);
17176
17520
  }
17177
17521
  }
17178
17522
  }
@@ -17180,19 +17524,11 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
17180
17524
  for (int i = 0; i < gb->n_leafs; i++) {
17181
17525
  struct ggml_tensor * node = gb->leafs[i];
17182
17526
 
17183
- if (node->src0) {
17184
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src0, "x");
17185
- }
17186
-
17187
- if (node->src1) {
17188
- ggml_graph_dump_dot_leaf_edge(fp, node, node->src1, "y");
17189
- }
17190
-
17191
- for (int j = 0; j < GGML_MAX_OPT; j++) {
17192
- if (node->opt[j]) {
17527
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
17528
+ if (node->src[j]) {
17193
17529
  char label[16];
17194
- snprintf(label, sizeof(label), "opt %d", j);
17195
- ggml_graph_dump_dot_leaf_edge(fp, node, node->opt[j], label);
17530
+ snprintf(label, sizeof(label), "src %d", j);
17531
+ ggml_graph_dump_dot_leaf_edge(fp, node, node->src[j], label);
17196
17532
  }
17197
17533
  }
17198
17534
  }
@@ -17254,9 +17590,6 @@ static enum ggml_opt_result ggml_opt_adam(
17254
17590
  struct ggml_cgraph * gb) {
17255
17591
  GGML_ASSERT(ggml_is_scalar(f));
17256
17592
 
17257
- gf->n_threads = params.n_threads;
17258
- gb->n_threads = params.n_threads;
17259
-
17260
17593
  // these will store the parameters we want to optimize
17261
17594
  struct ggml_tensor * ps[GGML_MAX_PARAMS];
17262
17595
 
@@ -17303,7 +17636,8 @@ static enum ggml_opt_result ggml_opt_adam(
17303
17636
  // compute the function value
17304
17637
  ggml_graph_reset (gf);
17305
17638
  ggml_set_f32 (f->grad, 1.0f);
17306
- ggml_graph_compute(ctx, gb);
17639
+
17640
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17307
17641
 
17308
17642
  opt->adam.fx_prev = ggml_get_f32_1d(f, 0);
17309
17643
  opt->adam.fx_best = opt->adam.fx_prev;
@@ -17383,7 +17717,8 @@ static enum ggml_opt_result ggml_opt_adam(
17383
17717
 
17384
17718
  ggml_graph_reset (gf);
17385
17719
  ggml_set_f32 (f->grad, 1.0f);
17386
- ggml_graph_compute(ctx, gb);
17720
+
17721
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17387
17722
 
17388
17723
  const float fx = ggml_get_f32_1d(f, 0);
17389
17724
 
@@ -17505,7 +17840,8 @@ static enum ggml_opt_result linesearch_backtracking(
17505
17840
 
17506
17841
  ggml_graph_reset (gf);
17507
17842
  ggml_set_f32 (f->grad, 1.0f);
17508
- ggml_graph_compute(ctx, gb);
17843
+
17844
+ ggml_graph_compute_with_ctx(ctx, gb, params->n_threads);
17509
17845
 
17510
17846
  ggml_opt_get_grad(np, ps, g);
17511
17847
 
@@ -17573,9 +17909,6 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17573
17909
  }
17574
17910
  }
17575
17911
 
17576
- gf->n_threads = params.n_threads;
17577
- gb->n_threads = params.n_threads;
17578
-
17579
17912
  const int m = params.lbfgs.m;
17580
17913
 
17581
17914
  // these will store the parameters we want to optimize
@@ -17627,7 +17960,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
17627
17960
 
17628
17961
  ggml_graph_reset (gf);
17629
17962
  ggml_set_f32 (f->grad, 1.0f);
17630
- ggml_graph_compute(ctx, gb);
17963
+
17964
+ ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
17631
17965
 
17632
17966
  ggml_opt_get_grad(np, ps, g);
17633
17967