red-candle 0.0.6 → 1.0.0.pre.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,267 @@
1
+ use magnus::{class, function, method, prelude::*, Error, RModule, Float, RArray};
2
+ use candle_transformers::models::bert::{BertModel, Config};
3
+ use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
+ use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
+ use hf_hub::{api::sync::Api, Repo, RepoType};
6
+ use tokenizers::{PaddingParams, Tokenizer, EncodeInput};
7
+ use std::thread;
8
+ use crate::ruby::{Device as RbDevice, Result as RbResult};
9
+
10
+ #[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
11
+ pub struct Reranker {
12
+ model: BertModel,
13
+ tokenizer: Tokenizer,
14
+ pooler: Linear,
15
+ classifier: Linear,
16
+ device: CoreDevice,
17
+ }
18
+
19
+ impl Reranker {
20
+ pub fn new(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
21
+ let device = device.unwrap_or(RbDevice::Cpu).as_device()?;
22
+ Self::new_with_core_device(model_id, device)
23
+ }
24
+
25
+ fn new_with_core_device(model_id: String, device: CoreDevice) -> Result<Self, Error> {
26
+ let device_clone = device.clone();
27
+ let handle = thread::spawn(move || -> Result<(BertModel, Tokenizer, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
28
+ let api = Api::new()?;
29
+ let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
30
+
31
+ // Download model files
32
+ let config_filename = repo.get("config.json")?;
33
+ let tokenizer_filename = repo.get("tokenizer.json")?;
34
+ let weights_filename = repo.get("model.safetensors")?;
35
+
36
+ // Load config
37
+ let config = std::fs::read_to_string(config_filename)?;
38
+ let config: Config = serde_json::from_str(&config)?;
39
+
40
+ // Setup tokenizer with padding
41
+ let mut tokenizer = Tokenizer::from_file(tokenizer_filename)?;
42
+ let pp = PaddingParams {
43
+ strategy: tokenizers::PaddingStrategy::BatchLongest,
44
+ ..Default::default()
45
+ };
46
+ tokenizer.with_padding(Some(pp));
47
+
48
+ // Load model weights
49
+ let vb = unsafe {
50
+ VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device_clone)?
51
+ };
52
+
53
+ // Load BERT model
54
+ let model = BertModel::load(vb.pp("bert"), &config)?;
55
+
56
+ // Load pooler layer (dense + tanh activation)
57
+ let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
58
+
59
+ // Load classifier layer for cross-encoder (single output score)
60
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
61
+
62
+ Ok((model, tokenizer, pooler, classifier))
63
+ });
64
+
65
+ match handle.join() {
66
+ Ok(Ok((model, tokenizer, pooler, classifier))) => {
67
+ Ok(Self { model, tokenizer, pooler, classifier, device })
68
+ }
69
+ Ok(Err(e)) => Err(Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e))),
70
+ Err(_) => Err(Error::new(magnus::exception::runtime_error(), "Thread panicked while loading model")),
71
+ }
72
+ }
73
+
74
+ pub fn debug_tokenization(&self, query: String, document: String) -> Result<magnus::RHash, Error> {
75
+ // Create query-document pair for cross-encoder
76
+ let query_doc_pair: EncodeInput = (query.clone(), document.clone()).into();
77
+
78
+ // Tokenize
79
+ let encoding = self.tokenizer.encode(query_doc_pair, true)
80
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
81
+
82
+ // Get token information
83
+ let token_ids = encoding.get_ids().to_vec();
84
+ let token_type_ids = encoding.get_type_ids().to_vec();
85
+ let attention_mask = encoding.get_attention_mask().to_vec();
86
+ let tokens = encoding.get_tokens().iter().map(|t| t.to_string()).collect::<Vec<_>>();
87
+
88
+ // Create result hash
89
+ let result = magnus::RHash::new();
90
+ result.aset("token_ids", RArray::from_vec(token_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
91
+ result.aset("token_type_ids", RArray::from_vec(token_type_ids.iter().map(|&id| id as i64).collect::<Vec<_>>()))?;
92
+ result.aset("attention_mask", RArray::from_vec(attention_mask.iter().map(|&mask| mask as i64).collect::<Vec<_>>()))?;
93
+ result.aset("tokens", RArray::from_vec(tokens))?;
94
+
95
+ Ok(result)
96
+ }
97
+
98
+ pub fn rerank_with_options(&self, query: String, documents: RArray, pooling_method: String, apply_sigmoid: bool) -> Result<RArray, Error> {
99
+ let documents: Vec<String> = documents.to_vec()?;
100
+
101
+ // Create query-document pairs for cross-encoder
102
+ let query_and_docs: Vec<EncodeInput> = documents
103
+ .iter()
104
+ .map(|d| (query.clone(), d.clone()).into())
105
+ .collect();
106
+
107
+ // Tokenize batch
108
+ let encodings = self.tokenizer.encode_batch(query_and_docs, true)
109
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tokenization failed: {}", e)))?;
110
+
111
+ // Convert to tensors
112
+ let token_ids = encodings
113
+ .iter()
114
+ .map(|e| e.get_ids().to_vec())
115
+ .collect::<Vec<_>>();
116
+
117
+ let token_type_ids = encodings
118
+ .iter()
119
+ .map(|e| e.get_type_ids().to_vec())
120
+ .collect::<Vec<_>>();
121
+
122
+ let token_ids = Tensor::new(token_ids, &self.device)
123
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create tensor: {}", e)))?;
124
+ let token_type_ids = Tensor::new(token_type_ids, &self.device)
125
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create token type ids tensor: {}", e)))?;
126
+ let attention_mask = token_ids.ne(0u32)
127
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create attention mask: {}", e)))?;
128
+
129
+ // Forward pass through BERT
130
+ let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
131
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Model forward pass failed: {}", e)))?;
132
+
133
+ // Apply pooling based on the specified method
134
+ let pooled_embeddings = match pooling_method.as_str() {
135
+ "pooler" => {
136
+ // Extract [CLS] token and apply pooler (dense + tanh)
137
+ // Work around Metal indexing issue by using narrow instead of i((.., 0))
138
+ let cls_embeddings = if self.device.is_metal() {
139
+ // Metal has issues with tensor indexing, use a different approach
140
+ let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
141
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
142
+
143
+ // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
144
+ let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
145
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
146
+
147
+ // Extract CLS tokens (first token of each sequence)
148
+ let mut cls_vecs = Vec::new();
149
+ for i in 0..batch_size {
150
+ let start_idx = i * _seq_len;
151
+ let cls_vec = reshaped.narrow(0, start_idx, 1)
152
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
153
+ cls_vecs.push(cls_vec);
154
+ }
155
+
156
+ // Stack the CLS vectors
157
+ Tensor::cat(&cls_vecs, 0)
158
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
159
+ .contiguous()
160
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
161
+ } else {
162
+ embeddings.i((.., 0))
163
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
164
+ };
165
+ // Ensure tensor is contiguous before linear layer
166
+ let cls_embeddings = cls_embeddings.contiguous()
167
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make cls_embeddings contiguous: {}", e)))?;
168
+ let pooled = self.pooler.forward(&cls_embeddings)
169
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Pooler forward failed: {}", e)))?;
170
+ pooled.tanh()
171
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Tanh activation failed: {}", e)))?
172
+ },
173
+ "cls" => {
174
+ // Just use the [CLS] token embeddings directly (no pooler layer)
175
+ // Work around Metal indexing issue
176
+ let cls_embeddings = if self.device.is_metal() {
177
+ // Use same approach as pooler method
178
+ let (batch_size, _seq_len, hidden_size) = embeddings.dims3()
179
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get dims: {}", e)))?;
180
+
181
+ let reshaped = embeddings.reshape((batch_size * _seq_len, hidden_size))
182
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to reshape: {}", e)))?;
183
+
184
+ let mut cls_vecs = Vec::new();
185
+ for i in 0..batch_size {
186
+ let start_idx = i * _seq_len;
187
+ let cls_vec = reshaped.narrow(0, start_idx, 1)
188
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS: {}", e)))?;
189
+ cls_vecs.push(cls_vec);
190
+ }
191
+
192
+ Tensor::cat(&cls_vecs, 0)
193
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to cat CLS tokens: {}", e)))?
194
+ .contiguous()
195
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make contiguous: {}", e)))?
196
+ } else {
197
+ embeddings.i((.., 0))
198
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to extract CLS token: {}", e)))?
199
+ };
200
+ // Ensure contiguous for classifier
201
+ cls_embeddings.contiguous()
202
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make CLS embeddings contiguous: {}", e)))?
203
+ },
204
+ "mean" => {
205
+ // Mean pooling across all tokens
206
+ let (_batch, seq_len, _hidden) = embeddings.dims3()
207
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to get tensor dimensions: {}", e)))?;
208
+ let sum = embeddings.sum(1)
209
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to sum embeddings: {}", e)))?;
210
+ (sum / (seq_len as f64))
211
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to compute mean: {}", e)))?
212
+ },
213
+ _ => return Err(Error::new(magnus::exception::runtime_error(),
214
+ format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
215
+ };
216
+
217
+ // Apply classifier to get relevance scores (raw logits)
218
+ // Ensure tensor is contiguous before linear layer
219
+ let pooled_embeddings = pooled_embeddings.contiguous()
220
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
221
+ let logits = self.classifier.forward(&pooled_embeddings)
222
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Classifier forward failed: {}", e)))?;
223
+ let scores = logits.squeeze(1)
224
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to squeeze tensor: {}", e)))?;
225
+
226
+ // Optionally apply sigmoid activation
227
+ let scores = if apply_sigmoid {
228
+ sigmoid(&scores)
229
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Sigmoid failed: {}", e)))?
230
+ } else {
231
+ scores
232
+ };
233
+
234
+ let scores_vec: Vec<f32> = scores.to_vec1()
235
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to convert scores to vec: {}", e)))?;
236
+
237
+ // Create tuples with document, score, and original index
238
+ let mut ranked_docs: Vec<(String, f32, usize)> = documents
239
+ .into_iter()
240
+ .zip(scores_vec)
241
+ .enumerate()
242
+ .map(|(idx, (doc, score))| (doc, score, idx))
243
+ .collect();
244
+
245
+ // Sort documents by relevance score (descending)
246
+ ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
247
+
248
+ // Build result array with [doc, score, doc_id]
249
+ let result_array = RArray::new();
250
+ for (doc, score, doc_id) in ranked_docs {
251
+ let tuple = RArray::new();
252
+ tuple.push(doc)?;
253
+ tuple.push(Float::from_f64(score as f64))?;
254
+ tuple.push(doc_id)?;
255
+ result_array.push(tuple)?;
256
+ }
257
+ Ok(result_array)
258
+ }
259
+ }
260
+
261
+ pub fn init(rb_candle: RModule) -> Result<(), Error> {
262
+ let c_reranker = rb_candle.define_class("Reranker", class::object())?;
263
+ c_reranker.define_singleton_method("_create", function!(Reranker::new, 2))?;
264
+ c_reranker.define_method("rerank_with_options", method!(Reranker::rerank_with_options, 4))?;
265
+ c_reranker.define_method("debug_tokenization", method!(Reranker::debug_tokenization, 2))?;
266
+ Ok(())
267
+ }
@@ -0,0 +1,197 @@
1
+ use magnus::Error;
2
+ use magnus::{function, method, class, RModule, Module, Object};
3
+
4
+ use ::candle_core::Device as CoreDevice;
5
+ use crate::ruby::Result as RbResult;
6
+
7
+ #[cfg(any(feature = "cuda", feature = "metal"))]
8
+ use crate::ruby::errors::wrap_candle_err;
9
+
10
+ #[cfg(feature = "cuda")]
11
+ static CUDA_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
12
+
13
+ #[cfg(feature = "metal")]
14
+ static METAL_DEVICE: std::sync::Mutex<Option<CoreDevice>> = std::sync::Mutex::new(None);
15
+
16
+ /// Get list of available devices based on compile-time features
17
+ pub fn available_devices() -> Vec<String> {
18
+ let devices = vec!["cpu".to_string()];
19
+
20
+ #[cfg(all(feature = "cuda", not(force_cpu)))]
21
+ let devices = {
22
+ let mut devices = devices;
23
+ devices.push("cuda".to_string());
24
+ devices
25
+ };
26
+
27
+ #[cfg(all(feature = "metal", not(force_cpu)))]
28
+ let devices = {
29
+ let mut devices = devices;
30
+ devices.push("metal".to_string());
31
+ devices
32
+ };
33
+
34
+ devices
35
+ }
36
+
37
+ /// Get the default device based on what's available
38
+ pub fn default_device() -> Device {
39
+ // Return based on compiled features, not detection
40
+ #[cfg(all(feature = "metal", not(force_cpu)))]
41
+ {
42
+ Device::Metal
43
+ }
44
+
45
+ #[cfg(all(feature = "cuda", not(feature = "metal"), not(force_cpu)))]
46
+ {
47
+ Device::Cuda
48
+ }
49
+
50
+ #[cfg(not(any(all(feature = "metal", not(force_cpu)), all(feature = "cuda", not(feature = "metal"), not(force_cpu)))))]
51
+ {
52
+ Device::Cpu
53
+ }
54
+ }
55
+
56
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
57
+ #[magnus::wrap(class = "Candle::Device")]
58
+ pub enum Device {
59
+ Cpu,
60
+ Cuda,
61
+ Metal,
62
+ }
63
+
64
+ impl Device {
65
+ /// Create a CPU device
66
+ pub fn cpu() -> Self {
67
+ Self::Cpu
68
+ }
69
+
70
+ /// Create a CUDA device (GPU)
71
+ pub fn cuda() -> RbResult<Self> {
72
+ #[cfg(not(feature = "cuda"))]
73
+ {
74
+ return Err(Error::new(
75
+ magnus::exception::runtime_error(),
76
+ "CUDA support not compiled in. Rebuild with CUDA available.",
77
+ ));
78
+ }
79
+
80
+ #[cfg(feature = "cuda")]
81
+ Ok(Self::Cuda)
82
+ }
83
+
84
+ /// Create a Metal device (Apple GPU)
85
+ pub fn metal() -> RbResult<Self> {
86
+ #[cfg(not(feature = "metal"))]
87
+ {
88
+ return Err(Error::new(
89
+ magnus::exception::runtime_error(),
90
+ "Metal support not compiled in. Rebuild on macOS.",
91
+ ));
92
+ }
93
+
94
+ #[cfg(feature = "metal")]
95
+ Ok(Self::Metal)
96
+ }
97
+
98
+ pub fn from_device(device: &CoreDevice) -> Self {
99
+ match device {
100
+ CoreDevice::Cpu => Self::Cpu,
101
+ CoreDevice::Cuda(_) => Self::Cuda,
102
+ CoreDevice::Metal(_) => Self::Metal,
103
+ }
104
+ }
105
+
106
+ pub fn as_device(&self) -> RbResult<CoreDevice> {
107
+ match self {
108
+ Self::Cpu => Ok(CoreDevice::Cpu),
109
+ Self::Cuda => {
110
+ #[cfg(not(feature = "cuda"))]
111
+ {
112
+ return Err(Error::new(
113
+ magnus::exception::runtime_error(),
114
+ "CUDA support not compiled in. Rebuild with CUDA available.",
115
+ ));
116
+ }
117
+
118
+ #[cfg(feature = "cuda")]
119
+ {
120
+ let mut device = CUDA_DEVICE.lock().unwrap();
121
+ if let Some(device) = device.as_ref() {
122
+ return Ok(device.clone());
123
+ };
124
+ // Note: new_cuda() is used here (not cuda_if_available) because
125
+ // we want to fail if CUDA isn't available at runtime, not fall back to CPU
126
+ let d = CoreDevice::new_cuda(0).map_err(wrap_candle_err)?;
127
+ *device = Some(d.clone());
128
+ Ok(d)
129
+ }
130
+ }
131
+ Self::Metal => {
132
+ #[cfg(not(feature = "metal"))]
133
+ {
134
+ return Err(Error::new(
135
+ magnus::exception::runtime_error(),
136
+ "Metal support not compiled in. Rebuild on macOS.",
137
+ ));
138
+ }
139
+
140
+ #[cfg(feature = "metal")]
141
+ {
142
+ let mut device = METAL_DEVICE.lock().unwrap();
143
+ if let Some(device) = device.as_ref() {
144
+ return Ok(device.clone());
145
+ };
146
+ let d = CoreDevice::new_metal(0).map_err(wrap_candle_err)?;
147
+ *device = Some(d.clone());
148
+ Ok(d)
149
+ }
150
+ }
151
+ }
152
+ }
153
+
154
+ pub fn __repr__(&self) -> String {
155
+ match self {
156
+ Self::Cpu => "cpu".to_string(),
157
+ Self::Cuda => "cuda".to_string(),
158
+ Self::Metal => "metal".to_string(),
159
+ }
160
+ }
161
+
162
+ pub fn __str__(&self) -> String {
163
+ self.__repr__()
164
+ }
165
+ }
166
+
167
+ impl magnus::TryConvert for Device {
168
+ fn try_convert(val: magnus::Value) -> RbResult<Self> {
169
+ // First check if it's already a wrapped Device object
170
+ if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
171
+ return Ok(*device);
172
+ }
173
+
174
+ // Otherwise try to convert from string
175
+ let device = magnus::RString::try_convert(val)?;
176
+ let device = unsafe { device.as_str() }.unwrap();
177
+ let device = match device {
178
+ "cpu" => Device::Cpu,
179
+ "cuda" => Device::Cuda,
180
+ "metal" => Device::Metal,
181
+ _ => return Err(Error::new(magnus::exception::arg_error(), "invalid device")),
182
+ };
183
+ Ok(device)
184
+ }
185
+ }
186
+
187
+ pub fn init(rb_candle: RModule) -> Result<(), Error> {
188
+ let rb_device = rb_candle.define_class("Device", class::object())?;
189
+ rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
190
+ rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
191
+ rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
192
+ rb_device.define_singleton_method("available_devices", function!(available_devices, 0))?;
193
+ rb_device.define_singleton_method("default", function!(default_device, 0))?;
194
+ rb_device.define_method("to_s", method!(Device::__str__, 0))?;
195
+ rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
196
+ Ok(())
197
+ }
@@ -0,0 +1,37 @@
1
+ use magnus::value::ReprValue;
2
+ use magnus::{method, class, RModule, Error, Module};
3
+
4
+ use ::candle_core::DType as CoreDType;
5
+ use crate::ruby::Result as RbResult;
6
+
7
+ #[derive(Clone, Copy, Debug, PartialEq, Eq)]
8
+ #[magnus::wrap(class = "Candle::DType", free_immediately, size)]
9
+
10
+ /// A `candle` dtype.
11
+ pub struct DType(pub CoreDType);
12
+
13
+ impl DType {
14
+ pub fn __repr__(&self) -> String {
15
+ format!("{:?}", self.0)
16
+ }
17
+
18
+ pub fn __str__(&self) -> String {
19
+ self.__repr__()
20
+ }
21
+ }
22
+
23
+ impl DType {
24
+ pub fn from_rbobject(dtype: magnus::Symbol) -> RbResult<Self> {
25
+ let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
26
+ use std::str::FromStr;
27
+ let dtype = CoreDType::from_str(&dtype).unwrap();
28
+ Ok(Self(dtype))
29
+ }
30
+ }
31
+
32
+ pub fn init(rb_candle: RModule) -> Result<(), Error> {
33
+ let rb_dtype = rb_candle.define_class("DType", class::object())?;
34
+ rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
35
+ rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
36
+ Ok(())
37
+ }