red-candle 0.0.5 → 1.0.0.pre.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/Cargo.toml CHANGED
@@ -1,2 +1,6 @@
1
1
  [workspace]
2
2
  members = ["ext/candle"]
3
+ resolver = "2"
4
+
5
+ [profile.test]
6
+ opt-level = 3
data/README.md CHANGED
@@ -3,7 +3,7 @@
3
3
  [![build](https://github.com/assaydepot/red-candle/actions/workflows/build.yml/badge.svg)](https://github.com/assaydepot/red-candle/actions/workflows/build.yml)
4
4
  [![Gem Version](https://badge.fury.io/rb/red-candle.svg)](https://badge.fury.io/rb/red-candle)
5
5
 
6
- 🕯️ [candle](https://github.com/huggingface/candle) - Minimalist ML framework - for Ruby
6
+ [candle](https://github.com/huggingface/candle) - Minimalist ML framework - for Ruby
7
7
 
8
8
  ## Usage
9
9
 
@@ -20,19 +20,119 @@ x = x.reshape([3, 2])
20
20
 
21
21
  ```ruby
22
22
  require 'candle'
23
- model = Candle::Model.new
23
+
24
+ # Default model (JinaBERT) on CPU
25
+ model = Candle::EmbeddingModel.new
26
+ embedding = model.embedding("Hi there!")
27
+
28
+ # Specify device (CPU, Metal, or CUDA)
29
+ device = Candle::Device.cpu # or Candle::Device.metal, Candle::Device.cuda
30
+ model = Candle::EmbeddingModel.new(
31
+ model_path: "jinaai/jina-embeddings-v2-base-en",
32
+ device: device
33
+ )
24
34
  embedding = model.embedding("Hi there!")
35
+
36
+ # Reranker also supports device selection
37
+ reranker = Candle::Reranker.new(
38
+ model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2",
39
+ device: device
40
+ )
41
+ results = reranker.rerank("query", ["doc1", "doc2", "doc3"])
42
+ ```
43
+
44
+ ## LLM Support
45
+
46
+ Red-Candle now supports Large Language Models (LLMs) with GPU acceleration!
47
+
48
+ > ### ⚠️ Huggingface login warning
49
+ >
50
+ > Many models, including the one below, require you to agree to the terms. You'll need to:
51
+ > 1. Login to [Huggingface](https://huggingface.co)
52
+ > 2. Agree to the terms. For example: [here](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
53
+ > 3. Authenticate your session. Simplest way is with `huggingface-cli login`. Detail here: [Huggingface CLI](https://huggingface.co/docs/huggingface_hub/en/guides/cli)
54
+ >
55
+ > More details here: [Huggingface Authentication](HUGGINGFACE.md)
56
+
57
+ ```ruby
58
+ require 'candle'
59
+
60
+ # Choose your device
61
+ device = Candle::Device.cpu # CPU (default)
62
+ device = Candle::Device.metal # Apple GPU (Metal)
63
+ device = Candle::Device.cuda # NVIDIA GPU (CUDA)
64
+
65
+ # Load a model
66
+ llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
67
+
68
+ # Generate text
69
+ response = llm.generate("What is Ruby?", config: Candle::GenerationConfig.balanced)
70
+
71
+ # Stream generation
72
+ llm.generate_stream("Tell me a story", config: Candle::GenerationConfig.balanced) do |token|
73
+ print token
74
+ end
75
+
76
+ # Chat interface
77
+ messages = [
78
+ { role: "system", content: "You are a helpful assistant." },
79
+ { role: "user", content: "Explain Ruby in one sentence." }
80
+ ]
81
+ response = llm.chat(messages)
82
+ ```
83
+
84
+ ### GPU Acceleration
85
+
86
+ ```ruby
87
+ # CPU works for all models
88
+ device = Candle::Device.cpu
89
+ llm = Candle::LLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device: device)
90
+
91
+ # Metal
92
+ device = Candle::Device.metal
93
+
94
+ # CUDA support (for NVIDIA GPUs COMING SOON)
95
+ device = Candle::Device.cuda # Linux/Windows with NVIDIA GPU
96
+ ```
97
+
98
+ ## ⚠️ Model Format Requirement: Safetensors Only
99
+
100
+ Red-Candle **only supports embedding models that provide their weights in the [safetensors](https://github.com/huggingface/safetensors) format** (i.e., the model repo must contain a `model.safetensors` file). If the model repo does not provide the required file, loading will fail with a clear error. Most official BERT and DistilBERT models do **not** provide safetensors; many Sentence Transformers and JinaBERT models do.
101
+
102
+ **If you encounter an error like:**
103
+
104
+ ```
105
+ RuntimeError: model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors.
106
+ ```
107
+
108
+ this means the selected model is not compatible. Please choose a model repo that provides the required file.
109
+
110
+ ## Supported Embedding Models
111
+
112
+ Red-Candle supports the following embedding model types from Hugging Face:
113
+
114
+ 1. `Candle::EmbeddingModelType::JINA_BERT` - Jina BERT models (e.g., `jinaai/jina-embeddings-v2-base-en`) (**safetensors required**)
115
+ 2. `Candle::EmbeddingModelType::STANDARD_BERT` - Standard BERT models (e.g., `sentence-transformers/all-MiniLM-L6-v2`) (**safetensors required**)
116
+ 3. `Candle::EmbeddingModelType::DISTILBERT` - DistilBERT models (e.g., `distilbert-base-uncased-finetuned-sst-2-english`) (**safetensors required**)
117
+
118
+ > **Note:** Most official BERT and DistilBERT models do _not_ provide safetensors. Please check the model repo before use.
119
+
120
+ You can get a list of all supported model types and suggested models paths:
121
+
122
+ ```ruby
123
+ Candle::EmbeddingModelType.all # Returns all supported model types
124
+ Candle::EmbeddingModelType.suggested_model_paths # Returns hash of suggested models for each type
25
125
  ```
26
126
 
27
127
  ## A note on memory usage
28
- The `Candle::Model` defaults to the `jinaai/jina-embeddings-v2-base-en` model with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer (both from [HuggingFace](https://huggingface.co)). With this configuration the model takes a little more than 3GB of memory running on my Mac. The memory stays with the instantiated `Candle::Model` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example:
128
+ The default model (`jinaai/jina-embeddings-v2-base-en` with the `sentence-transformers/all-MiniLM-L6-v2` tokenizer, both from [HuggingFace](https://huggingface.co)) takes a little more than 3GB of memory running on a Mac. The memory stays with the instantiated `Candle::EmbeddingModel` class, if you instantiate more than one, you'll use more memory. Likewise, if you let it go out of scope and call the garbage collector, you'll free the memory. For example:
29
129
 
30
130
  ```ruby
31
131
  > require 'candle'
32
132
  # Ruby memory = 25.9 MB
33
- > model = Candle::Model.new
133
+ > model = Candle::EmbeddingModel.new
34
134
  # Ruby memory = 3.50 GB
35
- > model2 = Candle::Model.new
135
+ > model2 = Candle::EmbeddingModel.new
36
136
  # Ruby memory = 7.04 GB
37
137
  > model2 = nil
38
138
  > GC.start
@@ -58,10 +158,128 @@ And the following ruby:
58
158
 
59
159
  ```ruby
60
160
  require 'candle'
61
- model = Candle::Model.new
161
+ model = Candle::EmbeddingModel.new
62
162
  embedding = model.embedding("Hi there!")
63
163
  ```
64
164
 
165
+ ## Document Reranking
166
+
167
+ Red-Candle includes support for cross-encoder reranking models, which can be used to reorder documents by relevance to a query. This is particularly useful for improving search results or implementing retrieval-augmented generation (RAG) systems.
168
+
169
+ ### Basic Usage
170
+
171
+ ```ruby
172
+ require 'candle'
173
+
174
+ # Initialize the reranker with a cross-encoder model
175
+ reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2")
176
+
177
+ # Define your query and candidate documents
178
+ query = "How many people live in London?"
179
+ documents = [
180
+ "London is known for its financial district",
181
+ "Around 9 Million people live in London",
182
+ "The weather in London is often rainy",
183
+ "London is the capital of England"
184
+ ]
185
+
186
+ # Rerank documents by relevance to the query (raw logits)
187
+ ranked_results = reranker.rerank(query, documents, pooling_method: "pooler", apply_sigmoid: false)
188
+
189
+ # Or apply sigmoid activation to get scores between 0 and 1
190
+ sigmoid_results = reranker.rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
191
+
192
+ # The pooler method is the default and is recommended for cross-encoders, as is apply_sigmod, so the above is the same as:
193
+ ranked_results = reranker.rerank(query, documents)
194
+
195
+ # Results are returned as an array of hashes, sorted by relevance
196
+ e.g.
197
+ ranked_results.each do |result|
198
+ puts "Score: #{result[:score].round(4)} - Doc ##{result[:doc_id]}: #{result[:text]}"
199
+ end
200
+ # Output:
201
+ # Score: 1.0 - Doc #1: Around 9 Million people live in London
202
+ # Score: 0.0438 - Doc #3: London is the capital of England
203
+ # Score: 0.0085 - Doc #0: London is known for its financial district
204
+ # Score: 0.0005 - Doc #2: The weather in London is often rainy
205
+ ```
206
+
207
+ ### Arguments & Activation Functions
208
+
209
+ By default, `apply_sigmoid` is `true` (scores between 0 and 1). Set it to `false` to get raw logits. You can also select the pooling method:
210
+
211
+ - `pooling_method: "pooler"` (default)
212
+ - `pooling_method: "cls"`
213
+ - `pooling_method: "mean"`
214
+
215
+ Example without sigmoid activation:
216
+
217
+ ```ruby
218
+ # Get raw logits
219
+ ranked_results = reranker.rerank(query, documents, apply_sigmoid: false)
220
+
221
+ ranked_results.each do |result|
222
+ puts "Score: #{result[:score].round(4)} - Doc ##{result[:doc_id]}: #{result[:text]}"
223
+ end
224
+ # Output:
225
+ # Score: 10.3918 - Doc #1: Around 9 Million people live in London
226
+ # Score: -3.0829 - Doc #3: London is the capital of England
227
+ # Score: -4.7619 - Doc #0: London is known for its financial district
228
+ # Score: -7.5251 - Doc #2: The weather in London is often rainy
229
+ ```
230
+
231
+ ### Output Format
232
+
233
+ The reranker returns an array of hashes, each with the following keys:
234
+ - `:text` – The original document text
235
+ - `:score` – The relevance score (raw logit or sigmoid-activated)
236
+ - `:doc_id` – The original 0-based index of the document in the input array
237
+
238
+ This format is compatible with the Informers gem, which returns results as hashes with `:doc_id` and `:score` keys. The `doc_id` allows you to map results back to your original data structure.
239
+
240
+ ### Pooling Methods
241
+
242
+ The reranker supports different pooling strategies for aggregating BERT embeddings:
243
+
244
+ ```ruby
245
+ # Use alternative pooling methods
246
+ # "pooler" (default) - Uses the pooler layer with tanh activation (most accurate for cross-encoders)
247
+ # "cls" - Uses raw [CLS] token embeddings without the pooler layer
248
+ # "mean" - Mean pooling across all tokens (not recommended for cross-encoders)
249
+
250
+ # With raw logits
251
+ results = reranker.rerank_with_pooling(query, documents, "cls")
252
+
253
+ # With sigmoid activation
254
+ results = reranker.rerank_sigmoid_with_pooling(query, documents, "cls")
255
+ ```
256
+
257
+ Note: The default "pooler" method is recommended as it matches how cross-encoder models are trained. Other pooling methods may produce different ranking results.
258
+
259
+ ### CUDA Support
260
+
261
+ For faster inference on NVIDIA GPUs:
262
+
263
+ ```ruby
264
+ # Initialize with CUDA if available (falls back to CPU if not)
265
+ reranker = Candle::Reranker.new(model_path: "cross-encoder/ms-marco-MiniLM-L-12-v2", cuda: true)
266
+ ```
267
+
268
+ ### How It Works
269
+
270
+ Cross-encoder reranking models differ from bi-encoder embedding models:
271
+
272
+ - **Bi-encoders** (like the embedding models above) encode queries and documents separately into dense vectors
273
+ - **Cross-encoders** process the query and document together, allowing for more nuanced relevance scoring
274
+
275
+ The reranker uses a BERT-based architecture that:
276
+ 1. Concatenates the query and document with special tokens: `[CLS] query [SEP] document [SEP]`
277
+ 2. Processes them jointly through BERT layers
278
+ 3. Applies a pooler layer (dense + tanh) to the [CLS] token
279
+ 4. Uses a classifier layer to produce a single relevance score
280
+
281
+ This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
282
+
65
283
  ## Development
66
284
 
67
285
  FORK IT!
@@ -2,18 +2,29 @@
2
2
  name = "candle"
3
3
  version = "0.1.0"
4
4
  edition = "2021"
5
+ build = "build.rs"
5
6
 
6
7
  [lib]
7
8
  crate-type = ["cdylib"]
8
9
 
9
10
  [dependencies]
10
- candle-core = "0.4.1"
11
- candle-nn = "0.4.1"
12
- candle-transformers = "0.4.1"
13
- tokenizers = { version = "0.15.0", default-features = true, features = ["fancy-regex"], exclude = ["onig"] }
14
- hf-hub = "0.3.0"
15
- half = "2"
16
- magnus = "0.6"
11
+ candle-core = { version = "0.9.1" }
12
+ candle-nn = { version = "0.9.1" }
13
+ candle-transformers = { version = "0.9.1" }
14
+ tokenizers = { version = "0.21.1", default-features = true, features = ["fancy-regex"] }
15
+ hf-hub = "0.4.3"
16
+ half = "2.6.0"
17
+ magnus = "0.7.1"
18
+ safetensors = "0.3"
19
+ serde_json = "1.0"
20
+ serde = { version = "1.0", features = ["derive"] }
21
+ tokio = { version = "1.45", features = ["rt", "macros"] }
22
+ rand = "0.8"
17
23
 
18
- [profile.test]
19
- opt-level = 3
24
+ [features]
25
+ default = []
26
+ metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"]
27
+ cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
28
+ cudnn = ["candle-core/cudnn", "cuda"]
29
+ mkl = ["candle-core/mkl"]
30
+ accelerate = ["candle-core/accelerate"]
@@ -1,4 +1,80 @@
1
1
  require "mkmf"
2
2
  require "rb_sys/mkmf"
3
3
 
4
- create_rust_makefile("candle/candle")
4
+ # Detect available hardware acceleration
5
+ features = []
6
+
7
+ # Force CPU-only build if requested
8
+ if ENV['CANDLE_FORCE_CPU']
9
+ puts "CANDLE_FORCE_CPU is set, building CPU-only version"
10
+ else
11
+ # Check for CUDA
12
+ cuda_available = ENV['CUDA_ROOT'] || ENV['CUDA_PATH'] || ENV['CANDLE_CUDA_PATH'] ||
13
+ File.exist?('/usr/local/cuda') || File.exist?('/opt/cuda') ||
14
+ (RbConfig::CONFIG['host_os'] =~ /mswin|mingw|cygwin/ &&
15
+ (File.exist?('C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA') ||
16
+ File.exist?('C:\CUDA')))
17
+
18
+ cuda_enabled = ENV['CANDLE_ENABLE_CUDA']
19
+
20
+ if cuda_available && cuda_enabled
21
+ puts "CUDA detected and enabled via CANDLE_ENABLE_CUDA"
22
+ features << 'cuda'
23
+
24
+ # Check if CUDNN should be enabled
25
+ if ENV['CANDLE_CUDNN'] || ENV['CUDNN_ROOT']
26
+ puts "CUDNN support enabled"
27
+ features << 'cudnn'
28
+ end
29
+ elsif cuda_available && !cuda_enabled
30
+ puts "=" * 80
31
+ puts "CUDA detected but not enabled."
32
+ puts "To enable CUDA support (coming soon), set CANDLE_ENABLE_CUDA=1"
33
+ puts "=" * 80
34
+ end
35
+
36
+ # Check for Metal (macOS only)
37
+ if RbConfig::CONFIG['host_os'] =~ /darwin/
38
+ puts "macOS detected, enabling Metal support"
39
+ features << 'metal'
40
+
41
+ # Also enable Accelerate framework on macOS
42
+ puts "Enabling Accelerate framework support"
43
+ features << 'accelerate'
44
+ end
45
+
46
+ # Check for Intel MKL
47
+ mkl_available = ENV['MKLROOT'] || ENV['MKL_ROOT'] ||
48
+ File.exist?('/opt/intel/mkl') ||
49
+ File.exist?('/opt/intel/oneapi/mkl/latest')
50
+
51
+ if mkl_available && !features.include?('accelerate') # Don't use both MKL and Accelerate
52
+ puts "Intel MKL detected, enabling MKL support"
53
+ features << 'mkl'
54
+ end
55
+ end
56
+
57
+ # Allow manual override of features
58
+ if ENV['CANDLE_FEATURES']
59
+ manual_features = ENV['CANDLE_FEATURES'].split(',').map(&:strip)
60
+ puts "Manual features override: #{manual_features.join(', ')}"
61
+ features = manual_features
62
+ end
63
+
64
+ # Display selected features
65
+ unless features.empty?
66
+ puts "Building with features: #{features.join(', ')}"
67
+ else
68
+ puts "Building CPU-only version (no acceleration features detected)"
69
+ end
70
+
71
+ # Create the Rust makefile with proper feature configuration
72
+ create_rust_makefile("candle/candle") do |r|
73
+ # Pass the features to rb_sys
74
+ r.features = features unless features.empty?
75
+
76
+ # Pass through any additional cargo flags
77
+ if ENV['CANDLE_CARGO_FLAGS']
78
+ r.extra_cargo_args = ENV['CANDLE_CARGO_FLAGS'].split(' ')
79
+ end
80
+ end
@@ -1,103 +1,142 @@
1
1
  use magnus::{function, method, prelude::*, Ruby};
2
2
 
3
- use crate::model::{candle_utils, RbModel, RbDType, RbDevice, RbQTensor, RbResult, RbTensor};
3
+ use crate::ruby::candle_utils;
4
+ use crate::ruby::{DType, Device, QTensor, Result as RbResult, Tensor};
4
5
 
5
- pub mod model;
6
+ pub mod llm;
7
+ pub mod reranker;
8
+ pub mod ruby;
9
+
10
+ // Configuration detection from build.rs
11
+ #[cfg(all(has_metal, not(force_cpu)))]
12
+ const DEFAULT_DEVICE: &str = "metal";
13
+
14
+ #[cfg(all(has_cuda, not(has_metal), not(force_cpu)))]
15
+ const DEFAULT_DEVICE: &str = "cuda";
16
+
17
+ #[cfg(any(force_cpu, not(any(has_metal, has_cuda))))]
18
+ const DEFAULT_DEVICE: &str = "cpu";
19
+
20
+ // Export build configuration for runtime checks
21
+ pub fn get_build_info() -> magnus::RHash {
22
+ let ruby = magnus::Ruby::get().unwrap();
23
+ let hash = ruby.hash_new();
24
+
25
+ let _ = hash.aset("default_device", DEFAULT_DEVICE);
26
+ let _ = hash.aset("cuda_available", cfg!(feature = "cuda"));
27
+ let _ = hash.aset("metal_available", cfg!(feature = "metal"));
28
+ let _ = hash.aset("mkl_available", cfg!(feature = "mkl"));
29
+ let _ = hash.aset("accelerate_available", cfg!(feature = "accelerate"));
30
+ let _ = hash.aset("cudnn_available", cfg!(feature = "cudnn"));
31
+
32
+ hash
33
+ }
6
34
 
7
35
  #[magnus::init]
8
36
  fn init(ruby: &Ruby) -> RbResult<()> {
9
37
  let rb_candle = ruby.define_module("Candle")?;
38
+
39
+ // Export build info
40
+ rb_candle.define_singleton_method("build_info", function!(get_build_info, 0))?;
41
+
42
+ ruby::init_embedding_model(rb_candle)?;
43
+ ruby::init_llm(rb_candle)?;
44
+ reranker::init(rb_candle)?;
10
45
  candle_utils(rb_candle)?;
11
46
  let rb_tensor = rb_candle.define_class("Tensor", Ruby::class_object(ruby))?;
12
- rb_tensor.define_singleton_method("new", function!(RbTensor::new, 2))?;
13
- // rb_tensor.define_singleton_method("cat", function!(RbTensor::cat, 2))?;
14
- // rb_tensor.define_singleton_method("stack", function!(RbTensor::stack, 2))?;
15
- rb_tensor.define_singleton_method("rand", function!(RbTensor::rand, 1))?;
16
- rb_tensor.define_singleton_method("randn", function!(RbTensor::randn, 1))?;
17
- rb_tensor.define_singleton_method("ones", function!(RbTensor::ones, 1))?;
18
- rb_tensor.define_singleton_method("zeros", function!(RbTensor::zeros, 1))?;
19
- rb_tensor.define_method("values", method!(RbTensor::values, 0))?;
20
- rb_tensor.define_method("shape", method!(RbTensor::shape, 0))?;
21
- rb_tensor.define_method("stride", method!(RbTensor::stride, 0))?;
22
- rb_tensor.define_method("dtype", method!(RbTensor::dtype, 0))?;
23
- rb_tensor.define_method("device", method!(RbTensor::device, 0))?;
24
- rb_tensor.define_method("rank", method!(RbTensor::rank, 0))?;
25
- rb_tensor.define_method("elem_count", method!(RbTensor::elem_count, 0))?;
26
- rb_tensor.define_method("sin", method!(RbTensor::sin, 0))?;
27
- rb_tensor.define_method("cos", method!(RbTensor::cos, 0))?;
28
- rb_tensor.define_method("log", method!(RbTensor::log, 0))?;
29
- rb_tensor.define_method("sqr", method!(RbTensor::sqr, 0))?;
30
- rb_tensor.define_method("sqrt", method!(RbTensor::sqrt, 0))?;
31
- rb_tensor.define_method("recip", method!(RbTensor::recip, 0))?;
32
- rb_tensor.define_method("exp", method!(RbTensor::exp, 0))?;
33
- rb_tensor.define_method("powf", method!(RbTensor::powf, 1))?;
34
- rb_tensor.define_method("index_select", method!(RbTensor::index_select, 2))?;
35
- rb_tensor.define_method("matmul", method!(RbTensor::matmul, 1))?;
36
- rb_tensor.define_method("broadcast_add", method!(RbTensor::broadcast_add, 1))?;
37
- rb_tensor.define_method("broadcast_sub", method!(RbTensor::broadcast_sub, 1))?;
38
- rb_tensor.define_method("broadcast_mul", method!(RbTensor::broadcast_mul, 1))?;
39
- rb_tensor.define_method("broadcast_div", method!(RbTensor::broadcast_div, 1))?;
40
- rb_tensor.define_method("where_cond", method!(RbTensor::where_cond, 2))?;
41
- rb_tensor.define_method("+", method!(RbTensor::__add__, 1))?;
42
- rb_tensor.define_method("*", method!(RbTensor::__mul__, 1))?;
43
- rb_tensor.define_method("-", method!(RbTensor::__sub__, 1))?;
44
- rb_tensor.define_method("reshape", method!(RbTensor::reshape, 1))?;
45
- rb_tensor.define_method("broadcast_as", method!(RbTensor::broadcast_as, 1))?;
46
- rb_tensor.define_method("broadcast_left", method!(RbTensor::broadcast_left, 1))?;
47
- rb_tensor.define_method("squeeze", method!(RbTensor::squeeze, 1))?;
48
- rb_tensor.define_method("unsqueeze", method!(RbTensor::unsqueeze, 1))?;
49
- rb_tensor.define_method("get", method!(RbTensor::get, 1))?;
50
- rb_tensor.define_method("[]", method!(RbTensor::get, 1))?;
51
- rb_tensor.define_method("transpose", method!(RbTensor::transpose, 2))?;
52
- rb_tensor.define_method("narrow", method!(RbTensor::narrow, 3))?;
53
- rb_tensor.define_method("argmax_keepdim", method!(RbTensor::argmax_keepdim, 1))?;
54
- rb_tensor.define_method("argmin_keepdim", method!(RbTensor::argmin_keepdim, 1))?;
55
- rb_tensor.define_method("max_keepdim", method!(RbTensor::max_keepdim, 1))?;
56
- rb_tensor.define_method("min_keepdim", method!(RbTensor::min_keepdim, 1))?;
57
- // rb_tensor.define_method("eq", method!(RbTensor::eq, 1))?;
58
- // rb_tensor.define_method("ne", method!(RbTensor::ne, 1))?;
59
- // rb_tensor.define_method("lt", method!(RbTensor::lt, 1))?;
60
- // rb_tensor.define_method("gt", method!(RbTensor::gt, 1))?;
61
- // rb_tensor.define_method("ge", method!(RbTensor::ge, 1))?;
62
- // rb_tensor.define_method("le", method!(RbTensor::le, 1))?;
63
- rb_tensor.define_method("sum_all", method!(RbTensor::sum_all, 0))?;
64
- rb_tensor.define_method("mean_all", method!(RbTensor::mean_all, 0))?;
65
- rb_tensor.define_method("flatten_from", method!(RbTensor::flatten_from, 1))?;
66
- rb_tensor.define_method("flatten_to", method!(RbTensor::flatten_to, 1))?;
67
- rb_tensor.define_method("flatten_all", method!(RbTensor::flatten_all, 0))?;
68
- rb_tensor.define_method("t", method!(RbTensor::t, 0))?;
69
- rb_tensor.define_method("contiguous", method!(RbTensor::contiguous, 0))?;
70
- rb_tensor.define_method("is_contiguous", method!(RbTensor::is_contiguous, 0))?;
47
+ rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
48
+ // rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
49
+ // rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
50
+ rb_tensor.define_singleton_method("rand", function!(Tensor::rand, 2))?;
51
+ rb_tensor.define_singleton_method("randn", function!(Tensor::randn, 2))?;
52
+ rb_tensor.define_singleton_method("ones", function!(Tensor::ones, 2))?;
53
+ rb_tensor.define_singleton_method("zeros", function!(Tensor::zeros, 2))?;
54
+ rb_tensor.define_method("values", method!(Tensor::values, 0))?;
55
+ rb_tensor.define_method("values_f32", method!(Tensor::values_f32, 0))?;
56
+ rb_tensor.define_method("item", method!(Tensor::item, 0))?;
57
+ rb_tensor.define_method("shape", method!(Tensor::shape, 0))?;
58
+ rb_tensor.define_method("stride", method!(Tensor::stride, 0))?;
59
+ rb_tensor.define_method("dtype", method!(Tensor::dtype, 0))?;
60
+ rb_tensor.define_method("device", method!(Tensor::device, 0))?;
61
+ rb_tensor.define_method("rank", method!(Tensor::rank, 0))?;
62
+ rb_tensor.define_method("elem_count", method!(Tensor::elem_count, 0))?;
63
+ rb_tensor.define_method("sin", method!(Tensor::sin, 0))?;
64
+ rb_tensor.define_method("cos", method!(Tensor::cos, 0))?;
65
+ rb_tensor.define_method("log", method!(Tensor::log, 0))?;
66
+ rb_tensor.define_method("sqr", method!(Tensor::sqr, 0))?;
67
+ rb_tensor.define_method("mean", method!(Tensor::mean, 1))?;
68
+ rb_tensor.define_method("sum", method!(Tensor::sum, 1))?;
69
+ rb_tensor.define_method("sqrt", method!(Tensor::sqrt, 0))?;
70
+ rb_tensor.define_method("/", method!(Tensor::__truediv__, 1))?; // Accepts Tensor, Float, or Integer
71
+ rb_tensor.define_method("recip", method!(Tensor::recip, 0))?;
72
+ rb_tensor.define_method("exp", method!(Tensor::exp, 0))?;
73
+ rb_tensor.define_method("powf", method!(Tensor::powf, 1))?;
74
+ rb_tensor.define_method("index_select", method!(Tensor::index_select, 2))?;
75
+ rb_tensor.define_method("matmul", method!(Tensor::matmul, 1))?;
76
+ rb_tensor.define_method("broadcast_add", method!(Tensor::broadcast_add, 1))?;
77
+ rb_tensor.define_method("broadcast_sub", method!(Tensor::broadcast_sub, 1))?;
78
+ rb_tensor.define_method("broadcast_mul", method!(Tensor::broadcast_mul, 1))?;
79
+ rb_tensor.define_method("broadcast_div", method!(Tensor::broadcast_div, 1))?;
80
+ rb_tensor.define_method("where_cond", method!(Tensor::where_cond, 2))?;
81
+ rb_tensor.define_method("+", method!(Tensor::__add__, 1))?;
82
+ rb_tensor.define_method("*", method!(Tensor::__mul__, 1))?;
83
+ rb_tensor.define_method("-", method!(Tensor::__sub__, 1))?;
84
+ rb_tensor.define_method("reshape", method!(Tensor::reshape, 1))?;
85
+ rb_tensor.define_method("broadcast_as", method!(Tensor::broadcast_as, 1))?;
86
+ rb_tensor.define_method("broadcast_left", method!(Tensor::broadcast_left, 1))?;
87
+ rb_tensor.define_method("squeeze", method!(Tensor::squeeze, 1))?;
88
+ rb_tensor.define_method("unsqueeze", method!(Tensor::unsqueeze, 1))?;
89
+ rb_tensor.define_method("get", method!(Tensor::get, 1))?;
90
+ rb_tensor.define_method("[]", method!(Tensor::get, 1))?;
91
+ rb_tensor.define_method("transpose", method!(Tensor::transpose, 2))?;
92
+ rb_tensor.define_method("narrow", method!(Tensor::narrow, 3))?;
93
+ rb_tensor.define_method("argmax_keepdim", method!(Tensor::argmax_keepdim, 1))?;
94
+ rb_tensor.define_method("argmin_keepdim", method!(Tensor::argmin_keepdim, 1))?;
95
+ rb_tensor.define_method("max_keepdim", method!(Tensor::max_keepdim, 1))?;
96
+ rb_tensor.define_method("min_keepdim", method!(Tensor::min_keepdim, 1))?;
97
+ // rb_tensor.define_method("eq", method!(Tensor::eq, 1))?;
98
+ // rb_tensor.define_method("ne", method!(Tensor::ne, 1))?;
99
+ // rb_tensor.define_method("lt", method!(Tensor::lt, 1))?;
100
+ // rb_tensor.define_method("gt", method!(Tensor::gt, 1))?;
101
+ // rb_tensor.define_method("ge", method!(Tensor::ge, 1))?;
102
+ // rb_tensor.define_method("le", method!(Tensor::le, 1))?;
103
+ rb_tensor.define_method("sum_all", method!(Tensor::sum_all, 0))?;
104
+ rb_tensor.define_method("mean_all", method!(Tensor::mean_all, 0))?;
105
+ rb_tensor.define_method("flatten_from", method!(Tensor::flatten_from, 1))?;
106
+ rb_tensor.define_method("flatten_to", method!(Tensor::flatten_to, 1))?;
107
+ rb_tensor.define_method("flatten_all", method!(Tensor::flatten_all, 0))?;
108
+ rb_tensor.define_method("t", method!(Tensor::t, 0))?;
109
+ rb_tensor.define_method("contiguous", method!(Tensor::contiguous, 0))?;
110
+ rb_tensor.define_method("is_contiguous", method!(Tensor::is_contiguous, 0))?;
71
111
  rb_tensor.define_method(
72
112
  "is_fortran_contiguous",
73
- method!(RbTensor::is_fortran_contiguous, 0),
113
+ method!(Tensor::is_fortran_contiguous, 0),
74
114
  )?;
75
- rb_tensor.define_method("detach", method!(RbTensor::detach, 0))?;
76
- rb_tensor.define_method("copy", method!(RbTensor::copy, 0))?;
77
- rb_tensor.define_method("to_dtype", method!(RbTensor::to_dtype, 1))?;
78
- rb_tensor.define_method("to_device", method!(RbTensor::to_device, 1))?;
79
- rb_tensor.define_method("to_s", method!(RbTensor::__str__, 0))?;
80
- rb_tensor.define_method("inspect", method!(RbTensor::__repr__, 0))?;
115
+ rb_tensor.define_method("detach", method!(Tensor::detach, 0))?;
116
+ rb_tensor.define_method("copy", method!(Tensor::copy, 0))?;
117
+ rb_tensor.define_method("to_dtype", method!(Tensor::to_dtype, 1))?;
118
+ rb_tensor.define_method("to_device", method!(Tensor::to_device, 1))?;
119
+ rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
120
+ rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
81
121
 
82
122
  let rb_dtype = rb_candle.define_class("DType", Ruby::class_object(ruby))?;
83
- rb_dtype.define_method("to_s", method!(RbDType::__str__, 0))?;
84
- rb_dtype.define_method("inspect", method!(RbDType::__repr__, 0))?;
123
+ rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
124
+ rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
85
125
 
86
126
  let rb_device = rb_candle.define_class("Device", Ruby::class_object(ruby))?;
87
- rb_device.define_method("to_s", method!(RbDevice::__str__, 0))?;
88
- rb_device.define_method("inspect", method!(RbDevice::__repr__, 0))?;
127
+ rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
128
+ rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
129
+ rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
130
+ rb_device.define_singleton_method("available_devices", function!(ruby::device::available_devices, 0))?;
131
+ rb_device.define_singleton_method("default", function!(ruby::device::default_device, 0))?;
132
+ rb_device.define_method("to_s", method!(Device::__str__, 0))?;
133
+ rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
89
134
 
90
135
  let rb_qtensor = rb_candle.define_class("QTensor", Ruby::class_object(ruby))?;
91
- rb_qtensor.define_method("ggml_dtype", method!(RbQTensor::ggml_dtype, 0))?;
92
- rb_qtensor.define_method("rank", method!(RbQTensor::rank, 0))?;
93
- rb_qtensor.define_method("shape", method!(RbQTensor::shape, 0))?;
94
- rb_qtensor.define_method("dequantize", method!(RbQTensor::dequantize, 0))?;
95
-
96
- let rb_model = rb_candle.define_class("Model", Ruby::class_object(ruby))?;
97
- rb_model.define_singleton_method("new", function!(RbModel::new, 0))?;
98
- rb_model.define_method("embedding", method!(RbModel::embedding, 1))?;
99
- rb_model.define_method("to_s", method!(RbModel::__str__, 0))?;
100
- rb_model.define_method("inspect", method!(RbModel::__repr__, 0))?;
136
+ rb_qtensor.define_method("ggml_dtype", method!(QTensor::ggml_dtype, 0))?;
137
+ rb_qtensor.define_method("rank", method!(QTensor::rank, 0))?;
138
+ rb_qtensor.define_method("shape", method!(QTensor::shape, 0))?;
139
+ rb_qtensor.define_method("dequantize", method!(QTensor::dequantize, 0))?;
101
140
 
102
141
  Ok(())
103
142
  }