red-candle 1.0.0.pre.1 → 1.0.0.pre.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +12 -0
- data/LICENSE +22 -0
- data/Rakefile +95 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/lib.rs +6 -96
- data/ext/candle/src/llm/generation_config.rs +49 -0
- data/ext/candle/src/llm/mistral.rs +325 -0
- data/ext/candle/src/llm/mod.rs +68 -0
- data/ext/candle/src/llm/text_generation.rs +141 -0
- data/ext/candle/src/reranker.rs +267 -0
- data/ext/candle/src/ruby/device.rs +197 -0
- data/ext/candle/src/ruby/dtype.rs +37 -0
- data/ext/candle/src/ruby/embedding_model.rs +410 -0
- data/ext/candle/src/ruby/errors.rs +13 -0
- data/ext/candle/src/ruby/llm.rs +295 -0
- data/ext/candle/src/ruby/mod.rs +21 -0
- data/ext/candle/src/ruby/qtensor.rs +69 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/tensor.rs +654 -0
- data/ext/candle/src/ruby/utils.rs +88 -0
- data/lib/candle/version.rb +1 -1
- metadata +22 -1
@@ -0,0 +1,267 @@
|
|
1
|
+
use magnus::{class, function, method, prelude::*, Error, RModule, Float, RArray};
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config};
|
3
|
+
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
4
|
+
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
5
|
+
use hf_hub::{api::sync::Api, Repo, RepoType};
|
6
|
+
use tokenizers::{PaddingParams, Tokenizer, EncodeInput};
|
7
|
+
use std::thread;
|
8
|
+
use crate::ruby::{Device as RbDevice, Result as RbResult};
|
9
|
+
|
10
|
+
#[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
|
11
|
+
pub struct Reranker {
|
12
|
+
model: BertModel,
|
13
|
+
tokenizer: Tokenizer,
|
14
|
+
pooler: Linear,
|
15
|
+
classifier: Linear,
|
16
|
+
device: CoreDevice,
|
17
|
+
}
|
18
|
+
|
19
|
+
impl Reranker {
|
20
|
+
pub fn new(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
|
21
|
+
let device = device.unwrap_or(RbDevice::Cpu).as_device()?;
|
22
|
+
Self::new_with_core_device(model_id, device)
|
23
|
+
}
|
24
|
+
|
25
|
+
fn new_with_core_device(model_id: String, device: CoreDevice) -> Result<Self, Error> {
|
26
|
+
let device_clone = device.clone();
|
27
|
+
let handle = thread::spawn(move || -> Result<(BertModel, Tokenizer, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
|
28
|
+
let api = Api::new()?;
|
29
|
+
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
30
|
+
|
31
|
+
// Download model files
|
32
|
+
let config_filename = repo.get("config.json")?;
|
33
|
+
let tokenizer_filename = repo.get("tokenizer.json")?;
|
34
|
+
let weights_filename = repo.get("model.safetensors")?;
|
35
|
+
|
36
|
+
// Load config
|
37
|
+
let config = std::fs::read_to_string(config_filename)?;
|
38
|
+
let config: Config = serde_json::from_str(&config)?;
|
39
|
+
|
40
|
+
// Setup tokenizer with padding
|
41
|
+
let mut tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
42
|
+
let pp = PaddingParams {
|
43
|
+
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
44
|
+
..Default::default()
|
45
|
+
};
|
46
|
+
tokenizer.with_padding(Some(pp));
|
47
|
+
|
48
|
+
// Load model weights
|
49
|
+
let vb = unsafe {
|
50
|
+
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
|
51
|
+
};
|
52
|
+
|
53
|
+
// Load BERT model
|
54
|
+
let model = BertModel::load(vb.pp("bert"), &config)?;
|
55
|
+
|
56
|
+
// Load pooler layer (dense + tanh activation)
|
57
|
+
let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
|
58
|
+
|
59
|
+
// Load classifier layer for cross-encoder (single output score)
|
60
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
61
|
+
|
62
|
+
Ok((model, tokenizer, pooler, classifier))
|
63
|
+
});
|
64
|
+
|
65
|
+
match handle.join() {
|
66
|
+
Ok(Ok((model, tokenizer, pooler, classifier))) => {
|
67
|
+
Ok(Self { model, tokenizer, pooler, classifier, device })
|
68
|
+
}
|
69
|
+
Ok(Err(e)) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
|
70
|
+
Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
pub fn debug_tokenization(&self, query: String, document: String) -> Result<magnus::RHash, Error> {
|
75
|
+
// Create query-document pair for cross-encoder
|
76
|
+
let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
|
77
|
+
|
78
|
+
// Tokenize
|
79
|
+
let encoding = self.tokenizer.encode(query_doc_pair, true)
|
80
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
81
|
+
|
82
|
+
// Get token information
|
83
|
+
let token_ids = encoding.get_ids().to_vec();
|
84
|
+
let token_type_ids = encoding.get_type_ids().to_vec();
|
85
|
+
let attention_mask = encoding.get_attention_mask().to_vec();
|
86
|
+
let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
|
87
|
+
|
88
|
+
// Create result hash
|
89
|
+
let result = magnus::RHash::new();
|
90
|
+
result.aset("token_ids", RArray::from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
91
|
+
result.aset("token_type_ids", RArray::from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
|
92
|
+
result.aset("attention_mask", RArray::from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
|
93
|
+
result.aset("tokens", RArray::from_vec(tokens))?;
|
94
|
+
|
95
|
+
Ok(result)
|
96
|
+
}
|
97
|
+
|
98
|
+
pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> Result<RArray, Error> {
|
99
|
+
let documents: Vec<String> = documents.to_vec()?;
|
100
|
+
|
101
|
+
// Create query-document pairs for cross-encoder
|
102
|
+
let query_and_docs: Vec<EncodeInput> = documents
|
103
|
+
.iter()
|
104
|
+
.map(|d| (query.clone(), d.clone()).into())
|
105
|
+
.collect();
|
106
|
+
|
107
|
+
// Tokenize batch
|
108
|
+
let encodings = self.tokenizer.encode_batch(query_and_docs, true)
|
109
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
|
110
|
+
|
111
|
+
// Convert to tensors
|
112
|
+
let token_ids = encodings
|
113
|
+
.iter()
|
114
|
+
.map(|e| e.get_ids().to_vec())
|
115
|
+
.collect::<Vec<_>>();
|
116
|
+
|
117
|
+
let token_type_ids = encodings
|
118
|
+
.iter()
|
119
|
+
.map(|e| e.get_type_ids().to_vec())
|
120
|
+
.collect::<Vec<_>>();
|
121
|
+
|
122
|
+
let token_ids = Tensor::new(token_ids, &self.device)
|
123
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create tensor: {}", e)))?;
|
124
|
+
let token_type_ids = Tensor::new(token_type_ids, &self.device)
|
125
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create token type ids tensor: {}", e)))?;
|
126
|
+
let attention_mask = token_ids.ne(0u32)
|
127
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create attention mask: {}", e)))?;
|
128
|
+
|
129
|
+
// Forward pass through BERT
|
130
|
+
let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
131
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Model forward pass failed: {}", e)))?;
|
132
|
+
|
133
|
+
// Apply pooling based on the specified method
|
134
|
+
let pooled_embeddings = match pooling_method.as_str() {
|
135
|
+
"pooler" => {
|
136
|
+
// Extract [CLS] token and apply pooler (dense + tanh)
|
137
|
+
// Work around Metal indexing issue by using narrow instead of i((.., 0))
|
138
|
+
let cls_embeddings = if self.device.is_metal() {
|
139
|
+
// Metal has issues with tensor indexing, use a different approach
|
140
|
+
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
141
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
142
|
+
|
143
|
+
// Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
|
144
|
+
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
145
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
146
|
+
|
147
|
+
// Extract CLS tokens (first token of each sequence)
|
148
|
+
let mut cls_vecs = Vec::new();
|
149
|
+
for i in 0..batch_size {
|
150
|
+
let start_idx = i * _seq_len;
|
151
|
+
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
152
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
153
|
+
cls_vecs.push(cls_vec);
|
154
|
+
}
|
155
|
+
|
156
|
+
// Stack the CLS vectors
|
157
|
+
Tensor::cat(&cls_vecs, 0)
|
158
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
159
|
+
.contiguous()
|
160
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
161
|
+
} else {
|
162
|
+
embeddings.i((.., 0))
|
163
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
164
|
+
};
|
165
|
+
// Ensure tensor is contiguous before linear layer
|
166
|
+
let cls_embeddings = cls_embeddings.contiguous()
|
167
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
|
168
|
+
let pooled = self.pooler.forward(&cls_embeddings)
|
169
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
|
170
|
+
pooled.tanh()
|
171
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tanh activation failed: {}", e)))?
|
172
|
+
},
|
173
|
+
"cls" => {
|
174
|
+
// Just use the [CLS] token embeddings directly (no pooler layer)
|
175
|
+
// Work around Metal indexing issue
|
176
|
+
let cls_embeddings = if self.device.is_metal() {
|
177
|
+
// Use same approach as pooler method
|
178
|
+
let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
|
179
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
|
180
|
+
|
181
|
+
let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
|
182
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
|
183
|
+
|
184
|
+
let mut cls_vecs = Vec::new();
|
185
|
+
for i in 0..batch_size {
|
186
|
+
let start_idx = i * _seq_len;
|
187
|
+
let cls_vec = reshaped.narrow(0, start_idx, 1)
|
188
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
|
189
|
+
cls_vecs.push(cls_vec);
|
190
|
+
}
|
191
|
+
|
192
|
+
Tensor::cat(&cls_vecs, 0)
|
193
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
|
194
|
+
.contiguous()
|
195
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
|
196
|
+
} else {
|
197
|
+
embeddings.i((.., 0))
|
198
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
|
199
|
+
};
|
200
|
+
// Ensure contiguous for classifier
|
201
|
+
cls_embeddings.contiguous()
|
202
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
|
203
|
+
},
|
204
|
+
"mean" => {
|
205
|
+
// Mean pooling across all tokens
|
206
|
+
let (_batch, seq_len, _hidden) = embeddings.dims3()
|
207
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get tensor dimensions: {}", e)))?;
|
208
|
+
let sum = embeddings.sum(1)
|
209
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to sum embeddings: {}", e)))?;
|
210
|
+
(sum / (seq_len as f64))
|
211
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to compute mean: {}", e)))?
|
212
|
+
},
|
213
|
+
_ => return Err(Error::new(magnus::exception::runtime_error(),
|
214
|
+
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
|
215
|
+
};
|
216
|
+
|
217
|
+
// Apply classifier to get relevance scores (raw logits)
|
218
|
+
// Ensure tensor is contiguous before linear layer
|
219
|
+
let pooled_embeddings = pooled_embeddings.contiguous()
|
220
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
|
221
|
+
let logits = self.classifier.forward(&pooled_embeddings)
|
222
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Classifier forward failed: {}", e)))?;
|
223
|
+
let scores = logits.squeeze(1)
|
224
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to squeeze tensor: {}", e)))?;
|
225
|
+
|
226
|
+
// Optionally apply sigmoid activation
|
227
|
+
let scores = if apply_sigmoid {
|
228
|
+
sigmoid(&scores)
|
229
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Sigmoid failed: {}", e)))?
|
230
|
+
} else {
|
231
|
+
scores
|
232
|
+
};
|
233
|
+
|
234
|
+
let scores_vec: Vec<f32> = scores.to_vec1()
|
235
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to convert scores to vec: {}", e)))?;
|
236
|
+
|
237
|
+
// Create tuples with document, score, and original index
|
238
|
+
let mut ranked_docs: Vec<(String, f32, usize)> = documents
|
239
|
+
.into_iter()
|
240
|
+
.zip(scores_vec)
|
241
|
+
.enumerate()
|
242
|
+
.map(|(idx, (doc, score))| (doc, score, idx))
|
243
|
+
.collect();
|
244
|
+
|
245
|
+
// Sort documents by relevance score (descending)
|
246
|
+
ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
247
|
+
|
248
|
+
// Build result array with [doc, score, doc_id]
|
249
|
+
let result_array = RArray::new();
|
250
|
+
for (doc, score, doc_id) in ranked_docs {
|
251
|
+
let tuple = RArray::new();
|
252
|
+
tuple.push(doc)?;
|
253
|
+
tuple.push(Float::from_f64(score as f64))?;
|
254
|
+
tuple.push(doc_id)?;
|
255
|
+
result_array.push(tuple)?;
|
256
|
+
}
|
257
|
+
Ok(result_array)
|
258
|
+
}
|
259
|
+
}
|
260
|
+
|
261
|
+
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
262
|
+
let c_reranker = rb_candle.define_class("Reranker", class::object())?;
|
263
|
+
c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
|
264
|
+
c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
|
265
|
+
c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
|
266
|
+
Ok(())
|
267
|
+
}
|
@@ -0,0 +1,197 @@
|
|
1
|
+
use magnus::Error;
|
2
|
+
use magnus::{function, method, class, RModule, Module, Object};
|
3
|
+
|
4
|
+
use ::candle_core::Device as CoreDevice;
|
5
|
+
use crate::ruby::Result as RbResult;
|
6
|
+
|
7
|
+
#[cfg(any(feature = "cuda", feature = "metal"))]
|
8
|
+
use crate::ruby::errors::wrap_candle_err;
|
9
|
+
|
10
|
+
#[cfg(feature = "cuda")]
|
11
|
+
static CUDA_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
|
12
|
+
|
13
|
+
#[cfg(feature = "metal")]
|
14
|
+
static METAL_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
|
15
|
+
|
16
|
+
/// Get list of available devices based on compile-time features
|
17
|
+
pub fn available_devices() -> Vec<String> {
|
18
|
+
let devices = vec!["cpu".to_string()];
|
19
|
+
|
20
|
+
#[cfg(all(feature = "cuda", not(force_cpu)))]
|
21
|
+
let devices = {
|
22
|
+
let mut devices = devices;
|
23
|
+
devices.push("cuda".to_string());
|
24
|
+
devices
|
25
|
+
};
|
26
|
+
|
27
|
+
#[cfg(all(feature = "metal", not(force_cpu)))]
|
28
|
+
let devices = {
|
29
|
+
let mut devices = devices;
|
30
|
+
devices.push("metal".to_string());
|
31
|
+
devices
|
32
|
+
};
|
33
|
+
|
34
|
+
devices
|
35
|
+
}
|
36
|
+
|
37
|
+
/// Get the default device based on what's available
|
38
|
+
pub fn default_device() -> Device {
|
39
|
+
// Return based on compiled features, not detection
|
40
|
+
#[cfg(all(feature = "metal", not(force_cpu)))]
|
41
|
+
{
|
42
|
+
Device::Metal
|
43
|
+
}
|
44
|
+
|
45
|
+
#[cfg(all(feature = "cuda", not(feature = "metal"), not(force_cpu)))]
|
46
|
+
{
|
47
|
+
Device::Cuda
|
48
|
+
}
|
49
|
+
|
50
|
+
#[cfg(not(any(all(feature = "metal", not(force_cpu)), all(feature = "cuda", not(feature = "metal"), not(force_cpu)))))]
|
51
|
+
{
|
52
|
+
Device::Cpu
|
53
|
+
}
|
54
|
+
}
|
55
|
+
|
56
|
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
57
|
+
#[magnus::wrap(class = "Candle::Device")]
|
58
|
+
pub enum Device {
|
59
|
+
Cpu,
|
60
|
+
Cuda,
|
61
|
+
Metal,
|
62
|
+
}
|
63
|
+
|
64
|
+
impl Device {
|
65
|
+
/// Create a CPU device
|
66
|
+
pub fn cpu() -> Self {
|
67
|
+
Self::Cpu
|
68
|
+
}
|
69
|
+
|
70
|
+
/// Create a CUDA device (GPU)
|
71
|
+
pub fn cuda() -> RbResult<Self> {
|
72
|
+
#[cfg(not(feature = "cuda"))]
|
73
|
+
{
|
74
|
+
return Err(Error::new(
|
75
|
+
magnus::exception::runtime_error(),
|
76
|
+
"CUDA support not compiled in. Rebuild with CUDA available.",
|
77
|
+
));
|
78
|
+
}
|
79
|
+
|
80
|
+
#[cfg(feature = "cuda")]
|
81
|
+
Ok(Self::Cuda)
|
82
|
+
}
|
83
|
+
|
84
|
+
/// Create a Metal device (Apple GPU)
|
85
|
+
pub fn metal() -> RbResult<Self> {
|
86
|
+
#[cfg(not(feature = "metal"))]
|
87
|
+
{
|
88
|
+
return Err(Error::new(
|
89
|
+
magnus::exception::runtime_error(),
|
90
|
+
"Metal support not compiled in. Rebuild on macOS.",
|
91
|
+
));
|
92
|
+
}
|
93
|
+
|
94
|
+
#[cfg(feature = "metal")]
|
95
|
+
Ok(Self::Metal)
|
96
|
+
}
|
97
|
+
|
98
|
+
pub fn from_device(device: &CoreDevice) -> Self {
|
99
|
+
match device {
|
100
|
+
CoreDevice::Cpu => Self::Cpu,
|
101
|
+
CoreDevice::Cuda(_) => Self::Cuda,
|
102
|
+
CoreDevice::Metal(_) => Self::Metal,
|
103
|
+
}
|
104
|
+
}
|
105
|
+
|
106
|
+
pub fn as_device(&self) -> RbResult<CoreDevice> {
|
107
|
+
match self {
|
108
|
+
Self::Cpu => Ok(CoreDevice::Cpu),
|
109
|
+
Self::Cuda => {
|
110
|
+
#[cfg(not(feature = "cuda"))]
|
111
|
+
{
|
112
|
+
return Err(Error::new(
|
113
|
+
magnus::exception::runtime_error(),
|
114
|
+
"CUDA support not compiled in. Rebuild with CUDA available.",
|
115
|
+
));
|
116
|
+
}
|
117
|
+
|
118
|
+
#[cfg(feature = "cuda")]
|
119
|
+
{
|
120
|
+
let mut device = CUDA_DEVICE.lock().unwrap();
|
121
|
+
if let Some(device) = device.as_ref() {
|
122
|
+
return Ok(device.clone());
|
123
|
+
};
|
124
|
+
// Note: new_cuda() is used here (not cuda_if_available) because
|
125
|
+
// we want to fail if CUDA isn't available at runtime, not fall back to CPU
|
126
|
+
let d = CoreDevice::new_cuda(0).map_err(wrap_candle_err)?;
|
127
|
+
*device = Some(d.clone());
|
128
|
+
Ok(d)
|
129
|
+
}
|
130
|
+
}
|
131
|
+
Self::Metal => {
|
132
|
+
#[cfg(not(feature = "metal"))]
|
133
|
+
{
|
134
|
+
return Err(Error::new(
|
135
|
+
magnus::exception::runtime_error(),
|
136
|
+
"Metal support not compiled in. Rebuild on macOS.",
|
137
|
+
));
|
138
|
+
}
|
139
|
+
|
140
|
+
#[cfg(feature = "metal")]
|
141
|
+
{
|
142
|
+
let mut device = METAL_DEVICE.lock().unwrap();
|
143
|
+
if let Some(device) = device.as_ref() {
|
144
|
+
return Ok(device.clone());
|
145
|
+
};
|
146
|
+
let d = CoreDevice::new_metal(0).map_err(wrap_candle_err)?;
|
147
|
+
*device = Some(d.clone());
|
148
|
+
Ok(d)
|
149
|
+
}
|
150
|
+
}
|
151
|
+
}
|
152
|
+
}
|
153
|
+
|
154
|
+
pub fn __repr__(&self) -> String {
|
155
|
+
match self {
|
156
|
+
Self::Cpu => "cpu".to_string(),
|
157
|
+
Self::Cuda => "cuda".to_string(),
|
158
|
+
Self::Metal => "metal".to_string(),
|
159
|
+
}
|
160
|
+
}
|
161
|
+
|
162
|
+
pub fn __str__(&self) -> String {
|
163
|
+
self.__repr__()
|
164
|
+
}
|
165
|
+
}
|
166
|
+
|
167
|
+
impl magnus::TryConvert for Device {
|
168
|
+
fn try_convert(val: magnus::Value) -> RbResult<Self> {
|
169
|
+
// First check if it's already a wrapped Device object
|
170
|
+
if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
|
171
|
+
return Ok(*device);
|
172
|
+
}
|
173
|
+
|
174
|
+
// Otherwise try to convert from string
|
175
|
+
let device = magnus::RString::try_convert(val)?;
|
176
|
+
let device = unsafe { device.as_str() }.unwrap();
|
177
|
+
let device = match device {
|
178
|
+
"cpu" => Device::Cpu,
|
179
|
+
"cuda" => Device::Cuda,
|
180
|
+
"metal" => Device::Metal,
|
181
|
+
_ => return Err(Error::new(magnus::exception::arg_error(), "invalid device")),
|
182
|
+
};
|
183
|
+
Ok(device)
|
184
|
+
}
|
185
|
+
}
|
186
|
+
|
187
|
+
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
188
|
+
let rb_device = rb_candle.define_class("Device", class::object())?;
|
189
|
+
rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
|
190
|
+
rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
|
191
|
+
rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
|
192
|
+
rb_device.define_singleton_method("available_devices", function!(available_devices, 0))?;
|
193
|
+
rb_device.define_singleton_method("default", function!(default_device, 0))?;
|
194
|
+
rb_device.define_method("to_s", method!(Device::__str__, 0))?;
|
195
|
+
rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
|
196
|
+
Ok(())
|
197
|
+
}
|
@@ -0,0 +1,37 @@
|
|
1
|
+
use magnus::value::ReprValue;
|
2
|
+
use magnus::{method, class, RModule, Error, Module};
|
3
|
+
|
4
|
+
use ::candle_core::DType as CoreDType;
|
5
|
+
use crate::ruby::Result as RbResult;
|
6
|
+
|
7
|
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
8
|
+
#[magnus::wrap(class = "Candle::DType", free_immediately, size)]
|
9
|
+
|
10
|
+
/// A `candle` dtype.
|
11
|
+
pub struct DType(pub CoreDType);
|
12
|
+
|
13
|
+
impl DType {
|
14
|
+
pub fn __repr__(&self) -> String {
|
15
|
+
format!("{:?}", self.0)
|
16
|
+
}
|
17
|
+
|
18
|
+
pub fn __str__(&self) -> String {
|
19
|
+
self.__repr__()
|
20
|
+
}
|
21
|
+
}
|
22
|
+
|
23
|
+
impl DType {
|
24
|
+
pub fn from_rbobject(dtype: magnus::Symbol) -> RbResult<Self> {
|
25
|
+
let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
|
26
|
+
use std::str::FromStr;
|
27
|
+
let dtype = CoreDType::from_str(&dtype).unwrap();
|
28
|
+
Ok(Self(dtype))
|
29
|
+
}
|
30
|
+
}
|
31
|
+
|
32
|
+
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
33
|
+
let rb_dtype = rb_candle.define_class("DType", class::object())?;
|
34
|
+
rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
|
35
|
+
rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
|
36
|
+
Ok(())
|
37
|
+
}
|