red-candle 1.3.0 → 1.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +11 -20
- data/ext/candle/Cargo.toml +1 -1
- data/ext/candle/src/llm/constrained_generation_test.rs +79 -0
- data/ext/candle/src/llm/text_generation.rs +40 -50
- data/ext/candle/src/ruby/device.rs +8 -7
- data/ext/candle/src/ruby/dtype.rs +3 -2
- data/ext/candle/src/ruby/embedding_model.rs +31 -14
- data/ext/candle/src/ruby/errors.rs +6 -4
- data/ext/candle/src/ruby/llm.rs +78 -68
- data/ext/candle/src/ruby/ner.rs +106 -95
- data/ext/candle/src/ruby/reranker.rs +51 -38
- data/ext/candle/src/ruby/structured.rs +61 -16
- data/ext/candle/src/ruby/tensor.rs +7 -6
- data/ext/candle/src/ruby/tokenizer.rs +101 -84
- data/lib/candle/llm.rb +77 -3
- data/lib/candle/version.rb +1 -1
- metadata +31 -6
data/ext/candle/src/ruby/llm.rs
CHANGED
|
@@ -86,16 +86,16 @@ impl ModelType {
|
|
|
86
86
|
// Macro to extract parameters from Ruby hash to reduce boilerplate
|
|
87
87
|
macro_rules! extract_param {
|
|
88
88
|
// Basic parameter extraction
|
|
89
|
-
($kwargs:expr, $config:expr, $param:ident) => {
|
|
90
|
-
if let Some(value) = $kwargs.get(
|
|
89
|
+
($ruby:expr, $kwargs:expr, $config:expr, $param:ident) => {
|
|
90
|
+
if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
|
|
91
91
|
if let Ok(v) = TryConvert::try_convert(value) {
|
|
92
92
|
$config.$param = v;
|
|
93
93
|
}
|
|
94
94
|
}
|
|
95
95
|
};
|
|
96
96
|
// Optional parameter extraction (wraps in Some)
|
|
97
|
-
($kwargs:expr, $config:expr, $param:ident, optional) => {
|
|
98
|
-
if let Some(value) = $kwargs.get(
|
|
97
|
+
($ruby:expr, $kwargs:expr, $config:expr, $param:ident, optional) => {
|
|
98
|
+
if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
|
|
99
99
|
if let Ok(v) = TryConvert::try_convert(value) {
|
|
100
100
|
$config.$param = Some(v);
|
|
101
101
|
}
|
|
@@ -111,23 +111,24 @@ pub struct GenerationConfig {
|
|
|
111
111
|
|
|
112
112
|
impl GenerationConfig {
|
|
113
113
|
pub fn new(kwargs: RHash) -> Result<Self> {
|
|
114
|
+
let ruby = Ruby::get().unwrap();
|
|
114
115
|
let mut config = RustGenerationConfig::default();
|
|
115
|
-
|
|
116
|
+
|
|
116
117
|
// Extract basic parameters using macro
|
|
117
|
-
extract_param!(kwargs, config, max_length);
|
|
118
|
-
extract_param!(kwargs, config, temperature);
|
|
119
|
-
extract_param!(kwargs, config, top_p, optional);
|
|
120
|
-
extract_param!(kwargs, config, top_k, optional);
|
|
121
|
-
extract_param!(kwargs, config, repetition_penalty);
|
|
122
|
-
extract_param!(kwargs, config, repetition_penalty_last_n);
|
|
123
|
-
extract_param!(kwargs, config, seed);
|
|
124
|
-
extract_param!(kwargs, config, include_prompt);
|
|
125
|
-
extract_param!(kwargs, config, debug_tokens);
|
|
126
|
-
extract_param!(kwargs, config, stop_on_constraint_satisfaction);
|
|
127
|
-
extract_param!(kwargs, config, stop_on_match);
|
|
128
|
-
|
|
118
|
+
extract_param!(ruby, kwargs, config, max_length);
|
|
119
|
+
extract_param!(ruby, kwargs, config, temperature);
|
|
120
|
+
extract_param!(ruby, kwargs, config, top_p, optional);
|
|
121
|
+
extract_param!(ruby, kwargs, config, top_k, optional);
|
|
122
|
+
extract_param!(ruby, kwargs, config, repetition_penalty);
|
|
123
|
+
extract_param!(ruby, kwargs, config, repetition_penalty_last_n);
|
|
124
|
+
extract_param!(ruby, kwargs, config, seed);
|
|
125
|
+
extract_param!(ruby, kwargs, config, include_prompt);
|
|
126
|
+
extract_param!(ruby, kwargs, config, debug_tokens);
|
|
127
|
+
extract_param!(ruby, kwargs, config, stop_on_constraint_satisfaction);
|
|
128
|
+
extract_param!(ruby, kwargs, config, stop_on_match);
|
|
129
|
+
|
|
129
130
|
// Handle special cases that need custom logic
|
|
130
|
-
if let Some(value) = kwargs.get(
|
|
131
|
+
if let Some(value) = kwargs.get(ruby.to_symbol("stop_sequences")) {
|
|
131
132
|
if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
|
|
132
133
|
config.stop_sequences = arr
|
|
133
134
|
.into_iter()
|
|
@@ -135,13 +136,13 @@ impl GenerationConfig {
|
|
|
135
136
|
.collect();
|
|
136
137
|
}
|
|
137
138
|
}
|
|
138
|
-
|
|
139
|
-
if let Some(value) = kwargs.get(
|
|
139
|
+
|
|
140
|
+
if let Some(value) = kwargs.get(ruby.to_symbol("constraint")) {
|
|
140
141
|
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
|
141
142
|
config.constraint = Some(Arc::clone(&constraint.index));
|
|
142
143
|
}
|
|
143
144
|
}
|
|
144
|
-
|
|
145
|
+
|
|
145
146
|
Ok(Self { inner: config })
|
|
146
147
|
}
|
|
147
148
|
|
|
@@ -204,19 +205,20 @@ impl GenerationConfig {
|
|
|
204
205
|
|
|
205
206
|
/// Get all options as a hash
|
|
206
207
|
pub fn options(&self) -> Result<RHash> {
|
|
207
|
-
let
|
|
208
|
-
|
|
208
|
+
let ruby = Ruby::get().unwrap();
|
|
209
|
+
let hash = ruby.hash_new();
|
|
210
|
+
|
|
209
211
|
hash.aset("max_length", self.inner.max_length)?;
|
|
210
212
|
hash.aset("temperature", self.inner.temperature)?;
|
|
211
|
-
|
|
213
|
+
|
|
212
214
|
if let Some(top_p) = self.inner.top_p {
|
|
213
215
|
hash.aset("top_p", top_p)?;
|
|
214
216
|
}
|
|
215
|
-
|
|
217
|
+
|
|
216
218
|
if let Some(top_k) = self.inner.top_k {
|
|
217
219
|
hash.aset("top_k", top_k)?;
|
|
218
220
|
}
|
|
219
|
-
|
|
221
|
+
|
|
220
222
|
hash.aset("repetition_penalty", self.inner.repetition_penalty)?;
|
|
221
223
|
hash.aset("repetition_penalty_last_n", self.inner.repetition_penalty_last_n)?;
|
|
222
224
|
hash.aset("seed", self.inner.seed)?;
|
|
@@ -225,11 +227,11 @@ impl GenerationConfig {
|
|
|
225
227
|
hash.aset("debug_tokens", self.inner.debug_tokens)?;
|
|
226
228
|
hash.aset("stop_on_constraint_satisfaction", self.inner.stop_on_constraint_satisfaction)?;
|
|
227
229
|
hash.aset("stop_on_match", self.inner.stop_on_match)?;
|
|
228
|
-
|
|
230
|
+
|
|
229
231
|
if self.inner.constraint.is_some() {
|
|
230
232
|
hash.aset("has_constraint", true)?;
|
|
231
233
|
}
|
|
232
|
-
|
|
234
|
+
|
|
233
235
|
Ok(hash)
|
|
234
236
|
}
|
|
235
237
|
}
|
|
@@ -245,18 +247,18 @@ pub struct LLM {
|
|
|
245
247
|
impl LLM {
|
|
246
248
|
/// Create a new LLM from a pretrained model
|
|
247
249
|
pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
|
|
250
|
+
let ruby = Ruby::get().unwrap();
|
|
251
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
248
252
|
let device = device.unwrap_or(Device::best());
|
|
249
253
|
let candle_device = device.as_device()?;
|
|
250
|
-
|
|
251
|
-
// For now, we'll use tokio runtime directly
|
|
252
|
-
// In production, you might want to share a runtime
|
|
254
|
+
|
|
253
255
|
let rt = tokio::runtime::Runtime::new()
|
|
254
|
-
.map_err(|e| Error::new(
|
|
255
|
-
|
|
256
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create runtime: {}", e)))?;
|
|
257
|
+
|
|
256
258
|
// Determine model type from ID and whether it's quantized
|
|
257
259
|
let model_lower = model_id.to_lowercase();
|
|
258
260
|
let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
|
|
259
|
-
|
|
261
|
+
|
|
260
262
|
// Extract tokenizer source if provided in model_id (for both GGUF and regular models)
|
|
261
263
|
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
|
262
264
|
let (id, _tok) = model_id.split_at(pos);
|
|
@@ -266,17 +268,17 @@ impl LLM {
|
|
|
266
268
|
};
|
|
267
269
|
|
|
268
270
|
let model = if is_quantized {
|
|
269
|
-
|
|
271
|
+
|
|
270
272
|
// Use unified GGUF loader for all quantized models
|
|
271
273
|
let gguf_model = rt.block_on(async {
|
|
272
274
|
RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
|
|
273
275
|
})
|
|
274
|
-
.map_err(|e| Error::new(
|
|
276
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load GGUF model: {}", e)))?;
|
|
275
277
|
ModelType::QuantizedGGUF(gguf_model)
|
|
276
278
|
} else {
|
|
277
279
|
// Load non-quantized models based on type
|
|
278
280
|
let model_lower_clean = model_id_clean.to_lowercase();
|
|
279
|
-
|
|
281
|
+
|
|
280
282
|
if model_lower_clean.contains("mistral") {
|
|
281
283
|
let mistral = if tokenizer_source.is_some() {
|
|
282
284
|
rt.block_on(async {
|
|
@@ -287,7 +289,7 @@ impl LLM {
|
|
|
287
289
|
RustMistral::from_pretrained(&model_id_clean, candle_device).await
|
|
288
290
|
})
|
|
289
291
|
}
|
|
290
|
-
.map_err(|e| Error::new(
|
|
292
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
291
293
|
ModelType::Mistral(mistral)
|
|
292
294
|
} else if model_lower_clean.contains("llama") || model_lower_clean.contains("meta-llama") || model_lower_clean.contains("tinyllama") {
|
|
293
295
|
let llama = if tokenizer_source.is_some() {
|
|
@@ -299,7 +301,7 @@ impl LLM {
|
|
|
299
301
|
RustLlama::from_pretrained(&model_id_clean, candle_device).await
|
|
300
302
|
})
|
|
301
303
|
}
|
|
302
|
-
.map_err(|e| Error::new(
|
|
304
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
303
305
|
ModelType::Llama(llama)
|
|
304
306
|
} else if model_lower_clean.contains("gemma") || model_lower_clean.contains("google/gemma") {
|
|
305
307
|
let gemma = if tokenizer_source.is_some() {
|
|
@@ -311,7 +313,7 @@ impl LLM {
|
|
|
311
313
|
RustGemma::from_pretrained(&model_id_clean, candle_device).await
|
|
312
314
|
})
|
|
313
315
|
}
|
|
314
|
-
.map_err(|e| Error::new(
|
|
316
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
315
317
|
ModelType::Gemma(gemma)
|
|
316
318
|
} else if model_lower_clean.contains("qwen") {
|
|
317
319
|
let qwen = if tokenizer_source.is_some() {
|
|
@@ -323,7 +325,7 @@ impl LLM {
|
|
|
323
325
|
RustQwen::from_pretrained(&model_id_clean, candle_device).await
|
|
324
326
|
})
|
|
325
327
|
}
|
|
326
|
-
.map_err(|e| Error::new(
|
|
328
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
327
329
|
ModelType::Qwen(qwen)
|
|
328
330
|
} else if model_lower_clean.contains("phi") {
|
|
329
331
|
let phi = if tokenizer_source.is_some() {
|
|
@@ -335,16 +337,16 @@ impl LLM {
|
|
|
335
337
|
RustPhi::from_pretrained(&model_id_clean, candle_device).await
|
|
336
338
|
})
|
|
337
339
|
}
|
|
338
|
-
.map_err(|e| Error::new(
|
|
340
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
339
341
|
ModelType::Phi(phi)
|
|
340
342
|
} else {
|
|
341
343
|
return Err(Error::new(
|
|
342
|
-
|
|
344
|
+
runtime_error,
|
|
343
345
|
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, and Phi models are supported.", model_id_clean),
|
|
344
346
|
));
|
|
345
347
|
}
|
|
346
348
|
};
|
|
347
|
-
|
|
349
|
+
|
|
348
350
|
Ok(Self {
|
|
349
351
|
model: std::sync::Arc::new(std::sync::Mutex::new(RefCell::new(model))),
|
|
350
352
|
model_id,
|
|
@@ -354,18 +356,19 @@ impl LLM {
|
|
|
354
356
|
|
|
355
357
|
/// Generate text from a prompt
|
|
356
358
|
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
|
359
|
+
let ruby = Ruby::get().unwrap();
|
|
357
360
|
let config = config
|
|
358
361
|
.map(|c| c.inner.clone())
|
|
359
362
|
.unwrap_or_default();
|
|
360
|
-
|
|
363
|
+
|
|
361
364
|
let model = match self.model.lock() {
|
|
362
365
|
Ok(guard) => guard,
|
|
363
366
|
Err(poisoned) => poisoned.into_inner(),
|
|
364
367
|
};
|
|
365
368
|
let mut model_ref = model.borrow_mut();
|
|
366
|
-
|
|
369
|
+
|
|
367
370
|
model_ref.generate(&prompt, &config)
|
|
368
|
-
.map_err(|e| Error::new(
|
|
371
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
|
|
369
372
|
}
|
|
370
373
|
|
|
371
374
|
/// Generate text with streaming output
|
|
@@ -373,26 +376,27 @@ impl LLM {
|
|
|
373
376
|
let config = config
|
|
374
377
|
.map(|c| c.inner.clone())
|
|
375
378
|
.unwrap_or_default();
|
|
376
|
-
|
|
379
|
+
|
|
377
380
|
let ruby = Ruby::get().unwrap();
|
|
381
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
378
382
|
let block = ruby.block_proc();
|
|
379
383
|
if let Err(_) = block {
|
|
380
|
-
return Err(Error::new(
|
|
384
|
+
return Err(Error::new(runtime_error, "No block given"));
|
|
381
385
|
}
|
|
382
386
|
let block = block.unwrap();
|
|
383
|
-
|
|
387
|
+
|
|
384
388
|
let model = match self.model.lock() {
|
|
385
389
|
Ok(guard) => guard,
|
|
386
390
|
Err(poisoned) => poisoned.into_inner(),
|
|
387
391
|
};
|
|
388
392
|
let mut model_ref = model.borrow_mut();
|
|
389
|
-
|
|
393
|
+
|
|
390
394
|
let result = model_ref.generate_stream(&prompt, &config, |token| {
|
|
391
395
|
// Call the Ruby block with each token
|
|
392
396
|
let _ = block.call::<(String,), Value>((token.to_string(),));
|
|
393
397
|
});
|
|
394
|
-
|
|
395
|
-
result.map_err(|e| Error::new(
|
|
398
|
+
|
|
399
|
+
result.map_err(|e| Error::new(runtime_error, format!("Generation failed: {}", e)))
|
|
396
400
|
}
|
|
397
401
|
|
|
398
402
|
/// Get the model name
|
|
@@ -477,40 +481,41 @@ impl LLM {
|
|
|
477
481
|
|
|
478
482
|
/// Apply chat template to messages
|
|
479
483
|
pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
|
|
484
|
+
let ruby = Ruby::get().unwrap();
|
|
480
485
|
// Convert Ruby array to JSON values
|
|
481
486
|
let json_messages: Vec<serde_json::Value> = messages
|
|
482
487
|
.into_iter()
|
|
483
488
|
.filter_map(|msg| {
|
|
484
489
|
if let Ok(hash) = <RHash as TryConvert>::try_convert(msg) {
|
|
485
490
|
let mut json_msg = serde_json::Map::new();
|
|
486
|
-
|
|
487
|
-
if let Some(role) = hash.get(
|
|
491
|
+
|
|
492
|
+
if let Some(role) = hash.get(ruby.to_symbol("role")) {
|
|
488
493
|
if let Ok(role_str) = <String as TryConvert>::try_convert(role) {
|
|
489
494
|
json_msg.insert("role".to_string(), serde_json::Value::String(role_str));
|
|
490
495
|
}
|
|
491
496
|
}
|
|
492
|
-
|
|
493
|
-
if let Some(content) = hash.get(
|
|
497
|
+
|
|
498
|
+
if let Some(content) = hash.get(ruby.to_symbol("content")) {
|
|
494
499
|
if let Ok(content_str) = <String as TryConvert>::try_convert(content) {
|
|
495
500
|
json_msg.insert("content".to_string(), serde_json::Value::String(content_str));
|
|
496
501
|
}
|
|
497
502
|
}
|
|
498
|
-
|
|
503
|
+
|
|
499
504
|
Some(serde_json::Value::Object(json_msg))
|
|
500
505
|
} else {
|
|
501
506
|
None
|
|
502
507
|
}
|
|
503
508
|
})
|
|
504
509
|
.collect();
|
|
505
|
-
|
|
510
|
+
|
|
506
511
|
let model = match self.model.lock() {
|
|
507
512
|
Ok(guard) => guard,
|
|
508
513
|
Err(poisoned) => poisoned.into_inner(),
|
|
509
514
|
};
|
|
510
515
|
let model_ref = model.borrow();
|
|
511
|
-
|
|
516
|
+
|
|
512
517
|
model_ref.apply_chat_template(&json_messages)
|
|
513
|
-
.map_err(|e| Error::new(
|
|
518
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to apply chat template: {}", e)))
|
|
514
519
|
}
|
|
515
520
|
|
|
516
521
|
/// Get the model ID
|
|
@@ -520,7 +525,8 @@ impl LLM {
|
|
|
520
525
|
|
|
521
526
|
/// Get model options
|
|
522
527
|
pub fn options(&self) -> Result<RHash> {
|
|
523
|
-
let
|
|
528
|
+
let ruby = Ruby::get().unwrap();
|
|
529
|
+
let hash = ruby.hash_new();
|
|
524
530
|
|
|
525
531
|
// Basic metadata
|
|
526
532
|
hash.aset("model_id", self.model_id.clone())?;
|
|
@@ -587,15 +593,19 @@ fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
|
|
|
587
593
|
let device: Device = TryConvert::try_convert(args[1])?;
|
|
588
594
|
LLM::from_pretrained(model_id, Some(device))
|
|
589
595
|
},
|
|
590
|
-
_ =>
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
596
|
+
_ => {
|
|
597
|
+
let ruby = Ruby::get().unwrap();
|
|
598
|
+
Err(Error::new(
|
|
599
|
+
ruby.exception_arg_error(),
|
|
600
|
+
"wrong number of arguments (expected 1..2)"
|
|
601
|
+
))
|
|
602
|
+
}
|
|
594
603
|
}
|
|
595
604
|
}
|
|
596
605
|
|
|
597
606
|
pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
598
|
-
let
|
|
607
|
+
let ruby = Ruby::get().unwrap();
|
|
608
|
+
let rb_generation_config = rb_candle.define_class("GenerationConfig", ruby.class_object())?;
|
|
599
609
|
rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
|
|
600
610
|
rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
|
|
601
611
|
|
|
@@ -613,7 +623,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
|
613
623
|
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
|
614
624
|
rb_generation_config.define_method("options", method!(GenerationConfig::options, 0))?;
|
|
615
625
|
|
|
616
|
-
let rb_llm = rb_candle.define_class("LLM",
|
|
626
|
+
let rb_llm = rb_candle.define_class("LLM", ruby.class_object())?;
|
|
617
627
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
|
618
628
|
rb_llm.define_method("_generate", method!(LLM::generate, 2))?;
|
|
619
629
|
rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
|