node-llama-cpp 2.8.14 → 2.8.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +6 -6
- package/dist/config.js +1 -1
- package/dist/config.js.map +1 -1
- package/dist/utils/compileLLamaCpp.js +23 -27
- package/dist/utils/compileLLamaCpp.js.map +1 -1
- package/llama/addon.cpp +1 -0
- package/llama/binariesGithubRelease.json +1 -1
- package/llama/gitRelease.bundle +0 -0
- package/llamaBins/linux-arm64/libggml.so +0 -0
- package/llamaBins/linux-arm64/libllama.so +0 -0
- package/llamaBins/linux-arm64/llama-addon.node +0 -0
- package/llamaBins/linux-armv7l/libggml.so +0 -0
- package/llamaBins/linux-armv7l/libllama.so +0 -0
- package/llamaBins/linux-armv7l/llama-addon.node +0 -0
- package/llamaBins/linux-x64/libggml.so +0 -0
- package/llamaBins/linux-x64/libllama.so +0 -0
- package/llamaBins/linux-x64/llama-addon.node +0 -0
- package/llamaBins/mac-arm64/ggml-common.h +5 -1
- package/llamaBins/mac-arm64/ggml-metal.metal +225 -19
- package/llamaBins/mac-arm64/libggml.dylib +0 -0
- package/llamaBins/mac-arm64/libllama.dylib +0 -0
- package/llamaBins/mac-arm64/llama-addon.node +0 -0
- package/llamaBins/mac-x64/ggml-common.h +5 -1
- package/llamaBins/mac-x64/ggml-metal.metal +225 -19
- package/llamaBins/mac-x64/libggml.dylib +0 -0
- package/llamaBins/mac-x64/libllama.dylib +0 -0
- package/llamaBins/mac-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/ggml.dll +0 -0
- package/llamaBins/win-x64/llama-addon.exp +0 -0
- package/llamaBins/win-x64/llama-addon.lib +0 -0
- package/llamaBins/win-x64/llama-addon.node +0 -0
- package/llamaBins/win-x64/llama.dll +0 -0
- package/package.json +2 -2
package/README.md
CHANGED
|
@@ -26,10 +26,10 @@
|
|
|
26
26
|
* Up-to-date with the latest version of `llama.cpp`. Download and compile the latest release with a single CLI command.
|
|
27
27
|
* Force a model to generate output in a parseable format, like JSON, or even force it to follow a specific JSON schema
|
|
28
28
|
|
|
29
|
-
## [Documentation](https://
|
|
30
|
-
* [Getting started guide](https://
|
|
31
|
-
* [API reference](https://
|
|
32
|
-
* [CLI help](https://
|
|
29
|
+
## [Documentation](https://node-llama-cpp.withcat.ai/)
|
|
30
|
+
* [Getting started guide](https://node-llama-cpp.withcat.ai/guide/)
|
|
31
|
+
* [API reference](https://node-llama-cpp.withcat.ai/api/classes/LlamaModel)
|
|
32
|
+
* [CLI help](https://node-llama-cpp.withcat.ai/guide/cli/)
|
|
33
33
|
* [Changelog](https://github.com/withcatai/node-llama-cpp/releases)
|
|
34
34
|
* [Roadmap](https://github.com/orgs/withcatai/projects/1)
|
|
35
35
|
|
|
@@ -72,10 +72,10 @@ const a2 = await session.prompt(q2);
|
|
|
72
72
|
console.log("AI: " + a2);
|
|
73
73
|
```
|
|
74
74
|
|
|
75
|
-
> For more examples, see the [getting started guide](https://
|
|
75
|
+
> For more examples, see the [getting started guide](https://node-llama-cpp.withcat.ai/guide/)
|
|
76
76
|
|
|
77
77
|
## Contributing
|
|
78
|
-
To contribute to `node-llama-cpp` read the [contribution guide](https://
|
|
78
|
+
To contribute to `node-llama-cpp` read the [contribution guide](https://node-llama-cpp.withcat.ai/guide/contributing).
|
|
79
79
|
|
|
80
80
|
## Acknowledgements
|
|
81
81
|
* llama.cpp: [ggerganov/llama.cpp](https://github.com/ggerganov/llama.cpp)
|
package/dist/config.js
CHANGED
|
@@ -53,7 +53,7 @@ export const defaultChatSystemPrompt = "You are a helpful, respectful and honest
|
|
|
53
53
|
"If you don't know the answer to a question, please don't share false information.";
|
|
54
54
|
export const cliBinName = "node-llama-cpp";
|
|
55
55
|
export const npxRunPrefix = "npx --no ";
|
|
56
|
-
const documentationUrl = "https://
|
|
56
|
+
const documentationUrl = "https://node-llama-cpp.withcat.ai";
|
|
57
57
|
export const documentationPageUrls = {
|
|
58
58
|
CUDA: documentationUrl + "/guide/CUDA"
|
|
59
59
|
};
|
package/dist/config.js.map
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"config.js","sourceRoot":"","sources":["../src/config.ts"],"names":[],"mappings":"AAAA,OAAO,EAAC,aAAa,EAAC,MAAM,KAAK,CAAC;AAClC,OAAO,KAAK,IAAI,MAAM,MAAM,CAAC;AAC7B,OAAO,KAAK,EAAE,MAAM,IAAI,CAAC;AACzB,OAAO,OAAO,MAAM,SAAS,CAAC;AAC9B,OAAO,MAAM,MAAM,SAAS,CAAC;AAC7B,OAAO,KAAK,IAAI,MAAM,MAAM,CAAC;AAC7B,OAAO,EAAC,wBAAwB,EAAC,MAAM,kCAAkC,CAAC;AAE1E,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;AAE/D,MAAM,GAAG,GAAG,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;AAGrC,MAAM,CAAC,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,OAAO,CAAC,CAAC;AAClE,MAAM,CAAC,MAAM,wBAAwB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,YAAY,CAAC,CAAC;AAChF,MAAM,CAAC,MAAM,kBAAkB,GAAG,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,WAAW,CAAC,CAAC;AAC1E,MAAM,CAAC,MAAM,0BAA0B,GAAG,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,OAAO,EAAE,UAAU,CAAC,CAAC;AAC1F,MAAM,CAAC,MAAM,iBAAiB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,CAAC;AACxE,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;AAC5F,MAAM,CAAC,MAAM,qBAAqB,GAAG,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,MAAM,EAAE,EAAE,gBAAgB,EAAE,IAAI,CAAC,EAAE,EAAE,CAAC,CAAC;AACzF,MAAM,CAAC,MAAM,0BAA0B,GAAG,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,OAAO,EAAE,EAAE,mCAAmC,CAAC,CAAC;AACvG,MAAM,CAAC,MAAM,mBAAmB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,cAAc,CAAC,CAAC;AAC7E,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,4BAA4B,CAAC,CAAC;AACjG,MAAM,CAAC,MAAM,4BAA4B,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,oBAAoB,CAAC,CAAC;AAC5F,MAAM,CAAC,MAAM,2BAA2B,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,mBAAmB,CAAC,CAAC;AAC1F,MAAM,CAAC,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;AACjE,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;AAC5E,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;AAC5E,MAAM,CAAC,MAAM,UAAU,GAAG,SAAS,CAAC;AAEpC,MAAM,CAAC,MAAM,IAAI,GAAG,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC;KAC5B,OAAO,CAAC,OAAO,CAAC;KAChB,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,yBAAyB,GAAG,GAAG,CAAC,GAAG,CAAC,qBAAqB,CAAC;KAClE,OAAO,CAAC,qBAAqB,CAAC;KAC9B,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,sBAAsB,GAAG,GAAG,CAAC,GAAG,CAAC,6BAA6B,CAAC;KACvE,OAAO,CAAC,MAAM,wBAAwB,EAAE,CAAC;KACzC,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,2BAA2B,GAAG,GAAG,CAAC,GAAG,CAAC,sBAAsB,CAAC;KACrE,OAAO,CAAC,OAAO,CAAC,QAAQ,KAAK,QAAQ,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,OAAO,CAAC;KACzD,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,0BAA0B,GAAG,GAAG,CAAC,GAAG,CAAC,qBAAqB,CAAC;KACnE,OAAO,CAAC,OAAO,CAAC;KAChB,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,mBAAmB,GAAG,GAAG,CAAC,GAAG,CAAC,8BAA8B,CAAC;KACrE,OAAO,CAAC,OAAO,CAAC;KAChB,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,2BAA2B,GAAG,GAAG,CAAC,GAAG,CAAC,oCAAoC,CAAC;KACnF,OAAO,CAAC,yBAAyB,CAAC;KAClC,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,2BAA2B,GAAG,GAAG,CAAC,GAAG,CAAC,oCAAoC,CAAC;KACnF,OAAO,CAAC,yBAAyB,CAAC;KAClC,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,8BAA8B,GAAG,8BAA8B,CAAC;AAC7E,MAAM,CAAC,MAAM,uBAAuB,GAAG,+FAA+F;IAClI,+HAA+H;IAC/H,mFAAmF,CAAC;AACxF,MAAM,CAAC,MAAM,UAAU,GAAG,gBAAgB,CAAC;AAC3C,MAAM,CAAC,MAAM,YAAY,GAAG,WAAW,CAAC;AAExC,MAAM,gBAAgB,GAAG,
|
|
1
|
+
{"version":3,"file":"config.js","sourceRoot":"","sources":["../src/config.ts"],"names":[],"mappings":"AAAA,OAAO,EAAC,aAAa,EAAC,MAAM,KAAK,CAAC;AAClC,OAAO,KAAK,IAAI,MAAM,MAAM,CAAC;AAC7B,OAAO,KAAK,EAAE,MAAM,IAAI,CAAC;AACzB,OAAO,OAAO,MAAM,SAAS,CAAC;AAC9B,OAAO,MAAM,MAAM,SAAS,CAAC;AAC7B,OAAO,KAAK,IAAI,MAAM,MAAM,CAAC;AAC7B,OAAO,EAAC,wBAAwB,EAAC,MAAM,kCAAkC,CAAC;AAE1E,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;AAE/D,MAAM,GAAG,GAAG,MAAM,CAAC,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC;AAGrC,MAAM,CAAC,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,OAAO,CAAC,CAAC;AAClE,MAAM,CAAC,MAAM,wBAAwB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,YAAY,CAAC,CAAC;AAChF,MAAM,CAAC,MAAM,kBAAkB,GAAG,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,WAAW,CAAC,CAAC;AAC1E,MAAM,CAAC,MAAM,0BAA0B,GAAG,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,IAAI,EAAE,OAAO,EAAE,UAAU,CAAC,CAAC;AAC1F,MAAM,CAAC,MAAM,iBAAiB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,CAAC;AACxE,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,EAAE,UAAU,CAAC,CAAC;AAC5F,MAAM,CAAC,MAAM,qBAAqB,GAAG,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,MAAM,EAAE,EAAE,gBAAgB,EAAE,IAAI,CAAC,EAAE,EAAE,CAAC,CAAC;AACzF,MAAM,CAAC,MAAM,0BAA0B,GAAG,IAAI,CAAC,IAAI,CAAC,EAAE,CAAC,OAAO,EAAE,EAAE,mCAAmC,CAAC,CAAC;AACvG,MAAM,CAAC,MAAM,mBAAmB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,cAAc,CAAC,CAAC;AAC7E,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,4BAA4B,CAAC,CAAC;AACjG,MAAM,CAAC,MAAM,4BAA4B,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,oBAAoB,CAAC,CAAC;AAC5F,MAAM,CAAC,MAAM,2BAA2B,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,mBAAmB,CAAC,CAAC;AAC1F,MAAM,CAAC,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;AACjE,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;AAC5E,MAAM,CAAC,MAAM,yBAAyB,GAAG,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;AAC5E,MAAM,CAAC,MAAM,UAAU,GAAG,SAAS,CAAC;AAEpC,MAAM,CAAC,MAAM,IAAI,GAAG,GAAG,CAAC,GAAG,CAAC,IAAI,CAAC;KAC5B,OAAO,CAAC,OAAO,CAAC;KAChB,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,yBAAyB,GAAG,GAAG,CAAC,GAAG,CAAC,qBAAqB,CAAC;KAClE,OAAO,CAAC,qBAAqB,CAAC;KAC9B,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,sBAAsB,GAAG,GAAG,CAAC,GAAG,CAAC,6BAA6B,CAAC;KACvE,OAAO,CAAC,MAAM,wBAAwB,EAAE,CAAC;KACzC,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,2BAA2B,GAAG,GAAG,CAAC,GAAG,CAAC,sBAAsB,CAAC;KACrE,OAAO,CAAC,OAAO,CAAC,QAAQ,KAAK,QAAQ,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,OAAO,CAAC;KACzD,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,0BAA0B,GAAG,GAAG,CAAC,GAAG,CAAC,qBAAqB,CAAC;KACnE,OAAO,CAAC,OAAO,CAAC;KAChB,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,mBAAmB,GAAG,GAAG,CAAC,GAAG,CAAC,8BAA8B,CAAC;KACrE,OAAO,CAAC,OAAO,CAAC;KAChB,MAAM,EAAE,CAAC;AACd,MAAM,CAAC,MAAM,2BAA2B,GAAG,GAAG,CAAC,GAAG,CAAC,oCAAoC,CAAC;KACnF,OAAO,CAAC,yBAAyB,CAAC;KAClC,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,2BAA2B,GAAG,GAAG,CAAC,GAAG,CAAC,oCAAoC,CAAC;KACnF,OAAO,CAAC,yBAAyB,CAAC;KAClC,QAAQ,EAAE,CAAC;AAChB,MAAM,CAAC,MAAM,8BAA8B,GAAG,8BAA8B,CAAC;AAC7E,MAAM,CAAC,MAAM,uBAAuB,GAAG,+FAA+F;IAClI,+HAA+H;IAC/H,mFAAmF,CAAC;AACxF,MAAM,CAAC,MAAM,UAAU,GAAG,gBAAgB,CAAC;AAC3C,MAAM,CAAC,MAAM,YAAY,GAAG,WAAW,CAAC;AAExC,MAAM,gBAAgB,GAAG,mCAAmC,CAAC;AAC7D,MAAM,CAAC,MAAM,qBAAqB,GAAG;IACjC,IAAI,EAAE,gBAAgB,GAAG,aAAa;CAChC,CAAC"}
|
|
@@ -18,34 +18,30 @@ export async function compileLlamaCpp({ arch = process.arch, nodeTarget = proces
|
|
|
18
18
|
const toolchainFile = await getToolchainFileForArch(arch);
|
|
19
19
|
const runtimeVersion = nodeTarget.startsWith("v") ? nodeTarget.slice("v".length) : nodeTarget;
|
|
20
20
|
const cmakeCustomOptions = new Map();
|
|
21
|
-
if ((metal && process.platform === "darwin") || process.env.
|
|
22
|
-
cmakeCustomOptions.set("
|
|
21
|
+
if ((metal && process.platform === "darwin") || process.env.GGML_METAL === "1")
|
|
22
|
+
cmakeCustomOptions.set("GGML_METAL", "1");
|
|
23
23
|
else
|
|
24
|
-
cmakeCustomOptions.set("
|
|
25
|
-
if (cuda || process.env.
|
|
26
|
-
cmakeCustomOptions.set("
|
|
27
|
-
if (process.env.
|
|
28
|
-
cmakeCustomOptions.set("
|
|
29
|
-
if (process.env.
|
|
30
|
-
cmakeCustomOptions.set("
|
|
31
|
-
if (process.env.
|
|
32
|
-
cmakeCustomOptions.set("
|
|
33
|
-
if (process.env.
|
|
34
|
-
cmakeCustomOptions.set("
|
|
35
|
-
if (process.env.
|
|
36
|
-
cmakeCustomOptions.set("
|
|
37
|
-
if (process.env.
|
|
38
|
-
cmakeCustomOptions.set("
|
|
39
|
-
if (process.env.
|
|
40
|
-
cmakeCustomOptions.set("
|
|
41
|
-
if (process.env.
|
|
42
|
-
cmakeCustomOptions.set("
|
|
43
|
-
if (process.env.
|
|
44
|
-
cmakeCustomOptions.set("
|
|
45
|
-
if (process.env.LLAMA_HIPBLAS === "1")
|
|
46
|
-
cmakeCustomOptions.set("LLAMA_HIPBLAS", "1");
|
|
47
|
-
if (process.env.LLAMA_CLBLAST === "1")
|
|
48
|
-
cmakeCustomOptions.set("LLAMA_CLBLAST", "1");
|
|
24
|
+
cmakeCustomOptions.set("GGML_METAL", "OFF");
|
|
25
|
+
if (cuda || process.env.GGML_CUDA === "1")
|
|
26
|
+
cmakeCustomOptions.set("GGML_CUDA", "1");
|
|
27
|
+
if (process.env.GGML_OPENBLAS === "1")
|
|
28
|
+
cmakeCustomOptions.set("GGML_OPENBLAS", "1");
|
|
29
|
+
if (process.env.GGML_BLAS_VENDOR != null)
|
|
30
|
+
cmakeCustomOptions.set("GGML_BLAS_VENDOR", process.env.GGML_BLAS_VENDOR);
|
|
31
|
+
if (process.env.GGML_CUDA_FORCE_DMMV != null)
|
|
32
|
+
cmakeCustomOptions.set("GGML_CUDA_FORCE_DMMV", process.env.GGML_CUDA_FORCE_DMMV);
|
|
33
|
+
if (process.env.GGML_CUDA_DMMV_X != null)
|
|
34
|
+
cmakeCustomOptions.set("GGML_CUDA_DMMV_X", process.env.GGML_CUDA_DMMV_X);
|
|
35
|
+
if (process.env.GGML_CUDA_MMV_Y != null)
|
|
36
|
+
cmakeCustomOptions.set("GGML_CUDA_MMV_Y", process.env.GGML_CUDA_MMV_Y);
|
|
37
|
+
if (process.env.GGML_CUDA_F16 != null)
|
|
38
|
+
cmakeCustomOptions.set("GGML_CUDA_F16", process.env.GGML_CUDA_F16);
|
|
39
|
+
if (process.env.GGML_CUDA_KQUANTS_ITER != null)
|
|
40
|
+
cmakeCustomOptions.set("GGML_CUDA_KQUANTS_ITER", process.env.GGML_CUDA_KQUANTS_ITER);
|
|
41
|
+
if (process.env.GGML_CUDA_PEER_MAX_BATCH_SIZE != null)
|
|
42
|
+
cmakeCustomOptions.set("GGML_CUDA_PEER_MAX_BATCH_SIZE", process.env.GGML_CUDA_PEER_MAX_BATCH_SIZE);
|
|
43
|
+
if (process.env.GGML_HIPBLAS === "1")
|
|
44
|
+
cmakeCustomOptions.set("GGML_HIPBLAS", "1");
|
|
49
45
|
if (toolchainFile != null)
|
|
50
46
|
cmakeCustomOptions.set("CMAKE_TOOLCHAIN_FILE", toolchainFile);
|
|
51
47
|
for (const key in process.env) {
|
|
@@ -1 +1 @@
|
|
|
1
|
-
{"version":3,"file":"compileLLamaCpp.js","sourceRoot":"","sources":["../../src/utils/compileLLamaCpp.ts"],"names":[],"mappings":"AAAA,OAAO,IAAI,MAAM,MAAM,CAAC;AACxB,OAAO,EAAC,aAAa,EAAC,MAAM,KAAK,CAAC;AAClC,OAAO,OAAO,MAAM,SAAS,CAAC;AAC9B,OAAO,EAAE,MAAM,UAAU,CAAC;AAC1B,OAAO,KAAK,MAAM,OAAO,CAAC;AAC1B,OAAO,EACH,8BAA8B,EAAE,qBAAqB,EAAE,iBAAiB,EAAE,cAAc,EAAE,wBAAwB,EACrH,MAAM,cAAc,CAAC;AACtB,OAAO,EAAC,eAAe,EAAC,MAAM,sBAAsB,CAAC;AACrD,OAAO,EAAC,cAAc,EAAC,MAAM,kBAAkB,CAAC;AAChD,OAAO,EAAC,YAAY,EAAC,MAAM,mBAAmB,CAAC;AAC/C,OAAO,EAAC,mBAAmB,EAAE,YAAY,EAAE,eAAe,EAAC,MAAM,YAAY,CAAC;AAE9E,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;AAE/D,MAAM,CAAC,KAAK,UAAU,eAAe,CAAC,EAClC,IAAI,GAAG,OAAO,CAAC,IAAI,EAAE,UAAU,GAAG,OAAO,CAAC,OAAO,EAAE,cAAc,EAAE,iBAAiB,GAAG,IAAI,EAAE,KAAK,GAAG,OAAO,CAAC,QAAQ,KAAK,QAAQ,EAClI,IAAI,GAAG,KAAK,EAGf;IACG,IAAI;QACA,IAAI,CAAC,CAAC,MAAM,EAAE,CAAC,UAAU,CAAC,iBAAiB,CAAC,CAAC,EAAE;YAC3C,MAAM,IAAI,KAAK,CAAC,IAAI,iBAAiB,4BAA4B,CAAC,CAAC;SACtE;QAED,MAAM,aAAa,GAAG,MAAM,gBAAgB,EAAE,CAAC;QAC/C,MAAM,aAAa,GAAG,MAAM,uBAAuB,CAAC,IAAI,CAAC,CAAC;QAC1D,MAAM,cAAc,GAAG,UAAU,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,KAAK,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;QAC9F,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAC;QAErD,IAAI,CAAC,KAAK,IAAI,OAAO,CAAC,QAAQ,KAAK,QAAQ,CAAC,IAAI,OAAO,CAAC,GAAG,CAAC,
|
|
1
|
+
{"version":3,"file":"compileLLamaCpp.js","sourceRoot":"","sources":["../../src/utils/compileLLamaCpp.ts"],"names":[],"mappings":"AAAA,OAAO,IAAI,MAAM,MAAM,CAAC;AACxB,OAAO,EAAC,aAAa,EAAC,MAAM,KAAK,CAAC;AAClC,OAAO,OAAO,MAAM,SAAS,CAAC;AAC9B,OAAO,EAAE,MAAM,UAAU,CAAC;AAC1B,OAAO,KAAK,MAAM,OAAO,CAAC;AAC1B,OAAO,EACH,8BAA8B,EAAE,qBAAqB,EAAE,iBAAiB,EAAE,cAAc,EAAE,wBAAwB,EACrH,MAAM,cAAc,CAAC;AACtB,OAAO,EAAC,eAAe,EAAC,MAAM,sBAAsB,CAAC;AACrD,OAAO,EAAC,cAAc,EAAC,MAAM,kBAAkB,CAAC;AAChD,OAAO,EAAC,YAAY,EAAC,MAAM,mBAAmB,CAAC;AAC/C,OAAO,EAAC,mBAAmB,EAAE,YAAY,EAAE,eAAe,EAAC,MAAM,YAAY,CAAC;AAE9E,MAAM,SAAS,GAAG,IAAI,CAAC,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;AAE/D,MAAM,CAAC,KAAK,UAAU,eAAe,CAAC,EAClC,IAAI,GAAG,OAAO,CAAC,IAAI,EAAE,UAAU,GAAG,OAAO,CAAC,OAAO,EAAE,cAAc,EAAE,iBAAiB,GAAG,IAAI,EAAE,KAAK,GAAG,OAAO,CAAC,QAAQ,KAAK,QAAQ,EAClI,IAAI,GAAG,KAAK,EAGf;IACG,IAAI;QACA,IAAI,CAAC,CAAC,MAAM,EAAE,CAAC,UAAU,CAAC,iBAAiB,CAAC,CAAC,EAAE;YAC3C,MAAM,IAAI,KAAK,CAAC,IAAI,iBAAiB,4BAA4B,CAAC,CAAC;SACtE;QAED,MAAM,aAAa,GAAG,MAAM,gBAAgB,EAAE,CAAC;QAC/C,MAAM,aAAa,GAAG,MAAM,uBAAuB,CAAC,IAAI,CAAC,CAAC;QAC1D,MAAM,cAAc,GAAG,UAAU,CAAC,UAAU,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC,KAAK,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,UAAU,CAAC;QAC9F,MAAM,kBAAkB,GAAG,IAAI,GAAG,EAAkB,CAAC;QAErD,IAAI,CAAC,KAAK,IAAI,OAAO,CAAC,QAAQ,KAAK,QAAQ,CAAC,IAAI,OAAO,CAAC,GAAG,CAAC,UAAU,KAAK,GAAG;YAAE,kBAAkB,CAAC,GAAG,CAAC,YAAY,EAAE,GAAG,CAAC,CAAC;;YACrH,kBAAkB,CAAC,GAAG,CAAC,YAAY,EAAE,KAAK,CAAC,CAAC;QAEjD,IAAI,IAAI,IAAI,OAAO,CAAC,GAAG,CAAC,SAAS,KAAK,GAAG;YAAE,kBAAkB,CAAC,GAAG,CAAC,WAAW,EAAE,GAAG,CAAC,CAAC;QAEpF,IAAI,OAAO,CAAC,GAAG,CAAC,aAAa,KAAK,GAAG;YAAE,kBAAkB,CAAC,GAAG,CAAC,eAAe,EAAE,GAAG,CAAC,CAAC;QACpF,IAAI,OAAO,CAAC,GAAG,CAAC,gBAAgB,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,kBAAkB,EAAE,OAAO,CAAC,GAAG,CAAC,gBAAgB,CAAC,CAAC;QACnH,IAAI,OAAO,CAAC,GAAG,CAAC,oBAAoB,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,sBAAsB,EAAE,OAAO,CAAC,GAAG,CAAC,oBAAoB,CAAC,CAAC;QAC/H,IAAI,OAAO,CAAC,GAAG,CAAC,gBAAgB,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,kBAAkB,EAAE,OAAO,CAAC,GAAG,CAAC,gBAAgB,CAAC,CAAC;QACnH,IAAI,OAAO,CAAC,GAAG,CAAC,eAAe,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,iBAAiB,EAAE,OAAO,CAAC,GAAG,CAAC,eAAe,CAAC,CAAC;QAChH,IAAI,OAAO,CAAC,GAAG,CAAC,aAAa,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,eAAe,EAAE,OAAO,CAAC,GAAG,CAAC,aAAa,CAAC,CAAC;QAC1G,IAAI,OAAO,CAAC,GAAG,CAAC,sBAAsB,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,wBAAwB,EAAE,OAAO,CAAC,GAAG,CAAC,sBAAsB,CAAC,CAAC;QACrI,IAAI,OAAO,CAAC,GAAG,CAAC,6BAA6B,IAAI,IAAI;YAAE,kBAAkB,CAAC,GAAG,CAAC,+BAA+B,EAAE,OAAO,CAAC,GAAG,CAAC,6BAA6B,CAAC,CAAC;QAC1J,IAAI,OAAO,CAAC,GAAG,CAAC,YAAY,KAAK,GAAG;YAAE,kBAAkB,CAAC,GAAG,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC;QAElF,IAAI,aAAa,IAAI,IAAI;YACrB,kBAAkB,CAAC,GAAG,CAAC,sBAAsB,EAAE,aAAa,CAAC,CAAC;QAElE,KAAK,MAAM,GAAG,IAAI,OAAO,CAAC,GAAG,EAAE;YAC3B,IAAI,GAAG,CAAC,UAAU,CAAC,8BAA8B,CAAC,EAAE;gBAChD,MAAM,MAAM,GAAG,GAAG,CAAC,KAAK,CAAC,8BAA8B,CAAC,MAAM,CAAC,CAAC;gBAChE,MAAM,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,GAAG,CAAC,CAAC;gBAC/B,kBAAkB,CAAC,GAAG,CAAC,MAAM,EAAE,KAAM,CAAC,CAAC;aAC1C;SACJ;QAED,MAAM,eAAe,EAAE,CAAC;QAExB,MAAM,YAAY,CAAC,KAAK,EAAE,CAAC,KAAK,EAAE,IAAI,EAAE,gBAAgB,EAAE,IAAI,EAAE,OAAO,EAAE,aAAa,EAAE,MAAM,EAAE,GAAG,aAAa,CAAC,EAAE,SAAS,CAAC,CAAC;QAE9H,MAAM,YAAY,CACd,KAAK,EACL,CAAC,KAAK,EAAE,IAAI,EAAE,gBAAgB,EAAE,IAAI,EAAE,SAAS,EAAE,aAAa,EAAE,MAAM,EAAE,SAAS,GAAG,IAAI,EAAE,oBAAoB,GAAG,cAAc,EAAE,GAAG,aAAa,CAAC;aAC7I,MAAM,CAAC,CAAC,GAAG,kBAAkB,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,GAAG,EAAE,KAAK,CAAC,EAAE,EAAE,CAAC,MAAM,GAAG,GAAG,GAAG,GAAG,GAAG,KAAK,CAAC,CAAC,EACtF,SAAS,CACZ,CAAC;QAEF,MAAM,gBAAgB,GAAG;YACrB,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,EAAE,KAAK,CAAC;YACzC,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,EAAE,WAAW,EAAE,KAAK,CAAC;SACzD,CAAC;QACF,MAAM,qBAAqB,GAAG,MAAM,oBAAoB,CAAC,IAAI,CAAC,CAAC;QAE/D,KAAK,MAAM,eAAe,IAAI,gBAAgB,EAAE;YAC5C,IAAI,MAAM,EAAE,CAAC,UAAU,CAAC,eAAe,CAAC,EAAE;gBACtC,MAAM,KAAK,GAAG,MAAM,EAAE,CAAC,OAAO,CAAC,eAAe,CAAC,CAAC;gBAEhD,MAAM,OAAO,CAAC,GAAG,CACb,KAAK,CAAC,GAAG,CAAC,CAAC,QAAQ,EAAE,EAAE,CAAC,CACpB,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,eAAe,EAAE,QAAQ,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,qBAAqB,EAAE,QAAQ,CAAC,EAAE;oBACtF,SAAS,EAAE,KAAK;iBACnB,CAAC,CACL,CAAC,CACL,CAAC;aACL;SACJ;QAED,mBAAmB,CAAC,qBAAqB,EAAE,IAAI,CAAC,IAAI,CAAC,qBAAqB,EAAE,QAAQ,CAAC,CAAC,CAAC;QAEvF,IAAI,iBAAiB,EAAE;YACnB,MAAM,cAAc,CAAC,sBAAsB,CAAC,CAAC;SAChD;KACJ;IAAC,OAAO,GAAG,EAAE;QACV,IAAI,iBAAiB;YACjB,MAAM,cAAc,CAAC,kBAAkB,CAAC,CAAC;QAE7C,IAAI,IAAI;YACJ,OAAO,CAAC,IAAI,CAAC,IAAI;gBACb,KAAK,CAAC,IAAI,CAAC,mBAAmB,CAAC;gBAC/B,KAAK,CAAC,MAAM,CAAC,qEAAqE,CAAC;gBACnF,qBAAqB,CAAC,IAAI,CAC7B,CAAC;QAEN,MAAM,GAAG,CAAC;KACb;YAAS;QACN,MAAM,mBAAmB,EAAE,CAAC;KAC/B;AACL,CAAC;AAED,MAAM,CAAC,KAAK,UAAU,6BAA6B;IAC/C,MAAM,qBAAqB,GAAG,MAAM,oBAAoB,CAAC,KAAK,CAAC,CAAC;IAEhE,IAAI,qBAAqB,IAAI,IAAI;QAC7B,OAAO,IAAI,CAAC;IAEhB,MAAM,UAAU,GAAG,IAAI,CAAC,IAAI,CAAC,qBAAqB,EAAE,kBAAkB,CAAC,CAAC;IAExE,IAAI,MAAM,EAAE,CAAC,UAAU,CAAC,UAAU,CAAC;QAC/B,OAAO,UAAU,CAAC;IAEtB,OAAO,IAAI,CAAC;AAChB,CAAC;AAID,KAAK,UAAU,oBAAoB,CAAC,iBAA0B,KAAK;IAC/D,IAAI,MAAM,EAAE,CAAC,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC,EAAE;QACpE,OAAO,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,EAAE,SAAS,CAAC,CAAC;KACxD;SAAM,IAAI,MAAM,EAAE,CAAC,UAAU,CAAC,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC,EAAE;QACzE,OAAO,IAAI,CAAC,IAAI,CAAC,cAAc,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;KACtD;IAED,IAAI,cAAc;QACd,MAAM,IAAI,KAAK,CAAC,2CAA2C,CAAC,CAAC;IAEjE,OAAO,IAAI,CAAC;AAChB,CAAC;AAED,KAAK,UAAU,gBAAgB;IAC3B,IAAI,MAAM,eAAe,EAAE;QACvB,OAAO,EAAE,CAAC;IAEd,MAAM,SAAS,GAAG,MAAM,YAAY,EAAE,CAAC;IAEvC,IAAI,SAAS,IAAI,IAAI;QACjB,OAAO,EAAE,CAAC;IAEd,OAAO,CAAC,cAAc,EAAE,SAAS,CAAC,CAAC;AACvC,CAAC;AAED,KAAK,UAAU,uBAAuB,CAAC,UAAkB;IACrD,IAAI,OAAO,CAAC,IAAI,KAAK,UAAU;QAC3B,OAAO,IAAI,CAAC;IAEhB,MAAM,QAAQ,GAAG,OAAO,CAAC,QAAQ,CAAC;IAClC,MAAM,QAAQ,GAAG,OAAO,CAAC,IAAI,CAAC;IAE9B,MAAM,iBAAiB,GAAG,GAAG,QAAQ,SAAS,QAAQ,WAAW,UAAU,QAAQ,CAAC;IAEpF,MAAM,QAAQ,GAAG,IAAI,CAAC,IAAI,CAAC,wBAAwB,EAAE,iBAAiB,CAAC,CAAC;IAExE,IAAI,MAAM,EAAE,CAAC,UAAU,CAAC,QAAQ,CAAC;QAC7B,OAAO,QAAQ,CAAC;IAEpB,OAAO,IAAI,CAAC;AAChB,CAAC;AAED,KAAK,UAAU,mBAAmB,CAAC,aAAqB,EAAE,WAAmB;IACzE,MAAM,cAAc,GAAG,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,SAAS,CAAC,CAAC;IAE3D,IAAI,MAAM,EAAE,CAAC,UAAU,CAAC,cAAc,CAAC,EAAE;QACrC,MAAM,EAAE,CAAC,MAAM,CAAC,WAAW,CAAC,CAAC;QAC7B,MAAM,EAAE,CAAC,IAAI,CAAC,cAAc,EAAE,WAAW,CAAC,CAAC;QAE3C,MAAM,SAAS,GAAG,MAAM,EAAE,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC;QAEhD,MAAM,OAAO,CAAC,GAAG,CACb,SAAS,CAAC,GAAG,CAAC,CAAC,QAAQ,EAAE,EAAE,CAAC,CACxB,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,IAAI,CAAC,WAAW,EAAE,QAAQ,CAAC,EAAE,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,QAAQ,CAAC,EAAE;YAC1E,SAAS,EAAE,IAAI;SAClB,CAAC,CACL,CAAC,CACL,CAAC;QAEF,MAAM,EAAE,CAAC,MAAM,CAAC,WAAW,CAAC,CAAC;KAChC;AACL,CAAC"}
|
package/llama/addon.cpp
CHANGED
package/llama/gitRelease.bundle
CHANGED
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -19,7 +19,11 @@ typedef half2 ggml_half2;
|
|
|
19
19
|
|
|
20
20
|
#define GGML_COMMON_DECL
|
|
21
21
|
#elif defined(GGML_COMMON_DECL_CUDA)
|
|
22
|
+
#if defined(GGML_COMMON_DECL_MUSA)
|
|
23
|
+
#include <musa_fp16.h>
|
|
24
|
+
#else
|
|
22
25
|
#include <cuda_fp16.h>
|
|
26
|
+
#endif
|
|
23
27
|
#include <cstdint>
|
|
24
28
|
|
|
25
29
|
typedef half ggml_half;
|
|
@@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
|
|
|
415
419
|
#define GGML_TABLE_END() };
|
|
416
420
|
|
|
417
421
|
#define GGML_COMMON_IMPL
|
|
418
|
-
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
|
|
422
|
+
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
|
|
419
423
|
#include <cstdint>
|
|
420
424
|
|
|
421
425
|
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
|
|
@@ -17,7 +17,7 @@ enum ggml_sort_order {
|
|
|
17
17
|
GGML_SORT_ORDER_DESC,
|
|
18
18
|
};
|
|
19
19
|
|
|
20
|
-
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
20
|
+
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
|
21
21
|
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
|
22
22
|
// cons: not very efficient
|
|
23
23
|
kernel void kernel_add(
|
|
@@ -70,6 +70,56 @@ kernel void kernel_add(
|
|
|
70
70
|
}
|
|
71
71
|
}
|
|
72
72
|
|
|
73
|
+
kernel void kernel_sub(
|
|
74
|
+
device const char * src0,
|
|
75
|
+
device const char * src1,
|
|
76
|
+
device char * dst,
|
|
77
|
+
constant int64_t & ne00,
|
|
78
|
+
constant int64_t & ne01,
|
|
79
|
+
constant int64_t & ne02,
|
|
80
|
+
constant int64_t & ne03,
|
|
81
|
+
constant uint64_t & nb00,
|
|
82
|
+
constant uint64_t & nb01,
|
|
83
|
+
constant uint64_t & nb02,
|
|
84
|
+
constant uint64_t & nb03,
|
|
85
|
+
constant int64_t & ne10,
|
|
86
|
+
constant int64_t & ne11,
|
|
87
|
+
constant int64_t & ne12,
|
|
88
|
+
constant int64_t & ne13,
|
|
89
|
+
constant uint64_t & nb10,
|
|
90
|
+
constant uint64_t & nb11,
|
|
91
|
+
constant uint64_t & nb12,
|
|
92
|
+
constant uint64_t & nb13,
|
|
93
|
+
constant int64_t & ne0,
|
|
94
|
+
constant int64_t & ne1,
|
|
95
|
+
constant int64_t & ne2,
|
|
96
|
+
constant int64_t & ne3,
|
|
97
|
+
constant uint64_t & nb0,
|
|
98
|
+
constant uint64_t & nb1,
|
|
99
|
+
constant uint64_t & nb2,
|
|
100
|
+
constant uint64_t & nb3,
|
|
101
|
+
constant int64_t & offs,
|
|
102
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
103
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
104
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
105
|
+
const int64_t i03 = tgpig.z;
|
|
106
|
+
const int64_t i02 = tgpig.y;
|
|
107
|
+
const int64_t i01 = tgpig.x;
|
|
108
|
+
|
|
109
|
+
const int64_t i13 = i03 % ne13;
|
|
110
|
+
const int64_t i12 = i02 % ne12;
|
|
111
|
+
const int64_t i11 = i01 % ne11;
|
|
112
|
+
|
|
113
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
|
114
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
115
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
|
116
|
+
|
|
117
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
118
|
+
const int i10 = i0 % ne10;
|
|
119
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
73
123
|
kernel void kernel_mul(
|
|
74
124
|
device const char * src0,
|
|
75
125
|
device const char * src1,
|
|
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
|
|
|
226
276
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
227
277
|
}
|
|
228
278
|
|
|
279
|
+
kernel void kernel_sub_row(
|
|
280
|
+
device const float4 * src0,
|
|
281
|
+
device const float4 * src1,
|
|
282
|
+
device float4 * dst,
|
|
283
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
284
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
285
|
+
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
|
286
|
+
}
|
|
287
|
+
|
|
229
288
|
kernel void kernel_mul_row(
|
|
230
289
|
device const float4 * src0,
|
|
231
290
|
device const float4 * src1,
|
|
@@ -358,6 +417,27 @@ kernel void kernel_sqr(
|
|
|
358
417
|
dst[tpig] = src0[tpig] * src0[tpig];
|
|
359
418
|
}
|
|
360
419
|
|
|
420
|
+
kernel void kernel_sqrt(
|
|
421
|
+
device const float * src0,
|
|
422
|
+
device float * dst,
|
|
423
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
424
|
+
dst[tpig] = sqrt(src0[tpig]);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
kernel void kernel_sin(
|
|
428
|
+
device const float * src0,
|
|
429
|
+
device float * dst,
|
|
430
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
431
|
+
dst[tpig] = sin(src0[tpig]);
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
kernel void kernel_cos(
|
|
435
|
+
device const float * src0,
|
|
436
|
+
device float * dst,
|
|
437
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
438
|
+
dst[tpig] = cos(src0[tpig]);
|
|
439
|
+
}
|
|
440
|
+
|
|
361
441
|
kernel void kernel_sum_rows(
|
|
362
442
|
device const float * src0,
|
|
363
443
|
device float * dst,
|
|
@@ -667,6 +747,127 @@ kernel void kernel_diag_mask_inf_8(
|
|
|
667
747
|
}
|
|
668
748
|
}
|
|
669
749
|
|
|
750
|
+
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
|
751
|
+
// TODO: optimize
|
|
752
|
+
kernel void kernel_ssm_conv_f32(
|
|
753
|
+
device const void * src0,
|
|
754
|
+
device const void * src1,
|
|
755
|
+
device float * dst,
|
|
756
|
+
constant int64_t & ne00,
|
|
757
|
+
constant int64_t & ne01,
|
|
758
|
+
constant int64_t & ne02,
|
|
759
|
+
constant uint64_t & nb00,
|
|
760
|
+
constant uint64_t & nb01,
|
|
761
|
+
constant uint64_t & nb02,
|
|
762
|
+
constant int64_t & ne10,
|
|
763
|
+
constant int64_t & ne11,
|
|
764
|
+
constant uint64_t & nb10,
|
|
765
|
+
constant uint64_t & nb11,
|
|
766
|
+
constant int64_t & ne0,
|
|
767
|
+
constant int64_t & ne1,
|
|
768
|
+
constant int64_t & ne2,
|
|
769
|
+
constant uint64_t & nb0,
|
|
770
|
+
constant uint64_t & nb1,
|
|
771
|
+
constant uint64_t & nb2,
|
|
772
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
773
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
774
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
775
|
+
const int64_t ir = tgpig.x;
|
|
776
|
+
const int64_t i2 = tgpig.y;
|
|
777
|
+
const int64_t i3 = tgpig.z;
|
|
778
|
+
|
|
779
|
+
const int64_t nc = ne10;
|
|
780
|
+
const int64_t ncs = ne00;
|
|
781
|
+
const int64_t nr = ne01;
|
|
782
|
+
const int64_t n_t = ne1;
|
|
783
|
+
const int64_t n_s = ne2;
|
|
784
|
+
|
|
785
|
+
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
|
786
|
+
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
|
787
|
+
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
|
788
|
+
|
|
789
|
+
float sumf = 0.0f;
|
|
790
|
+
|
|
791
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
792
|
+
sumf += s[i0] * c[i0];
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
x[0] = sumf;
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
|
799
|
+
// TODO: optimize
|
|
800
|
+
kernel void kernel_ssm_scan_f32(
|
|
801
|
+
device const void * src0,
|
|
802
|
+
device const void * src1,
|
|
803
|
+
device const void * src2,
|
|
804
|
+
device const void * src3,
|
|
805
|
+
device const void * src4,
|
|
806
|
+
device const void * src5,
|
|
807
|
+
device float * dst,
|
|
808
|
+
constant int64_t & d_state,
|
|
809
|
+
constant int64_t & d_inner,
|
|
810
|
+
constant int64_t & n_seq_tokens,
|
|
811
|
+
constant int64_t & n_seqs,
|
|
812
|
+
constant uint64_t & nb00,
|
|
813
|
+
constant uint64_t & nb01,
|
|
814
|
+
constant uint64_t & nb02,
|
|
815
|
+
constant uint64_t & nb10,
|
|
816
|
+
constant uint64_t & nb11,
|
|
817
|
+
constant uint64_t & nb12,
|
|
818
|
+
constant uint64_t & nb13,
|
|
819
|
+
constant uint64_t & nb20,
|
|
820
|
+
constant uint64_t & nb21,
|
|
821
|
+
constant uint64_t & nb22,
|
|
822
|
+
constant uint64_t & nb30,
|
|
823
|
+
constant uint64_t & nb31,
|
|
824
|
+
constant uint64_t & nb40,
|
|
825
|
+
constant uint64_t & nb41,
|
|
826
|
+
constant uint64_t & nb42,
|
|
827
|
+
constant uint64_t & nb50,
|
|
828
|
+
constant uint64_t & nb51,
|
|
829
|
+
constant uint64_t & nb52,
|
|
830
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
831
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
832
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
833
|
+
const int64_t ir = tgpig.x;
|
|
834
|
+
const int64_t i3 = tgpig.y;
|
|
835
|
+
|
|
836
|
+
const int64_t nc = d_state;
|
|
837
|
+
const int64_t nr = d_inner;
|
|
838
|
+
const int64_t n_t = n_seq_tokens;
|
|
839
|
+
const int64_t n_s = n_seqs;
|
|
840
|
+
|
|
841
|
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
842
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
|
843
|
+
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
|
|
844
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
|
|
845
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
|
|
846
|
+
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
|
|
847
|
+
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
|
|
848
|
+
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
|
|
849
|
+
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
|
|
850
|
+
|
|
851
|
+
if (i2 > 0) {
|
|
852
|
+
s0 = s;
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// i1 == 0
|
|
856
|
+
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
857
|
+
float x_dt = x[0] * dt_soft_plus;
|
|
858
|
+
float sumf = 0.0f;
|
|
859
|
+
|
|
860
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
861
|
+
int64_t i = i0;
|
|
862
|
+
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
863
|
+
sumf += state * C[i0];
|
|
864
|
+
s[i] = state;
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
y[0] = sumf;
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
|
|
670
871
|
kernel void kernel_norm(
|
|
671
872
|
device const void * src0,
|
|
672
873
|
device float * dst,
|
|
@@ -1976,6 +2177,7 @@ typedef void (flash_attn_ext_f16_t)(
|
|
|
1976
2177
|
constant float & m0,
|
|
1977
2178
|
constant float & m1,
|
|
1978
2179
|
constant uint32_t & n_head_log2,
|
|
2180
|
+
constant float & logit_softcap,
|
|
1979
2181
|
threadgroup half * shared,
|
|
1980
2182
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1981
2183
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2014,6 +2216,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2014
2216
|
constant float & m0,
|
|
2015
2217
|
constant float & m1,
|
|
2016
2218
|
constant uint32_t & n_head_log2,
|
|
2219
|
+
constant float & logit_softcap,
|
|
2017
2220
|
threadgroup half * shared [[threadgroup(0)]],
|
|
2018
2221
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2019
2222
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2138,19 +2341,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2138
2341
|
}
|
|
2139
2342
|
|
|
2140
2343
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
2141
|
-
|
|
2142
|
-
const short tx = tiisg%4;
|
|
2143
|
-
const short ty = tiisg/4;
|
|
2144
|
-
|
|
2145
|
-
if (mask != q) {
|
|
2146
|
-
// mqk = mqk*scale + mask*slope
|
|
2147
|
-
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
|
2148
|
-
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
|
2149
|
-
} else {
|
|
2150
|
-
// mqk = mqk*scale
|
|
2151
|
-
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
|
2152
|
-
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
|
2153
|
-
}
|
|
2154
2344
|
}
|
|
2155
2345
|
}
|
|
2156
2346
|
|
|
@@ -2162,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2162
2352
|
float ms[Q];
|
|
2163
2353
|
|
|
2164
2354
|
for (short j = 0; j < Q; ++j) {
|
|
2165
|
-
const short p = tiisg;
|
|
2166
|
-
|
|
2167
2355
|
const float m = M[j];
|
|
2168
|
-
|
|
2356
|
+
|
|
2357
|
+
// scale and apply the logitcap / mask
|
|
2358
|
+
float s = ss[j*TF + tiisg]*scale;
|
|
2359
|
+
|
|
2360
|
+
if (logit_softcap != 0.0f) {
|
|
2361
|
+
s = logit_softcap*precise::tanh(s);
|
|
2362
|
+
}
|
|
2363
|
+
|
|
2364
|
+
if (mask != q) {
|
|
2365
|
+
// mqk = mqk + mask*slope
|
|
2366
|
+
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
|
2367
|
+
}
|
|
2169
2368
|
|
|
2170
2369
|
smax = simd_max(max(smax, s));
|
|
2171
2370
|
M[j] = simd_max(max(M[j], s));
|
|
@@ -2176,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2176
2375
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
|
2177
2376
|
|
|
2178
2377
|
// the P matrix from the paper (Q rows, C columns)
|
|
2179
|
-
ss[j*TF +
|
|
2378
|
+
ss[j*TF + tiisg] = vs;
|
|
2180
2379
|
}
|
|
2181
2380
|
|
|
2182
2381
|
// create a QxQ diagonal matrix for rescaling the output
|
|
@@ -2345,6 +2544,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
2345
2544
|
constant float & m0,
|
|
2346
2545
|
constant float & m1,
|
|
2347
2546
|
constant uint32_t & n_head_log2,
|
|
2547
|
+
constant float & logit_softcap,
|
|
2348
2548
|
threadgroup half * shared [[threadgroup(0)]],
|
|
2349
2549
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2350
2550
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2479,7 +2679,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
2479
2679
|
|
|
2480
2680
|
// mqk = mqk*scale + mask*slope
|
|
2481
2681
|
if (tiisg == 0) {
|
|
2482
|
-
mqk
|
|
2682
|
+
mqk *= scale;
|
|
2683
|
+
|
|
2684
|
+
if (logit_softcap != 0.0f) {
|
|
2685
|
+
mqk = logit_softcap*precise::tanh(mqk);
|
|
2686
|
+
}
|
|
2687
|
+
|
|
2688
|
+
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
|
|
2483
2689
|
|
|
2484
2690
|
ss4[cc] = mqk;
|
|
2485
2691
|
}
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
@@ -19,7 +19,11 @@ typedef half2 ggml_half2;
|
|
|
19
19
|
|
|
20
20
|
#define GGML_COMMON_DECL
|
|
21
21
|
#elif defined(GGML_COMMON_DECL_CUDA)
|
|
22
|
+
#if defined(GGML_COMMON_DECL_MUSA)
|
|
23
|
+
#include <musa_fp16.h>
|
|
24
|
+
#else
|
|
22
25
|
#include <cuda_fp16.h>
|
|
26
|
+
#endif
|
|
23
27
|
#include <cstdint>
|
|
24
28
|
|
|
25
29
|
typedef half ggml_half;
|
|
@@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
|
|
|
415
419
|
#define GGML_TABLE_END() };
|
|
416
420
|
|
|
417
421
|
#define GGML_COMMON_IMPL
|
|
418
|
-
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
|
|
422
|
+
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
|
|
419
423
|
#include <cstdint>
|
|
420
424
|
|
|
421
425
|
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
|
|
@@ -17,7 +17,7 @@ enum ggml_sort_order {
|
|
|
17
17
|
GGML_SORT_ORDER_DESC,
|
|
18
18
|
};
|
|
19
19
|
|
|
20
|
-
// general-purpose kernel for addition, multiplication and division of two tensors
|
|
20
|
+
// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
|
|
21
21
|
// pros: works for non-contiguous tensors, supports broadcast across all dims
|
|
22
22
|
// cons: not very efficient
|
|
23
23
|
kernel void kernel_add(
|
|
@@ -70,6 +70,56 @@ kernel void kernel_add(
|
|
|
70
70
|
}
|
|
71
71
|
}
|
|
72
72
|
|
|
73
|
+
kernel void kernel_sub(
|
|
74
|
+
device const char * src0,
|
|
75
|
+
device const char * src1,
|
|
76
|
+
device char * dst,
|
|
77
|
+
constant int64_t & ne00,
|
|
78
|
+
constant int64_t & ne01,
|
|
79
|
+
constant int64_t & ne02,
|
|
80
|
+
constant int64_t & ne03,
|
|
81
|
+
constant uint64_t & nb00,
|
|
82
|
+
constant uint64_t & nb01,
|
|
83
|
+
constant uint64_t & nb02,
|
|
84
|
+
constant uint64_t & nb03,
|
|
85
|
+
constant int64_t & ne10,
|
|
86
|
+
constant int64_t & ne11,
|
|
87
|
+
constant int64_t & ne12,
|
|
88
|
+
constant int64_t & ne13,
|
|
89
|
+
constant uint64_t & nb10,
|
|
90
|
+
constant uint64_t & nb11,
|
|
91
|
+
constant uint64_t & nb12,
|
|
92
|
+
constant uint64_t & nb13,
|
|
93
|
+
constant int64_t & ne0,
|
|
94
|
+
constant int64_t & ne1,
|
|
95
|
+
constant int64_t & ne2,
|
|
96
|
+
constant int64_t & ne3,
|
|
97
|
+
constant uint64_t & nb0,
|
|
98
|
+
constant uint64_t & nb1,
|
|
99
|
+
constant uint64_t & nb2,
|
|
100
|
+
constant uint64_t & nb3,
|
|
101
|
+
constant int64_t & offs,
|
|
102
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
103
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
104
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
105
|
+
const int64_t i03 = tgpig.z;
|
|
106
|
+
const int64_t i02 = tgpig.y;
|
|
107
|
+
const int64_t i01 = tgpig.x;
|
|
108
|
+
|
|
109
|
+
const int64_t i13 = i03 % ne13;
|
|
110
|
+
const int64_t i12 = i02 % ne12;
|
|
111
|
+
const int64_t i11 = i01 % ne11;
|
|
112
|
+
|
|
113
|
+
device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
|
|
114
|
+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
|
|
115
|
+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
|
|
116
|
+
|
|
117
|
+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
|
|
118
|
+
const int i10 = i0 % ne10;
|
|
119
|
+
*((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
|
|
120
|
+
}
|
|
121
|
+
}
|
|
122
|
+
|
|
73
123
|
kernel void kernel_mul(
|
|
74
124
|
device const char * src0,
|
|
75
125
|
device const char * src1,
|
|
@@ -226,6 +276,15 @@ kernel void kernel_add_row(
|
|
|
226
276
|
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
|
227
277
|
}
|
|
228
278
|
|
|
279
|
+
kernel void kernel_sub_row(
|
|
280
|
+
device const float4 * src0,
|
|
281
|
+
device const float4 * src1,
|
|
282
|
+
device float4 * dst,
|
|
283
|
+
constant uint64_t & nb [[buffer(28)]],
|
|
284
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
285
|
+
dst[tpig] = src0[tpig] - src1[tpig % nb];
|
|
286
|
+
}
|
|
287
|
+
|
|
229
288
|
kernel void kernel_mul_row(
|
|
230
289
|
device const float4 * src0,
|
|
231
290
|
device const float4 * src1,
|
|
@@ -358,6 +417,27 @@ kernel void kernel_sqr(
|
|
|
358
417
|
dst[tpig] = src0[tpig] * src0[tpig];
|
|
359
418
|
}
|
|
360
419
|
|
|
420
|
+
kernel void kernel_sqrt(
|
|
421
|
+
device const float * src0,
|
|
422
|
+
device float * dst,
|
|
423
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
424
|
+
dst[tpig] = sqrt(src0[tpig]);
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
kernel void kernel_sin(
|
|
428
|
+
device const float * src0,
|
|
429
|
+
device float * dst,
|
|
430
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
431
|
+
dst[tpig] = sin(src0[tpig]);
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
kernel void kernel_cos(
|
|
435
|
+
device const float * src0,
|
|
436
|
+
device float * dst,
|
|
437
|
+
uint tpig[[thread_position_in_grid]]) {
|
|
438
|
+
dst[tpig] = cos(src0[tpig]);
|
|
439
|
+
}
|
|
440
|
+
|
|
361
441
|
kernel void kernel_sum_rows(
|
|
362
442
|
device const float * src0,
|
|
363
443
|
device float * dst,
|
|
@@ -667,6 +747,127 @@ kernel void kernel_diag_mask_inf_8(
|
|
|
667
747
|
}
|
|
668
748
|
}
|
|
669
749
|
|
|
750
|
+
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
|
|
751
|
+
// TODO: optimize
|
|
752
|
+
kernel void kernel_ssm_conv_f32(
|
|
753
|
+
device const void * src0,
|
|
754
|
+
device const void * src1,
|
|
755
|
+
device float * dst,
|
|
756
|
+
constant int64_t & ne00,
|
|
757
|
+
constant int64_t & ne01,
|
|
758
|
+
constant int64_t & ne02,
|
|
759
|
+
constant uint64_t & nb00,
|
|
760
|
+
constant uint64_t & nb01,
|
|
761
|
+
constant uint64_t & nb02,
|
|
762
|
+
constant int64_t & ne10,
|
|
763
|
+
constant int64_t & ne11,
|
|
764
|
+
constant uint64_t & nb10,
|
|
765
|
+
constant uint64_t & nb11,
|
|
766
|
+
constant int64_t & ne0,
|
|
767
|
+
constant int64_t & ne1,
|
|
768
|
+
constant int64_t & ne2,
|
|
769
|
+
constant uint64_t & nb0,
|
|
770
|
+
constant uint64_t & nb1,
|
|
771
|
+
constant uint64_t & nb2,
|
|
772
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
773
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
774
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
775
|
+
const int64_t ir = tgpig.x;
|
|
776
|
+
const int64_t i2 = tgpig.y;
|
|
777
|
+
const int64_t i3 = tgpig.z;
|
|
778
|
+
|
|
779
|
+
const int64_t nc = ne10;
|
|
780
|
+
const int64_t ncs = ne00;
|
|
781
|
+
const int64_t nr = ne01;
|
|
782
|
+
const int64_t n_t = ne1;
|
|
783
|
+
const int64_t n_s = ne2;
|
|
784
|
+
|
|
785
|
+
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
|
|
786
|
+
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
|
|
787
|
+
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
|
|
788
|
+
|
|
789
|
+
float sumf = 0.0f;
|
|
790
|
+
|
|
791
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
792
|
+
sumf += s[i0] * c[i0];
|
|
793
|
+
}
|
|
794
|
+
|
|
795
|
+
x[0] = sumf;
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
|
|
799
|
+
// TODO: optimize
|
|
800
|
+
kernel void kernel_ssm_scan_f32(
|
|
801
|
+
device const void * src0,
|
|
802
|
+
device const void * src1,
|
|
803
|
+
device const void * src2,
|
|
804
|
+
device const void * src3,
|
|
805
|
+
device const void * src4,
|
|
806
|
+
device const void * src5,
|
|
807
|
+
device float * dst,
|
|
808
|
+
constant int64_t & d_state,
|
|
809
|
+
constant int64_t & d_inner,
|
|
810
|
+
constant int64_t & n_seq_tokens,
|
|
811
|
+
constant int64_t & n_seqs,
|
|
812
|
+
constant uint64_t & nb00,
|
|
813
|
+
constant uint64_t & nb01,
|
|
814
|
+
constant uint64_t & nb02,
|
|
815
|
+
constant uint64_t & nb10,
|
|
816
|
+
constant uint64_t & nb11,
|
|
817
|
+
constant uint64_t & nb12,
|
|
818
|
+
constant uint64_t & nb13,
|
|
819
|
+
constant uint64_t & nb20,
|
|
820
|
+
constant uint64_t & nb21,
|
|
821
|
+
constant uint64_t & nb22,
|
|
822
|
+
constant uint64_t & nb30,
|
|
823
|
+
constant uint64_t & nb31,
|
|
824
|
+
constant uint64_t & nb40,
|
|
825
|
+
constant uint64_t & nb41,
|
|
826
|
+
constant uint64_t & nb42,
|
|
827
|
+
constant uint64_t & nb50,
|
|
828
|
+
constant uint64_t & nb51,
|
|
829
|
+
constant uint64_t & nb52,
|
|
830
|
+
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
831
|
+
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
832
|
+
uint3 ntg[[threads_per_threadgroup]]) {
|
|
833
|
+
const int64_t ir = tgpig.x;
|
|
834
|
+
const int64_t i3 = tgpig.y;
|
|
835
|
+
|
|
836
|
+
const int64_t nc = d_state;
|
|
837
|
+
const int64_t nr = d_inner;
|
|
838
|
+
const int64_t n_t = n_seq_tokens;
|
|
839
|
+
const int64_t n_s = n_seqs;
|
|
840
|
+
|
|
841
|
+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
842
|
+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
|
|
843
|
+
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
|
|
844
|
+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
|
|
845
|
+
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
|
|
846
|
+
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
|
|
847
|
+
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
|
|
848
|
+
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
|
|
849
|
+
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
|
|
850
|
+
|
|
851
|
+
if (i2 > 0) {
|
|
852
|
+
s0 = s;
|
|
853
|
+
}
|
|
854
|
+
|
|
855
|
+
// i1 == 0
|
|
856
|
+
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
857
|
+
float x_dt = x[0] * dt_soft_plus;
|
|
858
|
+
float sumf = 0.0f;
|
|
859
|
+
|
|
860
|
+
for (int64_t i0 = 0; i0 < nc; ++i0) {
|
|
861
|
+
int64_t i = i0;
|
|
862
|
+
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
|
|
863
|
+
sumf += state * C[i0];
|
|
864
|
+
s[i] = state;
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
y[0] = sumf;
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
|
|
670
871
|
kernel void kernel_norm(
|
|
671
872
|
device const void * src0,
|
|
672
873
|
device float * dst,
|
|
@@ -1976,6 +2177,7 @@ typedef void (flash_attn_ext_f16_t)(
|
|
|
1976
2177
|
constant float & m0,
|
|
1977
2178
|
constant float & m1,
|
|
1978
2179
|
constant uint32_t & n_head_log2,
|
|
2180
|
+
constant float & logit_softcap,
|
|
1979
2181
|
threadgroup half * shared,
|
|
1980
2182
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
1981
2183
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2014,6 +2216,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2014
2216
|
constant float & m0,
|
|
2015
2217
|
constant float & m1,
|
|
2016
2218
|
constant uint32_t & n_head_log2,
|
|
2219
|
+
constant float & logit_softcap,
|
|
2017
2220
|
threadgroup half * shared [[threadgroup(0)]],
|
|
2018
2221
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2019
2222
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2138,19 +2341,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2138
2341
|
}
|
|
2139
2342
|
|
|
2140
2343
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
|
2141
|
-
|
|
2142
|
-
const short tx = tiisg%4;
|
|
2143
|
-
const short ty = tiisg/4;
|
|
2144
|
-
|
|
2145
|
-
if (mask != q) {
|
|
2146
|
-
// mqk = mqk*scale + mask*slope
|
|
2147
|
-
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
|
2148
|
-
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
|
2149
|
-
} else {
|
|
2150
|
-
// mqk = mqk*scale
|
|
2151
|
-
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
|
2152
|
-
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
|
2153
|
-
}
|
|
2154
2344
|
}
|
|
2155
2345
|
}
|
|
2156
2346
|
|
|
@@ -2162,10 +2352,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2162
2352
|
float ms[Q];
|
|
2163
2353
|
|
|
2164
2354
|
for (short j = 0; j < Q; ++j) {
|
|
2165
|
-
const short p = tiisg;
|
|
2166
|
-
|
|
2167
2355
|
const float m = M[j];
|
|
2168
|
-
|
|
2356
|
+
|
|
2357
|
+
// scale and apply the logitcap / mask
|
|
2358
|
+
float s = ss[j*TF + tiisg]*scale;
|
|
2359
|
+
|
|
2360
|
+
if (logit_softcap != 0.0f) {
|
|
2361
|
+
s = logit_softcap*precise::tanh(s);
|
|
2362
|
+
}
|
|
2363
|
+
|
|
2364
|
+
if (mask != q) {
|
|
2365
|
+
// mqk = mqk + mask*slope
|
|
2366
|
+
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
|
2367
|
+
}
|
|
2169
2368
|
|
|
2170
2369
|
smax = simd_max(max(smax, s));
|
|
2171
2370
|
M[j] = simd_max(max(M[j], s));
|
|
@@ -2176,7 +2375,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|
|
2176
2375
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
|
2177
2376
|
|
|
2178
2377
|
// the P matrix from the paper (Q rows, C columns)
|
|
2179
|
-
ss[j*TF +
|
|
2378
|
+
ss[j*TF + tiisg] = vs;
|
|
2180
2379
|
}
|
|
2181
2380
|
|
|
2182
2381
|
// create a QxQ diagonal matrix for rescaling the output
|
|
@@ -2345,6 +2544,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
2345
2544
|
constant float & m0,
|
|
2346
2545
|
constant float & m1,
|
|
2347
2546
|
constant uint32_t & n_head_log2,
|
|
2547
|
+
constant float & logit_softcap,
|
|
2348
2548
|
threadgroup half * shared [[threadgroup(0)]],
|
|
2349
2549
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
2350
2550
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
@@ -2479,7 +2679,13 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|
|
2479
2679
|
|
|
2480
2680
|
// mqk = mqk*scale + mask*slope
|
|
2481
2681
|
if (tiisg == 0) {
|
|
2482
|
-
mqk
|
|
2682
|
+
mqk *= scale;
|
|
2683
|
+
|
|
2684
|
+
if (logit_softcap != 0.0f) {
|
|
2685
|
+
mqk = logit_softcap*precise::tanh(mqk);
|
|
2686
|
+
}
|
|
2687
|
+
|
|
2688
|
+
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
|
|
2483
2689
|
|
|
2484
2690
|
ss4[cc] = mqk;
|
|
2485
2691
|
}
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
|
Binary file
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "node-llama-cpp",
|
|
3
|
-
"version": "2.8.
|
|
3
|
+
"version": "2.8.16",
|
|
4
4
|
"description": "Run AI models locally on your machine with node.js bindings for llama.cpp. Force a JSON schema on the model output on the generation level",
|
|
5
5
|
"main": "dist/index.js",
|
|
6
6
|
"type": "module",
|
|
@@ -97,7 +97,7 @@
|
|
|
97
97
|
"bugs": {
|
|
98
98
|
"url": "https://github.com/withcatai/node-llama-cpp/issues"
|
|
99
99
|
},
|
|
100
|
-
"homepage": "https://
|
|
100
|
+
"homepage": "https://node-llama-cpp.withcat.ai",
|
|
101
101
|
"devDependencies": {
|
|
102
102
|
"@commitlint/cli": "^17.7.1",
|
|
103
103
|
"@commitlint/config-conventional": "^17.7.0",
|