red-candle 1.3.0 → 1.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -86,16 +86,16 @@ impl ModelType {
86
86
  // Macro to extract parameters from Ruby hash to reduce boilerplate
87
87
  macro_rules! extract_param {
88
88
  // Basic parameter extraction
89
- ($kwargs:expr, $config:expr, $param:ident) => {
90
- if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
89
+ ($ruby:expr, $kwargs:expr, $config:expr, $param:ident) => {
90
+ if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
91
91
  if let Ok(v) = TryConvert::try_convert(value) {
92
92
  $config.$param = v;
93
93
  }
94
94
  }
95
95
  };
96
96
  // Optional parameter extraction (wraps in Some)
97
- ($kwargs:expr, $config:expr, $param:ident, optional) => {
98
- if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
97
+ ($ruby:expr, $kwargs:expr, $config:expr, $param:ident, optional) => {
98
+ if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
99
99
  if let Ok(v) = TryConvert::try_convert(value) {
100
100
  $config.$param = Some(v);
101
101
  }
@@ -111,23 +111,24 @@ pub struct GenerationConfig {
111
111
 
112
112
  impl GenerationConfig {
113
113
  pub fn new(kwargs: RHash) -> Result<Self> {
114
+ let ruby = Ruby::get().unwrap();
114
115
  let mut config = RustGenerationConfig::default();
115
-
116
+
116
117
  // Extract basic parameters using macro
117
- extract_param!(kwargs, config, max_length);
118
- extract_param!(kwargs, config, temperature);
119
- extract_param!(kwargs, config, top_p, optional);
120
- extract_param!(kwargs, config, top_k, optional);
121
- extract_param!(kwargs, config, repetition_penalty);
122
- extract_param!(kwargs, config, repetition_penalty_last_n);
123
- extract_param!(kwargs, config, seed);
124
- extract_param!(kwargs, config, include_prompt);
125
- extract_param!(kwargs, config, debug_tokens);
126
- extract_param!(kwargs, config, stop_on_constraint_satisfaction);
127
- extract_param!(kwargs, config, stop_on_match);
128
-
118
+ extract_param!(ruby, kwargs, config, max_length);
119
+ extract_param!(ruby, kwargs, config, temperature);
120
+ extract_param!(ruby, kwargs, config, top_p, optional);
121
+ extract_param!(ruby, kwargs, config, top_k, optional);
122
+ extract_param!(ruby, kwargs, config, repetition_penalty);
123
+ extract_param!(ruby, kwargs, config, repetition_penalty_last_n);
124
+ extract_param!(ruby, kwargs, config, seed);
125
+ extract_param!(ruby, kwargs, config, include_prompt);
126
+ extract_param!(ruby, kwargs, config, debug_tokens);
127
+ extract_param!(ruby, kwargs, config, stop_on_constraint_satisfaction);
128
+ extract_param!(ruby, kwargs, config, stop_on_match);
129
+
129
130
  // Handle special cases that need custom logic
130
- if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
131
+ if let Some(value) = kwargs.get(ruby.to_symbol("stop_sequences")) {
131
132
  if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
132
133
  config.stop_sequences = arr
133
134
  .into_iter()
@@ -135,13 +136,13 @@ impl GenerationConfig {
135
136
  .collect();
136
137
  }
137
138
  }
138
-
139
- if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
139
+
140
+ if let Some(value) = kwargs.get(ruby.to_symbol("constraint")) {
140
141
  if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
141
142
  config.constraint = Some(Arc::clone(&constraint.index));
142
143
  }
143
144
  }
144
-
145
+
145
146
  Ok(Self { inner: config })
146
147
  }
147
148
 
@@ -204,19 +205,20 @@ impl GenerationConfig {
204
205
 
205
206
  /// Get all options as a hash
206
207
  pub fn options(&self) -> Result<RHash> {
207
- let hash = RHash::new();
208
-
208
+ let ruby = Ruby::get().unwrap();
209
+ let hash = ruby.hash_new();
210
+
209
211
  hash.aset("max_length", self.inner.max_length)?;
210
212
  hash.aset("temperature", self.inner.temperature)?;
211
-
213
+
212
214
  if let Some(top_p) = self.inner.top_p {
213
215
  hash.aset("top_p", top_p)?;
214
216
  }
215
-
217
+
216
218
  if let Some(top_k) = self.inner.top_k {
217
219
  hash.aset("top_k", top_k)?;
218
220
  }
219
-
221
+
220
222
  hash.aset("repetition_penalty", self.inner.repetition_penalty)?;
221
223
  hash.aset("repetition_penalty_last_n", self.inner.repetition_penalty_last_n)?;
222
224
  hash.aset("seed", self.inner.seed)?;
@@ -225,11 +227,11 @@ impl GenerationConfig {
225
227
  hash.aset("debug_tokens", self.inner.debug_tokens)?;
226
228
  hash.aset("stop_on_constraint_satisfaction", self.inner.stop_on_constraint_satisfaction)?;
227
229
  hash.aset("stop_on_match", self.inner.stop_on_match)?;
228
-
230
+
229
231
  if self.inner.constraint.is_some() {
230
232
  hash.aset("has_constraint", true)?;
231
233
  }
232
-
234
+
233
235
  Ok(hash)
234
236
  }
235
237
  }
@@ -245,18 +247,18 @@ pub struct LLM {
245
247
  impl LLM {
246
248
  /// Create a new LLM from a pretrained model
247
249
  pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
250
+ let ruby = Ruby::get().unwrap();
251
+ let runtime_error = ruby.exception_runtime_error();
248
252
  let device = device.unwrap_or(Device::best());
249
253
  let candle_device = device.as_device()?;
250
-
251
- // For now, we'll use tokio runtime directly
252
- // In production, you might want to share a runtime
254
+
253
255
  let rt = tokio::runtime::Runtime::new()
254
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
255
-
256
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create runtime: {}", e)))?;
257
+
256
258
  // Determine model type from ID and whether it's quantized
257
259
  let model_lower = model_id.to_lowercase();
258
260
  let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
259
-
261
+
260
262
  // Extract tokenizer source if provided in model_id (for both GGUF and regular models)
261
263
  let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
262
264
  let (id, _tok) = model_id.split_at(pos);
@@ -266,17 +268,17 @@ impl LLM {
266
268
  };
267
269
 
268
270
  let model = if is_quantized {
269
-
271
+
270
272
  // Use unified GGUF loader for all quantized models
271
273
  let gguf_model = rt.block_on(async {
272
274
  RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
273
275
  })
274
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
276
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load GGUF model: {}", e)))?;
275
277
  ModelType::QuantizedGGUF(gguf_model)
276
278
  } else {
277
279
  // Load non-quantized models based on type
278
280
  let model_lower_clean = model_id_clean.to_lowercase();
279
-
281
+
280
282
  if model_lower_clean.contains("mistral") {
281
283
  let mistral = if tokenizer_source.is_some() {
282
284
  rt.block_on(async {
@@ -287,7 +289,7 @@ impl LLM {
287
289
  RustMistral::from_pretrained(&model_id_clean, candle_device).await
288
290
  })
289
291
  }
290
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
292
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
291
293
  ModelType::Mistral(mistral)
292
294
  } else if model_lower_clean.contains("llama") || model_lower_clean.contains("meta-llama") || model_lower_clean.contains("tinyllama") {
293
295
  let llama = if tokenizer_source.is_some() {
@@ -299,7 +301,7 @@ impl LLM {
299
301
  RustLlama::from_pretrained(&model_id_clean, candle_device).await
300
302
  })
301
303
  }
302
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
304
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
303
305
  ModelType::Llama(llama)
304
306
  } else if model_lower_clean.contains("gemma") || model_lower_clean.contains("google/gemma") {
305
307
  let gemma = if tokenizer_source.is_some() {
@@ -311,7 +313,7 @@ impl LLM {
311
313
  RustGemma::from_pretrained(&model_id_clean, candle_device).await
312
314
  })
313
315
  }
314
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
316
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
315
317
  ModelType::Gemma(gemma)
316
318
  } else if model_lower_clean.contains("qwen") {
317
319
  let qwen = if tokenizer_source.is_some() {
@@ -323,7 +325,7 @@ impl LLM {
323
325
  RustQwen::from_pretrained(&model_id_clean, candle_device).await
324
326
  })
325
327
  }
326
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
328
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
327
329
  ModelType::Qwen(qwen)
328
330
  } else if model_lower_clean.contains("phi") {
329
331
  let phi = if tokenizer_source.is_some() {
@@ -335,16 +337,16 @@ impl LLM {
335
337
  RustPhi::from_pretrained(&model_id_clean, candle_device).await
336
338
  })
337
339
  }
338
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
340
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
339
341
  ModelType::Phi(phi)
340
342
  } else {
341
343
  return Err(Error::new(
342
- magnus::exception::runtime_error(),
344
+ runtime_error,
343
345
  format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id_clean),
344
346
  ));
345
347
  }
346
348
  };
347
-
349
+
348
350
  Ok(Self {
349
351
  model: std::sync::Arc::new(std::sync::Mutex::new(RefCell::new(model))),
350
352
  model_id,
@@ -354,18 +356,19 @@ impl LLM {
354
356
 
355
357
  /// Generate text from a prompt
356
358
  pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
359
+ let ruby = Ruby::get().unwrap();
357
360
  let config = config
358
361
  .map(|c| c.inner.clone())
359
362
  .unwrap_or_default();
360
-
363
+
361
364
  let model = match self.model.lock() {
362
365
  Ok(guard) => guard,
363
366
  Err(poisoned) => poisoned.into_inner(),
364
367
  };
365
368
  let mut model_ref = model.borrow_mut();
366
-
369
+
367
370
  model_ref.generate(&prompt, &config)
368
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Generation failed: {}", e)))
371
+ .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
369
372
  }
370
373
 
371
374
  /// Generate text with streaming output
@@ -373,26 +376,27 @@ impl LLM {
373
376
  let config = config
374
377
  .map(|c| c.inner.clone())
375
378
  .unwrap_or_default();
376
-
379
+
377
380
  let ruby = Ruby::get().unwrap();
381
+ let runtime_error = ruby.exception_runtime_error();
378
382
  let block = ruby.block_proc();
379
383
  if let Err(_) = block {
380
- return Err(Error::new(magnus::exception::runtime_error(), "No block given"));
384
+ return Err(Error::new(runtime_error, "No block given"));
381
385
  }
382
386
  let block = block.unwrap();
383
-
387
+
384
388
  let model = match self.model.lock() {
385
389
  Ok(guard) => guard,
386
390
  Err(poisoned) => poisoned.into_inner(),
387
391
  };
388
392
  let mut model_ref = model.borrow_mut();
389
-
393
+
390
394
  let result = model_ref.generate_stream(&prompt, &config, |token| {
391
395
  // Call the Ruby block with each token
392
396
  let _ = block.call::<(String,), Value>((token.to_string(),));
393
397
  });
394
-
395
- result.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Generation failed: {}", e)))
398
+
399
+ result.map_err(|e| Error::new(runtime_error, format!("Generation failed: {}", e)))
396
400
  }
397
401
 
398
402
  /// Get the model name
@@ -477,40 +481,41 @@ impl LLM {
477
481
 
478
482
  /// Apply chat template to messages
479
483
  pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
484
+ let ruby = Ruby::get().unwrap();
480
485
  // Convert Ruby array to JSON values
481
486
  let json_messages: Vec<serde_json::Value> = messages
482
487
  .into_iter()
483
488
  .filter_map(|msg| {
484
489
  if let Ok(hash) = <RHash as TryConvert>::try_convert(msg) {
485
490
  let mut json_msg = serde_json::Map::new();
486
-
487
- if let Some(role) = hash.get(magnus::Symbol::new("role")) {
491
+
492
+ if let Some(role) = hash.get(ruby.to_symbol("role")) {
488
493
  if let Ok(role_str) = <String as TryConvert>::try_convert(role) {
489
494
  json_msg.insert("role".to_string(), serde_json::Value::String(role_str));
490
495
  }
491
496
  }
492
-
493
- if let Some(content) = hash.get(magnus::Symbol::new("content")) {
497
+
498
+ if let Some(content) = hash.get(ruby.to_symbol("content")) {
494
499
  if let Ok(content_str) = <String as TryConvert>::try_convert(content) {
495
500
  json_msg.insert("content".to_string(), serde_json::Value::String(content_str));
496
501
  }
497
502
  }
498
-
503
+
499
504
  Some(serde_json::Value::Object(json_msg))
500
505
  } else {
501
506
  None
502
507
  }
503
508
  })
504
509
  .collect();
505
-
510
+
506
511
  let model = match self.model.lock() {
507
512
  Ok(guard) => guard,
508
513
  Err(poisoned) => poisoned.into_inner(),
509
514
  };
510
515
  let model_ref = model.borrow();
511
-
516
+
512
517
  model_ref.apply_chat_template(&json_messages)
513
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to apply chat template: {}", e)))
518
+ .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to apply chat template: {}", e)))
514
519
  }
515
520
 
516
521
  /// Get the model ID
@@ -520,7 +525,8 @@ impl LLM {
520
525
 
521
526
  /// Get model options
522
527
  pub fn options(&self) -> Result<RHash> {
523
- let hash = RHash::new();
528
+ let ruby = Ruby::get().unwrap();
529
+ let hash = ruby.hash_new();
524
530
 
525
531
  // Basic metadata
526
532
  hash.aset("model_id", self.model_id.clone())?;
@@ -587,15 +593,19 @@ fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
587
593
  let device: Device = TryConvert::try_convert(args[1])?;
588
594
  LLM::from_pretrained(model_id, Some(device))
589
595
  },
590
- _ => Err(Error::new(
591
- magnus::exception::arg_error(),
592
- "wrong number of arguments (expected 1..2)"
593
- ))
596
+ _ => {
597
+ let ruby = Ruby::get().unwrap();
598
+ Err(Error::new(
599
+ ruby.exception_arg_error(),
600
+ "wrong number of arguments (expected 1..2)"
601
+ ))
602
+ }
594
603
  }
595
604
  }
596
605
 
597
606
  pub fn init_llm(rb_candle: RModule) -> Result<()> {
598
- let rb_generation_config = rb_candle.define_class("GenerationConfig", magnus::class::object())?;
607
+ let ruby = Ruby::get().unwrap();
608
+ let rb_generation_config = rb_candle.define_class("GenerationConfig", ruby.class_object())?;
599
609
  rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
600
610
  rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
601
611
 
@@ -613,7 +623,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
613
623
  rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
614
624
  rb_generation_config.define_method("options", method!(GenerationConfig::options, 0))?;
615
625
 
616
- let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
626
+ let rb_llm = rb_candle.define_class("LLM", ruby.class_object())?;
617
627
  rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
618
628
  rb_llm.define_method("_generate", method!(LLM::generate, 2))?;
619
629
  rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;