llama_cpp 0.2.2 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +34 -0
- data/README.md +39 -6
- data/examples/chat.rb +2 -1
- data/examples/embedding.rb +3 -2
- data/ext/llama_cpp/extconf.rb +13 -0
- data/ext/llama_cpp/llama_cpp.cpp +305 -133
- data/ext/llama_cpp/src/ggml-cuda.cu +367 -69
- data/ext/llama_cpp/src/ggml-cuda.h +1 -0
- data/ext/llama_cpp/src/ggml-metal.m +36 -30
- data/ext/llama_cpp/src/ggml-metal.metal +328 -84
- data/ext/llama_cpp/src/ggml-opencl.cpp +352 -175
- data/ext/llama_cpp/src/ggml.c +800 -303
- data/ext/llama_cpp/src/ggml.h +68 -5
- data/ext/llama_cpp/src/k_quants.c +1712 -56
- data/ext/llama_cpp/src/k_quants.h +41 -6
- data/ext/llama_cpp/src/llama-util.h +19 -5
- data/ext/llama_cpp/src/llama.cpp +262 -291
- data/ext/llama_cpp/src/llama.h +49 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/lib/llama_cpp.rb +0 -2
- data/sig/llama_cpp.rbs +14 -17
- metadata +2 -3
- data/lib/llama_cpp/client.rb +0 -172
data/ext/llama_cpp/src/llama.h
CHANGED
@@ -26,6 +26,14 @@
|
|
26
26
|
# define LLAMA_API
|
27
27
|
#endif
|
28
28
|
|
29
|
+
#ifdef __GNUC__
|
30
|
+
# define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
|
31
|
+
#elif defined(_MSC_VER)
|
32
|
+
# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
|
33
|
+
#else
|
34
|
+
# define DEPRECATED(func, hint) func
|
35
|
+
#endif
|
36
|
+
|
29
37
|
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
|
30
38
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
31
39
|
#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf'
|
@@ -38,6 +46,8 @@
|
|
38
46
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
39
47
|
#define LLAMA_SESSION_VERSION 1
|
40
48
|
|
49
|
+
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
|
50
|
+
|
41
51
|
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
|
42
52
|
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
43
53
|
#define LLAMA_SUPPORTS_GPU_OFFLOAD
|
@@ -53,6 +63,7 @@ extern "C" {
|
|
53
63
|
// TODO: show sample usage
|
54
64
|
//
|
55
65
|
|
66
|
+
struct llama_model;
|
56
67
|
struct llama_context;
|
57
68
|
|
58
69
|
typedef int llama_token;
|
@@ -72,11 +83,11 @@ extern "C" {
|
|
72
83
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
73
84
|
|
74
85
|
struct llama_context_params {
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
86
|
+
uint32_t seed; // RNG seed, -1 for random
|
87
|
+
int32_t n_ctx; // text context
|
88
|
+
int32_t n_batch; // prompt processing batch size
|
89
|
+
int32_t n_gpu_layers; // number of layers to store in VRAM
|
90
|
+
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
80
91
|
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
|
81
92
|
// called with a progress value between 0 and 1, pass NULL to disable
|
82
93
|
llama_progress_callback progress_callback;
|
@@ -131,17 +142,29 @@ extern "C" {
|
|
131
142
|
|
132
143
|
// TODO: not great API - very likely to change
|
133
144
|
// Initialize the llama + ggml backend
|
145
|
+
// If numa is true, use NUMA optimizations
|
134
146
|
// Call once at the start of the program
|
135
|
-
LLAMA_API void llama_init_backend();
|
147
|
+
LLAMA_API void llama_init_backend(bool numa);
|
136
148
|
|
137
149
|
LLAMA_API int64_t llama_time_us();
|
138
150
|
|
151
|
+
LLAMA_API struct llama_model * llama_load_model_from_file(
|
152
|
+
const char * path_model,
|
153
|
+
struct llama_context_params params);
|
154
|
+
|
155
|
+
LLAMA_API void llama_free_model(struct llama_model * model);
|
156
|
+
|
157
|
+
LLAMA_API struct llama_context * llama_new_context_with_model(
|
158
|
+
struct llama_model * model,
|
159
|
+
struct llama_context_params params);
|
160
|
+
|
139
161
|
// Various functions for loading a ggml llama model.
|
140
162
|
// Allocate (almost) all memory needed for the model.
|
141
163
|
// Return NULL on failure
|
142
|
-
LLAMA_API struct llama_context * llama_init_from_file(
|
164
|
+
LLAMA_API DEPRECATED(struct llama_context * llama_init_from_file(
|
143
165
|
const char * path_model,
|
144
|
-
struct llama_context_params params)
|
166
|
+
struct llama_context_params params),
|
167
|
+
"please use llama_load_model_from_file combined with llama_new_context_with_model instead");
|
145
168
|
|
146
169
|
// Frees all allocated memory
|
147
170
|
LLAMA_API void llama_free(struct llama_context * ctx);
|
@@ -158,8 +181,15 @@ extern "C" {
|
|
158
181
|
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
159
182
|
// will be applied on top of the previous one
|
160
183
|
// Returns 0 on success
|
161
|
-
LLAMA_API int llama_apply_lora_from_file(
|
184
|
+
LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
|
162
185
|
struct llama_context * ctx,
|
186
|
+
const char * path_lora,
|
187
|
+
const char * path_base_model,
|
188
|
+
int n_threads),
|
189
|
+
"please use llama_model_apply_lora_from_file instead");
|
190
|
+
|
191
|
+
LLAMA_API int llama_model_apply_lora_from_file(
|
192
|
+
const struct llama_model * model,
|
163
193
|
const char * path_lora,
|
164
194
|
const char * path_base_model,
|
165
195
|
int n_threads);
|
@@ -168,7 +198,7 @@ extern "C" {
|
|
168
198
|
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
169
199
|
|
170
200
|
// Sets the current rng seed.
|
171
|
-
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx,
|
201
|
+
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
172
202
|
|
173
203
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
174
204
|
// and kv_cache) - will often be smaller after compacting tokens
|
@@ -198,6 +228,14 @@ extern "C" {
|
|
198
228
|
int n_past,
|
199
229
|
int n_threads);
|
200
230
|
|
231
|
+
// Same as llama_eval, but use float matrix input directly.
|
232
|
+
LLAMA_API int llama_eval_embd(
|
233
|
+
struct llama_context * ctx,
|
234
|
+
const float * embd,
|
235
|
+
int n_tokens,
|
236
|
+
int n_past,
|
237
|
+
int n_threads);
|
238
|
+
|
201
239
|
// Export a static computation graph for context of 511 and batch size of 1
|
202
240
|
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
203
241
|
// parameters here to keep things simple
|
@@ -310,7 +348,7 @@ extern "C" {
|
|
310
348
|
#include <string>
|
311
349
|
struct ggml_tensor;
|
312
350
|
|
313
|
-
std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
351
|
+
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx);
|
314
352
|
|
315
353
|
#endif
|
316
354
|
|
data/lib/llama_cpp/version.rb
CHANGED
@@ -3,8 +3,8 @@
|
|
3
3
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
4
4
|
module LLaMACpp
|
5
5
|
# The version of llama_cpp.rb you install.
|
6
|
-
VERSION = '0.
|
6
|
+
VERSION = '0.3.1'
|
7
7
|
|
8
8
|
# The version of llama.cpp bundled with llama_cpp.rb.
|
9
|
-
LLAMA_CPP_VERSION = 'master-
|
9
|
+
LLAMA_CPP_VERSION = 'master-b8c8dda'
|
10
10
|
end
|
data/lib/llama_cpp.rb
CHANGED
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
require_relative 'llama_cpp/version'
|
4
4
|
require_relative 'llama_cpp/llama_cpp'
|
5
|
-
require_relative 'llama_cpp/client'
|
6
5
|
|
7
6
|
# llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
8
7
|
module LLaMACpp
|
@@ -20,7 +19,6 @@ module LLaMACpp
|
|
20
19
|
# @return [String]
|
21
20
|
def generate(context, prompt, n_predict: 128, n_threads: 1) # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
22
21
|
raise ArgumentError, 'context must be an instance of LLaMACpp::Context' unless context.is_a?(LLaMACpp::Context)
|
23
|
-
raise ArgumentError, 'context must have loaded the model' if context.empty?
|
24
22
|
raise ArgumentError, 'prompt must be a String' unless prompt.is_a?(String)
|
25
23
|
|
26
24
|
spaced_prompt = " #{prompt}"
|
data/sig/llama_cpp.rbs
CHANGED
@@ -4,6 +4,7 @@ module LLaMACpp
|
|
4
4
|
LLAMA_FILE_VERSION: String
|
5
5
|
LLAMA_FILE_MAGIC: String
|
6
6
|
LLAMA_FILE_MAGIC_UNVERSIONED: String
|
7
|
+
LLAMA_DEFALUT_SEED: String
|
7
8
|
|
8
9
|
LLAMA_MAX_DEVICES: Integer
|
9
10
|
|
@@ -25,7 +26,7 @@ module LLaMACpp
|
|
25
26
|
LLAMA_FTYPE_MOSTLY_Q5_K_M: Integer
|
26
27
|
LLAMA_FTYPE_MOSTLY_Q6_K: Integer
|
27
28
|
|
28
|
-
def self?.init_backend: () -> void
|
29
|
+
def self?.init_backend: (?numa: bool) -> void
|
29
30
|
def self?.model_quantize: (input_path: String, output_path: String, params: ModelQuantizeParams) -> void
|
30
31
|
def self?.generate: (::LLaMACpp::Context, String, ?n_predict: Integer, ?n_threads: Integer) -> String
|
31
32
|
def self?.print_system_info: () -> void
|
@@ -55,17 +56,25 @@ module LLaMACpp
|
|
55
56
|
def sorted: () -> bool
|
56
57
|
end
|
57
58
|
|
58
|
-
class
|
59
|
+
class Model
|
59
60
|
public
|
60
61
|
|
61
62
|
def initialize: (model_path: String, params: ::LLaMACpp::ContextParams) -> void
|
62
63
|
| () -> void
|
63
|
-
def embeddings: () -> Array[Float]
|
64
64
|
def empty?: () -> bool
|
65
|
-
def eval: (tokens: Array[Integer], n_past: Integer, ?n_tokens: Integer, ?n_threads: Integer) -> void
|
66
|
-
def eval_export: (String) -> bool
|
67
65
|
def free: () -> void
|
68
66
|
def load: (model_path: String, params: ::LLaMACpp::ContextParams) -> void
|
67
|
+
def apply_lora_from_file: (lora_path: String, ?base_model_path: String, ?n_threads: Integer) -> void
|
68
|
+
end
|
69
|
+
|
70
|
+
class Context
|
71
|
+
public
|
72
|
+
|
73
|
+
def initialize: (model: ::LLaMACpp::Model) -> void
|
74
|
+
def embeddings: () -> Array[Float]
|
75
|
+
def eval: (tokens: Array[Integer], n_past: Integer, ?n_tokens: Integer, ?n_threads: Integer) -> void
|
76
|
+
def eval_embd: (tokens: Array[Float], n_past: Integer, ?n_tokens: Integer, ?n_threads: Integer) -> void
|
77
|
+
def eval_export: (String) -> bool
|
69
78
|
def logits: () -> Array[Float]
|
70
79
|
def n_ctx: () -> Integer
|
71
80
|
def n_embd: () -> Integer
|
@@ -75,7 +84,6 @@ module LLaMACpp
|
|
75
84
|
def reset_timings: () -> void
|
76
85
|
def token_to_str: (Integer) -> String
|
77
86
|
def tokenize: (text: String, ?n_max_tokens: Integer, ?add_bos: bool) -> Array[Integer]
|
78
|
-
def apply_lora_from_file: (lora_path: String, ?base_model_path: String, ?n_threads: Integer) -> void
|
79
87
|
def kv_cache_token_count: () -> Integer
|
80
88
|
def set_rng_seed: (Integer) -> void
|
81
89
|
def load_session_file: (session_path: String) -> void
|
@@ -138,15 +146,4 @@ module LLaMACpp
|
|
138
146
|
end
|
139
147
|
|
140
148
|
class Params = ContextParams
|
141
|
-
|
142
|
-
class Client
|
143
|
-
def initialize(model_path: String, ?lora_adapter_path: String, ?lora_base_path: String,
|
144
|
-
?n_ctx: Integer, ?memory_f16: bool, ?use_mmap: bool, ?use_mlock: bool,
|
145
|
-
?embedding: bool, ?n_threads: Integer, ?seed: Integer) -> void
|
146
|
-
def completions(String, ?max_tokens: Integer, ?n_keep: Integer, ?repeat_last_n: Integer, ?n_batch: Integer,
|
147
|
-
?frequency: Float, ?presence: Float,
|
148
|
-
?top_k: Integer, ?top_p: Float, ?tfs_z: Float, ?typical_p: Float, ?temperature: Float,
|
149
|
-
?repeat_penalty: Float) -> String
|
150
|
-
def embeddings(String) -> Array[Float]
|
151
|
-
end
|
152
149
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: llama_cpp
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-07-02 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: llama_cpp.rb provides Ruby bindings for the llama.cpp.
|
14
14
|
email:
|
@@ -44,7 +44,6 @@ files:
|
|
44
44
|
- ext/llama_cpp/src/llama.cpp
|
45
45
|
- ext/llama_cpp/src/llama.h
|
46
46
|
- lib/llama_cpp.rb
|
47
|
-
- lib/llama_cpp/client.rb
|
48
47
|
- lib/llama_cpp/version.rb
|
49
48
|
- sig/llama_cpp.rbs
|
50
49
|
homepage: https://github.com/yoshoku/llama_cpp.rb
|
data/lib/llama_cpp/client.rb
DELETED
@@ -1,172 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module LLaMACpp
|
4
|
-
# Client provides a high-level interface to the LLM model.
|
5
|
-
class Client # rubocop:disable Metrics/ClassLength
|
6
|
-
# Creates a new client.
|
7
|
-
#
|
8
|
-
# @param model_path [String] The path to the model file.
|
9
|
-
# @param lora_adapter_path [String] The path to the LoRA adapter file.
|
10
|
-
# @param lora_base_path [String] The path to the LoRA base model file.
|
11
|
-
# @param n_ctx [Integer] The context size.
|
12
|
-
# @param memory_f16 [Boolean] The flag wheter to use f16 instead of f32 for memory kv.
|
13
|
-
# @param use_mmap [Boolean] The flag whether to use mmap.
|
14
|
-
# @param use_mlock [Boolean] The flag hether to use mlock.
|
15
|
-
# @param embedding [Boolean] The flag whether to calculate embedding.
|
16
|
-
# @param n_threads [Integer] The number of threads to use.
|
17
|
-
# @param seed [Integer] The seed for the random number generator.
|
18
|
-
# @return [Client]
|
19
|
-
# rubocop:disable Metrics/MethodLength, Metrics/ParameterLists
|
20
|
-
def initialize(model_path:, lora_adapter_path: nil, lora_base_path: nil,
|
21
|
-
n_ctx: 512, memory_f16: false, use_mmap: true, use_mlock: false,
|
22
|
-
embedding: false,
|
23
|
-
n_threads: 1, seed: 0)
|
24
|
-
@params = {
|
25
|
-
model_path: model_path,
|
26
|
-
lora_adapter_path: lora_adapter_path,
|
27
|
-
lora_base_path: lora_base_path,
|
28
|
-
n_ctx: n_ctx,
|
29
|
-
memory_f16: memory_f16,
|
30
|
-
use_mmap: use_mmap,
|
31
|
-
use_mlock: use_mlock,
|
32
|
-
embedding: embedding,
|
33
|
-
n_threads: n_threads,
|
34
|
-
seed: seed
|
35
|
-
}
|
36
|
-
@context_params = ContextParams.new
|
37
|
-
@context_params.n_ctx = n_ctx
|
38
|
-
@context_params.n_parts = n_parts
|
39
|
-
@context_params.f16_kv = memory_f16
|
40
|
-
@context_params.use_mmap = use_mmap
|
41
|
-
@context_params.use_mlock = use_mlock
|
42
|
-
@context_params.embedding = embedding
|
43
|
-
@context_params.seed = seed
|
44
|
-
@context = Context.new(model_path: model_path, params: @context_params)
|
45
|
-
return unless lora_adapter_path.is_a?(String)
|
46
|
-
|
47
|
-
if lora_base_path.is_a?(String)
|
48
|
-
@context.apply_lora_from_file(lora_path: lora_adapter_path, base_model_path: lora_base_path, n_threads: n_threads)
|
49
|
-
else
|
50
|
-
@context.apply_lora_from_file(lora_path: lora_adapter_path, n_threads: n_threads)
|
51
|
-
end
|
52
|
-
end
|
53
|
-
# rubocop:enable Metrics/MethodLength, Metrics/ParameterLists
|
54
|
-
|
55
|
-
# Generates completions for a given prompt.
|
56
|
-
#
|
57
|
-
# @param prompt [String] The prompt to generate completions for.
|
58
|
-
# @param max_tokens [Integer] The maximum number of tokens to generate.
|
59
|
-
# @param n_keep [Integer] The number of tokens to keep from the initial prompt.
|
60
|
-
# @param repeat_last_n [Integer] The number of tokens to use for repeat penalty.
|
61
|
-
# @param n_batch [Integer] The batch size.
|
62
|
-
# @param frequency [Float] The frequency penalty value.
|
63
|
-
# @param presence [Float] The presence penalty value.
|
64
|
-
# @param top_k [Integer] The top-k value.
|
65
|
-
# @param top_p [Float] The top-p value.
|
66
|
-
# @param tfs_z [Float] The tail free sampling parameter.
|
67
|
-
# @param typical_p [Float] The typical probability value.
|
68
|
-
# @param temperature [Float] The temperature value.
|
69
|
-
# @param repeat_penalty [Float] The repeat penalty value.
|
70
|
-
# @return [String]
|
71
|
-
# rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/ParameterLists, Metrics/PerceivedComplexity
|
72
|
-
def completions(prompt, max_tokens: 64, n_keep: 10, repeat_last_n: 64, n_batch: 512,
|
73
|
-
frequency: 0.0, presence: 0.0,
|
74
|
-
top_k: 40, top_p: 0.95, tfs_z: 1.0, typical_p: 1.0, temperature: 0.8, repeat_penalty: 1.1)
|
75
|
-
embd_input = tokenize_prompt(prompt)
|
76
|
-
|
77
|
-
n_ctx = @context.n_ctx
|
78
|
-
raise ArgumentError, "prompt is too long #{embd_input.size} tokens, maximum is #{n_ctx - 4}" if embd_input.size > n_ctx - 4
|
79
|
-
|
80
|
-
last_n_tokens = [0] * n_ctx
|
81
|
-
|
82
|
-
embd = []
|
83
|
-
n_consumed = 0
|
84
|
-
n_past = 0
|
85
|
-
n_remain = max_tokens
|
86
|
-
n_vocab = @context.n_vocab
|
87
|
-
output = []
|
88
|
-
|
89
|
-
while n_remain != 0
|
90
|
-
unless embd.empty?
|
91
|
-
if n_past + embd.size > n_ctx
|
92
|
-
n_left = n_past - n_keep
|
93
|
-
n_past = n_keep
|
94
|
-
embd.insert(0, last_n_tokens[(n_ctx - (n_left / 2) - embd.size)...-embd.size])
|
95
|
-
end
|
96
|
-
|
97
|
-
@context.eval(tokens: embd, n_past: n_past, n_threads: @params[:n_threads])
|
98
|
-
end
|
99
|
-
|
100
|
-
n_past += embd.size
|
101
|
-
embd.clear
|
102
|
-
|
103
|
-
if embd_input.size <= n_consumed
|
104
|
-
logits = @context.logits
|
105
|
-
base_candidates = Array.new(n_vocab) { |i| LLaMACpp::TokenData.new(id: i, logit: logits[i], p: 0.0) }
|
106
|
-
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
107
|
-
|
108
|
-
# apply penalties
|
109
|
-
last_n_repeat = [last_n_tokens.size, repeat_last_n, n_ctx].min
|
110
|
-
@context.sample_repetition_penalty(candidates, last_n_tokens[-last_n_repeat..], penalty: repeat_penalty)
|
111
|
-
@context.sample_frequency_and_presence_penalties(
|
112
|
-
candidates, last_n_tokens[-last_n_repeat..], frequency: frequency, presence: presence
|
113
|
-
)
|
114
|
-
|
115
|
-
# temperature sampling
|
116
|
-
@context.sample_top_k(candidates, k: top_k)
|
117
|
-
@context.sample_tail_free(candidates, z: tfs_z)
|
118
|
-
@context.sample_typical(candidates, prob: typical_p)
|
119
|
-
@context.sample_top_p(candidates, prob: top_p)
|
120
|
-
@context.sample_temperature(candidates, temperature: temperature)
|
121
|
-
id = @context.sample_token(candidates)
|
122
|
-
|
123
|
-
last_n_tokens.shift
|
124
|
-
last_n_tokens.push(id)
|
125
|
-
|
126
|
-
last_n_tokens.shift
|
127
|
-
last_n_tokens.push(id)
|
128
|
-
|
129
|
-
embd.push(id)
|
130
|
-
n_remain -= 1
|
131
|
-
else
|
132
|
-
while embd_input.size > n_consumed
|
133
|
-
embd.push(embd_input[n_consumed])
|
134
|
-
last_n_tokens.shift
|
135
|
-
last_n_tokens.push(embd_input[n_consumed])
|
136
|
-
n_consumed += 1
|
137
|
-
break if embd.size >= n_batch
|
138
|
-
end
|
139
|
-
end
|
140
|
-
|
141
|
-
embd.each { |token| output << @context.token_to_str(token) }
|
142
|
-
|
143
|
-
break if !embd.empty? && embd[-1] == LLaMACpp.token_eos
|
144
|
-
end
|
145
|
-
|
146
|
-
output.join.delete_prefix(" #{prompt}").strip
|
147
|
-
end
|
148
|
-
# rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/ParameterLists, Metrics/PerceivedComplexity
|
149
|
-
|
150
|
-
# def chat(prompt); end
|
151
|
-
|
152
|
-
# Obtains the embedding for a given text.
|
153
|
-
#
|
154
|
-
# @param text [String] The text to obtain the embedding for.
|
155
|
-
# @return [Array<Float>]
|
156
|
-
def embeddings(text)
|
157
|
-
raise 'The embedding option is set to false' unless @params[:embedding]
|
158
|
-
|
159
|
-
embd_input = tokenize_prompt(text)
|
160
|
-
raise 'The result of tokenizing the input text is empty' unless embd_input.size.positive?
|
161
|
-
|
162
|
-
@context.eval(tokens: embd_input, n_past: 0, n_threads: @params[:n_threads])
|
163
|
-
@context.embeddings
|
164
|
-
end
|
165
|
-
|
166
|
-
private
|
167
|
-
|
168
|
-
def tokenize_prompt(prompt)
|
169
|
-
@context.tokenize(text: " #{prompt}", add_bos: true)
|
170
|
-
end
|
171
|
-
end
|
172
|
-
end
|