red-candle 1.3.0 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5b92d492e96b8192fba14141ab66ad42aa4afe0d942cc0658f8b64bab2bf916b
4
- data.tar.gz: fe3510382fe48853b45061beb336108499655b566c9cb8bf1889b36f76dcda0a
3
+ metadata.gz: 7a3ac57ccd28cb2eb647c5a71f42fdbdf84a9dcaabd906165daca3d57c4a8eb0
4
+ data.tar.gz: 283104e93802ac97f11525226c9e41dc3bebccb8706d5b69a6a648d62cf2ccad
5
5
  SHA512:
6
- metadata.gz: eeddd779bc811f2c2707439d8b92644a2711091d9e42750ed4ebbbf17054a482f1b79147a562200ef5cd5cf6f7620cfd5b543ca32624371121ca64bae40f210b
7
- data.tar.gz: cfdf7c9b76a8dda7bcfc9f215374251a606ba06d34c9310430e61390654ae873f3d5e61359767356e7c6554302e8f33a9385d3cbaee3d0da7c6cda771d2af970
6
+ metadata.gz: d5be5ca76fe5441ee1fd87ea83bcd02f5a630c3410a4dccc347da28290ef408c4c2cae428134f2a8fba839fe5aacf2b7593fe45274c967600e6684d09f212a01
7
+ data.tar.gz: 72289aaf0c17dea679acfa4f92be5371f2543293e27dab0752b39c77a1c2be98c480c0ce7266bc07519551770abdf53fdbbf0f4d3c1fdf6ba928c2452e598060
data/Cargo.lock CHANGED
@@ -167,7 +167,7 @@ dependencies = [
167
167
  "bitflags 2.9.4",
168
168
  "cexpr",
169
169
  "clang-sys",
170
- "itertools 0.12.1",
170
+ "itertools 0.11.0",
171
171
  "lazy_static",
172
172
  "lazycell",
173
173
  "proc-macro2",
@@ -1750,15 +1750,6 @@ dependencies = [
1750
1750
  "either",
1751
1751
  ]
1752
1752
 
1753
- [[package]]
1754
- name = "itertools"
1755
- version = "0.12.1"
1756
- source = "registry+https://github.com/rust-lang/crates.io-index"
1757
- checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
1758
- dependencies = [
1759
- "either",
1760
- ]
1761
-
1762
1753
  [[package]]
1763
1754
  name = "itertools"
1764
1755
  version = "0.13.0"
@@ -1890,9 +1881,9 @@ checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30"
1890
1881
 
1891
1882
  [[package]]
1892
1883
  name = "magnus"
1893
- version = "0.7.1"
1884
+ version = "0.8.2"
1894
1885
  source = "registry+https://github.com/rust-lang/crates.io-index"
1895
- checksum = "3d87ae53030f3a22e83879e666cb94e58a7bdf31706878a0ba48752994146dab"
1886
+ checksum = "3b36a5b126bbe97eb0d02d07acfeb327036c6319fd816139a49824a83b7f9012"
1896
1887
  dependencies = [
1897
1888
  "magnus-macros",
1898
1889
  "rb-sys",
@@ -1902,9 +1893,9 @@ dependencies = [
1902
1893
 
1903
1894
  [[package]]
1904
1895
  name = "magnus-macros"
1905
- version = "0.6.0"
1896
+ version = "0.8.0"
1906
1897
  source = "registry+https://github.com/rust-lang/crates.io-index"
1907
- checksum = "5968c820e2960565f647819f5928a42d6e874551cab9d88d75e3e0660d7f71e3"
1898
+ checksum = "47607461fd8e1513cb4f2076c197d8092d921a1ea75bd08af97398f593751892"
1908
1899
  dependencies = [
1909
1900
  "proc-macro2",
1910
1901
  "quote",
@@ -2656,18 +2647,18 @@ dependencies = [
2656
2647
 
2657
2648
  [[package]]
2658
2649
  name = "rb-sys"
2659
- version = "0.9.117"
2650
+ version = "0.9.124"
2660
2651
  source = "registry+https://github.com/rust-lang/crates.io-index"
2661
- checksum = "f900d1ce4629a2ebffaf5de74bd8f9c1188d4c5ed406df02f97e22f77a006f44"
2652
+ checksum = "c85c4188462601e2aa1469def389c17228566f82ea72f137ed096f21591bc489"
2662
2653
  dependencies = [
2663
2654
  "rb-sys-build",
2664
2655
  ]
2665
2656
 
2666
2657
  [[package]]
2667
2658
  name = "rb-sys-build"
2668
- version = "0.9.117"
2659
+ version = "0.9.124"
2669
2660
  source = "registry+https://github.com/rust-lang/crates.io-index"
2670
- checksum = "ef1e9c857028f631056bcd6d88cec390c751e343ce2223ddb26d23eb4a151d59"
2661
+ checksum = "568068db4102230882e6d4ae8de6632e224ca75fe5970f6e026a04e91ed635d3"
2671
2662
  dependencies = [
2672
2663
  "bindgen 0.69.5",
2673
2664
  "lazy_static",
@@ -2680,9 +2671,9 @@ dependencies = [
2680
2671
 
2681
2672
  [[package]]
2682
2673
  name = "rb-sys-env"
2683
- version = "0.1.2"
2674
+ version = "0.2.3"
2684
2675
  source = "registry+https://github.com/rust-lang/crates.io-index"
2685
- checksum = "a35802679f07360454b418a5d1735c89716bde01d35b1560fc953c1415a0b3bb"
2676
+ checksum = "cca7ad6a7e21e72151d56fe2495a259b5670e204c3adac41ee7ef676ea08117a"
2686
2677
 
2687
2678
  [[package]]
2688
2679
  name = "reborrow"
@@ -15,7 +15,7 @@ candle-transformers = { version = "0.9.1" }
15
15
  tokenizers = { version = "0.22.0", default-features = true, features = ["fancy-regex"] }
16
16
  hf-hub = "0.4.1"
17
17
  half = "2.6.0"
18
- magnus = "0.7.1"
18
+ magnus = "0.8"
19
19
  safetensors = "0.3"
20
20
  serde_json = "1.0"
21
21
  serde = { version = "1.0", features = ["derive"] }
@@ -313,4 +313,83 @@ mod constrained_generation_tests {
313
313
  // Verify tokens are being tracked
314
314
  assert_eq!(text_gen.get_tokens().len(), all_tokens.len(), "Internal tokens should match generated");
315
315
  }
316
+
317
+ #[test]
318
+ fn test_constraint_satisfied_not_triggered_by_large_allowed_set() {
319
+ // This test verifies the fix for the bug where is_constraint_satisfied_stop_on_match
320
+ // would incorrectly return true when many tokens are allowed (e.g., inside a JSON string).
321
+ // The old buggy code had: if allowed.len() > 1000 { return true; }
322
+ // This caused early termination when inside strings with many valid characters.
323
+
324
+ let config = GenerationConfig::default();
325
+ let mut text_gen = TextGeneration::new(&config);
326
+ text_gen.set_eos_token_id(50256);
327
+
328
+ // Without a constraint, should not be satisfied
329
+ assert!(!text_gen.is_constraint_satisfied(),
330
+ "Without constraint, should not be satisfied");
331
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
332
+ "Without constraint, stop_on_match should not be satisfied");
333
+ }
334
+
335
+ #[test]
336
+ fn test_constraint_satisfied_only_when_empty_or_eos_only() {
337
+ // Test that constraint satisfaction only triggers when:
338
+ // 1. No tokens are allowed (empty set)
339
+ // 2. Only EOS token is allowed
340
+ // NOT when many tokens are allowed (like inside a JSON string)
341
+
342
+ let config = GenerationConfig::default();
343
+ let mut text_gen = TextGeneration::new(&config);
344
+ text_gen.set_eos_token_id(100); // Set EOS token
345
+
346
+ // Without constraint, should not be satisfied
347
+ assert!(!text_gen.is_constraint_satisfied());
348
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match());
349
+
350
+ // The key insight: constraint satisfaction should NOT be triggered
351
+ // just because there are many allowed tokens. It should only trigger
352
+ // when the constraint is definitively complete (empty allowed set or only EOS).
353
+ }
354
+
355
+ #[tokio::test]
356
+ async fn test_constraint_with_json_schema_not_early_termination() {
357
+ // Integration test: Create a real JSON schema constraint and verify
358
+ // that being inside a string (many allowed tokens) doesn't trigger completion.
359
+
360
+ if let Ok(tokenizer) = TokenizerLoader::from_hf_hub("bert-base-uncased", None).await {
361
+ let wrapper = TokenizerWrapper::new(tokenizer);
362
+ let vocabulary = VocabularyAdapter::from_tokenizer(&wrapper)
363
+ .expect("Should create vocabulary");
364
+
365
+ let processor = SchemaProcessor::new();
366
+
367
+ // Schema with a string field - when generating content inside the string,
368
+ // many characters are valid, but the constraint is NOT complete
369
+ let schema = r#"{
370
+ "type": "object",
371
+ "properties": {
372
+ "name": { "type": "string" }
373
+ },
374
+ "required": ["name"]
375
+ }"#;
376
+
377
+ let index = processor.process_schema(schema, &vocabulary)
378
+ .expect("Should process schema");
379
+
380
+ let mut config = GenerationConfig::default();
381
+ config.constraint = Some(index);
382
+ config.max_length = 100;
383
+
384
+ let mut text_gen = TextGeneration::new(&config);
385
+ text_gen.set_eos_token_id(102); // BERT's [SEP]
386
+
387
+ // At the initial state, the constraint should NOT be satisfied
388
+ // (we haven't generated a complete JSON object yet)
389
+ assert!(!text_gen.is_constraint_satisfied(),
390
+ "Initial state should not be satisfied - JSON not yet generated");
391
+ assert!(!text_gen.is_constraint_satisfied_stop_on_match(),
392
+ "Initial state should not trigger stop_on_match");
393
+ }
394
+ }
316
395
  }
@@ -148,47 +148,28 @@ impl TextGeneration {
148
148
  if let (Some(ref constraint_index), Some(current_state)) = (&self.constraint, self.constraint_state) {
149
149
  // Get the next state
150
150
  let next_state = constraint_index.next_state(&current_state, &next_token);
151
-
151
+
152
152
  // Check if we're transitioning to a state with no allowed tokens (completion)
153
153
  if !self.constraint_completed && self.tokens.len() > self.tokens_since_constraint_start {
154
- // Check if we've transitioned from a constrained state to an unconstrained state
155
- // This happens when the pattern is complete and the FSM allows "anything"
156
-
157
- let current_constrained = if let Some(allowed) = constraint_index.allowed_tokens(&current_state) {
158
- // Consider it constrained if we have a limited set of allowed tokens
159
- allowed.len() < 1000 // Arbitrary threshold for "constrained"
160
- } else {
161
- true // No tokens allowed is definitely constrained
162
- };
163
-
164
- let next_constrained = if let Some(next_state_val) = next_state {
165
- if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
166
- allowed.is_empty() || allowed.len() < 1000
167
- } else {
168
- true
169
- }
170
- } else {
171
- true
172
- };
173
-
174
- // If we're transitioning from constrained to unconstrained, we've completed the pattern
175
- if current_constrained && !next_constrained {
176
- self.constraint_completed = true;
177
- }
178
-
179
- // Also check if next state has no allowed tokens at all
154
+ // Check if next state has no allowed tokens at all - this is definitive completion
180
155
  if let Some(next_state_val) = next_state {
181
156
  if let Some(allowed) = constraint_index.allowed_tokens(&next_state_val) {
182
157
  if allowed.is_empty() {
183
158
  self.constraint_completed = true;
184
159
  }
160
+ // Only mark as complete if ONLY EOS is allowed (not just if EOS is one of many options)
161
+ else if let Some(eos) = self.eos_token_id {
162
+ if allowed.len() == 1 && allowed.contains(&eos) {
163
+ self.constraint_completed = true;
164
+ }
165
+ }
185
166
  } else {
186
167
  // None means no tokens allowed - constraint is complete
187
168
  self.constraint_completed = true;
188
169
  }
189
170
  }
190
171
  }
191
-
172
+
192
173
  self.constraint_state = next_state;
193
174
  }
194
175
 
@@ -201,22 +182,22 @@ impl TextGeneration {
201
182
  if self.constraint_completed {
202
183
  return true;
203
184
  }
204
-
185
+
205
186
  // Also check the current state
206
187
  if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
207
- // Check if the constraint has reached a state where it could validly end
208
- // This happens when:
209
- // 1. We have no more allowed tokens (constraint fully satisfied)
210
- // 2. The EOS token is in the allowed tokens (optional ending)
188
+ // Check if the constraint has reached a state where it MUST end
189
+ // This happens when there are no more allowed tokens (constraint fully satisfied)
211
190
  if let Some(allowed) = constraint_index.allowed_tokens(&state) {
212
191
  // If no tokens are allowed, the constraint is fully satisfied
213
192
  if allowed.is_empty() {
214
193
  return true;
215
194
  }
216
-
217
- // If EOS token is allowed, we've reached an optional completion point
195
+
196
+ // For JSON schemas, check if ONLY the EOS token is allowed
197
+ // This means we've generated a complete, valid JSON structure
198
+ // Don't treat EOS as a satisfaction signal if other tokens are also allowed
218
199
  if let Some(eos) = self.eos_token_id {
219
- if allowed.contains(&eos) {
200
+ if allowed.len() == 1 && allowed.contains(&eos) {
220
201
  return true;
221
202
  }
222
203
  }
@@ -229,28 +210,37 @@ impl TextGeneration {
229
210
  }
230
211
 
231
212
  /// Check if the constraint is satisfied when stop_on_match is true
213
+ /// NOTE: For JSON schemas, this should only return true when the JSON structure is complete,
214
+ /// not just because we're in a state with many allowed tokens (like inside a string).
232
215
  pub fn is_constraint_satisfied_stop_on_match(&self) -> bool {
233
216
  // When stop_on_match is true, we stop as soon as the constraint is completed
234
217
  if self.constraint_completed {
235
218
  return true;
236
219
  }
237
-
238
- // Also check if we're currently in a state that could be a valid end
239
- // This is important for patterns like phone numbers where after matching
240
- // the pattern, the FSM might allow any token (including more numbers)
220
+
221
+ // For JSON and other structured outputs, don't use the "large allowed set" heuristic.
222
+ // Instead, only consider the constraint satisfied when:
223
+ // 1. There are no allowed tokens (definitive completion)
224
+ // 2. Only EOS is allowed (completion with optional termination)
241
225
  if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
242
- // Check if we've generated at least one token since constraint start
243
- if self.tokens.len() > self.tokens_since_constraint_start {
244
- if let Some(allowed) = constraint_index.allowed_tokens(&state) {
245
- // If the allowed tokens set is very large (unconstrained),
246
- // it means the pattern has been satisfied
247
- if allowed.len() > 1000 {
226
+ if let Some(allowed) = constraint_index.allowed_tokens(&state) {
227
+ // No more tokens allowed - definitely complete
228
+ if allowed.is_empty() {
229
+ return true;
230
+ }
231
+
232
+ // Only EOS is allowed - complete JSON structure
233
+ if let Some(eos) = self.eos_token_id {
234
+ if allowed.len() == 1 && allowed.contains(&eos) {
248
235
  return true;
249
236
  }
250
237
  }
238
+ } else {
239
+ // None means no tokens allowed - constraint is complete
240
+ return true;
251
241
  }
252
242
  }
253
-
243
+
254
244
  false
255
245
  }
256
246
 
@@ -259,13 +249,13 @@ impl TextGeneration {
259
249
  if self.tokens.len() >= max_length {
260
250
  return true;
261
251
  }
262
-
252
+
263
253
  if let Some(eos) = self.eos_token_id {
264
254
  if token == eos {
265
255
  return true;
266
256
  }
267
257
  }
268
-
258
+
269
259
  // Check if we've reached a final state in constraint
270
260
  // A state is considered final if it has no allowed tokens
271
261
  if let (Some(ref constraint_index), Some(state)) = (&self.constraint, self.constraint_state) {
@@ -278,7 +268,7 @@ impl TextGeneration {
278
268
  return true;
279
269
  }
280
270
  }
281
-
271
+
282
272
  false
283
273
  }
284
274
 
@@ -1,5 +1,5 @@
1
1
  use magnus::Error;
2
- use magnus::{function, method, class, RModule, Module, Object};
2
+ use magnus::{function, method, RModule, Module, Object, Ruby};
3
3
 
4
4
  use ::candle_core::Device as CoreDevice;
5
5
  use crate::ruby::Result;
@@ -101,7 +101,7 @@ impl Device {
101
101
  #[cfg(not(feature = "cuda"))]
102
102
  {
103
103
  return Err(Error::new(
104
- magnus::exception::runtime_error(),
104
+ Ruby::get().unwrap().exception_runtime_error(),
105
105
  "CUDA support not compiled in. Rebuild with CUDA available.",
106
106
  ));
107
107
  }
@@ -115,7 +115,7 @@ impl Device {
115
115
  #[cfg(not(feature = "metal"))]
116
116
  {
117
117
  return Err(Error::new(
118
- magnus::exception::runtime_error(),
118
+ Ruby::get().unwrap().exception_runtime_error(),
119
119
  "Metal support not compiled in. Rebuild on macOS.",
120
120
  ));
121
121
  }
@@ -139,7 +139,7 @@ impl Device {
139
139
  #[cfg(not(feature = "cuda"))]
140
140
  {
141
141
  return Err(Error::new(
142
- magnus::exception::runtime_error(),
142
+ Ruby::get().unwrap().exception_runtime_error(),
143
143
  "CUDA support not compiled in. Rebuild with CUDA available.",
144
144
  ));
145
145
  }
@@ -161,7 +161,7 @@ impl Device {
161
161
  #[cfg(not(feature = "metal"))]
162
162
  {
163
163
  return Err(Error::new(
164
- magnus::exception::runtime_error(),
164
+ Ruby::get().unwrap().exception_runtime_error(),
165
165
  "Metal support not compiled in. Rebuild on macOS.",
166
166
  ));
167
167
  }
@@ -211,14 +211,15 @@ impl magnus::TryConvert for Device {
211
211
  "cpu" => Device::Cpu,
212
212
  "cuda" => Device::Cuda,
213
213
  "metal" => Device::Metal,
214
- _ => return Err(Error::new(magnus::exception::arg_error(), "invalid device")),
214
+ _ => return Err(Error::new(Ruby::get().unwrap().exception_arg_error(), "invalid device")),
215
215
  };
216
216
  Ok(device)
217
217
  }
218
218
  }
219
219
 
220
220
  pub fn init(rb_candle: RModule) -> Result<()> {
221
- let rb_device = rb_candle.define_class("Device", class::object())?;
221
+ let ruby = Ruby::get().unwrap();
222
+ let rb_device = rb_candle.define_class("Device", ruby.class_object())?;
222
223
  rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
223
224
  rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
224
225
  rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
@@ -1,5 +1,5 @@
1
1
  use magnus::value::ReprValue;
2
- use magnus::{method, class, RModule, Module};
2
+ use magnus::{method, RModule, Module, Ruby};
3
3
 
4
4
  use ::candle_core::DType as CoreDType;
5
5
  use crate::ruby::Result;
@@ -30,7 +30,8 @@ impl DType {
30
30
  }
31
31
 
32
32
  pub fn init(rb_candle: RModule) -> Result<()> {
33
- let rb_dtype = rb_candle.define_class("DType", class::object())?;
33
+ let ruby = Ruby::get().unwrap();
34
+ let rb_dtype = rb_candle.define_class("DType", ruby.class_object())?;
34
35
  rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
35
36
  rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
36
37
  Ok(())
@@ -13,7 +13,7 @@ use candle_transformers::models::{
13
13
  jina_bert::{BertModel as JinaBertModel, Config as JinaConfig},
14
14
  distilbert::{DistilBertModel, Config as DistilBertConfig}
15
15
  };
16
- use magnus::{class, function, method, prelude::*, Error, RModule, RHash};
16
+ use magnus::{function, method, prelude::*, Error, RModule, RHash, Ruby};
17
17
  use std::path::Path;
18
18
  use serde_json;
19
19
 
@@ -103,28 +103,30 @@ impl EmbeddingModel {
103
103
  /// &RETURNS&: Tensor
104
104
  /// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
105
105
  pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
106
+ let ruby = Ruby::get().unwrap();
106
107
  match &self.0.model {
107
108
  Some(model) => {
108
109
  match &self.0.tokenizer {
109
110
  Some(tokenizer) => Ok(Tensor(self.compute_embedding(input, model, tokenizer, &pooling_method)?)),
110
- None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
111
+ None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found"))
111
112
  }
112
113
  }
113
- None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Model not found"))
114
+ None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
114
115
  }
115
116
  }
116
117
 
117
118
  /// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
118
119
  /// &RETURNS&: Tensor
119
120
  pub fn embeddings(&self, input: String) -> Result<Tensor> {
121
+ let ruby = Ruby::get().unwrap();
120
122
  match &self.0.model {
121
123
  Some(model) => {
122
124
  match &self.0.tokenizer {
123
125
  Some(tokenizer) => Ok(Tensor(self.compute_embeddings(input, model, tokenizer)?)),
124
- None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Tokenizer not found"))
126
+ None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Tokenizer not found"))
125
127
  }
126
128
  }
127
- None => Err(magnus::Error::new(magnus::exception::runtime_error(), "Model not found"))
129
+ None => Err(magnus::Error::new(ruby.exception_runtime_error(), "Model not found"))
128
130
  }
129
131
  }
130
132
 
@@ -165,7 +167,10 @@ impl EmbeddingModel {
165
167
  },
166
168
  Err(_) => None
167
169
  };
168
- inferred_emb_dim.ok_or_else(|| magnus::Error::new(magnus::exception::runtime_error(), "Could not infer embedding size from model file. Please specify embedding_size explicitly."))
170
+ inferred_emb_dim.ok_or_else(|| {
171
+ let ruby = Ruby::get().unwrap();
172
+ magnus::Error::new(ruby.exception_runtime_error(), "Could not infer embedding size from model file. Please specify embedding_size explicitly.")
173
+ })
169
174
  }
170
175
  }
171
176
  }
@@ -178,8 +183,9 @@ impl EmbeddingModel {
178
183
  EmbeddingModelType::JinaBert => {
179
184
  let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
180
185
  if !std::path::Path::new(&model_path).exists() {
186
+ let ruby = Ruby::get().unwrap();
181
187
  return Err(magnus::Error::new(
182
- magnus::exception::runtime_error(),
188
+ ruby.exception_runtime_error(),
183
189
  "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
184
190
  ));
185
191
  }
@@ -196,8 +202,9 @@ impl EmbeddingModel {
196
202
  EmbeddingModelType::StandardBert => {
197
203
  let model_path = api.repo(repo).get("model.safetensors").map_err(wrap_hf_err)?;
198
204
  if !std::path::Path::new(&model_path).exists() {
205
+ let ruby = Ruby::get().unwrap();
199
206
  return Err(magnus::Error::new(
200
- magnus::exception::runtime_error(),
207
+ ruby.exception_runtime_error(),
201
208
  "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
202
209
  ));
203
210
  }
@@ -214,8 +221,9 @@ impl EmbeddingModel {
214
221
  EmbeddingModelType::DistilBert => {
215
222
  let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
216
223
  if !std::path::Path::new(&model_path).exists() {
224
+ let ruby = Ruby::get().unwrap();
217
225
  return Err(magnus::Error::new(
218
- magnus::exception::runtime_error(),
226
+ ruby.exception_runtime_error(),
219
227
  "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
220
228
  ));
221
229
  }
@@ -235,8 +243,9 @@ impl EmbeddingModel {
235
243
  EmbeddingModelType::MiniLM => {
236
244
  let model_path = api.repo(repo.clone()).get("model.safetensors").map_err(wrap_hf_err)?;
237
245
  if !std::path::Path::new(&model_path).exists() {
246
+ let ruby = Ruby::get().unwrap();
238
247
  return Err(magnus::Error::new(
239
- magnus::exception::runtime_error(),
248
+ ruby.exception_runtime_error(),
240
249
  "model.safetensors not found after download. Only safetensors models are supported. Please ensure your model repo contains model.safetensors."
241
250
  ));
242
251
  }
@@ -357,7 +366,10 @@ impl EmbeddingModel {
357
366
  "pooled" => Self::pooled_embedding(&result),
358
367
  "pooled_normalized" => Self::pooled_normalized_embedding(&result),
359
368
  "cls" => Self::pooled_cls_embedding(&result),
360
- _ => Err(magnus::Error::new(magnus::exception::runtime_error(), "Unknown pooling method")),
369
+ _ => {
370
+ let ruby = Ruby::get().unwrap();
371
+ Err(magnus::Error::new(ruby.exception_runtime_error(), "Unknown pooling method"))
372
+ },
361
373
  }
362
374
  }
363
375
 
@@ -390,7 +402,10 @@ impl EmbeddingModel {
390
402
  pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
391
403
  match &self.0.tokenizer {
392
404
  Some(tokenizer) => Ok(crate::ruby::tokenizer::Tokenizer(tokenizer.clone())),
393
- None => Err(magnus::Error::new(magnus::exception::runtime_error(), "No tokenizer loaded for this model"))
405
+ None => {
406
+ let ruby = Ruby::get().unwrap();
407
+ Err(magnus::Error::new(ruby.exception_runtime_error(), "No tokenizer loaded for this model"))
408
+ }
394
409
  }
395
410
  }
396
411
 
@@ -409,7 +424,8 @@ impl EmbeddingModel {
409
424
 
410
425
  /// Get all options as a hash
411
426
  pub fn options(&self) -> Result<RHash> {
412
- let hash = RHash::new();
427
+ let ruby = Ruby::get().unwrap();
428
+ let hash = ruby.hash_new();
413
429
 
414
430
  // Add model_id
415
431
  if let Some(model_id) = &self.0.model_id {
@@ -439,7 +455,8 @@ impl EmbeddingModel {
439
455
  }
440
456
 
441
457
  pub fn init(rb_candle: RModule) -> Result<()> {
442
- let rb_embedding_model = rb_candle.define_class("EmbeddingModel", class::object())?;
458
+ let ruby = Ruby::get().unwrap();
459
+ let rb_embedding_model = rb_candle.define_class("EmbeddingModel", ruby.class_object())?;
443
460
  rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
444
461
  // Expose embedding with an optional pooling_method argument (default: "pooled")
445
462
  rb_embedding_model.define_method("_embedding", method!(EmbeddingModel::embedding, 2))?;
@@ -1,14 +1,16 @@
1
1
  use magnus::Error;
2
2
 
3
3
  pub fn wrap_std_err(err: Box<dyn std::error::Error + Send + Sync>) -> Error {
4
- Error::new(magnus::exception::runtime_error(), err.to_string())
4
+ let ruby = magnus::Ruby::get().unwrap();
5
+ Error::new(ruby.exception_runtime_error(), err.to_string())
5
6
  }
6
7
 
7
8
  pub fn wrap_candle_err(err: candle_core::Error) -> Error {
8
- Error::new(magnus::exception::runtime_error(), err.to_string())
9
+ let ruby = magnus::Ruby::get().unwrap();
10
+ Error::new(ruby.exception_runtime_error(), err.to_string())
9
11
  }
10
12
 
11
13
  pub fn wrap_hf_err(err: hf_hub::api::sync::ApiError) -> Error {
12
- Error::new(magnus::exception::runtime_error(), err.to_string())
14
+ let ruby = magnus::Ruby::get().unwrap();
15
+ Error::new(ruby.exception_runtime_error(), err.to_string())
13
16
  }
14
-