red-candle 0.0.5 → 1.0.0.pre.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +2627 -603
- data/Cargo.toml +4 -0
- data/README.md +224 -6
- data/ext/candle/Cargo.toml +20 -9
- data/ext/candle/extconf.rb +77 -1
- data/ext/candle/src/lib.rs +121 -82
- data/lib/candle/build_info.rb +66 -0
- data/lib/candle/device_utils.rb +20 -0
- data/lib/candle/embedding_model.rb +32 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +107 -0
- data/lib/candle/reranker.rb +24 -0
- data/lib/candle/tensor.rb +68 -3
- data/lib/candle/version.rb +2 -2
- data/lib/candle.rb +6 -0
- metadata +9 -3
data/Cargo.toml
CHANGED
data/README.md
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
[](https://github.com/assaydepot/red-candle/actions/workflows/build.yml)
|
4
4
|
[](https://badge.fury.io/rb/red-candle)
|
5
5
|
|
6
|
-
|
6
|
+
[candle](https://github.com/huggingface/candle) - Minimalist ML framework - for Ruby
|
7
7
|
|
8
8
|
## Usage
|
9
9
|
|
@@ -20,19 +20,119 @@ x = x.reshape([3, 2])
|
|
20
20
|
|
21
21
|
```ruby
|
22
22
|
require 'candle'
|
23
|
-
|
23
|
+
|
24
|
+
# Default model (JinaBERT) on CPU
|
25
|
+
model = Candle::EmbeddingModel.new
|
26
|
+
embedding = model.embedding("Hi there!")
|
27
|
+
|
28
|
+
# Specify device (CPU, Metal, or CUDA)
|
29
|
+
device = Candle::Device.cpu # or Candle::Device.metal, Candle::Device.cuda
|
30
|
+
model = Candle::EmbeddingModel.new(
|
31
|
+
model_path: "jinaai/jina-embeddings-v2-base-en",
|
32
|
+
device: device
|
33
|
+
)
|
24
34
|
embedding = model.embedding("Hi there!")
|
35
|
+
|
36
|
+
# Reranker also supports device selection
|
37
|
+
reranker = Candle::Reranker.new(
|
38
|
+
model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2",
|
39
|
+
device: device
|
40
|
+
)
|
41
|
+
results = reranker.rerank("query", ["doc1", "doc2", "doc3"])
|
42
|
+
```
|
43
|
+
|
44
|
+
## LLM Support
|
45
|
+
|
46
|
+
Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
|
47
|
+
|
48
|
+
> ### ⚠️ Huggingface login warning
|
49
|
+
>
|
50
|
+
> Many models, including the one below, require you to agree to the terms. You'll need to:
|
51
|
+
> 1. Login to [Huggingface](https://huggingface.co)
|
52
|
+
> 2. Agree to the terms. For example: [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
53
|
+
> 3. Authenticate your session. Simplest way is with `huggingface-cli login`. Detail here: [Huggingface CLI](https://huggingface.co/docs/huggingface_hub/en/guides/cli)
|
54
|
+
>
|
55
|
+
> More details here: [Huggingface Authentication](HUGGINGFACE.md)
|
56
|
+
|
57
|
+
```ruby
|
58
|
+
require 'candle'
|
59
|
+
|
60
|
+
# Choose your device
|
61
|
+
device = Candle::Device.cpu # CPU (default)
|
62
|
+
device = Candle::Device.metal # Apple GPU (Metal)
|
63
|
+
device = Candle::Device.cuda # NVIDIA GPU (CUDA)
|
64
|
+
|
65
|
+
# Load a model
|
66
|
+
llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
|
67
|
+
|
68
|
+
# Generate text
|
69
|
+
response = llm.generate("What is Ruby?", config: Candle::GenerationConfig.balanced)
|
70
|
+
|
71
|
+
# Stream generation
|
72
|
+
llm.generate_stream("Tell me a story", config: Candle::GenerationConfig.balanced) do |token|
|
73
|
+
print token
|
74
|
+
end
|
75
|
+
|
76
|
+
# Chat interface
|
77
|
+
messages = [
|
78
|
+
{ role: "system", content: "You are a helpful assistant." },
|
79
|
+
{ role: "user", content: "Explain Ruby in one sentence." }
|
80
|
+
]
|
81
|
+
response = llm.chat(messages)
|
82
|
+
```
|
83
|
+
|
84
|
+
### GPU Acceleration
|
85
|
+
|
86
|
+
```ruby
|
87
|
+
# CPU works for all models
|
88
|
+
device = Candle::Device.cpu
|
89
|
+
llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
|
90
|
+
|
91
|
+
# Metal
|
92
|
+
device = Candle::Device.metal
|
93
|
+
|
94
|
+
# CUDA support (for NVIDIA GPUs COMING SOON)
|
95
|
+
device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
|
96
|
+
```
|
97
|
+
|
98
|
+
## ⚠️ Model Format Requirement: Safetensors Only
|
99
|
+
|
100
|
+
Red-Candle **only supports embedding models that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
|
101
|
+
|
102
|
+
**If you encounter an error like:**
|
103
|
+
|
104
|
+
```
|
105
|
+
RuntimeError: model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors.
|
106
|
+
```
|
107
|
+
|
108
|
+
this means the selected model is not compatible. Please choose a model repo that provides the required file.
|
109
|
+
|
110
|
+
## Supported Embedding Models
|
111
|
+
|
112
|
+
Red-Candle supports the following embedding model types from Hugging Face:
|
113
|
+
|
114
|
+
1. `Candle::EmbeddingModelType::JINA_BERT` - Jina BERT models (e.g., `jinaai/jina-embeddings-v2-base-en`) (**safetensors required**)
|
115
|
+
2. `Candle::EmbeddingModelType::STANDARD_BERT` - Standard BERT models (e.g., `sentence-transformers/all-MiniLM-L6-v2`) (**safetensors required**)
|
116
|
+
3. `Candle::EmbeddingModelType::DISTILBERT` - DistilBERT models (e.g., `distilbert-base-uncased-finetuned-sst-2-english`) (**safetensors required**)
|
117
|
+
|
118
|
+
> **Note:** Most official BERT and DistilBERT models do _not_ provide safetensors. Please check the model repo before use.
|
119
|
+
|
120
|
+
You can get a list of all supported model types and suggested models paths:
|
121
|
+
|
122
|
+
```ruby
|
123
|
+
Candle::EmbeddingModelType.all # Returns all supported model types
|
124
|
+
Candle::EmbeddingModelType.suggested_model_paths # Returns hash of suggested models for each type
|
25
125
|
```
|
26
126
|
|
27
127
|
## A note on memory usage
|
28
|
-
The
|
128
|
+
The default model (`jinaai/jina-embeddings-v2-base-en` with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer, both from [HuggingFace](https://huggingface.co)) takes a little more than 3GB of memory running on a Mac. The memory stays with the instantiated `Candle::EmbeddingModel` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example:
|
29
129
|
|
30
130
|
```ruby
|
31
131
|
> require 'candle'
|
32
132
|
# Ruby memory = 25.9 MB
|
33
|
-
> model = Candle::
|
133
|
+
> model = Candle::EmbeddingModel.new
|
34
134
|
# Ruby memory = 3.50 GB
|
35
|
-
> model2 = Candle::
|
135
|
+
> model2 = Candle::EmbeddingModel.new
|
36
136
|
# Ruby memory = 7.04 GB
|
37
137
|
> model2 = nil
|
38
138
|
> GC.start
|
@@ -58,10 +158,128 @@ And the following ruby:
|
|
58
158
|
|
59
159
|
```ruby
|
60
160
|
require 'candle'
|
61
|
-
model = Candle::
|
161
|
+
model = Candle::EmbeddingModel.new
|
62
162
|
embedding = model.embedding("Hi there!")
|
63
163
|
```
|
64
164
|
|
165
|
+
## Document Reranking
|
166
|
+
|
167
|
+
Red-Candle includes support for cross-encoder reranking models, which can be used to reorder documents by relevance to a query. This is particularly useful for improving search results or implementing retrieval-augmented generation (RAG) systems.
|
168
|
+
|
169
|
+
### Basic Usage
|
170
|
+
|
171
|
+
```ruby
|
172
|
+
require 'candle'
|
173
|
+
|
174
|
+
# Initialize the reranker with a cross-encoder model
|
175
|
+
reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2")
|
176
|
+
|
177
|
+
# Define your query and candidate documents
|
178
|
+
query = "How many people live in London?"
|
179
|
+
documents = [
|
180
|
+
"London is known for its financial district",
|
181
|
+
"Around 9 Million people live in London",
|
182
|
+
"The weather in London is often rainy",
|
183
|
+
"London is the capital of England"
|
184
|
+
]
|
185
|
+
|
186
|
+
# Rerank documents by relevance to the query (raw logits)
|
187
|
+
ranked_results = reranker.rerank(query, documents, pooling_method: "pooler", apply_sigmoid: false)
|
188
|
+
|
189
|
+
# Or apply sigmoid activation to get scores between 0 and 1
|
190
|
+
sigmoid_results = reranker.rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
|
191
|
+
|
192
|
+
# The pooler method is the default and is recommended for cross-encoders, as is apply_sigmod, so the above is the same as:
|
193
|
+
ranked_results = reranker.rerank(query, documents)
|
194
|
+
|
195
|
+
# Results are returned as an array of hashes, sorted by relevance
|
196
|
+
e.g.
|
197
|
+
ranked_results.each do |result|
|
198
|
+
puts "Score: #{result[:score].round(4)} - Doc ##{result[:doc_id]}: #{result[:text]}"
|
199
|
+
end
|
200
|
+
# Output:
|
201
|
+
# Score: 1.0 - Doc #1: Around 9 Million people live in London
|
202
|
+
# Score: 0.0438 - Doc #3: London is the capital of England
|
203
|
+
# Score: 0.0085 - Doc #0: London is known for its financial district
|
204
|
+
# Score: 0.0005 - Doc #2: The weather in London is often rainy
|
205
|
+
```
|
206
|
+
|
207
|
+
### Arguments & Activation Functions
|
208
|
+
|
209
|
+
By default, `apply_sigmoid` is `true` (scores between 0 and 1). Set it to `false` to get raw logits. You can also select the pooling method:
|
210
|
+
|
211
|
+
- `pooling_method: "pooler"` (default)
|
212
|
+
- `pooling_method: "cls"`
|
213
|
+
- `pooling_method: "mean"`
|
214
|
+
|
215
|
+
Example without sigmoid activation:
|
216
|
+
|
217
|
+
```ruby
|
218
|
+
# Get raw logits
|
219
|
+
ranked_results = reranker.rerank(query, documents, apply_sigmoid: false)
|
220
|
+
|
221
|
+
ranked_results.each do |result|
|
222
|
+
puts "Score: #{result[:score].round(4)} - Doc ##{result[:doc_id]}: #{result[:text]}"
|
223
|
+
end
|
224
|
+
# Output:
|
225
|
+
# Score: 10.3918 - Doc #1: Around 9 Million people live in London
|
226
|
+
# Score: -3.0829 - Doc #3: London is the capital of England
|
227
|
+
# Score: -4.7619 - Doc #0: London is known for its financial district
|
228
|
+
# Score: -7.5251 - Doc #2: The weather in London is often rainy
|
229
|
+
```
|
230
|
+
|
231
|
+
### Output Format
|
232
|
+
|
233
|
+
The reranker returns an array of hashes, each with the following keys:
|
234
|
+
- `:text` – The original document text
|
235
|
+
- `:score` – The relevance score (raw logit or sigmoid-activated)
|
236
|
+
- `:doc_id` – The original 0-based index of the document in the input array
|
237
|
+
|
238
|
+
This format is compatible with the Informers gem, which returns results as hashes with `:doc_id` and `:score` keys. The `doc_id` allows you to map results back to your original data structure.
|
239
|
+
|
240
|
+
### Pooling Methods
|
241
|
+
|
242
|
+
The reranker supports different pooling strategies for aggregating BERT embeddings:
|
243
|
+
|
244
|
+
```ruby
|
245
|
+
# Use alternative pooling methods
|
246
|
+
# "pooler" (default) - Uses the pooler layer with tanh activation (most accurate for cross-encoders)
|
247
|
+
# "cls" - Uses raw [CLS] token embeddings without the pooler layer
|
248
|
+
# "mean" - Mean pooling across all tokens (not recommended for cross-encoders)
|
249
|
+
|
250
|
+
# With raw logits
|
251
|
+
results = reranker.rerank_with_pooling(query, documents, "cls")
|
252
|
+
|
253
|
+
# With sigmoid activation
|
254
|
+
results = reranker.rerank_sigmoid_with_pooling(query, documents, "cls")
|
255
|
+
```
|
256
|
+
|
257
|
+
Note: The default "pooler" method is recommended as it matches how cross-encoder models are trained. Other pooling methods may produce different ranking results.
|
258
|
+
|
259
|
+
### CUDA Support
|
260
|
+
|
261
|
+
For faster inference on NVIDIA GPUs:
|
262
|
+
|
263
|
+
```ruby
|
264
|
+
# Initialize with CUDA if available (falls back to CPU if not)
|
265
|
+
reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2", cuda: true)
|
266
|
+
```
|
267
|
+
|
268
|
+
### How It Works
|
269
|
+
|
270
|
+
Cross-encoder reranking models differ from bi-encoder embedding models:
|
271
|
+
|
272
|
+
- **Bi-encoders** (like the embedding models above) encode queries and documents separately into dense vectors
|
273
|
+
- **Cross-encoders** process the query and document together, allowing for more nuanced relevance scoring
|
274
|
+
|
275
|
+
The reranker uses a BERT-based architecture that:
|
276
|
+
1. Concatenates the query and document with special tokens: `[CLS] query [SEP] document [SEP]`
|
277
|
+
2. Processes them jointly through BERT layers
|
278
|
+
3. Applies a pooler layer (dense + tanh) to the [CLS] token
|
279
|
+
4. Uses a classifier layer to produce a single relevance score
|
280
|
+
|
281
|
+
This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
|
282
|
+
|
65
283
|
## Development
|
66
284
|
|
67
285
|
FORK IT!
|
data/ext/candle/Cargo.toml
CHANGED
@@ -2,18 +2,29 @@
|
|
2
2
|
name = "candle"
|
3
3
|
version = "0.1.0"
|
4
4
|
edition = "2021"
|
5
|
+
build = "build.rs"
|
5
6
|
|
6
7
|
[lib]
|
7
8
|
crate-type = ["cdylib"]
|
8
9
|
|
9
10
|
[dependencies]
|
10
|
-
candle-core = "0.
|
11
|
-
candle-nn = "0.
|
12
|
-
candle-transformers = "0.
|
13
|
-
tokenizers = { version = "0.
|
14
|
-
hf-hub = "0.3
|
15
|
-
half = "2"
|
16
|
-
magnus = "0.
|
11
|
+
candle-core = { version = "0.9.1" }
|
12
|
+
candle-nn = { version = "0.9.1" }
|
13
|
+
candle-transformers = { version = "0.9.1" }
|
14
|
+
tokenizers = { version = "0.21.1", default-features = true, features = ["fancy-regex"] }
|
15
|
+
hf-hub = "0.4.3"
|
16
|
+
half = "2.6.0"
|
17
|
+
magnus = "0.7.1"
|
18
|
+
safetensors = "0.3"
|
19
|
+
serde_json = "1.0"
|
20
|
+
serde = { version = "1.0", features = ["derive"] }
|
21
|
+
tokio = { version = "1.45", features = ["rt", "macros"] }
|
22
|
+
rand = "0.8"
|
17
23
|
|
18
|
-
[
|
19
|
-
|
24
|
+
[features]
|
25
|
+
default = []
|
26
|
+
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
|
27
|
+
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
28
|
+
cudnn = ["candle-core/cudnn", "cuda"]
|
29
|
+
mkl = ["candle-core/mkl"]
|
30
|
+
accelerate = ["candle-core/accelerate"]
|
data/ext/candle/extconf.rb
CHANGED
@@ -1,4 +1,80 @@
|
|
1
1
|
require "mkmf"
|
2
2
|
require "rb_sys/mkmf"
|
3
3
|
|
4
|
-
|
4
|
+
# Detect available hardware acceleration
|
5
|
+
features = []
|
6
|
+
|
7
|
+
# Force CPU-only build if requested
|
8
|
+
if ENV['CANDLE_FORCE_CPU']
|
9
|
+
puts "CANDLE_FORCE_CPU is set, building CPU-only version"
|
10
|
+
else
|
11
|
+
# Check for CUDA
|
12
|
+
cuda_available = ENV['CUDA_ROOT'] || ENV['CUDA_PATH'] || ENV['CANDLE_CUDA_PATH'] ||
|
13
|
+
File.exist?('/usr/local/cuda') || File.exist?('/opt/cuda') ||
|
14
|
+
(RbConfig::CONFIG['host_os'] =~ /mswin|mingw|cygwin/ &&
|
15
|
+
(File.exist?('C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA') ||
|
16
|
+
File.exist?('C:\CUDA')))
|
17
|
+
|
18
|
+
cuda_enabled = ENV['CANDLE_ENABLE_CUDA']
|
19
|
+
|
20
|
+
if cuda_available && cuda_enabled
|
21
|
+
puts "CUDA detected and enabled via CANDLE_ENABLE_CUDA"
|
22
|
+
features << 'cuda'
|
23
|
+
|
24
|
+
# Check if CUDNN should be enabled
|
25
|
+
if ENV['CANDLE_CUDNN'] || ENV['CUDNN_ROOT']
|
26
|
+
puts "CUDNN support enabled"
|
27
|
+
features << 'cudnn'
|
28
|
+
end
|
29
|
+
elsif cuda_available && !cuda_enabled
|
30
|
+
puts "=" * 80
|
31
|
+
puts "CUDA detected but not enabled."
|
32
|
+
puts "To enable CUDA support (coming soon), set CANDLE_ENABLE_CUDA=1"
|
33
|
+
puts "=" * 80
|
34
|
+
end
|
35
|
+
|
36
|
+
# Check for Metal (macOS only)
|
37
|
+
if RbConfig::CONFIG['host_os'] =~ /darwin/
|
38
|
+
puts "macOS detected, enabling Metal support"
|
39
|
+
features << 'metal'
|
40
|
+
|
41
|
+
# Also enable Accelerate framework on macOS
|
42
|
+
puts "Enabling Accelerate framework support"
|
43
|
+
features << 'accelerate'
|
44
|
+
end
|
45
|
+
|
46
|
+
# Check for Intel MKL
|
47
|
+
mkl_available = ENV['MKLROOT'] || ENV['MKL_ROOT'] ||
|
48
|
+
File.exist?('/opt/intel/mkl') ||
|
49
|
+
File.exist?('/opt/intel/oneapi/mkl/latest')
|
50
|
+
|
51
|
+
if mkl_available && !features.include?('accelerate') # Don't use both MKL and Accelerate
|
52
|
+
puts "Intel MKL detected, enabling MKL support"
|
53
|
+
features << 'mkl'
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
# Allow manual override of features
|
58
|
+
if ENV['CANDLE_FEATURES']
|
59
|
+
manual_features = ENV['CANDLE_FEATURES'].split(',').map(&:strip)
|
60
|
+
puts "Manual features override: #{manual_features.join(', ')}"
|
61
|
+
features = manual_features
|
62
|
+
end
|
63
|
+
|
64
|
+
# Display selected features
|
65
|
+
unless features.empty?
|
66
|
+
puts "Building with features: #{features.join(', ')}"
|
67
|
+
else
|
68
|
+
puts "Building CPU-only version (no acceleration features detected)"
|
69
|
+
end
|
70
|
+
|
71
|
+
# Create the Rust makefile with proper feature configuration
|
72
|
+
create_rust_makefile("candle/candle") do |r|
|
73
|
+
# Pass the features to rb_sys
|
74
|
+
r.features = features unless features.empty?
|
75
|
+
|
76
|
+
# Pass through any additional cargo flags
|
77
|
+
if ENV['CANDLE_CARGO_FLAGS']
|
78
|
+
r.extra_cargo_args = ENV['CANDLE_CARGO_FLAGS'].split(' ')
|
79
|
+
end
|
80
|
+
end
|
data/ext/candle/src/lib.rs
CHANGED
@@ -1,103 +1,142 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Ruby};
|
2
2
|
|
3
|
-
use crate::
|
3
|
+
use crate::ruby::candle_utils;
|
4
|
+
use crate::ruby::{DType, Device, QTensor, Result as RbResult, Tensor};
|
4
5
|
|
5
|
-
pub mod
|
6
|
+
pub mod llm;
|
7
|
+
pub mod reranker;
|
8
|
+
pub mod ruby;
|
9
|
+
|
10
|
+
// Configuration detection from build.rs
|
11
|
+
#[cfg(all(has_metal, not(force_cpu)))]
|
12
|
+
const DEFAULT_DEVICE: &str = "metal";
|
13
|
+
|
14
|
+
#[cfg(all(has_cuda, not(has_metal), not(force_cpu)))]
|
15
|
+
const DEFAULT_DEVICE: &str = "cuda";
|
16
|
+
|
17
|
+
#[cfg(any(force_cpu, not(any(has_metal, has_cuda))))]
|
18
|
+
const DEFAULT_DEVICE: &str = "cpu";
|
19
|
+
|
20
|
+
// Export build configuration for runtime checks
|
21
|
+
pub fn get_build_info() -> magnus::RHash {
|
22
|
+
let ruby = magnus::Ruby::get().unwrap();
|
23
|
+
let hash = ruby.hash_new();
|
24
|
+
|
25
|
+
let _ = hash.aset("default_device", DEFAULT_DEVICE);
|
26
|
+
let _ = hash.aset("cuda_available", cfg!(feature = "cuda"));
|
27
|
+
let _ = hash.aset("metal_available", cfg!(feature = "metal"));
|
28
|
+
let _ = hash.aset("mkl_available", cfg!(feature = "mkl"));
|
29
|
+
let _ = hash.aset("accelerate_available", cfg!(feature = "accelerate"));
|
30
|
+
let _ = hash.aset("cudnn_available", cfg!(feature = "cudnn"));
|
31
|
+
|
32
|
+
hash
|
33
|
+
}
|
6
34
|
|
7
35
|
#[magnus::init]
|
8
36
|
fn init(ruby: &Ruby) -> RbResult<()> {
|
9
37
|
let rb_candle = ruby.define_module("Candle")?;
|
38
|
+
|
39
|
+
// Export build info
|
40
|
+
rb_candle.define_singleton_method("build_info", function!(get_build_info, 0))?;
|
41
|
+
|
42
|
+
ruby::init_embedding_model(rb_candle)?;
|
43
|
+
ruby::init_llm(rb_candle)?;
|
44
|
+
reranker::init(rb_candle)?;
|
10
45
|
candle_utils(rb_candle)?;
|
11
46
|
let rb_tensor = rb_candle.define_class("Tensor", Ruby::class_object(ruby))?;
|
12
|
-
rb_tensor.define_singleton_method("new", function!(
|
13
|
-
// rb_tensor.define_singleton_method("cat", function!(
|
14
|
-
// rb_tensor.define_singleton_method("stack", function!(
|
15
|
-
rb_tensor.define_singleton_method("rand", function!(
|
16
|
-
rb_tensor.define_singleton_method("randn", function!(
|
17
|
-
rb_tensor.define_singleton_method("ones", function!(
|
18
|
-
rb_tensor.define_singleton_method("zeros", function!(
|
19
|
-
rb_tensor.define_method("values", method!(
|
20
|
-
rb_tensor.define_method("
|
21
|
-
rb_tensor.define_method("
|
22
|
-
rb_tensor.define_method("
|
23
|
-
rb_tensor.define_method("
|
24
|
-
rb_tensor.define_method("
|
25
|
-
rb_tensor.define_method("
|
26
|
-
rb_tensor.define_method("
|
27
|
-
rb_tensor.define_method("
|
28
|
-
rb_tensor.define_method("
|
29
|
-
rb_tensor.define_method("
|
30
|
-
rb_tensor.define_method("
|
31
|
-
rb_tensor.define_method("
|
32
|
-
rb_tensor.define_method("
|
33
|
-
rb_tensor.define_method("
|
34
|
-
rb_tensor.define_method("
|
35
|
-
rb_tensor.define_method("
|
36
|
-
rb_tensor.define_method("
|
37
|
-
rb_tensor.define_method("
|
38
|
-
rb_tensor.define_method("
|
39
|
-
rb_tensor.define_method("
|
40
|
-
rb_tensor.define_method("
|
41
|
-
rb_tensor.define_method("
|
42
|
-
rb_tensor.define_method("
|
43
|
-
rb_tensor.define_method("
|
44
|
-
rb_tensor.define_method("
|
45
|
-
rb_tensor.define_method("
|
46
|
-
rb_tensor.define_method("
|
47
|
-
rb_tensor.define_method("
|
48
|
-
rb_tensor.define_method("
|
49
|
-
rb_tensor.define_method("
|
50
|
-
rb_tensor.define_method("
|
51
|
-
rb_tensor.define_method("
|
52
|
-
rb_tensor.define_method("
|
53
|
-
rb_tensor.define_method("
|
54
|
-
rb_tensor.define_method("
|
55
|
-
rb_tensor.define_method("
|
56
|
-
rb_tensor.define_method("
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
// rb_tensor.define_method("
|
63
|
-
rb_tensor.define_method("
|
64
|
-
rb_tensor.define_method("
|
65
|
-
rb_tensor.define_method("
|
66
|
-
rb_tensor.define_method("
|
67
|
-
rb_tensor.define_method("
|
68
|
-
rb_tensor.define_method("
|
69
|
-
rb_tensor.define_method("
|
70
|
-
rb_tensor.define_method("
|
47
|
+
rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
|
48
|
+
// rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
|
49
|
+
// rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
|
50
|
+
rb_tensor.define_singleton_method("rand", function!(Tensor::rand, 2))?;
|
51
|
+
rb_tensor.define_singleton_method("randn", function!(Tensor::randn, 2))?;
|
52
|
+
rb_tensor.define_singleton_method("ones", function!(Tensor::ones, 2))?;
|
53
|
+
rb_tensor.define_singleton_method("zeros", function!(Tensor::zeros, 2))?;
|
54
|
+
rb_tensor.define_method("values", method!(Tensor::values, 0))?;
|
55
|
+
rb_tensor.define_method("values_f32", method!(Tensor::values_f32, 0))?;
|
56
|
+
rb_tensor.define_method("item", method!(Tensor::item, 0))?;
|
57
|
+
rb_tensor.define_method("shape", method!(Tensor::shape, 0))?;
|
58
|
+
rb_tensor.define_method("stride", method!(Tensor::stride, 0))?;
|
59
|
+
rb_tensor.define_method("dtype", method!(Tensor::dtype, 0))?;
|
60
|
+
rb_tensor.define_method("device", method!(Tensor::device, 0))?;
|
61
|
+
rb_tensor.define_method("rank", method!(Tensor::rank, 0))?;
|
62
|
+
rb_tensor.define_method("elem_count", method!(Tensor::elem_count, 0))?;
|
63
|
+
rb_tensor.define_method("sin", method!(Tensor::sin, 0))?;
|
64
|
+
rb_tensor.define_method("cos", method!(Tensor::cos, 0))?;
|
65
|
+
rb_tensor.define_method("log", method!(Tensor::log, 0))?;
|
66
|
+
rb_tensor.define_method("sqr", method!(Tensor::sqr, 0))?;
|
67
|
+
rb_tensor.define_method("mean", method!(Tensor::mean, 1))?;
|
68
|
+
rb_tensor.define_method("sum", method!(Tensor::sum, 1))?;
|
69
|
+
rb_tensor.define_method("sqrt", method!(Tensor::sqrt, 0))?;
|
70
|
+
rb_tensor.define_method("/", method!(Tensor::__truediv__, 1))?; // Accepts Tensor, Float, or Integer
|
71
|
+
rb_tensor.define_method("recip", method!(Tensor::recip, 0))?;
|
72
|
+
rb_tensor.define_method("exp", method!(Tensor::exp, 0))?;
|
73
|
+
rb_tensor.define_method("powf", method!(Tensor::powf, 1))?;
|
74
|
+
rb_tensor.define_method("index_select", method!(Tensor::index_select, 2))?;
|
75
|
+
rb_tensor.define_method("matmul", method!(Tensor::matmul, 1))?;
|
76
|
+
rb_tensor.define_method("broadcast_add", method!(Tensor::broadcast_add, 1))?;
|
77
|
+
rb_tensor.define_method("broadcast_sub", method!(Tensor::broadcast_sub, 1))?;
|
78
|
+
rb_tensor.define_method("broadcast_mul", method!(Tensor::broadcast_mul, 1))?;
|
79
|
+
rb_tensor.define_method("broadcast_div", method!(Tensor::broadcast_div, 1))?;
|
80
|
+
rb_tensor.define_method("where_cond", method!(Tensor::where_cond, 2))?;
|
81
|
+
rb_tensor.define_method("+", method!(Tensor::__add__, 1))?;
|
82
|
+
rb_tensor.define_method("*", method!(Tensor::__mul__, 1))?;
|
83
|
+
rb_tensor.define_method("-", method!(Tensor::__sub__, 1))?;
|
84
|
+
rb_tensor.define_method("reshape", method!(Tensor::reshape, 1))?;
|
85
|
+
rb_tensor.define_method("broadcast_as", method!(Tensor::broadcast_as, 1))?;
|
86
|
+
rb_tensor.define_method("broadcast_left", method!(Tensor::broadcast_left, 1))?;
|
87
|
+
rb_tensor.define_method("squeeze", method!(Tensor::squeeze, 1))?;
|
88
|
+
rb_tensor.define_method("unsqueeze", method!(Tensor::unsqueeze, 1))?;
|
89
|
+
rb_tensor.define_method("get", method!(Tensor::get, 1))?;
|
90
|
+
rb_tensor.define_method("[]", method!(Tensor::get, 1))?;
|
91
|
+
rb_tensor.define_method("transpose", method!(Tensor::transpose, 2))?;
|
92
|
+
rb_tensor.define_method("narrow", method!(Tensor::narrow, 3))?;
|
93
|
+
rb_tensor.define_method("argmax_keepdim", method!(Tensor::argmax_keepdim, 1))?;
|
94
|
+
rb_tensor.define_method("argmin_keepdim", method!(Tensor::argmin_keepdim, 1))?;
|
95
|
+
rb_tensor.define_method("max_keepdim", method!(Tensor::max_keepdim, 1))?;
|
96
|
+
rb_tensor.define_method("min_keepdim", method!(Tensor::min_keepdim, 1))?;
|
97
|
+
// rb_tensor.define_method("eq", method!(Tensor::eq, 1))?;
|
98
|
+
// rb_tensor.define_method("ne", method!(Tensor::ne, 1))?;
|
99
|
+
// rb_tensor.define_method("lt", method!(Tensor::lt, 1))?;
|
100
|
+
// rb_tensor.define_method("gt", method!(Tensor::gt, 1))?;
|
101
|
+
// rb_tensor.define_method("ge", method!(Tensor::ge, 1))?;
|
102
|
+
// rb_tensor.define_method("le", method!(Tensor::le, 1))?;
|
103
|
+
rb_tensor.define_method("sum_all", method!(Tensor::sum_all, 0))?;
|
104
|
+
rb_tensor.define_method("mean_all", method!(Tensor::mean_all, 0))?;
|
105
|
+
rb_tensor.define_method("flatten_from", method!(Tensor::flatten_from, 1))?;
|
106
|
+
rb_tensor.define_method("flatten_to", method!(Tensor::flatten_to, 1))?;
|
107
|
+
rb_tensor.define_method("flatten_all", method!(Tensor::flatten_all, 0))?;
|
108
|
+
rb_tensor.define_method("t", method!(Tensor::t, 0))?;
|
109
|
+
rb_tensor.define_method("contiguous", method!(Tensor::contiguous, 0))?;
|
110
|
+
rb_tensor.define_method("is_contiguous", method!(Tensor::is_contiguous, 0))?;
|
71
111
|
rb_tensor.define_method(
|
72
112
|
"is_fortran_contiguous",
|
73
|
-
method!(
|
113
|
+
method!(Tensor::is_fortran_contiguous, 0),
|
74
114
|
)?;
|
75
|
-
rb_tensor.define_method("detach", method!(
|
76
|
-
rb_tensor.define_method("copy", method!(
|
77
|
-
rb_tensor.define_method("to_dtype", method!(
|
78
|
-
rb_tensor.define_method("to_device", method!(
|
79
|
-
rb_tensor.define_method("to_s", method!(
|
80
|
-
rb_tensor.define_method("inspect", method!(
|
115
|
+
rb_tensor.define_method("detach", method!(Tensor::detach, 0))?;
|
116
|
+
rb_tensor.define_method("copy", method!(Tensor::copy, 0))?;
|
117
|
+
rb_tensor.define_method("to_dtype", method!(Tensor::to_dtype, 1))?;
|
118
|
+
rb_tensor.define_method("to_device", method!(Tensor::to_device, 1))?;
|
119
|
+
rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
|
120
|
+
rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
|
81
121
|
|
82
122
|
let rb_dtype = rb_candle.define_class("DType", Ruby::class_object(ruby))?;
|
83
|
-
rb_dtype.define_method("to_s", method!(
|
84
|
-
rb_dtype.define_method("inspect", method!(
|
123
|
+
rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
|
124
|
+
rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
|
85
125
|
|
86
126
|
let rb_device = rb_candle.define_class("Device", Ruby::class_object(ruby))?;
|
87
|
-
rb_device.
|
88
|
-
rb_device.
|
127
|
+
rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
|
128
|
+
rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
|
129
|
+
rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
|
130
|
+
rb_device.define_singleton_method("available_devices", function!(ruby::device::available_devices, 0))?;
|
131
|
+
rb_device.define_singleton_method("default", function!(ruby::device::default_device, 0))?;
|
132
|
+
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
133
|
+
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
89
134
|
|
90
135
|
let rb_qtensor = rb_candle.define_class("QTensor", Ruby::class_object(ruby))?;
|
91
|
-
rb_qtensor.define_method("ggml_dtype", method!(
|
92
|
-
rb_qtensor.define_method("rank", method!(
|
93
|
-
rb_qtensor.define_method("shape", method!(
|
94
|
-
rb_qtensor.define_method("dequantize", method!(
|
95
|
-
|
96
|
-
let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
|
97
|
-
rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
|
98
|
-
rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
|
99
|
-
rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
|
100
|
-
rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;
|
136
|
+
rb_qtensor.define_method("ggml_dtype", method!(QTensor::ggml_dtype, 0))?;
|
137
|
+
rb_qtensor.define_method("rank", method!(QTensor::rank, 0))?;
|
138
|
+
rb_qtensor.define_method("shape", method!(QTensor::shape, 0))?;
|
139
|
+
rb_qtensor.define_method("dequantize", method!(QTensor::dequantize, 0))?;
|
101
140
|
|
102
141
|
Ok(())
|
103
142
|
}
|