@fugood/llama.node 0.0.1-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +85 -0
- package/README.md +56 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/lib/binding.js +13 -0
- package/lib/binding.ts +57 -0
- package/lib/index.js +24 -0
- package/lib/index.ts +13 -0
- package/package.json +65 -0
- package/src/addons.cpp +506 -0
- package/src/llama.cpp/CMakeLists.txt +1320 -0
- package/src/llama.cpp/build.zig +172 -0
- package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
- package/src/llama.cpp/common/CMakeLists.txt +87 -0
- package/src/llama.cpp/common/base64.hpp +392 -0
- package/src/llama.cpp/common/common.cpp +2949 -0
- package/src/llama.cpp/common/common.h +324 -0
- package/src/llama.cpp/common/console.cpp +501 -0
- package/src/llama.cpp/common/console.h +19 -0
- package/src/llama.cpp/common/grammar-parser.cpp +440 -0
- package/src/llama.cpp/common/grammar-parser.h +29 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
- package/src/llama.cpp/common/json.hpp +24766 -0
- package/src/llama.cpp/common/log.h +724 -0
- package/src/llama.cpp/common/ngram-cache.cpp +282 -0
- package/src/llama.cpp/common/ngram-cache.h +94 -0
- package/src/llama.cpp/common/sampling.cpp +353 -0
- package/src/llama.cpp/common/sampling.h +147 -0
- package/src/llama.cpp/common/stb_image.h +8396 -0
- package/src/llama.cpp/common/train.cpp +1513 -0
- package/src/llama.cpp/common/train.h +233 -0
- package/src/llama.cpp/examples/CMakeLists.txt +52 -0
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
- package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/batched/batched.cpp +262 -0
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
- package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
- package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
- package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/infill/infill.cpp +767 -0
- package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
- package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
- package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
- package/src/llama.cpp/examples/llava/clip.h +85 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
- package/src/llama.cpp/examples/llava/llava.cpp +426 -0
- package/src/llama.cpp/examples/llava/llava.h +50 -0
- package/src/llama.cpp/examples/llava/requirements.txt +3 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
- package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
- package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/main/main.cpp +957 -0
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
- package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
- package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
- package/src/llama.cpp/examples/server/httplib.h +9465 -0
- package/src/llama.cpp/examples/server/server.cpp +3826 -0
- package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
- package/src/llama.cpp/examples/server/utils.hpp +653 -0
- package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple/simple.cpp +183 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
- package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
- package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
- package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
- package/src/llama.cpp/ggml-alloc.c +985 -0
- package/src/llama.cpp/ggml-alloc.h +76 -0
- package/src/llama.cpp/ggml-backend-impl.h +141 -0
- package/src/llama.cpp/ggml-backend.c +2099 -0
- package/src/llama.cpp/ggml-backend.h +233 -0
- package/src/llama.cpp/ggml-common.h +1853 -0
- package/src/llama.cpp/ggml-cuda.h +43 -0
- package/src/llama.cpp/ggml-impl.h +265 -0
- package/src/llama.cpp/ggml-kompute.cpp +2006 -0
- package/src/llama.cpp/ggml-kompute.h +46 -0
- package/src/llama.cpp/ggml-metal.h +66 -0
- package/src/llama.cpp/ggml-mpi.c +216 -0
- package/src/llama.cpp/ggml-mpi.h +39 -0
- package/src/llama.cpp/ggml-opencl.cpp +2301 -0
- package/src/llama.cpp/ggml-opencl.h +36 -0
- package/src/llama.cpp/ggml-quants.c +12678 -0
- package/src/llama.cpp/ggml-quants.h +133 -0
- package/src/llama.cpp/ggml-sycl.cpp +17882 -0
- package/src/llama.cpp/ggml-sycl.h +49 -0
- package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
- package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
- package/src/llama.cpp/ggml-vulkan.h +29 -0
- package/src/llama.cpp/ggml.c +21819 -0
- package/src/llama.cpp/ggml.h +2403 -0
- package/src/llama.cpp/llama.cpp +17468 -0
- package/src/llama.cpp/llama.h +1117 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
- package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
- package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
- package/src/llama.cpp/prompts/alpaca.txt +1 -0
- package/src/llama.cpp/prompts/assistant.txt +31 -0
- package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
- package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
- package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
- package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
- package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
- package/src/llama.cpp/prompts/chat.txt +28 -0
- package/src/llama.cpp/prompts/dan-modified.txt +1 -0
- package/src/llama.cpp/prompts/dan.txt +1 -0
- package/src/llama.cpp/prompts/mnemonics.txt +93 -0
- package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
- package/src/llama.cpp/prompts/reason-act.txt +18 -0
- package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
- package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
- package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
- package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
- package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
- package/src/llama.cpp/requirements.txt +12 -0
- package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
- package/src/llama.cpp/scripts/xxd.cmake +16 -0
- package/src/llama.cpp/sgemm.cpp +999 -0
- package/src/llama.cpp/sgemm.h +12 -0
- package/src/llama.cpp/tests/CMakeLists.txt +78 -0
- package/src/llama.cpp/tests/get-model.cpp +21 -0
- package/src/llama.cpp/tests/get-model.h +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
- package/src/llama.cpp/tests/test-c.c +7 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
- package/src/llama.cpp/tests/test-double-float.cpp +57 -0
- package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
- package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
- package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
- package/src/llama.cpp/tests/test-opt.cpp +181 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
- package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
- package/src/llama.cpp/tests/test-rope.cpp +221 -0
- package/src/llama.cpp/tests/test-sampling.cpp +301 -0
- package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
- package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
- package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
- package/src/llama.cpp/unicode-data.cpp +1651 -0
- package/src/llama.cpp/unicode-data.h +16 -0
- package/src/llama.cpp/unicode.cpp +277 -0
- package/src/llama.cpp/unicode.h +28 -0
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "ggml.h"
|
|
4
|
+
#include "ggml-backend.h"
|
|
5
|
+
|
|
6
|
+
#include <stdbool.h>
|
|
7
|
+
#include <stddef.h>
|
|
8
|
+
#include <stdint.h>
|
|
9
|
+
|
|
10
|
+
#ifdef __cplusplus
|
|
11
|
+
extern "C" {
|
|
12
|
+
#endif
|
|
13
|
+
|
|
14
|
+
struct ggml_vk_device {
|
|
15
|
+
int index;
|
|
16
|
+
int type; // same as VkPhysicalDeviceType
|
|
17
|
+
size_t heapSize;
|
|
18
|
+
const char * name;
|
|
19
|
+
const char * vendor;
|
|
20
|
+
int subgroupSize;
|
|
21
|
+
uint64_t bufferAlignment;
|
|
22
|
+
uint64_t maxAlloc;
|
|
23
|
+
};
|
|
24
|
+
|
|
25
|
+
struct ggml_vk_device * ggml_vk_available_devices(size_t memoryRequired, size_t * count);
|
|
26
|
+
bool ggml_vk_get_device(struct ggml_vk_device * device, size_t memoryRequired, const char * name);
|
|
27
|
+
bool ggml_vk_has_vulkan(void);
|
|
28
|
+
bool ggml_vk_has_device(void);
|
|
29
|
+
struct ggml_vk_device ggml_vk_current_device(void);
|
|
30
|
+
|
|
31
|
+
//
|
|
32
|
+
// backend API
|
|
33
|
+
//
|
|
34
|
+
|
|
35
|
+
// forward declaration
|
|
36
|
+
typedef struct ggml_backend * ggml_backend_t;
|
|
37
|
+
|
|
38
|
+
GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
|
|
39
|
+
|
|
40
|
+
GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
|
41
|
+
|
|
42
|
+
GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
|
43
|
+
|
|
44
|
+
#ifdef __cplusplus
|
|
45
|
+
}
|
|
46
|
+
#endif
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
// An interface allowing to compute ggml_cgraph with Metal
|
|
2
|
+
//
|
|
3
|
+
// This is a fully functional interface that extends ggml with GPU support for Apple devices.
|
|
4
|
+
// A similar interface can be created for other GPU backends (e.g. Vulkan, CUDA, OpenCL, etc.)
|
|
5
|
+
//
|
|
6
|
+
// How it works?
|
|
7
|
+
//
|
|
8
|
+
// As long as your program can create and evaluate a ggml_cgraph on the CPU, you can use this
|
|
9
|
+
// interface to evaluate the same graph on the GPU. Instead of using ggml_graph_compute(), you
|
|
10
|
+
// use ggml_metal_graph_compute() (or ggml_vulkan_graph_compute(), etc.)
|
|
11
|
+
//
|
|
12
|
+
// You only need to make sure that all memory buffers that you used during the graph creation
|
|
13
|
+
// are mapped to the device memory with the ggml_metal_add_buffer() function. This mapping is
|
|
14
|
+
// used during the graph evaluation to determine the arguments of the compute kernels.
|
|
15
|
+
//
|
|
16
|
+
// Synchronization between device and host memory (for example for input and output tensors)
|
|
17
|
+
// is done with the ggml_metal_set_tensor() and ggml_metal_get_tensor() functions.
|
|
18
|
+
//
|
|
19
|
+
|
|
20
|
+
#pragma once
|
|
21
|
+
|
|
22
|
+
#include "ggml.h"
|
|
23
|
+
#include "ggml-backend.h"
|
|
24
|
+
|
|
25
|
+
#include <stddef.h>
|
|
26
|
+
#include <stdbool.h>
|
|
27
|
+
|
|
28
|
+
// max memory buffers that can be mapped to the device
|
|
29
|
+
#define GGML_METAL_MAX_BUFFERS 64
|
|
30
|
+
|
|
31
|
+
struct ggml_tensor;
|
|
32
|
+
struct ggml_cgraph;
|
|
33
|
+
|
|
34
|
+
#ifdef __cplusplus
|
|
35
|
+
extern "C" {
|
|
36
|
+
#endif
|
|
37
|
+
|
|
38
|
+
//
|
|
39
|
+
// backend API
|
|
40
|
+
// user-code should use only these functions
|
|
41
|
+
//
|
|
42
|
+
|
|
43
|
+
GGML_API void ggml_backend_metal_log_set_callback(ggml_log_callback log_callback, void * user_data);
|
|
44
|
+
|
|
45
|
+
GGML_API ggml_backend_t ggml_backend_metal_init(void);
|
|
46
|
+
|
|
47
|
+
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
|
48
|
+
|
|
49
|
+
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
|
|
50
|
+
|
|
51
|
+
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
|
52
|
+
|
|
53
|
+
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
|
54
|
+
|
|
55
|
+
// helper to check if the device supports a specific family
|
|
56
|
+
// ideally, the user code should be doing these checks
|
|
57
|
+
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
|
58
|
+
GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
|
59
|
+
|
|
60
|
+
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
|
61
|
+
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
|
62
|
+
|
|
63
|
+
#ifdef __cplusplus
|
|
64
|
+
}
|
|
65
|
+
#endif
|
|
66
|
+
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
#include "ggml-mpi.h"
|
|
2
|
+
|
|
3
|
+
#include "ggml.h"
|
|
4
|
+
|
|
5
|
+
#include <mpi.h>
|
|
6
|
+
|
|
7
|
+
#include <stdio.h>
|
|
8
|
+
#include <stdlib.h>
|
|
9
|
+
|
|
10
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
11
|
+
|
|
12
|
+
#define UNUSED GGML_UNUSED
|
|
13
|
+
|
|
14
|
+
struct ggml_mpi_context {
|
|
15
|
+
int rank;
|
|
16
|
+
int size;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
void ggml_mpi_backend_init(void) {
|
|
20
|
+
MPI_Init(NULL, NULL);
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
void ggml_mpi_backend_free(void) {
|
|
24
|
+
MPI_Finalize();
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
struct ggml_mpi_context * ggml_mpi_init(void) {
|
|
28
|
+
struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context));
|
|
29
|
+
|
|
30
|
+
MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank);
|
|
31
|
+
MPI_Comm_size(MPI_COMM_WORLD, &ctx->size);
|
|
32
|
+
|
|
33
|
+
return ctx;
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
void ggml_mpi_free(struct ggml_mpi_context * ctx) {
|
|
37
|
+
free(ctx);
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
int ggml_mpi_rank(struct ggml_mpi_context * ctx) {
|
|
41
|
+
return ctx->rank;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
void ggml_mpi_eval_init(
|
|
45
|
+
struct ggml_mpi_context * ctx_mpi,
|
|
46
|
+
int * n_tokens,
|
|
47
|
+
int * n_past,
|
|
48
|
+
int * n_threads) {
|
|
49
|
+
UNUSED(ctx_mpi);
|
|
50
|
+
|
|
51
|
+
// synchronize the worker node parameters with the root node
|
|
52
|
+
MPI_Barrier(MPI_COMM_WORLD);
|
|
53
|
+
|
|
54
|
+
MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
|
55
|
+
MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
|
56
|
+
MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) {
|
|
60
|
+
struct ggml_tensor * t = ggml_graph_get_tensor(gf, name);
|
|
61
|
+
if (t == NULL) {
|
|
62
|
+
fprintf(stderr, "%s: tensor %s not found\n", __func__, name);
|
|
63
|
+
return -1;
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
for (int i = 0; i < gf->n_nodes; i++) {
|
|
67
|
+
if (gf->nodes[i] == t) {
|
|
68
|
+
return i;
|
|
69
|
+
}
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name);
|
|
73
|
+
return -1;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) {
|
|
77
|
+
MPI_Datatype mpi_type;
|
|
78
|
+
|
|
79
|
+
switch (t->type) {
|
|
80
|
+
case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
|
|
81
|
+
case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
|
|
82
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD);
|
|
86
|
+
GGML_ASSERT(retval == MPI_SUCCESS);
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) {
|
|
90
|
+
MPI_Datatype mpi_type;
|
|
91
|
+
|
|
92
|
+
switch (t->type) {
|
|
93
|
+
case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break;
|
|
94
|
+
case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break;
|
|
95
|
+
default: GGML_ASSERT(false && "not implemented");
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
MPI_Status status; UNUSED(status);
|
|
99
|
+
|
|
100
|
+
const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
|
|
101
|
+
GGML_ASSERT(retval == MPI_SUCCESS);
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// TODO: there are many improvements that can be done to this implementation
|
|
105
|
+
void ggml_mpi_graph_compute_pre(
|
|
106
|
+
struct ggml_mpi_context * ctx_mpi,
|
|
107
|
+
struct ggml_cgraph * gf,
|
|
108
|
+
int n_layers) {
|
|
109
|
+
const int mpi_rank = ctx_mpi->rank;
|
|
110
|
+
const int mpi_size = ctx_mpi->size;
|
|
111
|
+
|
|
112
|
+
struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens");
|
|
113
|
+
if (inp_tokens == NULL) {
|
|
114
|
+
fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__);
|
|
115
|
+
return;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0");
|
|
119
|
+
if (inp0 == NULL) {
|
|
120
|
+
fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__);
|
|
121
|
+
return;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
GGML_ASSERT(inp0 == gf->nodes[0]);
|
|
125
|
+
|
|
126
|
+
// distribute the compute graph into slices across the MPI nodes
|
|
127
|
+
//
|
|
128
|
+
// the main node (0) processes the last layers + the remainder of the compute graph
|
|
129
|
+
// and is responsible to pass the input tokens to the first node (1)
|
|
130
|
+
//
|
|
131
|
+
// node 1: [( 0) * n_per_node, ( 1) * n_per_node)
|
|
132
|
+
// node 2: [( 1) * n_per_node, ( 2) * n_per_node)
|
|
133
|
+
// ...
|
|
134
|
+
// node n-1: [(n-2) * n_per_node, (n-1) * n_per_node)
|
|
135
|
+
// node 0: [(n-1) * n_per_node, n_nodes)
|
|
136
|
+
//
|
|
137
|
+
if (mpi_rank > 0) {
|
|
138
|
+
if (mpi_rank == 1) {
|
|
139
|
+
// the first node (1) receives the input tokens from the main node (0)
|
|
140
|
+
ggml_mpi_tensor_recv(inp_tokens, 0);
|
|
141
|
+
} else {
|
|
142
|
+
// recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph)
|
|
143
|
+
ggml_mpi_tensor_recv(inp0, mpi_rank - 1);
|
|
144
|
+
}
|
|
145
|
+
} else if (mpi_size > 1) {
|
|
146
|
+
// node 0 sends the input tokens to node 1
|
|
147
|
+
ggml_mpi_tensor_send(inp_tokens, 1);
|
|
148
|
+
|
|
149
|
+
// recv the output data from the last node
|
|
150
|
+
ggml_mpi_tensor_recv(inp0, mpi_size - 1);
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
{
|
|
154
|
+
const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size;
|
|
155
|
+
|
|
156
|
+
const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1;
|
|
157
|
+
|
|
158
|
+
const int il0 = (mpi_idx + 0) * n_per_node;
|
|
159
|
+
const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node);
|
|
160
|
+
|
|
161
|
+
char name_l0[GGML_MAX_NAME];
|
|
162
|
+
char name_l1[GGML_MAX_NAME];
|
|
163
|
+
|
|
164
|
+
snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0);
|
|
165
|
+
snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1);
|
|
166
|
+
|
|
167
|
+
const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0);
|
|
168
|
+
const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes;
|
|
169
|
+
|
|
170
|
+
if (idx_l0 < 0 || idx_l1 < 0) {
|
|
171
|
+
fprintf(stderr, "%s: layer input nodes not found\n", __func__);
|
|
172
|
+
return;
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
// attach the input data to all nodes that need it
|
|
176
|
+
// TODO: not great - should be able to do this without modifying the compute graph (see next TODO below)
|
|
177
|
+
for (int i = idx_l0; i < idx_l1; i++) {
|
|
178
|
+
if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) {
|
|
179
|
+
gf->nodes[i]->src[0] = inp0;
|
|
180
|
+
}
|
|
181
|
+
if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) {
|
|
182
|
+
gf->nodes[i]->src[1] = inp0;
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
// TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph
|
|
187
|
+
for (int i = 1; i < idx_l1 - idx_l0; i++) {
|
|
188
|
+
gf->nodes[i] = gf->nodes[idx_l0 + i];
|
|
189
|
+
gf->grads[i] = gf->grads[idx_l0 + i];
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node
|
|
193
|
+
if (mpi_idx != 0) {
|
|
194
|
+
gf->nodes[0]->op = GGML_OP_NONE;
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
gf->n_nodes = idx_l1 - idx_l0;
|
|
198
|
+
|
|
199
|
+
//fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
void ggml_mpi_graph_compute_post(
|
|
204
|
+
struct ggml_mpi_context * ctx_mpi,
|
|
205
|
+
struct ggml_cgraph * gf,
|
|
206
|
+
int n_layers) {
|
|
207
|
+
UNUSED(n_layers);
|
|
208
|
+
|
|
209
|
+
const int mpi_rank = ctx_mpi->rank;
|
|
210
|
+
const int mpi_size = ctx_mpi->size;
|
|
211
|
+
|
|
212
|
+
// send the output data to the next node
|
|
213
|
+
if (mpi_rank > 0) {
|
|
214
|
+
ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size);
|
|
215
|
+
}
|
|
216
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
struct ggml_context;
|
|
4
|
+
struct ggml_tensor;
|
|
5
|
+
struct ggml_cgraph;
|
|
6
|
+
|
|
7
|
+
#ifdef __cplusplus
|
|
8
|
+
extern "C" {
|
|
9
|
+
#endif
|
|
10
|
+
|
|
11
|
+
struct ggml_mpi_context;
|
|
12
|
+
|
|
13
|
+
void ggml_mpi_backend_init(void);
|
|
14
|
+
void ggml_mpi_backend_free(void);
|
|
15
|
+
|
|
16
|
+
struct ggml_mpi_context * ggml_mpi_init(void);
|
|
17
|
+
void ggml_mpi_free(struct ggml_mpi_context * ctx);
|
|
18
|
+
|
|
19
|
+
int ggml_mpi_rank(struct ggml_mpi_context * ctx);
|
|
20
|
+
|
|
21
|
+
void ggml_mpi_eval_init(
|
|
22
|
+
struct ggml_mpi_context * ctx_mpi,
|
|
23
|
+
int * n_tokens,
|
|
24
|
+
int * n_past,
|
|
25
|
+
int * n_threads);
|
|
26
|
+
|
|
27
|
+
void ggml_mpi_graph_compute_pre(
|
|
28
|
+
struct ggml_mpi_context * ctx_mpi,
|
|
29
|
+
struct ggml_cgraph * gf,
|
|
30
|
+
int n_layers);
|
|
31
|
+
|
|
32
|
+
void ggml_mpi_graph_compute_post(
|
|
33
|
+
struct ggml_mpi_context * ctx_mpi,
|
|
34
|
+
struct ggml_cgraph * gf,
|
|
35
|
+
int n_layers);
|
|
36
|
+
|
|
37
|
+
#ifdef __cplusplus
|
|
38
|
+
}
|
|
39
|
+
#endif
|