llama-rb 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
data/ext/llama/llama.h ADDED
@@ -0,0 +1,152 @@
1
+ #ifndef LLAMA_H
2
+ #define LLAMA_H
3
+
4
+ #include <stddef.h>
5
+ #include <stdint.h>
6
+ #include <stdbool.h>
7
+
8
+ #ifdef LLAMA_SHARED
9
+ # if defined(_WIN32) && !defined(__MINGW32__)
10
+ # ifdef LLAMA_BUILD
11
+ # define LLAMA_API __declspec(dllexport)
12
+ # else
13
+ # define LLAMA_API __declspec(dllimport)
14
+ # endif
15
+ # else
16
+ # define LLAMA_API __attribute__ ((visibility ("default")))
17
+ # endif
18
+ #else
19
+ # define LLAMA_API
20
+ #endif
21
+
22
+ #define LLAMA_FILE_VERSION 1
23
+ #define LLAMA_FILE_MAGIC 0x67676a74 // 'ggjt' in hex
24
+ #define LLAMA_FILE_MAGIC_UNVERSIONED 0x67676d6c // pre-versioned files
25
+
26
+ #ifdef __cplusplus
27
+ extern "C" {
28
+ #endif
29
+
30
+ //
31
+ // C interface
32
+ //
33
+ // TODO: show sample usage
34
+ //
35
+
36
+ struct llama_context;
37
+
38
+ typedef int llama_token;
39
+
40
+ typedef struct llama_token_data {
41
+ llama_token id; // token id
42
+
43
+ float p; // probability of the token
44
+ float plog; // log probability of the token
45
+
46
+ } llama_token_data;
47
+
48
+ typedef void (*llama_progress_callback)(float progress, void *ctx);
49
+
50
+ struct llama_context_params {
51
+ int n_ctx; // text context
52
+ int n_parts; // -1 for default
53
+ int seed; // RNG seed, 0 for random
54
+
55
+ bool f16_kv; // use fp16 for KV cache
56
+ bool logits_all; // the llama_eval() call computes all logits, not just the last one
57
+ bool vocab_only; // only load the vocabulary, no weights
58
+ bool use_mlock; // force system to keep model in RAM
59
+ bool embedding; // embedding mode only
60
+
61
+ // called with a progress value between 0 and 1, pass NULL to disable
62
+ llama_progress_callback progress_callback;
63
+ // context pointer passed to the progress callback
64
+ void * progress_callback_user_data;
65
+ };
66
+
67
+ LLAMA_API struct llama_context_params llama_context_default_params();
68
+
69
+ // Various functions for loading a ggml llama model.
70
+ // Allocate (almost) all memory needed for the model.
71
+ // Return NULL on failure
72
+ LLAMA_API struct llama_context * llama_init_from_file(
73
+ const char * path_model,
74
+ struct llama_context_params params);
75
+
76
+ // Frees all allocated memory
77
+ LLAMA_API void llama_free(struct llama_context * ctx);
78
+
79
+ // TODO: not great API - very likely to change
80
+ // Returns 0 on success
81
+ LLAMA_API int llama_model_quantize(
82
+ const char * fname_inp,
83
+ const char * fname_out,
84
+ int itype);
85
+
86
+ // Run the llama inference to obtain the logits and probabilities for the next token.
87
+ // tokens + n_tokens is the provided batch of new tokens to process
88
+ // n_past is the number of tokens to use from previous eval calls
89
+ // Returns 0 on success
90
+ LLAMA_API int llama_eval(
91
+ struct llama_context * ctx,
92
+ const llama_token * tokens,
93
+ int n_tokens,
94
+ int n_past,
95
+ int n_threads);
96
+
97
+ // Convert the provided text into tokens.
98
+ // The tokens pointer must be large enough to hold the resulting tokens.
99
+ // Returns the number of tokens on success, no more than n_max_tokens
100
+ // Returns a negative number on failure - the number of tokens that would have been returned
101
+ // TODO: not sure if correct
102
+ LLAMA_API int llama_tokenize(
103
+ struct llama_context * ctx,
104
+ const char * text,
105
+ llama_token * tokens,
106
+ int n_max_tokens,
107
+ bool add_bos);
108
+
109
+ LLAMA_API int llama_n_vocab(struct llama_context * ctx);
110
+ LLAMA_API int llama_n_ctx (struct llama_context * ctx);
111
+ LLAMA_API int llama_n_embd (struct llama_context * ctx);
112
+
113
+ // Token logits obtained from the last call to llama_eval()
114
+ // The logits for the last token are stored in the last row
115
+ // Can be mutated in order to change the probabilities of the next token
116
+ // Rows: n_tokens
117
+ // Cols: n_vocab
118
+ LLAMA_API float * llama_get_logits(struct llama_context * ctx);
119
+
120
+ // Get the embeddings for the input
121
+ // shape: [n_embd] (1-dimensional)
122
+ LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
123
+
124
+ // Token Id -> String. Uses the vocabulary in the provided context
125
+ LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);
126
+
127
+ // Special tokens
128
+ LLAMA_API llama_token llama_token_bos();
129
+ LLAMA_API llama_token llama_token_eos();
130
+
131
+ // TODO: improve the last_n_tokens interface ?
132
+ LLAMA_API llama_token llama_sample_top_p_top_k(
133
+ struct llama_context * ctx,
134
+ const llama_token * last_n_tokens_data,
135
+ int last_n_tokens_size,
136
+ int top_k,
137
+ float top_p,
138
+ float temp,
139
+ float repeat_penalty);
140
+
141
+ // Performance information
142
+ LLAMA_API void llama_print_timings(struct llama_context * ctx);
143
+ LLAMA_API void llama_reset_timings(struct llama_context * ctx);
144
+
145
+ // Print system information
146
+ LLAMA_API const char * llama_print_system_info(void);
147
+
148
+ #ifdef __cplusplus
149
+ }
150
+ #endif
151
+
152
+ #endif
@@ -0,0 +1,192 @@
1
+ #include <rice/rice.hpp>
2
+
3
+ #include "common.h"
4
+ #include "llama.h"
5
+
6
+ #include <cassert>
7
+ #include <cinttypes>
8
+ #include <cmath>
9
+ #include <cstdio>
10
+ #include <cstring>
11
+ #include <fstream>
12
+ #include <iostream>
13
+ #include <string>
14
+ #include <vector>
15
+
16
+ class ModelCpp
17
+ {
18
+ public:
19
+ llama_context *ctx;
20
+ ModelCpp()
21
+ {
22
+ ctx = NULL;
23
+ }
24
+ void model_initialize(
25
+ const char *model,
26
+ const int32_t n_ctx,
27
+ const int32_t n_parts,
28
+ const int32_t seed,
29
+ const bool memory_f16,
30
+ const bool use_mlock
31
+ );
32
+ Rice::Object model_predict(
33
+ const char *prompt,
34
+ const int32_t n_predict
35
+ );
36
+ ~ModelCpp()
37
+ {
38
+
39
+ if (ctx != NULL) {
40
+ llama_free(ctx);
41
+ }
42
+ }
43
+ };
44
+
45
+ void ModelCpp::model_initialize(
46
+ const char *model, // path to model file, e.g. "models/7B/ggml-model-q4_0.bin"
47
+ const int32_t n_ctx, // context size
48
+ const int32_t n_parts, // amount of model parts (-1 = determine from model dimensions)
49
+ const int32_t seed, // RNG seed
50
+ const bool memory_f16, // use f16 instead of f32 for memory kv
51
+ const bool use_mlock // use mlock to keep model in memory
52
+ )
53
+ {
54
+ auto lparams = llama_context_default_params();
55
+
56
+ lparams.n_ctx = n_ctx;
57
+ lparams.n_parts = n_parts;
58
+ lparams.seed = seed;
59
+ lparams.f16_kv = memory_f16;
60
+ lparams.use_mlock = use_mlock;
61
+
62
+ ctx = llama_init_from_file(model, lparams);
63
+ }
64
+
65
+ Rice::Object ModelCpp::model_predict(
66
+ const char *prompt, // string used as prompt
67
+ const int32_t n_predict // number of tokens to predict
68
+ )
69
+ {
70
+ std::string return_val = "";
71
+
72
+ gpt_params params;
73
+ params.prompt = prompt;
74
+ params.n_predict = n_predict;
75
+
76
+ // add a space in front of the first character to match OG llama tokenizer behavior
77
+ params.prompt.insert(0, 1, ' ');
78
+
79
+ // tokenize the prompt
80
+ auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
81
+ const int n_ctx = llama_n_ctx(ctx);
82
+
83
+ // determine newline token
84
+ auto llama_token_newline = ::llama_tokenize(ctx, "\n", false);
85
+
86
+ // generate output
87
+ {
88
+ std::vector<llama_token> last_n_tokens(n_ctx);
89
+ std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
90
+
91
+ int n_past = 0;
92
+ int n_remain = params.n_predict;
93
+ int n_consumed = 0;
94
+
95
+ std::vector<llama_token> embd;
96
+
97
+ while (n_remain != 0) {
98
+ if (embd.size() > 0) {
99
+ // infinite text generation via context swapping
100
+ // if we run out of context:
101
+ // - take the n_keep first tokens from the original prompt (via n_past)
102
+ // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch
103
+ if (n_past + (int) embd.size() > n_ctx) {
104
+ const int n_left = n_past - params.n_keep;
105
+
106
+ n_past = params.n_keep;
107
+
108
+ // insert n_left/2 tokens at the start of embd from last_n_tokens
109
+ embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
110
+ }
111
+
112
+ if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
113
+ throw Rice::Exception(rb_eRuntimeError, "Failed to eval");
114
+ }
115
+ }
116
+
117
+
118
+ n_past += embd.size();
119
+ embd.clear();
120
+
121
+ if ((int) embd_inp.size() <= n_consumed) {
122
+ // out of user input, sample next token
123
+ const int32_t top_k = params.top_k;
124
+ const float top_p = params.top_p;
125
+ const float temp = params.temp;
126
+ const float repeat_penalty = params.repeat_penalty;
127
+
128
+ llama_token id = 0;
129
+
130
+ {
131
+ auto logits = llama_get_logits(ctx);
132
+
133
+ if (params.ignore_eos) {
134
+ logits[llama_token_eos()] = 0;
135
+ }
136
+
137
+ id = llama_sample_top_p_top_k(ctx,
138
+ last_n_tokens.data() + n_ctx - params.repeat_last_n,
139
+ params.repeat_last_n, top_k, top_p, temp, repeat_penalty);
140
+
141
+ last_n_tokens.erase(last_n_tokens.begin());
142
+ last_n_tokens.push_back(id);
143
+ }
144
+
145
+ // replace end of text token with newline token when in interactive mode
146
+ if (id == llama_token_eos() && params.interactive && !params.instruct) {
147
+ id = llama_token_newline.front();
148
+ if (params.antiprompt.size() != 0) {
149
+ // tokenize and inject first reverse prompt
150
+ const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
151
+ embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
152
+ }
153
+ }
154
+
155
+ // add it to the context
156
+ embd.push_back(id);
157
+
158
+ // decrement remaining sampling budget
159
+ --n_remain;
160
+ } else {
161
+ // some user input remains from prompt or interaction, forward it to processing
162
+ while ((int) embd_inp.size() > n_consumed) {
163
+ embd.push_back(embd_inp[n_consumed]);
164
+ last_n_tokens.erase(last_n_tokens.begin());
165
+ last_n_tokens.push_back(embd_inp[n_consumed]);
166
+ ++n_consumed;
167
+ if ((int) embd.size() >= params.n_batch) {
168
+ break;
169
+ }
170
+ }
171
+ }
172
+
173
+ for (auto id : embd) {
174
+ return_val += llama_token_to_str(ctx, id);
175
+ }
176
+ }
177
+ }
178
+
179
+ Rice::String ruby_return_val(return_val);
180
+ return ruby_return_val;
181
+ }
182
+
183
+ extern "C"
184
+ void Init_model()
185
+ {
186
+ Rice::Module rb_mLlama = Rice::define_module("Llama");
187
+ Rice::Data_Type<ModelCpp> rb_cModel =Rice::define_class_under<ModelCpp>(rb_mLlama, "Model");
188
+
189
+ rb_cModel.define_constructor(Rice::Constructor<ModelCpp>());
190
+ rb_cModel.define_method("initialize_cpp", &ModelCpp::model_initialize);
191
+ rb_cModel.define_method("predict_cpp", &ModelCpp::model_predict);
192
+ }
@@ -0,0 +1,86 @@
1
+ require 'tempfile'
2
+
3
+ module Llama
4
+ class Model
5
+ # move methods defined in `model.cpp` from public to private
6
+ private :initialize_cpp, :predict_cpp
7
+
8
+ # rubocop:disable Metrics/MethodLength
9
+ def self.new(
10
+ model, # path to model file, e.g. "models/7B/ggml-model-q4_0.bin"
11
+ n_ctx: 512, # context size
12
+ n_parts: -1, # amount of model parts (-1 = determine from model dimensions)
13
+ seed: Time.now.to_i, # RNG seed
14
+ memory_f16: true, # use f16 instead of f32 for memory kv
15
+ use_mlock: false # use mlock to keep model in memory
16
+ )
17
+ instance = allocate
18
+
19
+ instance.instance_eval do
20
+ initialize
21
+
22
+ @model = model
23
+ @n_ctx = n_ctx
24
+ @n_parts = n_parts
25
+ @seed = seed
26
+ @memory_f16 = memory_f16
27
+ @use_mlock = use_mlock
28
+
29
+ capture_stderr do
30
+ initialize_cpp(
31
+ model,
32
+ n_ctx,
33
+ n_parts,
34
+ seed,
35
+ memory_f16,
36
+ use_mlock,
37
+ )
38
+ end
39
+ end
40
+
41
+ instance
42
+ end
43
+ # rubocop:enable Metrics/MethodLength
44
+
45
+ def predict(
46
+ prompt, # string used as prompt
47
+ n_predict: 128 # number of tokens to predict
48
+ )
49
+ text = ''
50
+
51
+ capture_stderr { text = predict_cpp(prompt, n_predict) }
52
+
53
+ process_text(text)
54
+ end
55
+
56
+ attr_reader :model, :n_ctx, :n_parts, :seed, :memory_f16, :use_mlock, :stderr
57
+
58
+ private
59
+
60
+ def capture_stderr
61
+ previous = $stderr.dup
62
+ tmp = Tempfile.open('llama-rb-stderr')
63
+
64
+ begin
65
+ $stderr.reopen(tmp)
66
+
67
+ yield
68
+
69
+ tmp.rewind
70
+ @stderr = tmp.read
71
+ ensure
72
+ tmp.close(true)
73
+ $stderr.reopen(previous)
74
+ end
75
+ end
76
+
77
+ def process_text(text)
78
+ text = text.force_encoding(Encoding.default_external)
79
+
80
+ # remove the space that was added as a tokenizer hack in model.cpp
81
+ text[0] = '' if text.size.positive?
82
+
83
+ text
84
+ end
85
+ end
86
+ end
@@ -0,0 +1,3 @@
1
+ module Llama
2
+ VERSION = '0.1.0'.freeze
3
+ end
data/lib/llama.rb ADDED
@@ -0,0 +1,6 @@
1
+ require_relative 'llama/version'
2
+ require_relative '../ext/llama/model'
3
+ require_relative 'llama/model'
4
+
5
+ module Llama
6
+ end
data/llama-rb.gemspec ADDED
@@ -0,0 +1,50 @@
1
+ require_relative 'lib/llama'
2
+
3
+ Gem::Specification.new do |spec|
4
+ spec.name = 'llama-rb'
5
+ spec.version = Llama::VERSION
6
+ spec.licenses = ['MIT']
7
+ spec.authors = ['zfletch']
8
+ spec.email = ['zfletch2@gmail.com']
9
+
10
+ spec.summary = 'Ruby interface for Llama'
11
+ spec.description = 'ggerganov/llama.cpp with Ruby hooks'
12
+ spec.homepage = 'https://github.com/zfletch/llama-rb'
13
+ spec.required_ruby_version = '>= 3.0.0'
14
+
15
+ spec.metadata['homepage_uri'] = spec.homepage
16
+ spec.metadata['source_code_uri'] = spec.homepage
17
+ spec.metadata['changelog_uri'] = "#{spec.homepage}/releases"
18
+
19
+ # Specify which files should be added to the gem when it is released.
20
+ # The `git ls-files -z` loads the files in the RubyGem that have been added into git.
21
+ spec.files = [
22
+ "Gemfile",
23
+ "Gemfile.lock",
24
+ "LICENSE",
25
+ "README.md",
26
+ "Rakefile",
27
+ "ext/llama/common.cpp",
28
+ "ext/llama/common.h",
29
+ "ext/llama/extconf.rb",
30
+ "ext/llama/ggml.c",
31
+ "ext/llama/ggml.h",
32
+ "ext/llama/llama.cpp",
33
+ "ext/llama/llama.h",
34
+ "ext/llama/model.cpp",
35
+ "lib/llama.rb",
36
+ "lib/llama/model.rb",
37
+ "lib/llama/version.rb",
38
+ "llama-rb.gemspec",
39
+ "llama.cpp",
40
+ "models/.gitkeep",
41
+ ]
42
+ spec.bindir = 'exe'
43
+ spec.executables = spec.files.grep(%r{\Aexe/}) { |f| File.basename(f) }
44
+ spec.require_paths = ['lib']
45
+
46
+ spec.add_dependency 'rice', '~> 4.0.4'
47
+
48
+ spec.extensions = %w[ext/llama/extconf.rb]
49
+ spec.metadata['rubygems_mfa_required'] = 'true'
50
+ end
data/models/.gitkeep ADDED
File without changes
metadata ADDED
@@ -0,0 +1,80 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: llama-rb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - zfletch
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2023-04-02 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rice
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: 4.0.4
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: 4.0.4
27
+ description: ggerganov/llama.cpp with Ruby hooks
28
+ email:
29
+ - zfletch2@gmail.com
30
+ executables: []
31
+ extensions:
32
+ - ext/llama/extconf.rb
33
+ extra_rdoc_files: []
34
+ files:
35
+ - Gemfile
36
+ - Gemfile.lock
37
+ - LICENSE
38
+ - README.md
39
+ - Rakefile
40
+ - ext/llama/common.cpp
41
+ - ext/llama/common.h
42
+ - ext/llama/extconf.rb
43
+ - ext/llama/ggml.c
44
+ - ext/llama/ggml.h
45
+ - ext/llama/llama.cpp
46
+ - ext/llama/llama.h
47
+ - ext/llama/model.cpp
48
+ - lib/llama.rb
49
+ - lib/llama/model.rb
50
+ - lib/llama/version.rb
51
+ - llama-rb.gemspec
52
+ - models/.gitkeep
53
+ homepage: https://github.com/zfletch/llama-rb
54
+ licenses:
55
+ - MIT
56
+ metadata:
57
+ homepage_uri: https://github.com/zfletch/llama-rb
58
+ source_code_uri: https://github.com/zfletch/llama-rb
59
+ changelog_uri: https://github.com/zfletch/llama-rb/releases
60
+ rubygems_mfa_required: 'true'
61
+ post_install_message:
62
+ rdoc_options: []
63
+ require_paths:
64
+ - lib
65
+ required_ruby_version: !ruby/object:Gem::Requirement
66
+ requirements:
67
+ - - ">="
68
+ - !ruby/object:Gem::Version
69
+ version: 3.0.0
70
+ required_rubygems_version: !ruby/object:Gem::Requirement
71
+ requirements:
72
+ - - ">="
73
+ - !ruby/object:Gem::Version
74
+ version: '0'
75
+ requirements: []
76
+ rubygems_version: 3.3.7
77
+ signing_key:
78
+ specification_version: 4
79
+ summary: Ruby interface for Llama
80
+ test_files: []