cui-llama.rn 1.2.0 → 1.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. package/README.md +2 -0
  2. package/android/src/main/CMakeLists.txt +2 -2
  3. package/android/src/main/java/com/rnllama/LlamaContext.java +31 -9
  4. package/android/src/main/java/com/rnllama/RNLlama.java +39 -0
  5. package/android/src/newarch/java/com/rnllama/RNLlamaModule.java +5 -0
  6. package/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +5 -0
  7. package/cpp/common.cpp +36 -1
  8. package/cpp/common.h +5 -1
  9. package/cpp/ggml-aarch64.c +2 -11
  10. package/cpp/ggml-alloc.h +1 -1
  11. package/cpp/ggml-backend-impl.h +151 -78
  12. package/cpp/{ggml-backend.c → ggml-backend.cpp} +565 -269
  13. package/cpp/ggml-backend.h +147 -62
  14. package/cpp/ggml-impl.h +15 -0
  15. package/cpp/ggml-metal.h +8 -9
  16. package/cpp/ggml-metal.m +2428 -2111
  17. package/cpp/ggml-quants.c +2 -2
  18. package/cpp/ggml-quants.h +0 -4
  19. package/cpp/ggml.c +799 -1121
  20. package/cpp/ggml.h +79 -72
  21. package/cpp/llama-vocab.cpp +189 -106
  22. package/cpp/llama-vocab.h +18 -9
  23. package/cpp/llama.cpp +736 -341
  24. package/cpp/llama.h +9 -4
  25. package/cpp/unicode-data.cpp +6 -4
  26. package/cpp/unicode-data.h +4 -4
  27. package/cpp/unicode.cpp +14 -7
  28. package/lib/commonjs/NativeRNLlama.js.map +1 -1
  29. package/lib/commonjs/index.js +4 -0
  30. package/lib/commonjs/index.js.map +1 -1
  31. package/lib/module/NativeRNLlama.js.map +1 -1
  32. package/lib/module/index.js +3 -0
  33. package/lib/module/index.js.map +1 -1
  34. package/lib/typescript/NativeRNLlama.d.ts +6 -0
  35. package/lib/typescript/NativeRNLlama.d.ts.map +1 -1
  36. package/lib/typescript/index.d.ts +2 -1
  37. package/lib/typescript/index.d.ts.map +1 -1
  38. package/package.json +1 -1
  39. package/src/NativeRNLlama.ts +7 -0
  40. package/src/index.ts +5 -0
package/cpp/ggml.c CHANGED
@@ -39,9 +39,6 @@
39
39
  #include <unistd.h>
40
40
  #endif
41
41
 
42
- #if defined(__ARM_FEATURE_SVE)
43
- int lm_ggml_sve_cnt_b = 0;
44
- #endif
45
42
  #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
46
43
  #undef LM_GGML_USE_LLAMAFILE
47
44
  #endif
@@ -322,26 +319,63 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) {
322
319
  // logging
323
320
  //
324
321
 
322
+ struct lm_ggml_logger_state {
323
+ lm_ggml_log_callback log_callback;
324
+ void * log_callback_user_data;
325
+ };
326
+ static struct lm_ggml_logger_state g_logger_state = {lm_ggml_log_callback_default, NULL};
327
+
328
+ static void lm_ggml_log_internal_v(enum lm_ggml_log_level level, const char * format, va_list args) {
329
+ if (format == NULL)
330
+ return;
331
+ va_list args_copy;
332
+ va_copy(args_copy, args);
333
+ char buffer[128];
334
+ int len = vsnprintf(buffer, 128, format, args);
335
+ if (len < 128) {
336
+ g_logger_state.log_callback(level, buffer, g_logger_state.log_callback_user_data);
337
+ } else {
338
+ char * buffer2 = (char *) calloc(len + 1, sizeof(char));
339
+ vsnprintf(buffer2, len + 1, format, args_copy);
340
+ buffer2[len] = 0;
341
+ g_logger_state.log_callback(level, buffer2, g_logger_state.log_callback_user_data);
342
+ free(buffer2);
343
+ }
344
+ va_end(args_copy);
345
+ }
346
+
347
+ void lm_ggml_log_internal(enum lm_ggml_log_level level, const char * format, ...) {
348
+ va_list args;
349
+ va_start(args, format);
350
+ lm_ggml_log_internal_v(level, format, args);
351
+ va_end(args);
352
+ }
353
+
354
+ void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * text, void * user_data) {
355
+ (void) level;
356
+ (void) user_data;
357
+ fputs(text, stderr);
358
+ fflush(stderr);
359
+ }
360
+
325
361
  #if (LM_GGML_DEBUG >= 1)
326
- #define LM_GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
362
+ #define LM_GGML_PRINT_DEBUG(...) LM_GGML_LOG_DEBUG(__VA_ARGS__)
327
363
  #else
328
364
  #define LM_GGML_PRINT_DEBUG(...)
329
365
  #endif
330
366
 
331
367
  #if (LM_GGML_DEBUG >= 5)
332
- #define LM_GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__)
368
+ #define LM_GGML_PRINT_DEBUG_5(...) LM_GGML_LOG_DEBUG(__VA_ARGS__)
333
369
  #else
334
370
  #define LM_GGML_PRINT_DEBUG_5(...)
335
371
  #endif
336
372
 
337
373
  #if (LM_GGML_DEBUG >= 10)
338
- #define LM_GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__)
374
+ #define LM_GGML_PRINT_DEBUG_10(...) LM_GGML_LOG_DEBUG(__VA_ARGS__)
339
375
  #else
340
376
  #define LM_GGML_PRINT_DEBUG_10(...)
341
377
  #endif
342
378
 
343
- #define LM_GGML_PRINT(...) printf(__VA_ARGS__)
344
-
345
379
  //
346
380
  // end of logging block
347
381
  //
@@ -358,7 +392,7 @@ void lm_ggml_abort(const char * file, int line, const char * fmt, ...) {
358
392
  #else
359
393
  inline static void * lm_ggml_aligned_malloc(size_t size) {
360
394
  if (size == 0) {
361
- LM_GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for lm_ggml_aligned_malloc!\n");
395
+ LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_aligned_malloc!\n");
362
396
  return NULL;
363
397
  }
364
398
  void * aligned_memory = NULL;
@@ -380,7 +414,7 @@ inline static void * lm_ggml_aligned_malloc(size_t size) {
380
414
  error_desc = "insufficient memory";
381
415
  break;
382
416
  }
383
- LM_GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
417
+ LM_GGML_LOG_ERROR("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0));
384
418
  LM_GGML_ABORT("fatal error");
385
419
  return NULL;
386
420
  }
@@ -396,12 +430,12 @@ inline static void * lm_ggml_aligned_malloc(size_t size) {
396
430
 
397
431
  inline static void * lm_ggml_malloc(size_t size) {
398
432
  if (size == 0) {
399
- LM_GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for lm_ggml_malloc!\n");
433
+ LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_malloc!\n");
400
434
  return NULL;
401
435
  }
402
436
  void * result = malloc(size);
403
437
  if (result == NULL) {
404
- LM_GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
438
+ LM_GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
405
439
  LM_GGML_ABORT("fatal error");
406
440
  }
407
441
  return result;
@@ -410,12 +444,12 @@ inline static void * lm_ggml_malloc(size_t size) {
410
444
  // calloc
411
445
  inline static void * lm_ggml_calloc(size_t num, size_t size) {
412
446
  if (num == 0 || size == 0) {
413
- LM_GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for lm_ggml_calloc!\n");
447
+ LM_GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for lm_ggml_calloc!\n");
414
448
  return NULL;
415
449
  }
416
450
  void * result = calloc(num, size);
417
451
  if (result == NULL) {
418
- LM_GGML_PRINT("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
452
+ LM_GGML_LOG_ERROR("%s: failed to allocate %6.2f MB\n", __func__, size/(1024.0*1024.0));
419
453
  LM_GGML_ABORT("fatal error");
420
454
  }
421
455
  return result;
@@ -455,7 +489,16 @@ static lm_ggml_fp16_t lm_ggml_table_gelu_quick_f16[1 << 16];
455
489
  // precomputed f32 table for f16 (256 KB) (ggml-impl.h)
456
490
  float lm_ggml_table_f32_f16[1 << 16];
457
491
 
458
- LM_GGML_CALL const char * lm_ggml_status_to_string(enum lm_ggml_status status) {
492
+ #if defined(__ARM_ARCH)
493
+ struct lm_ggml_arm_arch_features_type {
494
+ int has_neon;
495
+ int has_i8mm;
496
+ int has_sve;
497
+ int sve_cnt;
498
+ } lm_ggml_arm_arch_features = {-1, -1, -1, 0};
499
+ #endif
500
+
501
+ const char * lm_ggml_status_to_string(enum lm_ggml_status status) {
459
502
  switch (status) {
460
503
  case LM_GGML_STATUS_ALLOC_FAILED: return "GGML status: error (failed to allocate memory)";
461
504
  case LM_GGML_STATUS_FAILED: return "GGML status: error (operation failed)";
@@ -661,6 +704,19 @@ FILE * lm_ggml_fopen(const char * fname, const char * mode) {
661
704
  }
662
705
 
663
706
  return file;
707
+ #elif defined(RNLLAMA_USE_FD_FILE)
708
+ // [RNLLAMA] VERY UNSAFE, ASSUMES GIVEN fname is FileDescriptor id
709
+
710
+ if (strchr(fname, '/') == NULL) {
711
+ char *endptr;
712
+ long num = strtol(fname, &endptr, 10);
713
+ FILE *file = fdopen(dup(num), mode);
714
+
715
+ if (file != NULL) {
716
+ return file;
717
+ }
718
+ }
719
+ return fopen(fname, mode);
664
720
  #else
665
721
  return fopen(fname, mode);
666
722
  #endif
@@ -686,7 +742,7 @@ static void lm_ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const floa
686
742
  static void lm_ggml_vec_dot_f16(int n, float * restrict s, size_t bs, lm_ggml_fp16_t * restrict x, size_t bx, lm_ggml_fp16_t * restrict y, size_t by, int nrc);
687
743
  static void lm_ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, lm_ggml_bf16_t * restrict x, size_t bx, lm_ggml_bf16_t * restrict y, size_t by, int nrc);
688
744
 
689
- static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = {
745
+ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = {
690
746
  [LM_GGML_TYPE_I8] = {
691
747
  .type_name = "i8",
692
748
  .blck_size = 1,
@@ -1108,9 +1164,9 @@ static const lm_ggml_type_traits_t type_traits[LM_GGML_TYPE_COUNT] = {
1108
1164
  };
1109
1165
 
1110
1166
  // For internal test use
1111
- lm_ggml_type_traits_t lm_ggml_internal_get_type_traits(enum lm_ggml_type type) {
1167
+ const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type) {
1112
1168
  LM_GGML_ASSERT(type < LM_GGML_TYPE_COUNT);
1113
- return type_traits[type];
1169
+ return &type_traits[type];
1114
1170
  }
1115
1171
 
1116
1172
  //
@@ -2951,6 +3007,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
2951
3007
  "SUM_ROWS",
2952
3008
  "MEAN",
2953
3009
  "ARGMAX",
3010
+ "COUNT_EQUAL",
2954
3011
  "REPEAT",
2955
3012
  "REPEAT_BACK",
2956
3013
  "CONCAT",
@@ -3024,7 +3081,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = {
3024
3081
  "OPT_STEP_ADAMW",
3025
3082
  };
3026
3083
 
3027
- static_assert(LM_GGML_OP_COUNT == 80, "LM_GGML_OP_COUNT != 80");
3084
+ static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81");
3028
3085
 
3029
3086
  static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
3030
3087
  "none",
@@ -3045,6 +3102,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
3045
3102
  "Σx_k",
3046
3103
  "Σx/n",
3047
3104
  "argmax(x)",
3105
+ "count_equal(x)",
3048
3106
  "repeat(x)",
3049
3107
  "repeat_back(x)",
3050
3108
  "concat(x, y)",
@@ -3118,7 +3176,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = {
3118
3176
  "adamw(x)",
3119
3177
  };
3120
3178
 
3121
- static_assert(LM_GGML_OP_COUNT == 80, "LM_GGML_OP_COUNT != 80");
3179
+ static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81");
3122
3180
 
3123
3181
  static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2");
3124
3182
 
@@ -3341,7 +3399,7 @@ void lm_ggml_numa_init(enum lm_ggml_numa_strategy numa_flag) {
3341
3399
  if (fptr != NULL) {
3342
3400
  char buf[42];
3343
3401
  if (fgets(buf, sizeof(buf), fptr) && strncmp(buf, "0\n", sizeof(buf)) != 0) {
3344
- LM_GGML_PRINT("WARNING: /proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
3402
+ LM_GGML_LOG_WARN("/proc/sys/kernel/numa_balancing is enabled, this has been observed to impair performance\n");
3345
3403
  }
3346
3404
  fclose(fptr);
3347
3405
  }
@@ -3359,36 +3417,36 @@ bool lm_ggml_is_numa(void) {
3359
3417
  ////////////////////////////////////////////////////////////////////////////////
3360
3418
 
3361
3419
  void lm_ggml_print_object(const struct lm_ggml_object * obj) {
3362
- LM_GGML_PRINT(" - lm_ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
3420
+ LM_GGML_LOG_INFO(" - lm_ggml_object: type = %d, offset = %zu, size = %zu, next = %p\n",
3363
3421
  obj->type, obj->offs, obj->size, (const void *) obj->next);
3364
3422
  }
3365
3423
 
3366
3424
  void lm_ggml_print_objects(const struct lm_ggml_context * ctx) {
3367
3425
  struct lm_ggml_object * obj = ctx->objects_begin;
3368
3426
 
3369
- LM_GGML_PRINT("%s: objects in context %p:\n", __func__, (const void *) ctx);
3427
+ LM_GGML_LOG_INFO("%s: objects in context %p:\n", __func__, (const void *) ctx);
3370
3428
 
3371
3429
  while (obj != NULL) {
3372
3430
  lm_ggml_print_object(obj);
3373
3431
  obj = obj->next;
3374
3432
  }
3375
3433
 
3376
- LM_GGML_PRINT("%s: --- end ---\n", __func__);
3434
+ LM_GGML_LOG_INFO("%s: --- end ---\n", __func__);
3377
3435
  }
3378
3436
 
3379
- LM_GGML_CALL int64_t lm_ggml_nelements(const struct lm_ggml_tensor * tensor) {
3437
+ int64_t lm_ggml_nelements(const struct lm_ggml_tensor * tensor) {
3380
3438
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
3381
3439
 
3382
3440
  return tensor->ne[0]*tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
3383
3441
  }
3384
3442
 
3385
- LM_GGML_CALL int64_t lm_ggml_nrows(const struct lm_ggml_tensor * tensor) {
3443
+ int64_t lm_ggml_nrows(const struct lm_ggml_tensor * tensor) {
3386
3444
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
3387
3445
 
3388
3446
  return tensor->ne[1]*tensor->ne[2]*tensor->ne[3];
3389
3447
  }
3390
3448
 
3391
- LM_GGML_CALL size_t lm_ggml_nbytes(const struct lm_ggml_tensor * tensor) {
3449
+ size_t lm_ggml_nbytes(const struct lm_ggml_tensor * tensor) {
3392
3450
  size_t nbytes;
3393
3451
  size_t blck_size = lm_ggml_blck_size(tensor->type);
3394
3452
  if (blck_size == 1) {
@@ -3411,15 +3469,15 @@ size_t lm_ggml_nbytes_pad(const struct lm_ggml_tensor * tensor) {
3411
3469
  return LM_GGML_PAD(lm_ggml_nbytes(tensor), LM_GGML_MEM_ALIGN);
3412
3470
  }
3413
3471
 
3414
- LM_GGML_CALL int64_t lm_ggml_blck_size(enum lm_ggml_type type) {
3472
+ int64_t lm_ggml_blck_size(enum lm_ggml_type type) {
3415
3473
  return type_traits[type].blck_size;
3416
3474
  }
3417
3475
 
3418
- LM_GGML_CALL size_t lm_ggml_type_size(enum lm_ggml_type type) {
3476
+ size_t lm_ggml_type_size(enum lm_ggml_type type) {
3419
3477
  return type_traits[type].type_size;
3420
3478
  }
3421
3479
 
3422
- LM_GGML_CALL size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) {
3480
+ size_t lm_ggml_row_size(enum lm_ggml_type type, int64_t ne) {
3423
3481
  assert(ne % lm_ggml_blck_size(type) == 0);
3424
3482
  return lm_ggml_type_size(type)*ne/lm_ggml_blck_size(type);
3425
3483
  }
@@ -3428,15 +3486,15 @@ double lm_ggml_type_sizef(enum lm_ggml_type type) {
3428
3486
  return ((double)(type_traits[type].type_size))/type_traits[type].blck_size;
3429
3487
  }
3430
3488
 
3431
- LM_GGML_CALL const char * lm_ggml_type_name(enum lm_ggml_type type) {
3489
+ const char * lm_ggml_type_name(enum lm_ggml_type type) {
3432
3490
  return type < LM_GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE";
3433
3491
  }
3434
3492
 
3435
- LM_GGML_CALL bool lm_ggml_is_quantized(enum lm_ggml_type type) {
3493
+ bool lm_ggml_is_quantized(enum lm_ggml_type type) {
3436
3494
  return type_traits[type].is_quantized;
3437
3495
  }
3438
3496
 
3439
- LM_GGML_CALL const char * lm_ggml_op_name(enum lm_ggml_op op) {
3497
+ const char * lm_ggml_op_name(enum lm_ggml_op op) {
3440
3498
  return LM_GGML_OP_NAME[op];
3441
3499
  }
3442
3500
 
@@ -3448,7 +3506,7 @@ const char * lm_ggml_unary_op_name(enum lm_ggml_unary_op op) {
3448
3506
  return LM_GGML_UNARY_OP_NAME[op];
3449
3507
  }
3450
3508
 
3451
- LM_GGML_CALL const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) {
3509
+ const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) {
3452
3510
  if (t->op == LM_GGML_OP_UNARY) {
3453
3511
  enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(t);
3454
3512
  return lm_ggml_unary_op_name(uop);
@@ -3456,7 +3514,7 @@ LM_GGML_CALL const char * lm_ggml_op_desc(const struct lm_ggml_tensor * t) {
3456
3514
  return lm_ggml_op_name(t->op);
3457
3515
  }
3458
3516
 
3459
- LM_GGML_CALL size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) {
3517
+ size_t lm_ggml_element_size(const struct lm_ggml_tensor * tensor) {
3460
3518
  return lm_ggml_type_size(tensor->type);
3461
3519
  }
3462
3520
 
@@ -3549,7 +3607,7 @@ size_t lm_ggml_tensor_overhead(void) {
3549
3607
  return LM_GGML_OBJECT_SIZE + LM_GGML_TENSOR_SIZE;
3550
3608
  }
3551
3609
 
3552
- LM_GGML_CALL bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) {
3610
+ bool lm_ggml_is_transposed(const struct lm_ggml_tensor * tensor) {
3553
3611
  return tensor->nb[0] > tensor->nb[1];
3554
3612
  }
3555
3613
 
@@ -3575,23 +3633,23 @@ static bool lm_ggml_is_contiguous_n(const struct lm_ggml_tensor * tensor, int n)
3575
3633
  return true;
3576
3634
  }
3577
3635
 
3578
- LM_GGML_CALL bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) {
3636
+ bool lm_ggml_is_contiguous(const struct lm_ggml_tensor * tensor) {
3579
3637
  return lm_ggml_is_contiguous_0(tensor);
3580
3638
  }
3581
3639
 
3582
- LM_GGML_CALL bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) {
3640
+ bool lm_ggml_is_contiguous_0(const struct lm_ggml_tensor * tensor) {
3583
3641
  return lm_ggml_is_contiguous_n(tensor, 0);
3584
3642
  }
3585
3643
 
3586
- LM_GGML_CALL bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) {
3644
+ bool lm_ggml_is_contiguous_1(const struct lm_ggml_tensor * tensor) {
3587
3645
  return lm_ggml_is_contiguous_n(tensor, 1);
3588
3646
  }
3589
3647
 
3590
- LM_GGML_CALL bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) {
3648
+ bool lm_ggml_is_contiguous_2(const struct lm_ggml_tensor * tensor) {
3591
3649
  return lm_ggml_is_contiguous_n(tensor, 2);
3592
3650
  }
3593
3651
 
3594
- LM_GGML_CALL bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) {
3652
+ bool lm_ggml_is_permuted(const struct lm_ggml_tensor * tensor) {
3595
3653
  static_assert(LM_GGML_MAX_DIMS == 4, "LM_GGML_MAX_DIMS is not 4 - update this function");
3596
3654
 
3597
3655
  return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
@@ -3606,7 +3664,7 @@ static inline bool lm_ggml_is_padded_1d(const struct lm_ggml_tensor * tensor) {
3606
3664
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3607
3665
  }
3608
3666
 
3609
- LM_GGML_CALL bool lm_ggml_is_empty(const struct lm_ggml_tensor * tensor) {
3667
+ bool lm_ggml_is_empty(const struct lm_ggml_tensor * tensor) {
3610
3668
  for (int i = 0; i < LM_GGML_MAX_DIMS; ++i) {
3611
3669
  if (tensor->ne[i] == 0) {
3612
3670
  // empty if any dimension has no elements
@@ -3673,6 +3731,70 @@ static inline int lm_ggml_up(int n, int m) {
3673
3731
 
3674
3732
  ////////////////////////////////////////////////////////////////////////////////
3675
3733
 
3734
+ #if defined(__ARM_ARCH)
3735
+
3736
+ #if defined(__linux__) && defined(__aarch64__)
3737
+ #include <sys/auxv.h>
3738
+ #elif defined(__APPLE__)
3739
+ #include <sys/sysctl.h>
3740
+ #endif
3741
+
3742
+ #if !defined(HWCAP2_I8MM)
3743
+ #define HWCAP2_I8MM 0
3744
+ #endif
3745
+
3746
+ static void lm_ggml_init_arm_arch_features(void) {
3747
+ #if defined(__linux__) && defined(__aarch64__)
3748
+ uint32_t hwcap = getauxval(AT_HWCAP);
3749
+ uint32_t hwcap2 = getauxval(AT_HWCAP2);
3750
+
3751
+ lm_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD);
3752
+ lm_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM);
3753
+ lm_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE);
3754
+
3755
+ #if defined(__ARM_FEATURE_SVE)
3756
+ lm_ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
3757
+ #endif
3758
+ #elif defined(__APPLE__)
3759
+ int oldp = 0;
3760
+ size_t size = sizeof(oldp);
3761
+ if (sysctlbyname("hw.optional.AdvSIMD", &oldp, &size, NULL, 0) != 0) {
3762
+ oldp = 0;
3763
+ }
3764
+ lm_ggml_arm_arch_features.has_neon = oldp;
3765
+
3766
+ if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) {
3767
+ oldp = 0;
3768
+ }
3769
+ lm_ggml_arm_arch_features.has_i8mm = oldp;
3770
+
3771
+ lm_ggml_arm_arch_features.has_sve = 0;
3772
+ lm_ggml_arm_arch_features.sve_cnt = 0;
3773
+ #else
3774
+ // Run-time CPU feature detection not implemented for this platform, fallback to compile time
3775
+ #if defined(__ARM_NEON)
3776
+ lm_ggml_arm_arch_features.has_neon = 1;
3777
+ #else
3778
+ lm_ggml_arm_arch_features.has_neon = 0;
3779
+ #endif
3780
+
3781
+ #if defined(__ARM_FEATURE_MATMUL_INT8)
3782
+ lm_ggml_arm_arch_features.has_i8mm = 1;
3783
+ #else
3784
+ lm_ggml_arm_arch_features.has_i8mm = 0;
3785
+ #endif
3786
+
3787
+ #if defined(__ARM_FEATURE_SVE)
3788
+ lm_ggml_arm_arch_features.has_sve = 1;
3789
+ lm_ggml_arm_arch_features.sve_cnt = 16;
3790
+ #else
3791
+ lm_ggml_arm_arch_features.has_sve = 0;
3792
+ lm_ggml_arm_arch_features.sve_cnt = 0;
3793
+ #endif
3794
+ #endif
3795
+ }
3796
+ #endif
3797
+
3676
3798
  struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) {
3677
3799
  // make this function thread safe
3678
3800
  lm_ggml_critical_section_start();
@@ -3723,6 +3845,10 @@ struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) {
3723
3845
  LM_GGML_PRINT_DEBUG("%s: g_state initialized in %f ms\n", __func__, (t_end - t_start)/1000.0f);
3724
3846
  }
3725
3847
 
3848
+ #if defined(__ARM_ARCH)
3849
+ lm_ggml_init_arm_arch_features();
3850
+ #endif
3851
+
3726
3852
  is_first_call = false;
3727
3853
  }
3728
3854
 
@@ -3771,12 +3897,6 @@ struct lm_ggml_context * lm_ggml_init(struct lm_ggml_init_params params) {
3771
3897
 
3772
3898
  LM_GGML_ASSERT_ALIGNED(ctx->mem_buffer);
3773
3899
 
3774
- #if defined(__ARM_FEATURE_SVE)
3775
- if (!lm_ggml_sve_cnt_b) {
3776
- lm_ggml_sve_cnt_b = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
3777
- }
3778
- #endif
3779
-
3780
3900
  LM_GGML_PRINT_DEBUG("%s: context initialized\n", __func__);
3781
3901
 
3782
3902
  lm_ggml_critical_section_end();
@@ -3894,7 +4014,7 @@ static struct lm_ggml_object * lm_ggml_new_object(struct lm_ggml_context * ctx,
3894
4014
  struct lm_ggml_object * const obj_new = (struct lm_ggml_object *)(mem_buffer + cur_end);
3895
4015
 
3896
4016
  if (cur_end + size_needed + LM_GGML_OBJECT_SIZE > ctx->mem_size) {
3897
- LM_GGML_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
4017
+ LM_GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n",
3898
4018
  __func__, cur_end + size_needed + LM_GGML_OBJECT_SIZE, ctx->mem_size);
3899
4019
  assert(false);
3900
4020
  return NULL;
@@ -3958,7 +4078,7 @@ static struct lm_ggml_tensor * lm_ggml_new_tensor_impl(
3958
4078
  if (ctx->scratch.data != NULL) {
3959
4079
  // allocate tensor data in the scratch buffer
3960
4080
  if (ctx->scratch.offs + data_size > ctx->scratch.size) {
3961
- LM_GGML_PRINT("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
4081
+ LM_GGML_LOG_WARN("%s: not enough space in the scratch memory pool (needed %zu, available %zu)\n",
3962
4082
  __func__, ctx->scratch.offs + data_size, ctx->scratch.size);
3963
4083
  assert(false);
3964
4084
  return NULL;
@@ -4127,9 +4247,13 @@ static void lm_ggml_set_op_params_f32(struct lm_ggml_tensor * tensor, uint32_t i
4127
4247
  }
4128
4248
 
4129
4249
  struct lm_ggml_tensor * lm_ggml_set_zero(struct lm_ggml_tensor * tensor) {
4250
+ if (lm_ggml_is_empty(tensor)) {
4251
+ return tensor;
4252
+ }
4130
4253
  if (tensor->buffer) {
4131
4254
  lm_ggml_backend_tensor_memset(tensor, 0, 0, lm_ggml_nbytes(tensor));
4132
4255
  } else {
4256
+ LM_GGML_ASSERT(tensor->data);
4133
4257
  memset(tensor->data, 0, lm_ggml_nbytes(tensor));
4134
4258
  }
4135
4259
  return tensor;
@@ -4560,7 +4684,7 @@ float * lm_ggml_get_data_f32(const struct lm_ggml_tensor * tensor) {
4560
4684
  return (float *)(tensor->data);
4561
4685
  }
4562
4686
 
4563
- LM_GGML_CALL enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor) {
4687
+ enum lm_ggml_unary_op lm_ggml_get_unary_op(const struct lm_ggml_tensor * tensor) {
4564
4688
  LM_GGML_ASSERT(tensor->op == LM_GGML_OP_UNARY);
4565
4689
  return (enum lm_ggml_unary_op) lm_ggml_get_op_params_i32(tensor, 0);
4566
4690
  }
@@ -4657,18 +4781,11 @@ struct lm_ggml_tensor * lm_ggml_get_tensor(struct lm_ggml_context * ctx, const c
4657
4781
 
4658
4782
  static struct lm_ggml_tensor * lm_ggml_dup_impl(
4659
4783
  struct lm_ggml_context * ctx,
4660
- struct lm_ggml_tensor * a,
4661
- bool inplace) {
4662
- bool is_node = false;
4663
-
4664
- if (!inplace && (a->grad)) {
4665
- is_node = true;
4666
- }
4667
-
4784
+ struct lm_ggml_tensor * a,
4785
+ bool inplace) {
4668
4786
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4669
4787
 
4670
- result->op = LM_GGML_OP_DUP;
4671
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4788
+ result->op = LM_GGML_OP_DUP;
4672
4789
  result->src[0] = a;
4673
4790
 
4674
4791
  return result;
@@ -4676,13 +4793,13 @@ static struct lm_ggml_tensor * lm_ggml_dup_impl(
4676
4793
 
4677
4794
  struct lm_ggml_tensor * lm_ggml_dup(
4678
4795
  struct lm_ggml_context * ctx,
4679
- struct lm_ggml_tensor * a) {
4796
+ struct lm_ggml_tensor * a) {
4680
4797
  return lm_ggml_dup_impl(ctx, a, false);
4681
4798
  }
4682
4799
 
4683
4800
  struct lm_ggml_tensor * lm_ggml_dup_inplace(
4684
4801
  struct lm_ggml_context * ctx,
4685
- struct lm_ggml_tensor * a) {
4802
+ struct lm_ggml_tensor * a) {
4686
4803
  return lm_ggml_dup_impl(ctx, a, true);
4687
4804
  }
4688
4805
 
@@ -4690,21 +4807,14 @@ struct lm_ggml_tensor * lm_ggml_dup_inplace(
4690
4807
 
4691
4808
  static struct lm_ggml_tensor * lm_ggml_add_impl(
4692
4809
  struct lm_ggml_context * ctx,
4693
- struct lm_ggml_tensor * a,
4694
- struct lm_ggml_tensor * b,
4695
- bool inplace) {
4810
+ struct lm_ggml_tensor * a,
4811
+ struct lm_ggml_tensor * b,
4812
+ bool inplace) {
4696
4813
  LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4697
4814
 
4698
- bool is_node = false;
4699
-
4700
- if (!inplace && (a->grad || b->grad)) {
4701
- is_node = true;
4702
- }
4703
-
4704
4815
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4705
4816
 
4706
- result->op = LM_GGML_OP_ADD;
4707
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4817
+ result->op = LM_GGML_OP_ADD;
4708
4818
  result->src[0] = a;
4709
4819
  result->src[1] = b;
4710
4820
 
@@ -4713,15 +4823,15 @@ static struct lm_ggml_tensor * lm_ggml_add_impl(
4713
4823
 
4714
4824
  struct lm_ggml_tensor * lm_ggml_add(
4715
4825
  struct lm_ggml_context * ctx,
4716
- struct lm_ggml_tensor * a,
4717
- struct lm_ggml_tensor * b) {
4826
+ struct lm_ggml_tensor * a,
4827
+ struct lm_ggml_tensor * b) {
4718
4828
  return lm_ggml_add_impl(ctx, a, b, false);
4719
4829
  }
4720
4830
 
4721
4831
  struct lm_ggml_tensor * lm_ggml_add_inplace(
4722
4832
  struct lm_ggml_context * ctx,
4723
- struct lm_ggml_tensor * a,
4724
- struct lm_ggml_tensor * b) {
4833
+ struct lm_ggml_tensor * a,
4834
+ struct lm_ggml_tensor * b) {
4725
4835
  return lm_ggml_add_impl(ctx, a, b, true);
4726
4836
  }
4727
4837
 
@@ -4729,9 +4839,9 @@ struct lm_ggml_tensor * lm_ggml_add_inplace(
4729
4839
 
4730
4840
  static struct lm_ggml_tensor * lm_ggml_add_cast_impl(
4731
4841
  struct lm_ggml_context * ctx,
4732
- struct lm_ggml_tensor * a,
4733
- struct lm_ggml_tensor * b,
4734
- enum lm_ggml_type type) {
4842
+ struct lm_ggml_tensor * a,
4843
+ struct lm_ggml_tensor * b,
4844
+ enum lm_ggml_type type) {
4735
4845
  // TODO: support less-strict constraint
4736
4846
  // LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4737
4847
  LM_GGML_ASSERT(lm_ggml_can_repeat_rows(b, a));
@@ -4741,18 +4851,9 @@ static struct lm_ggml_tensor * lm_ggml_add_cast_impl(
4741
4851
  a->type == LM_GGML_TYPE_F16 ||
4742
4852
  a->type == LM_GGML_TYPE_BF16);
4743
4853
 
4744
- bool is_node = false;
4745
-
4746
- if (a->grad || b->grad) {
4747
- // TODO: support backward pass for broadcasting
4748
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4749
- is_node = true;
4750
- }
4751
-
4752
4854
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne);
4753
4855
 
4754
- result->op = LM_GGML_OP_ADD;
4755
- result->grad = is_node ? lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, LM_GGML_MAX_DIMS, a->ne) : NULL;
4856
+ result->op = LM_GGML_OP_ADD;
4756
4857
  result->src[0] = a;
4757
4858
  result->src[1] = b;
4758
4859
 
@@ -4761,9 +4862,9 @@ static struct lm_ggml_tensor * lm_ggml_add_cast_impl(
4761
4862
 
4762
4863
  struct lm_ggml_tensor * lm_ggml_add_cast(
4763
4864
  struct lm_ggml_context * ctx,
4764
- struct lm_ggml_tensor * a,
4765
- struct lm_ggml_tensor * b,
4766
- enum lm_ggml_type type) {
4865
+ struct lm_ggml_tensor * a,
4866
+ struct lm_ggml_tensor * b,
4867
+ enum lm_ggml_type type) {
4767
4868
  return lm_ggml_add_cast_impl(ctx, a, b, type);
4768
4869
  }
4769
4870
 
@@ -4771,22 +4872,15 @@ struct lm_ggml_tensor * lm_ggml_add_cast(
4771
4872
 
4772
4873
  static struct lm_ggml_tensor * lm_ggml_add1_impl(
4773
4874
  struct lm_ggml_context * ctx,
4774
- struct lm_ggml_tensor * a,
4775
- struct lm_ggml_tensor * b,
4776
- bool inplace) {
4875
+ struct lm_ggml_tensor * a,
4876
+ struct lm_ggml_tensor * b,
4877
+ bool inplace) {
4777
4878
  LM_GGML_ASSERT(lm_ggml_is_scalar(b));
4778
4879
  LM_GGML_ASSERT(lm_ggml_is_padded_1d(a));
4779
4880
 
4780
- bool is_node = false;
4781
-
4782
- if (a->grad || b->grad) {
4783
- is_node = true;
4784
- }
4785
-
4786
4881
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4787
4882
 
4788
- result->op = LM_GGML_OP_ADD1;
4789
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4883
+ result->op = LM_GGML_OP_ADD1;
4790
4884
  result->src[0] = a;
4791
4885
  result->src[1] = b;
4792
4886
 
@@ -4795,15 +4889,15 @@ static struct lm_ggml_tensor * lm_ggml_add1_impl(
4795
4889
 
4796
4890
  struct lm_ggml_tensor * lm_ggml_add1(
4797
4891
  struct lm_ggml_context * ctx,
4798
- struct lm_ggml_tensor * a,
4799
- struct lm_ggml_tensor * b) {
4892
+ struct lm_ggml_tensor * a,
4893
+ struct lm_ggml_tensor * b) {
4800
4894
  return lm_ggml_add1_impl(ctx, a, b, false);
4801
4895
  }
4802
4896
 
4803
4897
  struct lm_ggml_tensor * lm_ggml_add1_inplace(
4804
4898
  struct lm_ggml_context * ctx,
4805
- struct lm_ggml_tensor * a,
4806
- struct lm_ggml_tensor * b) {
4899
+ struct lm_ggml_tensor * a,
4900
+ struct lm_ggml_tensor * b) {
4807
4901
  return lm_ggml_add1_impl(ctx, a, b, true);
4808
4902
  }
4809
4903
 
@@ -4811,31 +4905,24 @@ struct lm_ggml_tensor * lm_ggml_add1_inplace(
4811
4905
 
4812
4906
  static struct lm_ggml_tensor * lm_ggml_acc_impl(
4813
4907
  struct lm_ggml_context * ctx,
4814
- struct lm_ggml_tensor * a,
4815
- struct lm_ggml_tensor * b,
4816
- size_t nb1,
4817
- size_t nb2,
4818
- size_t nb3,
4819
- size_t offset,
4820
- bool inplace) {
4908
+ struct lm_ggml_tensor * a,
4909
+ struct lm_ggml_tensor * b,
4910
+ size_t nb1,
4911
+ size_t nb2,
4912
+ size_t nb3,
4913
+ size_t offset,
4914
+ bool inplace) {
4821
4915
  LM_GGML_ASSERT(lm_ggml_nelements(b) <= lm_ggml_nelements(a));
4822
4916
  LM_GGML_ASSERT(lm_ggml_is_contiguous(a));
4823
4917
  LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32);
4824
4918
  LM_GGML_ASSERT(b->type == LM_GGML_TYPE_F32);
4825
4919
 
4826
- bool is_node = false;
4827
-
4828
- if (!inplace && (a->grad || b->grad)) {
4829
- is_node = true;
4830
- }
4831
-
4832
4920
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4833
4921
 
4834
4922
  int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
4835
4923
  lm_ggml_set_op_params(result, params, sizeof(params));
4836
4924
 
4837
- result->op = LM_GGML_OP_ACC;
4838
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4925
+ result->op = LM_GGML_OP_ACC;
4839
4926
  result->src[0] = a;
4840
4927
  result->src[1] = b;
4841
4928
 
@@ -4844,23 +4931,23 @@ static struct lm_ggml_tensor * lm_ggml_acc_impl(
4844
4931
 
4845
4932
  struct lm_ggml_tensor * lm_ggml_acc(
4846
4933
  struct lm_ggml_context * ctx,
4847
- struct lm_ggml_tensor * a,
4848
- struct lm_ggml_tensor * b,
4849
- size_t nb1,
4850
- size_t nb2,
4851
- size_t nb3,
4852
- size_t offset) {
4934
+ struct lm_ggml_tensor * a,
4935
+ struct lm_ggml_tensor * b,
4936
+ size_t nb1,
4937
+ size_t nb2,
4938
+ size_t nb3,
4939
+ size_t offset) {
4853
4940
  return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
4854
4941
  }
4855
4942
 
4856
4943
  struct lm_ggml_tensor * lm_ggml_acc_inplace(
4857
4944
  struct lm_ggml_context * ctx,
4858
- struct lm_ggml_tensor * a,
4859
- struct lm_ggml_tensor * b,
4860
- size_t nb1,
4861
- size_t nb2,
4862
- size_t nb3,
4863
- size_t offset) {
4945
+ struct lm_ggml_tensor * a,
4946
+ struct lm_ggml_tensor * b,
4947
+ size_t nb1,
4948
+ size_t nb2,
4949
+ size_t nb3,
4950
+ size_t offset) {
4864
4951
  return lm_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
4865
4952
  }
4866
4953
 
@@ -4868,23 +4955,14 @@ struct lm_ggml_tensor * lm_ggml_acc_inplace(
4868
4955
 
4869
4956
  static struct lm_ggml_tensor * lm_ggml_sub_impl(
4870
4957
  struct lm_ggml_context * ctx,
4871
- struct lm_ggml_tensor * a,
4872
- struct lm_ggml_tensor * b,
4873
- bool inplace) {
4958
+ struct lm_ggml_tensor * a,
4959
+ struct lm_ggml_tensor * b,
4960
+ bool inplace) {
4874
4961
  LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4875
4962
 
4876
- bool is_node = false;
4877
-
4878
- if (!inplace && (a->grad || b->grad)) {
4879
- // TODO: support backward pass for broadcasting
4880
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4881
- is_node = true;
4882
- }
4883
-
4884
4963
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4885
4964
 
4886
- result->op = LM_GGML_OP_SUB;
4887
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4965
+ result->op = LM_GGML_OP_SUB;
4888
4966
  result->src[0] = a;
4889
4967
  result->src[1] = b;
4890
4968
 
@@ -4893,15 +4971,15 @@ static struct lm_ggml_tensor * lm_ggml_sub_impl(
4893
4971
 
4894
4972
  struct lm_ggml_tensor * lm_ggml_sub(
4895
4973
  struct lm_ggml_context * ctx,
4896
- struct lm_ggml_tensor * a,
4897
- struct lm_ggml_tensor * b) {
4974
+ struct lm_ggml_tensor * a,
4975
+ struct lm_ggml_tensor * b) {
4898
4976
  return lm_ggml_sub_impl(ctx, a, b, false);
4899
4977
  }
4900
4978
 
4901
4979
  struct lm_ggml_tensor * lm_ggml_sub_inplace(
4902
4980
  struct lm_ggml_context * ctx,
4903
- struct lm_ggml_tensor * a,
4904
- struct lm_ggml_tensor * b) {
4981
+ struct lm_ggml_tensor * a,
4982
+ struct lm_ggml_tensor * b) {
4905
4983
  return lm_ggml_sub_impl(ctx, a, b, true);
4906
4984
  }
4907
4985
 
@@ -4909,27 +4987,14 @@ struct lm_ggml_tensor * lm_ggml_sub_inplace(
4909
4987
 
4910
4988
  static struct lm_ggml_tensor * lm_ggml_mul_impl(
4911
4989
  struct lm_ggml_context * ctx,
4912
- struct lm_ggml_tensor * a,
4913
- struct lm_ggml_tensor * b,
4914
- bool inplace) {
4990
+ struct lm_ggml_tensor * a,
4991
+ struct lm_ggml_tensor * b,
4992
+ bool inplace) {
4915
4993
  LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4916
4994
 
4917
- bool is_node = false;
4918
-
4919
- if (!inplace && (a->grad || b->grad)) {
4920
- // TODO: support backward pass for broadcasting
4921
- LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
4922
- is_node = true;
4923
- }
4924
-
4925
- if (inplace) {
4926
- LM_GGML_ASSERT(!is_node);
4927
- }
4928
-
4929
4995
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4930
4996
 
4931
- result->op = LM_GGML_OP_MUL;
4932
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
4997
+ result->op = LM_GGML_OP_MUL;
4933
4998
  result->src[0] = a;
4934
4999
  result->src[1] = b;
4935
5000
 
@@ -4954,25 +5019,14 @@ struct lm_ggml_tensor * lm_ggml_mul_inplace(
4954
5019
 
4955
5020
  static struct lm_ggml_tensor * lm_ggml_div_impl(
4956
5021
  struct lm_ggml_context * ctx,
4957
- struct lm_ggml_tensor * a,
4958
- struct lm_ggml_tensor * b,
4959
- bool inplace) {
5022
+ struct lm_ggml_tensor * a,
5023
+ struct lm_ggml_tensor * b,
5024
+ bool inplace) {
4960
5025
  LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
4961
5026
 
4962
- bool is_node = false;
4963
-
4964
- if (!inplace && (a->grad || b->grad)) {
4965
- is_node = true;
4966
- }
4967
-
4968
- if (inplace) {
4969
- LM_GGML_ASSERT(!is_node);
4970
- }
4971
-
4972
5027
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
4973
5028
 
4974
- result->op = LM_GGML_OP_DIV;
4975
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5029
+ result->op = LM_GGML_OP_DIV;
4976
5030
  result->src[0] = a;
4977
5031
  result->src[1] = b;
4978
5032
 
@@ -4997,18 +5051,11 @@ struct lm_ggml_tensor * lm_ggml_div_inplace(
4997
5051
 
4998
5052
  static struct lm_ggml_tensor * lm_ggml_sqr_impl(
4999
5053
  struct lm_ggml_context * ctx,
5000
- struct lm_ggml_tensor * a,
5001
- bool inplace) {
5002
- bool is_node = false;
5003
-
5004
- if (!inplace && (a->grad)) {
5005
- is_node = true;
5006
- }
5007
-
5054
+ struct lm_ggml_tensor * a,
5055
+ bool inplace) {
5008
5056
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5009
5057
 
5010
- result->op = LM_GGML_OP_SQR;
5011
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5058
+ result->op = LM_GGML_OP_SQR;
5012
5059
  result->src[0] = a;
5013
5060
 
5014
5061
  return result;
@@ -5030,18 +5077,11 @@ struct lm_ggml_tensor * lm_ggml_sqr_inplace(
5030
5077
 
5031
5078
  static struct lm_ggml_tensor * lm_ggml_sqrt_impl(
5032
5079
  struct lm_ggml_context * ctx,
5033
- struct lm_ggml_tensor * a,
5034
- bool inplace) {
5035
- bool is_node = false;
5036
-
5037
- if (!inplace && (a->grad)) {
5038
- is_node = true;
5039
- }
5040
-
5080
+ struct lm_ggml_tensor * a,
5081
+ bool inplace) {
5041
5082
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5042
5083
 
5043
- result->op = LM_GGML_OP_SQRT;
5044
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5084
+ result->op = LM_GGML_OP_SQRT;
5045
5085
  result->src[0] = a;
5046
5086
 
5047
5087
  return result;
@@ -5064,17 +5104,10 @@ struct lm_ggml_tensor * lm_ggml_sqrt_inplace(
5064
5104
  static struct lm_ggml_tensor * lm_ggml_log_impl(
5065
5105
  struct lm_ggml_context * ctx,
5066
5106
  struct lm_ggml_tensor * a,
5067
- bool inplace) {
5068
- bool is_node = false;
5069
-
5070
- if (!inplace && (a->grad)) {
5071
- is_node = true;
5072
- }
5073
-
5107
+ bool inplace) {
5074
5108
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5075
5109
 
5076
- result->op = LM_GGML_OP_LOG;
5077
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5110
+ result->op = LM_GGML_OP_LOG;
5078
5111
  result->src[0] = a;
5079
5112
 
5080
5113
  return result;
@@ -5097,17 +5130,10 @@ struct lm_ggml_tensor * lm_ggml_log_inplace(
5097
5130
  static struct lm_ggml_tensor * lm_ggml_sin_impl(
5098
5131
  struct lm_ggml_context * ctx,
5099
5132
  struct lm_ggml_tensor * a,
5100
- bool inplace) {
5101
- bool is_node = false;
5102
-
5103
- if (!inplace && (a->grad)) {
5104
- is_node = true;
5105
- }
5106
-
5133
+ bool inplace) {
5107
5134
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5108
5135
 
5109
- result->op = LM_GGML_OP_SIN;
5110
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5136
+ result->op = LM_GGML_OP_SIN;
5111
5137
  result->src[0] = a;
5112
5138
 
5113
5139
  return result;
@@ -5130,17 +5156,10 @@ struct lm_ggml_tensor * lm_ggml_sin_inplace(
5130
5156
  static struct lm_ggml_tensor * lm_ggml_cos_impl(
5131
5157
  struct lm_ggml_context * ctx,
5132
5158
  struct lm_ggml_tensor * a,
5133
- bool inplace) {
5134
- bool is_node = false;
5135
-
5136
- if (!inplace && (a->grad)) {
5137
- is_node = true;
5138
- }
5139
-
5159
+ bool inplace) {
5140
5160
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5141
5161
 
5142
- result->op = LM_GGML_OP_COS;
5143
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5162
+ result->op = LM_GGML_OP_COS;
5144
5163
  result->src[0] = a;
5145
5164
 
5146
5165
  return result;
@@ -5162,17 +5181,10 @@ struct lm_ggml_tensor * lm_ggml_cos_inplace(
5162
5181
 
5163
5182
  struct lm_ggml_tensor * lm_ggml_sum(
5164
5183
  struct lm_ggml_context * ctx,
5165
- struct lm_ggml_tensor * a) {
5166
- bool is_node = false;
5167
-
5168
- if (a->grad) {
5169
- is_node = true;
5170
- }
5171
-
5184
+ struct lm_ggml_tensor * a) {
5172
5185
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1);
5173
5186
 
5174
- result->op = LM_GGML_OP_SUM;
5175
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5187
+ result->op = LM_GGML_OP_SUM;
5176
5188
  result->src[0] = a;
5177
5189
 
5178
5190
  return result;
@@ -5182,13 +5194,7 @@ struct lm_ggml_tensor * lm_ggml_sum(
5182
5194
 
5183
5195
  struct lm_ggml_tensor * lm_ggml_sum_rows(
5184
5196
  struct lm_ggml_context * ctx,
5185
- struct lm_ggml_tensor * a) {
5186
- bool is_node = false;
5187
-
5188
- if (a->grad) {
5189
- is_node = true;
5190
- }
5191
-
5197
+ struct lm_ggml_tensor * a) {
5192
5198
  int64_t ne[LM_GGML_MAX_DIMS] = { 1 };
5193
5199
  for (int i = 1; i < LM_GGML_MAX_DIMS; ++i) {
5194
5200
  ne[i] = a->ne[i];
@@ -5196,8 +5202,7 @@ struct lm_ggml_tensor * lm_ggml_sum_rows(
5196
5202
 
5197
5203
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne);
5198
5204
 
5199
- result->op = LM_GGML_OP_SUM_ROWS;
5200
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5205
+ result->op = LM_GGML_OP_SUM_ROWS;
5201
5206
  result->src[0] = a;
5202
5207
 
5203
5208
  return result;
@@ -5207,19 +5212,11 @@ struct lm_ggml_tensor * lm_ggml_sum_rows(
5207
5212
 
5208
5213
  struct lm_ggml_tensor * lm_ggml_mean(
5209
5214
  struct lm_ggml_context * ctx,
5210
- struct lm_ggml_tensor * a) {
5211
- bool is_node = false;
5212
-
5213
- if (a->grad) {
5214
- LM_GGML_ABORT("fatal error"); // TODO: implement
5215
- is_node = true;
5216
- }
5217
-
5215
+ struct lm_ggml_tensor * a) {
5218
5216
  int64_t ne[4] = { 1, a->ne[1], a->ne[2], a->ne[3] };
5219
5217
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
5220
5218
 
5221
- result->op = LM_GGML_OP_MEAN;
5222
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5219
+ result->op = LM_GGML_OP_MEAN;
5223
5220
  result->src[0] = a;
5224
5221
 
5225
5222
  return result;
@@ -5229,42 +5226,45 @@ struct lm_ggml_tensor * lm_ggml_mean(
5229
5226
 
5230
5227
  struct lm_ggml_tensor * lm_ggml_argmax(
5231
5228
  struct lm_ggml_context * ctx,
5232
- struct lm_ggml_tensor * a) {
5229
+ struct lm_ggml_tensor * a) {
5233
5230
  LM_GGML_ASSERT(lm_ggml_is_matrix(a));
5234
- bool is_node = false;
5235
-
5236
- if (a->grad) {
5237
- LM_GGML_ABORT("fatal error");
5238
- is_node = true;
5239
- }
5240
5231
 
5241
5232
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]);
5242
5233
 
5243
- result->op = LM_GGML_OP_ARGMAX;
5244
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5234
+ result->op = LM_GGML_OP_ARGMAX;
5245
5235
  result->src[0] = a;
5246
5236
 
5247
5237
  return result;
5248
5238
  }
5249
5239
 
5250
- // lm_ggml_repeat
5240
+ // lm_ggml_count_equal
5251
5241
 
5252
- struct lm_ggml_tensor * lm_ggml_repeat(
5242
+ struct lm_ggml_tensor * lm_ggml_count_equal(
5253
5243
  struct lm_ggml_context * ctx,
5254
- struct lm_ggml_tensor * a,
5255
- struct lm_ggml_tensor * b) {
5256
- LM_GGML_ASSERT(lm_ggml_can_repeat(a, b));
5257
-
5258
- bool is_node = false;
5244
+ struct lm_ggml_tensor * a,
5245
+ struct lm_ggml_tensor * b) {
5246
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
5259
5247
 
5260
- if (a->grad) {
5261
- is_node = true;
5262
- }
5248
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I64, 1);
5249
+
5250
+ result->op = LM_GGML_OP_COUNT_EQUAL;
5251
+ result->src[0] = a;
5252
+ result->src[1] = b;
5253
+
5254
+ return result;
5255
+ }
5256
+
5257
+ // lm_ggml_repeat
5258
+
5259
+ struct lm_ggml_tensor * lm_ggml_repeat(
5260
+ struct lm_ggml_context * ctx,
5261
+ struct lm_ggml_tensor * a,
5262
+ struct lm_ggml_tensor * b) {
5263
+ LM_GGML_ASSERT(lm_ggml_can_repeat(a, b));
5263
5264
 
5264
5265
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne);
5265
5266
 
5266
- result->op = LM_GGML_OP_REPEAT;
5267
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5267
+ result->op = LM_GGML_OP_REPEAT;
5268
5268
  result->src[0] = a;
5269
5269
 
5270
5270
  return result;
@@ -5274,24 +5274,13 @@ struct lm_ggml_tensor * lm_ggml_repeat(
5274
5274
 
5275
5275
  struct lm_ggml_tensor * lm_ggml_repeat_back(
5276
5276
  struct lm_ggml_context * ctx,
5277
- struct lm_ggml_tensor * a,
5278
- struct lm_ggml_tensor * b) {
5277
+ struct lm_ggml_tensor * a,
5278
+ struct lm_ggml_tensor * b) {
5279
5279
  LM_GGML_ASSERT(lm_ggml_can_repeat(b, a));
5280
5280
 
5281
- bool is_node = false;
5282
-
5283
- if (a->grad) {
5284
- is_node = true;
5285
- }
5286
-
5287
- if (lm_ggml_are_same_shape(a, b) && !is_node) {
5288
- return a;
5289
- }
5290
-
5291
5281
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, b->ne);
5292
5282
 
5293
- result->op = LM_GGML_OP_REPEAT_BACK;
5294
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5283
+ result->op = LM_GGML_OP_REPEAT_BACK;
5295
5284
  result->src[0] = a;
5296
5285
 
5297
5286
  return result;
@@ -5301,9 +5290,9 @@ struct lm_ggml_tensor * lm_ggml_repeat_back(
5301
5290
 
5302
5291
  struct lm_ggml_tensor * lm_ggml_concat(
5303
5292
  struct lm_ggml_context * ctx,
5304
- struct lm_ggml_tensor * a,
5305
- struct lm_ggml_tensor * b,
5306
- int dim) {
5293
+ struct lm_ggml_tensor * a,
5294
+ struct lm_ggml_tensor * b,
5295
+ int dim) {
5307
5296
  LM_GGML_ASSERT(dim >= 0 && dim < LM_GGML_MAX_DIMS);
5308
5297
 
5309
5298
  int64_t ne[LM_GGML_MAX_DIMS];
@@ -5316,19 +5305,11 @@ struct lm_ggml_tensor * lm_ggml_concat(
5316
5305
  ne[d] = a->ne[d];
5317
5306
  }
5318
5307
 
5319
- bool is_node = false;
5320
-
5321
- if (a->grad || b->grad) {
5322
- LM_GGML_ABORT("fatal error"); // TODO: implement
5323
- is_node = true;
5324
- }
5325
-
5326
5308
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, LM_GGML_MAX_DIMS, ne);
5327
5309
 
5328
5310
  lm_ggml_set_op_params_i32(result, 0, dim);
5329
5311
 
5330
- result->op = LM_GGML_OP_CONCAT;
5331
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5312
+ result->op = LM_GGML_OP_CONCAT;
5332
5313
  result->src[0] = a;
5333
5314
  result->src[1] = b;
5334
5315
 
@@ -5437,20 +5418,14 @@ struct lm_ggml_tensor * lm_ggml_relu_inplace(
5437
5418
 
5438
5419
  struct lm_ggml_tensor * lm_ggml_leaky_relu(
5439
5420
  struct lm_ggml_context * ctx,
5440
- struct lm_ggml_tensor * a, float negative_slope, bool inplace) {
5441
- bool is_node = false;
5442
-
5443
- if (!inplace && (a->grad)) {
5444
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
5445
- is_node = true;
5446
- }
5447
-
5421
+ struct lm_ggml_tensor * a,
5422
+ float negative_slope,
5423
+ bool inplace) {
5448
5424
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5449
5425
 
5450
5426
  lm_ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
5451
5427
 
5452
- result->op = LM_GGML_OP_LEAKY_RELU;
5453
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5428
+ result->op = LM_GGML_OP_LEAKY_RELU;
5454
5429
  result->src[0] = a;
5455
5430
 
5456
5431
  return result;
@@ -5518,17 +5493,9 @@ struct lm_ggml_tensor * lm_ggml_silu_back(
5518
5493
  struct lm_ggml_context * ctx,
5519
5494
  struct lm_ggml_tensor * a,
5520
5495
  struct lm_ggml_tensor * b) {
5521
- bool is_node = false;
5522
-
5523
- if (a->grad || b->grad) {
5524
- // TODO: implement backward
5525
- is_node = true;
5526
- }
5527
-
5528
5496
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
5529
5497
 
5530
- result->op = LM_GGML_OP_SILU_BACK;
5531
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5498
+ result->op = LM_GGML_OP_SILU_BACK;
5532
5499
  result->src[0] = a;
5533
5500
  result->src[1] = b;
5534
5501
 
@@ -5536,6 +5503,7 @@ struct lm_ggml_tensor * lm_ggml_silu_back(
5536
5503
  }
5537
5504
 
5538
5505
  // ggml hardswish
5506
+
5539
5507
  struct lm_ggml_tensor * lm_ggml_hardswish(
5540
5508
  struct lm_ggml_context * ctx,
5541
5509
  struct lm_ggml_tensor * a) {
@@ -5543,6 +5511,7 @@ struct lm_ggml_tensor * lm_ggml_hardswish(
5543
5511
  }
5544
5512
 
5545
5513
  // ggml hardsigmoid
5514
+
5546
5515
  struct lm_ggml_tensor * lm_ggml_hardsigmoid(
5547
5516
  struct lm_ggml_context * ctx,
5548
5517
  struct lm_ggml_tensor * a) {
@@ -5550,6 +5519,7 @@ struct lm_ggml_tensor * lm_ggml_hardsigmoid(
5550
5519
  }
5551
5520
 
5552
5521
  // ggml exp
5522
+
5553
5523
  struct lm_ggml_tensor * lm_ggml_exp(
5554
5524
  struct lm_ggml_context * ctx,
5555
5525
  struct lm_ggml_tensor * a) {
@@ -5567,21 +5537,13 @@ struct lm_ggml_tensor * lm_ggml_exp_inplace(
5567
5537
  static struct lm_ggml_tensor * lm_ggml_norm_impl(
5568
5538
  struct lm_ggml_context * ctx,
5569
5539
  struct lm_ggml_tensor * a,
5570
- float eps,
5571
- bool inplace) {
5572
- bool is_node = false;
5573
-
5574
- if (!inplace && (a->grad)) {
5575
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
5576
- is_node = true;
5577
- }
5578
-
5540
+ float eps,
5541
+ bool inplace) {
5579
5542
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5580
5543
 
5581
5544
  lm_ggml_set_op_params(result, &eps, sizeof(eps));
5582
5545
 
5583
- result->op = LM_GGML_OP_NORM;
5584
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5546
+ result->op = LM_GGML_OP_NORM;
5585
5547
  result->src[0] = a;
5586
5548
 
5587
5549
  return result;
@@ -5590,14 +5552,14 @@ static struct lm_ggml_tensor * lm_ggml_norm_impl(
5590
5552
  struct lm_ggml_tensor * lm_ggml_norm(
5591
5553
  struct lm_ggml_context * ctx,
5592
5554
  struct lm_ggml_tensor * a,
5593
- float eps) {
5555
+ float eps) {
5594
5556
  return lm_ggml_norm_impl(ctx, a, eps, false);
5595
5557
  }
5596
5558
 
5597
5559
  struct lm_ggml_tensor * lm_ggml_norm_inplace(
5598
5560
  struct lm_ggml_context * ctx,
5599
5561
  struct lm_ggml_tensor * a,
5600
- float eps) {
5562
+ float eps) {
5601
5563
  return lm_ggml_norm_impl(ctx, a, eps, true);
5602
5564
  }
5603
5565
 
@@ -5606,20 +5568,13 @@ struct lm_ggml_tensor * lm_ggml_norm_inplace(
5606
5568
  static struct lm_ggml_tensor * lm_ggml_rms_norm_impl(
5607
5569
  struct lm_ggml_context * ctx,
5608
5570
  struct lm_ggml_tensor * a,
5609
- float eps,
5610
- bool inplace) {
5611
- bool is_node = false;
5612
-
5613
- if (!inplace && (a->grad)) {
5614
- is_node = true;
5615
- }
5616
-
5571
+ float eps,
5572
+ bool inplace) {
5617
5573
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5618
5574
 
5619
5575
  lm_ggml_set_op_params(result, &eps, sizeof(eps));
5620
5576
 
5621
- result->op = LM_GGML_OP_RMS_NORM;
5622
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5577
+ result->op = LM_GGML_OP_RMS_NORM;
5623
5578
  result->src[0] = a;
5624
5579
 
5625
5580
  return result;
@@ -5628,14 +5583,14 @@ static struct lm_ggml_tensor * lm_ggml_rms_norm_impl(
5628
5583
  struct lm_ggml_tensor * lm_ggml_rms_norm(
5629
5584
  struct lm_ggml_context * ctx,
5630
5585
  struct lm_ggml_tensor * a,
5631
- float eps) {
5586
+ float eps) {
5632
5587
  return lm_ggml_rms_norm_impl(ctx, a, eps, false);
5633
5588
  }
5634
5589
 
5635
5590
  struct lm_ggml_tensor * lm_ggml_rms_norm_inplace(
5636
5591
  struct lm_ggml_context * ctx,
5637
5592
  struct lm_ggml_tensor * a,
5638
- float eps) {
5593
+ float eps) {
5639
5594
  return lm_ggml_rms_norm_impl(ctx, a, eps, true);
5640
5595
  }
5641
5596
 
@@ -5645,20 +5600,12 @@ struct lm_ggml_tensor * lm_ggml_rms_norm_back(
5645
5600
  struct lm_ggml_context * ctx,
5646
5601
  struct lm_ggml_tensor * a,
5647
5602
  struct lm_ggml_tensor * b,
5648
- float eps) {
5649
- bool is_node = false;
5650
-
5651
- if (a->grad) {
5652
- // TODO: implement backward
5653
- is_node = true;
5654
- }
5655
-
5603
+ float eps) {
5656
5604
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
5657
5605
 
5658
5606
  lm_ggml_set_op_params(result, &eps, sizeof(eps));
5659
5607
 
5660
- result->op = LM_GGML_OP_RMS_NORM_BACK;
5661
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5608
+ result->op = LM_GGML_OP_RMS_NORM_BACK;
5662
5609
  result->src[0] = a;
5663
5610
  result->src[1] = b;
5664
5611
 
@@ -5668,43 +5615,35 @@ struct lm_ggml_tensor * lm_ggml_rms_norm_back(
5668
5615
  // lm_ggml_group_norm
5669
5616
 
5670
5617
  static struct lm_ggml_tensor * lm_ggml_group_norm_impl(
5671
- struct lm_ggml_context * ctx,
5672
- struct lm_ggml_tensor * a,
5673
- int n_groups,
5674
- float eps,
5675
- bool inplace) {
5676
-
5677
- bool is_node = false;
5678
- if (!inplace && (a->grad)) {
5679
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
5680
- is_node = true;
5681
- }
5682
-
5618
+ struct lm_ggml_context * ctx,
5619
+ struct lm_ggml_tensor * a,
5620
+ int n_groups,
5621
+ float eps,
5622
+ bool inplace) {
5683
5623
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5684
5624
 
5685
5625
  lm_ggml_set_op_params_i32(result, 0, n_groups);
5686
5626
  lm_ggml_set_op_params_f32(result, 1, eps);
5687
5627
 
5688
- result->op = LM_GGML_OP_GROUP_NORM;
5689
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5628
+ result->op = LM_GGML_OP_GROUP_NORM;
5690
5629
  result->src[0] = a;
5691
5630
 
5692
5631
  return result;
5693
5632
  }
5694
5633
 
5695
5634
  struct lm_ggml_tensor * lm_ggml_group_norm(
5696
- struct lm_ggml_context * ctx,
5697
- struct lm_ggml_tensor * a,
5698
- int n_groups,
5699
- float eps) {
5635
+ struct lm_ggml_context * ctx,
5636
+ struct lm_ggml_tensor * a,
5637
+ int n_groups,
5638
+ float eps) {
5700
5639
  return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, false);
5701
5640
  }
5702
5641
 
5703
5642
  struct lm_ggml_tensor * lm_ggml_group_norm_inplace(
5704
- struct lm_ggml_context * ctx,
5705
- struct lm_ggml_tensor * a,
5706
- int n_groups,
5707
- float eps) {
5643
+ struct lm_ggml_context * ctx,
5644
+ struct lm_ggml_tensor * a,
5645
+ int n_groups,
5646
+ float eps) {
5708
5647
  return lm_ggml_group_norm_impl(ctx, a, n_groups, eps, true);
5709
5648
  }
5710
5649
 
@@ -5717,17 +5656,10 @@ struct lm_ggml_tensor * lm_ggml_mul_mat(
5717
5656
  LM_GGML_ASSERT(lm_ggml_can_mul_mat(a, b));
5718
5657
  LM_GGML_ASSERT(!lm_ggml_is_transposed(a));
5719
5658
 
5720
- bool is_node = false;
5721
-
5722
- if (a->grad || b->grad) {
5723
- is_node = true;
5724
- }
5725
-
5726
5659
  const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
5727
5660
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
5728
5661
 
5729
- result->op = LM_GGML_OP_MUL_MAT;
5730
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5662
+ result->op = LM_GGML_OP_MUL_MAT;
5731
5663
  result->src[0] = a;
5732
5664
  result->src[1] = b;
5733
5665
 
@@ -5773,17 +5705,10 @@ struct lm_ggml_tensor * lm_ggml_mul_mat_id(
5773
5705
  LM_GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
5774
5706
  LM_GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
5775
5707
 
5776
- bool is_node = false;
5777
-
5778
- if (as->grad || b->grad) {
5779
- is_node = true;
5780
- }
5781
-
5782
5708
  const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
5783
5709
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
5784
5710
 
5785
- result->op = LM_GGML_OP_MUL_MAT_ID;
5786
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5711
+ result->op = LM_GGML_OP_MUL_MAT_ID;
5787
5712
  result->src[0] = as;
5788
5713
  result->src[1] = b;
5789
5714
  result->src[2] = ids;
@@ -5800,18 +5725,11 @@ struct lm_ggml_tensor * lm_ggml_out_prod(
5800
5725
  LM_GGML_ASSERT(lm_ggml_can_out_prod(a, b));
5801
5726
  LM_GGML_ASSERT(!lm_ggml_is_transposed(a));
5802
5727
 
5803
- bool is_node = false;
5804
-
5805
- if (a->grad || b->grad) {
5806
- is_node = true;
5807
- }
5808
-
5809
5728
  // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3]
5810
5729
  const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] };
5811
5730
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
5812
5731
 
5813
- result->op = LM_GGML_OP_OUT_PROD;
5814
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5732
+ result->op = LM_GGML_OP_OUT_PROD;
5815
5733
  result->src[0] = a;
5816
5734
  result->src[1] = b;
5817
5735
 
@@ -5824,21 +5742,14 @@ static struct lm_ggml_tensor * lm_ggml_scale_impl(
5824
5742
  struct lm_ggml_context * ctx,
5825
5743
  struct lm_ggml_tensor * a,
5826
5744
  float s,
5827
- bool inplace) {
5745
+ bool inplace) {
5828
5746
  LM_GGML_ASSERT(lm_ggml_is_padded_1d(a));
5829
5747
 
5830
- bool is_node = false;
5831
-
5832
- if (a->grad) {
5833
- is_node = true;
5834
- }
5835
-
5836
5748
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5837
5749
 
5838
5750
  lm_ggml_set_op_params(result, &s, sizeof(s));
5839
5751
 
5840
- result->op = LM_GGML_OP_SCALE;
5841
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5752
+ result->op = LM_GGML_OP_SCALE;
5842
5753
  result->src[0] = a;
5843
5754
 
5844
5755
  return result;
@@ -5846,15 +5757,15 @@ static struct lm_ggml_tensor * lm_ggml_scale_impl(
5846
5757
 
5847
5758
  struct lm_ggml_tensor * lm_ggml_scale(
5848
5759
  struct lm_ggml_context * ctx,
5849
- struct lm_ggml_tensor * a,
5850
- float s) {
5760
+ struct lm_ggml_tensor * a,
5761
+ float s) {
5851
5762
  return lm_ggml_scale_impl(ctx, a, s, false);
5852
5763
  }
5853
5764
 
5854
5765
  struct lm_ggml_tensor * lm_ggml_scale_inplace(
5855
5766
  struct lm_ggml_context * ctx,
5856
- struct lm_ggml_tensor * a,
5857
- float s) {
5767
+ struct lm_ggml_tensor * a,
5768
+ float s) {
5858
5769
  return lm_ggml_scale_impl(ctx, a, s, true);
5859
5770
  }
5860
5771
 
@@ -5868,15 +5779,9 @@ static struct lm_ggml_tensor * lm_ggml_set_impl(
5868
5779
  size_t nb2,
5869
5780
  size_t nb3,
5870
5781
  size_t offset,
5871
- bool inplace) {
5782
+ bool inplace) {
5872
5783
  LM_GGML_ASSERT(lm_ggml_nelements(a) >= lm_ggml_nelements(b));
5873
5784
 
5874
- bool is_node = false;
5875
-
5876
- if (a->grad || b->grad) {
5877
- is_node = true;
5878
- }
5879
-
5880
5785
  // make a view of the destination
5881
5786
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
5882
5787
 
@@ -5884,8 +5789,7 @@ static struct lm_ggml_tensor * lm_ggml_set_impl(
5884
5789
  int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
5885
5790
  lm_ggml_set_op_params(result, params, sizeof(params));
5886
5791
 
5887
- result->op = LM_GGML_OP_SET;
5888
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5792
+ result->op = LM_GGML_OP_SET;
5889
5793
  result->src[0] = a;
5890
5794
  result->src[1] = b;
5891
5795
 
@@ -5894,8 +5798,8 @@ static struct lm_ggml_tensor * lm_ggml_set_impl(
5894
5798
 
5895
5799
  struct lm_ggml_tensor * lm_ggml_set(
5896
5800
  struct lm_ggml_context * ctx,
5897
- struct lm_ggml_tensor * a,
5898
- struct lm_ggml_tensor * b,
5801
+ struct lm_ggml_tensor * a,
5802
+ struct lm_ggml_tensor * b,
5899
5803
  size_t nb1,
5900
5804
  size_t nb2,
5901
5805
  size_t nb3,
@@ -5905,8 +5809,8 @@ struct lm_ggml_tensor * lm_ggml_set(
5905
5809
 
5906
5810
  struct lm_ggml_tensor * lm_ggml_set_inplace(
5907
5811
  struct lm_ggml_context * ctx,
5908
- struct lm_ggml_tensor * a,
5909
- struct lm_ggml_tensor * b,
5812
+ struct lm_ggml_tensor * a,
5813
+ struct lm_ggml_tensor * b,
5910
5814
  size_t nb1,
5911
5815
  size_t nb2,
5912
5816
  size_t nb3,
@@ -5916,24 +5820,24 @@ struct lm_ggml_tensor * lm_ggml_set_inplace(
5916
5820
 
5917
5821
  struct lm_ggml_tensor * lm_ggml_set_1d(
5918
5822
  struct lm_ggml_context * ctx,
5919
- struct lm_ggml_tensor * a,
5920
- struct lm_ggml_tensor * b,
5823
+ struct lm_ggml_tensor * a,
5824
+ struct lm_ggml_tensor * b,
5921
5825
  size_t offset) {
5922
5826
  return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, false);
5923
5827
  }
5924
5828
 
5925
5829
  struct lm_ggml_tensor * lm_ggml_set_1d_inplace(
5926
5830
  struct lm_ggml_context * ctx,
5927
- struct lm_ggml_tensor * a,
5928
- struct lm_ggml_tensor * b,
5831
+ struct lm_ggml_tensor * a,
5832
+ struct lm_ggml_tensor * b,
5929
5833
  size_t offset) {
5930
5834
  return lm_ggml_set_impl(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], offset, true);
5931
5835
  }
5932
5836
 
5933
5837
  struct lm_ggml_tensor * lm_ggml_set_2d(
5934
5838
  struct lm_ggml_context * ctx,
5935
- struct lm_ggml_tensor * a,
5936
- struct lm_ggml_tensor * b,
5839
+ struct lm_ggml_tensor * a,
5840
+ struct lm_ggml_tensor * b,
5937
5841
  size_t nb1,
5938
5842
  size_t offset) {
5939
5843
  return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false);
@@ -5941,8 +5845,8 @@ struct lm_ggml_tensor * lm_ggml_set_2d(
5941
5845
 
5942
5846
  struct lm_ggml_tensor * lm_ggml_set_2d_inplace(
5943
5847
  struct lm_ggml_context * ctx,
5944
- struct lm_ggml_tensor * a,
5945
- struct lm_ggml_tensor * b,
5848
+ struct lm_ggml_tensor * a,
5849
+ struct lm_ggml_tensor * b,
5946
5850
  size_t nb1,
5947
5851
  size_t offset) {
5948
5852
  return lm_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, true);
@@ -5956,13 +5860,6 @@ static struct lm_ggml_tensor * lm_ggml_cpy_impl(
5956
5860
  struct lm_ggml_tensor * b) {
5957
5861
  LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b));
5958
5862
 
5959
- bool is_node = false;
5960
-
5961
- if (a->grad || b->grad) {
5962
- // inplace is false and either one have a grad
5963
- is_node = true;
5964
- }
5965
-
5966
5863
  // make a view of the destination
5967
5864
  struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, b);
5968
5865
  if (strlen(b->name) > 0) {
@@ -5971,8 +5868,7 @@ static struct lm_ggml_tensor * lm_ggml_cpy_impl(
5971
5868
  lm_ggml_format_name(result, "%s (copy)", a->name);
5972
5869
  }
5973
5870
 
5974
- result->op = LM_GGML_OP_CPY;
5975
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5871
+ result->op = LM_GGML_OP_CPY;
5976
5872
  result->src[0] = a;
5977
5873
  result->src[1] = b;
5978
5874
 
@@ -5990,13 +5886,10 @@ struct lm_ggml_tensor * lm_ggml_cast(
5990
5886
  struct lm_ggml_context * ctx,
5991
5887
  struct lm_ggml_tensor * a,
5992
5888
  enum lm_ggml_type type) {
5993
- bool is_node = false;
5994
-
5995
5889
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, type, LM_GGML_MAX_DIMS, a->ne);
5996
5890
  lm_ggml_format_name(result, "%s (copy)", a->name);
5997
5891
 
5998
- result->op = LM_GGML_OP_CPY;
5999
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5892
+ result->op = LM_GGML_OP_CPY;
6000
5893
  result->src[0] = a;
6001
5894
  result->src[1] = result;
6002
5895
 
@@ -6008,17 +5901,10 @@ struct lm_ggml_tensor * lm_ggml_cast(
6008
5901
  static struct lm_ggml_tensor * lm_ggml_cont_impl(
6009
5902
  struct lm_ggml_context * ctx,
6010
5903
  struct lm_ggml_tensor * a) {
6011
- bool is_node = false;
6012
-
6013
- if (a->grad) {
6014
- is_node = true;
6015
- }
6016
-
6017
5904
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
6018
5905
  lm_ggml_format_name(result, "%s (cont)", a->name);
6019
5906
 
6020
- result->op = LM_GGML_OP_CONT;
6021
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5907
+ result->op = LM_GGML_OP_CONT;
6022
5908
  result->src[0] = a;
6023
5909
 
6024
5910
  return result;
@@ -6064,13 +5950,10 @@ struct lm_ggml_tensor * lm_ggml_cont_4d(
6064
5950
  int64_t ne3) {
6065
5951
  LM_GGML_ASSERT(lm_ggml_nelements(a) == (ne0*ne1*ne2*ne3));
6066
5952
 
6067
- bool is_node = false;
6068
-
6069
5953
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
6070
5954
  lm_ggml_format_name(result, "%s (cont)", a->name);
6071
5955
 
6072
- result->op = LM_GGML_OP_CONT;
6073
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5956
+ result->op = LM_GGML_OP_CONT;
6074
5957
  result->src[0] = a;
6075
5958
 
6076
5959
  return result;
@@ -6086,22 +5969,10 @@ struct lm_ggml_tensor * lm_ggml_reshape(
6086
5969
  // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
6087
5970
  LM_GGML_ASSERT(lm_ggml_nelements(a) == lm_ggml_nelements(b));
6088
5971
 
6089
- bool is_node = false;
6090
-
6091
- if (a->grad) {
6092
- is_node = true;
6093
- }
6094
-
6095
- if (b->grad) {
6096
- // gradient propagation is not supported
6097
- //LM_GGML_ABORT("fatal error");
6098
- }
6099
-
6100
5972
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, LM_GGML_MAX_DIMS, b->ne, a, 0);
6101
5973
  lm_ggml_format_name(result, "%s (reshaped)", a->name);
6102
5974
 
6103
- result->op = LM_GGML_OP_RESHAPE;
6104
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5975
+ result->op = LM_GGML_OP_RESHAPE;
6105
5976
  result->src[0] = a;
6106
5977
 
6107
5978
  return result;
@@ -6114,18 +5985,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_1d(
6114
5985
  LM_GGML_ASSERT(lm_ggml_is_contiguous(a));
6115
5986
  LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0);
6116
5987
 
6117
- bool is_node = false;
6118
-
6119
- if (a->grad) {
6120
- is_node = true;
6121
- }
6122
-
6123
5988
  const int64_t ne[1] = { ne0 };
6124
5989
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 1, ne, a, 0);
6125
5990
  lm_ggml_format_name(result, "%s (reshaped)", a->name);
6126
5991
 
6127
- result->op = LM_GGML_OP_RESHAPE;
6128
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
5992
+ result->op = LM_GGML_OP_RESHAPE;
6129
5993
  result->src[0] = a;
6130
5994
 
6131
5995
  return result;
@@ -6139,18 +6003,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_2d(
6139
6003
  LM_GGML_ASSERT(lm_ggml_is_contiguous(a));
6140
6004
  LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1);
6141
6005
 
6142
- bool is_node = false;
6143
-
6144
- if (a->grad) {
6145
- is_node = true;
6146
- }
6147
-
6148
6006
  const int64_t ne[2] = { ne0, ne1 };
6149
6007
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 2, ne, a, 0);
6150
6008
  lm_ggml_format_name(result, "%s (reshaped)", a->name);
6151
6009
 
6152
- result->op = LM_GGML_OP_RESHAPE;
6153
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6010
+ result->op = LM_GGML_OP_RESHAPE;
6154
6011
  result->src[0] = a;
6155
6012
 
6156
6013
  return result;
@@ -6165,18 +6022,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_3d(
6165
6022
  LM_GGML_ASSERT(lm_ggml_is_contiguous(a));
6166
6023
  LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2);
6167
6024
 
6168
- bool is_node = false;
6169
-
6170
- if (a->grad) {
6171
- is_node = true;
6172
- }
6173
-
6174
6025
  const int64_t ne[3] = { ne0, ne1, ne2 };
6175
6026
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 3, ne, a, 0);
6176
6027
  lm_ggml_format_name(result, "%s (reshaped)", a->name);
6177
6028
 
6178
- result->op = LM_GGML_OP_RESHAPE;
6179
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6029
+ result->op = LM_GGML_OP_RESHAPE;
6180
6030
  result->src[0] = a;
6181
6031
 
6182
6032
  return result;
@@ -6192,18 +6042,11 @@ struct lm_ggml_tensor * lm_ggml_reshape_4d(
6192
6042
  LM_GGML_ASSERT(lm_ggml_is_contiguous(a));
6193
6043
  LM_GGML_ASSERT(lm_ggml_nelements(a) == ne0*ne1*ne2*ne3);
6194
6044
 
6195
- bool is_node = false;
6196
-
6197
- if (a->grad) {
6198
- is_node = true;
6199
- }
6200
-
6201
6045
  const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
6202
6046
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, 4, ne, a, 0);
6203
6047
  lm_ggml_format_name(result, "%s (reshaped)", a->name);
6204
6048
 
6205
- result->op = LM_GGML_OP_RESHAPE;
6206
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6049
+ result->op = LM_GGML_OP_RESHAPE;
6207
6050
  result->src[0] = a;
6208
6051
 
6209
6052
  return result;
@@ -6215,20 +6058,12 @@ static struct lm_ggml_tensor * lm_ggml_view_impl(
6215
6058
  int n_dims,
6216
6059
  const int64_t * ne,
6217
6060
  size_t offset) {
6218
-
6219
- bool is_node = false;
6220
-
6221
- if (a->grad) {
6222
- is_node = true;
6223
- }
6224
-
6225
6061
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_impl(ctx, a->type, n_dims, ne, a, offset);
6226
6062
  lm_ggml_format_name(result, "%s (view)", a->name);
6227
6063
 
6228
6064
  lm_ggml_set_op_params(result, &offset, sizeof(offset));
6229
6065
 
6230
- result->op = LM_GGML_OP_VIEW;
6231
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6066
+ result->op = LM_GGML_OP_VIEW;
6232
6067
  result->src[0] = a;
6233
6068
 
6234
6069
  return result;
@@ -6241,7 +6076,6 @@ struct lm_ggml_tensor * lm_ggml_view_1d(
6241
6076
  struct lm_ggml_tensor * a,
6242
6077
  int64_t ne0,
6243
6078
  size_t offset) {
6244
-
6245
6079
  struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 1, &ne0, offset);
6246
6080
 
6247
6081
  return result;
@@ -6256,7 +6090,6 @@ struct lm_ggml_tensor * lm_ggml_view_2d(
6256
6090
  int64_t ne1,
6257
6091
  size_t nb1,
6258
6092
  size_t offset) {
6259
-
6260
6093
  const int64_t ne[2] = { ne0, ne1 };
6261
6094
 
6262
6095
  struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 2, ne, offset);
@@ -6279,7 +6112,6 @@ struct lm_ggml_tensor * lm_ggml_view_3d(
6279
6112
  size_t nb1,
6280
6113
  size_t nb2,
6281
6114
  size_t offset) {
6282
-
6283
6115
  const int64_t ne[3] = { ne0, ne1, ne2 };
6284
6116
 
6285
6117
  struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 3, ne, offset);
@@ -6304,7 +6136,6 @@ struct lm_ggml_tensor * lm_ggml_view_4d(
6304
6136
  size_t nb2,
6305
6137
  size_t nb3,
6306
6138
  size_t offset) {
6307
-
6308
6139
  const int64_t ne[4] = { ne0, ne1, ne2, ne3 };
6309
6140
 
6310
6141
  struct lm_ggml_tensor * result = lm_ggml_view_impl(ctx, a, 4, ne, offset);
@@ -6337,12 +6168,6 @@ struct lm_ggml_tensor * lm_ggml_permute(
6337
6168
  LM_GGML_ASSERT(axis1 != axis3);
6338
6169
  LM_GGML_ASSERT(axis2 != axis3);
6339
6170
 
6340
- bool is_node = false;
6341
-
6342
- if (a->grad) {
6343
- is_node = true;
6344
- }
6345
-
6346
6171
  struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
6347
6172
  lm_ggml_format_name(result, "%s (permuted)", a->name);
6348
6173
 
@@ -6369,8 +6194,7 @@ struct lm_ggml_tensor * lm_ggml_permute(
6369
6194
  result->nb[2] = nb[2];
6370
6195
  result->nb[3] = nb[3];
6371
6196
 
6372
- result->op = LM_GGML_OP_PERMUTE;
6373
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6197
+ result->op = LM_GGML_OP_PERMUTE;
6374
6198
  result->src[0] = a;
6375
6199
 
6376
6200
  int32_t params[] = { axis0, axis1, axis2, axis3 };
@@ -6384,12 +6208,6 @@ struct lm_ggml_tensor * lm_ggml_permute(
6384
6208
  struct lm_ggml_tensor * lm_ggml_transpose(
6385
6209
  struct lm_ggml_context * ctx,
6386
6210
  struct lm_ggml_tensor * a) {
6387
- bool is_node = false;
6388
-
6389
- if (a->grad) {
6390
- is_node = true;
6391
- }
6392
-
6393
6211
  struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
6394
6212
  lm_ggml_format_name(result, "%s (transposed)", a->name);
6395
6213
 
@@ -6399,8 +6217,7 @@ struct lm_ggml_tensor * lm_ggml_transpose(
6399
6217
  result->nb[0] = a->nb[1];
6400
6218
  result->nb[1] = a->nb[0];
6401
6219
 
6402
- result->op = LM_GGML_OP_TRANSPOSE;
6403
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6220
+ result->op = LM_GGML_OP_TRANSPOSE;
6404
6221
  result->src[0] = a;
6405
6222
 
6406
6223
  return result;
@@ -6416,12 +6233,6 @@ struct lm_ggml_tensor * lm_ggml_get_rows(
6416
6233
  LM_GGML_ASSERT(b->ne[3] == 1);
6417
6234
  LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32);
6418
6235
 
6419
- bool is_node = false;
6420
-
6421
- if (a->grad || b->grad) {
6422
- is_node = true;
6423
- }
6424
-
6425
6236
  // TODO: implement non F32 return
6426
6237
  enum lm_ggml_type type = LM_GGML_TYPE_F32;
6427
6238
  if (a->type == LM_GGML_TYPE_I32) {
@@ -6429,8 +6240,7 @@ struct lm_ggml_tensor * lm_ggml_get_rows(
6429
6240
  }
6430
6241
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, type, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
6431
6242
 
6432
- result->op = LM_GGML_OP_GET_ROWS;
6433
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6243
+ result->op = LM_GGML_OP_GET_ROWS;
6434
6244
  result->src[0] = a;
6435
6245
  result->src[1] = b;
6436
6246
 
@@ -6447,18 +6257,11 @@ struct lm_ggml_tensor * lm_ggml_get_rows_back(
6447
6257
  LM_GGML_ASSERT(lm_ggml_is_matrix(a) && lm_ggml_is_vector(b) && b->type == LM_GGML_TYPE_I32);
6448
6258
  LM_GGML_ASSERT(lm_ggml_is_matrix(c) && (a->ne[0] == c->ne[0]));
6449
6259
 
6450
- bool is_node = false;
6451
-
6452
- if (a->grad || b->grad) {
6453
- is_node = true;
6454
- }
6455
-
6456
6260
  // TODO: implement non F32 return
6457
6261
  //struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
6458
6262
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, c->ne[0], c->ne[1]);
6459
6263
 
6460
- result->op = LM_GGML_OP_GET_ROWS_BACK;
6461
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6264
+ result->op = LM_GGML_OP_GET_ROWS_BACK;
6462
6265
  result->src[0] = a;
6463
6266
  result->src[1] = b;
6464
6267
 
@@ -6471,17 +6274,11 @@ struct lm_ggml_tensor * lm_ggml_diag(
6471
6274
  struct lm_ggml_context * ctx,
6472
6275
  struct lm_ggml_tensor * a) {
6473
6276
  LM_GGML_ASSERT(a->ne[1] == 1);
6474
- bool is_node = false;
6475
-
6476
- if (a->grad) {
6477
- is_node = true;
6478
- }
6479
6277
 
6480
6278
  const int64_t ne[4] = { a->ne[0], a->ne[0], a->ne[2], a->ne[3] };
6481
6279
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, a->type, 4, ne);
6482
6280
 
6483
- result->op = LM_GGML_OP_DIAG;
6484
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6281
+ result->op = LM_GGML_OP_DIAG;
6485
6282
  result->src[0] = a;
6486
6283
 
6487
6284
  return result;
@@ -6494,19 +6291,12 @@ static struct lm_ggml_tensor * lm_ggml_diag_mask_inf_impl(
6494
6291
  struct lm_ggml_tensor * a,
6495
6292
  int n_past,
6496
6293
  bool inplace) {
6497
- bool is_node = false;
6498
-
6499
- if (a->grad) {
6500
- is_node = true;
6501
- }
6502
-
6503
6294
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
6504
6295
 
6505
6296
  int32_t params[] = { n_past };
6506
6297
  lm_ggml_set_op_params(result, params, sizeof(params));
6507
6298
 
6508
- result->op = LM_GGML_OP_DIAG_MASK_INF;
6509
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6299
+ result->op = LM_GGML_OP_DIAG_MASK_INF;
6510
6300
  result->src[0] = a;
6511
6301
 
6512
6302
  return result;
@@ -6533,19 +6323,12 @@ static struct lm_ggml_tensor * lm_ggml_diag_mask_zero_impl(
6533
6323
  struct lm_ggml_tensor * a,
6534
6324
  int n_past,
6535
6325
  bool inplace) {
6536
- bool is_node = false;
6537
-
6538
- if (a->grad) {
6539
- is_node = true;
6540
- }
6541
-
6542
6326
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
6543
6327
 
6544
6328
  int32_t params[] = { n_past };
6545
6329
  lm_ggml_set_op_params(result, params, sizeof(params));
6546
6330
 
6547
- result->op = LM_GGML_OP_DIAG_MASK_ZERO;
6548
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6331
+ result->op = LM_GGML_OP_DIAG_MASK_ZERO;
6549
6332
  result->src[0] = a;
6550
6333
 
6551
6334
  return result;
@@ -6588,19 +6371,12 @@ static struct lm_ggml_tensor * lm_ggml_soft_max_impl(
6588
6371
  LM_GGML_ASSERT(mask);
6589
6372
  }
6590
6373
 
6591
- bool is_node = false;
6592
-
6593
- if (a->grad) {
6594
- is_node = true;
6595
- }
6596
-
6597
6374
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
6598
6375
 
6599
6376
  float params[] = { scale, max_bias };
6600
6377
  lm_ggml_set_op_params(result, params, sizeof(params));
6601
6378
 
6602
- result->op = LM_GGML_OP_SOFT_MAX;
6603
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6379
+ result->op = LM_GGML_OP_SOFT_MAX;
6604
6380
  result->src[0] = a;
6605
6381
  result->src[1] = mask;
6606
6382
 
@@ -6635,16 +6411,9 @@ static struct lm_ggml_tensor * lm_ggml_soft_max_back_impl(
6635
6411
  struct lm_ggml_tensor * a,
6636
6412
  struct lm_ggml_tensor * b,
6637
6413
  bool inplace) {
6638
- bool is_node = false;
6639
-
6640
- if (a->grad || b->grad) {
6641
- is_node = true; // TODO : implement backward pass
6642
- }
6643
-
6644
6414
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
6645
6415
 
6646
- result->op = LM_GGML_OP_SOFT_MAX_BACK;
6647
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6416
+ result->op = LM_GGML_OP_SOFT_MAX_BACK;
6648
6417
  result->src[0] = a;
6649
6418
  result->src[1] = b;
6650
6419
 
@@ -6693,12 +6462,6 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl(
6693
6462
  LM_GGML_ASSERT(c->ne[0] >= n_dims / 2);
6694
6463
  }
6695
6464
 
6696
- bool is_node = false;
6697
-
6698
- if (a->grad) {
6699
- is_node = true;
6700
- }
6701
-
6702
6465
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
6703
6466
 
6704
6467
  int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -6710,8 +6473,7 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl(
6710
6473
  memcpy(params + 10, &beta_slow, sizeof(float));
6711
6474
  lm_ggml_set_op_params(result, params, sizeof(params));
6712
6475
 
6713
- result->op = LM_GGML_OP_ROPE;
6714
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6476
+ result->op = LM_GGML_OP_ROPE;
6715
6477
  result->src[0] = a;
6716
6478
  result->src[1] = b;
6717
6479
  result->src[2] = c;
@@ -6839,13 +6601,6 @@ struct lm_ggml_tensor * lm_ggml_rope_back(
6839
6601
  LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32);
6840
6602
  LM_GGML_ASSERT(a->ne[2] == b->ne[0]);
6841
6603
 
6842
- bool is_node = false;
6843
-
6844
- if (a->grad) {
6845
- LM_GGML_ASSERT(false && "backwards pass not implemented");
6846
- is_node = false;
6847
- }
6848
-
6849
6604
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
6850
6605
 
6851
6606
  int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig };
@@ -6857,8 +6612,7 @@ struct lm_ggml_tensor * lm_ggml_rope_back(
6857
6612
  memcpy(params + 10, &beta_slow, sizeof(float));
6858
6613
  lm_ggml_set_op_params(result, params, sizeof(params));
6859
6614
 
6860
- result->op = LM_GGML_OP_ROPE_BACK;
6861
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6615
+ result->op = LM_GGML_OP_ROPE_BACK;
6862
6616
  result->src[0] = a;
6863
6617
  result->src[1] = b;
6864
6618
  result->src[2] = c;
@@ -6873,21 +6627,13 @@ struct lm_ggml_tensor * lm_ggml_clamp(
6873
6627
  struct lm_ggml_tensor * a,
6874
6628
  float min,
6875
6629
  float max) {
6876
- bool is_node = false;
6877
-
6878
- if (a->grad) {
6879
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
6880
- is_node = true;
6881
- }
6882
-
6883
6630
  // TODO: when implement backward, fix this:
6884
6631
  struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
6885
6632
 
6886
6633
  float params[] = { min, max };
6887
6634
  lm_ggml_set_op_params(result, params, sizeof(params));
6888
6635
 
6889
- result->op = LM_GGML_OP_CLAMP;
6890
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6636
+ result->op = LM_GGML_OP_CLAMP;
6891
6637
  result->src[0] = a;
6892
6638
 
6893
6639
  return result;
@@ -6949,13 +6695,6 @@ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d(
6949
6695
  LM_GGML_ASSERT(p0 == 0);
6950
6696
  LM_GGML_ASSERT(d0 == 1);
6951
6697
 
6952
- bool is_node = false;
6953
-
6954
- if (a->grad || b->grad) {
6955
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
6956
- is_node = true;
6957
- }
6958
-
6959
6698
  const int64_t ne[4] = {
6960
6699
  lm_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
6961
6700
  a->ne[1], b->ne[2], 1,
@@ -6965,8 +6704,7 @@ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d(
6965
6704
  int32_t params[] = { s0, p0, d0 };
6966
6705
  lm_ggml_set_op_params(result, params, sizeof(params));
6967
6706
 
6968
- result->op = LM_GGML_OP_CONV_TRANSPOSE_1D;
6969
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6707
+ result->op = LM_GGML_OP_CONV_TRANSPOSE_1D;
6970
6708
  result->src[0] = a;
6971
6709
  result->src[1] = b;
6972
6710
 
@@ -6974,17 +6712,17 @@ LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d(
6974
6712
  }
6975
6713
 
6976
6714
  // lm_ggml_conv_depthwise
6977
- struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d(
6978
- struct lm_ggml_context * ctx,
6979
- struct lm_ggml_tensor * a,
6980
- struct lm_ggml_tensor * b,
6981
- int s0,
6982
- int s1,
6983
- int p0,
6984
- int p1,
6985
- int d0,
6986
- int d1) {
6987
6715
 
6716
+ struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d(
6717
+ struct lm_ggml_context * ctx,
6718
+ struct lm_ggml_tensor * a,
6719
+ struct lm_ggml_tensor * b,
6720
+ int s0,
6721
+ int s1,
6722
+ int p0,
6723
+ int p1,
6724
+ int d0,
6725
+ int d1) {
6988
6726
  struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
6989
6727
  struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a,
6990
6728
  lm_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
@@ -7004,29 +6742,23 @@ struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d(
7004
6742
  // b: [N, IC, IH, IW]
7005
6743
  // result: [N, OH, OW, IC*KH*KW]
7006
6744
  struct lm_ggml_tensor * lm_ggml_im2col(
7007
- struct lm_ggml_context * ctx,
7008
- struct lm_ggml_tensor * a,
7009
- struct lm_ggml_tensor * b,
7010
- int s0,
7011
- int s1,
7012
- int p0,
7013
- int p1,
7014
- int d0,
7015
- int d1,
7016
- bool is_2D,
7017
- enum lm_ggml_type dst_type) {
7018
-
6745
+ struct lm_ggml_context * ctx,
6746
+ struct lm_ggml_tensor * a,
6747
+ struct lm_ggml_tensor * b,
6748
+ int s0,
6749
+ int s1,
6750
+ int p0,
6751
+ int p1,
6752
+ int d0,
6753
+ int d1,
6754
+ bool is_2D,
6755
+ enum lm_ggml_type dst_type) {
7019
6756
  if(is_2D) {
7020
6757
  LM_GGML_ASSERT(a->ne[2] == b->ne[2]);
7021
6758
  } else {
7022
6759
  LM_GGML_ASSERT(a->ne[1] == b->ne[1]);
7023
6760
  LM_GGML_ASSERT(b->ne[3] == 1);
7024
6761
  }
7025
- bool is_node = false;
7026
-
7027
- if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
7028
- is_node = true;
7029
- }
7030
6762
 
7031
6763
  const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0;
7032
6764
  const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0);
@@ -7045,8 +6777,7 @@ struct lm_ggml_tensor * lm_ggml_im2col(
7045
6777
  int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
7046
6778
  lm_ggml_set_op_params(result, params, sizeof(params));
7047
6779
 
7048
- result->op = LM_GGML_OP_IM2COL;
7049
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6780
+ result->op = LM_GGML_OP_IM2COL;
7050
6781
  result->src[0] = a;
7051
6782
  result->src[1] = b;
7052
6783
 
@@ -7054,30 +6785,22 @@ struct lm_ggml_tensor * lm_ggml_im2col(
7054
6785
  }
7055
6786
 
7056
6787
  struct lm_ggml_tensor * lm_ggml_im2col_back(
7057
- struct lm_ggml_context * ctx,
7058
- struct lm_ggml_tensor * a,
7059
- struct lm_ggml_tensor * b,
7060
- int64_t * ne,
7061
- int s0,
7062
- int s1,
7063
- int p0,
7064
- int p1,
7065
- int d0,
7066
- int d1,
7067
- bool is_2D) {
7068
-
7069
- bool is_node = false;
7070
-
7071
- if (/*a->grad ||*/ b->grad) { // a is only used for its shape, not its data
7072
- is_node = true;
7073
- }
7074
-
6788
+ struct lm_ggml_context * ctx,
6789
+ struct lm_ggml_tensor * a,
6790
+ struct lm_ggml_tensor * b,
6791
+ int64_t * ne,
6792
+ int s0,
6793
+ int s1,
6794
+ int p0,
6795
+ int p1,
6796
+ int d0,
6797
+ int d1,
6798
+ bool is_2D) {
7075
6799
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7076
6800
  int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) };
7077
6801
  lm_ggml_set_op_params(result, params, sizeof(params));
7078
6802
 
7079
- result->op = LM_GGML_OP_IM2COL_BACK;
7080
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6803
+ result->op = LM_GGML_OP_IM2COL_BACK;
7081
6804
  result->src[0] = a;
7082
6805
  result->src[1] = b;
7083
6806
 
@@ -7091,12 +6814,12 @@ struct lm_ggml_tensor * lm_ggml_conv_2d(
7091
6814
  struct lm_ggml_context * ctx,
7092
6815
  struct lm_ggml_tensor * a,
7093
6816
  struct lm_ggml_tensor * b,
7094
- int s0,
7095
- int s1,
7096
- int p0,
7097
- int p1,
7098
- int d0,
7099
- int d1) {
6817
+ int s0,
6818
+ int s1,
6819
+ int p0,
6820
+ int p1,
6821
+ int d0,
6822
+ int d1) {
7100
6823
  struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, a, b, s0, s1, p0, p1, d0, d1, true, a->type); // [N, OH, OW, IC * KH * KW]
7101
6824
 
7102
6825
  struct lm_ggml_tensor * result =
@@ -7112,6 +6835,7 @@ struct lm_ggml_tensor * lm_ggml_conv_2d(
7112
6835
  }
7113
6836
 
7114
6837
  // lm_ggml_conv_2d_sk_p0
6838
+
7115
6839
  struct lm_ggml_tensor * lm_ggml_conv_2d_sk_p0(
7116
6840
  struct lm_ggml_context * ctx,
7117
6841
  struct lm_ggml_tensor * a,
@@ -7141,13 +6865,6 @@ struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0(
7141
6865
  int stride) {
7142
6866
  LM_GGML_ASSERT(a->ne[3] == b->ne[2]);
7143
6867
 
7144
- bool is_node = false;
7145
-
7146
- if (a->grad || b->grad) {
7147
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7148
- is_node = true;
7149
- }
7150
-
7151
6868
  const int64_t ne[4] = {
7152
6869
  lm_ggml_calc_conv_transpose_output_size(b->ne[0], a->ne[0], stride, 0 /*p0*/),
7153
6870
  lm_ggml_calc_conv_transpose_output_size(b->ne[1], a->ne[1], stride, 0 /*p1*/),
@@ -7158,8 +6875,7 @@ struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0(
7158
6875
 
7159
6876
  lm_ggml_set_op_params_i32(result, 0, stride);
7160
6877
 
7161
- result->op = LM_GGML_OP_CONV_TRANSPOSE_2D;
7162
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6878
+ result->op = LM_GGML_OP_CONV_TRANSPOSE_2D;
7163
6879
  result->src[0] = a;
7164
6880
  result->src[1] = b;
7165
6881
 
@@ -7181,14 +6897,6 @@ struct lm_ggml_tensor * lm_ggml_pool_1d(
7181
6897
  int k0,
7182
6898
  int s0,
7183
6899
  int p0) {
7184
-
7185
- bool is_node = false;
7186
-
7187
- if (a->grad) {
7188
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7189
- is_node = true;
7190
- }
7191
-
7192
6900
  const int64_t ne[4] = {
7193
6901
  lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
7194
6902
  a->ne[1],
@@ -7200,8 +6908,7 @@ struct lm_ggml_tensor * lm_ggml_pool_1d(
7200
6908
  int32_t params[] = { op, k0, s0, p0 };
7201
6909
  lm_ggml_set_op_params(result, params, sizeof(params));
7202
6910
 
7203
- result->op = LM_GGML_OP_POOL_1D;
7204
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6911
+ result->op = LM_GGML_OP_POOL_1D;
7205
6912
  result->src[0] = a;
7206
6913
 
7207
6914
  return result;
@@ -7219,13 +6926,6 @@ struct lm_ggml_tensor * lm_ggml_pool_2d(
7219
6926
  int s1,
7220
6927
  float p0,
7221
6928
  float p1) {
7222
-
7223
- bool is_node = false;
7224
-
7225
- if (a->grad) {
7226
- is_node = true;
7227
- }
7228
-
7229
6929
  struct lm_ggml_tensor * result;
7230
6930
  const int64_t ne[4] = {
7231
6931
  lm_ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
@@ -7238,9 +6938,9 @@ struct lm_ggml_tensor * lm_ggml_pool_2d(
7238
6938
  int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
7239
6939
  lm_ggml_set_op_params(result, params, sizeof(params));
7240
6940
 
7241
- result->op = LM_GGML_OP_POOL_2D;
7242
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6941
+ result->op = LM_GGML_OP_POOL_2D;
7243
6942
  result->src[0] = a;
6943
+
7244
6944
  return result;
7245
6945
  }
7246
6946
 
@@ -7255,100 +6955,74 @@ struct lm_ggml_tensor * lm_ggml_pool_2d_back(
7255
6955
  int s1,
7256
6956
  float p0,
7257
6957
  float p1) {
7258
-
7259
- bool is_node = false;
7260
-
7261
- if (a->grad) {
7262
- is_node = true;
7263
- }
7264
-
7265
6958
  struct lm_ggml_tensor * result;
7266
6959
  result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, af->ne);
7267
6960
 
7268
6961
  int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };
7269
6962
  lm_ggml_set_op_params(result, params, sizeof(params));
7270
6963
 
7271
- result->op = LM_GGML_OP_POOL_2D_BACK;
7272
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6964
+ result->op = LM_GGML_OP_POOL_2D_BACK;
7273
6965
  result->src[0] = a;
7274
6966
  result->src[1] = af;
6967
+
7275
6968
  return result;
7276
6969
  }
7277
6970
 
7278
6971
  // lm_ggml_upscale
7279
6972
 
7280
6973
  static struct lm_ggml_tensor * lm_ggml_upscale_impl(
7281
- struct lm_ggml_context * ctx,
7282
- struct lm_ggml_tensor * a,
7283
- int ne0,
7284
- int ne1,
7285
- int ne2,
7286
- int ne3) {
7287
- bool is_node = false;
7288
-
7289
- if (a->grad) {
7290
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7291
- is_node = true;
7292
- }
7293
-
6974
+ struct lm_ggml_context * ctx,
6975
+ struct lm_ggml_tensor * a,
6976
+ int ne0,
6977
+ int ne1,
6978
+ int ne2,
6979
+ int ne3) {
7294
6980
  LM_GGML_ASSERT(a->ne[0] <= ne0);
7295
6981
  LM_GGML_ASSERT(a->ne[1] <= ne1);
7296
6982
  LM_GGML_ASSERT(a->ne[2] <= ne2);
7297
6983
  LM_GGML_ASSERT(a->ne[3] <= ne3);
7298
6984
 
7299
- struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type,
7300
- ne0,
7301
- ne1,
7302
- ne2,
7303
- ne3
7304
- );
7305
-
7306
- result->op = LM_GGML_OP_UPSCALE;
6985
+ struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
7307
6986
 
7308
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
6987
+ result->op = LM_GGML_OP_UPSCALE;
7309
6988
  result->src[0] = a;
7310
6989
 
7311
6990
  return result;
7312
6991
  }
7313
6992
 
7314
6993
  struct lm_ggml_tensor * lm_ggml_upscale(
7315
- struct lm_ggml_context * ctx,
7316
- struct lm_ggml_tensor * a,
7317
- int scale_factor) {
6994
+ struct lm_ggml_context * ctx,
6995
+ struct lm_ggml_tensor * a,
6996
+ int scale_factor) {
7318
6997
  return lm_ggml_upscale_impl(ctx, a, a->ne[0] * scale_factor, a->ne[1] * scale_factor, a->ne[2], a->ne[3]);
7319
6998
  }
7320
6999
 
7321
7000
  struct lm_ggml_tensor * lm_ggml_upscale_ext(
7322
- struct lm_ggml_context * ctx,
7323
- struct lm_ggml_tensor * a,
7324
- int ne0,
7325
- int ne1,
7326
- int ne2,
7327
- int ne3) {
7001
+ struct lm_ggml_context * ctx,
7002
+ struct lm_ggml_tensor * a,
7003
+ int ne0,
7004
+ int ne1,
7005
+ int ne2,
7006
+ int ne3) {
7328
7007
  return lm_ggml_upscale_impl(ctx, a, ne0, ne1, ne2, ne3);
7329
7008
  }
7330
7009
 
7331
7010
  // lm_ggml_pad
7332
7011
 
7333
7012
  struct lm_ggml_tensor * lm_ggml_pad(
7334
- struct lm_ggml_context * ctx,
7335
- struct lm_ggml_tensor * a,
7336
- int p0, int p1, int p2, int p3) {
7337
- bool is_node = false;
7338
-
7339
- if (a->grad) {
7340
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7341
- is_node = true;
7342
- }
7343
-
7013
+ struct lm_ggml_context * ctx,
7014
+ struct lm_ggml_tensor * a,
7015
+ int p0,
7016
+ int p1,
7017
+ int p2,
7018
+ int p3) {
7344
7019
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type,
7345
7020
  a->ne[0] + p0,
7346
7021
  a->ne[1] + p1,
7347
7022
  a->ne[2] + p2,
7348
7023
  a->ne[3] + p3);
7349
7024
 
7350
- result->op = LM_GGML_OP_PAD;
7351
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7025
+ result->op = LM_GGML_OP_PAD;
7352
7026
  result->src[0] = a;
7353
7027
 
7354
7028
  return result;
@@ -7357,39 +7031,32 @@ struct lm_ggml_tensor * lm_ggml_pad(
7357
7031
  // lm_ggml_arange
7358
7032
 
7359
7033
  struct lm_ggml_tensor * lm_ggml_arange(
7360
- struct lm_ggml_context * ctx,
7361
- float start,
7362
- float stop,
7363
- float step) {
7364
-
7034
+ struct lm_ggml_context * ctx,
7035
+ float start,
7036
+ float stop,
7037
+ float step) {
7365
7038
  LM_GGML_ASSERT(stop > start);
7366
7039
 
7367
7040
  const int64_t steps = (int64_t) ceilf((stop - start) / step);
7368
7041
 
7369
7042
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, steps);
7370
7043
 
7371
- result->op = LM_GGML_OP_ARANGE;
7372
7044
  lm_ggml_set_op_params_f32(result, 0, start);
7373
7045
  lm_ggml_set_op_params_f32(result, 1, stop);
7374
7046
  lm_ggml_set_op_params_f32(result, 2, step);
7375
7047
 
7048
+ result->op = LM_GGML_OP_ARANGE;
7049
+
7376
7050
  return result;
7377
7051
  }
7378
7052
 
7379
7053
  // lm_ggml_timestep_embedding
7380
7054
 
7381
7055
  struct lm_ggml_tensor * lm_ggml_timestep_embedding(
7382
- struct lm_ggml_context * ctx,
7383
- struct lm_ggml_tensor * timesteps,
7384
- int dim,
7385
- int max_period) {
7386
- bool is_node = false;
7387
-
7388
- if (timesteps->grad) {
7389
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7390
- is_node = true;
7391
- }
7392
-
7056
+ struct lm_ggml_context * ctx,
7057
+ struct lm_ggml_tensor * timesteps,
7058
+ int dim,
7059
+ int max_period) {
7393
7060
  int actual_dim = dim;
7394
7061
  if (dim % 2 != 0) {
7395
7062
  actual_dim = dim + 1;
@@ -7397,11 +7064,10 @@ struct lm_ggml_tensor * lm_ggml_timestep_embedding(
7397
7064
 
7398
7065
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
7399
7066
 
7400
- result->op = LM_GGML_OP_TIMESTEP_EMBEDDING;
7401
7067
  lm_ggml_set_op_params_i32(result, 0, dim);
7402
7068
  lm_ggml_set_op_params_i32(result, 1, max_period);
7403
7069
 
7404
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7070
+ result->op = LM_GGML_OP_TIMESTEP_EMBEDDING;
7405
7071
  result->src[0] = timesteps;
7406
7072
 
7407
7073
  return result;
@@ -7410,22 +7076,14 @@ struct lm_ggml_tensor * lm_ggml_timestep_embedding(
7410
7076
  // lm_ggml_argsort
7411
7077
 
7412
7078
  struct lm_ggml_tensor * lm_ggml_argsort(
7413
- struct lm_ggml_context * ctx,
7414
- struct lm_ggml_tensor * a,
7415
- enum lm_ggml_sort_order order) {
7416
- bool is_node = false;
7417
-
7418
- if (a->grad) {
7419
- LM_GGML_ABORT("fatal error"); // TODO: not implemented
7420
- is_node = true;
7421
- }
7422
-
7079
+ struct lm_ggml_context * ctx,
7080
+ struct lm_ggml_tensor * a,
7081
+ enum lm_ggml_sort_order order) {
7423
7082
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne);
7424
7083
 
7425
7084
  lm_ggml_set_op_params_i32(result, 0, (int32_t) order);
7426
7085
 
7427
- result->op = LM_GGML_OP_ARGSORT;
7428
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7086
+ result->op = LM_GGML_OP_ARGSORT;
7429
7087
  result->src[0] = a;
7430
7088
 
7431
7089
  return result;
@@ -7478,10 +7136,6 @@ struct lm_ggml_tensor * lm_ggml_flash_attn_ext(
7478
7136
 
7479
7137
  bool is_node = false;
7480
7138
 
7481
- if (q->grad || k->grad || v->grad) {
7482
- is_node = true;
7483
- }
7484
-
7485
7139
  // permute(0, 2, 1, 3)
7486
7140
  int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
7487
7141
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
@@ -7608,17 +7262,9 @@ struct lm_ggml_tensor * lm_ggml_ssm_conv(
7608
7262
  LM_GGML_ASSERT(sx->ne[1] == d_inner);
7609
7263
  LM_GGML_ASSERT(n_t >= 0);
7610
7264
 
7611
- bool is_node = false;
7612
-
7613
- if (sx->grad || c->grad) {
7614
- LM_GGML_ABORT("fatal error"); // TODO: implement
7615
- is_node = true;
7616
- }
7617
-
7618
7265
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_3d(ctx, LM_GGML_TYPE_F32, d_inner, n_t, n_s);
7619
7266
 
7620
- result->op = LM_GGML_OP_SSM_CONV;
7621
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7267
+ result->op = LM_GGML_OP_SSM_CONV;
7622
7268
  result->src[0] = sx;
7623
7269
  result->src[1] = c;
7624
7270
 
@@ -7662,18 +7308,10 @@ struct lm_ggml_tensor * lm_ggml_ssm_scan(
7662
7308
  LM_GGML_ASSERT(B->ne[2] == n_seqs);
7663
7309
  }
7664
7310
 
7665
- bool is_node = false;
7666
-
7667
- if (s->grad || x->grad || dt->grad || A->grad || B->grad || C->grad) {
7668
- LM_GGML_ABORT("fatal error"); // TODO: implement
7669
- is_node = true;
7670
- }
7671
-
7672
7311
  // concatenated y + ssm_states
7673
7312
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_F32, lm_ggml_nelements(x) + lm_ggml_nelements(s));
7674
7313
 
7675
7314
  result->op = LM_GGML_OP_SSM_SCAN;
7676
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7677
7315
  result->src[0] = s;
7678
7316
  result->src[1] = x;
7679
7317
  result->src[2] = dt;
@@ -7693,13 +7331,6 @@ struct lm_ggml_tensor * lm_ggml_win_part(
7693
7331
  LM_GGML_ASSERT(a->ne[3] == 1);
7694
7332
  LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32);
7695
7333
 
7696
- bool is_node = false;
7697
-
7698
- if (a->grad) {
7699
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7700
- is_node = true;
7701
- }
7702
-
7703
7334
  // padding
7704
7335
  const int px = (w - a->ne[1]%w)%w;
7705
7336
  const int py = (w - a->ne[2]%w)%w;
@@ -7714,8 +7345,7 @@ struct lm_ggml_tensor * lm_ggml_win_part(
7714
7345
  int32_t params[] = { npx, npy, w };
7715
7346
  lm_ggml_set_op_params(result, params, sizeof(params));
7716
7347
 
7717
- result->op = LM_GGML_OP_WIN_PART;
7718
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7348
+ result->op = LM_GGML_OP_WIN_PART;
7719
7349
  result->src[0] = a;
7720
7350
 
7721
7351
  return result;
@@ -7731,21 +7361,13 @@ struct lm_ggml_tensor * lm_ggml_win_unpart(
7731
7361
  int w) {
7732
7362
  LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32);
7733
7363
 
7734
- bool is_node = false;
7735
-
7736
- if (a->grad) {
7737
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7738
- is_node = true;
7739
- }
7740
-
7741
7364
  const int64_t ne[4] = { a->ne[0], w0, h0, 1, };
7742
7365
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 3, ne);
7743
7366
 
7744
7367
  int32_t params[] = { w };
7745
7368
  lm_ggml_set_op_params(result, params, sizeof(params));
7746
7369
 
7747
- result->op = LM_GGML_OP_WIN_UNPART;
7748
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7370
+ result->op = LM_GGML_OP_WIN_UNPART;
7749
7371
  result->src[0] = a;
7750
7372
 
7751
7373
  return result;
@@ -7761,18 +7383,10 @@ struct lm_ggml_tensor * lm_ggml_get_rel_pos(
7761
7383
  LM_GGML_ASSERT(qh == kh);
7762
7384
  LM_GGML_ASSERT(2*MAX(qh, kh) - 1 == a->ne[1]);
7763
7385
 
7764
- bool is_node = false;
7765
-
7766
- if (a->grad) {
7767
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7768
- is_node = true;
7769
- }
7770
-
7771
7386
  const int64_t ne[4] = { a->ne[0], kh, qh, 1, };
7772
7387
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F16, 3, ne);
7773
7388
 
7774
- result->op = LM_GGML_OP_GET_REL_POS;
7775
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7389
+ result->op = LM_GGML_OP_GET_REL_POS;
7776
7390
  result->src[0] = a;
7777
7391
 
7778
7392
  return result;
@@ -7796,17 +7410,10 @@ static struct lm_ggml_tensor * lm_ggml_add_rel_pos_impl(
7796
7410
  LM_GGML_ASSERT(pw->ne[0]*pw->ne[0] == a->ne[0]);
7797
7411
  LM_GGML_ASSERT(pw->ne[1]*pw->ne[2] == a->ne[1]);
7798
7412
 
7799
- bool is_node = false;
7800
-
7801
- if (!inplace && (a->grad || pw->grad || ph->grad)) {
7802
- is_node = true;
7803
- }
7804
-
7805
7413
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
7806
7414
  lm_ggml_set_op_params_i32(result, 0, inplace ? 1 : 0);
7807
7415
 
7808
- result->op = LM_GGML_OP_ADD_REL_POS;
7809
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7416
+ result->op = LM_GGML_OP_ADD_REL_POS;
7810
7417
  result->src[0] = a;
7811
7418
  result->src[1] = pw;
7812
7419
  result->src[2] = ph;
@@ -7834,12 +7441,12 @@ struct lm_ggml_tensor * lm_ggml_add_rel_pos_inplace(
7834
7441
 
7835
7442
  struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
7836
7443
  struct lm_ggml_context * ctx,
7837
- struct lm_ggml_tensor * k,
7838
- struct lm_ggml_tensor * v,
7839
- struct lm_ggml_tensor * r,
7840
- struct lm_ggml_tensor * tf,
7841
- struct lm_ggml_tensor * td,
7842
- struct lm_ggml_tensor * state) {
7444
+ struct lm_ggml_tensor * k,
7445
+ struct lm_ggml_tensor * v,
7446
+ struct lm_ggml_tensor * r,
7447
+ struct lm_ggml_tensor * tf,
7448
+ struct lm_ggml_tensor * td,
7449
+ struct lm_ggml_tensor * state) {
7843
7450
  LM_GGML_ASSERT(lm_ggml_is_contiguous(k));
7844
7451
  LM_GGML_ASSERT(lm_ggml_is_contiguous(v));
7845
7452
  LM_GGML_ASSERT(lm_ggml_is_contiguous(r));
@@ -7860,19 +7467,11 @@ struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
7860
7467
  LM_GGML_ASSERT(lm_ggml_nelements(state) == S * S * H * n_seqs);
7861
7468
  }
7862
7469
 
7863
- bool is_node = false;
7864
-
7865
- if (k->grad || v->grad || r->grad || tf->grad || td->grad || state->grad) {
7866
- LM_GGML_ABORT("fatal error"); // TODO: implement backward
7867
- is_node = true;
7868
- }
7869
-
7870
7470
  // concat output and new_state
7871
7471
  const int64_t ne[4] = { S * H, n_tokens + S * n_seqs, 1, 1 };
7872
7472
  struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne);
7873
7473
 
7874
- result->op = LM_GGML_OP_RWKV_WKV;
7875
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7474
+ result->op = LM_GGML_OP_RWKV_WKV;
7876
7475
  result->src[0] = k;
7877
7476
  result->src[1] = v;
7878
7477
  result->src[2] = r;
@@ -7887,23 +7486,16 @@ struct lm_ggml_tensor * lm_ggml_rwkv_wkv(
7887
7486
 
7888
7487
  static struct lm_ggml_tensor * lm_ggml_unary_impl(
7889
7488
  struct lm_ggml_context * ctx,
7890
- struct lm_ggml_tensor * a,
7891
- enum lm_ggml_unary_op op,
7892
- bool inplace) {
7893
- LM_GGML_ASSERT(lm_ggml_is_contiguous_1(a));
7894
-
7895
- bool is_node = false;
7896
-
7897
- if (!inplace && (a->grad)) {
7898
- is_node = true;
7899
- }
7489
+ struct lm_ggml_tensor * a,
7490
+ enum lm_ggml_unary_op op,
7491
+ bool inplace) {
7492
+ LM_GGML_ASSERT(lm_ggml_is_contiguous_1(a));
7900
7493
 
7901
7494
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
7902
7495
 
7903
7496
  lm_ggml_set_op_params_i32(result, 0, (int32_t) op);
7904
7497
 
7905
- result->op = LM_GGML_OP_UNARY;
7906
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7498
+ result->op = LM_GGML_OP_UNARY;
7907
7499
  result->src[0] = a;
7908
7500
 
7909
7501
  return result;
@@ -7912,14 +7504,14 @@ static struct lm_ggml_tensor * lm_ggml_unary_impl(
7912
7504
  struct lm_ggml_tensor * lm_ggml_unary(
7913
7505
  struct lm_ggml_context * ctx,
7914
7506
  struct lm_ggml_tensor * a,
7915
- enum lm_ggml_unary_op op) {
7507
+ enum lm_ggml_unary_op op) {
7916
7508
  return lm_ggml_unary_impl(ctx, a, op, false);
7917
7509
  }
7918
7510
 
7919
7511
  struct lm_ggml_tensor * lm_ggml_unary_inplace(
7920
7512
  struct lm_ggml_context * ctx,
7921
7513
  struct lm_ggml_tensor * a,
7922
- enum lm_ggml_unary_op op) {
7514
+ enum lm_ggml_unary_op op) {
7923
7515
  return lm_ggml_unary_impl(ctx, a, op, true);
7924
7516
  }
7925
7517
 
@@ -7928,20 +7520,13 @@ struct lm_ggml_tensor * lm_ggml_unary_inplace(
7928
7520
  static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32(
7929
7521
  struct lm_ggml_context * ctx,
7930
7522
  struct lm_ggml_tensor * a,
7931
- const lm_ggml_unary_op_f32_t fun,
7932
- bool inplace) {
7933
- bool is_node = false;
7934
-
7935
- if (!inplace && a->grad) {
7936
- is_node = true;
7937
- }
7938
-
7523
+ const lm_ggml_unary_op_f32_t fun,
7524
+ bool inplace) {
7939
7525
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
7940
7526
 
7941
7527
  lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
7942
7528
 
7943
- result->op = LM_GGML_OP_MAP_UNARY;
7944
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7529
+ result->op = LM_GGML_OP_MAP_UNARY;
7945
7530
  result->src[0] = a;
7946
7531
 
7947
7532
  return result;
@@ -7950,14 +7535,14 @@ static struct lm_ggml_tensor * lm_ggml_map_unary_impl_f32(
7950
7535
  struct lm_ggml_tensor * lm_ggml_map_unary_f32(
7951
7536
  struct lm_ggml_context * ctx,
7952
7537
  struct lm_ggml_tensor * a,
7953
- const lm_ggml_unary_op_f32_t fun) {
7538
+ const lm_ggml_unary_op_f32_t fun) {
7954
7539
  return lm_ggml_map_unary_impl_f32(ctx, a, fun, false);
7955
7540
  }
7956
7541
 
7957
7542
  struct lm_ggml_tensor * lm_ggml_map_unary_inplace_f32(
7958
7543
  struct lm_ggml_context * ctx,
7959
7544
  struct lm_ggml_tensor * a,
7960
- const lm_ggml_unary_op_f32_t fun) {
7545
+ const lm_ggml_unary_op_f32_t fun) {
7961
7546
  return lm_ggml_map_unary_impl_f32(ctx, a, fun, true);
7962
7547
  }
7963
7548
 
@@ -7967,22 +7552,15 @@ static struct lm_ggml_tensor * lm_ggml_map_binary_impl_f32(
7967
7552
  struct lm_ggml_context * ctx,
7968
7553
  struct lm_ggml_tensor * a,
7969
7554
  struct lm_ggml_tensor * b,
7970
- const lm_ggml_binary_op_f32_t fun,
7971
- bool inplace) {
7555
+ const lm_ggml_binary_op_f32_t fun,
7556
+ bool inplace) {
7972
7557
  LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
7973
7558
 
7974
- bool is_node = false;
7975
-
7976
- if (!inplace && (a->grad || b->grad)) {
7977
- is_node = true;
7978
- }
7979
-
7980
7559
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
7981
7560
 
7982
7561
  lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
7983
7562
 
7984
- result->op = LM_GGML_OP_MAP_BINARY;
7985
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7563
+ result->op = LM_GGML_OP_MAP_BINARY;
7986
7564
  result->src[0] = a;
7987
7565
  result->src[1] = b;
7988
7566
 
@@ -7993,7 +7571,7 @@ struct lm_ggml_tensor * lm_ggml_map_binary_f32(
7993
7571
  struct lm_ggml_context * ctx,
7994
7572
  struct lm_ggml_tensor * a,
7995
7573
  struct lm_ggml_tensor * b,
7996
- const lm_ggml_binary_op_f32_t fun) {
7574
+ const lm_ggml_binary_op_f32_t fun) {
7997
7575
  return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, false);
7998
7576
  }
7999
7577
 
@@ -8001,7 +7579,7 @@ struct lm_ggml_tensor * lm_ggml_map_binary_inplace_f32(
8001
7579
  struct lm_ggml_context * ctx,
8002
7580
  struct lm_ggml_tensor * a,
8003
7581
  struct lm_ggml_tensor * b,
8004
- const lm_ggml_binary_op_f32_t fun) {
7582
+ const lm_ggml_binary_op_f32_t fun) {
8005
7583
  return lm_ggml_map_binary_impl_f32(ctx, a, b, fun, true);
8006
7584
  }
8007
7585
 
@@ -8011,19 +7589,12 @@ static struct lm_ggml_tensor * lm_ggml_map_custom1_impl_f32(
8011
7589
  struct lm_ggml_context * ctx,
8012
7590
  struct lm_ggml_tensor * a,
8013
7591
  const lm_ggml_custom1_op_f32_t fun,
8014
- bool inplace) {
8015
- bool is_node = false;
8016
-
8017
- if (!inplace && a->grad) {
8018
- is_node = true;
8019
- }
8020
-
7592
+ bool inplace) {
8021
7593
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
8022
7594
 
8023
7595
  lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
8024
7596
 
8025
- result->op = LM_GGML_OP_MAP_CUSTOM1_F32;
8026
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7597
+ result->op = LM_GGML_OP_MAP_CUSTOM1_F32;
8027
7598
  result->src[0] = a;
8028
7599
 
8029
7600
  return result;
@@ -8050,19 +7621,12 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl_f32(
8050
7621
  struct lm_ggml_tensor * a,
8051
7622
  struct lm_ggml_tensor * b,
8052
7623
  const lm_ggml_custom2_op_f32_t fun,
8053
- bool inplace) {
8054
- bool is_node = false;
8055
-
8056
- if (!inplace && (a->grad || b->grad)) {
8057
- is_node = true;
8058
- }
8059
-
7624
+ bool inplace) {
8060
7625
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
8061
7626
 
8062
7627
  lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
8063
7628
 
8064
- result->op = LM_GGML_OP_MAP_CUSTOM2_F32;
8065
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7629
+ result->op = LM_GGML_OP_MAP_CUSTOM2_F32;
8066
7630
  result->src[0] = a;
8067
7631
  result->src[1] = b;
8068
7632
 
@@ -8093,19 +7657,12 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl_f32(
8093
7657
  struct lm_ggml_tensor * b,
8094
7658
  struct lm_ggml_tensor * c,
8095
7659
  const lm_ggml_custom3_op_f32_t fun,
8096
- bool inplace) {
8097
- bool is_node = false;
8098
-
8099
- if (!inplace && (a->grad || b->grad || c->grad)) {
8100
- is_node = true;
8101
- }
8102
-
7660
+ bool inplace) {
8103
7661
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
8104
7662
 
8105
7663
  lm_ggml_set_op_params(result, (const void *) &fun, sizeof(fun));
8106
7664
 
8107
- result->op = LM_GGML_OP_MAP_CUSTOM3_F32;
8108
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7665
+ result->op = LM_GGML_OP_MAP_CUSTOM3_F32;
8109
7666
  result->src[0] = a;
8110
7667
  result->src[1] = b;
8111
7668
  result->src[2] = c;
@@ -8133,26 +7690,20 @@ struct lm_ggml_tensor * lm_ggml_map_custom3_inplace_f32(
8133
7690
 
8134
7691
  // lm_ggml_map_custom1
8135
7692
  struct lm_ggml_map_custom1_op_params {
8136
- lm_ggml_custom1_op_t fun;
8137
- int n_tasks;
8138
- void * userdata;
7693
+ lm_ggml_custom1_op_t fun;
7694
+ int n_tasks;
7695
+ void * userdata;
8139
7696
  };
8140
7697
 
8141
7698
  static struct lm_ggml_tensor * lm_ggml_map_custom1_impl(
8142
- struct lm_ggml_context * ctx,
8143
- struct lm_ggml_tensor * a,
8144
- const lm_ggml_custom1_op_t fun,
8145
- int n_tasks,
8146
- void * userdata,
8147
- bool inplace) {
7699
+ struct lm_ggml_context * ctx,
7700
+ struct lm_ggml_tensor * a,
7701
+ const lm_ggml_custom1_op_t fun,
7702
+ int n_tasks,
7703
+ void * userdata,
7704
+ bool inplace) {
8148
7705
  LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0);
8149
7706
 
8150
- bool is_node = false;
8151
-
8152
- if (!inplace && a->grad) {
8153
- is_node = true;
8154
- }
8155
-
8156
7707
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
8157
7708
 
8158
7709
  struct lm_ggml_map_custom1_op_params params = {
@@ -8162,55 +7713,48 @@ static struct lm_ggml_tensor * lm_ggml_map_custom1_impl(
8162
7713
  };
8163
7714
  lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
8164
7715
 
8165
- result->op = LM_GGML_OP_MAP_CUSTOM1;
8166
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7716
+ result->op = LM_GGML_OP_MAP_CUSTOM1;
8167
7717
  result->src[0] = a;
8168
7718
 
8169
7719
  return result;
8170
7720
  }
8171
7721
 
8172
7722
  struct lm_ggml_tensor * lm_ggml_map_custom1(
8173
- struct lm_ggml_context * ctx,
8174
- struct lm_ggml_tensor * a,
8175
- const lm_ggml_custom1_op_t fun,
8176
- int n_tasks,
8177
- void * userdata) {
7723
+ struct lm_ggml_context * ctx,
7724
+ struct lm_ggml_tensor * a,
7725
+ const lm_ggml_custom1_op_t fun,
7726
+ int n_tasks,
7727
+ void * userdata) {
8178
7728
  return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, false);
8179
7729
  }
8180
7730
 
8181
7731
  struct lm_ggml_tensor * lm_ggml_map_custom1_inplace(
8182
- struct lm_ggml_context * ctx,
8183
- struct lm_ggml_tensor * a,
8184
- const lm_ggml_custom1_op_t fun,
8185
- int n_tasks,
8186
- void * userdata) {
7732
+ struct lm_ggml_context * ctx,
7733
+ struct lm_ggml_tensor * a,
7734
+ const lm_ggml_custom1_op_t fun,
7735
+ int n_tasks,
7736
+ void * userdata) {
8187
7737
  return lm_ggml_map_custom1_impl(ctx, a, fun, n_tasks, userdata, true);
8188
7738
  }
8189
7739
 
8190
7740
  // lm_ggml_map_custom2
8191
7741
 
8192
7742
  struct lm_ggml_map_custom2_op_params {
8193
- lm_ggml_custom2_op_t fun;
8194
- int n_tasks;
8195
- void * userdata;
7743
+ lm_ggml_custom2_op_t fun;
7744
+ int n_tasks;
7745
+ void * userdata;
8196
7746
  };
8197
7747
 
8198
7748
  static struct lm_ggml_tensor * lm_ggml_map_custom2_impl(
8199
- struct lm_ggml_context * ctx,
8200
- struct lm_ggml_tensor * a,
8201
- struct lm_ggml_tensor * b,
8202
- const lm_ggml_custom2_op_t fun,
8203
- int n_tasks,
8204
- void * userdata,
8205
- bool inplace) {
7749
+ struct lm_ggml_context * ctx,
7750
+ struct lm_ggml_tensor * a,
7751
+ struct lm_ggml_tensor * b,
7752
+ const lm_ggml_custom2_op_t fun,
7753
+ int n_tasks,
7754
+ void * userdata,
7755
+ bool inplace) {
8206
7756
  LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0);
8207
7757
 
8208
- bool is_node = false;
8209
-
8210
- if (!inplace && (a->grad || b->grad)) {
8211
- is_node = true;
8212
- }
8213
-
8214
7758
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
8215
7759
 
8216
7760
  struct lm_ggml_map_custom2_op_params params = {
@@ -8220,8 +7764,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl(
8220
7764
  };
8221
7765
  lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
8222
7766
 
8223
- result->op = LM_GGML_OP_MAP_CUSTOM2;
8224
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7767
+ result->op = LM_GGML_OP_MAP_CUSTOM2;
8225
7768
  result->src[0] = a;
8226
7769
  result->src[1] = b;
8227
7770
 
@@ -8229,22 +7772,22 @@ static struct lm_ggml_tensor * lm_ggml_map_custom2_impl(
8229
7772
  }
8230
7773
 
8231
7774
  struct lm_ggml_tensor * lm_ggml_map_custom2(
8232
- struct lm_ggml_context * ctx,
8233
- struct lm_ggml_tensor * a,
8234
- struct lm_ggml_tensor * b,
8235
- const lm_ggml_custom2_op_t fun,
8236
- int n_tasks,
8237
- void * userdata) {
7775
+ struct lm_ggml_context * ctx,
7776
+ struct lm_ggml_tensor * a,
7777
+ struct lm_ggml_tensor * b,
7778
+ const lm_ggml_custom2_op_t fun,
7779
+ int n_tasks,
7780
+ void * userdata) {
8238
7781
  return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, false);
8239
7782
  }
8240
7783
 
8241
7784
  struct lm_ggml_tensor * lm_ggml_map_custom2_inplace(
8242
- struct lm_ggml_context * ctx,
8243
- struct lm_ggml_tensor * a,
8244
- struct lm_ggml_tensor * b,
8245
- const lm_ggml_custom2_op_t fun,
8246
- int n_tasks,
8247
- void * userdata) {
7785
+ struct lm_ggml_context * ctx,
7786
+ struct lm_ggml_tensor * a,
7787
+ struct lm_ggml_tensor * b,
7788
+ const lm_ggml_custom2_op_t fun,
7789
+ int n_tasks,
7790
+ void * userdata) {
8248
7791
  return lm_ggml_map_custom2_impl(ctx, a, b, fun, n_tasks, userdata, true);
8249
7792
  }
8250
7793
 
@@ -8257,22 +7800,16 @@ struct lm_ggml_map_custom3_op_params {
8257
7800
  };
8258
7801
 
8259
7802
  static struct lm_ggml_tensor * lm_ggml_map_custom3_impl(
8260
- struct lm_ggml_context * ctx,
8261
- struct lm_ggml_tensor * a,
8262
- struct lm_ggml_tensor * b,
8263
- struct lm_ggml_tensor * c,
8264
- const lm_ggml_custom3_op_t fun,
8265
- int n_tasks,
8266
- void * userdata,
8267
- bool inplace) {
7803
+ struct lm_ggml_context * ctx,
7804
+ struct lm_ggml_tensor * a,
7805
+ struct lm_ggml_tensor * b,
7806
+ struct lm_ggml_tensor * c,
7807
+ const lm_ggml_custom3_op_t fun,
7808
+ int n_tasks,
7809
+ void * userdata,
7810
+ bool inplace) {
8268
7811
  LM_GGML_ASSERT(n_tasks == LM_GGML_N_TASKS_MAX || n_tasks > 0);
8269
7812
 
8270
- bool is_node = false;
8271
-
8272
- if (!inplace && (a->grad || b->grad || c->grad)) {
8273
- is_node = true;
8274
- }
8275
-
8276
7813
  struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a);
8277
7814
 
8278
7815
  struct lm_ggml_map_custom3_op_params params = {
@@ -8282,8 +7819,7 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl(
8282
7819
  };
8283
7820
  lm_ggml_set_op_params(result, (const void *) &params, sizeof(params));
8284
7821
 
8285
- result->op = LM_GGML_OP_MAP_CUSTOM3;
8286
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7822
+ result->op = LM_GGML_OP_MAP_CUSTOM3;
8287
7823
  result->src[0] = a;
8288
7824
  result->src[1] = b;
8289
7825
  result->src[2] = c;
@@ -8292,44 +7828,38 @@ static struct lm_ggml_tensor * lm_ggml_map_custom3_impl(
8292
7828
  }
8293
7829
 
8294
7830
  struct lm_ggml_tensor * lm_ggml_map_custom3(
8295
- struct lm_ggml_context * ctx,
8296
- struct lm_ggml_tensor * a,
8297
- struct lm_ggml_tensor * b,
8298
- struct lm_ggml_tensor * c,
8299
- const lm_ggml_custom3_op_t fun,
8300
- int n_tasks,
8301
- void * userdata) {
7831
+ struct lm_ggml_context * ctx,
7832
+ struct lm_ggml_tensor * a,
7833
+ struct lm_ggml_tensor * b,
7834
+ struct lm_ggml_tensor * c,
7835
+ const lm_ggml_custom3_op_t fun,
7836
+ int n_tasks,
7837
+ void * userdata) {
8302
7838
  return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, false);
8303
7839
  }
8304
7840
 
8305
7841
  struct lm_ggml_tensor * lm_ggml_map_custom3_inplace(
8306
- struct lm_ggml_context * ctx,
8307
- struct lm_ggml_tensor * a,
8308
- struct lm_ggml_tensor * b,
8309
- struct lm_ggml_tensor * c,
8310
- const lm_ggml_custom3_op_t fun,
8311
- int n_tasks,
8312
- void * userdata) {
7842
+ struct lm_ggml_context * ctx,
7843
+ struct lm_ggml_tensor * a,
7844
+ struct lm_ggml_tensor * b,
7845
+ struct lm_ggml_tensor * c,
7846
+ const lm_ggml_custom3_op_t fun,
7847
+ int n_tasks,
7848
+ void * userdata) {
8313
7849
  return lm_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true);
8314
7850
  }
8315
7851
 
8316
7852
  // lm_ggml_cross_entropy_loss
8317
7853
 
8318
7854
  struct lm_ggml_tensor * lm_ggml_cross_entropy_loss(
8319
- struct lm_ggml_context * ctx,
8320
- struct lm_ggml_tensor * a,
8321
- struct lm_ggml_tensor * b) {
7855
+ struct lm_ggml_context * ctx,
7856
+ struct lm_ggml_tensor * a,
7857
+ struct lm_ggml_tensor * b) {
8322
7858
  LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
8323
- bool is_node = false;
8324
-
8325
- if (a->grad || b->grad) {
8326
- is_node = true;
8327
- }
8328
7859
 
8329
7860
  struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, a->type, 1);
8330
7861
 
8331
- result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS;
8332
- result->grad = is_node ? lm_ggml_dup_tensor(ctx, result) : NULL;
7862
+ result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS;
8333
7863
  result->src[0] = a;
8334
7864
  result->src[1] = b;
8335
7865
 
@@ -8339,17 +7869,16 @@ struct lm_ggml_tensor * lm_ggml_cross_entropy_loss(
8339
7869
  // lm_ggml_cross_entropy_loss_back
8340
7870
 
8341
7871
  struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back(
8342
- struct lm_ggml_context * ctx,
8343
- struct lm_ggml_tensor * a,
8344
- struct lm_ggml_tensor * b,
8345
- struct lm_ggml_tensor * c) {
7872
+ struct lm_ggml_context * ctx,
7873
+ struct lm_ggml_tensor * a,
7874
+ struct lm_ggml_tensor * b,
7875
+ struct lm_ggml_tensor * c) {
8346
7876
  LM_GGML_ASSERT(lm_ggml_are_same_shape(a, b));
8347
7877
  LM_GGML_ASSERT(lm_ggml_is_scalar(c));
8348
7878
 
8349
7879
  struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a);
8350
7880
 
8351
- result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK;
8352
- result->grad = NULL;
7881
+ result->op = LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK;
8353
7882
  result->src[0] = a;
8354
7883
  result->src[1] = b;
8355
7884
  result->src[2] = c;
@@ -8362,12 +7891,14 @@ struct lm_ggml_tensor * lm_ggml_cross_entropy_loss_back(
8362
7891
  struct lm_ggml_tensor * lm_ggml_opt_step_adamw(
8363
7892
  struct lm_ggml_context * ctx,
8364
7893
  struct lm_ggml_tensor * a,
7894
+ struct lm_ggml_tensor * grad,
8365
7895
  float alpha,
8366
7896
  float beta1,
8367
7897
  float beta2,
8368
7898
  float eps,
8369
7899
  float wd) {
8370
- LM_GGML_ASSERT(a->grad);
7900
+ LM_GGML_ASSERT(a->flags & LM_GGML_TENSOR_FLAG_PARAM);
7901
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(a, grad));
8371
7902
  LM_GGML_ASSERT(alpha > 0.0f);
8372
7903
  LM_GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
8373
7904
  LM_GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
@@ -8376,13 +7907,6 @@ struct lm_ggml_tensor * lm_ggml_opt_step_adamw(
8376
7907
 
8377
7908
  struct lm_ggml_tensor * result = lm_ggml_view_tensor(ctx, a);
8378
7909
 
8379
- result->op = LM_GGML_OP_OPT_STEP_ADAMW;
8380
- result->grad = NULL;
8381
- result->src[0] = a;
8382
- result->src[1] = a->grad;
8383
- result->src[2] = lm_ggml_dup_tensor(ctx, a->grad);
8384
- result->src[3] = lm_ggml_dup_tensor(ctx, a->grad);
8385
-
8386
7910
  const int64_t iter = 1;
8387
7911
  memcpy(&result->op_params[0], &iter, sizeof(int64_t));
8388
7912
  lm_ggml_set_op_params_f32(result, 2, alpha);
@@ -8391,26 +7915,17 @@ struct lm_ggml_tensor * lm_ggml_opt_step_adamw(
8391
7915
  lm_ggml_set_op_params_f32(result, 5, eps);
8392
7916
  lm_ggml_set_op_params_f32(result, 6, wd);
8393
7917
 
7918
+ result->op = LM_GGML_OP_OPT_STEP_ADAMW;
7919
+ result->src[0] = a;
7920
+ result->src[1] = grad;
7921
+ result->src[2] = lm_ggml_dup_tensor(ctx, grad);
7922
+ result->src[3] = lm_ggml_dup_tensor(ctx, grad);
7923
+
8394
7924
  return result;
8395
7925
  }
8396
7926
 
8397
7927
  ////////////////////////////////////////////////////////////////////////////////
8398
7928
 
8399
- void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) {
8400
- tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM;
8401
-
8402
- LM_GGML_ASSERT(tensor->grad == NULL);
8403
- tensor->grad = lm_ggml_dup_tensor(ctx, tensor);
8404
- lm_ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
8405
- }
8406
-
8407
- void lm_ggml_set_loss(struct lm_ggml_tensor * tensor) {
8408
- LM_GGML_ASSERT(lm_ggml_is_scalar(tensor));
8409
- LM_GGML_ASSERT(tensor->type == LM_GGML_TYPE_F32);
8410
- LM_GGML_ASSERT(tensor->grad);
8411
- tensor->flags |= LM_GGML_TENSOR_FLAG_LOSS;
8412
- }
8413
-
8414
7929
  // lm_ggml_compute_forward_dup
8415
7930
 
8416
7931
  static void lm_ggml_compute_forward_dup_same_cont(
@@ -11326,6 +10841,86 @@ static void lm_ggml_compute_forward_argmax(
11326
10841
  }
11327
10842
  }
11328
10843
 
10844
+ // lm_ggml_compute_forward_count_equal
10845
+
10846
+ static void lm_ggml_compute_forward_count_equal_i32(
10847
+ const struct lm_ggml_compute_params * params,
10848
+ struct lm_ggml_tensor * dst) {
10849
+
10850
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10851
+ const struct lm_ggml_tensor * src1 = dst->src[1];
10852
+
10853
+ LM_GGML_TENSOR_BINARY_OP_LOCALS;
10854
+
10855
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_I32);
10856
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_I32);
10857
+ LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1));
10858
+ LM_GGML_ASSERT(lm_ggml_is_scalar(dst));
10859
+ LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_I64);
10860
+
10861
+ const int64_t nr = lm_ggml_nrows(src0);
10862
+
10863
+ const int ith = params->ith;
10864
+ const int nth = params->nth;
10865
+
10866
+ int64_t * sums = (int64_t *) params->wdata;
10867
+ int64_t sum_thread = 0;
10868
+
10869
+ // rows per thread
10870
+ const int64_t dr = (nr + nth - 1)/nth;
10871
+
10872
+ // row range for this thread
10873
+ const int64_t ir0 = dr*ith;
10874
+ const int64_t ir1 = MIN(ir0 + dr, nr);
10875
+
10876
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
10877
+ const int64_t i03 = ir / (ne02*ne01);
10878
+ const int64_t i02 = (ir - i03*ne03) / ne01;
10879
+ const int64_t i01 = ir - i03*ne03 - i02*ne02;
10880
+
10881
+ const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
10882
+ const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
10883
+
10884
+ for (int64_t i00 = 0; i00 < ne00; ++i00) {
10885
+ const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
10886
+ const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
10887
+
10888
+ sum_thread += val0 == val1;
10889
+ }
10890
+ }
10891
+ if (ith != 0) {
10892
+ sums[ith] = sum_thread;
10893
+ }
10894
+ lm_ggml_barrier(params->threadpool);
10895
+
10896
+ if (ith != 0) {
10897
+ return;
10898
+ }
10899
+
10900
+ for (int ith_other = 1; ith_other < nth; ++ith_other) {
10901
+ sum_thread += sums[ith_other];
10902
+ }
10903
+ *((int64_t *) dst->data) = sum_thread;
10904
+ }
10905
+
10906
+ static void lm_ggml_compute_forward_count_equal(
10907
+ const struct lm_ggml_compute_params * params,
10908
+ struct lm_ggml_tensor * dst) {
10909
+
10910
+ const struct lm_ggml_tensor * src0 = dst->src[0];
10911
+
10912
+ switch (src0->type) {
10913
+ case LM_GGML_TYPE_I32:
10914
+ {
10915
+ lm_ggml_compute_forward_count_equal_i32(params, dst);
10916
+ } break;
10917
+ default:
10918
+ {
10919
+ LM_GGML_ABORT("fatal error");
10920
+ }
10921
+ }
10922
+ }
10923
+
11329
10924
  // lm_ggml_compute_forward_repeat
11330
10925
 
11331
10926
  static void lm_ggml_compute_forward_repeat_f32(
@@ -13289,6 +12884,10 @@ static void lm_ggml_compute_forward_out_prod_f32(
13289
12884
 
13290
12885
  LM_GGML_TENSOR_BINARY_OP_LOCALS
13291
12886
 
12887
+ LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32);
12888
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
12889
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
12890
+
13292
12891
  const int ith = params->ith;
13293
12892
  const int nth = params->nth;
13294
12893
 
@@ -14618,7 +14217,7 @@ static void lm_ggml_rope_cache_init(
14618
14217
  }
14619
14218
  }
14620
14219
 
14621
- LM_GGML_CALL void lm_ggml_rope_yarn_corr_dims(
14220
+ void lm_ggml_rope_yarn_corr_dims(
14622
14221
  int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
14623
14222
  ) {
14624
14223
  // start and end correction dims
@@ -17368,41 +16967,40 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
17368
16967
  const struct lm_ggml_tensor * src0 = dst->src[0];
17369
16968
  const struct lm_ggml_tensor * src1 = dst->src[1];
17370
16969
 
17371
- LM_GGML_ASSERT(lm_ggml_is_contiguous(src0));
17372
- LM_GGML_ASSERT(lm_ggml_is_contiguous(src1));
17373
- LM_GGML_ASSERT(lm_ggml_is_scalar(dst));
16970
+ LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32);
16971
+ LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32);
16972
+ LM_GGML_ASSERT(src0->nb[0] == lm_ggml_type_size(src0->type));
16973
+ LM_GGML_ASSERT(src1->nb[0] == lm_ggml_type_size(src1->type));
17374
16974
  LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, src1));
16975
+ LM_GGML_ASSERT(lm_ggml_is_scalar(dst));
16976
+ LM_GGML_ASSERT(dst->type == LM_GGML_TYPE_F32);
16977
+
16978
+ // TODO: handle transposed/permuted matrices
16979
+ const int64_t nc = src0->ne[0];
16980
+ const int64_t nr = lm_ggml_nrows(src0);
17375
16981
 
17376
16982
  const int ith = params->ith;
17377
16983
  const int nth = params->nth;
17378
16984
 
17379
- float * sums = (float *) params->wdata;
17380
-
17381
- // TODO: handle transposed/permuted matrices
17382
- const int nc = src0->ne[0];
17383
- const int nr = lm_ggml_nrows(src0);
16985
+ float * sums = (float *) params->wdata;
16986
+ float * st = ((float *) params->wdata) + nth + ith*nc;
16987
+ float sum_thread = 0.0f;
17384
16988
 
17385
16989
  LM_GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
17386
16990
 
17387
- if (ith == 0) {
17388
- memset(sums, 0, sizeof(float) * (nth + nth * nc));
17389
- }
17390
- lm_ggml_barrier(params->threadpool);
17391
-
17392
16991
  // rows per thread
17393
- const int dr = (nr + nth - 1)/nth;
16992
+ const int64_t dr = (nr + nth - 1)/nth;
17394
16993
 
17395
16994
  // row range for this thread
17396
- const int ir0 = dr*ith;
17397
- const int ir1 = MIN(ir0 + dr, nr);
16995
+ const int64_t ir0 = dr*ith;
16996
+ const int64_t ir1 = MIN(ir0 + dr, nr);
17398
16997
 
17399
- for (int i1 = ir0; i1 < ir1; i1++) {
17400
- float * s0 = (float *)((char *) src0->data + i1*src0->nb[1]);
17401
- float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
17402
- float * st = ((float *) params->wdata) + nth + ith*nc;
16998
+ for (int64_t i1 = ir0; i1 < ir1; ++i1) {
16999
+ const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
17000
+ const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
17403
17001
 
17404
17002
  #ifndef NDEBUG
17405
- for (int i = 0; i < nc; ++i) {
17003
+ for (int64_t i = 0; i < nc; ++i) {
17406
17004
  //printf("p[%d] = %f\n", i, p[i]);
17407
17005
  assert(!isnan(s0[i]));
17408
17006
  assert(!isnan(s1[i]));
@@ -17411,23 +17009,24 @@ static void lm_ggml_compute_forward_cross_entropy_loss_f32(
17411
17009
 
17412
17010
  float max = -INFINITY;
17413
17011
  lm_ggml_vec_max_f32(nc, &max, s0);
17414
- lm_ggml_float sum = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max);
17415
- assert(sum >= 0.0);
17012
+ const lm_ggml_float sum_softmax = lm_ggml_vec_log_soft_max_f32(nc, st, s0, max);
17013
+ assert(sum_softmax >= 0.0);
17416
17014
 
17417
- lm_ggml_vec_add1_f32(nc, st, st, -sum);
17015
+ lm_ggml_vec_add1_f32(nc, st, st, -sum_softmax);
17418
17016
  lm_ggml_vec_mul_f32(nc, st, st, s1);
17419
17017
 
17420
- float st_sum = 0.0f;
17421
- lm_ggml_vec_sum_f32(nc, &st_sum, st);
17422
- sums[ith] += st_sum;
17018
+ float sum_st = 0.0f;
17019
+ lm_ggml_vec_sum_f32(nc, &sum_st, st);
17020
+ sum_thread += sum_st;
17423
17021
 
17424
17022
  #ifndef NDEBUG
17425
- for (int i = 0; i < nc; ++i) {
17023
+ for (int64_t i = 0; i < nc; ++i) {
17426
17024
  assert(!isnan(st[i]));
17427
17025
  assert(!isinf(st[i]));
17428
17026
  }
17429
17027
  #endif
17430
17028
  }
17029
+ sums[ith] = sum_thread;
17431
17030
  lm_ggml_barrier(params->threadpool);
17432
17031
 
17433
17032
  if (ith == 0) {
@@ -17493,7 +17092,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
17493
17092
  float * s1 = (float *)((char *) src1->data + i1*src1->nb[1]);
17494
17093
 
17495
17094
  #ifndef NDEBUG
17496
- for (int i = 0; i < nc; ++i) {
17095
+ for (int64_t i = 0; i < nc; ++i) {
17497
17096
  //printf("p[%d] = %f\n", i, p[i]);
17498
17097
  assert(!isnan(s0[i]));
17499
17098
  assert(!isnan(s1[i]));
@@ -17512,7 +17111,7 @@ static void lm_ggml_compute_forward_cross_entropy_loss_back_f32(
17512
17111
  lm_ggml_vec_scale_f32(nc, ds0, d_by_nr);
17513
17112
 
17514
17113
  #ifndef NDEBUG
17515
- for (int i = 0; i < nc; ++i) {
17114
+ for (int64_t i = 0; i < nc; ++i) {
17516
17115
  assert(!isnan(ds0[i]));
17517
17116
  assert(!isinf(ds0[i]));
17518
17117
  }
@@ -17700,6 +17299,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru
17700
17299
  {
17701
17300
  lm_ggml_compute_forward_argmax(params, tensor);
17702
17301
  } break;
17302
+ case LM_GGML_OP_COUNT_EQUAL:
17303
+ {
17304
+ lm_ggml_compute_forward_count_equal(params, tensor);
17305
+ } break;
17703
17306
  case LM_GGML_OP_REPEAT:
17704
17307
  {
17705
17308
  lm_ggml_compute_forward_repeat(params, tensor);
@@ -18130,7 +17733,7 @@ void lm_ggml_build_backward_gradient_checkpointing(
18130
17733
  struct lm_ggml_tensor * * checkpoints,
18131
17734
  int n_checkpoints) {
18132
17735
  lm_ggml_graph_cpy(gf, gb_tmp);
18133
- lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
17736
+ lm_ggml_build_backward_expand(ctx, gf, gb_tmp, false);
18134
17737
 
18135
17738
  if (n_checkpoints <= 0) {
18136
17739
  lm_ggml_graph_cpy(gb_tmp, gb);
@@ -18450,6 +18053,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18450
18053
  } break;
18451
18054
  case LM_GGML_OP_MEAN:
18452
18055
  case LM_GGML_OP_ARGMAX:
18056
+ case LM_GGML_OP_COUNT_EQUAL:
18453
18057
  {
18454
18058
  LM_GGML_ABORT("fatal error"); // TODO: implement
18455
18059
  }
@@ -18782,7 +18386,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18782
18386
  lm_ggml_soft_max_back(ctx, tensor->grad, tensor),
18783
18387
  zero_table, acc_table);
18784
18388
  }
18785
-
18389
+ LM_GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented");
18786
18390
  } break;
18787
18391
  case LM_GGML_OP_SOFT_MAX_BACK:
18788
18392
  {
@@ -18823,6 +18427,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18823
18427
  beta_slow),
18824
18428
  zero_table, acc_table);
18825
18429
  }
18430
+ LM_GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented");
18826
18431
  } break;
18827
18432
  case LM_GGML_OP_ROPE_BACK:
18828
18433
  {
@@ -18944,6 +18549,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
18944
18549
  }
18945
18550
  case LM_GGML_OP_FLASH_ATTN_EXT:
18946
18551
  {
18552
+ LM_GGML_ABORT("FA backward pass not adapted after rework");
18947
18553
  struct lm_ggml_tensor * flash_grad = NULL;
18948
18554
  if (src0->grad || src1->grad || tensor->src[2]->grad) {
18949
18555
  int32_t t = lm_ggml_get_op_params_i32(tensor, 0);
@@ -19118,6 +18724,7 @@ static void lm_ggml_compute_backward(struct lm_ggml_context * ctx, struct lm_ggm
19118
18724
  tensor->grad),
19119
18725
  zero_table, acc_table);
19120
18726
  }
18727
+ LM_GGML_ASSERT(!src1->grad && "backward pass for labels not implemented");
19121
18728
  } break;
19122
18729
  case LM_GGML_OP_CROSS_ENTROPY_LOSS_BACK:
19123
18730
  {
@@ -19168,7 +18775,7 @@ static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml
19168
18775
  }
19169
18776
  }
19170
18777
 
19171
- if (node->op == LM_GGML_OP_NONE && node->grad == NULL) {
18778
+ if (node->op == LM_GGML_OP_NONE && !(node->flags & LM_GGML_TENSOR_FLAG_PARAM)) {
19172
18779
  // reached a leaf node, not part of the gradient graph (e.g. a constant)
19173
18780
  LM_GGML_ASSERT(cgraph->n_leafs < cgraph->size);
19174
18781
 
@@ -19186,9 +18793,6 @@ static void lm_ggml_visit_parents(struct lm_ggml_cgraph * cgraph, struct lm_ggml
19186
18793
  }
19187
18794
 
19188
18795
  cgraph->nodes[cgraph->n_nodes] = node;
19189
- if (cgraph->grads) {
19190
- cgraph->grads[cgraph->n_nodes] = node->grad;
19191
- }
19192
18796
  cgraph->n_nodes++;
19193
18797
  }
19194
18798
  }
@@ -19216,20 +18820,62 @@ void lm_ggml_build_forward_expand(struct lm_ggml_cgraph * cgraph, struct lm_ggml
19216
18820
  lm_ggml_build_forward_impl(cgraph, tensor, true);
19217
18821
  }
19218
18822
 
19219
- void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate, bool keep) {
18823
+ void lm_ggml_build_backward_expand(struct lm_ggml_context * ctx, struct lm_ggml_cgraph * gf, struct lm_ggml_cgraph * gb, bool accumulate) {
19220
18824
  LM_GGML_ASSERT(gf->n_nodes > 0);
19221
18825
  LM_GGML_ASSERT(gf->grads);
19222
18826
 
19223
- // if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
19224
- if (keep) {
19225
- for (int i = 0; i < gf->n_nodes; i++) {
19226
- struct lm_ggml_tensor * node = gf->nodes[i];
18827
+ for (int i = 0; i < gf->n_nodes; ++i) {
18828
+ struct lm_ggml_tensor * node = gf->nodes[i];
18829
+
18830
+ if (node->type == LM_GGML_TYPE_I32) {
18831
+ continue;
18832
+ }
18833
+
18834
+ bool needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM;
18835
+ bool ignore_src[LM_GGML_MAX_SRC] = {false};
18836
+ switch (node->op) {
18837
+ // gradients in node->src[0] for one reason or another have no effect on output gradients
18838
+ case LM_GGML_OP_IM2COL: // only used for its shape
18839
+ case LM_GGML_OP_IM2COL_BACK: // same as IM2COL
18840
+ ignore_src[0] = true;
18841
+ break;
18842
+ case LM_GGML_OP_UNARY: {
18843
+ const enum lm_ggml_unary_op uop = lm_ggml_get_unary_op(node);
18844
+ // SGN and STEP unary ops are piecewise constant
18845
+ if (uop == LM_GGML_UNARY_OP_SGN || uop == LM_GGML_UNARY_OP_STEP) {
18846
+ ignore_src[0] = true;
18847
+ }
18848
+ } break;
18849
+
18850
+ // gradients in node->src[1] for one reason or another have no effect on output gradients
18851
+ case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant
18852
+ case LM_GGML_OP_GET_ROWS: // row indices not differentiable
18853
+ case LM_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS
18854
+ case LM_GGML_OP_ROPE: // positions not differentiable
18855
+ ignore_src[1] = true;
18856
+ break;
19227
18857
 
19228
- if (node->grad) {
19229
- node->grad = lm_ggml_dup_tensor(ctx, node);
19230
- gf->grads[i] = node->grad;
18858
+ default:
18859
+ break;
18860
+ }
18861
+ for (int j = 0; j < LM_GGML_MAX_SRC; ++j) {
18862
+ if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) {
18863
+ continue;
19231
18864
  }
18865
+ LM_GGML_ASSERT(node->src[j]->type == LM_GGML_TYPE_F32 || node->src[j]->type == LM_GGML_TYPE_F16);
18866
+ needs_grad = true;
18867
+ break;
18868
+ }
18869
+ if (!needs_grad) {
18870
+ continue;
19232
18871
  }
18872
+
18873
+ // inplace operations are currently not supported
18874
+ LM_GGML_ASSERT(!node->view_src || node->op == LM_GGML_OP_CPY || node->op == LM_GGML_OP_VIEW ||
18875
+ node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE);
18876
+
18877
+ // create a new tensor with the same type and shape as the node and set it as grad
18878
+ node->grad = lm_ggml_dup_tensor(ctx, node);
19233
18879
  }
19234
18880
 
19235
18881
  // keep tables of original gradients for replacement/accumulation logic
@@ -19291,7 +18937,7 @@ void lm_ggml_build_opt_adamw(
19291
18937
 
19292
18938
  if (node->flags & LM_GGML_TENSOR_FLAG_PARAM) {
19293
18939
  LM_GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
19294
- struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
18940
+ struct lm_ggml_tensor * opt_step = lm_ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd);
19295
18941
  lm_ggml_build_forward_expand(gb, opt_step);
19296
18942
  }
19297
18943
  }
@@ -19588,6 +19234,13 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) {
19588
19234
  case LM_GGML_OP_SUM_ROWS:
19589
19235
  case LM_GGML_OP_MEAN:
19590
19236
  case LM_GGML_OP_ARGMAX:
19237
+ {
19238
+ n_tasks = 1;
19239
+ } break;
19240
+ case LM_GGML_OP_COUNT_EQUAL:
19241
+ {
19242
+ n_tasks = n_threads;
19243
+ } break;
19591
19244
  case LM_GGML_OP_REPEAT:
19592
19245
  case LM_GGML_OP_REPEAT_BACK:
19593
19246
  case LM_GGML_OP_LEAKY_RELU:
@@ -20086,6 +19739,10 @@ struct lm_ggml_cplan lm_ggml_graph_plan(
20086
19739
  cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks;
20087
19740
  }
20088
19741
  } break;
19742
+ case LM_GGML_OP_COUNT_EQUAL:
19743
+ {
19744
+ cur = lm_ggml_type_size(node->type)*n_tasks;
19745
+ } break;
20089
19746
  case LM_GGML_OP_MUL_MAT:
20090
19747
  {
20091
19748
  const enum lm_ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type;
@@ -20529,7 +20186,7 @@ enum lm_ggml_status lm_ggml_graph_compute(struct lm_ggml_cgraph * cgraph, struct
20529
20186
  }
20530
20187
  #else
20531
20188
  if (n_threads > threadpool->n_threads_max) {
20532
- LM_GGML_PRINT("WARNING: cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
20189
+ LM_GGML_LOG_WARN("cplan requested more threads (%d) than available (%d)\n", n_threads, threadpool->n_threads_max);
20533
20190
  n_threads = threadpool->n_threads_max;
20534
20191
  }
20535
20192
 
@@ -21068,30 +20725,30 @@ struct lm_ggml_cgraph * lm_ggml_graph_import(const char * fname, struct lm_ggml_
21068
20725
  }
21069
20726
 
21070
20727
  void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) {
21071
- LM_GGML_PRINT("=== GRAPH ===\n");
20728
+ LM_GGML_LOG_INFO("=== GRAPH ===\n");
21072
20729
 
21073
- LM_GGML_PRINT("n_nodes = %d\n", cgraph->n_nodes);
20730
+ LM_GGML_LOG_INFO("n_nodes = %d\n", cgraph->n_nodes);
21074
20731
  for (int i = 0; i < cgraph->n_nodes; i++) {
21075
20732
  struct lm_ggml_tensor * node = cgraph->nodes[i];
21076
20733
 
21077
- LM_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
20734
+ LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n",
21078
20735
  i,
21079
20736
  node->ne[0], node->ne[1], node->ne[2],
21080
20737
  lm_ggml_op_name(node->op), (node->flags & LM_GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " ");
21081
20738
  }
21082
20739
 
21083
- LM_GGML_PRINT("n_leafs = %d\n", cgraph->n_leafs);
20740
+ LM_GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs);
21084
20741
  for (int i = 0; i < cgraph->n_leafs; i++) {
21085
20742
  struct lm_ggml_tensor * node = cgraph->leafs[i];
21086
20743
 
21087
- LM_GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
20744
+ LM_GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s %16s\n",
21088
20745
  i,
21089
20746
  node->ne[0], node->ne[1],
21090
20747
  lm_ggml_op_name(node->op),
21091
20748
  lm_ggml_get_name(node));
21092
20749
  }
21093
20750
 
21094
- LM_GGML_PRINT("========================================\n");
20751
+ LM_GGML_LOG_INFO("========================================\n");
21095
20752
  }
21096
20753
 
21097
20754
  // check if node is part of the graph
@@ -21262,7 +20919,7 @@ void lm_ggml_graph_dump_dot(const struct lm_ggml_cgraph * gb, const struct lm_gg
21262
20919
 
21263
20920
  fclose(fp);
21264
20921
 
21265
- LM_GGML_PRINT("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
20922
+ LM_GGML_LOG_INFO("%s: dot -Tpng %s -o %s.png && open %s.png\n", __func__, filename, filename, filename);
21266
20923
  }
21267
20924
 
21268
20925
  ////////////////////////////////////////////////////////////////////////////////
@@ -22094,8 +21751,6 @@ enum lm_ggml_opt_result lm_ggml_opt(
22094
21751
  struct lm_ggml_context * ctx,
22095
21752
  struct lm_ggml_opt_params params,
22096
21753
  struct lm_ggml_tensor * f) {
22097
- LM_GGML_ASSERT(f->grad && "lm_ggml_set_param called for at least one parent tensor.");
22098
-
22099
21754
  bool free_ctx = false;
22100
21755
  if (ctx == NULL) {
22101
21756
  struct lm_ggml_init_params params_ctx = {
@@ -22136,7 +21791,7 @@ enum lm_ggml_opt_result lm_ggml_opt_resume(
22136
21791
  lm_ggml_build_forward_expand(gf, f);
22137
21792
 
22138
21793
  struct lm_ggml_cgraph * gb = lm_ggml_graph_dup(ctx, gf);
22139
- lm_ggml_build_backward_expand(ctx, gf, gb, false, true);
21794
+ lm_ggml_build_backward_expand(ctx, gf, gb, false);
22140
21795
 
22141
21796
  return lm_ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
22142
21797
  }
@@ -22189,6 +21844,17 @@ void lm_ggml_set_output(struct lm_ggml_tensor * tensor) {
22189
21844
  tensor->flags |= LM_GGML_TENSOR_FLAG_OUTPUT;
22190
21845
  }
22191
21846
 
21847
+ void lm_ggml_set_param(struct lm_ggml_context * ctx, struct lm_ggml_tensor * tensor) {
21848
+ LM_GGML_UNUSED(ctx); // TODO: remove this parameter
21849
+ tensor->flags |= LM_GGML_TENSOR_FLAG_PARAM;
21850
+ }
21851
+
21852
+ void lm_ggml_set_loss(struct lm_ggml_tensor * tensor) {
21853
+ LM_GGML_ASSERT(lm_ggml_is_scalar(tensor));
21854
+ LM_GGML_ASSERT(tensor->type == LM_GGML_TYPE_F32);
21855
+ tensor->flags |= LM_GGML_TENSOR_FLAG_LOSS;
21856
+ }
21857
+
22192
21858
  ////////////////////////////////////////////////////////////////////////////////
22193
21859
 
22194
21860
  void lm_ggml_quantize_init(enum lm_ggml_type type) {
@@ -23578,16 +23244,16 @@ int lm_ggml_cpu_has_fma(void) {
23578
23244
  }
23579
23245
 
23580
23246
  int lm_ggml_cpu_has_neon(void) {
23581
- #if defined(__ARM_NEON)
23582
- return 1;
23247
+ #if defined(__ARM_ARCH)
23248
+ return lm_ggml_arm_arch_features.has_neon;
23583
23249
  #else
23584
23250
  return 0;
23585
23251
  #endif
23586
23252
  }
23587
23253
 
23588
23254
  int lm_ggml_cpu_has_sve(void) {
23589
- #if defined(__ARM_FEATURE_SVE)
23590
- return 1;
23255
+ #if defined(__ARM_ARCH)
23256
+ return lm_ggml_arm_arch_features.has_sve;
23591
23257
  #else
23592
23258
  return 0;
23593
23259
  #endif
@@ -23734,11 +23400,23 @@ int lm_ggml_cpu_has_vsx(void) {
23734
23400
  }
23735
23401
 
23736
23402
  int lm_ggml_cpu_has_matmul_int8(void) {
23737
- #if defined(__ARM_FEATURE_MATMUL_INT8)
23738
- return 1;
23403
+ #if defined(__ARM_ARCH)
23404
+ return lm_ggml_arm_arch_features.has_i8mm;
23405
+ #else
23406
+ return 0;
23407
+ #endif
23408
+ }
23409
+
23410
+ int lm_ggml_cpu_get_sve_cnt(void) {
23411
+ #if defined(__ARM_ARCH)
23412
+ return lm_ggml_arm_arch_features.sve_cnt;
23739
23413
  #else
23740
23414
  return 0;
23741
23415
  #endif
23742
23416
  }
23743
23417
 
23418
+ void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data) {
23419
+ g_logger_state.log_callback = log_callback ? log_callback : lm_ggml_log_callback_default;
23420
+ g_logger_state.log_callback_user_data = user_data;
23421
+ }
23744
23422
  ////////////////////////////////////////////////////////////////////////////////