llama-rb 0.2.1 → 0.3.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '03801e4f99933be9c0e8d559008626991535c2167af88c8cb31defb31c88d0f6'
4
- data.tar.gz: 6f17e50818de906f33de2686cf1b75c0e17aa052f0fba60889bad85df0591f59
3
+ metadata.gz: 4f2bc2e51fa10f5dcdc890664eb5603d1f3a3742d3259d3aa8784c790ded070f
4
+ data.tar.gz: 2b08904fca31b95d35bb1b6ea2a2c78288898ad072aaae26b7cf3f3a8c64184a
5
5
  SHA512:
6
- metadata.gz: 40602fc8c253087a78fd4e5edf5fbae24f3a4ad0d9a3bb2f6730ef701753f6815e8716303220e8edcb1984484d5ffbd20c6adb7e07690244cd738ec6918c80e8
7
- data.tar.gz: 9cbf6bed4fa4359bd007d083f99976a885b1557b0bf01c4d22a55e231515adf7f66e58e951e01bf731e827b893bf6fc278a306f8a566be3e133039f210214bc2
6
+ metadata.gz: 3504f141131b27bca91c7348ef9617ec57d85b2ed1de67020afa46b89618fe008ce99c2df29cfe3ff1be1b01f9fe2b5b600389b298b66b7ff575767923eae6af
7
+ data.tar.gz: 404109c7650567a2bc2953324c0b9abda35381ecaa74bb7197e46f18e0f72e8c00f74d674009a4d25e2e458f3b77ceda605fcb91d4e954a8e9805fe2f26cc9bf
data/Gemfile.lock CHANGED
@@ -1,7 +1,7 @@
1
1
  PATH
2
2
  remote: .
3
3
  specs:
4
- llama-rb (0.2.0)
4
+ llama-rb (0.3.0)
5
5
 
6
6
  GEM
7
7
  remote: https://rubygems.org/
data/lib/llama/model.rb CHANGED
@@ -6,29 +6,61 @@ module Llama
6
6
  class ModelError < StandardError
7
7
  end
8
8
 
9
- def initialize(
9
+ def initialize( # rubocop:disable all
10
10
  model,
11
+ binary: default_binary,
11
12
  seed: Time.now.to_i,
12
- n_predict: 128,
13
- binary: default_binary
13
+ n_predict: nil,
14
+ threads: nil,
15
+ top_k: nil,
16
+ top_p: nil,
17
+ repeat_last_n: nil,
18
+ repeat_penalty: nil,
19
+ ctx_size: nil,
20
+ ignore_eos: nil,
21
+ memory_f32: nil,
22
+ temp: nil,
23
+ n_parts: nil,
24
+ batch_size: nil,
25
+ keep: nil,
26
+ mlock: nil
14
27
  )
15
28
  @model = model
16
29
  @seed = seed
17
30
  @n_predict = n_predict
18
31
  @binary = binary
32
+ @threads = threads
33
+ @top_k = top_k
34
+ @top_p = top_p
35
+ @repeat_last_n = repeat_last_n
36
+ @repeat_penalty = repeat_penalty
37
+ @ctx_size = ctx_size
38
+ @ignore_eos = ignore_eos
39
+ @memory_f32 = memory_f32
40
+ @temp = temp
41
+ @n_parts = n_parts
42
+ @batch_size = batch_size
43
+ @keep = keep
44
+ @mlock = mlock
19
45
  end
20
46
 
21
47
  def predict(prompt)
22
48
  stdout, @stderr, @status = Open3.capture3(command(prompt))
23
49
 
24
- raise ModelError, "Error #{status.to_i}" unless status.success?
50
+ unless status.success?
51
+ error_string = stderr.split("\n").first
52
+
53
+ raise ModelError, "Error #{error_string}"
54
+ end
25
55
 
26
56
  # remove the space that is added as a tokenizer hack in examples/main/main.cpp
27
57
  stdout[0] = ''
28
58
  stdout
29
59
  end
30
60
 
31
- attr_reader :model, :seed, :n_predict, :binary
61
+ attr_reader :model, :seed, :n_predict, :binary, :threads, :top_k, :top_p, :repeat_last_n,
62
+ :repeat_penalty, :ctx_size, :ignore_eos, :memory_f32, :temp, :n_parts, :batch_size, :keep,
63
+ :mlock
32
64
 
33
65
  private
34
66
 
@@ -38,19 +70,42 @@ module Llama
38
70
  File.join(File.dirname(__FILE__), '..', '..', 'bin', 'llama')
39
71
  end
40
72
 
41
- def command(prompt)
42
- escape_command(binary,
73
+ def command(prompt) # rubocop:disable all
74
+ escape_command(
75
+ binary,
43
76
  model: model,
44
77
  prompt: prompt,
45
78
  seed: seed,
46
- n_predict: n_predict)
79
+ n_predict: n_predict,
80
+ threads: threads,
81
+ top_k: top_k,
82
+ top_p: top_p,
83
+ repeat_last_n: repeat_last_n,
84
+ repeat_penalty: repeat_penalty,
85
+ ctx_size: ctx_size,
86
+ 'ignore-eos': !!ignore_eos,
87
+ memory_f32: !!memory_f32,
88
+ temp: temp,
89
+ n_parts: n_parts,
90
+ batch_size: batch_size,
91
+ keep: keep,
92
+ mlock: mlock,
93
+ )
47
94
  end
48
95
 
49
96
  def escape_command(command, **flags)
50
- flags_string = flags.map do |key, value|
51
- "--#{Shellwords.escape(key)} #{Shellwords.escape(value)}"
52
- end.join(' ')
97
+ flags_components = []
98
+
99
+ flags.each do |key, value|
100
+ if value == true
101
+ flags_components.push("--#{Shellwords.escape(key)}")
102
+ elsif value
103
+ flags_components.push("--#{Shellwords.escape(key)} #{Shellwords.escape(value)}")
104
+ end
105
+ end
106
+
53
107
  command_string = Shellwords.escape(command)
108
+ flags_string = flags_components.join(' ')
54
109
 
55
110
  "#{command_string} #{flags_string}"
56
111
  end
data/lib/llama/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Llama
2
- VERSION = '0.2.1'.freeze
2
+ VERSION = '0.3.0'.freeze
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: llama-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.1
4
+ version: 0.3.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - zfletch
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-04-06 00:00:00.000000000 Z
11
+ date: 2023-04-07 00:00:00.000000000 Z
12
12
  dependencies: []
13
13
  description: ggerganov/llama.cpp with Ruby hooks
14
14
  email: