red-candle 1.7.0 → 1.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 3d1b83311f0ad99adaffb886efda12291e9499ba112742e471e7b1eace390cee
4
- data.tar.gz: 6156968c937204767fda21adfc6069e55a9571018f064ac56821d61472cc9cc1
3
+ metadata.gz: 2a4feba5be08ae2cb7a70b4e61671a5cd434c65a1fa29d98b77841187caf2893
4
+ data.tar.gz: 4c453e5cbf48837fa3b68af52f1f409ddaf039fe3f0b49e13a3c66693614f18f
5
5
  SHA512:
6
- metadata.gz: 57abc7d285ebb3c67c438d861563247ec1130b03bee5fab418670bd926de58ad2bc06b2328732c362f46a4f8a8ccb7d85d7bd0e4b7a384838e290dadb7ed2be4
7
- data.tar.gz: 2f10da2df023629d59b2c2454efa886a87e0738d99801e6bb25b951431856b083780167190a6ba735c3c51cb8d170ed16a9c163f19aaf44fc8b95a72f8ea4b5e
6
+ metadata.gz: 9c5b31367d0efd7d99ec866ab44615d94cbbd98cd401edb3d8393b892fdac0d1d3657a2903c02f60b20b2f8e3649a3eeac29d9f6348f3a0a51ec1fd97ad60757
7
+ data.tar.gz: 60da5b3c14e101b44e53321276c17e3a80074fb56e5e927469908ffb1deb17443b215c317405c526bb79a2d22df9e8837a32a3f7d486851b89e51a9cdc2ccf77
@@ -0,0 +1,58 @@
1
+ /// GVL (Global VM Lock) release support for Ruby.
2
+ ///
3
+ /// Ruby's GVL prevents other Ruby threads from running while native code
4
+ /// executes. For long-running operations (LLM inference, reranking, embedding),
5
+ /// we release the GVL so other threads (TUI render loops, HTTP servers, etc.)
6
+ /// can run concurrently.
7
+ ///
8
+ /// SAFETY: Code running without the GVL must NOT call any Ruby API.
9
+
10
+ use std::os::raw::c_void;
11
+
12
+ type UnblockFn = unsafe extern "C" fn(*mut c_void);
13
+
14
+ extern "C" {
15
+ fn rb_thread_call_without_gvl(
16
+ func: unsafe extern "C" fn(*mut c_void) -> *mut c_void,
17
+ data1: *mut c_void,
18
+ ubf: Option<UnblockFn>,
19
+ data2: *mut c_void,
20
+ ) -> *mut c_void;
21
+ }
22
+
23
+ /// Run a closure without the GVL. The closure must not call any Ruby API.
24
+ pub fn without_gvl<F, R>(f: F) -> R
25
+ where
26
+ F: FnOnce() -> R,
27
+ {
28
+ struct CallData<F, R> {
29
+ func: Option<F>,
30
+ result: Option<R>,
31
+ }
32
+
33
+ unsafe extern "C" fn call_func<F, R>(data: *mut c_void) -> *mut c_void
34
+ where
35
+ F: FnOnce() -> R,
36
+ {
37
+ let data = &mut *(data as *mut CallData<F, R>);
38
+ let func = data.func.take().unwrap();
39
+ data.result = Some(func());
40
+ std::ptr::null_mut()
41
+ }
42
+
43
+ let mut data = CallData {
44
+ func: Some(f),
45
+ result: None,
46
+ };
47
+
48
+ unsafe {
49
+ rb_thread_call_without_gvl(
50
+ call_func::<F, R>,
51
+ &mut data as *mut _ as *mut c_void,
52
+ None,
53
+ std::ptr::null_mut(),
54
+ );
55
+ }
56
+
57
+ data.result.unwrap()
58
+ }
@@ -4,6 +4,7 @@ use crate::ruby::candle_utils;
4
4
  use crate::ruby::utils::ensure_hf_cache_dir;
5
5
  use crate::ruby::Result;
6
6
 
7
+ pub mod gvl;
7
8
  pub mod llm;
8
9
  pub mod ruby;
9
10
  pub mod structured;
@@ -104,14 +104,15 @@ impl EmbeddingModel {
104
104
  /// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
105
105
  pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
106
106
  let ruby = Ruby::get().unwrap();
107
- match &self.0.model {
108
- Some(model) => {
109
- match &self.0.tokenizer {
110
- Some(tokenizer) => Ok(Tensor(self.compute_embedding(input, model, tokenizer, &pooling_method)?)),
111
- None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found"))
112
- }
107
+ match (&self.0.model, &self.0.tokenizer) {
108
+ (Some(model), Some(tokenizer)) => {
109
+ let result = crate::gvl::without_gvl(|| {
110
+ self.compute_embedding(input, model, tokenizer, &pooling_method)
111
+ });
112
+ Ok(Tensor(result?))
113
113
  }
114
- None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
114
+ (None, _) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found")),
115
+ (_, None) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found")),
115
116
  }
116
117
  }
117
118
 
@@ -119,14 +120,15 @@ impl EmbeddingModel {
119
120
  /// &RETURNS&: Tensor
120
121
  pub fn embeddings(&self, input: String) -> Result<Tensor> {
121
122
  let ruby = Ruby::get().unwrap();
122
- match &self.0.model {
123
- Some(model) => {
124
- match &self.0.tokenizer {
125
- Some(tokenizer) => Ok(Tensor(self.compute_embeddings(input, model, tokenizer)?)),
126
- None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found"))
127
- }
123
+ match (&self.0.model, &self.0.tokenizer) {
124
+ (Some(model), Some(tokenizer)) => {
125
+ let result = crate::gvl::without_gvl(|| {
126
+ self.compute_embeddings(input, model, tokenizer)
127
+ });
128
+ Ok(Tensor(result?))
128
129
  }
129
- None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
130
+ (None, _) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found")),
131
+ (_, None) => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found")),
130
132
  }
131
133
  }
132
134
 
@@ -5,6 +5,7 @@ use std::sync::Arc;
5
5
  use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, qwen3::Qwen3 as RustQwen3, phi::Phi as RustPhi, granite::Granite as RustGranite, granitemoehybrid::GraniteMoeHybrid as RustGraniteMoeHybrid, glm4::Glm4 as RustGlm4, QuantizedGGUF as RustQuantizedGGUF};
6
6
  use crate::ruby::{Result, Device};
7
7
  use crate::ruby::structured::StructuredConstraint;
8
+ use crate::gvl;
8
9
 
9
10
  // Use an enum to handle different model types instead of trait objects
10
11
  enum ModelType {
@@ -422,7 +423,7 @@ impl LLM {
422
423
  })
423
424
  }
424
425
 
425
- /// Generate text from a prompt
426
+ /// Generate text from a prompt (releases GVL during inference)
426
427
  pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
427
428
  let ruby = Ruby::get().unwrap();
428
429
  let config = config
@@ -435,8 +436,13 @@ impl LLM {
435
436
  };
436
437
  let mut model_ref = model.borrow_mut();
437
438
 
438
- model_ref.generate(&prompt, &config)
439
- .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
439
+ // Release the GVL during inference so other Ruby threads can run
440
+ // (e.g., TUI render loops, HTTP servers, etc.)
441
+ let result = gvl::without_gvl(|| {
442
+ model_ref.generate(&prompt, &config)
443
+ });
444
+
445
+ result.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
440
446
  }
441
447
 
442
448
  /// Generate text with streaming output
@@ -173,8 +173,10 @@ impl NER {
173
173
  let ruby = Ruby::get().unwrap();
174
174
  let threshold = confidence_threshold.unwrap_or(0.9) as f32;
175
175
 
176
- // Use common tokenization and prediction logic
177
- let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
176
+ // Release GVL during tokenization + model forward pass
177
+ let (encoding, probs_vec) = crate::gvl::without_gvl(|| {
178
+ self.tokenize_and_predict(&text)
179
+ })?;
178
180
 
179
181
  let tokens = encoding.get_tokens();
180
182
  let offsets = encoding.get_offsets();
@@ -208,8 +210,10 @@ impl NER {
208
210
  /// Get token-level predictions with labels and confidence scores
209
211
  pub fn predict_tokens(&self, text: String) -> Result<RArray> {
210
212
  let ruby = Ruby::get().unwrap();
211
- // Use common tokenization and prediction logic
212
- let (encoding, probs_vec) = self.tokenize_and_predict(&text)?;
213
+ // Release GVL during tokenization + model forward pass
214
+ let (encoding, probs_vec) = crate::gvl::without_gvl(|| {
215
+ self.tokenize_and_predict(&text)
216
+ })?;
213
217
 
214
218
  let tokens = encoding.get_tokens();
215
219
 
@@ -18,6 +18,7 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
18
18
  use tokenizers::{EncodeInput, Tokenizer};
19
19
  use std::cell::RefCell;
20
20
  use crate::ruby::{Device, Result};
21
+ use crate::gvl;
21
22
  use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
22
23
 
23
24
  enum RerankerModel {
@@ -164,39 +165,27 @@ impl Reranker {
164
165
  }
165
166
 
166
167
  /// Extract CLS embeddings from the model output, handling Metal device workarounds
167
- fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, Error> {
168
- let ruby = Ruby::get().unwrap();
169
- let runtime_error = ruby.exception_runtime_error();
170
-
168
+ fn extract_cls_embeddings(&self, embeddings: &Tensor) -> std::result::Result<Tensor, String> {
171
169
  let cls_embeddings = if self.device.is_metal() {
172
- // Metal has issues with tensor indexing, use a different approach
173
170
  let (batch_size, seq_len, hidden_size) = embeddings.dims3()
174
- .map_err(|e| Error::new(runtime_error, format!("Failed to get dims: {}", e)))?;
175
-
176
- // Reshape to [batch * seq_len, hidden] then take first hidden vectors for each batch
171
+ .map_err(|e| format!("Failed to get dims: {}", e))?;
177
172
  let reshaped = embeddings.reshape((batch_size * seq_len, hidden_size))
178
- .map_err(|e| Error::new(runtime_error, format!("Failed to reshape: {}", e)))?;
179
-
180
- // Extract CLS tokens (first token of each sequence)
173
+ .map_err(|e| format!("Failed to reshape: {}", e))?;
181
174
  let mut cls_vecs = Vec::new();
182
175
  for i in 0..batch_size {
183
176
  let start_idx = i * seq_len;
184
177
  let cls_vec = reshaped.narrow(0, start_idx, 1)
185
- .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?;
178
+ .map_err(|e| format!("Failed to extract CLS: {}", e))?;
186
179
  cls_vecs.push(cls_vec);
187
180
  }
188
-
189
- // Stack the CLS vectors
190
181
  Tensor::cat(&cls_vecs, 0)
191
- .map_err(|e| Error::new(runtime_error, format!("Failed to cat CLS tokens: {}", e)))?
182
+ .map_err(|e| format!("Failed to cat CLS tokens: {}", e))?
192
183
  } else {
193
184
  embeddings.i((.., 0))
194
- .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS token: {}", e)))?
185
+ .map_err(|e| format!("Failed to extract CLS token: {}", e))?
195
186
  };
196
-
197
- // Ensure tensor is contiguous for downstream operations
198
187
  cls_embeddings.contiguous()
199
- .map_err(|e| Error::new(runtime_error, format!("Failed to make CLS embeddings contiguous: {}", e)))
188
+ .map_err(|e| format!("Failed to make CLS embeddings contiguous: {}", e))
200
189
  }
201
190
 
202
191
  pub fn debug_tokenization(&self, query: String, document: String) -> std::result::Result<RHash, Error> {
@@ -231,124 +220,147 @@ impl Reranker {
231
220
  let runtime_error = ruby.exception_runtime_error();
232
221
  let documents: Vec<String> = documents.to_vec()?;
233
222
 
223
+ // Release the GVL for the entire compute portion (tokenization + inference + scoring).
224
+ // None of this calls Ruby API.
225
+ let ranked_docs = gvl::without_gvl(|| -> std::result::Result<Vec<(String, f32, usize)>, String> {
226
+ self.compute_rerank(&query, &documents, &pooling_method, apply_sigmoid)
227
+ });
228
+
229
+ let ranked_docs = ranked_docs
230
+ .map_err(|e| Error::new(runtime_error, e))?;
231
+
232
+ // Build result array (requires GVL for Ruby object creation)
233
+ let result_array = ruby.ary_new();
234
+ for (doc, score, doc_id) in ranked_docs {
235
+ let tuple = ruby.ary_new();
236
+ tuple.push(doc)?;
237
+ tuple.push(ruby.float_from_f64(score as f64))?;
238
+ tuple.push(doc_id)?;
239
+ result_array.push(tuple)?;
240
+ }
241
+ Ok(result_array)
242
+ }
243
+
244
+ /// Pure compute portion of reranking — no Ruby API calls.
245
+ /// Returns ranked (document, score, original_index) tuples.
246
+ fn compute_rerank(&self, query: &str, documents: &[String], pooling_method: &str, apply_sigmoid: bool) -> std::result::Result<Vec<(String, f32, usize)>, String> {
234
247
  // Create query-document pairs for cross-encoder
235
248
  let query_and_docs: Vec<EncodeInput> = documents
236
249
  .iter()
237
- .map(|d| (query.clone(), d.clone()).into())
250
+ .map(|d| (query.to_string(), d.clone()).into())
238
251
  .collect();
239
252
 
240
- // Tokenize batch using inner tokenizer for access to token type IDs
253
+ // Tokenize batch
241
254
  let encodings = self.tokenizer.inner().encode_batch(query_and_docs, true)
242
- .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
255
+ .map_err(|e| format!("Tokenization failed: {}", e))?;
243
256
 
244
- // Convert to tensors
245
- let token_ids = encodings
257
+ let token_ids_vec = encodings
246
258
  .iter()
247
259
  .map(|e| e.get_ids().to_vec())
248
260
  .collect::<Vec<_>>();
249
261
 
250
- let token_type_ids = encodings
262
+ let token_type_ids_vec = encodings
251
263
  .iter()
252
264
  .map(|e| e.get_type_ids().to_vec())
253
265
  .collect::<Vec<_>>();
254
266
 
255
- let token_ids = Tensor::new(token_ids, &self.device)
256
- .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
257
- let token_type_ids = Tensor::new(token_type_ids, &self.device)
258
- .map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
267
+ let token_ids = Tensor::new(token_ids_vec, &self.device)
268
+ .map_err(|e| format!("Failed to create tensor: {}", e))?;
269
+ let token_type_ids = Tensor::new(token_type_ids_vec, &self.device)
270
+ .map_err(|e| format!("Failed to create token type ids tensor: {}", e))?;
259
271
 
260
272
  // Compute scores based on model type
261
273
  let scores = match &self.model {
262
274
  RerankerModel::Bert { model, pooler, classifier } => {
263
275
  let attention_mask = token_ids.ne(0u32)
264
- .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
276
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
265
277
 
266
278
  // Forward pass through BERT
267
279
  let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
268
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
280
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
269
281
 
270
282
  // Apply pooling based on the specified method
271
- let pooled_embeddings = match pooling_method.as_str() {
283
+ let pooled_embeddings = match pooling_method {
272
284
  "pooler" => {
273
285
  let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
274
286
  let pooled = pooler.forward(&cls_embeddings)
275
- .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
287
+ .map_err(|e| format!("Pooler forward failed: {}", e))?;
276
288
  pooled.tanh()
277
- .map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
289
+ .map_err(|e| format!("Tanh activation failed: {}", e))?
278
290
  },
279
291
  "cls" => {
280
292
  self.extract_cls_embeddings(&embeddings)?
281
293
  },
282
294
  "mean" => {
283
295
  let (_batch, seq_len, _hidden) = embeddings.dims3()
284
- .map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
296
+ .map_err(|e| format!("Failed to get tensor dimensions: {}", e))?;
285
297
  let sum = embeddings.sum(1)
286
- .map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
298
+ .map_err(|e| format!("Failed to sum embeddings: {}", e))?;
287
299
  (sum / (seq_len as f64))
288
- .map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
300
+ .map_err(|e| format!("Failed to compute mean: {}", e))?
289
301
  },
290
- _ => return Err(Error::new(runtime_error,
291
- format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
302
+ _ => return Err(
303
+ format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method))
292
304
  };
293
305
 
294
306
  let pooled_embeddings = pooled_embeddings.contiguous()
295
- .map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
307
+ .map_err(|e| format!("Failed to make pooled_embeddings contiguous: {}", e))?;
296
308
  let logits = classifier.forward(&pooled_embeddings)
297
- .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
309
+ .map_err(|e| format!("Classifier forward failed: {}", e))?;
298
310
  logits.squeeze(1)
299
- .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
311
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
300
312
  }
301
313
  RerankerModel::XLMRoberta { model, pad_token_id } => {
302
314
  let attention_mask = token_ids.ne(*pad_token_id)
303
- .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
315
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
304
316
 
305
317
  // XLMRobertaForSequenceClassification returns logits directly
306
318
  let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
307
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
319
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
308
320
  logits.squeeze(1)
309
- .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
321
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
310
322
  }
311
323
  RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
312
324
  let attention_mask = token_ids.ne(*pad_token_id)
313
- .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
325
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
314
326
 
315
327
  // Forward through DeBERTa encoder
316
328
  let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
317
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
329
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
318
330
 
319
331
  // Pool and classify
320
332
  let pooled = pooler.forward(&encoder_output)
321
- .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
333
+ .map_err(|e| format!("Pooler forward failed: {}", e))?;
322
334
  let logits = classifier.forward(&pooled)
323
- .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
335
+ .map_err(|e| format!("Classifier forward failed: {}", e))?;
324
336
  logits.squeeze(1)
325
- .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
337
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
326
338
  }
327
339
  RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
328
340
  let attention_mask = token_ids.ne(*pad_token_id)
329
- .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
341
+ .map_err(|e| format!("Failed to create attention mask: {}", e))?;
330
342
  let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
331
- .map_err(|e| Error::new(runtime_error, format!("Failed to convert attention mask: {}", e)))?;
343
+ .map_err(|e| format!("Failed to convert attention mask: {}", e))?;
332
344
 
333
345
  // Forward through ModernBERT encoder
334
346
  let encoder_output = model.forward(&token_ids, &attention_mask_f32)
335
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
347
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
336
348
 
337
349
  // CLS pooling, then head (dense + GELU + norm) + classifier
338
350
  let cls = encoder_output.i((.., 0, ..))
339
- .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?
351
+ .map_err(|e| format!("Failed to extract CLS: {}", e))?
340
352
  .contiguous()
341
- .map_err(|e| Error::new(runtime_error, format!("Failed to make contiguous: {}", e)))?;
353
+ .map_err(|e| format!("Failed to make contiguous: {}", e))?;
342
354
  let hidden = head_dense.forward(&cls)
343
- .map_err(|e| Error::new(runtime_error, format!("Head dense failed: {}", e)))?;
355
+ .map_err(|e| format!("Head dense failed: {}", e))?;
344
356
  let hidden = hidden.gelu_erf()
345
- .map_err(|e| Error::new(runtime_error, format!("GELU activation failed: {}", e)))?;
357
+ .map_err(|e| format!("GELU activation failed: {}", e))?;
346
358
  let hidden = head_norm.forward(&hidden)
347
- .map_err(|e| Error::new(runtime_error, format!("Head norm failed: {}", e)))?;
359
+ .map_err(|e| format!("Head norm failed: {}", e))?;
348
360
  let logits = classifier.forward(&hidden)
349
- .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
361
+ .map_err(|e| format!("Classifier forward failed: {}", e))?;
350
362
  logits.squeeze(1)
351
- .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
363
+ .map_err(|e| format!("Failed to squeeze tensor: {}", e))?
352
364
  }
353
365
  RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
354
366
  // Qwen3 reranker: decoder-based yes/no scoring
@@ -356,7 +368,7 @@ impl Reranker {
356
368
  let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
357
369
  let mut model = model.borrow_mut();
358
370
 
359
- for doc in &documents {
371
+ for doc in documents.iter() {
360
372
  // Build the Qwen3 reranker prompt
361
373
  let prompt = format!(
362
374
  "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
@@ -365,7 +377,7 @@ impl Reranker {
365
377
 
366
378
  // Tokenize the prompt
367
379
  let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
368
- .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
380
+ .map_err(|e| format!("Tokenization failed: {}", e))?;
369
381
  let input_ids: Vec<u32> = encoding.get_ids().to_vec();
370
382
 
371
383
  // Clear KV cache for each document
@@ -373,28 +385,28 @@ impl Reranker {
373
385
 
374
386
  // Forward pass — get logits for the last token position
375
387
  let input_tensor = Tensor::new(&input_ids[..], &self.device)
376
- .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?
388
+ .map_err(|e| format!("Failed to create tensor: {}", e))?
377
389
  .unsqueeze(0)
378
- .map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
390
+ .map_err(|e| format!("Failed to unsqueeze: {}", e))?;
379
391
 
380
392
  let logits = model.forward(&input_tensor, 0)
381
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
393
+ .map_err(|e| format!("Model forward pass failed: {}", e))?;
382
394
 
383
395
  // logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
384
396
  let logits = logits.flatten_all()
385
- .map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?
397
+ .map_err(|e| format!("Failed to flatten: {}", e))?
386
398
  .to_dtype(DType::F32)
387
- .map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
399
+ .map_err(|e| format!("Failed to convert dtype: {}", e))?;
388
400
 
389
401
  // Extract yes/no logits and compute score
390
402
  let yes_logit: f32 = logits.i(*yes_token_id as usize)
391
- .map_err(|e| Error::new(runtime_error, format!("Failed to get yes logit: {}", e)))?
403
+ .map_err(|e| format!("Failed to get yes logit: {}", e))?
392
404
  .to_scalar()
393
- .map_err(|e| Error::new(runtime_error, format!("Failed to convert yes logit: {}", e)))?;
405
+ .map_err(|e| format!("Failed to convert yes logit: {}", e))?;
394
406
  let no_logit: f32 = logits.i(*no_token_id as usize)
395
- .map_err(|e| Error::new(runtime_error, format!("Failed to get no logit: {}", e)))?
407
+ .map_err(|e| format!("Failed to get no logit: {}", e))?
396
408
  .to_scalar()
397
- .map_err(|e| Error::new(runtime_error, format!("Failed to convert no logit: {}", e)))?;
409
+ .map_err(|e| format!("Failed to convert no logit: {}", e))?;
398
410
 
399
411
  // softmax over [yes, no] → P(yes)
400
412
  let max_logit = yes_logit.max(no_logit);
@@ -407,24 +419,25 @@ impl Reranker {
407
419
 
408
420
  // Build scores tensor for uniform handling below
409
421
  Tensor::new(scores_vec.as_slice(), &self.device)
410
- .map_err(|e| Error::new(runtime_error, format!("Failed to create scores tensor: {}", e)))?
422
+ .map_err(|e| format!("Failed to create scores tensor: {}", e))?
411
423
  }
412
424
  };
413
425
 
414
426
  // Optionally apply sigmoid activation
415
427
  let scores = if apply_sigmoid {
416
428
  sigmoid(&scores)
417
- .map_err(|e| Error::new(runtime_error, format!("Sigmoid failed: {}", e)))?
429
+ .map_err(|e| format!("Sigmoid failed: {}", e))?
418
430
  } else {
419
431
  scores
420
432
  };
421
433
 
422
434
  let scores_vec: Vec<f32> = scores.to_vec1()
423
- .map_err(|e| Error::new(runtime_error, format!("Failed to convert scores to vec: {}", e)))?;
435
+ .map_err(|e| format!("Failed to convert scores to vec: {}", e))?;
424
436
 
425
437
  // Create tuples with document, score, and original index
426
438
  let mut ranked_docs: Vec<(String, f32, usize)> = documents
427
- .into_iter()
439
+ .iter()
440
+ .cloned()
428
441
  .zip(scores_vec)
429
442
  .enumerate()
430
443
  .map(|(idx, (doc, score))| (doc, score, idx))
@@ -433,16 +446,7 @@ impl Reranker {
433
446
  // Sort documents by relevance score (descending)
434
447
  ranked_docs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
435
448
 
436
- // Build result array with [doc, score, doc_id]
437
- let result_array = ruby.ary_new();
438
- for (doc, score, doc_id) in ranked_docs {
439
- let tuple = ruby.ary_new();
440
- tuple.push(doc)?;
441
- tuple.push(ruby.float_from_f64(score as f64))?;
442
- tuple.push(doc_id)?;
443
- result_array.push(tuple)?;
444
- }
445
- Ok(result_array)
449
+ Ok(ranked_docs)
446
450
  }
447
451
 
448
452
  /// Get the tokenizer used by this model
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.7.0"
3
+ VERSION = "1.7.1"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.7.0
4
+ version: 1.7.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -215,6 +215,7 @@ files:
215
215
  - ext/candle/build.rs
216
216
  - ext/candle/extconf.rb
217
217
  - ext/candle/rustfmt.toml
218
+ - ext/candle/src/gvl.rs
218
219
  - ext/candle/src/lib.rs
219
220
  - ext/candle/src/llm/constrained_generation_test.rs
220
221
  - ext/candle/src/llm/gemma.rs