@fugood/llama.node 1.1.6 → 1.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. package/lib/binding.ts +4 -0
  2. package/lib/index.js +6 -1
  3. package/lib/index.ts +6 -0
  4. package/lib/version.js +5 -0
  5. package/lib/version.ts +2 -0
  6. package/package.json +14 -14
  7. package/scripts/llama.cpp.patch +9 -9
  8. package/src/LlamaCompletionWorker.cpp +73 -20
  9. package/src/LlamaCompletionWorker.h +8 -0
  10. package/src/LlamaContext.cpp +9 -0
  11. package/src/common.hpp +8 -1
  12. package/src/llama.cpp/CMakeLists.txt +2 -0
  13. package/src/llama.cpp/common/arg.cpp +132 -41
  14. package/src/llama.cpp/common/chat-parser.cpp +9 -1
  15. package/src/llama.cpp/common/chat.cpp +311 -9
  16. package/src/llama.cpp/common/chat.h +4 -1
  17. package/src/llama.cpp/common/common.cpp +54 -0
  18. package/src/llama.cpp/common/common.h +46 -9
  19. package/src/llama.cpp/ggml/CMakeLists.txt +2 -0
  20. package/src/llama.cpp/ggml/include/ggml-opt.h +25 -6
  21. package/src/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
  22. package/src/llama.cpp/ggml/include/ggml.h +28 -2
  23. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
  24. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +1 -1
  25. package/src/llama.cpp/ggml/src/ggml-cpu/arch/powerpc/quants.c +66 -0
  26. package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +1136 -1077
  27. package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +14 -1
  28. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
  29. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
  30. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
  31. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +63 -2
  32. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -1
  33. package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +200 -51
  34. package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
  35. package/src/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
  36. package/src/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
  37. package/src/llama.cpp/include/llama.h +25 -0
  38. package/src/llama.cpp/src/llama-batch.cpp +1 -1
  39. package/src/llama.cpp/src/llama-chat.cpp +2 -4
  40. package/src/llama.cpp/src/llama-context.cpp +29 -22
  41. package/src/llama.cpp/src/llama-context.h +6 -5
  42. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +12 -6
  43. package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +2 -2
  44. package/src/llama.cpp/src/llama-kv-cache-unified.cpp +89 -69
  45. package/src/llama.cpp/src/llama-kv-cache-unified.h +2 -2
  46. package/src/llama.cpp/src/llama-memory-hybrid.cpp +6 -2
  47. package/src/llama.cpp/src/llama-memory-hybrid.h +2 -2
  48. package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -2
  49. package/src/llama.cpp/src/llama-memory-recurrent.h +2 -2
  50. package/src/llama.cpp/src/llama-memory.h +2 -2
  51. package/src/llama.cpp/src/llama-model.cpp +81 -70
  52. package/src/llama.cpp/src/llama-model.h +2 -0
  53. package/src/llama.cpp/src/llama-quant.cpp +1 -1
  54. package/src/llama.cpp/src/llama-vocab.cpp +2 -1
@@ -40,18 +40,22 @@
40
40
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
41
41
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
42
42
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
43
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
43
44
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
44
45
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
45
46
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
46
47
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
47
48
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
48
49
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
50
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
49
51
  #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
50
52
  // repack.cpp
51
53
  #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
52
54
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
55
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
53
56
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
54
57
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
58
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
55
59
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
56
60
  #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
57
61
  // repack.cpp
@@ -69,7 +73,6 @@
69
73
  #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
70
74
  #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
71
75
  #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
72
- #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
73
76
  // repack.cpp
74
77
  #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
75
78
  #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -80,12 +83,14 @@
80
83
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
81
84
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
82
85
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
86
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
83
87
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
84
88
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
85
89
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
86
90
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
87
91
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
88
92
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
93
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
89
94
  #elif defined(__loongarch64)
90
95
  // quants.c
91
96
  #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -103,12 +108,14 @@
103
108
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
104
109
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
105
110
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
111
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
106
112
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
107
113
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
108
114
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
109
115
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
110
116
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
111
117
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
118
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
112
119
  #elif defined(__riscv)
113
120
  // quants.c
114
121
  #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -133,11 +140,13 @@
133
140
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
134
141
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
135
142
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
143
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
136
144
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
137
145
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
138
146
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
139
147
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
140
148
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
149
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
141
150
  #elif defined(__s390x__)
142
151
  // quants.c
143
152
  #define quantize_row_q8_K_generic quantize_row_q8_K
@@ -164,12 +173,14 @@
164
173
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
165
174
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
166
175
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
176
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
167
177
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
168
178
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
169
179
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
170
180
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
171
181
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
172
182
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
183
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
173
184
  #elif defined(__wasm__)
174
185
  // quants.c
175
186
  #define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
@@ -195,10 +206,12 @@
195
206
  #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
196
207
  #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
197
208
  #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
209
+ #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
198
210
  #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
199
211
  #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
200
212
  #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
201
213
  #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
202
214
  #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
203
215
  #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
216
+ #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
204
217
  #endif
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
2022
2022
  ggml_compute_forward_opt_step_adamw(params, tensor);
2023
2023
  }
2024
2024
  break;
2025
+ case GGML_OP_OPT_STEP_SGD:
2026
+ {
2027
+ ggml_compute_forward_opt_step_sgd(params, tensor);
2028
+ }
2029
+ break;
2025
2030
  case GGML_OP_NONE:
2026
2031
  {
2027
2032
  // nop
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
2325
2330
  case GGML_OP_CROSS_ENTROPY_LOSS:
2326
2331
  case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2327
2332
  case GGML_OP_OPT_STEP_ADAMW:
2333
+ case GGML_OP_OPT_STEP_SGD:
2328
2334
  {
2329
2335
  n_tasks = n_threads;
2330
2336
  } break;
@@ -35,7 +35,7 @@
35
35
 
36
36
  // ggml-backend interface
37
37
 
38
- std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type() {
38
+ std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
39
39
  static std::vector<ggml_backend_buffer_type_t> bufts = []() {
40
40
  std::vector<ggml_backend_buffer_type_t> bufts;
41
41
 
@@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
57
57
  }
58
58
  #endif
59
59
 
60
- bufts.push_back(NULL);
61
-
62
60
  return bufts;
63
61
  }();
64
62
 
@@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
66
64
  }
67
65
 
68
66
  static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
69
- return ggml_backend_cpu_get_extra_buffers_type().data();
67
+ static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
68
+ std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
69
+ bufts.push_back(nullptr);
70
+ return bufts;
71
+ }();
72
+
73
+ return extra_bufts.data();
70
74
 
71
75
  GGML_UNUSED(device);
72
76
  }
73
77
 
74
78
  static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
75
- for (auto * extra : ggml_backend_cpu_get_extra_buffers_type()) {
76
- if (extra && extra == buft) {
79
+ for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
80
+ if (extra == buft) {
77
81
  return true;
78
82
  }
79
83
  }
@@ -210,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) {
210
214
  ctx->abort_callback_data = NULL;
211
215
 
212
216
  ggml_backend_t cpu_backend = new ggml_backend {
213
- /* .guid = */ ggml_backend_cpu_guid(),
214
- /* .interface = */ ggml_backend_cpu_i,
215
- /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
216
- /* .context = */ ctx,
217
+ /* .guid = */ ggml_backend_cpu_guid(),
218
+ /* .iface = */ ggml_backend_cpu_i,
219
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
220
+ /* .context = */ ctx,
217
221
  };
218
222
 
219
223
  if (cpu_backend == NULL) {
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
397
401
  return true;
398
402
  }
399
403
 
400
- // extra_buffer_op?
401
- for (auto extra : ggml_backend_cpu_get_extra_buffers_type()) {
402
- if (extra) {
403
- auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context;
404
- if (buf_extra && buf_extra->supports_op(dev, op)) {
405
- return true;
406
- }
407
- }
408
- }
409
-
410
- // the other case need host buffer.
411
- for (int i = 0; i < GGML_MAX_SRC; i++) {
412
- if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
413
- return false;
404
+ // check extra buffer types
405
+ // note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
406
+ for (int i = 0; i < 4; i++) {
407
+ if (op->src[i] && op->src[i]->buffer &&
408
+ ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
409
+ auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
410
+ return buf_extra->supports_op(dev, op);
414
411
  }
415
412
  }
416
413
 
@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
259
259
  const int64_t m_start = 0;
260
260
 
261
261
  const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
262
- const int64_t num_threads = KAI_MIN(n / n_step, nth);
262
+ int64_t num_threads = KAI_MIN(n / n_step, nth);
263
+ if (num_threads <= 0) {
264
+ num_threads = 1;
265
+ }
263
266
 
264
267
  if (ith < num_threads) {
265
268
  const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
309
312
  GGML_ASSERT(kernel);
310
313
 
311
314
  const int ith = params->ith;
312
- const int nth = params->nth;
315
+ const int nth_raw = params->nth;
316
+ const int nth = nth_raw > 0 ? nth_raw : 1;
313
317
 
314
318
  const size_t k = ne00;
315
319
  const size_t m = ne11;
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
327
331
  const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
328
332
  const size_t n_start = ith * num_n_per_thread;
329
333
 
330
- size_t n_to_process = num_n_per_thread;
331
- if ((n_start + n_to_process) > n) {
332
- n_to_process = n - n_start;
334
+ size_t n_to_process = 0;
335
+ if (n_start < n) {
336
+ n_to_process = num_n_per_thread;
337
+ if ((n_start + n_to_process) > n) {
338
+ n_to_process = n - n_start;
339
+ }
333
340
  }
334
341
 
335
342
  // Calculate number of columns to be processed per thread
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
361
368
  const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
362
369
  float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
363
370
 
364
- variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
365
- sizeof(float), -FLT_MAX, FLT_MAX);
371
+ if (n_to_process > 0) {
372
+ variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
373
+ sizeof(float), -FLT_MAX, FLT_MAX);
374
+ }
366
375
 
367
376
  return true;
368
377
  }
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10330
10330
  const int ir1 = MIN(ir0 + dr, nr);
10331
10331
 
10332
10332
  const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10333
+
10333
10334
  const float alpha = adamw_params_ptr[0];
10334
10335
  const float beta1 = adamw_params_ptr[1];
10335
10336
  const float beta2 = adamw_params_ptr[2];
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10337
10338
  const float wd = adamw_params_ptr[4];
10338
10339
  const float beta1h = adamw_params_ptr[5];
10339
10340
  const float beta2h = adamw_params_ptr[6];
10340
-
10341
+ const float keep = 1.f - alpha * wd;
10341
10342
  for (int ir = ir0; ir < ir1; ++ir) {
10342
10343
  const int64_t i03 = ir/(ne02*ne01);
10343
10344
  const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
10360
10361
  // The weight decay is applied independently of the Adam momenta m and v.
10361
10362
  // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10362
10363
  // See: https://arxiv.org/pdf/1711.05101v3.pdf
10363
- w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
10364
+ w[i00] = w[i00] * keep - alpha * mh / vh;
10364
10365
  }
10365
10366
  }
10366
10367
  }
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
10382
10383
  }
10383
10384
  }
10384
10385
  }
10386
+
10387
+ static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10388
+ const ggml_tensor * src0 = dst->src[0];
10389
+ const ggml_tensor * src0_grad = dst->src[1];
10390
+ const ggml_tensor * sgd_params = dst->src[2];
10391
+
10392
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10393
+ GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10394
+
10395
+ const int ith = params->ith;
10396
+ const int nth = params->nth;
10397
+
10398
+ const int nr = ggml_nrows(src0);
10399
+
10400
+ GGML_TENSOR_UNARY_OP_LOCALS
10401
+ GGML_ASSERT(nb00 == sizeof(float));
10402
+
10403
+ // rows per thread
10404
+ const int dr = (nr + nth - 1) / nth;
10405
+
10406
+ // row range for this thread
10407
+ const int ir0 = dr * ith;
10408
+ const int ir1 = MIN(ir0 + dr, nr);
10409
+
10410
+ // using adamw param subset we care about - alpha, wd - could have a separate struct
10411
+ const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10412
+ const float alpha = sgd_params_ptr[0];
10413
+ const float keep = 1.f - alpha * sgd_params_ptr[1];
10414
+
10415
+ for (int ir = ir0; ir < ir1; ++ir) {
10416
+ const int64_t i03 = ir / (ne02 * ne01);
10417
+ const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10418
+ const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10419
+
10420
+ const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10421
+
10422
+ float * w = (float *) ((char *) src0->data + offset); // weight
10423
+ const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10424
+
10425
+ for (int i00 = 0; i00 < ne00; ++i00) {
10426
+ w[i00] = w[i00] * keep - alpha * g[i00];
10427
+ }
10428
+ }
10429
+ }
10430
+
10431
+ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10432
+ const ggml_tensor * src0 = dst->src[0];
10433
+
10434
+ switch (src0->type) {
10435
+ case GGML_TYPE_F32:
10436
+ {
10437
+ ggml_compute_forward_opt_step_sgd_f32(params, dst);
10438
+ }
10439
+ break;
10440
+ default:
10441
+ {
10442
+ GGML_ABORT("fatal error - sgd is F32 only");
10443
+ }
10444
+ }
10445
+ }
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
107
107
  void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108
108
  void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
109
109
  void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
110
-
110
+ void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
111
111
  #ifdef __cplusplus
112
112
  }
113
113
  #endif