red-candle 1.1.0 → 1.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/lib/candle/llm.rb CHANGED
@@ -2,6 +2,35 @@ require 'json'
2
2
 
3
3
  module Candle
4
4
  class LLM
5
+ # Cache for EOS token to avoid repeated calls
6
+ def cached_eos_token
7
+ @cached_eos_token ||= begin
8
+ if respond_to?(:eos_token)
9
+ eos_token rescue nil
10
+ end
11
+ end
12
+ end
13
+
14
+ # Get model-specific EOS tokens
15
+ def model_eos_tokens
16
+ @model_eos_tokens ||= begin
17
+ tokens = []
18
+ if model_eos = cached_eos_token
19
+ tokens << model_eos
20
+ # For Gemma, also include end_of_turn for chat scenarios and </s>
21
+ # Even though </s> is technically an HTML tag in Gemma's vocabulary,
22
+ # it seems to use it as a generation boundary in practice
23
+ if model_name.downcase.include?("gemma")
24
+ tokens << "<end_of_turn>"
25
+ tokens << "</s>"
26
+ end
27
+ else
28
+ # Fallback to common tokens only if model doesn't provide one
29
+ tokens = ["</s>", "<|endoftext|>", "<|im_end|>", "<end>"]
30
+ end
31
+ tokens.uniq
32
+ end
33
+ end
5
34
  # Create a structured constraint from a JSON schema
6
35
  def constraint_from_schema(schema)
7
36
  schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
@@ -15,48 +44,39 @@ module Candle
15
44
  end
16
45
 
17
46
  # Generate with regex constraint
18
- def generate_regex(prompt, pattern:, **options)
47
+ def generate_regex(prompt, pattern:, stop_on_match: true, **options)
19
48
  constraint = constraint_from_regex(pattern)
20
49
 
21
- # Add common EOS tokens as stop sequences for regex generation
22
- stop_sequences = options[:stop_sequences] || []
23
- stop_sequences += ["</s>", "<|endoftext|>", "<|im_end|>", "<end>", "\n"] unless options[:no_auto_stop]
24
-
25
- config_opts = options.merge(constraint: constraint, stop_sequences: stop_sequences)
50
+ # Configure generation with early stopping by default
51
+ config_opts = options.merge(
52
+ constraint: constraint,
53
+ stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, stop_on_match),
54
+ stop_on_match: stop_on_match
55
+ )
26
56
  config = options[:config] || GenerationConfig.balanced(**config_opts)
27
57
 
28
- result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
29
-
30
- # Clean up any trailing EOS tokens
31
- result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
58
+ generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
32
59
  end
33
60
 
34
61
  # Generate and parse structured output from a JSON schema
35
62
  def generate_structured(prompt, schema:, **options)
36
63
  constraint = constraint_from_schema(schema)
37
- config_opts = options.merge(constraint: constraint)
64
+
65
+ # Configure generation with early stopping by default
66
+ config_opts = options.merge(
67
+ constraint: constraint,
68
+ stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, true)
69
+ )
38
70
  config = options[:config] || GenerationConfig.balanced(**config_opts)
39
71
 
40
72
  result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
41
73
 
42
- # Clean up the result - remove common end-of-sequence tokens
43
- # that might appear after valid JSON
44
- cleaned_result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '')
45
-
46
74
  # Try to parse as JSON
47
75
  begin
48
- JSON.parse(cleaned_result)
76
+ # First, try to extract JSON if there's content after stop tokens
77
+ json_content = extract_json_content(result)
78
+ JSON.parse(json_content)
49
79
  rescue JSON::ParserError => e
50
- # If cleaning didn't help, try to extract JSON from the result
51
- # Look for the first complete JSON object/array
52
- if match = cleaned_result.match(/(\{[^{}]*\}|\[[^\[\]]*\])/m)
53
- begin
54
- return JSON.parse(match[1])
55
- rescue JSON::ParserError
56
- # Fall through to warning
57
- end
58
- end
59
-
60
80
  # Return the raw string if parsing fails
61
81
  warn "Warning: Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
62
82
  result
@@ -172,14 +192,7 @@ module Candle
172
192
 
173
193
  def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
174
194
  begin
175
- result = _generate(prompt, config)
176
-
177
- # If there's a constraint, clean up common EOS tokens that appear after the constrained content
178
- if config.constraint
179
- result = result.gsub(/(<\/s>|<\|endoftext\|>|<\|im_end\|>|<end>).*$/m, '').strip
180
- end
181
-
182
- result
195
+ _generate(prompt, config)
183
196
  ensure
184
197
  clear_cache if reset_cache
185
198
  end
@@ -228,6 +241,88 @@ module Candle
228
241
 
229
242
  private
230
243
 
244
+ # Extract JSON content from generated text, handling stop tokens and extra content
245
+ def extract_json_content(text)
246
+ # Remove any content after common stop tokens
247
+ cleaned = text
248
+
249
+ # Check for EOS tokens and truncate at the first one found
250
+ model_eos_tokens.each do |token|
251
+ if idx = cleaned.index(token)
252
+ cleaned = cleaned[0...idx]
253
+ end
254
+ end
255
+
256
+ # Try to find valid JSON boundaries
257
+ # First try a simple approach - find the first { or [ and match to its closing } or ]
258
+ start_idx = cleaned.index(/[\{\[]/)
259
+ return cleaned.strip unless start_idx
260
+
261
+ # Extract from the start position
262
+ json_candidate = cleaned[start_idx..-1]
263
+
264
+ # Try to find a valid JSON object or array
265
+ # This regex handles nested structures better
266
+ if json_candidate[0] == '{'
267
+ # Match a JSON object
268
+ bracket_count = 0
269
+ in_string = false
270
+ escape_next = false
271
+
272
+ json_candidate.chars.each_with_index do |char, idx|
273
+ if !in_string
274
+ case char
275
+ when '{'
276
+ bracket_count += 1
277
+ when '}'
278
+ bracket_count -= 1
279
+ if bracket_count == 0
280
+ return json_candidate[0..idx]
281
+ end
282
+ when '"'
283
+ in_string = true unless escape_next
284
+ end
285
+ else
286
+ if char == '"' && !escape_next
287
+ in_string = false
288
+ end
289
+ end
290
+
291
+ escape_next = (!escape_next && char == '\\')
292
+ end
293
+ elsif json_candidate[0] == '['
294
+ # Match a JSON array (similar logic)
295
+ bracket_count = 0
296
+ in_string = false
297
+ escape_next = false
298
+
299
+ json_candidate.chars.each_with_index do |char, idx|
300
+ if !in_string
301
+ case char
302
+ when '['
303
+ bracket_count += 1
304
+ when ']'
305
+ bracket_count -= 1
306
+ if bracket_count == 0
307
+ return json_candidate[0..idx]
308
+ end
309
+ when '"'
310
+ in_string = true unless escape_next
311
+ end
312
+ else
313
+ if char == '"' && !escape_next
314
+ in_string = false
315
+ end
316
+ end
317
+
318
+ escape_next = (!escape_next && char == '\\')
319
+ end
320
+ end
321
+
322
+ # If no valid JSON structure found, return the cleaned string
323
+ cleaned.strip
324
+ end
325
+
231
326
  # Legacy format messages method - kept for backward compatibility
232
327
  # Use apply_chat_template for proper model-specific formatting
233
328
  def format_messages(messages)
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.1.0"
3
+ VERSION = "1.1.1"
4
4
  end
5
5
  # :nocov:
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.1.0
4
+ version: 1.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-07-27 00:00:00.000000000 Z
12
+ date: 2025-07-28 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys