llama-rb 0.2.0 → 0.3.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/Gemfile.lock +1 -1
- data/README.md +2 -8
- data/lib/llama/model.rb +66 -11
- data/lib/llama/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 4f2bc2e51fa10f5dcdc890664eb5603d1f3a3742d3259d3aa8784c790ded070f
|
4
|
+
data.tar.gz: 2b08904fca31b95d35bb1b6ea2a2c78288898ad072aaae26b7cf3f3a8c64184a
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3504f141131b27bca91c7348ef9617ec57d85b2ed1de67020afa46b89618fe008ce99c2df29cfe3ff1be1b01f9fe2b5b600389b298b66b7ff575767923eae6af
|
7
|
+
data.tar.gz: 404109c7650567a2bc2953324c0b9abda35381ecaa74bb7197e46f18e0f72e8c00f74d674009a4d25e2e458f3b77ceda605fcb91d4e954a8e9805fe2f26cc9bf
|
data/Gemfile.lock
CHANGED
data/README.md
CHANGED
@@ -42,21 +42,15 @@ m.predict('hello world')
|
|
42
42
|
```ruby
|
43
43
|
def self.new(
|
44
44
|
model, # path to model file, e.g. "models/7B/ggml-model-q4_0.bin"
|
45
|
-
|
46
|
-
n_parts: -1, # amount of model parts (-1 = determine from model dimensions)
|
45
|
+
n_predict: 128 # number of tokens to predict
|
47
46
|
seed: Time.now.to_i, # RNG seed
|
48
|
-
memory_f16: true, # use f16 instead of f32 for memory kv
|
49
|
-
use_mlock: false # use mlock to keep model in memory
|
50
47
|
)
|
51
48
|
```
|
52
49
|
|
53
50
|
#### Llama::Model#predict
|
54
51
|
|
55
52
|
```ruby
|
56
|
-
def predict(
|
57
|
-
prompt, # string used as prompt
|
58
|
-
n_predict: 128 # number of tokens to predict
|
59
|
-
)
|
53
|
+
def predict(prompt)
|
60
54
|
```
|
61
55
|
|
62
56
|
## Development
|
data/lib/llama/model.rb
CHANGED
@@ -6,29 +6,61 @@ module Llama
|
|
6
6
|
class ModelError < StandardError
|
7
7
|
end
|
8
8
|
|
9
|
-
def initialize(
|
9
|
+
def initialize( # rubocop:disable all
|
10
10
|
model,
|
11
|
+
binary: default_binary,
|
11
12
|
seed: Time.now.to_i,
|
12
|
-
n_predict:
|
13
|
-
|
13
|
+
n_predict: nil,
|
14
|
+
threads: nil,
|
15
|
+
top_k: nil,
|
16
|
+
top_p: nil,
|
17
|
+
repeat_last_n: nil,
|
18
|
+
repeat_penalty: nil,
|
19
|
+
ctx_size: nil,
|
20
|
+
ignore_eos: nil,
|
21
|
+
memory_f32: nil,
|
22
|
+
temp: nil,
|
23
|
+
n_parts: nil,
|
24
|
+
batch_size: nil,
|
25
|
+
keep: nil,
|
26
|
+
mlock: nil
|
14
27
|
)
|
15
28
|
@model = model
|
16
29
|
@seed = seed
|
17
30
|
@n_predict = n_predict
|
18
31
|
@binary = binary
|
32
|
+
@threads = threads
|
33
|
+
@top_k = top_k
|
34
|
+
@top_p = top_p
|
35
|
+
@repeat_last_n = repeat_last_n
|
36
|
+
@repeat_penalty = repeat_penalty
|
37
|
+
@ctx_size = ctx_size
|
38
|
+
@ignore_eos = ignore_eos
|
39
|
+
@memory_f32 = memory_f32
|
40
|
+
@temp = temp
|
41
|
+
@n_parts = n_parts
|
42
|
+
@batch_size = batch_size
|
43
|
+
@keep = keep
|
44
|
+
@mlock = mlock
|
19
45
|
end
|
20
46
|
|
21
47
|
def predict(prompt)
|
22
48
|
stdout, @stderr, @status = Open3.capture3(command(prompt))
|
23
49
|
|
24
|
-
|
50
|
+
unless status.success?
|
51
|
+
error_string = stderr.split("\n").first
|
52
|
+
|
53
|
+
raise ModelError, "Error #{error_string}"
|
54
|
+
end
|
25
55
|
|
26
56
|
# remove the space that is added as a tokenizer hack in examples/main/main.cpp
|
27
57
|
stdout[0] = ''
|
28
58
|
stdout
|
29
59
|
end
|
30
60
|
|
31
|
-
attr_reader :model, :seed, :n_predict, :binary
|
61
|
+
attr_reader :model, :seed, :n_predict, :binary, :threads, :top_k, :top_p, :repeat_last_n,
|
62
|
+
:repeat_penalty, :ctx_size, :ignore_eos, :memory_f32, :temp, :n_parts, :batch_size, :keep,
|
63
|
+
:mlock
|
32
64
|
|
33
65
|
private
|
34
66
|
|
@@ -38,19 +70,42 @@ module Llama
|
|
38
70
|
File.join(File.dirname(__FILE__), '..', '..', 'bin', 'llama')
|
39
71
|
end
|
40
72
|
|
41
|
-
def command(prompt)
|
42
|
-
escape_command(
|
73
|
+
def command(prompt) # rubocop:disable all
|
74
|
+
escape_command(
|
75
|
+
binary,
|
43
76
|
model: model,
|
44
77
|
prompt: prompt,
|
45
78
|
seed: seed,
|
46
|
-
n_predict: n_predict
|
79
|
+
n_predict: n_predict,
|
80
|
+
threads: threads,
|
81
|
+
top_k: top_k,
|
82
|
+
top_p: top_p,
|
83
|
+
repeat_last_n: repeat_last_n,
|
84
|
+
repeat_penalty: repeat_penalty,
|
85
|
+
ctx_size: ctx_size,
|
86
|
+
'ignore-eos': !!ignore_eos,
|
87
|
+
memory_f32: !!memory_f32,
|
88
|
+
temp: temp,
|
89
|
+
n_parts: n_parts,
|
90
|
+
batch_size: batch_size,
|
91
|
+
keep: keep,
|
92
|
+
mlock: mlock,
|
93
|
+
)
|
47
94
|
end
|
48
95
|
|
49
96
|
def escape_command(command, **flags)
|
50
|
-
|
51
|
-
|
52
|
-
|
97
|
+
flags_components = []
|
98
|
+
|
99
|
+
flags.each do |key, value|
|
100
|
+
if value == true
|
101
|
+
flags_components.push("--#{Shellwords.escape(key)}")
|
102
|
+
elsif value
|
103
|
+
flags_components.push("--#{Shellwords.escape(key)} #{Shellwords.escape(value)}")
|
104
|
+
end
|
105
|
+
end
|
106
|
+
|
53
107
|
command_string = Shellwords.escape(command)
|
108
|
+
flags_string = flags_components.join(' ')
|
54
109
|
|
55
110
|
"#{command_string} #{flags_string}"
|
56
111
|
end
|
data/lib/llama/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: llama-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.3.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- zfletch
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-04-
|
11
|
+
date: 2023-04-07 00:00:00.000000000 Z
|
12
12
|
dependencies: []
|
13
13
|
description: ggerganov/llama.cpp with Ruby hooks
|
14
14
|
email:
|