@fugood/llama.node 1.1.11 → 1.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. package/CMakeLists.txt +5 -8
  2. package/lib/binding.ts +18 -1
  3. package/lib/index.js +2 -2
  4. package/lib/index.ts +2 -2
  5. package/package.json +20 -16
  6. package/src/DecodeAudioTokenWorker.cpp +23 -26
  7. package/src/DecodeAudioTokenWorker.h +6 -8
  8. package/src/DetokenizeWorker.cpp +5 -8
  9. package/src/DetokenizeWorker.h +6 -5
  10. package/src/DisposeWorker.cpp +23 -3
  11. package/src/DisposeWorker.h +4 -2
  12. package/src/EmbeddingWorker.cpp +9 -35
  13. package/src/EmbeddingWorker.h +3 -2
  14. package/src/LlamaCompletionWorker.cpp +217 -315
  15. package/src/LlamaCompletionWorker.h +6 -12
  16. package/src/LlamaContext.cpp +166 -396
  17. package/src/LlamaContext.h +8 -13
  18. package/src/LoadSessionWorker.cpp +22 -19
  19. package/src/LoadSessionWorker.h +3 -2
  20. package/src/RerankWorker.h +3 -2
  21. package/src/SaveSessionWorker.cpp +22 -19
  22. package/src/SaveSessionWorker.h +3 -2
  23. package/src/TokenizeWorker.cpp +38 -35
  24. package/src/TokenizeWorker.h +12 -3
  25. package/src/common.hpp +0 -458
  26. package/src/llama.cpp/common/arg.cpp +50 -30
  27. package/src/llama.cpp/common/chat.cpp +250 -1
  28. package/src/llama.cpp/common/chat.h +4 -0
  29. package/src/llama.cpp/common/common.h +1 -1
  30. package/src/llama.cpp/common/json-schema-to-grammar.cpp +21 -1
  31. package/src/llama.cpp/common/log.cpp +53 -2
  32. package/src/llama.cpp/common/log.h +10 -4
  33. package/src/llama.cpp/common/sampling.cpp +23 -2
  34. package/src/llama.cpp/common/sampling.h +3 -1
  35. package/src/llama.cpp/common/speculative.cpp +1 -1
  36. package/src/llama.cpp/ggml/CMakeLists.txt +3 -2
  37. package/src/llama.cpp/ggml/include/ggml-backend.h +15 -0
  38. package/src/llama.cpp/ggml/include/ggml-cpu.h +1 -1
  39. package/src/llama.cpp/ggml/include/ggml-metal.h +0 -6
  40. package/src/llama.cpp/ggml/include/ggml.h +56 -2
  41. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +21 -14
  42. package/src/llama.cpp/ggml/src/ggml-cpu/arch/riscv/quants.c +210 -96
  43. package/src/llama.cpp/ggml/src/ggml-cpu/arch/s390/quants.c +57 -59
  44. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +6 -7
  45. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +25 -38
  46. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -4
  47. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +4 -12
  48. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +379 -4
  49. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +1 -0
  50. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +41 -37
  51. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +150 -28
  52. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +320 -73
  53. package/src/llama.cpp/include/llama.h +5 -6
  54. package/src/llama.cpp/src/llama-adapter.cpp +33 -0
  55. package/src/llama.cpp/src/llama-adapter.h +3 -0
  56. package/src/llama.cpp/src/llama-arch.cpp +28 -4
  57. package/src/llama.cpp/src/llama-arch.h +3 -0
  58. package/src/llama.cpp/src/llama-context.cpp +65 -57
  59. package/src/llama.cpp/src/llama-context.h +1 -1
  60. package/src/llama.cpp/src/llama-graph.cpp +57 -11
  61. package/src/llama.cpp/src/llama-graph.h +8 -0
  62. package/src/llama.cpp/src/llama-hparams.cpp +37 -0
  63. package/src/llama.cpp/src/llama-hparams.h +10 -3
  64. package/src/llama.cpp/src/llama-kv-cache.cpp +56 -38
  65. package/src/llama.cpp/src/llama-kv-cache.h +9 -0
  66. package/src/llama.cpp/src/llama-model.cpp +217 -97
  67. package/src/llama.cpp/src/llama-model.h +0 -1
  68. package/src/llama.cpp/src/llama-quant.cpp +3 -3
  69. package/src/llama.cpp/src/llama-sampling.cpp +226 -126
  70. package/src/llama.cpp/src/llama.cpp +53 -10
  71. package/src/anyascii.c +0 -22223
  72. package/src/anyascii.h +0 -42
  73. package/src/tts_utils.cpp +0 -371
  74. package/src/tts_utils.h +0 -103
package/CMakeLists.txt CHANGED
@@ -137,10 +137,6 @@ file(
137
137
  "src/LlamaCompletionWorker.h"
138
138
  "src/LlamaContext.cpp"
139
139
  "src/LlamaContext.h"
140
- "src/TokenizeWorker.cpp"
141
- "src/TokenizeWorker.h"
142
- "src/DetokenizeWorker.cpp"
143
- "src/DetokenizeWorker.h"
144
140
  "src/EmbeddingWorker.cpp"
145
141
  "src/EmbeddingWorker.h"
146
142
  "src/RerankWorker.cpp"
@@ -149,12 +145,13 @@ file(
149
145
  "src/LoadSessionWorker.h"
150
146
  "src/SaveSessionWorker.cpp"
151
147
  "src/SaveSessionWorker.h"
148
+ "src/TokenizeWorker.cpp"
149
+ "src/TokenizeWorker.h"
150
+ "src/DetokenizeWorker.cpp"
151
+ "src/DetokenizeWorker.h"
152
152
  "src/DecodeAudioTokenWorker.cpp"
153
153
  "src/DecodeAudioTokenWorker.h"
154
- "src/tts_utils.cpp"
155
- "src/tts_utils.h"
156
- "src/anyascii.h"
157
- "src/anyascii.c"
154
+ "src/rn-llama/*"
158
155
  )
159
156
 
160
157
  if (NOT MSVC AND CMAKE_SYSTEM_NAME STREQUAL "Windows")
package/lib/binding.ts CHANGED
@@ -150,6 +150,21 @@ export type LlamaCompletionOptions = {
150
150
  * Help prevent hallucinations by forcing the TTS to use the correct words.
151
151
  */
152
152
  guide_tokens?: number[] | Int32Array
153
+ /**
154
+ * Number of top token probabilities to return for each generated token.
155
+ * When > 0, completion_probabilities will be included in streaming callbacks and final result.
156
+ */
157
+ n_probs?: number
158
+ }
159
+
160
+ export type TokenProbability = {
161
+ tok_str: string
162
+ prob: number
163
+ }
164
+
165
+ export type CompletionProbability = {
166
+ content: string
167
+ probs: TokenProbability[]
153
168
  }
154
169
 
155
170
  export type LlamaCompletionResult = {
@@ -163,6 +178,7 @@ export type LlamaCompletionResult = {
163
178
  context_full: boolean
164
179
  interrupted: boolean
165
180
  audio_tokens?: Array<number>
181
+ completion_probabilities?: CompletionProbability[]
166
182
  timings: {
167
183
  prompt_n: number
168
184
  prompt_ms: number
@@ -181,6 +197,7 @@ export type LlamaCompletionToken = {
181
197
  reasoning_content?: string
182
198
  tool_calls?: ToolCall[]
183
199
  accumulated_text?: string
200
+ completion_probabilities?: CompletionProbability[]
184
201
  }
185
202
 
186
203
  export type TokenizeResult = {
@@ -309,7 +326,7 @@ export interface LlamaContext {
309
326
  stopCompletion(): void
310
327
  tokenize(text: string, media_paths?: string[]): Promise<TokenizeResult>
311
328
  detokenize(tokens: number[]): Promise<string>
312
- embedding(text: string): Promise<EmbeddingResult>
329
+ embedding(text: string, params?: { embd_normalize?: number }): Promise<EmbeddingResult>
313
330
  rerank(query: string, documents: string[], params?: RerankParams): Promise<RerankResult[]>
314
331
  saveSession(path: string): Promise<void>
315
332
  loadSession(path: string): Promise<void>
package/lib/index.js CHANGED
@@ -180,8 +180,8 @@ class LlamaContextWrapper {
180
180
  detokenize(tokens) {
181
181
  return this.ctx.detokenize(tokens);
182
182
  }
183
- embedding(text) {
184
- return this.ctx.embedding(text);
183
+ embedding(text, params) {
184
+ return this.ctx.embedding(text, params);
185
185
  }
186
186
  rerank(query, documents, params) {
187
187
  return this.ctx
package/lib/index.ts CHANGED
@@ -251,8 +251,8 @@ class LlamaContextWrapper {
251
251
  return this.ctx.detokenize(tokens)
252
252
  }
253
253
 
254
- embedding(text: string): Promise<EmbeddingResult> {
255
- return this.ctx.embedding(text)
254
+ embedding(text: string, params?: { embd_normalize?: number }): Promise<EmbeddingResult> {
255
+ return this.ctx.embedding(text, params)
256
256
  }
257
257
 
258
258
  rerank(
package/package.json CHANGED
@@ -1,11 +1,12 @@
1
1
  {
2
2
  "name": "@fugood/llama.node",
3
3
  "access": "public",
4
- "version": "1.1.11",
4
+ "version": "1.2.0",
5
5
  "description": "An another Node binding of llama.cpp",
6
6
  "main": "lib/index.js",
7
7
  "scripts": {
8
- "bootstrap": "npm install --omit=optional",
8
+ "copy-rn-llama-source": "node scripts/copy-rn-llama-source.js",
9
+ "bootstrap": "npm run copy-rn-llama-source && npm install --omit=optional",
9
10
  "postinstall": "node scripts/check.js",
10
11
  "pretest": "node scripts/download-test-models.js",
11
12
  "test": "jest",
@@ -53,6 +54,7 @@
53
54
  "scripts/check.js",
54
55
  "scripts/llama.cpp.patch",
55
56
  "src/*.{cc,c,h,hpp}",
57
+ "src/rn-llama/*",
56
58
  "src/DecodeAudioTokenWorker.cpp",
57
59
  "src/DetokenizeWorker.cpp",
58
60
  "src/DisposeWorker.cpp",
@@ -62,7 +64,6 @@
62
64
  "src/LoadSessionWorker.cpp",
63
65
  "src/SaveSessionWorker.cpp",
64
66
  "src/TokenizeWorker.cpp",
65
- "src/tts_utils.cpp",
66
67
  "src/llama.cpp/{common,src,include}/**/*.{h,hpp,cpp,cc,c}",
67
68
  "src/llama.cpp/ggml/include/*.h",
68
69
  "src/llama.cpp/ggml/src/ggml-cpu/**/*.{h,hpp,cpp,cc,c}",
@@ -71,19 +72,19 @@
71
72
  "CMakeLists.txt"
72
73
  ],
73
74
  "optionalDependencies": {
74
- "@fugood/node-llama-linux-x64": "1.1.11",
75
- "@fugood/node-llama-linux-x64-vulkan": "1.1.11",
76
- "@fugood/node-llama-linux-x64-cuda": "1.1.11",
77
- "@fugood/node-llama-linux-arm64": "1.1.11",
78
- "@fugood/node-llama-linux-arm64-vulkan": "1.1.11",
79
- "@fugood/node-llama-linux-arm64-cuda": "1.1.11",
80
- "@fugood/node-llama-win32-x64": "1.1.11",
81
- "@fugood/node-llama-win32-x64-vulkan": "1.1.11",
82
- "@fugood/node-llama-win32-x64-cuda": "1.1.11",
83
- "@fugood/node-llama-win32-arm64": "1.1.11",
84
- "@fugood/node-llama-win32-arm64-vulkan": "1.1.11",
85
- "@fugood/node-llama-darwin-x64": "1.1.11",
86
- "@fugood/node-llama-darwin-arm64": "1.1.11"
75
+ "@fugood/node-llama-linux-x64": "1.2.0",
76
+ "@fugood/node-llama-linux-x64-vulkan": "1.2.0",
77
+ "@fugood/node-llama-linux-x64-cuda": "1.2.0",
78
+ "@fugood/node-llama-linux-arm64": "1.2.0",
79
+ "@fugood/node-llama-linux-arm64-vulkan": "1.2.0",
80
+ "@fugood/node-llama-linux-arm64-cuda": "1.2.0",
81
+ "@fugood/node-llama-win32-x64": "1.2.0",
82
+ "@fugood/node-llama-win32-x64-vulkan": "1.2.0",
83
+ "@fugood/node-llama-win32-x64-cuda": "1.2.0",
84
+ "@fugood/node-llama-win32-arm64": "1.2.0",
85
+ "@fugood/node-llama-win32-arm64-vulkan": "1.2.0",
86
+ "@fugood/node-llama-darwin-x64": "1.2.0",
87
+ "@fugood/node-llama-darwin-arm64": "1.2.0"
87
88
  },
88
89
  "devDependencies": {
89
90
  "@babel/preset-env": "^7.24.4",
@@ -115,6 +116,9 @@
115
116
  ],
116
117
  "testMatch": [
117
118
  "**/*.test.ts"
119
+ ],
120
+ "testPathIgnorePatterns": [
121
+ "<rootDir>/src/llama.rn/"
118
122
  ]
119
123
  },
120
124
  "prettier": {
@@ -1,40 +1,37 @@
1
1
  #include "DecodeAudioTokenWorker.h"
2
- #include "tts_utils.h"
3
- #include <vector>
2
+ #include "LlamaContext.h"
4
3
 
5
- DecodeAudioTokenWorker::DecodeAudioTokenWorker(
6
- const Napi::CallbackInfo &info, llama_model *model, llama_context *ctx,
7
- int n_threads, const std::vector<llama_token> &tokens)
8
- : AsyncWorker(info.Env()), Deferred(info.Env()), _model(model), _ctx(ctx),
9
- _n_threads(n_threads), _tokens(tokens) {}
4
+ DecodeAudioTokenWorker::DecodeAudioTokenWorker(const Napi::CallbackInfo &info,
5
+ rnllama::llama_rn_context* rn_ctx, std::vector<int32_t> tokens)
6
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _tokens(tokens) {}
10
7
 
11
8
  void DecodeAudioTokenWorker::Execute() {
12
- const int n_codes = _tokens.size();
13
- llama_batch batch = llama_batch_init(n_codes, 0, 1);
14
- for (size_t i = 0; i < _tokens.size(); ++i) {
15
- common_batch_add(batch, _tokens[i], i, {0}, true);
9
+ try {
10
+ if (!_rn_ctx->tts_wrapper) {
11
+ SetError("Vocoder not initialized");
12
+ return;
13
+ }
14
+
15
+ // Convert to llama_token vector - rn-tts handles token adjustment internally
16
+ std::vector<llama_token> llama_tokens;
17
+ for (const auto& token : _tokens) {
18
+ llama_tokens.push_back(token);
19
+ }
20
+
21
+ // Use the rn-tts API instead of directly accessing the worker
22
+ _result = _rn_ctx->tts_wrapper->decodeAudioTokens(_rn_ctx, llama_tokens);
23
+ } catch (const std::exception &e) {
24
+ SetError(e.what());
16
25
  }
17
- if (batch.n_tokens != n_codes) {
18
- SetError("batch.n_tokens != n_codes");
19
- return;
20
- }
21
- if (llama_encode(_ctx, batch) != 0) {
22
- SetError("llama_encode() failed");
23
- return;
24
- }
25
- llama_synchronize(_ctx);
26
- const int n_embd = llama_model_n_embd(_model);
27
- const float *embd = llama_get_embeddings(_ctx);
28
- _result = embd_to_audio(embd, n_codes, n_embd, _n_threads);
29
26
  }
30
27
 
31
28
  void DecodeAudioTokenWorker::OnOK() {
32
- auto result =
33
- Napi::Float32Array::New(Napi::AsyncWorker::Env(), _result.size());
29
+ // Create Float32Array and copy the data
30
+ auto result = Napi::Float32Array::New(Napi::AsyncWorker::Env(), _result.size());
34
31
  memcpy(result.Data(), _result.data(), _result.size() * sizeof(float));
35
32
  Napi::Promise::Deferred::Resolve(result);
36
33
  }
37
34
 
38
35
  void DecodeAudioTokenWorker::OnError(const Napi::Error &err) {
39
36
  Napi::Promise::Deferred::Reject(err.Value());
40
- }
37
+ }
@@ -1,12 +1,12 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
  #include <vector>
3
4
 
4
5
  class DecodeAudioTokenWorker : public Napi::AsyncWorker,
5
6
  public Napi::Promise::Deferred {
6
7
  public:
7
- DecodeAudioTokenWorker(const Napi::CallbackInfo &info, llama_model *model,
8
- llama_context *ctx, int n_threads,
9
- const std::vector<llama_token> &tokens);
8
+ DecodeAudioTokenWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
9
+ std::vector<int32_t> tokens);
10
10
 
11
11
  protected:
12
12
  void Execute();
@@ -14,9 +14,7 @@ protected:
14
14
  void OnError(const Napi::Error &err);
15
15
 
16
16
  private:
17
- llama_model *_model;
18
- llama_context *_ctx;
19
- int _n_threads;
20
- std::vector<llama_token> _tokens;
17
+ rnllama::llama_rn_context* _rn_ctx;
18
+ std::vector<int32_t> _tokens;
21
19
  std::vector<float> _result;
22
- };
20
+ };
@@ -2,21 +2,18 @@
2
2
  #include "LlamaContext.h"
3
3
 
4
4
  DetokenizeWorker::DetokenizeWorker(const Napi::CallbackInfo &info,
5
- LlamaSessionPtr &sess,
6
- std::vector<llama_token> &tokens)
7
- : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess),
8
- _tokens(std::move(tokens)) {}
5
+ rnllama::llama_rn_context* rn_ctx, std::vector<int32_t> tokens)
6
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _tokens(tokens) {}
9
7
 
10
8
  void DetokenizeWorker::Execute() {
11
- const auto text = ::common_detokenize(_sess->context(), _tokens);
9
+ const auto text = tokens_to_str(_rn_ctx->ctx, _tokens.begin(), _tokens.end());
12
10
  _text = std::move(text);
13
11
  }
14
12
 
15
13
  void DetokenizeWorker::OnOK() {
16
- Napi::Promise::Deferred::Resolve(
17
- Napi::String::New(Napi::AsyncWorker::Env(), _text));
14
+ Napi::Promise::Deferred::Resolve(Napi::String::New(Napi::AsyncWorker::Env(), _text));
18
15
  }
19
16
 
20
17
  void DetokenizeWorker::OnError(const Napi::Error &err) {
21
18
  Napi::Promise::Deferred::Reject(err.Value());
22
- }
19
+ }
@@ -1,11 +1,12 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
  #include <vector>
3
4
 
4
5
  class DetokenizeWorker : public Napi::AsyncWorker,
5
6
  public Napi::Promise::Deferred {
6
7
  public:
7
- DetokenizeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
8
- std::vector<llama_token> &tokens);
8
+ DetokenizeWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
9
+ std::vector<int32_t> tokens);
9
10
 
10
11
  protected:
11
12
  void Execute();
@@ -13,7 +14,7 @@ protected:
13
14
  void OnError(const Napi::Error &err);
14
15
 
15
16
  private:
16
- LlamaSessionPtr _sess;
17
- std::vector<llama_token> _tokens;
17
+ rnllama::llama_rn_context* _rn_ctx;
18
+ std::vector<int32_t> _tokens;
18
19
  std::string _text;
19
- };
20
+ };
@@ -1,10 +1,30 @@
1
1
  #include "DisposeWorker.h"
2
+ #include "rn-llama/rn-completion.h"
2
3
 
3
4
  DisposeWorker::DisposeWorker(const Napi::CallbackInfo &info,
4
- LlamaSessionPtr sess)
5
- : AsyncWorker(info.Env()), Deferred(info.Env()), sess_(std::move(sess)) {}
5
+ rnllama::llama_rn_context* rn_ctx, rnllama::llama_rn_context** parent_ptr)
6
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _parent_ptr(parent_ptr) {}
6
7
 
7
- void DisposeWorker::Execute() { sess_->dispose(); }
8
+ void DisposeWorker::Execute() {
9
+ if (_rn_ctx) {
10
+ // Ensure all child contexts are properly cleaned up first
11
+ try {
12
+ // Now delete the main context
13
+ delete _rn_ctx;
14
+
15
+ // Set parent pointer to nullptr to prevent double free
16
+ if (_parent_ptr) {
17
+ *_parent_ptr = nullptr;
18
+ }
19
+ } catch (const std::exception& e) {
20
+ SetError(std::string("Error during context disposal: ") + e.what());
21
+ return;
22
+ } catch (...) {
23
+ SetError("Unknown error during context disposal");
24
+ return;
25
+ }
26
+ }
27
+ }
8
28
 
9
29
  void DisposeWorker::OnOK() { Resolve(AsyncWorker::Env().Undefined()); }
10
30
 
@@ -1,8 +1,9 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
 
3
4
  class DisposeWorker : public Napi::AsyncWorker, public Napi::Promise::Deferred {
4
5
  public:
5
- DisposeWorker(const Napi::CallbackInfo &info, LlamaSessionPtr sess);
6
+ DisposeWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx, rnllama::llama_rn_context** parent_ptr);
6
7
 
7
8
  protected:
8
9
  void Execute();
@@ -10,5 +11,6 @@ protected:
10
11
  void OnError(const Napi::Error &err);
11
12
 
12
13
  private:
13
- LlamaSessionPtr sess_;
14
+ rnllama::llama_rn_context* _rn_ctx;
15
+ rnllama::llama_rn_context** _parent_ptr; // Pointer to the parent's _rn_ctx pointer
14
16
  };
@@ -2,46 +2,20 @@
2
2
  #include "LlamaContext.h"
3
3
 
4
4
  EmbeddingWorker::EmbeddingWorker(const Napi::CallbackInfo &info,
5
- LlamaSessionPtr &sess, std::string text,
5
+ rnllama::llama_rn_context* rn_ctx, std::string text,
6
6
  common_params &params)
7
- : AsyncWorker(info.Env()), Deferred(info.Env()), _sess(sess), _text(text),
7
+ : AsyncWorker(info.Env()), Deferred(info.Env()), _rn_ctx(rn_ctx), _text(text),
8
8
  _params(params) {}
9
9
 
10
10
  void EmbeddingWorker::Execute() {
11
- llama_memory_clear(llama_get_memory(_sess->context()), true);
12
- auto tokens = ::common_tokenize(_sess->context(), _text, true);
13
- // add SEP if not present
14
- auto vocab = llama_model_get_vocab(_sess->model());
15
- if (tokens.empty() || tokens.back() != llama_vocab_sep(vocab)) {
16
- tokens.push_back(llama_vocab_sep(vocab));
17
- }
18
- const int n_embd = llama_model_n_embd(_sess->model());
19
- do {
20
- auto ctx = _sess->context();
21
- int ret =
22
- llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()));
23
- if (ret < 0) {
24
- SetError("Failed to inference, code: " + std::to_string(ret));
25
- break;
26
- }
11
+ try {
12
+ _rn_ctx->params.prompt = _text;
13
+ _rn_ctx->params.n_predict = 0;
27
14
 
28
- float *embd;
29
- const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
30
- if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
31
- embd = llama_get_embeddings(ctx);
32
- } else {
33
- embd = llama_get_embeddings_seq(ctx, 0);
34
- }
35
- if (embd == nullptr) {
36
- SetError("Failed to get embeddings");
37
- break;
38
- }
39
- _result.embedding.resize(n_embd);
40
- std::vector<float> embedding(embd, embd + n_embd), out(embd, embd + n_embd);
41
- common_embd_normalize(embedding.data(), out.data(), n_embd,
42
- _params.embd_normalize);
43
- memcpy(_result.embedding.data(), out.data(), n_embd * sizeof(float));
44
- } while (false);
15
+ _result.embedding = _rn_ctx->completion->embedding(_params);
16
+ } catch (const std::exception &e) {
17
+ SetError(e.what());
18
+ }
45
19
  }
46
20
 
47
21
  void EmbeddingWorker::OnOK() {
@@ -1,4 +1,5 @@
1
1
  #include "common.hpp"
2
+ #include "rn-llama/rn-llama.h"
2
3
  #include <vector>
3
4
 
4
5
  struct EmbeddingResult {
@@ -8,7 +9,7 @@ struct EmbeddingResult {
8
9
  class EmbeddingWorker : public Napi::AsyncWorker,
9
10
  public Napi::Promise::Deferred {
10
11
  public:
11
- EmbeddingWorker(const Napi::CallbackInfo &info, LlamaSessionPtr &sess,
12
+ EmbeddingWorker(const Napi::CallbackInfo &info, rnllama::llama_rn_context* rn_ctx,
12
13
  std::string text, common_params &params);
13
14
 
14
15
  protected:
@@ -17,7 +18,7 @@ protected:
17
18
  void OnError(const Napi::Error &err);
18
19
 
19
20
  private:
20
- LlamaSessionPtr _sess;
21
+ rnllama::llama_rn_context* _rn_ctx;
21
22
  std::string _text;
22
23
  common_params _params;
23
24
  EmbeddingResult _result;