@fugood/llama.node 1.1.6 → 1.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/lib/binding.ts +4 -0
- package/lib/index.js +6 -1
- package/lib/index.ts +6 -0
- package/lib/version.js +5 -0
- package/lib/version.ts +2 -0
- package/package.json +14 -14
- package/scripts/llama.cpp.patch +9 -9
- package/src/LlamaCompletionWorker.cpp +73 -20
- package/src/LlamaCompletionWorker.h +8 -0
- package/src/llama.cpp/CMakeLists.txt +2 -0
- package/src/llama.cpp/common/arg.cpp +124 -40
- package/src/llama.cpp/common/chat-parser.cpp +9 -1
- package/src/llama.cpp/common/chat.cpp +312 -9
- package/src/llama.cpp/common/chat.h +4 -1
- package/src/llama.cpp/common/common.cpp +54 -0
- package/src/llama.cpp/common/common.h +41 -7
- package/src/llama.cpp/ggml/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +25 -6
- package/src/llama.cpp/ggml/include/ggml-zdnn.h +16 -0
- package/src/llama.cpp/ggml/include/ggml.h +28 -2
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/arch/x86/repack.cpp +1136 -1077
- package/src/llama.cpp/ggml/src/ggml-cpu/arch-fallback.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +21 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +16 -7
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +63 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.cpp +200 -51
- package/src/llama.cpp/ggml/src/ggml-cpu/repack.h +11 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/traits.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/traits.h +1 -1
- package/src/llama.cpp/include/llama.h +25 -0
- package/src/llama.cpp/src/llama-batch.cpp +1 -1
- package/src/llama.cpp/src/llama-chat.cpp +2 -4
- package/src/llama.cpp/src/llama-context.cpp +29 -17
- package/src/llama.cpp/src/llama-context.h +6 -5
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.cpp +12 -6
- package/src/llama.cpp/src/llama-kv-cache-unified-iswa.h +2 -2
- package/src/llama.cpp/src/llama-kv-cache-unified.cpp +89 -69
- package/src/llama.cpp/src/llama-kv-cache-unified.h +2 -2
- package/src/llama.cpp/src/llama-memory-hybrid.cpp +6 -2
- package/src/llama.cpp/src/llama-memory-hybrid.h +2 -2
- package/src/llama.cpp/src/llama-memory-recurrent.cpp +6 -2
- package/src/llama.cpp/src/llama-memory-recurrent.h +2 -2
- package/src/llama.cpp/src/llama-memory.h +2 -2
- package/src/llama.cpp/src/llama-model.cpp +1 -0
- package/src/llama.cpp/src/llama-model.h +1 -0
- package/src/llama.cpp/src/llama-quant.cpp +1 -1
- package/src/llama.cpp/src/llama-vocab.cpp +2 -1
|
@@ -40,18 +40,22 @@
|
|
|
40
40
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
41
41
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
42
42
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
|
43
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
43
44
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
|
44
45
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
|
45
46
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
|
46
47
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
47
48
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
48
49
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
|
50
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
49
51
|
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
|
|
50
52
|
// repack.cpp
|
|
51
53
|
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
|
|
52
54
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
55
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
53
56
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
54
57
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
58
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
55
59
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
56
60
|
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
|
|
57
61
|
// repack.cpp
|
|
@@ -80,12 +84,14 @@
|
|
|
80
84
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
81
85
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
82
86
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
|
87
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
83
88
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
|
84
89
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
|
85
90
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
|
86
91
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
87
92
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
88
93
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
|
94
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
89
95
|
#elif defined(__loongarch64)
|
|
90
96
|
// quants.c
|
|
91
97
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
|
@@ -103,12 +109,14 @@
|
|
|
103
109
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
104
110
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
105
111
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
|
112
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
106
113
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
|
107
114
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
|
108
115
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
|
109
116
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
110
117
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
111
118
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
|
119
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
112
120
|
#elif defined(__riscv)
|
|
113
121
|
// quants.c
|
|
114
122
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
|
@@ -133,11 +141,13 @@
|
|
|
133
141
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
134
142
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
135
143
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
|
144
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
136
145
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
|
137
146
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
|
138
147
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
139
148
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
140
149
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
|
150
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
141
151
|
#elif defined(__s390x__)
|
|
142
152
|
// quants.c
|
|
143
153
|
#define quantize_row_q8_K_generic quantize_row_q8_K
|
|
@@ -164,12 +174,14 @@
|
|
|
164
174
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
165
175
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
166
176
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
|
177
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
167
178
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
|
168
179
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
|
169
180
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
|
170
181
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
171
182
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
172
183
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
|
184
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
173
185
|
#elif defined(__wasm__)
|
|
174
186
|
// quants.c
|
|
175
187
|
#define ggml_vec_dot_q4_1_q8_1_generic ggml_vec_dot_q4_1_q8_1
|
|
@@ -195,10 +207,12 @@
|
|
|
195
207
|
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
|
|
196
208
|
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
|
|
197
209
|
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
|
|
210
|
+
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
|
|
198
211
|
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
|
|
199
212
|
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
|
|
200
213
|
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
|
|
201
214
|
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
|
|
202
215
|
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
|
|
203
216
|
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
|
|
217
|
+
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
|
|
204
218
|
#endif
|
|
@@ -2022,6 +2022,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
|
2022
2022
|
ggml_compute_forward_opt_step_adamw(params, tensor);
|
|
2023
2023
|
}
|
|
2024
2024
|
break;
|
|
2025
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
2026
|
+
{
|
|
2027
|
+
ggml_compute_forward_opt_step_sgd(params, tensor);
|
|
2028
|
+
}
|
|
2029
|
+
break;
|
|
2025
2030
|
case GGML_OP_NONE:
|
|
2026
2031
|
{
|
|
2027
2032
|
// nop
|
|
@@ -2325,6 +2330,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|
|
2325
2330
|
case GGML_OP_CROSS_ENTROPY_LOSS:
|
|
2326
2331
|
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
|
2327
2332
|
case GGML_OP_OPT_STEP_ADAMW:
|
|
2333
|
+
case GGML_OP_OPT_STEP_SGD:
|
|
2328
2334
|
{
|
|
2329
2335
|
n_tasks = n_threads;
|
|
2330
2336
|
} break;
|
|
@@ -35,7 +35,7 @@
|
|
|
35
35
|
|
|
36
36
|
// ggml-backend interface
|
|
37
37
|
|
|
38
|
-
std::vector<ggml_backend_buffer_type_t
|
|
38
|
+
std::vector<ggml_backend_buffer_type_t> & ggml_backend_cpu_get_extra_buffer_types() {
|
|
39
39
|
static std::vector<ggml_backend_buffer_type_t> bufts = []() {
|
|
40
40
|
std::vector<ggml_backend_buffer_type_t> bufts;
|
|
41
41
|
|
|
@@ -57,8 +57,6 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
|
|
57
57
|
}
|
|
58
58
|
#endif
|
|
59
59
|
|
|
60
|
-
bufts.push_back(NULL);
|
|
61
|
-
|
|
62
60
|
return bufts;
|
|
63
61
|
}();
|
|
64
62
|
|
|
@@ -66,14 +64,20 @@ std::vector<ggml_backend_buffer_type_t>& ggml_backend_cpu_get_extra_buffers_type
|
|
|
66
64
|
}
|
|
67
65
|
|
|
68
66
|
static ggml_backend_buffer_type_t * ggml_backend_cpu_device_get_extra_buffers_type(ggml_backend_dev_t device) {
|
|
69
|
-
|
|
67
|
+
static std::vector<ggml_backend_buffer_type_t> extra_bufts = [] {
|
|
68
|
+
std::vector<ggml_backend_buffer_type_t> bufts = ggml_backend_cpu_get_extra_buffer_types();
|
|
69
|
+
bufts.push_back(nullptr);
|
|
70
|
+
return bufts;
|
|
71
|
+
}();
|
|
72
|
+
|
|
73
|
+
return extra_bufts.data();
|
|
70
74
|
|
|
71
75
|
GGML_UNUSED(device);
|
|
72
76
|
}
|
|
73
77
|
|
|
74
78
|
static bool ggml_backend_cpu_is_extra_buffer_type(ggml_backend_buffer_type_t buft) {
|
|
75
|
-
for (auto * extra :
|
|
76
|
-
if (extra
|
|
79
|
+
for (auto * extra : ggml_backend_cpu_get_extra_buffer_types()) {
|
|
80
|
+
if (extra == buft) {
|
|
77
81
|
return true;
|
|
78
82
|
}
|
|
79
83
|
}
|
|
@@ -210,10 +214,10 @@ ggml_backend_t ggml_backend_cpu_init(void) {
|
|
|
210
214
|
ctx->abort_callback_data = NULL;
|
|
211
215
|
|
|
212
216
|
ggml_backend_t cpu_backend = new ggml_backend {
|
|
213
|
-
/* .guid
|
|
214
|
-
/* .
|
|
215
|
-
/* .device
|
|
216
|
-
/* .context
|
|
217
|
+
/* .guid = */ ggml_backend_cpu_guid(),
|
|
218
|
+
/* .iface = */ ggml_backend_cpu_i,
|
|
219
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
|
220
|
+
/* .context = */ ctx,
|
|
217
221
|
};
|
|
218
222
|
|
|
219
223
|
if (cpu_backend == NULL) {
|
|
@@ -397,20 +401,13 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
|
|
|
397
401
|
return true;
|
|
398
402
|
}
|
|
399
403
|
|
|
400
|
-
//
|
|
401
|
-
for
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
}
|
|
408
|
-
}
|
|
409
|
-
|
|
410
|
-
// the other case need host buffer.
|
|
411
|
-
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
412
|
-
if (op->src[i] && op->src[i]->buffer && !ggml_backend_buft_is_host(op->src[i]->buffer->buft)) {
|
|
413
|
-
return false;
|
|
404
|
+
// check extra buffer types
|
|
405
|
+
// note: only the first sources are checked for extra buffer types to reduce overhead, increase if necessary
|
|
406
|
+
for (int i = 0; i < 4; i++) {
|
|
407
|
+
if (op->src[i] && op->src[i]->buffer &&
|
|
408
|
+
ggml_backend_cpu_is_extra_buffer_type(op->src[i]->buffer->buft)) {
|
|
409
|
+
auto * buf_extra = (ggml::cpu::extra_buffer_type *) op->src[i]->buffer->buft->context;
|
|
410
|
+
return buf_extra->supports_op(dev, op);
|
|
414
411
|
}
|
|
415
412
|
}
|
|
416
413
|
|
|
@@ -259,7 +259,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
259
259
|
const int64_t m_start = 0;
|
|
260
260
|
|
|
261
261
|
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
|
262
|
-
|
|
262
|
+
int64_t num_threads = KAI_MIN(n / n_step, nth);
|
|
263
|
+
if (num_threads <= 0) {
|
|
264
|
+
num_threads = 1;
|
|
265
|
+
}
|
|
263
266
|
|
|
264
267
|
if (ith < num_threads) {
|
|
265
268
|
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
|
@@ -309,7 +312,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
309
312
|
GGML_ASSERT(kernel);
|
|
310
313
|
|
|
311
314
|
const int ith = params->ith;
|
|
312
|
-
const int
|
|
315
|
+
const int nth_raw = params->nth;
|
|
316
|
+
const int nth = nth_raw > 0 ? nth_raw : 1;
|
|
313
317
|
|
|
314
318
|
const size_t k = ne00;
|
|
315
319
|
const size_t m = ne11;
|
|
@@ -327,9 +331,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
327
331
|
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
|
328
332
|
const size_t n_start = ith * num_n_per_thread;
|
|
329
333
|
|
|
330
|
-
size_t n_to_process =
|
|
331
|
-
if (
|
|
332
|
-
n_to_process =
|
|
334
|
+
size_t n_to_process = 0;
|
|
335
|
+
if (n_start < n) {
|
|
336
|
+
n_to_process = num_n_per_thread;
|
|
337
|
+
if ((n_start + n_to_process) > n) {
|
|
338
|
+
n_to_process = n - n_start;
|
|
339
|
+
}
|
|
333
340
|
}
|
|
334
341
|
|
|
335
342
|
// Calculate number of columns to be processed per thread
|
|
@@ -361,8 +368,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|
|
361
368
|
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
|
362
369
|
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
|
363
370
|
|
|
364
|
-
|
|
365
|
-
|
|
371
|
+
if (n_to_process > 0) {
|
|
372
|
+
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
|
373
|
+
sizeof(float), -FLT_MAX, FLT_MAX);
|
|
374
|
+
}
|
|
366
375
|
|
|
367
376
|
return true;
|
|
368
377
|
}
|
|
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
10330
10330
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
10331
10331
|
|
|
10332
10332
|
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
|
10333
|
+
|
|
10333
10334
|
const float alpha = adamw_params_ptr[0];
|
|
10334
10335
|
const float beta1 = adamw_params_ptr[1];
|
|
10335
10336
|
const float beta2 = adamw_params_ptr[2];
|
|
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
10337
10338
|
const float wd = adamw_params_ptr[4];
|
|
10338
10339
|
const float beta1h = adamw_params_ptr[5];
|
|
10339
10340
|
const float beta2h = adamw_params_ptr[6];
|
|
10340
|
-
|
|
10341
|
+
const float keep = 1.f - alpha * wd;
|
|
10341
10342
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
10342
10343
|
const int64_t i03 = ir/(ne02*ne01);
|
|
10343
10344
|
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
|
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
|
|
10360
10361
|
// The weight decay is applied independently of the Adam momenta m and v.
|
|
10361
10362
|
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
|
10362
10363
|
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
|
10363
|
-
w[i00] = w[i00]*
|
|
10364
|
+
w[i00] = w[i00] * keep - alpha * mh / vh;
|
|
10364
10365
|
}
|
|
10365
10366
|
}
|
|
10366
10367
|
}
|
|
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
|
|
|
10382
10383
|
}
|
|
10383
10384
|
}
|
|
10384
10385
|
}
|
|
10386
|
+
|
|
10387
|
+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
10388
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10389
|
+
const ggml_tensor * src0_grad = dst->src[1];
|
|
10390
|
+
const ggml_tensor * sgd_params = dst->src[2];
|
|
10391
|
+
|
|
10392
|
+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
|
10393
|
+
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
|
10394
|
+
|
|
10395
|
+
const int ith = params->ith;
|
|
10396
|
+
const int nth = params->nth;
|
|
10397
|
+
|
|
10398
|
+
const int nr = ggml_nrows(src0);
|
|
10399
|
+
|
|
10400
|
+
GGML_TENSOR_UNARY_OP_LOCALS
|
|
10401
|
+
GGML_ASSERT(nb00 == sizeof(float));
|
|
10402
|
+
|
|
10403
|
+
// rows per thread
|
|
10404
|
+
const int dr = (nr + nth - 1) / nth;
|
|
10405
|
+
|
|
10406
|
+
// row range for this thread
|
|
10407
|
+
const int ir0 = dr * ith;
|
|
10408
|
+
const int ir1 = MIN(ir0 + dr, nr);
|
|
10409
|
+
|
|
10410
|
+
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
|
10411
|
+
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
|
10412
|
+
const float alpha = sgd_params_ptr[0];
|
|
10413
|
+
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
|
10414
|
+
|
|
10415
|
+
for (int ir = ir0; ir < ir1; ++ir) {
|
|
10416
|
+
const int64_t i03 = ir / (ne02 * ne01);
|
|
10417
|
+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
|
10418
|
+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
|
10419
|
+
|
|
10420
|
+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
|
10421
|
+
|
|
10422
|
+
float * w = (float *) ((char *) src0->data + offset); // weight
|
|
10423
|
+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
|
10424
|
+
|
|
10425
|
+
for (int i00 = 0; i00 < ne00; ++i00) {
|
|
10426
|
+
w[i00] = w[i00] * keep - alpha * g[i00];
|
|
10427
|
+
}
|
|
10428
|
+
}
|
|
10429
|
+
}
|
|
10430
|
+
|
|
10431
|
+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
|
10432
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
10433
|
+
|
|
10434
|
+
switch (src0->type) {
|
|
10435
|
+
case GGML_TYPE_F32:
|
|
10436
|
+
{
|
|
10437
|
+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
|
10438
|
+
}
|
|
10439
|
+
break;
|
|
10440
|
+
default:
|
|
10441
|
+
{
|
|
10442
|
+
GGML_ABORT("fatal error - sgd is F32 only");
|
|
10443
|
+
}
|
|
10444
|
+
}
|
|
10445
|
+
}
|
|
@@ -107,7 +107,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params *
|
|
|
107
107
|
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
108
108
|
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
109
109
|
void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
110
|
-
|
|
110
|
+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
|
111
111
|
#ifdef __cplusplus
|
|
112
112
|
}
|
|
113
113
|
#endif
|