red-candle 1.1.0 → 1.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -83,6 +83,26 @@ impl ModelType {
83
83
  }
84
84
  }
85
85
 
86
+ // Macro to extract parameters from Ruby hash to reduce boilerplate
87
+ macro_rules! extract_param {
88
+ // Basic parameter extraction
89
+ ($kwargs:expr, $config:expr, $param:ident) => {
90
+ if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
91
+ if let Ok(v) = TryConvert::try_convert(value) {
92
+ $config.$param = v;
93
+ }
94
+ }
95
+ };
96
+ // Optional parameter extraction (wraps in Some)
97
+ ($kwargs:expr, $config:expr, $param:ident, optional) => {
98
+ if let Some(value) = $kwargs.get(magnus::Symbol::new(stringify!($param))) {
99
+ if let Ok(v) = TryConvert::try_convert(value) {
100
+ $config.$param = Some(v);
101
+ }
102
+ }
103
+ };
104
+ }
105
+
86
106
  #[derive(Clone, Debug)]
87
107
  #[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
88
108
  pub struct GenerationConfig {
@@ -93,55 +113,20 @@ impl GenerationConfig {
93
113
  pub fn new(kwargs: RHash) -> Result<Self> {
94
114
  let mut config = RustGenerationConfig::default();
95
115
 
96
- // Extract values from kwargs manually
97
- if let Some(value) = kwargs.get(magnus::Symbol::new("max_length")) {
98
- if let Ok(v) = TryConvert::try_convert(value) {
99
- config.max_length = v;
100
- }
101
- }
102
-
103
- if let Some(value) = kwargs.get(magnus::Symbol::new("temperature")) {
104
- if let Ok(v) = TryConvert::try_convert(value) {
105
- config.temperature = v;
106
- }
107
- }
108
-
109
- if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
110
- if let Ok(v) = TryConvert::try_convert(value) {
111
- config.top_p = Some(v);
112
- }
113
- }
114
-
115
- if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
116
- if let Ok(v) = TryConvert::try_convert(value) {
117
- config.top_k = Some(v);
118
- }
119
- }
120
-
121
- if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
122
- if let Ok(v) = TryConvert::try_convert(value) {
123
- config.repetition_penalty = v;
124
- }
125
- }
126
-
127
- if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
128
- if let Ok(v) = TryConvert::try_convert(value) {
129
- config.repetition_penalty_last_n = v;
130
- }
131
- }
132
-
133
- if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
134
- if let Ok(v) = TryConvert::try_convert(value) {
135
- config.seed = v;
136
- }
137
- }
138
-
139
- if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
140
- if let Ok(v) = TryConvert::try_convert(value) {
141
- config.include_prompt = v;
142
- }
143
- }
116
+ // Extract basic parameters using macro
117
+ extract_param!(kwargs, config, max_length);
118
+ extract_param!(kwargs, config, temperature);
119
+ extract_param!(kwargs, config, top_p, optional);
120
+ extract_param!(kwargs, config, top_k, optional);
121
+ extract_param!(kwargs, config, repetition_penalty);
122
+ extract_param!(kwargs, config, repetition_penalty_last_n);
123
+ extract_param!(kwargs, config, seed);
124
+ extract_param!(kwargs, config, include_prompt);
125
+ extract_param!(kwargs, config, debug_tokens);
126
+ extract_param!(kwargs, config, stop_on_constraint_satisfaction);
127
+ extract_param!(kwargs, config, stop_on_match);
144
128
 
129
+ // Handle special cases that need custom logic
145
130
  if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
146
131
  if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
147
132
  config.stop_sequences = arr
@@ -151,13 +136,6 @@ impl GenerationConfig {
151
136
  }
152
137
  }
153
138
 
154
- if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
155
- if let Ok(v) = TryConvert::try_convert(value) {
156
- config.debug_tokens = v;
157
- }
158
- }
159
-
160
- // Handle constraint parameter
161
139
  if let Some(value) = kwargs.get(magnus::Symbol::new("constraint")) {
162
140
  if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
163
141
  config.constraint = Some(Arc::clone(&constraint.index));
@@ -209,6 +187,15 @@ impl GenerationConfig {
209
187
  pub fn debug_tokens(&self) -> bool {
210
188
  self.inner.debug_tokens
211
189
  }
190
+
191
+ pub fn stop_on_constraint_satisfaction(&self) -> bool {
192
+ self.inner.stop_on_constraint_satisfaction
193
+ }
194
+
195
+ pub fn stop_on_match(&self) -> bool {
196
+ self.inner.stop_on_match
197
+ }
198
+
212
199
  pub fn constraint(&self) -> Option<StructuredConstraint> {
213
200
  self.inner.constraint.as_ref().map(|c| StructuredConstraint {
214
201
  index: Arc::clone(c),
@@ -372,6 +359,42 @@ impl LLM {
372
359
  ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
373
360
  }
374
361
  }
362
+
363
+ /// Get the EOS token string for this model
364
+ pub fn eos_token(&self) -> Result<String> {
365
+ let (eos_token_id, tokenizer_clone) = {
366
+ let model = match self.model.lock() {
367
+ Ok(guard) => guard,
368
+ Err(poisoned) => poisoned.into_inner(),
369
+ };
370
+ let model_ref = model.borrow();
371
+
372
+ // Get both EOS token ID and tokenizer clone in one lock scope
373
+ let eos_id = match &*model_ref {
374
+ ModelType::Mistral(m) => m.eos_token_id(),
375
+ ModelType::Llama(m) => m.eos_token_id(),
376
+ ModelType::Gemma(m) => m.eos_token_id(),
377
+ ModelType::Qwen(m) => m.eos_token_id(),
378
+ ModelType::Phi(m) => m.eos_token_id(),
379
+ ModelType::QuantizedGGUF(m) => m.eos_token_id(),
380
+ };
381
+
382
+ let tokenizer = match &*model_ref {
383
+ ModelType::Mistral(m) => m.tokenizer().clone(),
384
+ ModelType::Llama(m) => m.tokenizer().clone(),
385
+ ModelType::Gemma(m) => m.tokenizer().clone(),
386
+ ModelType::Qwen(m) => m.tokenizer().clone(),
387
+ ModelType::Phi(m) => m.tokenizer().clone(),
388
+ ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
389
+ };
390
+
391
+ (eos_id, tokenizer)
392
+ }; // Lock is released here
393
+
394
+ // Convert ID to string using the tokenizer
395
+ let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
396
+ tokenizer_wrapper.id_to_token(eos_token_id as i64)
397
+ }
375
398
 
376
399
  /// Clear the model's cache (e.g., KV cache for transformers)
377
400
  pub fn clear_cache(&self) -> Result<()> {
@@ -460,6 +483,8 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
460
483
  rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
461
484
  rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
462
485
  rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
486
+ rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
487
+ rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
463
488
  rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
464
489
 
465
490
  let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
@@ -469,6 +494,7 @@ pub fn init_llm(rb_candle: RModule) -> Result<()> {
469
494
  rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
470
495
  rb_llm.define_method("device", method!(LLM::device, 0))?;
471
496
  rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
497
+ rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
472
498
  rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
473
499
  rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
474
500
 
@@ -651,4 +651,5 @@ pub fn init(rb_candle: RModule) -> Result<()> {
651
651
  rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
652
652
  rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
653
653
  Ok(())
654
- }
654
+ }
655
+
@@ -100,4 +100,5 @@ impl TokenizerWrapper {
100
100
  pub fn inner_mut(&mut self) -> &mut Tokenizer {
101
101
  &mut self.tokenizer
102
102
  }
103
- }
103
+ }
104
+
@@ -0,0 +1,43 @@
1
+ use candle_core::Device as CoreDevice;
2
+
3
+ #[test]
4
+ fn test_device_creation() {
5
+ // CPU device should always work
6
+ let cpu = CoreDevice::Cpu;
7
+ assert!(matches!(cpu, CoreDevice::Cpu));
8
+
9
+ // Test device display
10
+ assert_eq!(format!("{:?}", cpu), "Cpu");
11
+ }
12
+
13
+ #[cfg(feature = "cuda")]
14
+ #[test]
15
+ #[ignore = "requires CUDA hardware"]
16
+ fn test_cuda_device_creation() {
17
+ // This might fail if no CUDA device is available
18
+ match CoreDevice::new_cuda(0) {
19
+ Ok(device) => assert!(matches!(device, CoreDevice::Cuda(_))),
20
+ Err(_) => println!("No CUDA device available for testing"),
21
+ }
22
+ }
23
+
24
+ #[cfg(feature = "metal")]
25
+ #[test]
26
+ #[ignore = "requires Metal hardware"]
27
+ fn test_metal_device_creation() {
28
+ // This might fail if no Metal device is available
29
+ match CoreDevice::new_metal(0) {
30
+ Ok(device) => assert!(matches!(device, CoreDevice::Metal(_))),
31
+ Err(_) => println!("No Metal device available for testing"),
32
+ }
33
+ }
34
+
35
+ #[test]
36
+ fn test_device_matching() {
37
+ let cpu1 = CoreDevice::Cpu;
38
+ let cpu2 = CoreDevice::Cpu;
39
+
40
+ // Same device types should match
41
+ assert!(matches!(cpu1, CoreDevice::Cpu));
42
+ assert!(matches!(cpu2, CoreDevice::Cpu));
43
+ }
@@ -0,0 +1,162 @@
1
+ use candle_core::{Tensor, Device, DType};
2
+
3
+ #[test]
4
+ fn test_tensor_creation() {
5
+ let device = Device::Cpu;
6
+
7
+ // Test tensor creation from slice
8
+ let data = vec![1.0f32, 2.0, 3.0, 4.0];
9
+ let tensor = Tensor::new(&data[..], &device).unwrap();
10
+ assert_eq!(tensor.dims(), &[4]);
11
+ assert_eq!(tensor.dtype(), DType::F32);
12
+
13
+ // Test zeros
14
+ let zeros = Tensor::zeros(&[2, 3], DType::F32, &device).unwrap();
15
+ assert_eq!(zeros.dims(), &[2, 3]);
16
+
17
+ // Test ones
18
+ let ones = Tensor::ones(&[3, 2], DType::F32, &device).unwrap();
19
+ assert_eq!(ones.dims(), &[3, 2]);
20
+ }
21
+
22
+ #[test]
23
+ fn test_tensor_arithmetic() {
24
+ let device = Device::Cpu;
25
+
26
+ let a = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
27
+ let b = Tensor::new(&[4.0f32, 5.0, 6.0], &device).unwrap();
28
+
29
+ // Addition
30
+ let sum = a.add(&b).unwrap();
31
+ let sum_vec: Vec<f32> = sum.to_vec1().unwrap();
32
+ assert_eq!(sum_vec, vec![5.0, 7.0, 9.0]);
33
+
34
+ // Subtraction
35
+ let diff = a.sub(&b).unwrap();
36
+ let diff_vec: Vec<f32> = diff.to_vec1().unwrap();
37
+ assert_eq!(diff_vec, vec![-3.0, -3.0, -3.0]);
38
+
39
+ // Multiplication
40
+ let prod = a.mul(&b).unwrap();
41
+ let prod_vec: Vec<f32> = prod.to_vec1().unwrap();
42
+ assert_eq!(prod_vec, vec![4.0, 10.0, 18.0]);
43
+ }
44
+
45
+ #[test]
46
+ fn test_tensor_reshape() {
47
+ let device = Device::Cpu;
48
+
49
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device).unwrap();
50
+
51
+ // Reshape to 2x3
52
+ let reshaped = tensor.reshape(&[2, 3]).unwrap();
53
+ assert_eq!(reshaped.dims(), &[2, 3]);
54
+
55
+ // Reshape to 3x2
56
+ let reshaped = tensor.reshape(&[3, 2]).unwrap();
57
+ assert_eq!(reshaped.dims(), &[3, 2]);
58
+ }
59
+
60
+ #[test]
61
+ fn test_tensor_transpose() {
62
+ let device = Device::Cpu;
63
+
64
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)
65
+ .unwrap()
66
+ .reshape(&[2, 2])
67
+ .unwrap();
68
+
69
+ let transposed = tensor.transpose(0, 1).unwrap();
70
+ assert_eq!(transposed.dims(), &[2, 2]);
71
+
72
+ let values: Vec<f32> = transposed.flatten_all().unwrap().to_vec1().unwrap();
73
+ assert_eq!(values, vec![1.0, 3.0, 2.0, 4.0]);
74
+ }
75
+
76
+ #[test]
77
+ fn test_tensor_reduction() {
78
+ let device = Device::Cpu;
79
+
80
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
81
+
82
+ // Sum
83
+ let sum = tensor.sum_all().unwrap();
84
+ let sum_val: f32 = sum.to_scalar().unwrap();
85
+ assert_eq!(sum_val, 10.0);
86
+
87
+ // Mean
88
+ let mean = tensor.mean_all().unwrap();
89
+ let mean_val: f32 = mean.to_scalar().unwrap();
90
+ assert_eq!(mean_val, 2.5);
91
+ }
92
+
93
+ #[test]
94
+ fn test_tensor_indexing() {
95
+ let device = Device::Cpu;
96
+
97
+ let tensor = Tensor::new(&[10.0f32, 20.0, 30.0, 40.0], &device).unwrap();
98
+
99
+ // Get element at index 0
100
+ let elem = tensor.get(0).unwrap();
101
+ let val: f32 = elem.to_scalar().unwrap();
102
+ assert_eq!(val, 10.0);
103
+
104
+ // Get element at index 2
105
+ let elem = tensor.get(2).unwrap();
106
+ let val: f32 = elem.to_scalar().unwrap();
107
+ assert_eq!(val, 30.0);
108
+ }
109
+
110
+ #[test]
111
+ fn test_tensor_matmul() {
112
+ let device = Device::Cpu;
113
+
114
+ // 2x3 matrix
115
+ let a = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &device)
116
+ .unwrap()
117
+ .reshape(&[2, 3])
118
+ .unwrap();
119
+
120
+ // 3x2 matrix
121
+ let b = Tensor::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0], &device)
122
+ .unwrap()
123
+ .reshape(&[3, 2])
124
+ .unwrap();
125
+
126
+ // Matrix multiplication
127
+ let result = a.matmul(&b).unwrap();
128
+ assert_eq!(result.dims(), &[2, 2]);
129
+
130
+ let values: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
131
+ // [1*7 + 2*9 + 3*11, 1*8 + 2*10 + 3*12, 4*7 + 5*9 + 6*11, 4*8 + 5*10 + 6*12]
132
+ // = [58, 64, 139, 154]
133
+ assert_eq!(values, vec![58.0, 64.0, 139.0, 154.0]);
134
+ }
135
+
136
+ #[test]
137
+ fn test_tensor_where() {
138
+ let device = Device::Cpu;
139
+
140
+ // Create a condition tensor where values > 0 are treated as true
141
+ let cond_values = Tensor::new(&[1.0f32, 0.0, 1.0], &device).unwrap();
142
+ let cond = cond_values.gt(&Tensor::zeros(cond_values.shape(), DType::F32, &device).unwrap()).unwrap();
143
+
144
+ let on_true = Tensor::new(&[10.0f32, 20.0, 30.0], &device).unwrap();
145
+ let on_false = Tensor::new(&[100.0f32, 200.0, 300.0], &device).unwrap();
146
+
147
+ let result = cond.where_cond(&on_true, &on_false).unwrap();
148
+ let values: Vec<f32> = result.to_vec1().unwrap();
149
+ assert_eq!(values, vec![10.0, 200.0, 30.0]);
150
+ }
151
+
152
+ #[test]
153
+ fn test_tensor_narrow() {
154
+ let device = Device::Cpu;
155
+
156
+ let tensor = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &device).unwrap();
157
+
158
+ // Narrow from index 1, length 3
159
+ let narrowed = tensor.narrow(0, 1, 3).unwrap();
160
+ let values: Vec<f32> = narrowed.to_vec1().unwrap();
161
+ assert_eq!(values, vec![2.0, 3.0, 4.0]);
162
+ }
data/lib/candle/llm.rb CHANGED
@@ -2,6 +2,35 @@ require 'json'
2
2
 
3
3
  module Candle
4
4
  class LLM
5
+ # Cache for EOS token to avoid repeated calls
6
+ def cached_eos_token
7
+ @cached_eos_token ||= begin
8
+ if respond_to?(:eos_token)
9
+ eos_token rescue nil
10
+ end
11
+ end
12
+ end
13
+
14
+ # Get model-specific EOS tokens
15
+ def model_eos_tokens
16
+ @model_eos_tokens ||= begin
17
+ tokens = []
18
+ if model_eos = cached_eos_token
19
+ tokens << model_eos
20
+ # For Gemma, also include end_of_turn for chat scenarios and </s>
21
+ # Even though </s> is technically an HTML tag in Gemma's vocabulary,
22
+ # it seems to use it as a generation boundary in practice
23
+ if model_name.downcase.include?("gemma")
24
+ tokens << "<end_of_turn>"
25
+ tokens << "</s>"
26
+ end
27
+ else
28
+ # Fallback to common tokens only if model doesn't provide one
29
+ tokens = ["</s>", "<|endoftext|>", "<|im_end|>", "<end>"]
30
+ end
31
+ tokens.uniq
32
+ end
33
+ end
5
34
  # Create a structured constraint from a JSON schema
6
35
  def constraint_from_schema(schema)
7
36
  schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
@@ -15,48 +44,39 @@ module Candle
15
44
  end
16
45
 
17
46
  # Generate with regex constraint
18
- def generate_regex(prompt, pattern:, **options)
47
+ def generate_regex(prompt, pattern:, stop_on_match: true, **options)
19
48
  constraint = constraint_from_regex(pattern)
20
49
 
21
- # Add common EOS tokens as stop sequences for regex generation
22
- stop_sequences = options[:stop_sequences] || []
23
- stop_sequences += ["</s>", "<|endoftext|>", "<|im_end|>", "<end>", "\n"] unless options[:no_auto_stop]
24
-
25
- config_opts = options.merge(constraint: constraint, stop_sequences: stop_sequences)
50
+ # Configure generation with early stopping by default
51
+ config_opts = options.merge(
52
+ constraint: constraint,
53
+ stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, stop_on_match),
54
+ stop_on_match: stop_on_match
55
+ )
26
56
  config = options[:config] || GenerationConfig.balanced(**config_opts)
27
57
 
28
- result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
29
-
30
- # Clean up any trailing EOS tokens
31
- result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
58
+ generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
32
59
  end
33
60
 
34
61
  # Generate and parse structured output from a JSON schema
35
62
  def generate_structured(prompt, schema:, **options)
36
63
  constraint = constraint_from_schema(schema)
37
- config_opts = options.merge(constraint: constraint)
64
+
65
+ # Configure generation with early stopping by default
66
+ config_opts = options.merge(
67
+ constraint: constraint,
68
+ stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, true)
69
+ )
38
70
  config = options[:config] || GenerationConfig.balanced(**config_opts)
39
71
 
40
72
  result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
41
73
 
42
- # Clean up the result - remove common end-of-sequence tokens
43
- # that might appear after valid JSON
44
- cleaned_result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '')
45
-
46
74
  # Try to parse as JSON
47
75
  begin
48
- JSON.parse(cleaned_result)
76
+ # First, try to extract JSON if there's content after stop tokens
77
+ json_content = extract_json_content(result)
78
+ JSON.parse(json_content)
49
79
  rescue JSON::ParserError => e
50
- # If cleaning didn't help, try to extract JSON from the result
51
- # Look for the first complete JSON object/array
52
- if match = cleaned_result.match(/(\{[^{}]*\}|\[[^\[\]]*\])/m)
53
- begin
54
- return JSON.parse(match[1])
55
- rescue JSON::ParserError
56
- # Fall through to warning
57
- end
58
- end
59
-
60
80
  # Return the raw string if parsing fails
61
81
  warn "Warning: Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
62
82
  result
@@ -172,14 +192,7 @@ module Candle
172
192
 
173
193
  def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
174
194
  begin
175
- result = _generate(prompt, config)
176
-
177
- # If there's a constraint, clean up common EOS tokens that appear after the constrained content
178
- if config.constraint
179
- result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
180
- end
181
-
182
- result
195
+ _generate(prompt, config)
183
196
  ensure
184
197
  clear_cache if reset_cache
185
198
  end
@@ -228,6 +241,88 @@ module Candle
228
241
 
229
242
  private
230
243
 
244
+ # Extract JSON content from generated text, handling stop tokens and extra content
245
+ def extract_json_content(text)
246
+ # Remove any content after common stop tokens
247
+ cleaned = text
248
+
249
+ # Check for EOS tokens and truncate at the first one found
250
+ model_eos_tokens.each do |token|
251
+ if idx = cleaned.index(token)
252
+ cleaned = cleaned[0...idx]
253
+ end
254
+ end
255
+
256
+ # Try to find valid JSON boundaries
257
+ # First try a simple approach - find the first { or [ and match to its closing } or ]
258
+ start_idx = cleaned.index(/[\{\[]/)
259
+ return cleaned.strip unless start_idx
260
+
261
+ # Extract from the start position
262
+ json_candidate = cleaned[start_idx..-1]
263
+
264
+ # Try to find a valid JSON object or array
265
+ # This regex handles nested structures better
266
+ if json_candidate[0] == '{'
267
+ # Match a JSON object
268
+ bracket_count = 0
269
+ in_string = false
270
+ escape_next = false
271
+
272
+ json_candidate.chars.each_with_index do |char, idx|
273
+ if !in_string
274
+ case char
275
+ when '{'
276
+ bracket_count += 1
277
+ when '}'
278
+ bracket_count -= 1
279
+ if bracket_count == 0
280
+ return json_candidate[0..idx]
281
+ end
282
+ when '"'
283
+ in_string = true unless escape_next
284
+ end
285
+ else
286
+ if char == '"' && !escape_next
287
+ in_string = false
288
+ end
289
+ end
290
+
291
+ escape_next = (!escape_next && char == '\\')
292
+ end
293
+ elsif json_candidate[0] == '['
294
+ # Match a JSON array (similar logic)
295
+ bracket_count = 0
296
+ in_string = false
297
+ escape_next = false
298
+
299
+ json_candidate.chars.each_with_index do |char, idx|
300
+ if !in_string
301
+ case char
302
+ when '['
303
+ bracket_count += 1
304
+ when ']'
305
+ bracket_count -= 1
306
+ if bracket_count == 0
307
+ return json_candidate[0..idx]
308
+ end
309
+ when '"'
310
+ in_string = true unless escape_next
311
+ end
312
+ else
313
+ if char == '"' && !escape_next
314
+ in_string = false
315
+ end
316
+ end
317
+
318
+ escape_next = (!escape_next && char == '\\')
319
+ end
320
+ end
321
+
322
+ # If no valid JSON structure found, return the cleaned string
323
+ cleaned.strip
324
+ end
325
+
231
326
  # Legacy format messages method - kept for backward compatibility
232
327
  # Use apply_chat_template for proper model-specific formatting
233
328
  def format_messages(messages)
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.1.0"
3
+ VERSION = "1.1.2"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.0
4
+ version: 1.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-07-27 00:00:00.000000000 Z
12
+ date: 2025-08-06 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -196,6 +196,8 @@ files:
196
196
  - ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs
197
197
  - ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs
198
198
  - ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs
199
+ - ext/candle/tests/device_tests.rs
200
+ - ext/candle/tests/tensor_tests.rs
199
201
  - lib/candle.rb
200
202
  - lib/candle/build_info.rb
201
203
  - lib/candle/device_utils.rb