red-candle 1.1.0 → 1.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +65 -1
- data/Rakefile +40 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +199 -6
- data/ext/candle/src/llm/gemma.rs +21 -5
- data/ext/candle/src/llm/generation_config.rs +6 -0
- data/ext/candle/src/llm/llama.rs +21 -5
- data/ext/candle/src/llm/mistral.rs +21 -5
- data/ext/candle/src/llm/phi.rs +21 -5
- data/ext/candle/src/llm/quantized_gguf.rs +35 -6
- data/ext/candle/src/llm/qwen.rs +21 -5
- data/ext/candle/src/llm/text_generation.rs +121 -28
- data/ext/candle/src/ner.rs +25 -51
- data/ext/candle/src/reranker.rs +41 -68
- data/ext/candle/src/ruby/device.rs +2 -1
- data/ext/candle/src/ruby/dtype.rs +1 -0
- data/ext/candle/src/ruby/errors.rs +1 -0
- data/ext/candle/src/ruby/llm.rs +81 -55
- data/ext/candle/src/ruby/tensor.rs +2 -1
- data/ext/candle/src/tokenizer/mod.rs +2 -1
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/llm.rb +129 -34
- data/lib/candle/version.rb +1 -1
- metadata +4 -2
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -83,6 +83,26 @@ impl ModelType {
|
|
83
83
|
}
|
84
84
|
}
|
85
85
|
|
86
|
+
// Macro to extract parameters from Ruby hash to reduce boilerplate
|
87
|
+
macro_rules! extract_param {
|
88
|
+
// Basic parameter extraction
|
89
|
+
($kwargs:expr, $config:expr, $param:ident) => {
|
90
|
+
if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
|
91
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
92
|
+
$config.$param = v;
|
93
|
+
}
|
94
|
+
}
|
95
|
+
};
|
96
|
+
// Optional parameter extraction (wraps in Some)
|
97
|
+
($kwargs:expr, $config:expr, $param:ident, optional) => {
|
98
|
+
if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
|
99
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
100
|
+
$config.$param = Some(v);
|
101
|
+
}
|
102
|
+
}
|
103
|
+
};
|
104
|
+
}
|
105
|
+
|
86
106
|
#[derive(Clone, Debug)]
|
87
107
|
#[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
|
88
108
|
pub struct GenerationConfig {
|
@@ -93,55 +113,20 @@ impl GenerationConfig {
|
|
93
113
|
pub fn new(kwargs: RHash) -> Result<Self> {
|
94
114
|
let mut config = RustGenerationConfig::default();
|
95
115
|
|
96
|
-
// Extract
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
|
110
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
111
|
-
config.top_p = Some(v);
|
112
|
-
}
|
113
|
-
}
|
114
|
-
|
115
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
|
116
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
117
|
-
config.top_k = Some(v);
|
118
|
-
}
|
119
|
-
}
|
120
|
-
|
121
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
|
122
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
123
|
-
config.repetition_penalty = v;
|
124
|
-
}
|
125
|
-
}
|
126
|
-
|
127
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
|
128
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
129
|
-
config.repetition_penalty_last_n = v;
|
130
|
-
}
|
131
|
-
}
|
132
|
-
|
133
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
|
134
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
135
|
-
config.seed = v;
|
136
|
-
}
|
137
|
-
}
|
138
|
-
|
139
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
|
140
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
141
|
-
config.include_prompt = v;
|
142
|
-
}
|
143
|
-
}
|
116
|
+
// Extract basic parameters using macro
|
117
|
+
extract_param!(kwargs, config, max_length);
|
118
|
+
extract_param!(kwargs, config, temperature);
|
119
|
+
extract_param!(kwargs, config, top_p, optional);
|
120
|
+
extract_param!(kwargs, config, top_k, optional);
|
121
|
+
extract_param!(kwargs, config, repetition_penalty);
|
122
|
+
extract_param!(kwargs, config, repetition_penalty_last_n);
|
123
|
+
extract_param!(kwargs, config, seed);
|
124
|
+
extract_param!(kwargs, config, include_prompt);
|
125
|
+
extract_param!(kwargs, config, debug_tokens);
|
126
|
+
extract_param!(kwargs, config, stop_on_constraint_satisfaction);
|
127
|
+
extract_param!(kwargs, config, stop_on_match);
|
144
128
|
|
129
|
+
// Handle special cases that need custom logic
|
145
130
|
if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
|
146
131
|
if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
|
147
132
|
config.stop_sequences = arr
|
@@ -151,13 +136,6 @@ impl GenerationConfig {
|
|
151
136
|
}
|
152
137
|
}
|
153
138
|
|
154
|
-
if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
|
155
|
-
if let Ok(v) = TryConvert::try_convert(value) {
|
156
|
-
config.debug_tokens = v;
|
157
|
-
}
|
158
|
-
}
|
159
|
-
|
160
|
-
// Handle constraint parameter
|
161
139
|
if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
|
162
140
|
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
163
141
|
config.constraint = Some(Arc::clone(&constraint.index));
|
@@ -209,6 +187,15 @@ impl GenerationConfig {
|
|
209
187
|
pub fn debug_tokens(&self) -> bool {
|
210
188
|
self.inner.debug_tokens
|
211
189
|
}
|
190
|
+
|
191
|
+
pub fn stop_on_constraint_satisfaction(&self) -> bool {
|
192
|
+
self.inner.stop_on_constraint_satisfaction
|
193
|
+
}
|
194
|
+
|
195
|
+
pub fn stop_on_match(&self) -> bool {
|
196
|
+
self.inner.stop_on_match
|
197
|
+
}
|
198
|
+
|
212
199
|
pub fn constraint(&self) -> Option<StructuredConstraint> {
|
213
200
|
self.inner.constraint.as_ref().map(|c| StructuredConstraint {
|
214
201
|
index: Arc::clone(c),
|
@@ -372,6 +359,42 @@ impl LLM {
|
|
372
359
|
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
373
360
|
}
|
374
361
|
}
|
362
|
+
|
363
|
+
/// Get the EOS token string for this model
|
364
|
+
pub fn eos_token(&self) -> Result<String> {
|
365
|
+
let (eos_token_id, tokenizer_clone) = {
|
366
|
+
let model = match self.model.lock() {
|
367
|
+
Ok(guard) => guard,
|
368
|
+
Err(poisoned) => poisoned.into_inner(),
|
369
|
+
};
|
370
|
+
let model_ref = model.borrow();
|
371
|
+
|
372
|
+
// Get both EOS token ID and tokenizer clone in one lock scope
|
373
|
+
let eos_id = match &*model_ref {
|
374
|
+
ModelType::Mistral(m) => m.eos_token_id(),
|
375
|
+
ModelType::Llama(m) => m.eos_token_id(),
|
376
|
+
ModelType::Gemma(m) => m.eos_token_id(),
|
377
|
+
ModelType::Qwen(m) => m.eos_token_id(),
|
378
|
+
ModelType::Phi(m) => m.eos_token_id(),
|
379
|
+
ModelType::QuantizedGGUF(m) => m.eos_token_id(),
|
380
|
+
};
|
381
|
+
|
382
|
+
let tokenizer = match &*model_ref {
|
383
|
+
ModelType::Mistral(m) => m.tokenizer().clone(),
|
384
|
+
ModelType::Llama(m) => m.tokenizer().clone(),
|
385
|
+
ModelType::Gemma(m) => m.tokenizer().clone(),
|
386
|
+
ModelType::Qwen(m) => m.tokenizer().clone(),
|
387
|
+
ModelType::Phi(m) => m.tokenizer().clone(),
|
388
|
+
ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
|
389
|
+
};
|
390
|
+
|
391
|
+
(eos_id, tokenizer)
|
392
|
+
}; // Lock is released here
|
393
|
+
|
394
|
+
// Convert ID to string using the tokenizer
|
395
|
+
let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
|
396
|
+
tokenizer_wrapper.id_to_token(eos_token_id as i64)
|
397
|
+
}
|
375
398
|
|
376
399
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
377
400
|
pub fn clear_cache(&self) -> Result<()> {
|
@@ -460,6 +483,8 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
460
483
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
461
484
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
462
485
|
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
486
|
+
rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
|
487
|
+
rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
|
463
488
|
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
464
489
|
|
465
490
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
@@ -469,6 +494,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
469
494
|
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
470
495
|
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
471
496
|
rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
|
497
|
+
rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
|
472
498
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
473
499
|
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
474
500
|
|
@@ -0,0 +1,43 @@
|
|
1
|
+
use candle_core::Device as CoreDevice;
|
2
|
+
|
3
|
+
#[test]
|
4
|
+
fn test_device_creation() {
|
5
|
+
// CPU device should always work
|
6
|
+
let cpu = CoreDevice::Cpu;
|
7
|
+
assert!(matches!(cpu, CoreDevice::Cpu));
|
8
|
+
|
9
|
+
// Test device display
|
10
|
+
assert_eq!(format!("{:?}", cpu), "Cpu");
|
11
|
+
}
|
12
|
+
|
13
|
+
#[cfg(feature = "cuda")]
|
14
|
+
#[test]
|
15
|
+
#[ignore = "requires CUDA hardware"]
|
16
|
+
fn test_cuda_device_creation() {
|
17
|
+
// This might fail if no CUDA device is available
|
18
|
+
match CoreDevice::new_cuda(0) {
|
19
|
+
Ok(device) => assert!(matches!(device, CoreDevice::Cuda(_))),
|
20
|
+
Err(_) => println!("No CUDA device available for testing"),
|
21
|
+
}
|
22
|
+
}
|
23
|
+
|
24
|
+
#[cfg(feature = "metal")]
|
25
|
+
#[test]
|
26
|
+
#[ignore = "requires Metal hardware"]
|
27
|
+
fn test_metal_device_creation() {
|
28
|
+
// This might fail if no Metal device is available
|
29
|
+
match CoreDevice::new_metal(0) {
|
30
|
+
Ok(device) => assert!(matches!(device, CoreDevice::Metal(_))),
|
31
|
+
Err(_) => println!("No Metal device available for testing"),
|
32
|
+
}
|
33
|
+
}
|
34
|
+
|
35
|
+
#[test]
|
36
|
+
fn test_device_matching() {
|
37
|
+
let cpu1 = CoreDevice::Cpu;
|
38
|
+
let cpu2 = CoreDevice::Cpu;
|
39
|
+
|
40
|
+
// Same device types should match
|
41
|
+
assert!(matches!(cpu1, CoreDevice::Cpu));
|
42
|
+
assert!(matches!(cpu2, CoreDevice::Cpu));
|
43
|
+
}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
use candle_core::{Tensor, Device, DType};
|
2
|
+
|
3
|
+
#[test]
|
4
|
+
fn test_tensor_creation() {
|
5
|
+
let device = Device::Cpu;
|
6
|
+
|
7
|
+
// Test tensor creation from slice
|
8
|
+
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
9
|
+
let tensor = Tensor::new(&data[..], &device).unwrap();
|
10
|
+
assert_eq!(tensor.dims(), &[4]);
|
11
|
+
assert_eq!(tensor.dtype(), DType::F32);
|
12
|
+
|
13
|
+
// Test zeros
|
14
|
+
let zeros = Tensor::zeros(&[2, 3], DType::F32, &device).unwrap();
|
15
|
+
assert_eq!(zeros.dims(), &[2, 3]);
|
16
|
+
|
17
|
+
// Test ones
|
18
|
+
let ones = Tensor::ones(&[3, 2], DType::F32, &device).unwrap();
|
19
|
+
assert_eq!(ones.dims(), &[3, 2]);
|
20
|
+
}
|
21
|
+
|
22
|
+
#[test]
|
23
|
+
fn test_tensor_arithmetic() {
|
24
|
+
let device = Device::Cpu;
|
25
|
+
|
26
|
+
let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
|
27
|
+
let b = Tensor::new(&[4.0f32, 5.0, 6.0], &device).unwrap();
|
28
|
+
|
29
|
+
// Addition
|
30
|
+
let sum = a.add(&b).unwrap();
|
31
|
+
let sum_vec: Vec<f32> = sum.to_vec1().unwrap();
|
32
|
+
assert_eq!(sum_vec, vec![5.0, 7.0, 9.0]);
|
33
|
+
|
34
|
+
// Subtraction
|
35
|
+
let diff = a.sub(&b).unwrap();
|
36
|
+
let diff_vec: Vec<f32> = diff.to_vec1().unwrap();
|
37
|
+
assert_eq!(diff_vec, vec![-3.0, -3.0, -3.0]);
|
38
|
+
|
39
|
+
// Multiplication
|
40
|
+
let prod = a.mul(&b).unwrap();
|
41
|
+
let prod_vec: Vec<f32> = prod.to_vec1().unwrap();
|
42
|
+
assert_eq!(prod_vec, vec![4.0, 10.0, 18.0]);
|
43
|
+
}
|
44
|
+
|
45
|
+
#[test]
|
46
|
+
fn test_tensor_reshape() {
|
47
|
+
let device = Device::Cpu;
|
48
|
+
|
49
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device).unwrap();
|
50
|
+
|
51
|
+
// Reshape to 2x3
|
52
|
+
let reshaped = tensor.reshape(&[2, 3]).unwrap();
|
53
|
+
assert_eq!(reshaped.dims(), &[2, 3]);
|
54
|
+
|
55
|
+
// Reshape to 3x2
|
56
|
+
let reshaped = tensor.reshape(&[3, 2]).unwrap();
|
57
|
+
assert_eq!(reshaped.dims(), &[3, 2]);
|
58
|
+
}
|
59
|
+
|
60
|
+
#[test]
|
61
|
+
fn test_tensor_transpose() {
|
62
|
+
let device = Device::Cpu;
|
63
|
+
|
64
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
|
65
|
+
.unwrap()
|
66
|
+
.reshape(&[2, 2])
|
67
|
+
.unwrap();
|
68
|
+
|
69
|
+
let transposed = tensor.transpose(0, 1).unwrap();
|
70
|
+
assert_eq!(transposed.dims(), &[2, 2]);
|
71
|
+
|
72
|
+
let values: Vec<f32> = transposed.flatten_all().unwrap().to_vec1().unwrap();
|
73
|
+
assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
|
74
|
+
}
|
75
|
+
|
76
|
+
#[test]
|
77
|
+
fn test_tensor_reduction() {
|
78
|
+
let device = Device::Cpu;
|
79
|
+
|
80
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
|
81
|
+
|
82
|
+
// Sum
|
83
|
+
let sum = tensor.sum_all().unwrap();
|
84
|
+
let sum_val: f32 = sum.to_scalar().unwrap();
|
85
|
+
assert_eq!(sum_val, 10.0);
|
86
|
+
|
87
|
+
// Mean
|
88
|
+
let mean = tensor.mean_all().unwrap();
|
89
|
+
let mean_val: f32 = mean.to_scalar().unwrap();
|
90
|
+
assert_eq!(mean_val, 2.5);
|
91
|
+
}
|
92
|
+
|
93
|
+
#[test]
|
94
|
+
fn test_tensor_indexing() {
|
95
|
+
let device = Device::Cpu;
|
96
|
+
|
97
|
+
let tensor = Tensor::new(&[10.0f32, 20.0, 30.0, 40.0], &device).unwrap();
|
98
|
+
|
99
|
+
// Get element at index 0
|
100
|
+
let elem = tensor.get(0).unwrap();
|
101
|
+
let val: f32 = elem.to_scalar().unwrap();
|
102
|
+
assert_eq!(val, 10.0);
|
103
|
+
|
104
|
+
// Get element at index 2
|
105
|
+
let elem = tensor.get(2).unwrap();
|
106
|
+
let val: f32 = elem.to_scalar().unwrap();
|
107
|
+
assert_eq!(val, 30.0);
|
108
|
+
}
|
109
|
+
|
110
|
+
#[test]
|
111
|
+
fn test_tensor_matmul() {
|
112
|
+
let device = Device::Cpu;
|
113
|
+
|
114
|
+
// 2x3 matrix
|
115
|
+
let a = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device)
|
116
|
+
.unwrap()
|
117
|
+
.reshape(&[2, 3])
|
118
|
+
.unwrap();
|
119
|
+
|
120
|
+
// 3x2 matrix
|
121
|
+
let b = Tensor::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &device)
|
122
|
+
.unwrap()
|
123
|
+
.reshape(&[3, 2])
|
124
|
+
.unwrap();
|
125
|
+
|
126
|
+
// Matrix multiplication
|
127
|
+
let result = a.matmul(&b).unwrap();
|
128
|
+
assert_eq!(result.dims(), &[2, 2]);
|
129
|
+
|
130
|
+
let values: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
|
131
|
+
// [1*7 + 2*9 + 3*11, 1*8 + 2*10 + 3*12, 4*7 + 5*9 + 6*11, 4*8 + 5*10 + 6*12]
|
132
|
+
// = [58, 64, 139, 154]
|
133
|
+
assert_eq!(values, vec![58.0, 64.0, 139.0, 154.0]);
|
134
|
+
}
|
135
|
+
|
136
|
+
#[test]
|
137
|
+
fn test_tensor_where() {
|
138
|
+
let device = Device::Cpu;
|
139
|
+
|
140
|
+
// Create a condition tensor where values > 0 are treated as true
|
141
|
+
let cond_values = Tensor::new(&[1.0f32, 0.0, 1.0], &device).unwrap();
|
142
|
+
let cond = cond_values.gt(&Tensor::zeros(cond_values.shape(), DType::F32, &device).unwrap()).unwrap();
|
143
|
+
|
144
|
+
let on_true = Tensor::new(&[10.0f32, 20.0, 30.0], &device).unwrap();
|
145
|
+
let on_false = Tensor::new(&[100.0f32, 200.0, 300.0], &device).unwrap();
|
146
|
+
|
147
|
+
let result = cond.where_cond(&on_true, &on_false).unwrap();
|
148
|
+
let values: Vec<f32> = result.to_vec1().unwrap();
|
149
|
+
assert_eq!(values, vec![10.0, 200.0, 30.0]);
|
150
|
+
}
|
151
|
+
|
152
|
+
#[test]
|
153
|
+
fn test_tensor_narrow() {
|
154
|
+
let device = Device::Cpu;
|
155
|
+
|
156
|
+
let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &device).unwrap();
|
157
|
+
|
158
|
+
// Narrow from index 1, length 3
|
159
|
+
let narrowed = tensor.narrow(0, 1, 3).unwrap();
|
160
|
+
let values: Vec<f32> = narrowed.to_vec1().unwrap();
|
161
|
+
assert_eq!(values, vec![2.0, 3.0, 4.0]);
|
162
|
+
}
|
data/lib/candle/llm.rb
CHANGED
@@ -2,6 +2,35 @@ require 'json'
|
|
2
2
|
|
3
3
|
module Candle
|
4
4
|
class LLM
|
5
|
+
# Cache for EOS token to avoid repeated calls
|
6
|
+
def cached_eos_token
|
7
|
+
@cached_eos_token ||= begin
|
8
|
+
if respond_to?(:eos_token)
|
9
|
+
eos_token rescue nil
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
|
14
|
+
# Get model-specific EOS tokens
|
15
|
+
def model_eos_tokens
|
16
|
+
@model_eos_tokens ||= begin
|
17
|
+
tokens = []
|
18
|
+
if model_eos = cached_eos_token
|
19
|
+
tokens << model_eos
|
20
|
+
# For Gemma, also include end_of_turn for chat scenarios and </s>
|
21
|
+
# Even though </s> is technically an HTML tag in Gemma's vocabulary,
|
22
|
+
# it seems to use it as a generation boundary in practice
|
23
|
+
if model_name.downcase.include?("gemma")
|
24
|
+
tokens << "<end_of_turn>"
|
25
|
+
tokens << "</s>"
|
26
|
+
end
|
27
|
+
else
|
28
|
+
# Fallback to common tokens only if model doesn't provide one
|
29
|
+
tokens = ["</s>", "<|endoftext|>", "<|im_end|>", "<end>"]
|
30
|
+
end
|
31
|
+
tokens.uniq
|
32
|
+
end
|
33
|
+
end
|
5
34
|
# Create a structured constraint from a JSON schema
|
6
35
|
def constraint_from_schema(schema)
|
7
36
|
schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
|
@@ -15,48 +44,39 @@ module Candle
|
|
15
44
|
end
|
16
45
|
|
17
46
|
# Generate with regex constraint
|
18
|
-
def generate_regex(prompt, pattern:, **options)
|
47
|
+
def generate_regex(prompt, pattern:, stop_on_match: true, **options)
|
19
48
|
constraint = constraint_from_regex(pattern)
|
20
49
|
|
21
|
-
#
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
50
|
+
# Configure generation with early stopping by default
|
51
|
+
config_opts = options.merge(
|
52
|
+
constraint: constraint,
|
53
|
+
stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, stop_on_match),
|
54
|
+
stop_on_match: stop_on_match
|
55
|
+
)
|
26
56
|
config = options[:config] || GenerationConfig.balanced(**config_opts)
|
27
57
|
|
28
|
-
|
29
|
-
|
30
|
-
# Clean up any trailing EOS tokens
|
31
|
-
result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
|
58
|
+
generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
|
32
59
|
end
|
33
60
|
|
34
61
|
# Generate and parse structured output from a JSON schema
|
35
62
|
def generate_structured(prompt, schema:, **options)
|
36
63
|
constraint = constraint_from_schema(schema)
|
37
|
-
|
64
|
+
|
65
|
+
# Configure generation with early stopping by default
|
66
|
+
config_opts = options.merge(
|
67
|
+
constraint: constraint,
|
68
|
+
stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, true)
|
69
|
+
)
|
38
70
|
config = options[:config] || GenerationConfig.balanced(**config_opts)
|
39
71
|
|
40
72
|
result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
|
41
73
|
|
42
|
-
# Clean up the result - remove common end-of-sequence tokens
|
43
|
-
# that might appear after valid JSON
|
44
|
-
cleaned_result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '')
|
45
|
-
|
46
74
|
# Try to parse as JSON
|
47
75
|
begin
|
48
|
-
JSON
|
76
|
+
# First, try to extract JSON if there's content after stop tokens
|
77
|
+
json_content = extract_json_content(result)
|
78
|
+
JSON.parse(json_content)
|
49
79
|
rescue JSON::ParserError => e
|
50
|
-
# If cleaning didn't help, try to extract JSON from the result
|
51
|
-
# Look for the first complete JSON object/array
|
52
|
-
if match = cleaned_result.match(/(\{[^{}]*\}|\[[^\[\]]*\])/m)
|
53
|
-
begin
|
54
|
-
return JSON.parse(match[1])
|
55
|
-
rescue JSON::ParserError
|
56
|
-
# Fall through to warning
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
80
|
# Return the raw string if parsing fails
|
61
81
|
warn "Warning: Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
|
62
82
|
result
|
@@ -172,14 +192,7 @@ module Candle
|
|
172
192
|
|
173
193
|
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
174
194
|
begin
|
175
|
-
|
176
|
-
|
177
|
-
# If there's a constraint, clean up common EOS tokens that appear after the constrained content
|
178
|
-
if config.constraint
|
179
|
-
result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
|
180
|
-
end
|
181
|
-
|
182
|
-
result
|
195
|
+
_generate(prompt, config)
|
183
196
|
ensure
|
184
197
|
clear_cache if reset_cache
|
185
198
|
end
|
@@ -228,6 +241,88 @@ module Candle
|
|
228
241
|
|
229
242
|
private
|
230
243
|
|
244
|
+
# Extract JSON content from generated text, handling stop tokens and extra content
|
245
|
+
def extract_json_content(text)
|
246
|
+
# Remove any content after common stop tokens
|
247
|
+
cleaned = text
|
248
|
+
|
249
|
+
# Check for EOS tokens and truncate at the first one found
|
250
|
+
model_eos_tokens.each do |token|
|
251
|
+
if idx = cleaned.index(token)
|
252
|
+
cleaned = cleaned[0...idx]
|
253
|
+
end
|
254
|
+
end
|
255
|
+
|
256
|
+
# Try to find valid JSON boundaries
|
257
|
+
# First try a simple approach - find the first { or [ and match to its closing } or ]
|
258
|
+
start_idx = cleaned.index(/[\{\[]/)
|
259
|
+
return cleaned.strip unless start_idx
|
260
|
+
|
261
|
+
# Extract from the start position
|
262
|
+
json_candidate = cleaned[start_idx..-1]
|
263
|
+
|
264
|
+
# Try to find a valid JSON object or array
|
265
|
+
# This regex handles nested structures better
|
266
|
+
if json_candidate[0] == '{'
|
267
|
+
# Match a JSON object
|
268
|
+
bracket_count = 0
|
269
|
+
in_string = false
|
270
|
+
escape_next = false
|
271
|
+
|
272
|
+
json_candidate.chars.each_with_index do |char, idx|
|
273
|
+
if !in_string
|
274
|
+
case char
|
275
|
+
when '{'
|
276
|
+
bracket_count += 1
|
277
|
+
when '}'
|
278
|
+
bracket_count -= 1
|
279
|
+
if bracket_count == 0
|
280
|
+
return json_candidate[0..idx]
|
281
|
+
end
|
282
|
+
when '"'
|
283
|
+
in_string = true unless escape_next
|
284
|
+
end
|
285
|
+
else
|
286
|
+
if char == '"' && !escape_next
|
287
|
+
in_string = false
|
288
|
+
end
|
289
|
+
end
|
290
|
+
|
291
|
+
escape_next = (!escape_next && char == '\\')
|
292
|
+
end
|
293
|
+
elsif json_candidate[0] == '['
|
294
|
+
# Match a JSON array (similar logic)
|
295
|
+
bracket_count = 0
|
296
|
+
in_string = false
|
297
|
+
escape_next = false
|
298
|
+
|
299
|
+
json_candidate.chars.each_with_index do |char, idx|
|
300
|
+
if !in_string
|
301
|
+
case char
|
302
|
+
when '['
|
303
|
+
bracket_count += 1
|
304
|
+
when ']'
|
305
|
+
bracket_count -= 1
|
306
|
+
if bracket_count == 0
|
307
|
+
return json_candidate[0..idx]
|
308
|
+
end
|
309
|
+
when '"'
|
310
|
+
in_string = true unless escape_next
|
311
|
+
end
|
312
|
+
else
|
313
|
+
if char == '"' && !escape_next
|
314
|
+
in_string = false
|
315
|
+
end
|
316
|
+
end
|
317
|
+
|
318
|
+
escape_next = (!escape_next && char == '\\')
|
319
|
+
end
|
320
|
+
end
|
321
|
+
|
322
|
+
# If no valid JSON structure found, return the cleaned string
|
323
|
+
cleaned.strip
|
324
|
+
end
|
325
|
+
|
231
326
|
# Legacy format messages method - kept for backward compatibility
|
232
327
|
# Use apply_chat_template for proper model-specific formatting
|
233
328
|
def format_messages(messages)
|
data/lib/candle/version.rb
CHANGED
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 1.1.
|
4
|
+
version: 1.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -9,7 +9,7 @@ authors:
|
|
9
9
|
autorequire:
|
10
10
|
bindir: bin
|
11
11
|
cert_chain: []
|
12
|
-
date: 2025-
|
12
|
+
date: 2025-08-06 00:00:00.000000000 Z
|
13
13
|
dependencies:
|
14
14
|
- !ruby/object:Gem::Dependency
|
15
15
|
name: rb_sys
|
@@ -196,6 +196,8 @@ files:
|
|
196
196
|
- ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs
|
197
197
|
- ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs
|
198
198
|
- ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs
|
199
|
+
- ext/candle/tests/device_tests.rs
|
200
|
+
- ext/candle/tests/tensor_tests.rs
|
199
201
|
- lib/candle.rb
|
200
202
|
- lib/candle/build_info.rb
|
201
203
|
- lib/candle/device_utils.rb
|