llama_cpp 0.2.0 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +11 -0
- data/examples/README.md +92 -0
- data/examples/chat.rb +195 -0
- data/examples/embedding.rb +37 -0
- data/ext/llama_cpp/llama_cpp.cpp +52 -0
- data/ext/llama_cpp/src/ggml-cuda.cu +1218 -411
- data/ext/llama_cpp/src/ggml-cuda.h +4 -1
- data/ext/llama_cpp/src/ggml-metal.h +5 -1
- data/ext/llama_cpp/src/ggml-metal.m +703 -514
- data/ext/llama_cpp/src/ggml-metal.metal +574 -122
- data/ext/llama_cpp/src/ggml-opencl.cpp +496 -36
- data/ext/llama_cpp/src/ggml-opencl.h +1 -2
- data/ext/llama_cpp/src/ggml.c +2715 -476
- data/ext/llama_cpp/src/ggml.h +266 -11
- data/ext/llama_cpp/src/llama.cpp +266 -135
- data/ext/llama_cpp/src/llama.h +19 -11
- data/lib/llama_cpp/version.rb +2 -2
- data/sig/llama_cpp.rbs +3 -0
- metadata +5 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: e5e221d4831be790a990b121e6ac780d10b4cbfb85b2a9b4284d9c216f6e5604
|
4
|
+
data.tar.gz: fba76ac1a70bfd7b02b8d123c57e4c8096a29ac7f658bb090cda91c6a54752d2
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 994029383219077e134d170177954251c20ede6d1c83843ecd22c42eeae83584079d124b41702f55add7f3f237e9bdb14382fbd37dde2d0e74f8cffcfed1715b
|
7
|
+
data.tar.gz: ca4e94b6ddf4e4e9ddabbb2b8309cf4b2b06a881df09fdf4ad96e27c4f1f620ca0024ac46f69d9b474849c074a5c9ba9b0440777a0b52a12413bc356457a02f3
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,14 @@
|
|
1
|
+
## [[0.2.2](https://github.com/yoshoku/llama_cpp.rb/compare/v0.2.1...v0.2.2)] - 2023-06-24
|
2
|
+
|
3
|
+
- Bump bundled llama.cpp from master-a09f919 to master-7487137.
|
4
|
+
|
5
|
+
## [[0.2.1](https://github.com/yoshoku/llama_cpp.rb/compare/v0.2.0...v0.2.1)] - 2023-06-17
|
6
|
+
|
7
|
+
- Bump bundled llama.cpp from master-4de0334 to master-a09f919.
|
8
|
+
- Add `low_vram` parameter to ContextParams.
|
9
|
+
- Add `vocab` method to Context.
|
10
|
+
- Add example script: https://github.com/yoshoku/llama_cpp.rb/tree/main/examples
|
11
|
+
|
1
12
|
## [[0.2.0](https://github.com/yoshoku/llama_cpp.rb/compare/v0.1.4...v0.2.0)] - 2023-06-11
|
2
13
|
|
3
14
|
- Bump bundled llama.cpp from master-ffb06a3 to master-4de0334.
|
data/examples/README.md
ADDED
@@ -0,0 +1,92 @@
|
|
1
|
+
# llama_cpp.rb/examples
|
2
|
+
|
3
|
+
## chat.rb
|
4
|
+
|
5
|
+
### Usage
|
6
|
+
|
7
|
+
```sh
|
8
|
+
$ cd examples
|
9
|
+
$ gem install llama_cpp thor
|
10
|
+
$ ./chat.rb -m /path/to/quantized-model.bin -t 4
|
11
|
+
...
|
12
|
+
User: Please tell me the largest city in Japan.
|
13
|
+
Bob: Sure. The largest city in Japan is Tokyo.
|
14
|
+
User:
|
15
|
+
```
|
16
|
+
|
17
|
+
### Options
|
18
|
+
|
19
|
+
```sh
|
20
|
+
$ ./chat.rb help main
|
21
|
+
Usage:
|
22
|
+
chat.rb main -m, --model=MODEL
|
23
|
+
|
24
|
+
Options:
|
25
|
+
-s, [--seed=N] # random seed
|
26
|
+
# Default: -1
|
27
|
+
-t, [--threads=N] # number of threads
|
28
|
+
# Default: 2
|
29
|
+
-m, --model=MODEL # path to model file
|
30
|
+
-f, [--file=FILE] # prompt file to start generation
|
31
|
+
-r, [--reverse-prompt=REVERSE_PROMPT] # halt generation at PROMPT, return control in interactive mode
|
32
|
+
-b, [--batch-size=N] # batch size for prompt processing
|
33
|
+
# Default: 1024
|
34
|
+
-n, [--n-predict=N] # number of tokens to predict
|
35
|
+
# Default: 256
|
36
|
+
[--keep=N] # number of tokens to keep from the initial prompt
|
37
|
+
# Default: 48
|
38
|
+
[--repeat-last-n=N] # last n tokens to consider for penalize
|
39
|
+
# Default: 64
|
40
|
+
[--repeat-penalty=N] # penalize repeat sequence of tokens
|
41
|
+
# Default: 1.0
|
42
|
+
[--presence-penalty=N] # repeat alpha presence penalty
|
43
|
+
# Default: 0.0
|
44
|
+
[--frequency-penalty=N] # repeat alpha frequency penalty
|
45
|
+
# Default: 0.0
|
46
|
+
[--top-k=N] # top k sampling
|
47
|
+
# Default: 40
|
48
|
+
[--top-p=N] # top p sampling
|
49
|
+
# Default: 0.95
|
50
|
+
[--tfs-z=N] # tail free sampling, parameter z
|
51
|
+
# Default: 1.0
|
52
|
+
[--typical-p=N] # locally typical sampling, parameter p
|
53
|
+
# Default: 1.0
|
54
|
+
[--temp=N] # temperature
|
55
|
+
# Default: 0.8
|
56
|
+
[--n-gpu-layers=N] # number of layers on GPU
|
57
|
+
# Default: 0
|
58
|
+
|
59
|
+
Start chat
|
60
|
+
```
|
61
|
+
|
62
|
+
## embedding.rb
|
63
|
+
|
64
|
+
### Usage
|
65
|
+
|
66
|
+
```sh
|
67
|
+
$ cd examples
|
68
|
+
$ gem install llama_cpp thor
|
69
|
+
$ ./embedding.rb -m /path/to/quantized-model.bin -t 4 -p 'Hello, World.'
|
70
|
+
...
|
71
|
+
0.7191136479377747 0.5564611554145813 1.4210394620895386 -1.4874695539474487
|
72
|
+
```
|
73
|
+
|
74
|
+
### Options
|
75
|
+
|
76
|
+
```
|
77
|
+
$ ./embedding.rb help main
|
78
|
+
Usage:
|
79
|
+
embedding.rb main -m, --model=MODEL -p, --prompt=PROMPT
|
80
|
+
|
81
|
+
Options:
|
82
|
+
-s, [--seed=N] # random seed
|
83
|
+
# Default: -1
|
84
|
+
-t, [--threads=N] # number of threads
|
85
|
+
# Default: 2
|
86
|
+
-m, --model=MODEL # path to model file
|
87
|
+
-p, --prompt=PROMPT # prompt to generate embedding
|
88
|
+
[--n-gpu-layers=N] # number of layers on GPU
|
89
|
+
# Default: 0
|
90
|
+
|
91
|
+
Extract embedding from prompt
|
92
|
+
```
|
data/examples/chat.rb
ADDED
@@ -0,0 +1,195 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
# chat.rb is a simple chatbot that uses llama_cpp to generate text.
|
5
|
+
# It is created with reference to main.cpp and chat.sh in llama.cpp examples:
|
6
|
+
# - https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp
|
7
|
+
# - https://github.com/ggerganov/llama.cpp/blob/master/examples/chat.sh
|
8
|
+
|
9
|
+
require 'llama_cpp'
|
10
|
+
require 'thor'
|
11
|
+
require 'readline'
|
12
|
+
|
13
|
+
class Chat < Thor # rubocop:disable Metrics/ClassLength, Style/Documentation
|
14
|
+
default_command :main
|
15
|
+
desc 'main', 'Start chat'
|
16
|
+
option :seed, type: :numeric, aliases: '-s', desc: 'random seed', default: -1
|
17
|
+
option :threads, type: :numeric, aliases: '-t', desc: 'number of threads', default: 2
|
18
|
+
option :model, type: :string, aliases: '-m', desc: 'path to model file', required: true
|
19
|
+
option :file, type: :string, aliases: '-f', desc: 'prompt file to start generation'
|
20
|
+
option :reverse_prompt, type: :string, aliases: '-r', desc: 'halt generation at PROMPT, return control in interactive mode'
|
21
|
+
option :batch_size, type: :numeric, aliases: '-b', desc: 'batch size for prompt processing', default: 1024
|
22
|
+
option :n_predict, type: :numeric, aliases: '-n', desc: 'number of tokens to predict', default: 256
|
23
|
+
option :keep, type: :numeric, desc: 'number of tokens to keep from the initial prompt', default: 48
|
24
|
+
option :repeat_last_n, type: :numeric, desc: 'last n tokens to consider for penalize', default: 64
|
25
|
+
option :repeat_penalty, type: :numeric, desc: 'penalize repeat sequence of tokens', default: 1.0
|
26
|
+
option :presence_penalty, type: :numeric, desc: 'repeat alpha presence penalty', default: 0.0
|
27
|
+
option :frequency_penalty, type: :numeric, desc: 'repeat alpha frequency penalty', default: 0.0
|
28
|
+
option :top_k, type: :numeric, desc: 'top k sampling', default: 40
|
29
|
+
option :top_p, type: :numeric, desc: 'top p sampling', default: 0.95
|
30
|
+
option :tfs_z, type: :numeric, desc: 'tail free sampling, parameter z', default: 1.0
|
31
|
+
option :typical_p, type: :numeric, desc: 'locally typical sampling, parameter p', default: 1.0
|
32
|
+
option :temp, type: :numeric, desc: 'temperature', default: 0.8
|
33
|
+
option :n_gpu_layers, type: :numeric, desc: 'number of layers on GPU', default: 0
|
34
|
+
def main # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
35
|
+
params = LLaMACpp::ContextParams.new
|
36
|
+
params.seed = options[:seed]
|
37
|
+
params.n_gpu_layers = options[:n_gpu_layers]
|
38
|
+
context = LLaMACpp::Context.new(model_path: options[:model], params: params)
|
39
|
+
|
40
|
+
antiprompt = options[:reverse_prompt] || 'User:'
|
41
|
+
start_prompt = read_prompt(options[:file]) || default_prompt(antiprompt)
|
42
|
+
|
43
|
+
embd_input = context.tokenize(text: start_prompt, add_bos: true)
|
44
|
+
|
45
|
+
n_ctx = context.n_ctx
|
46
|
+
raise ArgumentError, "prompt is too long #{embd_input.size} tokens, maximum is #{n_ctx - 4}" if embd_input.size > n_ctx - 4
|
47
|
+
|
48
|
+
n_keep = options[:keep]
|
49
|
+
n_keep = embd_input.size if n_keep > embd_input.size
|
50
|
+
|
51
|
+
token_newline = context.tokenize(text: "\n", add_bos: false)
|
52
|
+
|
53
|
+
last_n_tokens = [0] * n_ctx
|
54
|
+
interactive = true
|
55
|
+
is_interacting = false
|
56
|
+
input_echo = true
|
57
|
+
first_input = true
|
58
|
+
embd = []
|
59
|
+
n_consumed = 0
|
60
|
+
n_past = 0
|
61
|
+
n_remain = options[:n_predict]
|
62
|
+
n_vocab = context.n_vocab
|
63
|
+
|
64
|
+
while interactive
|
65
|
+
unless embd.empty?
|
66
|
+
if n_past + embd.size > n_ctx
|
67
|
+
n_left = n_past - n_keep
|
68
|
+
n_past = [1, n_keep].max
|
69
|
+
embd.insert(0, last_n_tokens[(n_ctx - (n_left / 2) - embd.size)...-embd.size])
|
70
|
+
end
|
71
|
+
|
72
|
+
0.step(embd.size - 1, options[:batch_size]) do |i|
|
73
|
+
n_eval = [options[:batch_size], embd.size - i].min
|
74
|
+
context.eval(tokens: embd[i...i + n_eval], n_past: n_past, n_threads: options[:threads])
|
75
|
+
n_past += n_eval
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
embd.clear
|
80
|
+
|
81
|
+
if embd_input.size <= n_consumed && !is_interacting
|
82
|
+
logits = context.logits
|
83
|
+
base_candidates = Array.new(n_vocab) { |i| LLaMACpp::TokenData.new(id: i, logit: logits[i], p: 0.0) }
|
84
|
+
candidates = LLaMACpp::TokenDataArray.new(base_candidates)
|
85
|
+
|
86
|
+
last_n_repeat = [last_n_tokens.size, options[:repeat_last_n], n_ctx].min
|
87
|
+
context.sample_repetition_penalty(candidates, last_n_tokens[-last_n_repeat..], penalty: options[:repeat_penalty])
|
88
|
+
context.sample_frequency_and_presence_penalties(
|
89
|
+
candidates, last_n_tokens[-last_n_repeat..],
|
90
|
+
frequency: options[:frequency_penalty], presence: options[:presence_penalty]
|
91
|
+
)
|
92
|
+
|
93
|
+
context.sample_top_k(candidates, k: options[:top_k])
|
94
|
+
context.sample_tail_free(candidates, z: options[:tfs_z])
|
95
|
+
context.sample_typical(candidates, prob: options[:typical_p])
|
96
|
+
context.sample_top_p(candidates, prob: options[:top_p])
|
97
|
+
context.sample_temperature(candidates, temperature: options[:temp])
|
98
|
+
id = context.sample_token(candidates)
|
99
|
+
|
100
|
+
last_n_tokens.shift
|
101
|
+
last_n_tokens.push(id)
|
102
|
+
|
103
|
+
if id == LLaMACpp.token_eos
|
104
|
+
id = token_newline.first
|
105
|
+
unless antiprompt.empty?
|
106
|
+
first_antiprompt = context.tokenize(text: antiprompt, add_bos: false)
|
107
|
+
embd_input.concat(first_antiprompt)
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
embd.push(id)
|
112
|
+
input_echo = true
|
113
|
+
n_remain -= 1
|
114
|
+
else
|
115
|
+
while embd_input.size > n_consumed
|
116
|
+
embd.push(embd_input[n_consumed])
|
117
|
+
last_n_tokens.shift
|
118
|
+
last_n_tokens.push(embd_input[n_consumed])
|
119
|
+
n_consumed += 1
|
120
|
+
break if embd.size >= options[:batch_size]
|
121
|
+
end
|
122
|
+
end
|
123
|
+
|
124
|
+
if input_echo
|
125
|
+
output = []
|
126
|
+
embd.each { |token| output << context.token_to_str(token) }
|
127
|
+
output_str = output.join
|
128
|
+
output_str.chomp!(antiprompt) if first_input
|
129
|
+
print(output_str)
|
130
|
+
end
|
131
|
+
|
132
|
+
if embd_input.size <= n_consumed
|
133
|
+
if antiprompt.size.positive?
|
134
|
+
last_output = []
|
135
|
+
last_n_tokens.each { |token| last_output << context.token_to_str(token) }
|
136
|
+
last_output_str = last_output.join
|
137
|
+
|
138
|
+
search_start_pos = last_output_str.size > antiprompt.size ? last_output_str.size - antiprompt.size : 0
|
139
|
+
unless last_output_str.index(antiprompt, search_start_pos).nil?
|
140
|
+
is_interacting = true
|
141
|
+
true
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
if n_past.positive? && is_interacting
|
146
|
+
if first_input
|
147
|
+
print("\r#{antiprompt}")
|
148
|
+
first_input = false
|
149
|
+
end
|
150
|
+
buffer = Readline.readline(' ')
|
151
|
+
break interactive = false if buffer.nil?
|
152
|
+
|
153
|
+
if buffer.size > 1
|
154
|
+
line_input = context.tokenize(text: "#{buffer}\n", add_bos: false)
|
155
|
+
embd_input.concat(line_input)
|
156
|
+
n_remain -= line_input.size
|
157
|
+
end
|
158
|
+
|
159
|
+
input_echo = false
|
160
|
+
end
|
161
|
+
|
162
|
+
is_interacting = false if n_past.positive?
|
163
|
+
end
|
164
|
+
|
165
|
+
if n_remain <= 0 && options[:n_predict] != -1
|
166
|
+
n_remain = options[:n_predict]
|
167
|
+
is_interacting = true
|
168
|
+
end
|
169
|
+
end
|
170
|
+
end
|
171
|
+
|
172
|
+
private
|
173
|
+
|
174
|
+
def read_prompt(filename)
|
175
|
+
return if filename.nil?
|
176
|
+
|
177
|
+
File.read(filename).chomp
|
178
|
+
end
|
179
|
+
|
180
|
+
def default_prompt(antiprompt)
|
181
|
+
# Reference:
|
182
|
+
# https://github.com/ggerganov/llama.cpp/blob/master/prompts/chat-with-bob.txt
|
183
|
+
prompt = <<~MSG
|
184
|
+
Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
185
|
+
|
186
|
+
User: Hello, Bob.
|
187
|
+
Bob: Hello. How may I help you today?
|
188
|
+
User: Please tell me the largest city in Europe.
|
189
|
+
Bob: Sure. The largest city in Europe is Moscow, the capital of Russia.
|
190
|
+
MSG
|
191
|
+
prompt + antiprompt
|
192
|
+
end
|
193
|
+
end
|
194
|
+
|
195
|
+
Chat.start(ARGV)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
#!/usr/bin/env ruby
|
2
|
+
# frozen_string_literal: true
|
3
|
+
|
4
|
+
# embedding.rb extracts embedding from prompt.
|
5
|
+
# It is created with reference to embedding.cpp in llama.cpp examples:
|
6
|
+
# - https://github.com/ggerganov/llama.cpp/blob/master/examples/embedding/embedding.cpp
|
7
|
+
|
8
|
+
require 'llama_cpp'
|
9
|
+
require 'thor'
|
10
|
+
|
11
|
+
class Embedding < Thor # rubocop:disable Style/Documentation
|
12
|
+
default_command :main
|
13
|
+
desc 'main', 'Extract embedding from prompt'
|
14
|
+
option :seed, type: :numeric, aliases: '-s', desc: 'random seed', default: -1
|
15
|
+
option :threads, type: :numeric, aliases: '-t', desc: 'number of threads', default: 2
|
16
|
+
option :model, type: :string, aliases: '-m', desc: 'path to model file', required: true
|
17
|
+
option :prompt, type: :string, aliases: '-p', desc: 'prompt to generate embedding', required: true
|
18
|
+
option :n_gpu_layers, type: :numeric, desc: 'number of layers on GPU', default: 0
|
19
|
+
def main # rubocop:disable Metrics/AbcSize
|
20
|
+
params = LLaMACpp::ContextParams.new
|
21
|
+
params.seed = options[:seed]
|
22
|
+
params.n_gpu_layers = options[:n_gpu_layers]
|
23
|
+
params.embedding = true
|
24
|
+
context = LLaMACpp::Context.new(model_path: options[:model], params: params)
|
25
|
+
|
26
|
+
embd_input = context.tokenize(text: options[:prompt], add_bos: true)
|
27
|
+
|
28
|
+
return unless embd_input.size.positive?
|
29
|
+
|
30
|
+
context.eval(tokens: embd_input, n_past: 0, n_threads: options[:threads])
|
31
|
+
|
32
|
+
context.embeddings.each { |val| print("#{val} ") }
|
33
|
+
print("\n")
|
34
|
+
end
|
35
|
+
end
|
36
|
+
|
37
|
+
Embedding.start(ARGV)
|
data/ext/llama_cpp/llama_cpp.cpp
CHANGED
@@ -300,6 +300,8 @@ public:
|
|
300
300
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu=", RUBY_METHOD_FUNC(_llama_context_params_set_main_gpu), 1);
|
301
301
|
rb_define_method(rb_cLLaMAContextParams, "main_gpu", RUBY_METHOD_FUNC(_llama_context_params_get_main_gpu), 0);
|
302
302
|
rb_define_method(rb_cLLaMAContextParams, "tensor_split", RUBY_METHOD_FUNC(_llama_context_params_get_tensor_split), 0);
|
303
|
+
rb_define_method(rb_cLLaMAContextParams, "low_vram=", RUBY_METHOD_FUNC(_llama_context_params_set_low_vram), 1);
|
304
|
+
rb_define_method(rb_cLLaMAContextParams, "low_vram", RUBY_METHOD_FUNC(_llama_context_params_get_low_vram), 0);
|
303
305
|
rb_define_method(rb_cLLaMAContextParams, "seed=", RUBY_METHOD_FUNC(_llama_context_params_set_seed), 1);
|
304
306
|
rb_define_method(rb_cLLaMAContextParams, "seed", RUBY_METHOD_FUNC(_llama_context_params_get_seed), 0);
|
305
307
|
rb_define_method(rb_cLLaMAContextParams, "f16_kv=", RUBY_METHOD_FUNC(_llama_context_params_set_f16_kv), 1);
|
@@ -386,6 +388,18 @@ private:
|
|
386
388
|
return ret;
|
387
389
|
};
|
388
390
|
|
391
|
+
// low_vram
|
392
|
+
static VALUE _llama_context_params_set_low_vram(VALUE self, VALUE low_vram) {
|
393
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
394
|
+
ptr->params.low_vram = low_vram == Qtrue ? true : false;
|
395
|
+
return ptr->params.low_vram ? Qtrue : Qfalse;
|
396
|
+
};
|
397
|
+
|
398
|
+
static VALUE _llama_context_params_get_low_vram(VALUE self) {
|
399
|
+
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
400
|
+
return ptr->params.low_vram ? Qtrue : Qfalse;
|
401
|
+
};
|
402
|
+
|
389
403
|
// seed
|
390
404
|
static VALUE _llama_context_params_set_seed(VALUE self, VALUE seed) {
|
391
405
|
LLaMAContextParamsWrapper* ptr = get_llama_context_params(self);
|
@@ -641,6 +655,7 @@ public:
|
|
641
655
|
rb_define_method(rb_cLLaMAContext, "tokenize", RUBY_METHOD_FUNC(_llama_context_tokenize), -1);
|
642
656
|
rb_define_method(rb_cLLaMAContext, "logits", RUBY_METHOD_FUNC(_llama_context_logits), 0);
|
643
657
|
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
|
658
|
+
rb_define_method(rb_cLLaMAContext, "vocab", RUBY_METHOD_FUNC(_llama_context_vocab), -1);
|
644
659
|
rb_define_method(rb_cLLaMAContext, "token_to_str", RUBY_METHOD_FUNC(_llama_context_token_to_str), 1);
|
645
660
|
rb_define_method(rb_cLLaMAContext, "n_vocab", RUBY_METHOD_FUNC(_llama_context_n_vocab), 0);
|
646
661
|
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
|
@@ -896,6 +911,43 @@ private:
|
|
896
911
|
return output;
|
897
912
|
};
|
898
913
|
|
914
|
+
static VALUE _llama_context_vocab(int argc, VALUE* argv, VALUE self) {
|
915
|
+
VALUE kw_args = Qnil;
|
916
|
+
ID kw_table[1] = { rb_intern("capacity") };
|
917
|
+
VALUE kw_values[1] = { Qundef };
|
918
|
+
rb_scan_args(argc, argv, ":", &kw_args);
|
919
|
+
rb_get_kwargs(kw_args, kw_table, 1, 0, kw_values);
|
920
|
+
|
921
|
+
if (!RB_INTEGER_TYPE_P(kw_values[0])) {
|
922
|
+
rb_raise(rb_eArgError, "capacity must be an integer");
|
923
|
+
return Qnil;
|
924
|
+
}
|
925
|
+
|
926
|
+
LLaMAContextWrapper* ptr = get_llama_context(self);
|
927
|
+
if (ptr->ctx == NULL) {
|
928
|
+
rb_raise(rb_eRuntimeError, "LLaMA context is not initialized");
|
929
|
+
return Qnil;
|
930
|
+
}
|
931
|
+
|
932
|
+
const int capacity = NUM2INT(kw_values[0]);
|
933
|
+
std::vector<const char*> strings;
|
934
|
+
std::vector<float> scores;
|
935
|
+
int n_vocab = llama_n_vocab(ptr->ctx);
|
936
|
+
strings.resize(n_vocab, NULL);
|
937
|
+
scores.resize(n_vocab, 0);
|
938
|
+
|
939
|
+
n_vocab = llama_get_vocab(ptr->ctx, strings.data(), scores.data(), capacity);
|
940
|
+
|
941
|
+
VALUE ret_strings = rb_ary_new();
|
942
|
+
VALUE ret_scores = rb_ary_new();
|
943
|
+
for (int i = 0; i < n_vocab; i++) {
|
944
|
+
rb_ary_push(ret_strings, rb_utf8_str_new_cstr(strings[i]));
|
945
|
+
rb_ary_push(ret_scores, DBL2NUM(static_cast<double>(scores[i])));
|
946
|
+
}
|
947
|
+
|
948
|
+
return rb_ary_new_from_args(2, ret_strings, ret_scores);
|
949
|
+
};
|
950
|
+
|
899
951
|
static VALUE _llama_context_n_vocab(VALUE self) {
|
900
952
|
LLaMAContextWrapper* ptr = get_llama_context(self);
|
901
953
|
if (ptr->ctx == NULL) {
|