@fugood/llama.node 1.1.10 → 1.2.0-rc.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +20 -2
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +174 -388
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +67 -37
  27. package/src/llama.cpp/common/chat.cpp +263 -2
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.cpp +10 -3
  30. package/src/llama.cpp/common/common.h +5 -2
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +4 -3
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +0 -1
  39. package/src/llama.cpp/ggml/include/ggml.h +50 -1
  40. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +19 -16
  41. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  42. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +1 -7
  43. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +11 -37
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +3 -4
  45. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +43 -6
  46. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +4 -1
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +18 -18
  48. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +232 -123
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +234 -16
  50. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  51. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +80 -51
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +161 -20
  53. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +399 -50
  54. package/src/llama.cpp/include/llama.h +32 -7
  55. package/src/llama.cpp/src/llama-adapter.cpp +101 -4
  56. package/src/llama.cpp/src/llama-adapter.h +6 -0
  57. package/src/llama.cpp/src/llama-arch.cpp +69 -2
  58. package/src/llama.cpp/src/llama-arch.h +6 -0
  59. package/src/llama.cpp/src/llama-context.cpp +92 -45
  60. package/src/llama.cpp/src/llama-context.h +1 -5
  61. package/src/llama.cpp/src/llama-graph.cpp +74 -19
  62. package/src/llama.cpp/src/llama-graph.h +10 -1
  63. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  64. package/src/llama.cpp/src/llama-hparams.h +9 -3
  65. package/src/llama.cpp/src/llama-impl.h +2 -0
  66. package/src/llama.cpp/src/llama-kv-cache.cpp +33 -120
  67. package/src/llama.cpp/src/llama-kv-cache.h +4 -13
  68. package/src/llama.cpp/src/llama-model-loader.cpp +1 -0
  69. package/src/llama.cpp/src/llama-model.cpp +434 -21
  70. package/src/llama.cpp/src/llama-model.h +1 -1
  71. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  72. package/src/llama.cpp/src/llama-vocab.cpp +1 -1
  73. package/src/llama.cpp/src/llama.cpp +12 -0
  74. package/src/anyascii.c +0 -22223
  75. package/src/anyascii.h +0 -42
  76. package/src/tts_utils.cpp +0 -371
  77. package/src/tts_utils.h +0 -103
package/CMakeLists.txt CHANGED
@@ -137,10 +137,6 @@ file(
137
137
  "src/LlamaCompletionWorker.h"
138
138
  "src/LlamaContext.cpp"
139
139
  "src/LlamaContext.h"
140
- "src/TokenizeWorker.cpp"
141
- "src/TokenizeWorker.h"
142
- "src/DetokenizeWorker.cpp"
143
- "src/DetokenizeWorker.h"
144
140
  "src/EmbeddingWorker.cpp"
145
141
  "src/EmbeddingWorker.h"
146
142
  "src/RerankWorker.cpp"
@@ -149,12 +145,13 @@ file(
149
145
  "src/LoadSessionWorker.h"
150
146
  "src/SaveSessionWorker.cpp"
151
147
  "src/SaveSessionWorker.h"
148
+ "src/TokenizeWorker.cpp"
149
+ "src/TokenizeWorker.h"
150
+ "src/DetokenizeWorker.cpp"
151
+ "src/DetokenizeWorker.h"
152
152
  "src/DecodeAudioTokenWorker.cpp"
153
153
  "src/DecodeAudioTokenWorker.h"
154
- "src/tts_utils.cpp"
155
- "src/tts_utils.h"
156
- "src/anyascii.h"
157
- "src/anyascii.c"
154
+ "src/rn-llama/*"
158
155
  )
159
156
 
160
157
  if (NOT MSVC AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
package/lib/binding.ts CHANGED
@@ -27,7 +27,8 @@ export type LlamaModelOptions = {
27
27
  n_ubatch?: number
28
28
  n_threads?: number
29
29
  n_gpu_layers?: number
30
- flash_attn?: boolean
30
+ flash_attn_type?: 'auto' | 'on' | 'off'
31
+ flash_attn?: boolean // Deprecated: use flash_attn_type instead
31
32
  cache_type_k?:
32
33
  | 'f16'
33
34
  | 'f32'
@@ -149,6 +150,21 @@ export type LlamaCompletionOptions = {
149
150
  * Help prevent hallucinations by forcing the TTS to use the correct words.
150
151
  */
151
152
  guide_tokens?: number[] | Int32Array
153
+ /**
154
+ * Number of top token probabilities to return for each generated token.
155
+ * When > 0, completion_probabilities will be included in streaming callbacks and final result.
156
+ */
157
+ n_probs?: number
158
+ }
159
+
160
+ export type TokenProbability = {
161
+ tok_str: string
162
+ prob: number
163
+ }
164
+
165
+ export type CompletionProbability = {
166
+ content: string
167
+ probs: TokenProbability[]
152
168
  }
153
169
 
154
170
  export type LlamaCompletionResult = {
@@ -162,6 +178,7 @@ export type LlamaCompletionResult = {
162
178
  context_full: boolean
163
179
  interrupted: boolean
164
180
  audio_tokens?: Array<number>
181
+ completion_probabilities?: CompletionProbability[]
165
182
  timings: {
166
183
  prompt_n: number
167
184
  prompt_ms: number
@@ -180,6 +197,7 @@ export type LlamaCompletionToken = {
180
197
  reasoning_content?: string
181
198
  tool_calls?: ToolCall[]
182
199
  accumulated_text?: string
200
+ completion_probabilities?: CompletionProbability[]
183
201
  }
184
202
 
185
203
  export type TokenizeResult = {
@@ -308,7 +326,7 @@ export interface LlamaContext {
308
326
  stopCompletion(): void
309
327
  tokenize(text: string, media_paths?: string[]): Promise<TokenizeResult>
310
328
  detokenize(tokens: number[]): Promise<string>
311
- embedding(text: string): Promise<EmbeddingResult>
329
+ embedding(text: string, params?: { embd_normalize?: number }): Promise<EmbeddingResult>
312
330
  rerank(query: string, documents: string[], params?: RerankParams): Promise<RerankResult[]>
313
331
  saveSession(path: string): Promise<void>
314
332
  loadSession(path: string): Promise<void>
package/lib/index.js CHANGED
@@ -180,8 +180,8 @@ class LlamaContextWrapper {
180
180
  detokenize(tokens) {
181
181
  return this.ctx.detokenize(tokens);
182
182
  }
183
- embedding(text) {
184
- return this.ctx.embedding(text);
183
+ embedding(text, params) {
184
+ return this.ctx.embedding(text, params);
185
185
  }
186
186
  rerank(query, documents, params) {
187
187
  return this.ctx
package/lib/index.ts CHANGED
@@ -251,8 +251,8 @@ class LlamaContextWrapper {
251
251
  return this.ctx.detokenize(tokens)
252
252
  }
253
253
 
254
- embedding(text: string): Promise<EmbeddingResult> {
255
- return this.ctx.embedding(text)
254
+ embedding(text: string, params?: { embd_normalize?: number }): Promise<EmbeddingResult> {
255
+ return this.ctx.embedding(text, params)
256
256
  }
257
257
 
258
258
  rerank(
package/package.json CHANGED
@@ -1,11 +1,12 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.1.10",
4
+ "version": "1.2.0-rc.0",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
8
- "bootstrap": "npm install --omit=optional",
8
+ "copy-rn-llama-source": "node scripts/copy-rn-llama-source.js",
9
+ "bootstrap": "npm run copy-rn-llama-source && npm install --omit=optional",
9
10
  "postinstall": "node scripts/check.js",
10
11
  "pretest": "node scripts/download-test-models.js",
11
12
  "test": "jest",
@@ -53,6 +54,7 @@
53
54
  "scripts/check.js",
54
55
  "scripts/llama.cpp.patch",
55
56
  "src/*.{cc,c,h,hpp}",
57
+ "src/rn-llama/*",
56
58
  "src/DecodeAudioTokenWorker.cpp",
57
59
  "src/DetokenizeWorker.cpp",
58
60
  "src/DisposeWorker.cpp",
@@ -62,7 +64,6 @@
62
64
  "src/LoadSessionWorker.cpp",
63
65
  "src/SaveSessionWorker.cpp",
64
66
  "src/TokenizeWorker.cpp",
65
- "src/tts_utils.cpp",
66
67
  "src/llama.cpp/{common,src,include}/**/*.{h,hpp,cpp,cc,c}",
67
68
  "src/llama.cpp/ggml/include/*.h",
68
69
  "src/llama.cpp/ggml/src/ggml-cpu/**/*.{h,hpp,cpp,cc,c}",
@@ -71,19 +72,19 @@
71
72
  "CMakeLists.txt"
72
73
  ],
73
74
  "optionalDependencies": {
74
- "@fugood/node-llama-linux-x64": "1.1.10",
75
- "@fugood/node-llama-linux-x64-vulkan": "1.1.10",
76
- "@fugood/node-llama-linux-x64-cuda": "1.1.10",
77
- "@fugood/node-llama-linux-arm64": "1.1.10",
78
- "@fugood/node-llama-linux-arm64-vulkan": "1.1.10",
79
- "@fugood/node-llama-linux-arm64-cuda": "1.1.10",
80
- "@fugood/node-llama-win32-x64": "1.1.10",
81
- "@fugood/node-llama-win32-x64-vulkan": "1.1.10",
82
- "@fugood/node-llama-win32-x64-cuda": "1.1.10",
83
- "@fugood/node-llama-win32-arm64": "1.1.10",
84
- "@fugood/node-llama-win32-arm64-vulkan": "1.1.10",
85
- "@fugood/node-llama-darwin-x64": "1.1.10",
86
- "@fugood/node-llama-darwin-arm64": "1.1.10"
75
+ "@fugood/node-llama-linux-x64": "1.2.0-rc.0",
76
+ "@fugood/node-llama-linux-x64-vulkan": "1.2.0-rc.0",
77
+ "@fugood/node-llama-linux-x64-cuda": "1.2.0-rc.0",
78
+ "@fugood/node-llama-linux-arm64": "1.2.0-rc.0",
79
+ "@fugood/node-llama-linux-arm64-vulkan": "1.2.0-rc.0",
80
+ "@fugood/node-llama-linux-arm64-cuda": "1.2.0-rc.0",
81
+ "@fugood/node-llama-win32-x64": "1.2.0-rc.0",
82
+ "@fugood/node-llama-win32-x64-vulkan": "1.2.0-rc.0",
83
+ "@fugood/node-llama-win32-x64-cuda": "1.2.0-rc.0",
84
+ "@fugood/node-llama-win32-arm64": "1.2.0-rc.0",
85
+ "@fugood/node-llama-win32-arm64-vulkan": "1.2.0-rc.0",
86
+ "@fugood/node-llama-darwin-x64": "1.2.0-rc.0",
87
+ "@fugood/node-llama-darwin-arm64": "1.2.0-rc.0"
87
88
  },
88
89
  "devDependencies": {
89
90
  "@babel/preset-env": "^7.24.4",
@@ -115,6 +116,9 @@
115
116
  ],
116
117
  "testMatch": [
117
118
  "**/*.test.ts"
119
+ ],
120
+ "testPathIgnorePatterns": [
121
+ "<rootDir>/src/llama.rn/"
118
122
  ]
119
123
  },
120
124
  "prettier": {
@@ -1,40 +1,37 @@
1
1
  #include "DecodeAudioTokenWorker.h"
2
- #include "tts_utils.h"
3
- #include <vector>
2
+ #include "LlamaContext.h"
4
3
 
5
- DecodeAudioTokenWorker::DecodeAudioTokenWorker(
6
- const Napi::CallbackInfo &info, llama_model *model, llama_context *ctx,
7
- int n_threads, const std::vector<llama_token> &tokens)
8
- : AsyncWorker(info.Env()), Deferred(info.Env()), _model(model), _ctx(ctx),
9
- _n_threads(n_threads), _tokens(tokens) {}
4
+ DecodeAudioTokenWorker::DecodeAudioTokenWorker(const Napi::CallbackInfo &info,
5
+ rnllama::llama_rn_context* rn_ctx, std::vector<int32_t> tokens)
6
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _tokens(tokens) {}
10
7
 
11
8
  void DecodeAudioTokenWorker::Execute() {
12
- const int n_codes = _tokens.size();
13
- llama_batch batch = llama_batch_init(n_codes, 0, 1);
14
- for (size_t i = 0; i < _tokens.size(); ++i) {
15
- common_batch_add(batch, _tokens[i], i, {0}, true);
9
+ try {
10
+ if (!_rn_ctx->tts_wrapper) {
11
+ SetError("Vocoder not initialized");
12
+ return;
13
+ }
14
+
15
+ // Convert to llama_token vector - rn-tts handles token adjustment internally
16
+ std::vector<llama_token> llama_tokens;
17
+ for (const auto& token : _tokens) {
18
+ llama_tokens.push_back(token);
19
+ }
20
+
21
+ // Use the rn-tts API instead of directly accessing the worker
22
+ _result = _rn_ctx->tts_wrapper->decodeAudioTokens(_rn_ctx, llama_tokens);
23
+ } catch (const std::exception &e) {
24
+ SetError(e.what());
16
25
  }
17
- if (batch.n_tokens != n_codes) {
18
- SetError("batch.n_tokens != n_codes");
19
- return;
20
- }
21
- if (llama_encode(_ctx, batch) != 0) {
22
- SetError("llama_encode() failed");
23
- return;
24
- }
25
- llama_synchronize(_ctx);
26
- const int n_embd = llama_model_n_embd(_model);
27
- const float *embd = llama_get_embeddings(_ctx);
28
- _result = embd_to_audio(embd, n_codes, n_embd, _n_threads);
29
26
  }
30
27
 
31
28
  void DecodeAudioTokenWorker::OnOK() {
32
- auto result =
33
- Napi::Float32Array::New(Napi::AsyncWorker::Env(), _result.size());
29
+ // Create Float32Array and copy the data
30
+ auto result = Napi::Float32Array::New(Napi::AsyncWorker::Env(), _result.size());
34
31
  memcpy(result.Data(), _result.data(), _result.size() * sizeof(float));
35
32
  Napi::Promise::Deferred::Resolve(result);
36
33
  }
37
34
 
38
35
  void DecodeAudioTokenWorker::OnError(const Napi::Error &err) {
39
36
  Napi::Promise::Deferred::Reject(err.Value());
40
- }
37
+ }
@@ -1,12 +1,12 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
  #include <vector>
3
4
 
4
5
  class DecodeAudioTokenWorker : public Napi::AsyncWorker,
5
6
  public Napi::Promise::Deferred {
6
7
  public:
7
- DecodeAudioTokenWorker(const Napi::CallbackInfo &info, llama_model *model,
8
- llama_context *ctx, int n_threads,
9
- const std::vector<llama_token> &tokens);
8
+ DecodeAudioTokenWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
9
+ std::vector<int32_t> tokens);
10
10
 
11
11
  protected:
12
12
  void Execute();
@@ -14,9 +14,7 @@ protected:
14
14
  void OnError(const Napi::Error &err);
15
15
 
16
16
  private:
17
- llama_model *_model;
18
- llama_context *_ctx;
19
- int _n_threads;
20
- std::vector<llama_token> _tokens;
17
+ rnllama::llama_rn_context* _rn_ctx;
18
+ std::vector<int32_t> _tokens;
21
19
  std::vector<float> _result;
22
- };
20
+ };
@@ -2,21 +2,18 @@
2
2
  #include "LlamaContext.h"
3
3
 
4
4
  DetokenizeWorker::DetokenizeWorker(const Napi::CallbackInfo &info,
5
- LlamaSessionPtr &sess,
6
- std::vector<llama_token> &tokens)
7
- : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
8
- _tokens(std::move(tokens)) {}
5
+ rnllama::llama_rn_context* rn_ctx, std::vector<int32_t> tokens)
6
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _tokens(tokens) {}
9
7
 
10
8
  void DetokenizeWorker::Execute() {
11
- const auto text = ::common_detokenize(_sess->context(), _tokens);
9
+ const auto text = tokens_to_str(_rn_ctx->ctx, _tokens.begin(), _tokens.end());
12
10
  _text = std::move(text);
13
11
  }
14
12
 
15
13
  void DetokenizeWorker::OnOK() {
16
- Napi::Promise::Deferred::Resolve(
17
- Napi::String::New(Napi::AsyncWorker::Env(), _text));
14
+ Napi::Promise::Deferred::Resolve(Napi::String::New(Napi::AsyncWorker::Env(), _text));
18
15
  }
19
16
 
20
17
  void DetokenizeWorker::OnError(const Napi::Error &err) {
21
18
  Napi::Promise::Deferred::Reject(err.Value());
22
- }
19
+ }
@@ -1,11 +1,12 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
  #include <vector>
3
4
 
4
5
  class DetokenizeWorker : public Napi::AsyncWorker,
5
6
  public Napi::Promise::Deferred {
6
7
  public:
7
- DetokenizeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
8
- std::vector<llama_token> &tokens);
8
+ DetokenizeWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
9
+ std::vector<int32_t> tokens);
9
10
 
10
11
  protected:
11
12
  void Execute();
@@ -13,7 +14,7 @@ protected:
13
14
  void OnError(const Napi::Error &err);
14
15
 
15
16
  private:
16
- LlamaSessionPtr _sess;
17
- std::vector<llama_token> _tokens;
17
+ rnllama::llama_rn_context* _rn_ctx;
18
+ std::vector<int32_t> _tokens;
18
19
  std::string _text;
19
- };
20
+ };
@@ -1,10 +1,30 @@
1
1
  #include "DisposeWorker.h"
2
+ #include "rn-llama/rn-completion.h"
2
3
 
3
4
  DisposeWorker::DisposeWorker(const Napi::CallbackInfo &info,
4
- LlamaSessionPtr sess)
5
- : AsyncWorker(info.Env()), Deferred(info.Env()), sess_(std::move(sess)) {}
5
+ rnllama::llama_rn_context* rn_ctx, rnllama::llama_rn_context** parent_ptr)
6
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _parent_ptr(parent_ptr) {}
6
7
 
7
- void DisposeWorker::Execute() { sess_->dispose(); }
8
+ void DisposeWorker::Execute() {
9
+ if (_rn_ctx) {
10
+ // Ensure all child contexts are properly cleaned up first
11
+ try {
12
+ // Now delete the main context
13
+ delete _rn_ctx;
14
+
15
+ // Set parent pointer to nullptr to prevent double free
16
+ if (_parent_ptr) {
17
+ *_parent_ptr = nullptr;
18
+ }
19
+ } catch (const std::exception& e) {
20
+ SetError(std::string("Error during context disposal: ") + e.what());
21
+ return;
22
+ } catch (...) {
23
+ SetError("Unknown error during context disposal");
24
+ return;
25
+ }
26
+ }
27
+ }
8
28
 
9
29
  void DisposeWorker::OnOK() { Resolve(AsyncWorker::Env().Undefined()); }
10
30
 
@@ -1,8 +1,9 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
 
3
4
  class DisposeWorker : public Napi::AsyncWorker, public Napi::Promise::Deferred {
4
5
  public:
5
- DisposeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr sess);
6
+ DisposeWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx, rnllama::llama_rn_context** parent_ptr);
6
7
 
7
8
  protected:
8
9
  void Execute();
@@ -10,5 +11,6 @@ protected:
10
11
  void OnError(const Napi::Error &err);
11
12
 
12
13
  private:
13
- LlamaSessionPtr sess_;
14
+ rnllama::llama_rn_context* _rn_ctx;
15
+ rnllama::llama_rn_context** _parent_ptr; // Pointer to the parent's _rn_ctx pointer
14
16
  };
@@ -2,46 +2,20 @@
2
2
  #include "LlamaContext.h"
3
3
 
4
4
  EmbeddingWorker::EmbeddingWorker(const Napi::CallbackInfo &info,
5
- LlamaSessionPtr &sess, std::string text,
5
+ rnllama::llama_rn_context* rn_ctx, std::string text,
6
6
  common_params &params)
7
- : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text),
7
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _text(text),
8
8
  _params(params) {}
9
9
 
10
10
  void EmbeddingWorker::Execute() {
11
- llama_memory_clear(llama_get_memory(_sess->context()), true);
12
- auto tokens = ::common_tokenize(_sess->context(), _text, true);
13
- // add SEP if not present
14
- auto vocab = llama_model_get_vocab(_sess->model());
15
- if (tokens.empty() || tokens.back() != llama_vocab_sep(vocab)) {
16
- tokens.push_back(llama_vocab_sep(vocab));
17
- }
18
- const int n_embd = llama_model_n_embd(_sess->model());
19
- do {
20
- auto ctx = _sess->context();
21
- int ret =
22
- llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()));
23
- if (ret < 0) {
24
- SetError("Failed to inference, code: " + std::to_string(ret));
25
- break;
26
- }
11
+ try {
12
+ _rn_ctx->params.prompt = _text;
13
+ _rn_ctx->params.n_predict = 0;
27
14
 
28
- float *embd;
29
- const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
30
- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
31
- embd = llama_get_embeddings(ctx);
32
- } else {
33
- embd = llama_get_embeddings_seq(ctx, 0);
34
- }
35
- if (embd == nullptr) {
36
- SetError("Failed to get embeddings");
37
- break;
38
- }
39
- _result.embedding.resize(n_embd);
40
- std::vector<float> embedding(embd, embd + n_embd), out(embd, embd + n_embd);
41
- common_embd_normalize(embedding.data(), out.data(), n_embd,
42
- _params.embd_normalize);
43
- memcpy(_result.embedding.data(), out.data(), n_embd * sizeof(float));
44
- } while (false);
15
+ _result.embedding = _rn_ctx->completion->embedding(_params);
16
+ } catch (const std::exception &e) {
17
+ SetError(e.what());
18
+ }
45
19
  }
46
20
 
47
21
  void EmbeddingWorker::OnOK() {
@@ -1,4 +1,5 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
  #include <vector>
3
4
 
4
5
  struct EmbeddingResult {
@@ -8,7 +9,7 @@ struct EmbeddingResult {
8
9
  class EmbeddingWorker : public Napi::AsyncWorker,
9
10
  public Napi::Promise::Deferred {
10
11
  public:
11
- EmbeddingWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
12
+ EmbeddingWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
12
13
  std::string text, common_params &params);
13
14
 
14
15
  protected:
@@ -17,7 +18,7 @@ protected:
17
18
  void OnError(const Napi::Error &err);
18
19
 
19
20
  private:
20
- LlamaSessionPtr _sess;
21
+ rnllama::llama_rn_context* _rn_ctx;
21
22
  std::string _text;
22
23
  common_params _params;
23
24
  EmbeddingResult _result;