@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -26,6 +26,9 @@ void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_REST
26
26
  void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
27
27
  void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
28
28
 
29
+ void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k);
30
+ void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k);
31
+
29
32
  void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
30
33
  void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
31
34
  void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
@@ -46,6 +49,9 @@ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
46
49
  void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
47
50
  void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
48
51
 
52
+ void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
53
+ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
54
+
49
55
  void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
50
56
  void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
51
57
  void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -67,6 +73,9 @@ void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRI
67
73
  void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
68
74
  void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
69
75
 
76
+ void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
77
+ void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
78
+
70
79
  void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
71
80
  void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
72
81
  void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +99,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
90
99
  void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
91
100
  void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
92
101
 
102
+ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
103
+ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
104
+
93
105
  void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
94
106
  void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
95
107
  void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -111,6 +123,9 @@ size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds
111
123
  size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
112
124
  size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
113
125
 
126
+ size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
127
+ size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
128
+
114
129
  size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
115
130
  size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
116
131
  size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
@@ -1,5 +1,5 @@
1
1
  #include "ggml-rpc.h"
2
- #include "ggml.h"
2
+ #include "ggml-impl.h"
3
3
  #include "ggml-backend-impl.h"
4
4
 
5
5
  #include <cinttypes>
@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
82
82
 
83
83
  // RPC commands
84
84
  enum rpc_cmd {
85
- ALLOC_BUFFER = 0,
86
- GET_ALIGNMENT,
87
- GET_MAX_SIZE,
88
- BUFFER_GET_BASE,
89
- FREE_BUFFER,
90
- BUFFER_CLEAR,
91
- SET_TENSOR,
92
- GET_TENSOR,
93
- COPY_TENSOR,
94
- GRAPH_COMPUTE,
95
- GET_DEVICE_MEMORY,
85
+ RPC_CMD_ALLOC_BUFFER = 0,
86
+ RPC_CMD_GET_ALIGNMENT,
87
+ RPC_CMD_GET_MAX_SIZE,
88
+ RPC_CMD_BUFFER_GET_BASE,
89
+ RPC_CMD_FREE_BUFFER,
90
+ RPC_CMD_BUFFER_CLEAR,
91
+ RPC_CMD_SET_TENSOR,
92
+ RPC_CMD_GET_TENSOR,
93
+ RPC_CMD_COPY_TENSOR,
94
+ RPC_CMD_GRAPH_COMPUTE,
95
+ RPC_CMD_GET_DEVICE_MEMORY,
96
+ RPC_CMD_COUNT,
96
97
  };
97
98
 
98
99
  // RPC data structures
@@ -197,6 +198,10 @@ static std::shared_ptr<socket_t> create_server_socket(const char * host, int por
197
198
  fprintf(stderr, "Failed to set SO_REUSEADDR\n");
198
199
  return nullptr;
199
200
  }
201
+ if (inet_addr(host) == INADDR_NONE) {
202
+ fprintf(stderr, "Invalid host address: %s\n", host);
203
+ return nullptr;
204
+ }
200
205
  struct sockaddr_in serv_addr;
201
206
  serv_addr.sin_family = AF_INET;
202
207
  serv_addr.sin_addr.s_addr = inet_addr(host);
@@ -314,25 +319,25 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
314
319
  return sock;
315
320
  }
316
321
 
317
- GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
322
+ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
318
323
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
319
324
  return ctx->name.c_str();
320
325
  }
321
326
 
322
- GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
327
+ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
323
328
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
324
329
  // input serialization format: | remote_ptr (8 bytes) |
325
330
  std::vector<uint8_t> input(sizeof(uint64_t), 0);
326
331
  uint64_t remote_ptr = ctx->remote_ptr;
327
332
  memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
328
333
  std::vector<uint8_t> output;
329
- bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
334
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
330
335
  GGML_ASSERT(status);
331
336
  GGML_ASSERT(output.empty());
332
337
  delete ctx;
333
338
  }
334
339
 
335
- GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
340
+ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
336
341
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
337
342
  if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
338
343
  return ctx->base_cache[buffer];
@@ -342,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
342
347
  uint64_t remote_ptr = ctx->remote_ptr;
343
348
  memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
344
349
  std::vector<uint8_t> output;
345
- bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
350
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
346
351
  GGML_ASSERT(status);
347
352
  GGML_ASSERT(output.size() == sizeof(uint64_t));
348
353
  // output serialization format: | base_ptr (8 bytes) |
@@ -383,7 +388,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
383
388
  return result;
384
389
  }
385
390
 
386
- GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
391
+ static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
387
392
  UNUSED(buffer);
388
393
  if (ggml_is_quantized(tensor->type)) {
389
394
  // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
@@ -391,7 +396,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t
391
396
  }
392
397
  }
393
398
 
394
- GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
399
+ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
395
400
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
396
401
  // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
397
402
  size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
@@ -401,11 +406,11 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
401
406
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
402
407
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
403
408
  std::vector<uint8_t> output;
404
- bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
409
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
405
410
  GGML_ASSERT(status);
406
411
  }
407
412
 
408
- GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
413
+ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
409
414
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
410
415
  // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
411
416
  int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
@@ -415,14 +420,14 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
415
420
  memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
416
421
  memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
417
422
  std::vector<uint8_t> output;
418
- bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
423
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
419
424
  GGML_ASSERT(status);
420
425
  GGML_ASSERT(output.size() == size);
421
426
  // output serialization format: | data (size bytes) |
422
427
  memcpy(data, output.data(), size);
423
428
  }
424
429
 
425
- GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
430
+ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
426
431
  // check if src and dst are on the same server
427
432
  ggml_backend_buffer_t src_buffer = src->buffer;
428
433
  ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
@@ -440,14 +445,14 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
440
445
  memcpy(input.data(), &rpc_src, sizeof(rpc_src));
441
446
  memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
442
447
  std::vector<uint8_t> output;
443
- bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
448
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
444
449
  GGML_ASSERT(status);
445
450
  // output serialization format: | result (1 byte) |
446
451
  GGML_ASSERT(output.size() == 1);
447
452
  return output[0];
448
453
  }
449
454
 
450
- GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
455
+ static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
451
456
  ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
452
457
  // serialization format: | bufptr (8 bytes) | value (1 byte) |
453
458
  int input_size = sizeof(uint64_t) + sizeof(uint8_t);
@@ -455,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
455
460
  memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
456
461
  memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
457
462
  std::vector<uint8_t> output;
458
- bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
463
+ bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
459
464
  GGML_ASSERT(status);
460
465
  }
461
466
 
@@ -464,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
464
469
  /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
465
470
  /* .get_base = */ ggml_backend_rpc_buffer_get_base,
466
471
  /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
472
+ /* .memset_tensor = */ NULL,
467
473
  /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
468
474
  /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
469
475
  /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
@@ -471,12 +477,12 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
471
477
  /* .reset = */ NULL,
472
478
  };
473
479
 
474
- GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
480
+ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
475
481
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
476
482
  return buft_ctx->name.c_str();
477
483
  }
478
484
 
479
- GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
485
+ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
480
486
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
481
487
  // input serialization format: | size (8 bytes) |
482
488
  int input_size = sizeof(uint64_t);
@@ -484,7 +490,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
484
490
  memcpy(input.data(), &size, sizeof(size));
485
491
  std::vector<uint8_t> output;
486
492
  auto sock = get_socket(buft_ctx->endpoint);
487
- bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
493
+ bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
488
494
  GGML_ASSERT(status);
489
495
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
490
496
  // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -507,7 +513,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
507
513
  // input serialization format: | 0 bytes |
508
514
  std::vector<uint8_t> input;
509
515
  std::vector<uint8_t> output;
510
- bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
516
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
511
517
  GGML_ASSERT(status);
512
518
  GGML_ASSERT(output.size() == sizeof(uint64_t));
513
519
  // output serialization format: | alignment (8 bytes) |
@@ -516,7 +522,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
516
522
  return alignment;
517
523
  }
518
524
 
519
- GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
525
+ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
520
526
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
521
527
  return buft_ctx->alignment;
522
528
  }
@@ -525,7 +531,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
525
531
  // input serialization format: | 0 bytes |
526
532
  std::vector<uint8_t> input;
527
533
  std::vector<uint8_t> output;
528
- bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
534
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
529
535
  GGML_ASSERT(status);
530
536
  GGML_ASSERT(output.size() == sizeof(uint64_t));
531
537
  // output serialization format: | max_size (8 bytes) |
@@ -534,12 +540,12 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
534
540
  return max_size;
535
541
  }
536
542
 
537
- GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
543
+ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
538
544
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
539
545
  return buft_ctx->max_size;
540
546
  }
541
547
 
542
- GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
548
+ static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
543
549
  UNUSED(buft);
544
550
  return ggml_nbytes(tensor);
545
551
  }
@@ -553,24 +559,24 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
553
559
  /* .is_host = */ NULL,
554
560
  };
555
561
 
556
- GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
562
+ static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
557
563
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
558
564
 
559
565
  return rpc_ctx->name.c_str();
560
566
  }
561
567
 
562
- GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
568
+ static void ggml_backend_rpc_free(ggml_backend_t backend) {
563
569
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
564
570
  delete rpc_ctx;
565
571
  delete backend;
566
572
  }
567
573
 
568
- GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
574
+ static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
569
575
  ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
570
576
  return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
571
577
  }
572
578
 
573
- GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
579
+ static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
574
580
  UNUSED(backend);
575
581
  // this is no-op because we don't have any async operations
576
582
  }
@@ -612,27 +618,27 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
612
618
  memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
613
619
  }
614
620
 
615
- GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
621
+ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
616
622
  ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
617
623
  std::vector<uint8_t> input;
618
624
  serialize_graph(cgraph, input);
619
625
  std::vector<uint8_t> output;
620
626
  auto sock = get_socket(rpc_ctx->endpoint);
621
- bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
627
+ bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
622
628
  GGML_ASSERT(status);
623
629
  GGML_ASSERT(output.size() == 1);
624
630
  return (enum ggml_status)output[0];
625
631
  }
626
632
 
627
- GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
633
+ static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
628
634
  UNUSED(backend);
629
635
  UNUSED(op);
630
636
  //TODO: call the remote backend and cache the results
631
637
  return true;
632
638
  }
633
639
 
634
- GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
635
- if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
640
+ static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
641
+ if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
636
642
  return false;
637
643
  }
638
644
  ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
@@ -656,14 +662,11 @@ static ggml_backend_i ggml_backend_rpc_interface = {
656
662
  /* .supports_op = */ ggml_backend_rpc_supports_op,
657
663
  /* .supports_buft = */ ggml_backend_rpc_supports_buft,
658
664
  /* .offload_op = */ NULL,
659
- /* .event_new = */ NULL,
660
- /* .event_free = */ NULL,
661
665
  /* .event_record = */ NULL,
662
666
  /* .event_wait = */ NULL,
663
- /* .event_synchronize = */ NULL,
664
667
  };
665
668
 
666
- GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
669
+ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
667
670
  static std::mutex mutex;
668
671
  std::lock_guard<std::mutex> lock(mutex);
669
672
  // NOTE: buffer types are allocated and never freed; this is by design
@@ -674,6 +677,7 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
674
677
  }
675
678
  auto sock = get_socket(endpoint);
676
679
  if (sock == nullptr) {
680
+ fprintf(stderr, "Failed to connect to %s\n", endpoint);
677
681
  return nullptr;
678
682
  }
679
683
  size_t alignment = get_alignment(sock);
@@ -687,13 +691,14 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
687
691
 
688
692
  ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
689
693
  /* .iface = */ ggml_backend_rpc_buffer_type_interface,
694
+ /* .device = */ nullptr,
690
695
  /* .context = */ buft_ctx
691
696
  };
692
697
  buft_map[endpoint] = buft;
693
698
  return buft;
694
699
  }
695
700
 
696
- GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
701
+ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
697
702
  ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
698
703
  /* .endpoint = */ endpoint,
699
704
  /* .name = */ "RPC[" + std::string(endpoint) + "]",
@@ -702,12 +707,13 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
702
707
  ggml_backend_t backend = new ggml_backend {
703
708
  /* .guid = */ ggml_backend_rpc_guid(),
704
709
  /* .interface = */ ggml_backend_rpc_interface,
710
+ /* .device = */ nullptr,
705
711
  /* .context = */ ctx
706
712
  };
707
713
  return backend;
708
714
  }
709
715
 
710
- GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
716
+ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
711
717
  return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
712
718
  }
713
719
 
@@ -715,7 +721,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
715
721
  // input serialization format: | 0 bytes |
716
722
  std::vector<uint8_t> input;
717
723
  std::vector<uint8_t> output;
718
- bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
724
+ bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
719
725
  GGML_ASSERT(status);
720
726
  GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
721
727
  // output serialization format: | free (8 bytes) | total (8 bytes) |
@@ -727,7 +733,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
727
733
  *total = total_mem;
728
734
  }
729
735
 
730
- GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
736
+ GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
731
737
  auto sock = get_socket(endpoint);
732
738
  if (sock == nullptr) {
733
739
  *free = 0;
@@ -877,8 +883,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp
877
883
  }
878
884
  result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
879
885
  if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
880
- return nullptr;
886
+ result->buffer = nullptr;
881
887
  }
888
+
889
+ if (result->buffer) {
890
+ // require that the tensor data does not go beyond the buffer end
891
+ uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
892
+ uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
893
+ uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
894
+ GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
895
+ GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
896
+ }
897
+
882
898
  result->op = (ggml_op) tensor->op;
883
899
  for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
884
900
  result->op_params[i] = tensor->op_params[i];
@@ -898,7 +914,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
898
914
  const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
899
915
  uint64_t offset;
900
916
  memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
901
- size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
917
+ const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
902
918
 
903
919
  struct ggml_init_params params {
904
920
  /*.mem_size =*/ ggml_tensor_overhead(),
@@ -913,6 +929,17 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
913
929
  return false;
914
930
  }
915
931
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
932
+
933
+ // sanitize tensor->data
934
+ {
935
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
936
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
937
+
938
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
939
+ GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
940
+ }
941
+ }
942
+
916
943
  const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
917
944
  ggml_backend_tensor_set(tensor, data, offset, size);
918
945
  ggml_free(ctx);
@@ -943,6 +970,17 @@ bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint
943
970
  return false;
944
971
  }
945
972
  GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
973
+
974
+ // sanitize tensor->data
975
+ {
976
+ const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
977
+ const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
978
+
979
+ if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
980
+ GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
981
+ }
982
+ }
983
+
946
984
  // output serialization format: | data (size bytes) |
947
985
  output.resize(size, 0);
948
986
  ggml_backend_tensor_get(tensor, output.data(), offset, size);
@@ -1024,7 +1062,7 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<u
1024
1062
  const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
1025
1063
  GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
1026
1064
 
1027
- static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1065
+ size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
1028
1066
  struct ggml_init_params params = {
1029
1067
  /*.mem_size =*/ buf_size,
1030
1068
  /*.mem_buffer =*/ NULL,
@@ -1064,59 +1102,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1064
1102
  if (!recv_data(sockfd, &cmd, 1)) {
1065
1103
  break;
1066
1104
  }
1105
+ if (cmd >= RPC_CMD_COUNT) {
1106
+ // fail fast if the command is invalid
1107
+ fprintf(stderr, "Unknown command: %d\n", cmd);
1108
+ break;
1109
+ }
1067
1110
  std::vector<uint8_t> input;
1068
1111
  std::vector<uint8_t> output;
1069
1112
  uint64_t input_size;
1070
1113
  if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
1071
1114
  break;
1072
1115
  }
1073
- input.resize(input_size);
1116
+ try {
1117
+ input.resize(input_size);
1118
+ } catch (const std::bad_alloc & e) {
1119
+ fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
1120
+ break;
1121
+ }
1074
1122
  if (!recv_data(sockfd, input.data(), input_size)) {
1075
1123
  break;
1076
1124
  }
1077
1125
  bool ok = true;
1078
1126
  switch (cmd) {
1079
- case ALLOC_BUFFER: {
1127
+ case RPC_CMD_ALLOC_BUFFER: {
1080
1128
  ok = server.alloc_buffer(input, output);
1081
1129
  break;
1082
1130
  }
1083
- case GET_ALIGNMENT: {
1131
+ case RPC_CMD_GET_ALIGNMENT: {
1084
1132
  server.get_alignment(output);
1085
1133
  break;
1086
1134
  }
1087
- case GET_MAX_SIZE: {
1135
+ case RPC_CMD_GET_MAX_SIZE: {
1088
1136
  server.get_max_size(output);
1089
1137
  break;
1090
1138
  }
1091
- case BUFFER_GET_BASE: {
1139
+ case RPC_CMD_BUFFER_GET_BASE: {
1092
1140
  ok = server.buffer_get_base(input, output);
1093
1141
  break;
1094
1142
  }
1095
- case FREE_BUFFER: {
1143
+ case RPC_CMD_FREE_BUFFER: {
1096
1144
  ok = server.free_buffer(input);
1097
1145
  break;
1098
1146
  }
1099
- case BUFFER_CLEAR: {
1147
+ case RPC_CMD_BUFFER_CLEAR: {
1100
1148
  ok = server.buffer_clear(input);
1101
1149
  break;
1102
1150
  }
1103
- case SET_TENSOR: {
1151
+ case RPC_CMD_SET_TENSOR: {
1104
1152
  ok = server.set_tensor(input);
1105
1153
  break;
1106
1154
  }
1107
- case GET_TENSOR: {
1155
+ case RPC_CMD_GET_TENSOR: {
1108
1156
  ok = server.get_tensor(input, output);
1109
1157
  break;
1110
1158
  }
1111
- case COPY_TENSOR: {
1159
+ case RPC_CMD_COPY_TENSOR: {
1112
1160
  ok = server.copy_tensor(input, output);
1113
1161
  break;
1114
1162
  }
1115
- case GRAPH_COMPUTE: {
1163
+ case RPC_CMD_GRAPH_COMPUTE: {
1116
1164
  ok = server.graph_compute(input, output);
1117
1165
  break;
1118
1166
  }
1119
- case GET_DEVICE_MEMORY: {
1167
+ case RPC_CMD_GET_DEVICE_MEMORY: {
1120
1168
  // output serialization format: | free (8 bytes) | total (8 bytes) |
1121
1169
  output.resize(2*sizeof(uint64_t), 0);
1122
1170
  memcpy(output.data(), &free_mem, sizeof(free_mem));
@@ -1169,8 +1217,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1169
1217
  return;
1170
1218
  }
1171
1219
  printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1220
+ fflush(stdout);
1172
1221
  rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
1173
1222
  printf("Client connection closed\n");
1223
+ fflush(stdout);
1174
1224
  }
1175
1225
  #ifdef _WIN32
1176
1226
  WSACleanup();
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include "concat.hpp"
17
17
  #include "common.hpp"
18
+ #include "conv.hpp"
18
19
  #include "convert.hpp"
19
20
  #include "dequantize.hpp"
20
21
  #include "dmmv.hpp"
@@ -23,5 +24,7 @@
23
24
  #include "rope.hpp"
24
25
  #include "norm.hpp"
25
26
  #include "softmax.hpp"
27
+ #include "tsembd.hpp"
28
+ #include "im2col.hpp"
26
29
 
27
30
  #endif // GGML_SYCL_BACKEND_HPP
@@ -51,3 +51,14 @@ void ggml_sycl_host_free(void* ptr) try {
51
51
  << ", line:" << __LINE__ << std::endl;
52
52
  std::exit(1);
53
53
  }
54
+
55
+ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) {
56
+ const int64_t max_range = std::numeric_limits<int>::max();
57
+ int64_t sycl_down_blk_size = block_size;
58
+ int64_t global_range = accumulate_block_num * sycl_down_blk_size;
59
+ while(global_range > max_range) {
60
+ sycl_down_blk_size /= 2;
61
+ global_range = accumulate_block_num * sycl_down_blk_size;
62
+ }
63
+ return sycl_down_blk_size;
64
+ }