llama-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/llama/llama.h ADDED
@@ -0,0 +1,152 @@
1
+ #ifndef LLAMA_H
2
+ #define LLAMA_H
3
+
4
+ #include <stddef.h>
5
+ #include <stdint.h>
6
+ #include <stdbool.h>
7
+
8
+ #ifdef LLAMA_SHARED
9
+ # if defined(_WIN32) && !defined(__MINGW32__)
10
+ # ifdef LLAMA_BUILD
11
+ # define LLAMA_API __declspec(dllexport)
12
+ # else
13
+ # define LLAMA_API __declspec(dllimport)
14
+ # endif
15
+ # else
16
+ # define LLAMA_API __attribute__ ((visibility ("default")))
17
+ # endif
18
+ #else
19
+ # define LLAMA_API
20
+ #endif
21
+
22
+ #define LLAMA_FILE_VERSION 1
23
+ #define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
24
+ #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
25
+
26
+ #ifdef __cplusplus
27
+ extern "C" {
28
+ #endif
29
+
30
+ //
31
+ // C interface
32
+ //
33
+ // TODO: show sample usage
34
+ //
35
+
36
+ struct llama_context;
37
+
38
+ typedef int llama_token;
39
+
40
+ typedef struct llama_token_data {
41
+ llama_token id; // token id
42
+
43
+ float p; // probability of the token
44
+ float plog; // log probability of the token
45
+
46
+ } llama_token_data;
47
+
48
+ typedef void (*llama_progress_callback)(float progress, void *ctx);
49
+
50
+ struct llama_context_params {
51
+ int n_ctx; // text context
52
+ int n_parts; // -1 for default
53
+ int seed; // RNG seed, 0 for random
54
+
55
+ bool f16_kv; // use fp16 for KV cache
56
+ bool logits_all; // the llama_eval() call computes all logits, not just the last one
57
+ bool vocab_only; // only load the vocabulary, no weights
58
+ bool use_mlock; // force system to keep model in RAM
59
+ bool embedding; // embedding mode only
60
+
61
+ // called with a progress value between 0 and 1, pass NULL to disable
62
+ llama_progress_callback progress_callback;
63
+ // context pointer passed to the progress callback
64
+ void * progress_callback_user_data;
65
+ };
66
+
67
+ LLAMA_API struct llama_context_params llama_context_default_params();
68
+
69
+ // Various functions for loading a ggml llama model.
70
+ // Allocate (almost) all memory needed for the model.
71
+ // Return NULL on failure
72
+ LLAMA_API struct llama_context * llama_init_from_file(
73
+ const char * path_model,
74
+ struct llama_context_params params);
75
+
76
+ // Frees all allocated memory
77
+ LLAMA_API void llama_free(struct llama_context * ctx);
78
+
79
+ // TODO: not great API - very likely to change
80
+ // Returns 0 on success
81
+ LLAMA_API int llama_model_quantize(
82
+ const char * fname_inp,
83
+ const char * fname_out,
84
+ int itype);
85
+
86
+ // Run the llama inference to obtain the logits and probabilities for the next token.
87
+ // tokens + n_tokens is the provided batch of new tokens to process
88
+ // n_past is the number of tokens to use from previous eval calls
89
+ // Returns 0 on success
90
+ LLAMA_API int llama_eval(
91
+ struct llama_context * ctx,
92
+ const llama_token * tokens,
93
+ int n_tokens,
94
+ int n_past,
95
+ int n_threads);
96
+
97
+ // Convert the provided text into tokens.
98
+ // The tokens pointer must be large enough to hold the resulting tokens.
99
+ // Returns the number of tokens on success, no more than n_max_tokens
100
+ // Returns a negative number on failure - the number of tokens that would have been returned
101
+ // TODO: not sure if correct
102
+ LLAMA_API int llama_tokenize(
103
+ struct llama_context * ctx,
104
+ const char * text,
105
+ llama_token * tokens,
106
+ int n_max_tokens,
107
+ bool add_bos);
108
+
109
+ LLAMA_API int llama_n_vocab(struct llama_context * ctx);
110
+ LLAMA_API int llama_n_ctx (struct llama_context * ctx);
111
+ LLAMA_API int llama_n_embd (struct llama_context * ctx);
112
+
113
+ // Token logits obtained from the last call to llama_eval()
114
+ // The logits for the last token are stored in the last row
115
+ // Can be mutated in order to change the probabilities of the next token
116
+ // Rows: n_tokens
117
+ // Cols: n_vocab
118
+ LLAMA_API float * llama_get_logits(struct llama_context * ctx);
119
+
120
+ // Get the embeddings for the input
121
+ // shape: [n_embd] (1-dimensional)
122
+ LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
123
+
124
+ // Token Id -> String. Uses the vocabulary in the provided context
125
+ LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
126
+
127
+ // Special tokens
128
+ LLAMA_API llama_token llama_token_bos();
129
+ LLAMA_API llama_token llama_token_eos();
130
+
131
+ // TODO: improve the last_n_tokens interface ?
132
+ LLAMA_API llama_token llama_sample_top_p_top_k(
133
+ struct llama_context * ctx,
134
+ const llama_token * last_n_tokens_data,
135
+ int last_n_tokens_size,
136
+ int top_k,
137
+ float top_p,
138
+ float temp,
139
+ float repeat_penalty);
140
+
141
+ // Performance information
142
+ LLAMA_API void llama_print_timings(struct llama_context * ctx);
143
+ LLAMA_API void llama_reset_timings(struct llama_context * ctx);
144
+
145
+ // Print system information
146
+ LLAMA_API const char * llama_print_system_info(void);
147
+
148
+ #ifdef __cplusplus
149
+ }
150
+ #endif
151
+
152
+ #endif
@@ -0,0 +1,192 @@
1
+ #include <rice/rice.hpp>
2
+
3
+ #include "common.h"
4
+ #include "llama.h"
5
+
6
+ #include <cassert>
7
+ #include <cinttypes>
8
+ #include <cmath>
9
+ #include <cstdio>
10
+ #include <cstring>
11
+ #include <fstream>
12
+ #include <iostream>
13
+ #include <string>
14
+ #include <vector>
15
+
16
+ class ModelCpp
17
+ {
18
+ public:
19
+ llama_context *ctx;
20
+ ModelCpp()
21
+ {
22
+ ctx = NULL;
23
+ }
24
+ void model_initialize(
25
+ const char *model,
26
+ const int32_t n_ctx,
27
+ const int32_t n_parts,
28
+ const int32_t seed,
29
+ const bool memory_f16,
30
+ const bool use_mlock
31
+ );
32
+ Rice::Object model_predict(
33
+ const char *prompt,
34
+ const int32_t n_predict
35
+ );
36
+ ~ModelCpp()
37
+ {
38
+
39
+ if (ctx != NULL) {
40
+ llama_free(ctx);
41
+ }
42
+ }
43
+ };
44
+
45
+ void ModelCpp::model_initialize(
46
+ const char *model, // path to model file, e.g. "models/7B/ggml-model-q4_0.bin"
47
+ const int32_t n_ctx, // context size
48
+ const int32_t n_parts, // amount of model parts (-1 = determine from model dimensions)
49
+ const int32_t seed, // RNG seed
50
+ const bool memory_f16, // use f16 instead of f32 for memory kv
51
+ const bool use_mlock // use mlock to keep model in memory
52
+ )
53
+ {
54
+ auto lparams = llama_context_default_params();
55
+
56
+ lparams.n_ctx = n_ctx;
57
+ lparams.n_parts = n_parts;
58
+ lparams.seed = seed;
59
+ lparams.f16_kv = memory_f16;
60
+ lparams.use_mlock = use_mlock;
61
+
62
+ ctx = llama_init_from_file(model, lparams);
63
+ }
64
+
65
+ Rice::Object ModelCpp::model_predict(
66
+ const char *prompt, // string used as prompt
67
+ const int32_t n_predict // number of tokens to predict
68
+ )
69
+ {
70
+ std::string return_val = "";
71
+
72
+ gpt_params params;
73
+ params.prompt = prompt;
74
+ params.n_predict = n_predict;
75
+
76
+ // add a space in front of the first character to match OG llama tokenizer behavior
77
+ params.prompt.insert(0, 1, ' ');
78
+
79
+ // tokenize the prompt
80
+ auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
81
+ const int n_ctx = llama_n_ctx(ctx);
82
+
83
+ // determine newline token
84
+ auto llama_token_newline = ::llama_tokenize(ctx, "\n", false);
85
+
86
+ // generate output
87
+ {
88
+ std::vector<llama_token> last_n_tokens(n_ctx);
89
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
90
+
91
+ int n_past = 0;
92
+ int n_remain = params.n_predict;
93
+ int n_consumed = 0;
94
+
95
+ std::vector<llama_token> embd;
96
+
97
+ while (n_remain != 0) {
98
+ if (embd.size() > 0) {
99
+ // infinite text generation via context swapping
100
+ // if we run out of context:
101
+ // - take the n_keep first tokens from the original prompt (via n_past)
102
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
103
+ if (n_past + (int) embd.size() > n_ctx) {
104
+ const int n_left = n_past - params.n_keep;
105
+
106
+ n_past = params.n_keep;
107
+
108
+ // insert n_left/2 tokens at the start of embd from last_n_tokens
109
+ embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
110
+ }
111
+
112
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
113
+ throw Rice::Exception(rb_eRuntimeError, "Failed to eval");
114
+ }
115
+ }
116
+
117
+
118
+ n_past += embd.size();
119
+ embd.clear();
120
+
121
+ if ((int) embd_inp.size() <= n_consumed) {
122
+ // out of user input, sample next token
123
+ const int32_t top_k = params.top_k;
124
+ const float top_p = params.top_p;
125
+ const float temp = params.temp;
126
+ const float repeat_penalty = params.repeat_penalty;
127
+
128
+ llama_token id = 0;
129
+
130
+ {
131
+ auto logits = llama_get_logits(ctx);
132
+
133
+ if (params.ignore_eos) {
134
+ logits[llama_token_eos()] = 0;
135
+ }
136
+
137
+ id = llama_sample_top_p_top_k(ctx,
138
+ last_n_tokens.data() + n_ctx - params.repeat_last_n,
139
+ params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
140
+
141
+ last_n_tokens.erase(last_n_tokens.begin());
142
+ last_n_tokens.push_back(id);
143
+ }
144
+
145
+ // replace end of text token with newline token when in interactive mode
146
+ if (id == llama_token_eos() && params.interactive && !params.instruct) {
147
+ id = llama_token_newline.front();
148
+ if (params.antiprompt.size() != 0) {
149
+ // tokenize and inject first reverse prompt
150
+ const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
151
+ embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
152
+ }
153
+ }
154
+
155
+ // add it to the context
156
+ embd.push_back(id);
157
+
158
+ // decrement remaining sampling budget
159
+ --n_remain;
160
+ } else {
161
+ // some user input remains from prompt or interaction, forward it to processing
162
+ while ((int) embd_inp.size() > n_consumed) {
163
+ embd.push_back(embd_inp[n_consumed]);
164
+ last_n_tokens.erase(last_n_tokens.begin());
165
+ last_n_tokens.push_back(embd_inp[n_consumed]);
166
+ ++n_consumed;
167
+ if ((int) embd.size() >= params.n_batch) {
168
+ break;
169
+ }
170
+ }
171
+ }
172
+
173
+ for (auto id : embd) {
174
+ return_val += llama_token_to_str(ctx, id);
175
+ }
176
+ }
177
+ }
178
+
179
+ Rice::String ruby_return_val(return_val);
180
+ return ruby_return_val;
181
+ }
182
+
183
+ extern "C"
184
+ void Init_model()
185
+ {
186
+ Rice::Module rb_mLlama = Rice::define_module("Llama");
187
+ Rice::Data_Type<ModelCpp> rb_cModel =Rice::define_class_under<ModelCpp>(rb_mLlama, "Model");
188
+
189
+ rb_cModel.define_constructor(Rice::Constructor<ModelCpp>());
190
+ rb_cModel.define_method("initialize_cpp", &ModelCpp::model_initialize);
191
+ rb_cModel.define_method("predict_cpp", &ModelCpp::model_predict);
192
+ }
@@ -0,0 +1,86 @@
1
+ require 'tempfile'
2
+
3
+ module Llama
4
+ class Model
5
+ # move methods defined in `model.cpp` from public to private
6
+ private :initialize_cpp, :predict_cpp
7
+
8
+ # rubocop:disable Metrics/MethodLength
9
+ def self.new(
10
+ model, # path to model file, e.g. "models/7B/ggml-model-q4_0.bin"
11
+ n_ctx: 512, # context size
12
+ n_parts: -1, # amount of model parts (-1 = determine from model dimensions)
13
+ seed: Time.now.to_i, # RNG seed
14
+ memory_f16: true, # use f16 instead of f32 for memory kv
15
+ use_mlock: false # use mlock to keep model in memory
16
+ )
17
+ instance = allocate
18
+
19
+ instance.instance_eval do
20
+ initialize
21
+
22
+ @model = model
23
+ @n_ctx = n_ctx
24
+ @n_parts = n_parts
25
+ @seed = seed
26
+ @memory_f16 = memory_f16
27
+ @use_mlock = use_mlock
28
+
29
+ capture_stderr do
30
+ initialize_cpp(
31
+ model,
32
+ n_ctx,
33
+ n_parts,
34
+ seed,
35
+ memory_f16,
36
+ use_mlock,
37
+ )
38
+ end
39
+ end
40
+
41
+ instance
42
+ end
43
+ # rubocop:enable Metrics/MethodLength
44
+
45
+ def predict(
46
+ prompt, # string used as prompt
47
+ n_predict: 128 # number of tokens to predict
48
+ )
49
+ text = ''
50
+
51
+ capture_stderr { text = predict_cpp(prompt, n_predict) }
52
+
53
+ process_text(text)
54
+ end
55
+
56
+ attr_reader :model, :n_ctx, :n_parts, :seed, :memory_f16, :use_mlock, :stderr
57
+
58
+ private
59
+
60
+ def capture_stderr
61
+ previous = $stderr.dup
62
+ tmp = Tempfile.open('llama-rb-stderr')
63
+
64
+ begin
65
+ $stderr.reopen(tmp)
66
+
67
+ yield
68
+
69
+ tmp.rewind
70
+ @stderr = tmp.read
71
+ ensure
72
+ tmp.close(true)
73
+ $stderr.reopen(previous)
74
+ end
75
+ end
76
+
77
+ def process_text(text)
78
+ text = text.force_encoding(Encoding.default_external)
79
+
80
+ # remove the space that was added as a tokenizer hack in model.cpp
81
+ text[0] = '' if text.size.positive?
82
+
83
+ text
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,3 @@
1
+ module Llama
2
+ VERSION = '0.1.0'.freeze
3
+ end
data/lib/llama.rb ADDED
@@ -0,0 +1,6 @@
1
+ require_relative 'llama/version'
2
+ require_relative '../ext/llama/model'
3
+ require_relative 'llama/model'
4
+
5
+ module Llama
6
+ end
data/llama-rb.gemspec ADDED
@@ -0,0 +1,50 @@
1
+ require_relative 'lib/llama'
2
+
3
+ Gem::Specification.new do |spec|
4
+ spec.name = 'llama-rb'
5
+ spec.version = Llama::VERSION
6
+ spec.licenses = ['MIT']
7
+ spec.authors = ['zfletch']
8
+ spec.email = ['zfletch2@gmail.com']
9
+
10
+ spec.summary = 'Ruby interface for Llama'
11
+ spec.description = 'ggerganov/llama.cpp with Ruby hooks'
12
+ spec.homepage = 'https://github.com/zfletch/llama-rb'
13
+ spec.required_ruby_version = '>= 3.0.0'
14
+
15
+ spec.metadata['homepage_uri'] = spec.homepage
16
+ spec.metadata['source_code_uri'] = spec.homepage
17
+ spec.metadata['changelog_uri'] = "#{spec.homepage}/releases"
18
+
19
+ # Specify which files should be added to the gem when it is released.
20
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
21
+ spec.files = [
22
+ "Gemfile",
23
+ "Gemfile.lock",
24
+ "LICENSE",
25
+ "README.md",
26
+ "Rakefile",
27
+ "ext/llama/common.cpp",
28
+ "ext/llama/common.h",
29
+ "ext/llama/extconf.rb",
30
+ "ext/llama/ggml.c",
31
+ "ext/llama/ggml.h",
32
+ "ext/llama/llama.cpp",
33
+ "ext/llama/llama.h",
34
+ "ext/llama/model.cpp",
35
+ "lib/llama.rb",
36
+ "lib/llama/model.rb",
37
+ "lib/llama/version.rb",
38
+ "llama-rb.gemspec",
39
+ "llama.cpp",
40
+ "models/.gitkeep",
41
+ ]
42
+ spec.bindir = 'exe'
43
+ spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
44
+ spec.require_paths = ['lib']
45
+
46
+ spec.add_dependency 'rice', '~> 4.0.4'
47
+
48
+ spec.extensions = %w[ext/llama/extconf.rb]
49
+ spec.metadata['rubygems_mfa_required'] = 'true'
50
+ end
data/models/.gitkeep ADDED
File without changes
metadata ADDED
@@ -0,0 +1,80 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: llama-rb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - zfletch
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2023-04-02 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rice
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: 4.0.4
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: 4.0.4
27
+ description: ggerganov/llama.cpp with Ruby hooks
28
+ email:
29
+ - zfletch2@gmail.com
30
+ executables: []
31
+ extensions:
32
+ - ext/llama/extconf.rb
33
+ extra_rdoc_files: []
34
+ files:
35
+ - Gemfile
36
+ - Gemfile.lock
37
+ - LICENSE
38
+ - README.md
39
+ - Rakefile
40
+ - ext/llama/common.cpp
41
+ - ext/llama/common.h
42
+ - ext/llama/extconf.rb
43
+ - ext/llama/ggml.c
44
+ - ext/llama/ggml.h
45
+ - ext/llama/llama.cpp
46
+ - ext/llama/llama.h
47
+ - ext/llama/model.cpp
48
+ - lib/llama.rb
49
+ - lib/llama/model.rb
50
+ - lib/llama/version.rb
51
+ - llama-rb.gemspec
52
+ - models/.gitkeep
53
+ homepage: https://github.com/zfletch/llama-rb
54
+ licenses:
55
+ - MIT
56
+ metadata:
57
+ homepage_uri: https://github.com/zfletch/llama-rb
58
+ source_code_uri: https://github.com/zfletch/llama-rb
59
+ changelog_uri: https://github.com/zfletch/llama-rb/releases
60
+ rubygems_mfa_required: 'true'
61
+ post_install_message:
62
+ rdoc_options: []
63
+ require_paths:
64
+ - lib
65
+ required_ruby_version: !ruby/object:Gem::Requirement
66
+ requirements:
67
+ - - ">="
68
+ - !ruby/object:Gem::Version
69
+ version: 3.0.0
70
+ required_rubygems_version: !ruby/object:Gem::Requirement
71
+ requirements:
72
+ - - ">="
73
+ - !ruby/object:Gem::Version
74
+ version: '0'
75
+ requirements: []
76
+ rubygems_version: 3.3.7
77
+ signing_key:
78
+ specification_version: 4
79
+ summary: Ruby interface for Llama
80
+ test_files: []