llama_cpp 0.2.2 → 0.3.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
data/ext/llama_cpp/src/llama.h
CHANGED
@@ -26,6 +26,14 @@
|
|
26
26
|
# define LLAMA_API
|
27
27
|
#endif
|
28
28
|
|
29
|
+
#ifdef __GNUC__
|
30
|
+
# define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
31
|
+
#elif defined(_MSC_VER)
|
32
|
+
# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
33
|
+
#else
|
34
|
+
# define DEPRECATED(func, hint) func
|
35
|
+
#endif
|
36
|
+
|
29
37
|
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
|
30
38
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
31
39
|
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
|
@@ -38,6 +46,8 @@
|
|
38
46
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
39
47
|
#define LLAMA_SESSION_VERSION 1
|
40
48
|
|
49
|
+
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
50
|
+
|
41
51
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
42
52
|
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
43
53
|
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
@@ -53,6 +63,7 @@ extern "C" {
|
|
53
63
|
// TODO: show sample usage
|
54
64
|
//
|
55
65
|
|
66
|
+
struct llama_model;
|
56
67
|
struct llama_context;
|
57
68
|
|
58
69
|
typedef int llama_token;
|
@@ -72,11 +83,11 @@ extern "C" {
|
|
72
83
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
73
84
|
|
74
85
|
struct llama_context_params {
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
86
|
+
uint32_t seed; // RNG seed, -1 for random
|
87
|
+
int32_t n_ctx; // text context
|
88
|
+
int32_t n_batch; // prompt processing batch size
|
89
|
+
int32_t n_gpu_layers; // number of layers to store in VRAM
|
90
|
+
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
80
91
|
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
|
81
92
|
// called with a progress value between 0 and 1, pass NULL to disable
|
82
93
|
llama_progress_callback progress_callback;
|
@@ -131,17 +142,29 @@ extern "C" {
|
|
131
142
|
|
132
143
|
// TODO: not great API - very likely to change
|
133
144
|
// Initialize the llama + ggml backend
|
145
|
+
// If numa is true, use NUMA optimizations
|
134
146
|
// Call once at the start of the program
|
135
|
-
LLAMA_API void llama_init_backend();
|
147
|
+
LLAMA_API void llama_init_backend(bool numa);
|
136
148
|
|
137
149
|
LLAMA_API int64_t llama_time_us();
|
138
150
|
|
151
|
+
LLAMA_API struct llama_model * llama_load_model_from_file(
|
152
|
+
const char * path_model,
|
153
|
+
struct llama_context_params params);
|
154
|
+
|
155
|
+
LLAMA_API void llama_free_model(struct llama_model * model);
|
156
|
+
|
157
|
+
LLAMA_API struct llama_context * llama_new_context_with_model(
|
158
|
+
struct llama_model * model,
|
159
|
+
struct llama_context_params params);
|
160
|
+
|
139
161
|
// Various functions for loading a ggml llama model.
|
140
162
|
// Allocate (almost) all memory needed for the model.
|
141
163
|
// Return NULL on failure
|
142
|
-
LLAMA_API struct llama_context * llama_init_from_file(
|
164
|
+
LLAMA_API DEPRECATED(struct llama_context * llama_init_from_file(
|
143
165
|
const char * path_model,
|
144
|
-
struct llama_context_params params)
|
166
|
+
struct llama_context_params params),
|
167
|
+
"please use llama_load_model_from_file combined with llama_new_context_with_model instead");
|
145
168
|
|
146
169
|
// Frees all allocated memory
|
147
170
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
@@ -158,8 +181,15 @@ extern "C" {
|
|
158
181
|
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
159
182
|
// will be applied on top of the previous one
|
160
183
|
// Returns 0 on success
|
161
|
-
LLAMA_API int llama_apply_lora_from_file(
|
184
|
+
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
|
162
185
|
struct llama_context * ctx,
|
186
|
+
const char * path_lora,
|
187
|
+
const char * path_base_model,
|
188
|
+
int n_threads),
|
189
|
+
"please use llama_model_apply_lora_from_file instead");
|
190
|
+
|
191
|
+
LLAMA_API int llama_model_apply_lora_from_file(
|
192
|
+
const struct llama_model * model,
|
163
193
|
const char * path_lora,
|
164
194
|
const char * path_base_model,
|
165
195
|
int n_threads);
|
@@ -168,7 +198,7 @@ extern "C" {
|
|
168
198
|
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
169
199
|
|
170
200
|
// Sets the current rng seed.
|
171
|
-
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx,
|
201
|
+
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
172
202
|
|
173
203
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
174
204
|
// and kv_cache) - will often be smaller after compacting tokens
|
@@ -198,6 +228,14 @@ extern "C" {
|
|
198
228
|
int n_past,
|
199
229
|
int n_threads);
|
200
230
|
|
231
|
+
// Same as llama_eval, but use float matrix input directly.
|
232
|
+
LLAMA_API int llama_eval_embd(
|
233
|
+
struct llama_context * ctx,
|
234
|
+
const float * embd,
|
235
|
+
int n_tokens,
|
236
|
+
int n_past,
|
237
|
+
int n_threads);
|
238
|
+
|
201
239
|
// Export a static computation graph for context of 511 and batch size of 1
|
202
240
|
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
203
241
|
// parameters here to keep things simple
|
@@ -310,7 +348,7 @@ extern "C" {
|
|
310
348
|
#include <string>
|
311
349
|
struct ggml_tensor;
|
312
350
|
|
313
|
-
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
351
|
+
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
314
352
|
|
315
353
|
#endif
|
316
354
|
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.
|
6
|
+
VERSION = '0.3.1'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = 'master-
|
9
|
+
LLAMA_CPP_VERSION = 'master-b8c8dda'
|
10
10
|
end
|
data/lib/llama_cpp.rb
CHANGED
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
require_relative 'llama_cpp/version'
|
4
4
|
require_relative 'llama_cpp/llama_cpp'
|
5
|
-
require_relative 'llama_cpp/client'
|
6
5
|
|
7
6
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
8
7
|
module LLaMACpp
|
@@ -20,7 +19,6 @@ module LLaMACpp
|
|
20
19
|
# @return [String]
|
21
20
|
def generate(context, prompt, n_predict: 128, n_threads: 1) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
22
21
|
raise ArgumentError, 'context must be an instance of LLaMACpp::Context' unless context.is_a?(LLaMACpp::Context)
|
23
|
-
raise ArgumentError, 'context must have loaded the model' if context.empty?
|
24
22
|
raise ArgumentError, 'prompt must be a String' unless prompt.is_a?(String)
|
25
23
|
|
26
24
|
spaced_prompt = " #{prompt}"
|
data/sig/llama_cpp.rbs
CHANGED
@@ -4,6 +4,7 @@ module LLaMACpp
|
|
4
4
|
LLAMA_FILE_VERSION: String
|
5
5
|
LLAMA_FILE_MAGIC: String
|
6
6
|
LLAMA_FILE_MAGIC_UNVERSIONED: String
|
7
|
+
LLAMA_DEFALUT_SEED: String
|
7
8
|
|
8
9
|
LLAMA_MAX_DEVICES: Integer
|
9
10
|
|
@@ -25,7 +26,7 @@ module LLaMACpp
|
|
25
26
|
LLAMA_FTYPE_MOSTLY_Q5_K_M: Integer
|
26
27
|
LLAMA_FTYPE_MOSTLY_Q6_K: Integer
|
27
28
|
|
28
|
-
def self?.init_backend: () -> void
|
29
|
+
def self?.init_backend: (?numa: bool) -> void
|
29
30
|
def self?.model_quantize: (input_path: String, output_path: String, params: ModelQuantizeParams) -> void
|
30
31
|
def self?.generate: (::LLaMACpp::Context, String, ?n_predict: Integer, ?n_threads: Integer) -> String
|
31
32
|
def self?.print_system_info: () -> void
|
@@ -55,17 +56,25 @@ module LLaMACpp
|
|
55
56
|
def sorted: () -> bool
|
56
57
|
end
|
57
58
|
|
58
|
-
class
|
59
|
+
class Model
|
59
60
|
public
|
60
61
|
|
61
62
|
def initialize: (model_path: String, params: ::LLaMACpp::ContextParams) -> void
|
62
63
|
| () -> void
|
63
|
-
def embeddings: () -> Array[Float]
|
64
64
|
def empty?: () -> bool
|
65
|
-
def eval: (tokens: Array[Integer], n_past: Integer, ?n_tokens: Integer, ?n_threads: Integer) -> void
|
66
|
-
def eval_export: (String) -> bool
|
67
65
|
def free: () -> void
|
68
66
|
def load: (model_path: String, params: ::LLaMACpp::ContextParams) -> void
|
67
|
+
def apply_lora_from_file: (lora_path: String, ?base_model_path: String, ?n_threads: Integer) -> void
|
68
|
+
end
|
69
|
+
|
70
|
+
class Context
|
71
|
+
public
|
72
|
+
|
73
|
+
def initialize: (model: ::LLaMACpp::Model) -> void
|
74
|
+
def embeddings: () -> Array[Float]
|
75
|
+
def eval: (tokens: Array[Integer], n_past: Integer, ?n_tokens: Integer, ?n_threads: Integer) -> void
|
76
|
+
def eval_embd: (tokens: Array[Float], n_past: Integer, ?n_tokens: Integer, ?n_threads: Integer) -> void
|
77
|
+
def eval_export: (String) -> bool
|
69
78
|
def logits: () -> Array[Float]
|
70
79
|
def n_ctx: () -> Integer
|
71
80
|
def n_embd: () -> Integer
|
@@ -75,7 +84,6 @@ module LLaMACpp
|
|
75
84
|
def reset_timings: () -> void
|
76
85
|
def token_to_str: (Integer) -> String
|
77
86
|
def tokenize: (text: String, ?n_max_tokens: Integer, ?add_bos: bool) -> Array[Integer]
|
78
|
-
def apply_lora_from_file: (lora_path: String, ?base_model_path: String, ?n_threads: Integer) -> void
|
79
87
|
def kv_cache_token_count: () -> Integer
|
80
88
|
def set_rng_seed: (Integer) -> void
|
81
89
|
def load_session_file: (session_path: String) -> void
|
@@ -138,15 +146,4 @@ module LLaMACpp
|
|
138
146
|
end
|
139
147
|
|
140
148
|
class Params = ContextParams
|
141
|
-
|
142
|
-
class Client
|
143
|
-
def initialize(model_path: String, ?lora_adapter_path: String, ?lora_base_path: String,
|
144
|
-
?n_ctx: Integer, ?memory_f16: bool, ?use_mmap: bool, ?use_mlock: bool,
|
145
|
-
?embedding: bool, ?n_threads: Integer, ?seed: Integer) -> void
|
146
|
-
def completions(String, ?max_tokens: Integer, ?n_keep: Integer, ?repeat_last_n: Integer, ?n_batch: Integer,
|
147
|
-
?frequency: Float, ?presence: Float,
|
148
|
-
?top_k: Integer, ?top_p: Float, ?tfs_z: Float, ?typical_p: Float, ?temperature: Float,
|
149
|
-
?repeat_penalty: Float) -> String
|
150
|
-
def embeddings(String) -> Array[Float]
|
151
|
-
end
|
152
149
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: llama_cpp
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-07-02 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
14
14
|
email:
|
@@ -44,7 +44,6 @@ files:
|
|
44
44
|
- ext/llama_cpp/src/llama.cpp
|
45
45
|
- ext/llama_cpp/src/llama.h
|
46
46
|
- lib/llama_cpp.rb
|
47
|
-
- lib/llama_cpp/client.rb
|
48
47
|
- lib/llama_cpp/version.rb
|
49
48
|
- sig/llama_cpp.rbs
|
50
49
|
homepage: https://github.com/yoshoku/llama_cpp.rb
|
data/lib/llama_cpp/client.rb
DELETED
@@ -1,172 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module LLaMACpp
|
4
|
-
# Client provides a high-level interface to the LLM model.
|
5
|
-
class Client # rubocop:disable Metrics/ClassLength
|
6
|
-
# Creates a new client.
|
7
|
-
#
|
8
|
-
# @param model_path [String] The path to the model file.
|
9
|
-
# @param lora_adapter_path [String] The path to the LoRA adapter file.
|
10
|
-
# @param lora_base_path [String] The path to the LoRA base model file.
|
11
|
-
# @param n_ctx [Integer] The context size.
|
12
|
-
# @param memory_f16 [Boolean] The flag wheter to use f16 instead of f32 for memory kv.
|
13
|
-
# @param use_mmap [Boolean] The flag whether to use mmap.
|
14
|
-
# @param use_mlock [Boolean] The flag hether to use mlock.
|
15
|
-
# @param embedding [Boolean] The flag whether to calculate embedding.
|
16
|
-
# @param n_threads [Integer] The number of threads to use.
|
17
|
-
# @param seed [Integer] The seed for the random number generator.
|
18
|
-
# @return [Client]
|
19
|
-
# rubocop:disable Metrics/MethodLength, Metrics/ParameterLists
|
20
|
-
def initialize(model_path:, lora_adapter_path: nil, lora_base_path: nil,
|
21
|
-
n_ctx: 512, memory_f16: false, use_mmap: true, use_mlock: false,
|
22
|
-
embedding: false,
|
23
|
-
n_threads: 1, seed: 0)
|
24
|
-
@params = {
|
25
|
-
model_path: model_path,
|
26
|
-
lora_adapter_path: lora_adapter_path,
|
27
|
-
lora_base_path: lora_base_path,
|
28
|
-
n_ctx: n_ctx,
|
29
|
-
memory_f16: memory_f16,
|
30
|
-
use_mmap: use_mmap,
|
31
|
-
use_mlock: use_mlock,
|
32
|
-
embedding: embedding,
|
33
|
-
n_threads: n_threads,
|
34
|
-
seed: seed
|
35
|
-
}
|
36
|
-
@context_params = ContextParams.new
|
37
|
-
@context_params.n_ctx = n_ctx
|
38
|
-
@context_params.n_parts = n_parts
|
39
|
-
@context_params.f16_kv = memory_f16
|
40
|
-
@context_params.use_mmap = use_mmap
|
41
|
-
@context_params.use_mlock = use_mlock
|
42
|
-
@context_params.embedding = embedding
|
43
|
-
@context_params.seed = seed
|
44
|
-
@context = Context.new(model_path: model_path, params: @context_params)
|
45
|
-
return unless lora_adapter_path.is_a?(String)
|
46
|
-
|
47
|
-
if lora_base_path.is_a?(String)
|
48
|
-
@context.apply_lora_from_file(lora_path: lora_adapter_path, base_model_path: lora_base_path, n_threads: n_threads)
|
49
|
-
else
|
50
|
-
@context.apply_lora_from_file(lora_path: lora_adapter_path, n_threads: n_threads)
|
51
|
-
end
|
52
|
-
end
|
53
|
-
# rubocop:enable Metrics/MethodLength, Metrics/ParameterLists
|
54
|
-
|
55
|
-
# Generates completions for a given prompt.
|
56
|
-
#
|
57
|
-
# @param prompt [String] The prompt to generate completions for.
|
58
|
-
# @param max_tokens [Integer] The maximum number of tokens to generate.
|
59
|
-
# @param n_keep [Integer] The number of tokens to keep from the initial prompt.
|
60
|
-
# @param repeat_last_n [Integer] The number of tokens to use for repeat penalty.
|
61
|
-
# @param n_batch [Integer] The batch size.
|
62
|
-
# @param frequency [Float] The frequency penalty value.
|
63
|
-
# @param presence [Float] The presence penalty value.
|
64
|
-
# @param top_k [Integer] The top-k value.
|
65
|
-
# @param top_p [Float] The top-p value.
|
66
|
-
# @param tfs_z [Float] The tail free sampling parameter.
|
67
|
-
# @param typical_p [Float] The typical probability value.
|
68
|
-
# @param temperature [Float] The temperature value.
|
69
|
-
# @param repeat_penalty [Float] The repeat penalty value.
|
70
|
-
# @return [String]
|
71
|
-
# rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/ParameterLists, Metrics/PerceivedComplexity
|
72
|
-
def completions(prompt, max_tokens: 64, n_keep: 10, repeat_last_n: 64, n_batch: 512,
|
73
|
-
frequency: 0.0, presence: 0.0,
|
74
|
-
top_k: 40, top_p: 0.95, tfs_z: 1.0, typical_p: 1.0, temperature: 0.8, repeat_penalty: 1.1)
|
75
|
-
embd_input = tokenize_prompt(prompt)
|
76
|
-
|
77
|
-
n_ctx = @context.n_ctx
|
78
|
-
raise ArgumentError, "prompt is too long #{embd_input.size} tokens, maximum is #{n_ctx - 4}" if embd_input.size > n_ctx - 4
|
79
|
-
|
80
|
-
last_n_tokens = [0] * n_ctx
|
81
|
-
|
82
|
-
embd = []
|
83
|
-
n_consumed = 0
|
84
|
-
n_past = 0
|
85
|
-
n_remain = max_tokens
|
86
|
-
n_vocab = @context.n_vocab
|
87
|
-
output = []
|
88
|
-
|
89
|
-
while n_remain != 0
|
90
|
-
unless embd.empty?
|
91
|
-
if n_past + embd.size > n_ctx
|
92
|
-
n_left = n_past - n_keep
|
93
|
-
n_past = n_keep
|
94
|
-
embd.insert(0, last_n_tokens[(n_ctx - (n_left / 2) - embd.size)...-embd.size])
|
95
|
-
end
|
96
|
-
|
97
|
-
@context.eval(tokens: embd, n_past: n_past, n_threads: @params[:n_threads])
|
98
|
-
end
|
99
|
-
|
100
|
-
n_past += embd.size
|
101
|
-
embd.clear
|
102
|
-
|
103
|
-
if embd_input.size <= n_consumed
|
104
|
-
logits = @context.logits
|
105
|
-
base_candidates = Array.new(n_vocab) { |i| LLaMACpp::TokenData.new(id: i, logit: logits[i], p: 0.0) }
|
106
|
-
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
107
|
-
|
108
|
-
# apply penalties
|
109
|
-
last_n_repeat = [last_n_tokens.size, repeat_last_n, n_ctx].min
|
110
|
-
@context.sample_repetition_penalty(candidates, last_n_tokens[-last_n_repeat..], penalty: repeat_penalty)
|
111
|
-
@context.sample_frequency_and_presence_penalties(
|
112
|
-
candidates, last_n_tokens[-last_n_repeat..], frequency: frequency, presence: presence
|
113
|
-
)
|
114
|
-
|
115
|
-
# temperature sampling
|
116
|
-
@context.sample_top_k(candidates, k: top_k)
|
117
|
-
@context.sample_tail_free(candidates, z: tfs_z)
|
118
|
-
@context.sample_typical(candidates, prob: typical_p)
|
119
|
-
@context.sample_top_p(candidates, prob: top_p)
|
120
|
-
@context.sample_temperature(candidates, temperature: temperature)
|
121
|
-
id = @context.sample_token(candidates)
|
122
|
-
|
123
|
-
last_n_tokens.shift
|
124
|
-
last_n_tokens.push(id)
|
125
|
-
|
126
|
-
last_n_tokens.shift
|
127
|
-
last_n_tokens.push(id)
|
128
|
-
|
129
|
-
embd.push(id)
|
130
|
-
n_remain -= 1
|
131
|
-
else
|
132
|
-
while embd_input.size > n_consumed
|
133
|
-
embd.push(embd_input[n_consumed])
|
134
|
-
last_n_tokens.shift
|
135
|
-
last_n_tokens.push(embd_input[n_consumed])
|
136
|
-
n_consumed += 1
|
137
|
-
break if embd.size >= n_batch
|
138
|
-
end
|
139
|
-
end
|
140
|
-
|
141
|
-
embd.each { |token| output << @context.token_to_str(token) }
|
142
|
-
|
143
|
-
break if !embd.empty? && embd[-1] == LLaMACpp.token_eos
|
144
|
-
end
|
145
|
-
|
146
|
-
output.join.delete_prefix(" #{prompt}").strip
|
147
|
-
end
|
148
|
-
# rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/ParameterLists, Metrics/PerceivedComplexity
|
149
|
-
|
150
|
-
# def chat(prompt); end
|
151
|
-
|
152
|
-
# Obtains the embedding for a given text.
|
153
|
-
#
|
154
|
-
# @param text [String] The text to obtain the embedding for.
|
155
|
-
# @return [Array<Float>]
|
156
|
-
def embeddings(text)
|
157
|
-
raise 'The embedding option is set to false' unless @params[:embedding]
|
158
|
-
|
159
|
-
embd_input = tokenize_prompt(text)
|
160
|
-
raise 'The result of tokenizing the input text is empty' unless embd_input.size.positive?
|
161
|
-
|
162
|
-
@context.eval(tokens: embd_input, n_past: 0, n_threads: @params[:n_threads])
|
163
|
-
@context.embeddings
|
164
|
-
end
|
165
|
-
|
166
|
-
private
|
167
|
-
|
168
|
-
def tokenize_prompt(prompt)
|
169
|
-
@context.tokenize(text: " #{prompt}", add_bos: true)
|
170
|
-
end
|
171
|
-
end
|
172
|
-
end
|