@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,1861 @@
1
+ #include "ggml.h"
2
+ #include "ggml-alloc.h"
3
+ #include "ggml-backend.h"
4
+ #include "llama.h"
5
+ #include "common.h"
6
+ #include "train.h"
7
+ #include <vector>
8
+ #include <cstring>
9
+ #include <ctime>
10
+ #include <algorithm>
11
+ #include <string>
12
+
13
+ #if defined(_MSC_VER)
14
+ #pragma warning(disable: 4244 4267) // possible loss of data
15
+ #endif
16
+
17
+ struct my_llama_hparams {
18
+ uint32_t n_vocab = 32000;
19
+ uint32_t n_ctx = 512;
20
+ uint32_t n_embd = 4096;
21
+ uint32_t n_ff = 11008;
22
+ uint32_t n_head = 32;
23
+ uint32_t n_head_kv = 32;
24
+ uint32_t n_layer = 32;
25
+
26
+ // float f_norm_eps = 1e-5f; // falcon
27
+ float f_norm_rms_eps = 1e-5f; // llama
28
+
29
+ float rope_freq_base = 10000.0f;
30
+ float rope_freq_scale = 1.0f;
31
+
32
+ uint32_t n_gqa() const {
33
+ return n_head/n_head_kv;
34
+ }
35
+
36
+ uint32_t n_embd_head() const {
37
+ return n_embd/n_head;
38
+ }
39
+
40
+ uint32_t n_embd_gqa() const {
41
+ return n_embd/n_gqa();
42
+ }
43
+
44
+ bool operator!=(const my_llama_hparams& other) const {
45
+ return memcmp(this, &other, sizeof(other));
46
+ }
47
+ };
48
+
49
+ struct my_llama_layer {
50
+ // normalization
51
+ struct ggml_tensor * attention_norm;
52
+
53
+ // attention
54
+ struct ggml_tensor * wq;
55
+ struct ggml_tensor * wk;
56
+ struct ggml_tensor * wv;
57
+ struct ggml_tensor * wo;
58
+
59
+ // normalization
60
+ struct ggml_tensor * ffn_norm;
61
+
62
+ // ff
63
+ struct ggml_tensor * ffn_gate; // w1
64
+ struct ggml_tensor * ffn_down; // w2
65
+ struct ggml_tensor * ffn_up; // w3
66
+ };
67
+
68
+ struct my_llama_model {
69
+ struct my_llama_hparams hparams;
70
+
71
+ struct ggml_tensor * tok_embeddings;
72
+
73
+ struct ggml_tensor * norm;
74
+ struct ggml_tensor * output;
75
+
76
+ std::vector<my_llama_layer> layers;
77
+ };
78
+
79
+ struct my_llama_lora_hparams {
80
+ uint32_t lora_r = 1;
81
+ uint32_t lora_alpha = 1;
82
+ uint32_t n_rank_attention_norm = 1;
83
+ uint32_t n_rank_wq = 4;
84
+ uint32_t n_rank_wk = 4;
85
+ uint32_t n_rank_wv = 4;
86
+ uint32_t n_rank_wo = 4;
87
+ uint32_t n_rank_ffn_norm = 1;
88
+ uint32_t n_rank_ffn_gate = 4;
89
+ uint32_t n_rank_ffn_down = 4;
90
+ uint32_t n_rank_ffn_up = 4;
91
+ uint32_t n_rank_tok_embeddings = 4;
92
+ uint32_t n_rank_norm = 1;
93
+ uint32_t n_rank_output = 4;
94
+
95
+ bool operator!=(const my_llama_lora_hparams& other) const {
96
+ return memcmp(this, &other, sizeof(other));
97
+ }
98
+ };
99
+
100
+ struct my_llama_lora_layer {
101
+ // normalization
102
+ struct ggml_tensor * attention_norm_a;
103
+ struct ggml_tensor * attention_norm_b;
104
+
105
+ // attention
106
+ struct ggml_tensor * wq_a;
107
+ struct ggml_tensor * wq_b;
108
+ struct ggml_tensor * wk_a;
109
+ struct ggml_tensor * wk_b;
110
+ struct ggml_tensor * wv_a;
111
+ struct ggml_tensor * wv_b;
112
+ struct ggml_tensor * wo_a;
113
+ struct ggml_tensor * wo_b;
114
+
115
+ // normalization
116
+ struct ggml_tensor * ffn_norm_a;
117
+ struct ggml_tensor * ffn_norm_b;
118
+
119
+ // ff
120
+ struct ggml_tensor * ffn_gate_a;
121
+ struct ggml_tensor * ffn_gate_b;
122
+ struct ggml_tensor * ffn_down_a;
123
+ struct ggml_tensor * ffn_down_b;
124
+ struct ggml_tensor * ffn_up_a;
125
+ struct ggml_tensor * ffn_up_b;
126
+ };
127
+
128
+ struct my_llama_lora {
129
+ struct ggml_context * ctx = NULL;
130
+ ggml_backend_buffer_t data;
131
+
132
+ my_llama_lora_hparams hparams;
133
+
134
+ struct ggml_tensor * tok_embeddings_a;
135
+ struct ggml_tensor * tok_embeddings_b;
136
+
137
+ struct ggml_tensor * norm_a;
138
+ struct ggml_tensor * norm_b;
139
+ struct ggml_tensor * output_a;
140
+ struct ggml_tensor * output_b;
141
+
142
+ std::vector<my_llama_lora_layer> layers;
143
+ };
144
+
145
+ // gguf constants
146
+ static const char * LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora";
147
+ static const char * LLM_KV_TRAINING_TYPE = "training.type";
148
+
149
+ static const char * LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD = "training.lora.rank.token_embd";
150
+ static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM = "training.lora.rank.output_norm";
151
+ static const char * LLM_KV_TRAINING_LORA_RANK_OUTPUT = "training.lora.rank.output";
152
+ static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_NORM = "training.lora.rank.attn_norm";
153
+ static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_Q = "training.lora.rank.attn_q";
154
+ static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_K = "training.lora.rank.attn_k";
155
+ static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_V = "training.lora.rank.attn_v";
156
+ static const char * LLM_KV_TRAINING_LORA_RANK_ATTN_OUT = "training.lora.rank.attn_output";
157
+ static const char * LLM_KV_TRAINING_LORA_RANK_FFN_NORM = "training.lora.rank.ffn_norm";
158
+ static const char * LLM_KV_TRAINING_LORA_RANK_FFN_GATE = "training.lora.rank.ffn_gate";
159
+ static const char * LLM_KV_TRAINING_LORA_RANK_FFN_DOWN = "training.lora.rank.ffn_down";
160
+ static const char * LLM_KV_TRAINING_LORA_RANK_FFN_UP = "training.lora.rank.ffn_up";
161
+
162
+ // gguf constants (sync with gguf.py)
163
+
164
+ static const char * LLM_KV_GENERAL_ARCHITECTURE = "general.architecture";
165
+ static const char * LLM_KV_GENERAL_FILE_TYPE = "general.file_type";
166
+
167
+ static const char * LLM_KV_CONTEXT_LENGTH = "%s.context_length";
168
+ static const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
169
+ static const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
170
+ static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
171
+ static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
172
+ static const char * LLM_KV_ATTENTION_HEAD_COUNT_KV = "%s.attention.head_count_kv";
173
+ static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
174
+ static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
175
+ static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
176
+ static const char * LLM_KV_ROPE_SCALE_LINEAR = "%s.rope.scale_linear";
177
+
178
+ static const char * LLM_TENSOR_TOKEN_EMBD = "token_embd";
179
+ static const char * LLM_TENSOR_OUTPUT_NORM = "output_norm";
180
+ static const char * LLM_TENSOR_OUTPUT = "output";
181
+ static const char * LLM_TENSOR_ATTN_NORM = "blk.%d.attn_norm";
182
+ static const char * LLM_TENSOR_ATTN_Q = "blk.%d.attn_q";
183
+ static const char * LLM_TENSOR_ATTN_K = "blk.%d.attn_k";
184
+ static const char * LLM_TENSOR_ATTN_V = "blk.%d.attn_v";
185
+ static const char * LLM_TENSOR_ATTN_OUT = "blk.%d.attn_output";
186
+ static const char * LLM_TENSOR_FFN_NORM = "blk.%d.ffn_norm";
187
+ static const char * LLM_TENSOR_FFN_GATE = "blk.%d.ffn_gate";
188
+ static const char * LLM_TENSOR_FFN_DOWN = "blk.%d.ffn_down";
189
+ static const char * LLM_TENSOR_FFN_UP = "blk.%d.ffn_up";
190
+
191
+ static void print_params(struct my_llama_hparams * params) {
192
+ printf("%s: n_vocab : %u\n", __func__, params->n_vocab);
193
+ printf("%s: n_ctx : %u\n", __func__, params->n_ctx);
194
+ printf("%s: n_embd : %u\n", __func__, params->n_embd);
195
+ printf("%s: n_ff : %u\n", __func__, params->n_ff);
196
+ printf("%s: n_head : %u\n", __func__, params->n_head);
197
+ printf("%s: n_head_kv : %u\n", __func__, params->n_head_kv);
198
+ printf("%s: n_layer : %u\n", __func__, params->n_layer);
199
+ printf("%s: norm_rms_eps : %f\n", __func__, params->f_norm_rms_eps);
200
+ printf("%s: rope_freq_base : %f\n", __func__, params->rope_freq_base);
201
+ printf("%s: rope_freq_scale : %f\n", __func__, params->rope_freq_scale);
202
+ }
203
+
204
+ static void print_lora_params(struct my_llama_lora_hparams * params) {
205
+ printf("%s: n_rank_attention_norm : %u\n", __func__, params->n_rank_attention_norm);
206
+ printf("%s: n_rank_wq : %u\n", __func__, params->n_rank_wq);
207
+ printf("%s: n_rank_wk : %u\n", __func__, params->n_rank_wk);
208
+ printf("%s: n_rank_wv : %u\n", __func__, params->n_rank_wv);
209
+ printf("%s: n_rank_wo : %u\n", __func__, params->n_rank_wo);
210
+ printf("%s: n_rank_ffn_norm : %u\n", __func__, params->n_rank_ffn_norm);
211
+ printf("%s: n_rank_ffn_gate : %u\n", __func__, params->n_rank_ffn_gate);
212
+ printf("%s: n_rank_ffn_down : %u\n", __func__, params->n_rank_ffn_down);
213
+ printf("%s: n_rank_ffn_up : %u\n", __func__, params->n_rank_ffn_up);
214
+ printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
215
+ printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm);
216
+ printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output);
217
+ }
218
+
219
+ #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \
220
+ { \
221
+ const std::string skey(key); \
222
+ const int kid = gguf_find_key(ctx, skey.c_str()); \
223
+ if (kid >= 0) { \
224
+ enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
225
+ if (ktype != (type)) { \
226
+ die_fmt("key %s has wrong type: %s", skey.c_str(), gguf_type_name(ktype)); \
227
+ } \
228
+ (dst) = func(ctx, kid); \
229
+ } else if (req) { \
230
+ die_fmt("key not found in model: %s", skey.c_str()); \
231
+ } \
232
+ }
233
+
234
+ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_hparams * hparams, const char * expected_arch) {
235
+ std::string arch;
236
+
237
+ GGUF_GET_KEY(ctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
238
+ if (expected_arch != NULL) {
239
+ if (arch != expected_arch) {
240
+ printf("%s: arch=%s expected_arch=%s\n", __func__, arch.c_str(), expected_arch);
241
+ }
242
+ GGML_ASSERT(arch == expected_arch);
243
+ }
244
+
245
+ std::vector<char> keybuf;
246
+ keybuf.resize(512);
247
+ auto kv = [&arch, &keybuf](const char * key) -> const char * {
248
+ snprintf(keybuf.data(), keybuf.size(), key, arch.c_str());
249
+ return keybuf.data();
250
+ };
251
+
252
+ GGUF_GET_KEY(ctx, hparams->n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
253
+ GGUF_GET_KEY(ctx, hparams->n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
254
+ GGUF_GET_KEY(ctx, hparams->n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
255
+ GGUF_GET_KEY(ctx, hparams->n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
256
+ GGUF_GET_KEY(ctx, hparams->n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
257
+
258
+ // n_head_kv is optional, default to n_head
259
+ hparams->n_head_kv = hparams->n_head;
260
+ GGUF_GET_KEY(ctx, hparams->n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
261
+
262
+ float rope_freq_scale = 1.0f;
263
+ GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
264
+ GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
265
+ GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
266
+ if (rope_freq_scale != 1.0f) {
267
+ hparams->rope_freq_scale = 1.0f / rope_freq_scale;
268
+ }
269
+ }
270
+
271
+ static void init_model(struct llama_model * input, struct my_llama_model * model, const char * fn_model, uint32_t n_ctx) {
272
+ auto & hparams = model->hparams;
273
+
274
+ std::vector<char> tn_buf;
275
+ tn_buf.resize(GGML_MAX_NAME);
276
+ auto tn = [&tn_buf](const char * key) -> const char * {
277
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", key);
278
+ return tn_buf.data();
279
+ };
280
+ auto tni = [&tn_buf](const char * key, int bid) -> const char * {
281
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
282
+ std::string s = tn_buf.data();
283
+ snprintf(tn_buf.data(), tn_buf.size(), "%s.weight", s.c_str());
284
+ return tn_buf.data();
285
+ };
286
+
287
+
288
+ // get parameters directly from gguf file
289
+ {
290
+ struct gguf_init_params params = {
291
+ /*.no_alloc = */ false,
292
+ /*.ctx = */ NULL,
293
+ };
294
+ struct gguf_context * mctx = gguf_init_from_file(fn_model, params);
295
+
296
+ load_model_hparams_gguf(mctx, &hparams, "llama");
297
+
298
+ gguf_free(mctx);
299
+ }
300
+ hparams.n_vocab = llama_n_vocab(input);
301
+ hparams.n_ctx = n_ctx;
302
+
303
+ // get tensors from llama_model (possibly mmapped)
304
+ model->tok_embeddings = llama_get_model_tensor(input, tn(LLM_TENSOR_TOKEN_EMBD));
305
+ model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
306
+ model->output = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT));
307
+
308
+ assert_shape_2d(model->tok_embeddings, hparams.n_embd, hparams.n_vocab);
309
+ assert_shape_1d(model->norm, hparams.n_embd);
310
+ assert_shape_2d(model->output, hparams.n_embd, hparams.n_vocab);
311
+
312
+ model->layers.resize(hparams.n_layer);
313
+ for (uint32_t i = 0; i < hparams.n_layer; ++i) {
314
+ auto & layer = model->layers[i];
315
+
316
+ layer.attention_norm = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_NORM, i));
317
+ layer.wq = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_Q, i));
318
+ layer.wk = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_K, i));
319
+ layer.wv = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_V, i));
320
+ layer.wo = llama_get_model_tensor(input, tni(LLM_TENSOR_ATTN_OUT, i));
321
+ layer.ffn_norm = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_NORM, i));
322
+ layer.ffn_gate = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_GATE, i));
323
+ layer.ffn_down = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_DOWN, i));
324
+ layer.ffn_up = llama_get_model_tensor(input, tni(LLM_TENSOR_FFN_UP, i));
325
+
326
+ assert_shape_1d(layer.attention_norm, hparams.n_embd);
327
+ assert_shape_2d(layer.wq, hparams.n_embd, hparams.n_embd);
328
+ assert_shape_2d(layer.wk, hparams.n_embd, hparams.n_embd_gqa());
329
+ assert_shape_2d(layer.wv, hparams.n_embd, hparams.n_embd_gqa());
330
+ assert_shape_2d(layer.wo, hparams.n_embd, hparams.n_embd);
331
+ assert_shape_1d(layer.ffn_norm, hparams.n_embd);
332
+ assert_shape_2d(layer.ffn_gate, hparams.n_embd, hparams.n_ff);
333
+ assert_shape_2d(layer.ffn_down, hparams.n_ff, hparams.n_embd);
334
+ assert_shape_2d(layer.ffn_up, hparams.n_embd, hparams.n_ff);
335
+ }
336
+ }
337
+
338
+ static void set_param_lora(struct my_llama_lora * lora) {
339
+ const uint32_t n_layer = lora->layers.size();
340
+
341
+ struct ggml_context* ctx = lora->ctx;
342
+
343
+ ggml_set_param(ctx, lora->tok_embeddings_a);
344
+ ggml_set_param(ctx, lora->tok_embeddings_b);
345
+ ggml_set_param(ctx, lora->norm_a);
346
+ ggml_set_param(ctx, lora->norm_b);
347
+ ggml_set_param(ctx, lora->output_a);
348
+ ggml_set_param(ctx, lora->output_b);
349
+
350
+ for (uint32_t i = 0; i < n_layer; ++i) {
351
+ auto & layer = lora->layers[i];
352
+
353
+ ggml_set_param(ctx, layer.attention_norm_a);
354
+ ggml_set_param(ctx, layer.attention_norm_b);
355
+ ggml_set_param(ctx, layer.wq_a);
356
+ ggml_set_param(ctx, layer.wq_b);
357
+ ggml_set_param(ctx, layer.wk_a);
358
+ ggml_set_param(ctx, layer.wk_b);
359
+ ggml_set_param(ctx, layer.wv_a);
360
+ ggml_set_param(ctx, layer.wv_b);
361
+ ggml_set_param(ctx, layer.wo_a);
362
+ ggml_set_param(ctx, layer.wo_b);
363
+ ggml_set_param(ctx, layer.ffn_norm_a);
364
+ ggml_set_param(ctx, layer.ffn_norm_b);
365
+ ggml_set_param(ctx, layer.ffn_gate_a);
366
+ ggml_set_param(ctx, layer.ffn_gate_b);
367
+ ggml_set_param(ctx, layer.ffn_down_a);
368
+ ggml_set_param(ctx, layer.ffn_down_b);
369
+ ggml_set_param(ctx, layer.ffn_up_a);
370
+ ggml_set_param(ctx, layer.ffn_up_b);
371
+ }
372
+ }
373
+
374
+ static void init_lora(const struct my_llama_model * model, struct my_llama_lora * lora) {
375
+ const auto & lparams = lora->hparams;
376
+
377
+ const uint32_t n_embd = model->hparams.n_embd;
378
+ const uint32_t n_embd_gqa = model->hparams.n_embd_gqa();
379
+ const uint32_t n_layer = model->hparams.n_layer;
380
+ const uint32_t n_vocab = model->hparams.n_vocab;
381
+ const uint32_t n_ff = model->hparams.n_ff;
382
+
383
+ std::vector<char> tn_buf;
384
+ tn_buf.resize(GGML_MAX_NAME);
385
+ auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
386
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", key, suffix);
387
+ return tn_buf.data();
388
+ };
389
+ auto tni = [&tn_buf](const char * key, const char * suffix, int bid) -> const char * {
390
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
391
+ std::string s = tn_buf.data();
392
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", s.c_str(), suffix);
393
+ return tn_buf.data();
394
+ };
395
+
396
+ // context for lora tensors without their data
397
+ struct ggml_init_params ctx_lora_params;
398
+ ctx_lora_params.mem_size = ggml_tensor_overhead()*2*(6 + n_layer*18);
399
+ ctx_lora_params.mem_buffer = NULL;
400
+ ctx_lora_params.no_alloc = true;
401
+
402
+ struct ggml_context * ctx = ggml_init(ctx_lora_params);
403
+ lora->ctx = ctx;
404
+
405
+ lora->tok_embeddings_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_tok_embeddings, n_embd);
406
+ lora->tok_embeddings_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_tok_embeddings, n_vocab);
407
+ lora->norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_norm, n_embd);
408
+ lora->norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_norm, 1);
409
+ lora->output_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_output, n_embd);
410
+ lora->output_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_output, n_vocab);
411
+
412
+ ggml_set_name(lora->tok_embeddings_a, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.lora_a"));
413
+ ggml_set_name(lora->tok_embeddings_b, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.lora_b"));
414
+ ggml_set_name(lora->norm_a, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.lora_a"));
415
+ ggml_set_name(lora->norm_b, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.lora_b"));
416
+ ggml_set_name(lora->output_a, tn(LLM_TENSOR_OUTPUT, ".weight.lora_a"));
417
+ ggml_set_name(lora->output_b, tn(LLM_TENSOR_OUTPUT, ".weight.lora_b"));
418
+
419
+ lora->layers.resize(n_layer);
420
+ for (uint32_t i = 0; i < n_layer; ++i) {
421
+ auto & layer = lora->layers[i];
422
+
423
+ layer.attention_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_attention_norm, n_embd);
424
+ layer.attention_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_attention_norm, 1);
425
+
426
+ layer.wq_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wq, n_embd);
427
+ layer.wq_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wq, n_embd);
428
+ layer.wk_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wk, n_embd);
429
+ layer.wk_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wk, n_embd_gqa);
430
+ layer.wv_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wv, n_embd);
431
+ layer.wv_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wv, n_embd_gqa);
432
+ layer.wo_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wo, n_embd);
433
+ layer.wo_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_wo, n_embd);
434
+
435
+ layer.ffn_norm_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, n_embd);
436
+ layer.ffn_norm_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_norm, 1);
437
+
438
+ layer.ffn_gate_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_gate, n_embd);
439
+ layer.ffn_gate_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_gate, n_ff);
440
+ layer.ffn_down_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_down, n_ff);
441
+ layer.ffn_down_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_down, n_embd);
442
+ layer.ffn_up_a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_up, n_embd);
443
+ layer.ffn_up_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, lparams.n_rank_ffn_up, n_ff);
444
+
445
+ ggml_set_name(layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_a", i));
446
+ ggml_set_name(layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, ".weight.lora_b", i));
447
+ ggml_set_name(layer.wq_a, tni(LLM_TENSOR_ATTN_Q, ".weight.lora_a", i));
448
+ ggml_set_name(layer.wq_b, tni(LLM_TENSOR_ATTN_Q, ".weight.lora_b", i));
449
+ ggml_set_name(layer.wk_a, tni(LLM_TENSOR_ATTN_K, ".weight.lora_a", i));
450
+ ggml_set_name(layer.wk_b, tni(LLM_TENSOR_ATTN_K, ".weight.lora_b", i));
451
+ ggml_set_name(layer.wv_a, tni(LLM_TENSOR_ATTN_V, ".weight.lora_a", i));
452
+ ggml_set_name(layer.wv_b, tni(LLM_TENSOR_ATTN_V, ".weight.lora_b", i));
453
+ ggml_set_name(layer.wo_a, tni(LLM_TENSOR_ATTN_OUT, ".weight.lora_a", i));
454
+ ggml_set_name(layer.wo_b, tni(LLM_TENSOR_ATTN_OUT, ".weight.lora_b", i));
455
+ ggml_set_name(layer.ffn_norm_a, tni(LLM_TENSOR_FFN_NORM, ".weight.lora_a", i));
456
+ ggml_set_name(layer.ffn_norm_b, tni(LLM_TENSOR_FFN_NORM, ".weight.lora_b", i));
457
+ ggml_set_name(layer.ffn_gate_a, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_a", i));
458
+ ggml_set_name(layer.ffn_gate_b, tni(LLM_TENSOR_FFN_GATE, ".weight.lora_b", i));
459
+ ggml_set_name(layer.ffn_down_a, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_a", i));
460
+ ggml_set_name(layer.ffn_down_b, tni(LLM_TENSOR_FFN_DOWN, ".weight.lora_b", i));
461
+ ggml_set_name(layer.ffn_up_a, tni(LLM_TENSOR_FFN_UP, ".weight.lora_a", i));
462
+ ggml_set_name(layer.ffn_up_b, tni(LLM_TENSOR_FFN_UP, ".weight.lora_b", i));
463
+ }
464
+
465
+ set_param_lora(lora);
466
+
467
+ // allocate data for lora tensors
468
+ lora->data = ggml_backend_alloc_ctx_tensors_from_buft(ctx, ggml_backend_cpu_buffer_type());
469
+ }
470
+
471
+ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, float std, float min, float max) {
472
+ const uint32_t n_layer = lora->layers.size();
473
+
474
+ struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
475
+
476
+ randomize_tensor_normal(lora->tok_embeddings_a, rnd);
477
+ ggml_set_zero(lora->tok_embeddings_b);
478
+ randomize_tensor_normal(lora->norm_a, rnd);
479
+ ggml_set_zero(lora->norm_b);
480
+ randomize_tensor_normal(lora->output_a, rnd);
481
+ ggml_set_zero(lora->output_b);
482
+
483
+ for (uint32_t i = 0; i < n_layer; ++i) {
484
+ auto & layer = lora->layers[i];
485
+ randomize_tensor_normal(layer.attention_norm_a, rnd);
486
+ ggml_set_zero(layer.attention_norm_b);
487
+
488
+ randomize_tensor_normal(layer.wq_a, rnd);
489
+ ggml_set_zero(layer.wq_b);
490
+ randomize_tensor_normal(layer.wk_a, rnd);
491
+ ggml_set_zero(layer.wk_b);
492
+ randomize_tensor_normal(layer.wv_a, rnd);
493
+ ggml_set_zero(layer.wv_b);
494
+ randomize_tensor_normal(layer.wo_a, rnd);
495
+ ggml_set_zero(layer.wo_b);
496
+
497
+ randomize_tensor_normal(layer.ffn_norm_a, rnd);
498
+ ggml_set_zero(layer.ffn_norm_b);
499
+
500
+ randomize_tensor_normal(layer.ffn_gate_a, rnd);
501
+ ggml_set_zero(layer.ffn_gate_b);
502
+ randomize_tensor_normal(layer.ffn_down_a, rnd);
503
+ ggml_set_zero(layer.ffn_down_b);
504
+ randomize_tensor_normal(layer.ffn_up_a, rnd);
505
+ ggml_set_zero(layer.ffn_up_b);
506
+ }
507
+
508
+ free_random_normal_distribution(rnd);
509
+ }
510
+
511
+ static struct ggml_tensor * llama_build_lora_finetune_graphs(
512
+ struct my_llama_model * model,
513
+ struct my_llama_lora * lora,
514
+ ggml_gallocr_t alloc,
515
+ struct ggml_context * ctx,
516
+ struct ggml_cgraph * gf,
517
+ struct ggml_cgraph * gb,
518
+ struct ggml_cgraph * gb_tmp,
519
+ struct ggml_tensor * * logits,
520
+ struct ggml_tensor * tokens_input,
521
+ struct ggml_tensor * targets,
522
+ const int n_tokens,
523
+ const int n_batch,
524
+ const bool enable_flash_attn,
525
+ const bool enable_checkpointing,
526
+ const bool measure_only) {
527
+
528
+ ggml_set_scratch(ctx, { 0, 0, nullptr, });
529
+ const int n_past = 0;
530
+ const int N = n_tokens;
531
+ const auto & hparams = model->hparams;
532
+ const int n_ctx = hparams.n_ctx;
533
+ const int n_vocab = hparams.n_vocab;
534
+ const int n_embd = hparams.n_embd;
535
+ const int n_layer = hparams.n_layer;
536
+ const int n_head = hparams.n_head;
537
+ const int n_head_kv = hparams.n_head_kv;
538
+ const int n_ff = hparams.n_ff;
539
+ const int n_rot = hparams.n_embd_head();
540
+ const int n_embd_head = hparams.n_embd_head();
541
+ const int n_embd_gqa = hparams.n_embd_gqa();
542
+
543
+ const float rms_norm_eps = hparams.f_norm_rms_eps;
544
+ const float rope_freq_base = hparams.rope_freq_base;
545
+ const float rope_freq_scale = hparams.rope_freq_scale;
546
+
547
+ GGML_ASSERT((size_t) n_layer == lora->layers.size());
548
+
549
+ auto set_name = [](struct ggml_tensor * t, const char * n) {
550
+ ggml_set_name(t, n);
551
+ if (t->grad) {
552
+ ggml_format_name(t->grad, "%s->grad", n);
553
+ }
554
+ };
555
+
556
+ // KQ_pos - contains the positions
557
+ struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
558
+ ggml_set_input(KQ_pos);
559
+
560
+ // rope has so much parameters that we make a custom function for it
561
+ auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
562
+ (struct ggml_tensor * t) -> struct ggml_tensor * {
563
+ // not capturing these, to silcence warnings
564
+ const int rope_mode = 0;
565
+
566
+ return ggml_rope_custom(ctx,
567
+ t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
568
+ rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
569
+ );
570
+ };
571
+
572
+ set_name(tokens_input, "tokens_input");
573
+ set_name(targets, "targets");
574
+
575
+ GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
576
+
577
+ auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
578
+ if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
579
+ return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
580
+ } else if (a->type == GGML_TYPE_F32) {
581
+ return ggml_add(ctx, a, b);
582
+ } else {
583
+ die_fmt("%s: Finetuning on tensors with type '%s' is not yet supported.\n",
584
+ __func__, ggml_type_name(a->type));
585
+ }
586
+ };
587
+
588
+ struct ggml_tensor * tok_embeddings = add_to_f32(ctx, model->tok_embeddings, ggml_mul_mat(ctx, lora->tok_embeddings_a, lora->tok_embeddings_b));
589
+ struct ggml_tensor * norm = add_to_f32(ctx, model->norm, ggml_mul_mat(ctx, lora->norm_a, lora->norm_b));
590
+ struct ggml_tensor * output = add_to_f32(ctx, model->output, ggml_mul_mat(ctx, lora->output_a, lora->output_b));
591
+
592
+ struct ggml_tensor * t00 = ggml_reshape_1d(ctx, tokens_input, N*n_batch); set_name(t00, "t00"); assert_shape_1d(t00, N*n_batch);
593
+ struct ggml_tensor * t01 = ggml_get_rows(ctx, tok_embeddings, t00); set_name(t01, "t01"); assert_shape_2d(t01, n_embd, N*n_batch);
594
+
595
+ struct ggml_tensor * cur = t01;
596
+
597
+ std::vector<struct ggml_tensor *> checkpoints;
598
+ if (enable_checkpointing) {
599
+ checkpoints.push_back(tokens_input);
600
+ checkpoints.push_back(targets);
601
+ checkpoints.push_back(t00);
602
+ checkpoints.push_back(t01);
603
+ }
604
+
605
+ const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
606
+
607
+ for (int il = 0; il < n_layer; ++il) {
608
+ struct my_llama_layer & layer = model->layers[il];
609
+ struct my_llama_lora_layer & llayer = lora->layers[il];
610
+
611
+ struct ggml_tensor * attention_norm = add_to_f32(ctx, layer.attention_norm, ggml_mul_mat(ctx, llayer.attention_norm_a, llayer.attention_norm_b));
612
+ struct ggml_tensor * ffn_norm = add_to_f32(ctx, layer.ffn_norm, ggml_mul_mat(ctx, llayer.ffn_norm_a, llayer.ffn_norm_b));
613
+ struct ggml_tensor * wq = add_to_f32(ctx, layer.wq, ggml_mul_mat(ctx, llayer.wq_a, llayer.wq_b));
614
+ struct ggml_tensor * wk = add_to_f32(ctx, layer.wk, ggml_mul_mat(ctx, llayer.wk_a, llayer.wk_b));
615
+ struct ggml_tensor * wv = add_to_f32(ctx, layer.wv, ggml_mul_mat(ctx, llayer.wv_a, llayer.wv_b));
616
+ struct ggml_tensor * wo = add_to_f32(ctx, layer.wo, ggml_mul_mat(ctx, llayer.wo_a, llayer.wo_b));
617
+ struct ggml_tensor * ffn_gate = add_to_f32(ctx, layer.ffn_gate, ggml_mul_mat(ctx, llayer.ffn_gate_a, llayer.ffn_gate_b));
618
+ struct ggml_tensor * ffn_down = add_to_f32(ctx, layer.ffn_down, ggml_mul_mat(ctx, llayer.ffn_down_a, llayer.ffn_down_b));
619
+ struct ggml_tensor * ffn_up = add_to_f32(ctx, layer.ffn_up, ggml_mul_mat(ctx, llayer.ffn_up_a, llayer.ffn_up_b));
620
+
621
+ struct ggml_tensor * t02 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t02, "t02"); assert_shape_2d(t02, n_embd, N*n_batch);
622
+ struct ggml_tensor * t03 = ggml_repeat (ctx, attention_norm, t02); set_name(t03, "t03"); assert_shape_2d(t03, n_embd, N*n_batch);
623
+ struct ggml_tensor * t04 = ggml_mul (ctx, t03, t02); set_name(t04, "t04"); assert_shape_2d(t04, n_embd, N*n_batch);
624
+ struct ggml_tensor * t05 = ggml_mul_mat (ctx, wq, t04); set_name(t05, "t05"); assert_shape_2d(t05, n_embd, N*n_batch);
625
+ struct ggml_tensor * t06 = ggml_reshape_4d (ctx, t05, n_embd_head, n_head, N, n_batch); set_name(t06, "t06"); assert_shape_4d(t06, n_embd_head, n_head, N, n_batch);
626
+ struct ggml_tensor * t07 = rope (t06); set_name(t07, "t07"); assert_shape_4d(t07, n_embd_head, n_head, N, n_batch);
627
+ struct ggml_tensor * t08 = ggml_mul_mat (ctx, wk, t04); set_name(t08, "t08"); assert_shape_2d(t08, n_embd_gqa, N*n_batch);
628
+ struct ggml_tensor * t09 = ggml_reshape_4d (ctx, t08, n_embd_head, n_head_kv, N, n_batch); set_name(t09, "t09"); assert_shape_4d(t09, n_embd_head, n_head_kv, N, n_batch);
629
+ struct ggml_tensor * t10 = rope (t09); set_name(t10, "t10"); assert_shape_4d(t10, n_embd_head, n_head_kv, N, n_batch);
630
+
631
+ struct ggml_tensor * t11;
632
+ if (ggml_is_quantized(wv->type)) {
633
+ struct ggml_tensor * t11_1 = ggml_mul_mat (ctx, wv, t04); set_name(t11_1, "t11_1"); assert_shape_2d(t11_1, n_embd_gqa, N*n_batch);
634
+ struct ggml_tensor * t11_2 = ggml_transpose(ctx, t11_1); set_name(t11_2, "t11_2"); assert_shape_2d(t11_2, N*n_batch, n_embd_gqa);
635
+ t11 = ggml_cont (ctx, t11_2); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd_gqa);
636
+ } else {
637
+ t11 = ggml_mul_mat (ctx, t04, wv); set_name(t11, "t11"); assert_shape_2d(t11, N*n_batch, n_embd_gqa);
638
+ }
639
+
640
+ struct ggml_tensor * t12 = ggml_reshape_4d (ctx, t11, N, n_batch, n_embd_head, n_head_kv); set_name(t12, "t12"); assert_shape_4d(t12, N, n_batch, n_embd_head, n_head_kv);
641
+ struct ggml_tensor * t13 = ggml_permute (ctx, t07, 0, 2, 1, 3); set_name(t13, "t13"); assert_shape_4d(t13, n_embd_head, N, n_head, n_batch);
642
+ struct ggml_tensor * t14 = ggml_permute (ctx, t10, 0, 2, 1, 3); set_name(t14, "t14"); assert_shape_4d(t14, n_embd_head, N, n_head_kv, n_batch);
643
+ struct ggml_tensor * t15 = ggml_permute (ctx, t12, 0, 3, 1, 2); set_name(t15, "t15"); assert_shape_4d(t15, N, n_embd_head, n_head_kv, n_batch);
644
+ struct ggml_tensor * t16;
645
+ if (enable_flash_attn) {
646
+ t16 = ggml_flash_attn(ctx, t13, t14, t15, true); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
647
+ } else {
648
+ struct ggml_tensor * t16_0 = ggml_mul_mat (ctx, t14, t13); set_name(t16_0, "t16_0"); assert_shape_4d(t16_0, N, N, n_head, n_batch);
649
+ struct ggml_tensor * t16_1 = ggml_scale_inplace (ctx, t16_0, kv_scale); set_name(t16_1, "t16_1"); assert_shape_4d(t16_1, N, N, n_head, n_batch);
650
+ struct ggml_tensor * t16_2 = ggml_diag_mask_inf_inplace(ctx, t16_1, n_past); set_name(t16_2, "t16_2"); assert_shape_4d(t16_2, N, N, n_head, n_batch);
651
+ struct ggml_tensor * t16_3 = ggml_soft_max_inplace (ctx, t16_2); set_name(t16_3, "t16_3"); assert_shape_4d(t16_3, N, N, n_head, n_batch);
652
+ t16 = ggml_mul_mat(ctx, t15, t16_3); set_name(t16, "t16"); assert_shape_4d(t16, n_embd_head, N, n_head, n_batch);
653
+ }
654
+ struct ggml_tensor * t17 = ggml_permute (ctx, t16, 0, 2, 1, 3); set_name(t17, "t17"); assert_shape_4d(t17, n_embd_head, n_head, N, n_batch);
655
+ struct ggml_tensor * t18 = ggml_cont (ctx, t17); set_name(t18, "t18"); assert_shape_4d(t18, n_embd_head, n_head, N, n_batch);
656
+ struct ggml_tensor * t19 = ggml_reshape_2d (ctx, t18, n_embd, N*n_batch); set_name(t19, "t19"); assert_shape_2d(t19, n_embd, N*n_batch);
657
+ struct ggml_tensor * t20 = ggml_mul_mat (ctx, wo, t19); set_name(t20, "t20"); assert_shape_2d(t20, n_embd, N*n_batch);
658
+ struct ggml_tensor * t21 = ggml_add (ctx, t20, cur); set_name(t21, "t21"); assert_shape_2d(t21, n_embd, N*n_batch);
659
+ struct ggml_tensor * t22 = ggml_rms_norm (ctx, t21, rms_norm_eps); set_name(t22, "t22"); assert_shape_2d(t22, n_embd, N*n_batch);
660
+ struct ggml_tensor * t23 = ggml_repeat (ctx, ffn_norm, t22); set_name(t23, "t23"); assert_shape_2d(t23, n_embd, N*n_batch);
661
+ struct ggml_tensor * t24 = ggml_mul (ctx, t23, t22); set_name(t24, "t24"); assert_shape_2d(t24, n_embd, N*n_batch);
662
+ struct ggml_tensor * t25 = ggml_mul_mat (ctx, ffn_up, t24); set_name(t25, "t25"); assert_shape_2d(t25, n_ff, N*n_batch);
663
+ struct ggml_tensor * t26 = ggml_mul_mat (ctx, ffn_gate, t24); set_name(t26, "t26"); assert_shape_2d(t26, n_ff, N*n_batch);
664
+ struct ggml_tensor * t27 = ggml_silu (ctx, t26); set_name(t27, "t27"); assert_shape_2d(t27, n_ff, N*n_batch);
665
+ struct ggml_tensor * t28 = ggml_mul (ctx, t27, t25); set_name(t28, "t28"); assert_shape_2d(t28, n_ff, N*n_batch);
666
+ struct ggml_tensor * t29 = ggml_mul_mat (ctx, ffn_down, t28); set_name(t29, "t29"); assert_shape_2d(t29, n_embd, N*n_batch);
667
+ struct ggml_tensor * t30 = ggml_add (ctx, t29, t21); set_name(t30, "t30"); assert_shape_2d(t30, n_embd, N*n_batch);
668
+ cur = t30;
669
+ if (enable_checkpointing) {
670
+ checkpoints.push_back(cur);
671
+ }
672
+ }
673
+ struct ggml_tensor * t31 = ggml_rms_norm (ctx, cur, rms_norm_eps); set_name(t31, "t31"); assert_shape_2d(t31, n_embd, N*n_batch);
674
+ struct ggml_tensor * t32 = ggml_repeat (ctx, norm, t31); set_name(t32, "t32"); assert_shape_2d(t32, n_embd, N*n_batch);
675
+ struct ggml_tensor * t33 = ggml_mul (ctx, t32, t31); set_name(t33, "t33"); assert_shape_2d(t33, n_embd, N*n_batch);
676
+ struct ggml_tensor * t34 = ggml_mul_mat (ctx, output, t33); set_name(t34, "t34"); assert_shape_2d(t34, n_vocab, N*n_batch);
677
+ struct ggml_tensor * t35 = ggml_reshape_3d (ctx, t34, n_vocab, N, n_batch); set_name(t35, "t35"); assert_shape_3d(t35, n_vocab, N, n_batch);
678
+ struct ggml_tensor * t36 = ggml_cross_entropy_loss(ctx, t35, targets); set_name(t36, "t36"); assert_shape_1d(t36, 1);
679
+
680
+ if (enable_checkpointing) {
681
+ checkpoints.push_back(t31);
682
+ checkpoints.push_back(t32);
683
+ checkpoints.push_back(t33);
684
+ checkpoints.push_back(t34);
685
+ checkpoints.push_back(t35);
686
+ checkpoints.push_back(t36);
687
+ }
688
+
689
+ ggml_build_forward_expand(gf, t36);
690
+
691
+ if (enable_checkpointing) {
692
+ ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
693
+ } else {
694
+ ggml_graph_cpy(gf, gb);
695
+ ggml_build_backward_expand(ctx, gf, gb, true);
696
+ }
697
+
698
+ GGML_ASSERT(alloc != NULL);
699
+
700
+ // make sure some tensors are not reallocated by inserting new temporary nodes depending on them
701
+ int n_leafs_before = gb->n_leafs;
702
+ int n_nodes_before = gb->n_nodes;
703
+
704
+ // output tensors
705
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
706
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
707
+ // input gradient
708
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
709
+ GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
710
+ ggml_set_input(t36->grad);
711
+ // KQ_pos
712
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
713
+
714
+ // make sure base model tensors data cannot be used in viewable operations
715
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
716
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
717
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
718
+ for (int il = 0; il < n_layer; ++il) {
719
+ struct my_llama_layer & layer = model->layers[il];
720
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
721
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
722
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
723
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
724
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
725
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
726
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_gate, 1.0f));
727
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_down, 1.0f));
728
+ ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_up, 1.0f));
729
+ }
730
+
731
+ // allocating checkpoints in one block to reduce memory fragmentation
732
+ // note: they will be freed in reverse order
733
+ for (unsigned int i = 0; i < checkpoints.size(); ++i) {
734
+ if (checkpoints[i]->data == NULL && checkpoints[i]->view_src == NULL) {
735
+ ggml_set_input(checkpoints[i]);
736
+ }
737
+ }
738
+
739
+ if (measure_only) {
740
+ ggml_gallocr_reserve(alloc, gb);
741
+ } else {
742
+ ggml_gallocr_alloc_graph(alloc, gb);
743
+
744
+ // set KQ_pos
745
+ {
746
+ int * data = (int *) KQ_pos->data;
747
+ for (int i = 0; i < N; ++i) {
748
+ data[i] = n_past + i;
749
+ }
750
+ }
751
+ }
752
+
753
+ // remove the additional nodes and leafs
754
+ for (int i = n_leafs_before; i < gb->n_leafs; ++i) {
755
+ gb->leafs[i] = NULL;
756
+ }
757
+ for (int i = n_nodes_before; i < gb->n_nodes; ++i) {
758
+ gb->nodes[i] = NULL;
759
+ }
760
+ gb->n_leafs = n_leafs_before;
761
+ gb->n_nodes = n_nodes_before;
762
+
763
+ *logits = t35;
764
+ return t36;
765
+ }
766
+
767
+ static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora) {
768
+ // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
769
+
770
+ std::string arch;
771
+
772
+ std::vector<char> keybuf;
773
+ keybuf.resize(512);
774
+
775
+ GGUF_GET_KEY(fctx, arch, gguf_get_val_str, GGUF_TYPE_STRING, true, LLM_KV_GENERAL_ARCHITECTURE);
776
+ GGML_ASSERT(arch == "llama");
777
+
778
+ uint32_t ftype_u;
779
+ GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
780
+ GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
781
+
782
+ struct my_llama_hparams hparams;
783
+ load_model_hparams_gguf(fctx, &hparams, arch.c_str());
784
+
785
+ // parameters that define tensor shapes must match
786
+ GGML_ASSERT(hparams.n_embd == model->hparams.n_embd);
787
+ GGML_ASSERT(hparams.n_ff == model->hparams.n_ff);
788
+ GGML_ASSERT(hparams.n_head == model->hparams.n_head);
789
+ GGML_ASSERT(hparams.n_head_kv == model->hparams.n_head_kv);
790
+ GGML_ASSERT(hparams.n_layer == model->hparams.n_layer);
791
+
792
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_tok_embeddings, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD);
793
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM);
794
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_output, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_OUTPUT);
795
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_attention_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_NORM);
796
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wq, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_Q);
797
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wk, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_K);
798
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wv, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_V);
799
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_wo, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT);
800
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_norm, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_NORM);
801
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_gate, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_GATE);
802
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_down, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN);
803
+ GGUF_GET_KEY(fctx, lora->hparams.n_rank_ffn_up, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_TRAINING_LORA_RANK_FFN_UP);
804
+
805
+ init_lora(model, lora);
806
+
807
+ copy_tensor_by_name(lora->tok_embeddings_a, f_ggml_ctx, ggml_get_name(lora->tok_embeddings_a));
808
+ copy_tensor_by_name(lora->tok_embeddings_b, f_ggml_ctx, ggml_get_name(lora->tok_embeddings_b));
809
+ copy_tensor_by_name(lora->norm_a, f_ggml_ctx, ggml_get_name(lora->norm_a));
810
+ copy_tensor_by_name(lora->norm_b, f_ggml_ctx, ggml_get_name(lora->norm_b));
811
+ copy_tensor_by_name(lora->output_a, f_ggml_ctx, ggml_get_name(lora->output_a));
812
+ copy_tensor_by_name(lora->output_b, f_ggml_ctx, ggml_get_name(lora->output_b));
813
+
814
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
815
+ auto & layer = lora->layers[i];
816
+ copy_tensor_by_name(layer.attention_norm_a, f_ggml_ctx, ggml_get_name(layer.attention_norm_a));
817
+ copy_tensor_by_name(layer.attention_norm_b, f_ggml_ctx, ggml_get_name(layer.attention_norm_b));
818
+ copy_tensor_by_name(layer.wq_a, f_ggml_ctx, ggml_get_name(layer.wq_a));
819
+ copy_tensor_by_name(layer.wq_b, f_ggml_ctx, ggml_get_name(layer.wq_b));
820
+ copy_tensor_by_name(layer.wk_a, f_ggml_ctx, ggml_get_name(layer.wk_a));
821
+ copy_tensor_by_name(layer.wk_b, f_ggml_ctx, ggml_get_name(layer.wk_b));
822
+ copy_tensor_by_name(layer.wv_a, f_ggml_ctx, ggml_get_name(layer.wv_a));
823
+ copy_tensor_by_name(layer.wv_b, f_ggml_ctx, ggml_get_name(layer.wv_b));
824
+ copy_tensor_by_name(layer.wo_a, f_ggml_ctx, ggml_get_name(layer.wo_a));
825
+ copy_tensor_by_name(layer.wo_b, f_ggml_ctx, ggml_get_name(layer.wo_b));
826
+ copy_tensor_by_name(layer.ffn_norm_a, f_ggml_ctx, ggml_get_name(layer.ffn_norm_a));
827
+ copy_tensor_by_name(layer.ffn_norm_b, f_ggml_ctx, ggml_get_name(layer.ffn_norm_b));
828
+ copy_tensor_by_name(layer.ffn_gate_a, f_ggml_ctx, ggml_get_name(layer.ffn_gate_a));
829
+ copy_tensor_by_name(layer.ffn_gate_b, f_ggml_ctx, ggml_get_name(layer.ffn_gate_b));
830
+ copy_tensor_by_name(layer.ffn_down_a, f_ggml_ctx, ggml_get_name(layer.ffn_down_a));
831
+ copy_tensor_by_name(layer.ffn_down_b, f_ggml_ctx, ggml_get_name(layer.ffn_down_b));
832
+ copy_tensor_by_name(layer.ffn_up_a, f_ggml_ctx, ggml_get_name(layer.ffn_up_a));
833
+ copy_tensor_by_name(layer.ffn_up_b, f_ggml_ctx, ggml_get_name(layer.ffn_up_b));
834
+ }
835
+ }
836
+
837
+ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora) {
838
+ const char * arch = "llama";
839
+ enum llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
840
+
841
+ std::vector<char> keybuf;
842
+ keybuf.resize(512);
843
+ auto kv = [arch, &keybuf](const char * key) -> const char * {
844
+ snprintf(keybuf.data(), keybuf.size(), key, arch);
845
+ return keybuf.data();
846
+ };
847
+
848
+ gguf_set_val_str(fctx, LLM_KV_GENERAL_ARCHITECTURE, arch);
849
+ gguf_set_val_u32(fctx, LLM_KV_GENERAL_FILE_TYPE, ftype);
850
+
851
+ gguf_set_val_u32(fctx, kv(LLM_KV_CONTEXT_LENGTH), model->hparams.n_ctx);
852
+ gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd);
853
+ gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff);
854
+ gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
855
+ gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV), model->hparams.n_head_kv);
856
+ gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer);
857
+ gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_embd_head());
858
+ gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), model->hparams.f_norm_rms_eps);
859
+ gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), model->hparams.rope_freq_base);
860
+ gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), model->hparams.rope_freq_scale);
861
+
862
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_TOKEN_EMBD, lora->hparams.n_rank_tok_embeddings);
863
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_OUTPUT_NORM, lora->hparams.n_rank_norm);
864
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_OUTPUT, lora->hparams.n_rank_output);
865
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_NORM, lora->hparams.n_rank_attention_norm);
866
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_Q, lora->hparams.n_rank_wq);
867
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_K, lora->hparams.n_rank_wk);
868
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_V, lora->hparams.n_rank_wv);
869
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_ATTN_OUT, lora->hparams.n_rank_wo);
870
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_NORM, lora->hparams.n_rank_ffn_norm);
871
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_GATE, lora->hparams.n_rank_ffn_gate);
872
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_DOWN, lora->hparams.n_rank_ffn_down);
873
+ gguf_set_val_u32(fctx, LLM_KV_TRAINING_LORA_RANK_FFN_UP, lora->hparams.n_rank_ffn_up);
874
+
875
+ gguf_add_tensor(fctx, lora->tok_embeddings_a);
876
+ gguf_add_tensor(fctx, lora->tok_embeddings_b);
877
+ gguf_add_tensor(fctx, lora->norm_a);
878
+ gguf_add_tensor(fctx, lora->norm_b);
879
+ gguf_add_tensor(fctx, lora->output_a);
880
+ gguf_add_tensor(fctx, lora->output_b);
881
+
882
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
883
+ auto & layer = lora->layers[i];
884
+
885
+ gguf_add_tensor(fctx, layer.attention_norm_a);
886
+ gguf_add_tensor(fctx, layer.attention_norm_b);
887
+ gguf_add_tensor(fctx, layer.wq_a);
888
+ gguf_add_tensor(fctx, layer.wq_b);
889
+ gguf_add_tensor(fctx, layer.wk_a);
890
+ gguf_add_tensor(fctx, layer.wk_b);
891
+ gguf_add_tensor(fctx, layer.wv_a);
892
+ gguf_add_tensor(fctx, layer.wv_b);
893
+ gguf_add_tensor(fctx, layer.wo_a);
894
+ gguf_add_tensor(fctx, layer.wo_b);
895
+ gguf_add_tensor(fctx, layer.ffn_norm_a);
896
+ gguf_add_tensor(fctx, layer.ffn_norm_b);
897
+ gguf_add_tensor(fctx, layer.ffn_gate_a);
898
+ gguf_add_tensor(fctx, layer.ffn_gate_b);
899
+ gguf_add_tensor(fctx, layer.ffn_down_a);
900
+ gguf_add_tensor(fctx, layer.ffn_down_b);
901
+ gguf_add_tensor(fctx, layer.ffn_up_a);
902
+ gguf_add_tensor(fctx, layer.ffn_up_b);
903
+ }
904
+ }
905
+
906
+ static void load_checkpoint_lora_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
907
+ std::string train_type = LLM_KV_TRAINING_TYPE_FINETUNE_LORA;
908
+ GGUF_GET_KEY(fctx, train_type, gguf_get_val_str, GGUF_TYPE_STRING, false, LLM_KV_TRAINING_TYPE);
909
+ GGML_ASSERT(train_type == LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
910
+
911
+ load_train_state_gguf(fctx, f_ggml_ctx, train);
912
+ load_llama_lora_gguf(fctx, f_ggml_ctx, model, lora);
913
+ }
914
+
915
+ static void save_checkpoint_lora_gguf(struct gguf_context * fctx, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
916
+ gguf_set_val_str(fctx, LLM_KV_TRAINING_TYPE, LLM_KV_TRAINING_TYPE_FINETUNE_LORA);
917
+ save_llama_lora_gguf(fctx, model, lora);
918
+ save_train_state_gguf(fctx, train);
919
+ }
920
+
921
+ static bool load_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
922
+ struct ggml_context * f_ggml_ctx;
923
+ struct gguf_init_params params;
924
+ params.no_alloc = false;
925
+ params.ctx = &f_ggml_ctx;
926
+ struct gguf_context * fctx = gguf_init_from_file(filename, params);
927
+ if (fctx == NULL) {
928
+ return false;
929
+ }
930
+
931
+ load_checkpoint_lora_gguf(fctx, f_ggml_ctx, model, lora, train);
932
+
933
+ gguf_free(fctx);
934
+ return true;
935
+ }
936
+
937
+ static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) {
938
+ printf("%s: saving to %s\n", __func__, filename);
939
+ struct gguf_context * fctx = gguf_init_empty();
940
+
941
+ save_checkpoint_lora_gguf(fctx, model, lora, train);
942
+
943
+ // write file
944
+ const bool only_meta = false;
945
+ gguf_write_to_file(fctx, filename, only_meta);
946
+ gguf_free(fctx);
947
+ }
948
+
949
+ struct llama_file {
950
+ // use FILE * so we don't have to re-open the file to mmap
951
+ FILE * fp;
952
+ size_t size;
953
+
954
+ llama_file(const char * fname, const char * mode) {
955
+ fp = std::fopen(fname, mode);
956
+ if (fp == NULL) {
957
+ size = 0;
958
+ } else {
959
+ seek(0, SEEK_END);
960
+ size = tell();
961
+ seek(0, SEEK_SET);
962
+ }
963
+ }
964
+
965
+ size_t tell() const {
966
+ #ifdef _WIN32
967
+ __int64 ret = _ftelli64(fp);
968
+ #else
969
+ long ret = std::ftell(fp);
970
+ #endif
971
+ GGML_ASSERT(ret != -1); // this really shouldn't fail
972
+ return (size_t) ret;
973
+ }
974
+
975
+ void seek(size_t offset, int whence) {
976
+ #ifdef _WIN32
977
+ int ret = _fseeki64(fp, (__int64) offset, whence);
978
+ #else
979
+ int ret = std::fseek(fp, (long) offset, whence);
980
+ #endif
981
+ GGML_ASSERT(ret == 0); // same
982
+ }
983
+
984
+ void read_raw(void * ptr, size_t size) {
985
+ if (size == 0) {
986
+ return;
987
+ }
988
+ errno = 0;
989
+ std::size_t ret = std::fread(ptr, size, 1, fp);
990
+ if (ferror(fp)) {
991
+ die_fmt("read error: %s", strerror(errno));
992
+ }
993
+ if (ret != 1) {
994
+ die("unexpectedly reached end of file");
995
+ }
996
+ }
997
+
998
+ std::uint32_t read_u32() {
999
+ std::uint32_t ret;
1000
+ read_raw(&ret, sizeof(ret));
1001
+ return ret;
1002
+ }
1003
+
1004
+ std::string read_string(std::uint32_t len) {
1005
+ std::vector<char> chars(len);
1006
+ read_raw(chars.data(), len);
1007
+ return std::string(chars.data(), len);
1008
+ }
1009
+
1010
+ void write_raw(const void * ptr, size_t size) {
1011
+ if (size == 0) {
1012
+ return;
1013
+ }
1014
+ errno = 0;
1015
+ size_t ret = std::fwrite(ptr, size, 1, fp);
1016
+ if (ret != 1) {
1017
+ die_fmt("write error: %s", strerror(errno));
1018
+ }
1019
+ }
1020
+
1021
+ void write_u32(std::uint32_t val) {
1022
+ write_raw(&val, sizeof(val));
1023
+ }
1024
+
1025
+ ~llama_file() {
1026
+ if (fp) {
1027
+ std::fclose(fp);
1028
+ }
1029
+ }
1030
+ };
1031
+
1032
+ static void write_tensor(struct llama_file * file, struct ggml_tensor * tensor, const char * name) {
1033
+ if (tensor == NULL) {
1034
+ file->write_u32(0);
1035
+ file->write_u32(0);
1036
+ file->write_u32(GGML_TYPE_F32);
1037
+ file->seek((0-file->tell()) & 31, SEEK_CUR);
1038
+ return;
1039
+ }
1040
+ if (name == NULL) {
1041
+ name = ggml_get_name(tensor);
1042
+ }
1043
+ uint32_t name_len = strlen(name);
1044
+ uint32_t nd = ggml_n_dims(tensor);
1045
+ uint32_t ne[4] = { (uint32_t)tensor->ne[0],
1046
+ (uint32_t)tensor->ne[1],
1047
+ (uint32_t)tensor->ne[2],
1048
+ (uint32_t)tensor->ne[3] };
1049
+ file->write_u32(nd);
1050
+ file->write_u32(name_len);
1051
+ file->write_u32(tensor->type);
1052
+ file->write_raw(ne, sizeof(ne[0]) * nd);
1053
+ file->write_raw(name, name_len);
1054
+ file->seek((0-file->tell()) & 31, SEEK_CUR);
1055
+ file->write_raw(tensor->data, ggml_nbytes(tensor));
1056
+ }
1057
+
1058
+ static void save_as_llama_lora(const char * filename, struct my_llama_lora * lora) {
1059
+ printf("%s: saving to %s\n", __func__, filename);
1060
+ struct llama_file file(filename, "wb");
1061
+ if (file.fp == NULL) {
1062
+ return;
1063
+ }
1064
+
1065
+ std::vector<char> tn_buf;
1066
+ tn_buf.resize(GGML_MAX_NAME);
1067
+
1068
+ auto tn = [&tn_buf](const char * key, const char * suffix) -> const char * {
1069
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", key, suffix);
1070
+ return tn_buf.data();
1071
+ };
1072
+
1073
+ auto tni = [&tn_buf](const char * key, int bid, const char * suffix) -> const char * {
1074
+ snprintf(tn_buf.data(), tn_buf.size(), key, bid);
1075
+ std::string s = tn_buf.data();
1076
+ snprintf(tn_buf.data(), tn_buf.size(), "%s%s", s.c_str(), suffix);
1077
+ return tn_buf.data();
1078
+ };
1079
+
1080
+ // write_magic
1081
+ file.write_u32(LLAMA_FILE_MAGIC_GGLA); // magic
1082
+ file.write_u32(1); // version
1083
+ // write_hparams
1084
+ file.write_u32(lora->hparams.lora_r);
1085
+ file.write_u32(lora->hparams.lora_alpha);
1086
+ // write tensors
1087
+ write_tensor(&file, lora->tok_embeddings_a, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.loraA"));
1088
+ write_tensor(&file, lora->tok_embeddings_b, tn(LLM_TENSOR_TOKEN_EMBD, ".weight.loraB"));
1089
+ write_tensor(&file, lora->norm_a, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.loraA"));
1090
+ write_tensor(&file, lora->norm_b, tn(LLM_TENSOR_OUTPUT_NORM, ".weight.loraB"));
1091
+ write_tensor(&file, lora->output_a, tn(LLM_TENSOR_OUTPUT, ".weight.loraA"));
1092
+ write_tensor(&file, lora->output_b, tn(LLM_TENSOR_OUTPUT, ".weight.loraB"));
1093
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
1094
+ auto & layer = lora->layers[i];
1095
+ write_tensor(&file, layer.attention_norm_a, tni(LLM_TENSOR_ATTN_NORM, i, ".weight.loraA"));
1096
+ write_tensor(&file, layer.attention_norm_b, tni(LLM_TENSOR_ATTN_NORM, i, ".weight.loraB"));
1097
+ write_tensor(&file, layer.wq_a, tni(LLM_TENSOR_ATTN_Q, i, ".weight.loraA"));
1098
+ write_tensor(&file, layer.wq_b, tni(LLM_TENSOR_ATTN_Q, i, ".weight.loraB"));
1099
+ write_tensor(&file, layer.wk_a, tni(LLM_TENSOR_ATTN_K, i, ".weight.loraA"));
1100
+ write_tensor(&file, layer.wk_b, tni(LLM_TENSOR_ATTN_K, i, ".weight.loraB"));
1101
+ write_tensor(&file, layer.wv_a, tni(LLM_TENSOR_ATTN_V, i, ".weight.loraA"));
1102
+ write_tensor(&file, layer.wv_b, tni(LLM_TENSOR_ATTN_V, i, ".weight.loraB"));
1103
+ write_tensor(&file, layer.wo_a, tni(LLM_TENSOR_ATTN_OUT, i, ".weight.loraA"));
1104
+ write_tensor(&file, layer.wo_b, tni(LLM_TENSOR_ATTN_OUT, i, ".weight.loraB"));
1105
+ write_tensor(&file, layer.ffn_norm_a, tni(LLM_TENSOR_FFN_NORM, i, ".weight.loraA"));
1106
+ write_tensor(&file, layer.ffn_norm_b, tni(LLM_TENSOR_FFN_NORM, i, ".weight.loraB"));
1107
+ write_tensor(&file, layer.ffn_gate_a, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraA"));
1108
+ write_tensor(&file, layer.ffn_gate_b, tni(LLM_TENSOR_FFN_GATE, i, ".weight.loraB"));
1109
+ write_tensor(&file, layer.ffn_down_a, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraA"));
1110
+ write_tensor(&file, layer.ffn_down_b, tni(LLM_TENSOR_FFN_DOWN, i, ".weight.loraB"));
1111
+ write_tensor(&file, layer.ffn_up_a, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraA"));
1112
+ write_tensor(&file, layer.ffn_up_b, tni(LLM_TENSOR_FFN_UP, i, ".weight.loraB"));
1113
+ }
1114
+ }
1115
+
1116
+ struct train_params {
1117
+ struct train_params_common common;
1118
+
1119
+ const char * fn_model_base;
1120
+ const char * fn_lora_out;
1121
+
1122
+ bool only_write_lora;
1123
+
1124
+ float f_norm_rms_eps;
1125
+ float rope_freq_base;
1126
+ float rope_freq_scale;
1127
+
1128
+ bool custom_f_norm_rms_eps;
1129
+ bool custom_rope_freq_base;
1130
+ bool custom_rope_freq_scale;
1131
+
1132
+ int32_t lora_r;
1133
+ int32_t lora_alpha;
1134
+ bool custom_lora_alpha;
1135
+
1136
+ uint32_t n_rank_attention_norm;
1137
+ uint32_t n_rank_wq;
1138
+ uint32_t n_rank_wk;
1139
+ uint32_t n_rank_wv;
1140
+ uint32_t n_rank_wo;
1141
+ uint32_t n_rank_ffn_norm;
1142
+ uint32_t n_rank_ffn_gate;
1143
+ uint32_t n_rank_ffn_down;
1144
+ uint32_t n_rank_ffn_up;
1145
+ uint32_t n_rank_tok_embeddings;
1146
+ uint32_t n_rank_norm;
1147
+ uint32_t n_rank_output;
1148
+
1149
+ bool custom_n_rank_attention_norm;
1150
+ bool custom_n_rank_wq;
1151
+ bool custom_n_rank_wk;
1152
+ bool custom_n_rank_wv;
1153
+ bool custom_n_rank_wo;
1154
+ bool custom_n_rank_ffn_norm;
1155
+ bool custom_n_rank_ffn_gate;
1156
+ bool custom_n_rank_ffn_down;
1157
+ bool custom_n_rank_ffn_up;
1158
+ bool custom_n_rank_tok_embeddings;
1159
+ bool custom_n_rank_norm;
1160
+ bool custom_n_rank_output;
1161
+ };
1162
+
1163
+ static struct train_params get_default_train_params() {
1164
+ struct train_params params;
1165
+ params.common = get_default_train_params_common();
1166
+ params.fn_model_base = "";
1167
+ params.fn_lora_out = "ggml-lora-ITERATION-f32.gguf";
1168
+
1169
+ params.only_write_lora = false;
1170
+
1171
+ params.f_norm_rms_eps = 1e-5f;
1172
+ params.rope_freq_base = 10000.0f;
1173
+ params.rope_freq_scale = 1.0f;
1174
+
1175
+ params.custom_f_norm_rms_eps = false;
1176
+ params.custom_rope_freq_base = false;
1177
+ params.custom_rope_freq_scale = false;
1178
+
1179
+ params.lora_r = 4;
1180
+ params.lora_alpha = 4;
1181
+ params.custom_lora_alpha = false;
1182
+
1183
+ params.n_rank_attention_norm = 1;
1184
+ params.n_rank_wq = 4;
1185
+ params.n_rank_wk = 4;
1186
+ params.n_rank_wv = 4;
1187
+ params.n_rank_wo = 4;
1188
+ params.n_rank_ffn_norm = 1;
1189
+ params.n_rank_ffn_gate = 4;
1190
+ params.n_rank_ffn_down = 4;
1191
+ params.n_rank_ffn_up = 4;
1192
+ params.n_rank_tok_embeddings = 4;
1193
+ params.n_rank_norm = 1;
1194
+ params.n_rank_output = 4;
1195
+
1196
+ params.custom_n_rank_attention_norm = false;
1197
+ params.custom_n_rank_wq = false;
1198
+ params.custom_n_rank_wk = false;
1199
+ params.custom_n_rank_wv = false;
1200
+ params.custom_n_rank_wo = false;
1201
+ params.custom_n_rank_ffn_norm = false;
1202
+ params.custom_n_rank_ffn_gate = false;
1203
+ params.custom_n_rank_ffn_down = false;
1204
+ params.custom_n_rank_ffn_up = false;
1205
+ params.custom_n_rank_tok_embeddings = false;
1206
+ params.custom_n_rank_norm = false;
1207
+ params.custom_n_rank_output = false;
1208
+
1209
+ return params;
1210
+ }
1211
+
1212
+ static void train_print_usage(int argc, char ** argv, const struct train_params * params) {
1213
+ fprintf(stderr, "usage: %s [options]\n", argv[0]);
1214
+ fprintf(stderr, "\n");
1215
+ fprintf(stderr, "options:\n");
1216
+ fprintf(stderr, " -h, --help show this help message and exit\n");
1217
+
1218
+ fprintf(stderr, " --model-base FNAME model path from which to load base model (default '%s')\n", params->fn_model_base);
1219
+ fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out);
1220
+ fprintf(stderr, " --only-write-lora only save llama lora, don't do any training. use this if you only want to convert a checkpoint to a lora adapter.\n");
1221
+ fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps);
1222
+ fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base);
1223
+ fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale);
1224
+ fprintf(stderr, " --lora-alpha N LORA alpha : resulting LORA scaling is alpha/r. (default %d)\n", params->lora_alpha);
1225
+ fprintf(stderr, " --lora-r N LORA r: default rank. Also specifies resulting scaling together with lora-alpha. (default %d)\n", params->lora_r);
1226
+ fprintf(stderr, " --rank-att-norm N LORA rank for attention norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
1227
+ fprintf(stderr, " --rank-ffn-norm N LORA rank for feed-forward norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
1228
+ fprintf(stderr, " --rank-out-norm N LORA rank for output norm tensor, overrides default rank. Norm tensors should generally have rank 1.\n");
1229
+ fprintf(stderr, " --rank-tok-embd N LORA rank for token embeddings tensor, overrides default rank.\n");
1230
+ fprintf(stderr, " --rank-out N LORA rank for output tensor, overrides default rank.\n");
1231
+ fprintf(stderr, " --rank-wq N LORA rank for wq tensor, overrides default rank.\n");
1232
+ fprintf(stderr, " --rank-wk N LORA rank for wk tensor, overrides default rank.\n");
1233
+ fprintf(stderr, " --rank-wv N LORA rank for wv tensor, overrides default rank.\n");
1234
+ fprintf(stderr, " --rank-wo N LORA rank for wo tensor, overrides default rank.\n");
1235
+ fprintf(stderr, " --rank-ffn_gate N LORA rank for ffn_gate tensor, overrides default rank.\n");
1236
+ fprintf(stderr, " --rank-ffn_down N LORA rank for ffn_down tensor, overrides default rank.\n");
1237
+ fprintf(stderr, " --rank-ffn_up N LORA rank for ffn_up tensor, overrides default rank.\n");
1238
+
1239
+ print_common_train_usage(argc, argv, &params->common);
1240
+ }
1241
+
1242
+ static bool train_params_parse(int argc, char ** argv, struct train_params * params) {
1243
+ bool invalid_param = false;
1244
+ std::string arg;
1245
+ struct train_params default_params = get_default_train_params();
1246
+ const std::string arg_prefix = "--";
1247
+
1248
+ for (int i = 1; i < argc; i++) {
1249
+ arg = argv[i];
1250
+ if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
1251
+ std::replace(arg.begin(), arg.end(), '_', '-');
1252
+ }
1253
+
1254
+ if (consume_common_train_arg(argc, argv, &i, &params->common, &invalid_param)) {
1255
+ if (invalid_param) {
1256
+ break;
1257
+ } else if (params->common.print_usage) {
1258
+ train_print_usage(argc, argv, &default_params);
1259
+ exit(0);
1260
+ }
1261
+ } else if (arg == "--model-base") {
1262
+ if (++i >= argc) {
1263
+ invalid_param = true;
1264
+ break;
1265
+ }
1266
+ params->fn_model_base = argv[i];
1267
+ } else if (arg == "--lora-out") {
1268
+ if (++i >= argc) {
1269
+ invalid_param = true;
1270
+ break;
1271
+ }
1272
+ params->fn_lora_out = argv[i];
1273
+ } else if (arg == "--only-write-lora") {
1274
+ params->only_write_lora = true;
1275
+ } else if (arg == "--norm-rms-eps") {
1276
+ if (++i >= argc) {
1277
+ invalid_param = true;
1278
+ break;
1279
+ }
1280
+ params->f_norm_rms_eps = std::stof(argv[i]);
1281
+ params->custom_f_norm_rms_eps = true;
1282
+ } else if (arg == "--rope-freq-base") {
1283
+ if (++i >= argc) {
1284
+ invalid_param = true;
1285
+ break;
1286
+ }
1287
+ params->rope_freq_base = std::stof(argv[i]);
1288
+ params->custom_rope_freq_base = true;
1289
+ } else if (arg == "--rope-freq-scale") {
1290
+ if (++i >= argc) {
1291
+ invalid_param = true;
1292
+ break;
1293
+ }
1294
+ params->rope_freq_scale = std::stof(argv[i]);
1295
+ params->custom_rope_freq_scale = true;
1296
+ } else if (arg == "--lora-alpha") {
1297
+ if (++i >= argc) {
1298
+ invalid_param = true;
1299
+ break;
1300
+ }
1301
+ params->lora_alpha = std::stoi(argv[i]);
1302
+ params->custom_lora_alpha = true;
1303
+ } else if (arg == "--lora-r") {
1304
+ if (++i >= argc) {
1305
+ invalid_param = true;
1306
+ break;
1307
+ }
1308
+ params->lora_r = std::stoi(argv[i]);
1309
+ } else if (arg == "--rank-att-norm") {
1310
+ if (++i >= argc) {
1311
+ invalid_param = true;
1312
+ break;
1313
+ }
1314
+ params->n_rank_attention_norm = std::stoi(argv[i]);
1315
+ params->custom_n_rank_attention_norm = true;
1316
+ } else if (arg == "--rank-ffn-norm") {
1317
+ if (++i >= argc) {
1318
+ invalid_param = true;
1319
+ break;
1320
+ }
1321
+ params->n_rank_ffn_norm = std::stoi(argv[i]);
1322
+ params->custom_n_rank_ffn_norm = true;
1323
+ } else if (arg == "--rank-out-norm") {
1324
+ if (++i >= argc) {
1325
+ invalid_param = true;
1326
+ break;
1327
+ }
1328
+ params->n_rank_norm = std::stoi(argv[i]);
1329
+ params->custom_n_rank_norm = true;
1330
+ } else if (arg == "--rank-tok-embd") {
1331
+ if (++i >= argc) {
1332
+ invalid_param = true;
1333
+ break;
1334
+ }
1335
+ params->n_rank_tok_embeddings = std::stoi(argv[i]);
1336
+ params->custom_n_rank_tok_embeddings = true;
1337
+ } else if (arg == "--rank-out") {
1338
+ if (++i >= argc) {
1339
+ invalid_param = true;
1340
+ break;
1341
+ }
1342
+ params->n_rank_output = std::stoi(argv[i]);
1343
+ params->custom_n_rank_output = true;
1344
+ } else if (arg == "--rank-wq") {
1345
+ if (++i >= argc) {
1346
+ invalid_param = true;
1347
+ break;
1348
+ }
1349
+ params->n_rank_wq = std::stoi(argv[i]);
1350
+ params->custom_n_rank_wq = true;
1351
+ } else if (arg == "--rank-wk") {
1352
+ if (++i >= argc) {
1353
+ invalid_param = true;
1354
+ break;
1355
+ }
1356
+ params->n_rank_wk = std::stoi(argv[i]);
1357
+ params->custom_n_rank_wk = true;
1358
+ } else if (arg == "--rank-wv") {
1359
+ if (++i >= argc) {
1360
+ invalid_param = true;
1361
+ break;
1362
+ }
1363
+ params->n_rank_wv = std::stoi(argv[i]);
1364
+ params->custom_n_rank_wv = true;
1365
+ } else if (arg == "--rank-wo") {
1366
+ if (++i >= argc) {
1367
+ invalid_param = true;
1368
+ break;
1369
+ }
1370
+ params->n_rank_wo = std::stoi(argv[i]);
1371
+ params->custom_n_rank_wo = true;
1372
+ } else if (arg == "--rank-ffn_gate") {
1373
+ if (++i >= argc) {
1374
+ invalid_param = true;
1375
+ break;
1376
+ }
1377
+ params->n_rank_ffn_gate = std::stoi(argv[i]);
1378
+ params->custom_n_rank_ffn_gate = true;
1379
+ } else if (arg == "--rank-ffn_down") {
1380
+ if (++i >= argc) {
1381
+ invalid_param = true;
1382
+ break;
1383
+ }
1384
+ params->n_rank_ffn_down = std::stoi(argv[i]);
1385
+ params->custom_n_rank_ffn_down = true;
1386
+ } else if (arg == "--rank-ffn_up") {
1387
+ if (++i >= argc) {
1388
+ invalid_param = true;
1389
+ break;
1390
+ }
1391
+ params->n_rank_ffn_up = std::stoi(argv[i]);
1392
+ params->custom_n_rank_ffn_up = true;
1393
+ } else {
1394
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
1395
+ train_print_usage(argc, argv, &default_params);
1396
+ exit(1);
1397
+ }
1398
+ }
1399
+ if (invalid_param) {
1400
+ fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
1401
+ train_print_usage(argc, argv, &default_params);
1402
+ exit(1);
1403
+ }
1404
+ finish_processing_train_args(&params->common);
1405
+ return true;
1406
+ }
1407
+
1408
+ struct save_train_files_data {
1409
+ const char * fn_checkpoint_out;
1410
+ const char * fn_lora_out;
1411
+ const char * pattern_fn_it;
1412
+ const char * fn_latest;
1413
+ struct my_llama_model * model;
1414
+ struct my_llama_lora * lora;
1415
+ };
1416
+
1417
+ static void save_train_files(void * vdata, struct train_state * train) {
1418
+ struct save_train_files_data * data = (struct save_train_files_data *) vdata;
1419
+
1420
+ int64_t iter = train->opt->iter;
1421
+
1422
+ if (strlen(data->fn_checkpoint_out) > 0) {
1423
+ save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->model, data->lora, train);
1424
+ save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->model, data->lora, train);
1425
+ }
1426
+ if (strlen(data->fn_lora_out) > 0) {
1427
+ save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->lora);
1428
+ save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->lora);
1429
+ }
1430
+ }
1431
+
1432
+ static int64_t get_parameter_count(struct my_llama_lora* lora) {
1433
+ int64_t nx = 0;
1434
+ nx += ggml_nelements(lora->tok_embeddings_a);
1435
+ nx += ggml_nelements(lora->tok_embeddings_b);
1436
+ nx += ggml_nelements(lora->norm_a);
1437
+ nx += ggml_nelements(lora->norm_b);
1438
+ nx += ggml_nelements(lora->output_a);
1439
+ nx += ggml_nelements(lora->output_b);
1440
+
1441
+ for (uint32_t i = 0; i < lora->layers.size(); ++i) {
1442
+ auto & layer = lora->layers[i];
1443
+ nx += ggml_nelements(layer.attention_norm_a);
1444
+ nx += ggml_nelements(layer.attention_norm_b);
1445
+ nx += ggml_nelements(layer.wq_a);
1446
+ nx += ggml_nelements(layer.wq_b);
1447
+ nx += ggml_nelements(layer.wk_a);
1448
+ nx += ggml_nelements(layer.wk_b);
1449
+ nx += ggml_nelements(layer.wv_a);
1450
+ nx += ggml_nelements(layer.wv_b);
1451
+ nx += ggml_nelements(layer.wo_a);
1452
+ nx += ggml_nelements(layer.wo_b);
1453
+ nx += ggml_nelements(layer.ffn_norm_a);
1454
+ nx += ggml_nelements(layer.ffn_norm_b);
1455
+ nx += ggml_nelements(layer.ffn_gate_a);
1456
+ nx += ggml_nelements(layer.ffn_gate_b);
1457
+ nx += ggml_nelements(layer.ffn_down_a);
1458
+ nx += ggml_nelements(layer.ffn_down_b);
1459
+ nx += ggml_nelements(layer.ffn_up_a);
1460
+ nx += ggml_nelements(layer.ffn_up_b);
1461
+ }
1462
+ return nx;
1463
+ }
1464
+
1465
+ int main(int argc, char ** argv) {
1466
+ struct train_params params = get_default_train_params();
1467
+
1468
+ if (!train_params_parse(argc, argv, &params)) {
1469
+ return 1;
1470
+ }
1471
+
1472
+ if (params.common.seed == LLAMA_DEFAULT_SEED) {
1473
+ params.common.seed = time(NULL);
1474
+ }
1475
+ printf("%s: seed: %u\n", __func__, params.common.seed);
1476
+ srand(params.common.seed);
1477
+
1478
+ struct llama_model_params llama_mparams = llama_model_default_params();
1479
+ llama_mparams.n_gpu_layers = params.common.n_gpu_layers;
1480
+ llama_mparams.vocab_only = false;
1481
+
1482
+ printf("%s: model base = '%s'\n", __func__, params.fn_model_base);
1483
+ struct llama_model * lmodel = llama_load_model_from_file(params.fn_model_base, llama_mparams);
1484
+
1485
+ struct llama_context_params llama_cparams = llama_context_default_params();
1486
+ struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_cparams);
1487
+
1488
+ struct my_llama_model model;
1489
+ init_model(lmodel, &model, params.fn_model_base, params.common.n_ctx);
1490
+
1491
+ struct my_llama_lora lora;
1492
+
1493
+ struct train_state * train = init_train_state();
1494
+ struct ggml_opt_context * opt = train->opt;
1495
+
1496
+ // set params from command line
1497
+ if (params.custom_f_norm_rms_eps) {
1498
+ model.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
1499
+ }
1500
+ if (params.custom_rope_freq_base) {
1501
+ model.hparams.rope_freq_base = params.rope_freq_base;
1502
+ }
1503
+ if (params.custom_rope_freq_scale) {
1504
+ model.hparams.rope_freq_scale = params.rope_freq_scale;
1505
+ }
1506
+ lora.hparams.lora_r = params.lora_r;
1507
+ lora.hparams.lora_alpha = params.custom_lora_alpha ? params.lora_alpha : params.lora_r;
1508
+ uint32_t n_rank_attention_norm = params.custom_n_rank_attention_norm ? params.n_rank_attention_norm : 1;
1509
+ uint32_t n_rank_wq = params.custom_n_rank_wq ? params.n_rank_wq : params.lora_r;
1510
+ uint32_t n_rank_wk = params.custom_n_rank_wk ? params.n_rank_wk : params.lora_r;
1511
+ uint32_t n_rank_wv = params.custom_n_rank_wv ? params.n_rank_wv : params.lora_r;
1512
+ uint32_t n_rank_wo = params.custom_n_rank_wo ? params.n_rank_wo : params.lora_r;
1513
+ uint32_t n_rank_ffn_norm = params.custom_n_rank_ffn_norm ? params.n_rank_ffn_norm : 1;
1514
+ uint32_t n_rank_ffn_gate = params.custom_n_rank_ffn_gate ? params.n_rank_ffn_gate : params.lora_r;
1515
+ uint32_t n_rank_ffn_down = params.custom_n_rank_ffn_down ? params.n_rank_ffn_down : params.lora_r;
1516
+ uint32_t n_rank_ffn_up = params.custom_n_rank_ffn_up ? params.n_rank_ffn_up : params.lora_r;
1517
+ uint32_t n_rank_tok_embeddings = params.custom_n_rank_tok_embeddings ? params.n_rank_tok_embeddings : params.lora_r;
1518
+ uint32_t n_rank_norm = params.custom_n_rank_norm ? params.n_rank_norm : 1;
1519
+ uint32_t n_rank_output = params.custom_n_rank_output ? params.n_rank_output : params.lora_r;
1520
+ lora.hparams.n_rank_attention_norm = n_rank_attention_norm;
1521
+ lora.hparams.n_rank_wq = n_rank_wq;
1522
+ lora.hparams.n_rank_wk = n_rank_wk;
1523
+ lora.hparams.n_rank_wv = n_rank_wv;
1524
+ lora.hparams.n_rank_wo = n_rank_wo;
1525
+ lora.hparams.n_rank_ffn_norm = n_rank_ffn_norm;
1526
+ lora.hparams.n_rank_ffn_gate = n_rank_ffn_gate;
1527
+ lora.hparams.n_rank_ffn_down = n_rank_ffn_down;
1528
+ lora.hparams.n_rank_ffn_up = n_rank_ffn_up;
1529
+ lora.hparams.n_rank_tok_embeddings = n_rank_tok_embeddings;
1530
+ lora.hparams.n_rank_norm = n_rank_norm;
1531
+ lora.hparams.n_rank_output = n_rank_output;
1532
+
1533
+ // set opt params from command line
1534
+ opt->params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM);
1535
+ opt->params.print_forward_graph = false;
1536
+ opt->params.print_backward_graph = false;
1537
+ opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
1538
+ opt->params.n_threads = params.common.n_threads;
1539
+ opt->params.past = params.common.opt_past;
1540
+ opt->params.delta = params.common.opt_delta;
1541
+ opt->params.max_no_improvement = params.common.opt_max_no_improvement;
1542
+ opt->params.n_gradient_accumulation = params.common.n_gradient_accumulation;
1543
+ opt->params.adam.n_iter = params.common.adam_n_iter;
1544
+ opt->params.adam.sched = 1.0f;
1545
+ opt->params.adam.alpha = params.common.adam_alpha;
1546
+ opt->params.adam.decay = params.common.adam_decay;
1547
+ opt->params.adam.decay_min_ndim = params.common.adam_decay_min_ndim;
1548
+ opt->params.adam.beta1 = params.common.adam_beta1;
1549
+ opt->params.adam.beta2 = params.common.adam_beta2;
1550
+ opt->params.adam.gclip = params.common.adam_gclip;
1551
+ opt->params.adam.eps_f = params.common.adam_eps_f;
1552
+
1553
+ printf("%s: init model\n", __func__);
1554
+ bool existed = load_checkpoint_lora_file(params.common.fn_checkpoint_in, &model, &lora, train);
1555
+
1556
+ if (existed) {
1557
+ // overwrite last n_ctx with user provided n_ctx
1558
+ if (params.common.custom_n_ctx) {
1559
+ model.hparams.n_ctx = params.common.n_ctx;
1560
+ }
1561
+
1562
+ const bool opt_param_count_changed = (
1563
+ (lora.hparams.n_rank_attention_norm != n_rank_attention_norm)
1564
+ || (lora.hparams.n_rank_wq != n_rank_wq)
1565
+ || (lora.hparams.n_rank_wk != n_rank_wk)
1566
+ || (lora.hparams.n_rank_wv != n_rank_wv)
1567
+ || (lora.hparams.n_rank_wo != n_rank_wo)
1568
+ || (lora.hparams.n_rank_ffn_norm != n_rank_ffn_norm)
1569
+ || (lora.hparams.n_rank_ffn_gate != n_rank_ffn_gate)
1570
+ || (lora.hparams.n_rank_ffn_down != n_rank_ffn_down)
1571
+ || (lora.hparams.n_rank_ffn_up != n_rank_ffn_up)
1572
+ || (lora.hparams.n_rank_tok_embeddings != n_rank_tok_embeddings)
1573
+ || (lora.hparams.n_rank_norm != n_rank_norm)
1574
+ || (lora.hparams.n_rank_output != n_rank_output)
1575
+ );
1576
+
1577
+ const bool opt_past_changed = opt->params.past != params.common.opt_past;
1578
+
1579
+ if (opt_param_count_changed) {
1580
+ print_lora_params(&lora.hparams);
1581
+ die("Provided rank differs from checkpoint file. To use different rank start finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting.");
1582
+ // need to discard previous optimizer gradient statistics and opt_init with new shapes
1583
+ // TODO
1584
+ }
1585
+ if (opt_past_changed) {
1586
+ die("Optimizer parameter '--opt-past N' differs from checkpoint file. To use different value finetune from scratch with empty input checkpoint, e.g --checkpoint-in ''. Aborting");
1587
+ // need to discard previous optimizer past function value statistics and opt_init with new shapes
1588
+ // TODO
1589
+ }
1590
+ } else { // existed == false
1591
+ init_lora(&model, &lora);
1592
+ randomize_lora(&lora, params.common.seed, 0.0f, 1.0f, -1.0f, +1.0f);
1593
+ if (!params.only_write_lora) {
1594
+ ggml_opt_init(opt->ctx, opt, opt->params, get_parameter_count(&lora));
1595
+ }
1596
+ }
1597
+ opt->iter = train->train_its;
1598
+
1599
+ print_params(&model.hparams);
1600
+ print_lora_params(&lora.hparams);
1601
+ printf("%s: total train_iterations %llu\n", __func__, (long long unsigned) train->train_its);
1602
+ printf("%s: seen train_samples %llu\n", __func__, (long long unsigned) train->train_samples);
1603
+ printf("%s: seen train_tokens %llu\n", __func__, (long long unsigned) train->train_tokens);
1604
+ printf("%s: completed train_epochs %llu\n", __func__, (long long unsigned) train->train_epochs);
1605
+ printf("%s: lora_size = %zu bytes (%.1f MB)\n", __func__, (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)), (float) (ggml_used_mem(lora.ctx) + ggml_backend_buffer_get_size(lora.data)) / (1024.0f*1024.0f));
1606
+
1607
+ if (params.only_write_lora) {
1608
+ save_train_files_data save_data;
1609
+ save_data.fn_checkpoint_out = "";
1610
+ save_data.fn_lora_out = params.fn_lora_out;
1611
+ save_data.pattern_fn_it = params.common.pattern_fn_it;
1612
+ save_data.fn_latest = params.common.fn_latest;
1613
+ save_data.model = &model;
1614
+ save_data.lora = &lora;
1615
+
1616
+ save_train_files(&save_data, train);
1617
+
1618
+ free_train_state(train);
1619
+ ggml_free(lora.ctx);
1620
+ llama_free(lctx);
1621
+ llama_free_model(lmodel);
1622
+ return 0;
1623
+ }
1624
+
1625
+ printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
1626
+ printf("%s: opt iter %d\n", __func__, opt->iter);
1627
+
1628
+ int n_tokens = model.hparams.n_ctx;
1629
+ int n_vocab = model.hparams.n_vocab;
1630
+ int n_batch = params.common.n_batch;
1631
+
1632
+ // context for input tensors without their data
1633
+ struct ggml_init_params ctx_input_params = {
1634
+ ggml_tensor_overhead() * 2, // mem_size
1635
+ NULL, // mem_buffer
1636
+ true, // no_alloc
1637
+ };
1638
+ struct ggml_context * ctx_input = ggml_init(ctx_input_params);
1639
+
1640
+ // the input tensors
1641
+ struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch);
1642
+ struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch);
1643
+
1644
+ // allocate input tensors
1645
+ // measure required memory for input tensors
1646
+ ggml_backend_buffer_t input_data = ggml_backend_alloc_ctx_tensors_from_buft(ctx_input, ggml_backend_cpu_buffer_type());
1647
+ size_t max_input_size = ggml_backend_buffer_get_size(input_data);
1648
+ printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f));
1649
+
1650
+ // context for compute tensors without their data
1651
+ const size_t estimated_compute_size_wo_data = (
1652
+ 2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
1653
+ (params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
1654
+ );
1655
+ struct ggml_init_params ctx_compute_params = {
1656
+ estimated_compute_size_wo_data, // mem_size
1657
+ NULL, // mem_buffer
1658
+ true, // no_alloc
1659
+ };
1660
+ struct ggml_context * ctx_compute = NULL;
1661
+
1662
+ struct ggml_tensor * loss = NULL;
1663
+ struct ggml_tensor * logits = NULL;
1664
+
1665
+ struct ggml_cgraph * gf = NULL;
1666
+ struct ggml_cgraph * gb = NULL;
1667
+ struct ggml_cgraph * gb_tmp = NULL;
1668
+
1669
+ // measure required memory for compute tensors
1670
+ size_t best_compute_size = SIZE_MAX;
1671
+ enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
1672
+ // find best evaluation order
1673
+ for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
1674
+ ctx_compute = ggml_init(ctx_compute_params);
1675
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
1676
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1677
+ gf->order = (enum ggml_cgraph_eval_order) order;
1678
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1679
+ gb_tmp = params.common.use_checkpointing
1680
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
1681
+ : NULL;
1682
+ loss = llama_build_lora_finetune_graphs(
1683
+ &model, &lora, alloc, ctx_compute,
1684
+ gf, gb, gb_tmp,
1685
+ &logits, tokens_input, target_probs,
1686
+ n_tokens, n_batch,
1687
+ params.common.use_flash,
1688
+ params.common.use_checkpointing,
1689
+ true
1690
+ );
1691
+ size_t max_compute_size = ggml_gallocr_get_buffer_size(alloc, 0); // FIXME: this will still allocate the buffer
1692
+ if (max_compute_size < best_compute_size) {
1693
+ best_compute_size = max_compute_size;
1694
+ best_order = gf->order;
1695
+ }
1696
+ ggml_gallocr_free(alloc);
1697
+ ggml_free(ctx_compute);
1698
+ }
1699
+ size_t max_compute_size = best_compute_size;
1700
+ printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
1701
+ printf("%s: evaluation order = %s\n", __func__,
1702
+ (best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
1703
+ (best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
1704
+ "invalid");
1705
+
1706
+ // allocate compute tensors
1707
+ ctx_compute = ggml_init(ctx_compute_params);
1708
+ ggml_gallocr_t alloc = ggml_gallocr_new(ggml_backend_cpu_buffer_type());
1709
+ gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1710
+ gf->order = best_order;
1711
+ gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
1712
+ gb_tmp = params.common.use_checkpointing
1713
+ ? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
1714
+ : NULL;
1715
+ loss = llama_build_lora_finetune_graphs(
1716
+ &model, &lora, alloc, ctx_compute,
1717
+ gf, gb, gb_tmp,
1718
+ &logits, tokens_input, target_probs,
1719
+ n_tokens, n_batch,
1720
+ params.common.use_flash,
1721
+ params.common.use_checkpointing,
1722
+ false
1723
+ );
1724
+
1725
+ // tokenize data
1726
+ std::vector<llama_token> train_tokens;
1727
+ std::vector<size_t> train_samples_begin;
1728
+ std::vector<size_t> train_samples_size;
1729
+ printf("%s: tokenize training data from %s\n", __func__, params.common.fn_train_data);
1730
+ printf("%s: sample-start: %s\n", __func__, params.common.sample_start.c_str());
1731
+ printf("%s: include-sample-start: %s\n", __func__, params.common.include_sample_start ? "true" : "false");
1732
+ tokenize_file(lctx,
1733
+ params.common.fn_train_data,
1734
+ params.common.sample_start,
1735
+ params.common.include_sample_start,
1736
+ params.common.overlapping_samples,
1737
+ n_tokens,
1738
+ train_tokens,
1739
+ train_samples_begin,
1740
+ train_samples_size);
1741
+ GGML_ASSERT(train_samples_begin.size() == train_samples_size.size());
1742
+
1743
+ printf("%s: number of training tokens: %zu\n", __func__, train_tokens.size());
1744
+
1745
+ std::vector<size_t> token_noccurs;
1746
+ token_noccurs.resize(model.hparams.n_vocab, 0);
1747
+ for (unsigned int i = 0; i < train_tokens.size(); ++i) {
1748
+ ++token_noccurs[train_tokens[i]];
1749
+ }
1750
+ int n_unique_tokens = 0;
1751
+ for (unsigned int i = 0; i < token_noccurs.size(); ++i) {
1752
+ if (token_noccurs[i] == 0) continue;
1753
+ ++n_unique_tokens;
1754
+ }
1755
+ printf("%s: number of unique tokens: %d\n", __func__, n_unique_tokens);
1756
+
1757
+ size_t shuffle_samples_hash = compute_samples_hash(params.common.fn_train_data, train_samples_begin.data(), train_samples_size.data(), train_samples_size.size());
1758
+ const bool changed_train_data = (shuffle_samples_hash != train->shuffle_samples_hash) || (train->shuffle_sample_count != train_samples_size.size());
1759
+ if (changed_train_data) {
1760
+ printf("%s: train data seems to have changed. restarting shuffled epoch.\n", __func__);
1761
+ }
1762
+ if (params.common.force_reshuffle) {
1763
+ printf("%s: forced reshuffling of data. restarting with newly shuffled epoch.\n", __func__);
1764
+ }
1765
+ if ((train->shuffle_rng_state_current == "") || changed_train_data || params.common.force_reshuffle) {
1766
+ train->shuffle_rng_state_current = mt19937_seed_to_state(params.common.seed);
1767
+ train->shuffle_sample_count = train_samples_size.size();
1768
+ train->shuffle_next_sample = 0;
1769
+ train->shuffle_samples_hash = shuffle_samples_hash;
1770
+ }
1771
+ std::vector<size_t> train_shuffled_samples_offs;
1772
+ std::vector<size_t> train_shuffled_samples_begin;
1773
+ std::vector<size_t> train_shuffled_samples_size;
1774
+ train_shuffled_samples_offs.resize(train_samples_begin.size());
1775
+ train_shuffled_samples_begin.resize(train_samples_begin.size());
1776
+ train_shuffled_samples_size.resize(train_samples_size.size());
1777
+ train->shuffle_rng_state_next = shuffle_samples(
1778
+ train->shuffle_rng_state_current,
1779
+ train_shuffled_samples_offs.data(),
1780
+ train_shuffled_samples_begin.data(),
1781
+ train_shuffled_samples_size.data(),
1782
+ train_samples_begin.data(),
1783
+ train_samples_size.data(),
1784
+ train_samples_size.size());
1785
+
1786
+ printf("%s: begin training\n", __func__);
1787
+
1788
+ save_train_files_data save_data;
1789
+ save_data.fn_checkpoint_out = params.common.fn_checkpoint_out;
1790
+ save_data.fn_lora_out = params.fn_lora_out;
1791
+ save_data.pattern_fn_it = params.common.pattern_fn_it;
1792
+ save_data.fn_latest = params.common.fn_latest;
1793
+ save_data.model = &model;
1794
+ save_data.lora = &lora;
1795
+
1796
+ struct train_opt_callback_data opt_cb_data;
1797
+ opt_cb_data.params = &params.common;
1798
+ opt_cb_data.train = train;
1799
+ opt_cb_data.save_cb = &save_train_files;
1800
+ opt_cb_data.save_data = &save_data;
1801
+ opt_cb_data.lctx = lctx;
1802
+ opt_cb_data.last_save_iter = opt->iter;
1803
+ opt_cb_data.tokens_data = train_tokens.data();
1804
+ opt_cb_data.tokens_size = train_tokens.size();
1805
+ opt_cb_data.samples_begin = train_samples_begin.data();
1806
+ opt_cb_data.samples_size = train_samples_size.data();
1807
+ opt_cb_data.shuffled_samples_offs = train_shuffled_samples_offs.data();
1808
+ opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data();
1809
+ opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data();
1810
+ opt_cb_data.samples_count = train_samples_size.size();
1811
+ opt_cb_data.tokens_input = tokens_input;
1812
+ opt_cb_data.target_probs = target_probs;
1813
+ opt_cb_data.first_iter = opt->iter;
1814
+ opt_cb_data.first_epoch = train->train_epochs;
1815
+ opt_cb_data.iter_at_last_epoch = -1;
1816
+ opt_cb_data.last_time = ggml_time_ms();
1817
+ opt_cb_data.millis_per_iter = 0.0;
1818
+
1819
+ // measure required memory for work buffer
1820
+ size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
1821
+ printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
1822
+
1823
+ // context for work buffer
1824
+ struct ggml_init_params ctx_work_params = {
1825
+ max_work_size, // mem_size
1826
+ NULL, // mem_buffer
1827
+ false, // no_alloc
1828
+ };
1829
+ struct ggml_context * ctx_work = ggml_init(ctx_work_params);
1830
+
1831
+ int64_t t0 = ggml_time_ms();
1832
+
1833
+ ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
1834
+
1835
+ ggml_free(ctx_work);
1836
+ ggml_free(ctx_compute);
1837
+ ggml_free(ctx_input);
1838
+ ggml_gallocr_free(alloc);
1839
+
1840
+
1841
+ int64_t t1 = ggml_time_ms();
1842
+ printf("%s: total training time: ", __func__);
1843
+ print_duration((double) (t1 - t0));
1844
+ printf("\n");
1845
+
1846
+ int new_iters = opt->iter - opt_cb_data.last_save_iter;
1847
+ if (new_iters > 0) {
1848
+ train->train_its += new_iters;
1849
+ train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
1850
+
1851
+ save_train_files(&save_data, train);
1852
+ opt_cb_data.last_save_iter = opt->iter;
1853
+ }
1854
+
1855
+ ggml_free(opt->ctx);
1856
+ free_train_state(train);
1857
+ ggml_free(lora.ctx);
1858
+ llama_free(lctx);
1859
+ llama_free_model(lmodel);
1860
+ return 0;
1861
+ }