@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -21,12 +21,6 @@
21
21
  #include "ggml.h"
22
22
  #include "llama.h"
23
23
  #include "common.h"
24
- #include "ggml-cuda.h"
25
- #include "ggml-sycl.h"
26
-
27
- #ifdef GGML_USE_CANN
28
- #include "ggml-cann.h"
29
- #endif
30
24
 
31
25
  #ifdef _WIN32
32
26
  #define WIN32_LEAN_AND_MEAN
@@ -82,95 +76,27 @@ static T stdev(const std::vector<T> & v) {
82
76
  }
83
77
 
84
78
  static std::string get_cpu_info() {
85
- std::string id;
86
- #ifdef __linux__
87
- FILE * f = fopen("/proc/cpuinfo", "r");
88
- if (f) {
89
- char buf[1024];
90
- while (fgets(buf, sizeof(buf), f)) {
91
- if (strncmp(buf, "model name", 10) == 0) {
92
- char * p = strchr(buf, ':');
93
- if (p) {
94
- p++;
95
- while (std::isspace(*p)) {
96
- p++;
97
- }
98
- while (std::isspace(p[strlen(p) - 1])) {
99
- p[strlen(p) - 1] = '\0';
100
- }
101
- id = p;
102
- break;
103
- }
104
- }
79
+ std::vector<std::string> cpu_list;
80
+ for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
81
+ auto * dev = ggml_backend_dev_get(i);
82
+ auto dev_type = ggml_backend_dev_type(dev);
83
+ if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU || dev_type == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
84
+ cpu_list.push_back(ggml_backend_dev_description(dev));
105
85
  }
106
- fclose(f);
107
86
  }
108
- #elif defined(_WIN32)
109
- HKEY hKey;
110
- if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
111
- TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
112
- 0,
113
- KEY_READ,
114
- &hKey) != ERROR_SUCCESS) {
115
- // fail to open registry key
116
- return "";
117
- }
118
- char cpu_brand[256];
119
- DWORD cpu_brand_size = sizeof(cpu_brand);
120
- if (RegQueryValueExA(hKey,
121
- TEXT("ProcessorNameString"),
122
- NULL,
123
- NULL,
124
- (LPBYTE)cpu_brand,
125
- &cpu_brand_size) == ERROR_SUCCESS) {
126
- id.assign(cpu_brand, cpu_brand_size);
127
- if (id.find('\0') != std::string::npos) {
128
- id.resize(id.find('\0'));
129
- }
130
- }
131
- RegCloseKey(hKey);
132
- #endif
133
- // TODO: other platforms
134
- return id;
87
+ return join(cpu_list, ", ");
135
88
  }
136
89
 
137
90
  static std::string get_gpu_info() {
138
- std::string id;
139
- #ifdef GGML_USE_CUDA
140
- int count = ggml_backend_cuda_get_device_count();
141
- for (int i = 0; i < count; i++) {
142
- char buf[128];
143
- ggml_backend_cuda_get_device_description(i, buf, sizeof(buf));
144
- id += buf;
145
- if (i < count - 1) {
146
- id += "/";
147
- }
148
- }
149
- #endif
150
- #ifdef GGML_USE_SYCL
151
- int count = ggml_backend_sycl_get_device_count();
152
- for (int i = 0; i < count; i++) {
153
- char buf[128];
154
- ggml_sycl_get_device_description(i, buf, sizeof(buf));
155
- id += buf;
156
- if (i < count - 1) {
157
- id += "/";
158
- }
159
- }
160
- #endif
161
- #ifdef GGML_USE_CANN
162
- uint32_t count = ggml_backend_cann_get_device_count();
163
- for (uint32_t i = 0; i < count; i++) {
164
- char buf[128];
165
- ggml_backend_cann_get_device_description(i, buf, sizeof(buf));
166
- id += buf;
167
- if (i < count - 1) {
168
- id += "/";
91
+ std::vector<std::string> gpu_list;
92
+ for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
93
+ auto * dev = ggml_backend_dev_get(i);
94
+ auto dev_type = ggml_backend_dev_type(dev);
95
+ if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU) {
96
+ gpu_list.push_back(ggml_backend_dev_description(dev));
169
97
  }
170
98
  }
171
- #endif
172
- // TODO: other backends
173
- return id;
99
+ return join(gpu_list, ", ");
174
100
  }
175
101
 
176
102
  // command line params
@@ -304,9 +230,9 @@ static void print_usage(int /* argc */, char ** argv) {
304
230
  printf(" --cpu-strict <0|1> (default: %s)\n", join(cmd_params_defaults.cpu_strict, ",").c_str());
305
231
  printf(" --poll <0...100> (default: %s)\n", join(cmd_params_defaults.poll, ",").c_str());
306
232
  printf(" -ngl, --n-gpu-layers <n> (default: %s)\n", join(cmd_params_defaults.n_gpu_layers, ",").c_str());
307
- #ifdef GGML_USE_RPC
308
- printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
309
- #endif
233
+ if (llama_supports_rpc()) {
234
+ printf(" -rpc, --rpc <rpc_servers> (default: %s)\n", join(cmd_params_defaults.rpc_servers, ",").c_str());
235
+ }
310
236
  printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
311
237
  printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
312
238
  printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
@@ -330,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) {
330
256
  if (s == "f16") {
331
257
  return GGML_TYPE_F16;
332
258
  }
259
+ if (s == "bf16") {
260
+ return GGML_TYPE_BF16;
261
+ }
333
262
  if (s == "q8_0") {
334
263
  return GGML_TYPE_Q8_0;
335
264
  }
@@ -497,14 +426,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
497
426
  }
498
427
  auto p = string_split<int>(argv[i], split_delim);
499
428
  params.n_gpu_layers.insert(params.n_gpu_layers.end(), p.begin(), p.end());
500
- #ifdef GGML_USE_RPC
501
- } else if (arg == "-rpc" || arg == "--rpc") {
429
+ } else if (llama_supports_rpc() && (arg == "-rpc" || arg == "--rpc")) {
502
430
  if (++i >= argc) {
503
431
  invalid_param = true;
504
432
  break;
505
433
  }
506
434
  params.rpc_servers.push_back(argv[i]);
507
- #endif
508
435
  } else if (arg == "-sm" || arg == "--split-mode") {
509
436
  if (++i >= argc) {
510
437
  invalid_param = true;
@@ -847,13 +774,6 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
847
774
  struct test {
848
775
  static const std::string build_commit;
849
776
  static const int build_number;
850
- static const bool cuda;
851
- static const bool vulkan;
852
- static const bool kompute;
853
- static const bool metal;
854
- static const bool sycl;
855
- static const bool gpu_blas;
856
- static const bool blas;
857
777
  static const std::string cpu_info;
858
778
  static const std::string gpu_info;
859
779
  std::string model_filename;
@@ -866,7 +786,6 @@ struct test {
866
786
  std::string cpu_mask;
867
787
  bool cpu_strict;
868
788
  int poll;
869
- bool has_rpc;
870
789
  ggml_type type_k;
871
790
  ggml_type type_v;
872
791
  int n_gpu_layers;
@@ -895,7 +814,6 @@ struct test {
895
814
  cpu_mask = inst.cpu_mask;
896
815
  cpu_strict = inst.cpu_strict;
897
816
  poll = inst.poll;
898
- has_rpc = !inst.rpc_servers.empty();
899
817
  type_k = inst.type_k;
900
818
  type_v = inst.type_v;
901
819
  n_gpu_layers = inst.n_gpu_layers;
@@ -940,36 +858,21 @@ struct test {
940
858
  }
941
859
 
942
860
  static std::string get_backend() {
943
- if (cuda) {
944
- return GGML_CUDA_NAME;
945
- }
946
- if (vulkan) {
947
- return "Vulkan";
948
- }
949
- if (kompute) {
950
- return "Kompute";
951
- }
952
- if (metal) {
953
- return "Metal";
954
- }
955
- if (sycl) {
956
- return GGML_SYCL_NAME;
957
- }
958
- if (gpu_blas) {
959
- return "GPU BLAS";
960
- }
961
- if (blas) {
962
- return "BLAS";
861
+ std::vector<std::string> backends;
862
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
863
+ auto * reg = ggml_backend_reg_get(i);
864
+ std::string name = ggml_backend_reg_name(reg);
865
+ if (name != "CPU") {
866
+ backends.push_back(ggml_backend_reg_name(reg));
867
+ }
963
868
  }
964
-
965
- return "CPU";
869
+ return backends.empty() ? "CPU" : join(backends, ",");
966
870
  }
967
871
 
968
872
  static const std::vector<std::string> & get_fields() {
969
873
  static const std::vector<std::string> fields = {
970
874
  "build_commit", "build_number",
971
- "cuda", "vulkan", "kompute", "metal", "sycl", "rpc", "gpu_blas", "blas",
972
- "cpu_info", "gpu_info",
875
+ "cpu_info", "gpu_info", "backends",
973
876
  "model_filename", "model_type", "model_size", "model_n_params",
974
877
  "n_batch", "n_ubatch",
975
878
  "n_threads", "cpu_mask", "cpu_strict", "poll",
@@ -995,8 +898,7 @@ struct test {
995
898
  field == "avg_ns" || field == "stddev_ns") {
996
899
  return INT;
997
900
  }
998
- if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
999
- field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
901
+ if (field == "f16_kv" || field == "no_kv_offload" ||
1000
902
  field == "cpu_strict" ||
1001
903
  field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
1002
904
  return BOOL;
@@ -1025,9 +927,7 @@ struct test {
1025
927
  }
1026
928
  std::vector<std::string> values = {
1027
929
  build_commit, std::to_string(build_number),
1028
- std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan),
1029
- std::to_string(metal), std::to_string(sycl), std::to_string(has_rpc), std::to_string(gpu_blas), std::to_string(blas),
1030
- cpu_info, gpu_info,
930
+ cpu_info, gpu_info, get_backend(),
1031
931
  model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
1032
932
  std::to_string(n_batch), std::to_string(n_ubatch),
1033
933
  std::to_string(n_threads), cpu_mask, std::to_string(cpu_strict), std::to_string(poll),
@@ -1054,13 +954,6 @@ struct test {
1054
954
 
1055
955
  const std::string test::build_commit = LLAMA_COMMIT;
1056
956
  const int test::build_number = LLAMA_BUILD_NUMBER;
1057
- const bool test::cuda = !!ggml_cpu_has_cuda();
1058
- const bool test::vulkan = !!ggml_cpu_has_vulkan();
1059
- const bool test::kompute = !!ggml_cpu_has_kompute();
1060
- const bool test::metal = !!ggml_cpu_has_metal();
1061
- const bool test::gpu_blas = !!ggml_cpu_has_gpublas();
1062
- const bool test::blas = !!ggml_cpu_has_blas();
1063
- const bool test::sycl = !!ggml_cpu_has_sycl();
1064
957
  const std::string test::cpu_info = get_cpu_info();
1065
958
  const std::string test::gpu_info = get_gpu_info();
1066
959
 
@@ -1265,7 +1158,8 @@ struct markdown_printer : public printer {
1265
1158
  fields.emplace_back("size");
1266
1159
  fields.emplace_back("params");
1267
1160
  fields.emplace_back("backend");
1268
- bool is_cpu_backend = test::get_backend() == "CPU" || test::get_backend() == "BLAS";
1161
+ bool is_cpu_backend = test::get_backend().find("CPU") != std::string::npos ||
1162
+ test::get_backend().find("BLAS") != std::string::npos;
1269
1163
  if (!is_cpu_backend) {
1270
1164
  fields.emplace_back("n_gpu_layers");
1271
1165
  }
@@ -1355,9 +1249,6 @@ struct markdown_printer : public printer {
1355
1249
  value = buf;
1356
1250
  } else if (field == "backend") {
1357
1251
  value = test::get_backend();
1358
- if (t.has_rpc) {
1359
- value += "+RPC";
1360
- }
1361
1252
  } else if (field == "test") {
1362
1253
  if (t.n_prompt > 0 && t.n_gen == 0) {
1363
1254
  snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
@@ -1430,7 +1321,7 @@ struct sql_printer : public printer {
1430
1321
  }
1431
1322
  };
1432
1323
 
1433
- static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
1324
+ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
1434
1325
  llama_set_n_threads(ctx, n_threads, n_threads);
1435
1326
 
1436
1327
  const llama_model * model = llama_get_model(ctx);
@@ -1446,14 +1337,14 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
1446
1337
  for (int i = 1; i < n_tokens; i++) {
1447
1338
  tokens[i] = std::rand() % n_vocab;
1448
1339
  }
1449
- llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
1340
+ llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
1450
1341
  n_processed += n_tokens;
1451
1342
  }
1452
1343
 
1453
1344
  llama_synchronize(ctx);
1454
1345
  }
1455
1346
 
1456
- static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
1347
+ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
1457
1348
  llama_set_n_threads(ctx, n_threads, n_threads);
1458
1349
 
1459
1350
  const llama_model * model = llama_get_model(ctx);
@@ -1462,7 +1353,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
1462
1353
  llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;
1463
1354
 
1464
1355
  for (int i = 0; i < n_gen; i++) {
1465
- llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
1356
+ llama_decode(ctx, llama_batch_get_one(&token, 1));
1466
1357
  llama_synchronize(ctx);
1467
1358
  token = std::rand() % n_vocab;
1468
1359
  }
@@ -1598,13 +1489,13 @@ int main(int argc, char ** argv) {
1598
1489
  fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
1599
1490
  }
1600
1491
  //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1601
- test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
1492
+ test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1602
1493
  }
1603
1494
  if (t.n_gen > 0) {
1604
1495
  if (params.progress) {
1605
1496
  fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
1606
1497
  }
1607
- test_gen(ctx, 1, 0, t.n_threads);
1498
+ test_gen(ctx, 1, t.n_threads);
1608
1499
  }
1609
1500
 
1610
1501
  for (int i = 0; i < params.reps; i++) {
@@ -1616,13 +1507,13 @@ int main(int argc, char ** argv) {
1616
1507
  if (params.progress) {
1617
1508
  fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
1618
1509
  }
1619
- test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
1510
+ test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
1620
1511
  }
1621
1512
  if (t.n_gen > 0) {
1622
1513
  if (params.progress) {
1623
1514
  fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
1624
1515
  }
1625
- test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
1516
+ test_gen(ctx, t.n_gen, t.n_threads);
1626
1517
  }
1627
1518
 
1628
1519
  uint64_t t_ns = get_time_ns() - t_start;
@@ -18,6 +18,7 @@ android {
18
18
  }
19
19
  externalNativeBuild {
20
20
  cmake {
21
+ arguments += "-DLLAMA_BUILD_COMMON=ON"
21
22
  arguments += "-DCMAKE_BUILD_TYPE=Release"
22
23
  cppFlags += listOf()
23
24
  arguments += listOf()
@@ -186,11 +186,11 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
186
186
  for (nri = 0; nri < nr; nri++) {
187
187
  LOGi("Benchmark prompt processing (pp)");
188
188
 
189
- llama_batch_clear(*batch);
189
+ common_batch_clear(*batch);
190
190
 
191
191
  const int n_tokens = pp;
192
192
  for (i = 0; i < n_tokens; i++) {
193
- llama_batch_add(*batch, 0, i, { 0 }, false);
193
+ common_batch_add(*batch, 0, i, { 0 }, false);
194
194
  }
195
195
 
196
196
  batch->logits[batch->n_tokens - 1] = true;
@@ -210,9 +210,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
210
210
  const auto t_tg_start = ggml_time_us();
211
211
  for (i = 0; i < tg; i++) {
212
212
 
213
- llama_batch_clear(*batch);
213
+ common_batch_clear(*batch);
214
214
  for (j = 0; j < pl; j++) {
215
- llama_batch_add(*batch, 0, i, { j }, true);
215
+ common_batch_add(*batch, 0, i, { j }, true);
216
216
  }
217
217
 
218
218
  LOGi("llama_decode() text generation: %d", i);
@@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
283
283
  nullptr,
284
284
  nullptr,
285
285
  nullptr,
286
- 0,
287
- 0,
288
- 0,
289
286
  };
290
287
 
291
288
  if (embd) {
@@ -357,7 +354,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
357
354
  const auto context = reinterpret_cast<llama_context *>(context_pointer);
358
355
  const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
359
356
 
360
- const auto tokens_list = llama_tokenize(context, text, 1);
357
+ const auto tokens_list = common_tokenize(context, text, 1);
361
358
 
362
359
  auto n_ctx = llama_n_ctx(context);
363
360
  auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
@@ -369,14 +366,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
369
366
  }
370
367
 
371
368
  for (auto id : tokens_list) {
372
- LOGi("%s", llama_token_to_piece(context, id).c_str());
369
+ LOGi("%s", common_token_to_piece(context, id).c_str());
373
370
  }
374
371
 
375
- llama_batch_clear(*batch);
372
+ common_batch_clear(*batch);
376
373
 
377
374
  // evaluate the initial prompt
378
375
  for (auto i = 0; i < tokens_list.size(); i++) {
379
- llama_batch_add(*batch, tokens_list[i], i, { 0 }, false);
376
+ common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
380
377
  }
381
378
 
382
379
  // llama_decode will output logits only for the last token of the prompt
@@ -419,7 +416,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
419
416
  return nullptr;
420
417
  }
421
418
 
422
- auto new_token_chars = llama_token_to_piece(context, new_token_id);
419
+ auto new_token_chars = common_token_to_piece(context, new_token_id);
423
420
  cached_token_chars += new_token_chars;
424
421
 
425
422
  jstring new_token = nullptr;
@@ -431,8 +428,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
431
428
  new_token = env->NewStringUTF("");
432
429
  }
433
430
 
434
- llama_batch_clear(*batch);
435
- llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
431
+ common_batch_clear(*batch);
432
+ common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
436
433
 
437
434
  env->CallVoidMethod(intvar_ncur, la_int_var_inc);
438
435
 
@@ -4,6 +4,7 @@
4
4
  // Note: Even when using identical normalized image inputs (see normalize_image_u8_to_f32()) we have a significant difference in resulting embeddings compared to pytorch
5
5
  #include "clip.h"
6
6
  #include "ggml.h"
7
+ #include "ggml-cpu.h"
7
8
  #include "ggml-alloc.h"
8
9
  #include "ggml-backend.h"
9
10
 
@@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
20
20
  if (n_eval > n_batch) {
21
21
  n_eval = n_batch;
22
22
  }
23
- if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
23
+ if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
24
24
  LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
25
25
  return false;
26
26
  }
@@ -37,21 +37,21 @@ static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
37
37
 
38
38
  static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
39
39
  std::string str2 = str;
40
- std::vector<llama_token> embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos, true);
40
+ std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
41
41
  eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
42
42
  return true;
43
43
  }
44
44
 
45
- static const char * sample(struct gpt_sampler * smpl,
45
+ static const char * sample(struct common_sampler * smpl,
46
46
  struct llama_context * ctx_llama,
47
47
  int * n_past) {
48
- const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
49
- gpt_sampler_accept(smpl, id, true);
48
+ const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
49
+ common_sampler_accept(smpl, id, true);
50
50
  static std::string ret;
51
51
  if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
52
52
  ret = "</s>";
53
53
  } else {
54
- ret = llama_token_to_piece(ctx_llama, id);
54
+ ret = common_token_to_piece(ctx_llama, id);
55
55
  }
56
56
  eval_id(ctx_llama, id, n_past);
57
57
  return ret.c_str();
@@ -120,7 +120,7 @@ static void print_usage(int, char ** argv) {
120
120
  LOG("\n note: a lower temperature value like 0.1 is recommended for better quality.\n");
121
121
  }
122
122
 
123
- static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, const std::string & fname) {
123
+ static struct llava_image_embed * load_image(llava_context * ctx_llava, common_params * params, const std::string & fname) {
124
124
 
125
125
  // load and preprocess the image
126
126
  llava_image_embed * embed = NULL;
@@ -146,7 +146,7 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
146
146
  return embed;
147
147
  }
148
148
 
149
- static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
149
+ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, common_params * params, const std::string & prompt) {
150
150
  int n_past = 0;
151
151
 
152
152
  const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
@@ -159,16 +159,16 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
159
159
  user_prompt = prompt.substr(image_pos + std::string("<image>").length());
160
160
  LOG_INF("system_prompt: %s\n", system_prompt.c_str());
161
161
  if (params->verbose_prompt) {
162
- auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
162
+ auto tmp = common_tokenize(ctx_llava->ctx_llama, system_prompt, true, true);
163
163
  for (int i = 0; i < (int) tmp.size(); i++) {
164
- LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
164
+ LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
165
165
  }
166
166
  }
167
167
  LOG_INF("user_prompt: %s\n", user_prompt.c_str());
168
168
  if (params->verbose_prompt) {
169
- auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
169
+ auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
170
170
  for (int i = 0; i < (int) tmp.size(); i++) {
171
- LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
171
+ LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
172
172
  }
173
173
  }
174
174
  } else {
@@ -176,9 +176,9 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
176
176
  system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:";
177
177
  user_prompt = prompt + "\nASSISTANT:";
178
178
  if (params->verbose_prompt) {
179
- auto tmp = ::llama_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
179
+ auto tmp = common_tokenize(ctx_llava->ctx_llama, user_prompt, true, true);
180
180
  for (int i = 0; i < (int) tmp.size(); i++) {
181
- LOG_INF("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
181
+ LOG_INF("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_llava->ctx_llama, tmp[i]).c_str());
182
182
  }
183
183
  }
184
184
  }
@@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
191
191
 
192
192
  LOG("\n");
193
193
 
194
- struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
194
+ struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
195
195
  if (!smpl) {
196
196
  LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
197
197
  exit(1);
@@ -211,15 +211,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
211
211
  fflush(stdout);
212
212
  }
213
213
 
214
- gpt_sampler_free(smpl);
214
+ common_sampler_free(smpl);
215
215
  LOG("\n");
216
216
  }
217
217
 
218
- static struct llama_model * llava_init(gpt_params * params) {
218
+ static struct llama_model * llava_init(common_params * params) {
219
219
  llama_backend_init();
220
220
  llama_numa_init(params->numa);
221
221
 
222
- llama_model_params model_params = llama_model_params_from_gpt_params(*params);
222
+ llama_model_params model_params = common_model_params_to_llama(*params);
223
223
 
224
224
  llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
225
225
  if (model == NULL) {
@@ -229,7 +229,7 @@ static struct llama_model * llava_init(gpt_params * params) {
229
229
  return model;
230
230
  }
231
231
 
232
- static struct llava_context * llava_init_context(gpt_params * params, llama_model * model) {
232
+ static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
233
233
  const char * clip_path = params->mmproj.c_str();
234
234
 
235
235
  auto prompt = params->prompt;
@@ -240,7 +240,7 @@ static struct llava_context * llava_init_context(gpt_params * params, llama_mode
240
240
  auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
241
241
 
242
242
 
243
- llama_context_params ctx_params = llama_context_params_from_gpt_params(*params);
243
+ llama_context_params ctx_params = common_context_params_to_llama(*params);
244
244
  ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
245
245
 
246
246
  llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params);
@@ -272,13 +272,13 @@ static void llava_free(struct llava_context * ctx_llava) {
272
272
  int main(int argc, char ** argv) {
273
273
  ggml_time_init();
274
274
 
275
- gpt_params params;
275
+ common_params params;
276
276
 
277
- if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
277
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) {
278
278
  return 1;
279
279
  }
280
280
 
281
- gpt_init();
281
+ common_init();
282
282
 
283
283
  if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) {
284
284
  print_usage(argc, argv);
@@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
401
401
  return true;
402
402
  }
403
403
 
404
+ struct llava_embd_batch {
405
+ std::vector<llama_pos> pos;
406
+ std::vector<int32_t> n_seq_id;
407
+ std::vector<llama_seq_id> seq_id_0;
408
+ std::vector<llama_seq_id *> seq_ids;
409
+ std::vector<int8_t> logits;
410
+ llama_batch batch;
411
+ llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
412
+ pos .resize(n_tokens);
413
+ n_seq_id.resize(n_tokens);
414
+ seq_ids .resize(n_tokens + 1);
415
+ logits .resize(n_tokens);
416
+ seq_id_0.resize(1);
417
+ seq_id_0[0] = seq_id;
418
+ seq_ids [n_tokens] = nullptr;
419
+ batch = {
420
+ /*n_tokens =*/ n_tokens,
421
+ /*tokens =*/ nullptr,
422
+ /*embd =*/ embd,
423
+ /*pos =*/ pos.data(),
424
+ /*n_seq_id =*/ n_seq_id.data(),
425
+ /*seq_id =*/ seq_ids.data(),
426
+ /*logits =*/ logits.data(),
427
+ };
428
+ for (int i = 0; i < n_tokens; i++) {
429
+ batch.pos [i] = pos_0 + i;
430
+ batch.n_seq_id[i] = 1;
431
+ batch.seq_id [i] = seq_id_0.data();
432
+ batch.logits [i] = false;
433
+ }
434
+ }
435
+ };
436
+
404
437
  bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
405
438
  int n_embd = llama_n_embd(llama_get_model(ctx_llama));
406
439
 
@@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
409
442
  if (n_eval > n_batch) {
410
443
  n_eval = n_batch;
411
444
  }
412
- llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
413
- if (llama_decode(ctx_llama, batch)) {
445
+ float * embd = image_embed->embed+i*n_embd;
446
+ llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
447
+ if (llama_decode(ctx_llama, llava_batch.batch)) {
414
448
  LOG_ERR("%s : failed to eval\n", __func__);
415
449
  return false;
416
450
  }
@@ -432,7 +466,7 @@ struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * c
432
466
  bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
433
467
  if (!image_embed_result) {
434
468
  clip_image_u8_free(img);
435
- LOG_ERR("%s: coulnd't embed the image\n", __func__);
469
+ LOG_ERR("%s: couldn't embed the image\n", __func__);
436
470
  return NULL;
437
471
  }
438
472