red-candle 1.8.0.pre2-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5193 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +33 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
data/lib/candle/llm.rb ADDED
@@ -0,0 +1,595 @@
1
+ require 'json'
2
+
3
+ module Candle
4
+ class LLM
5
+ # Cache for EOS token to avoid repeated calls
6
+ def cached_eos_token
7
+ @cached_eos_token ||= begin
8
+ if respond_to?(:eos_token)
9
+ eos_token rescue nil
10
+ end
11
+ end
12
+ end
13
+
14
+ # Get model-specific EOS tokens
15
+ def model_eos_tokens
16
+ @model_eos_tokens ||= begin
17
+ tokens = []
18
+ if model_eos = cached_eos_token
19
+ tokens << model_eos
20
+ # For Gemma, also include end_of_turn for chat scenarios and </s>
21
+ # Even though </s> is technically an HTML tag in Gemma's vocabulary,
22
+ # it seems to use it as a generation boundary in practice
23
+ if model_name.downcase.include?("gemma")
24
+ tokens << "<end_of_turn>"
25
+ tokens << "</s>"
26
+ end
27
+ else
28
+ # Fallback to common tokens only if model doesn't provide one
29
+ tokens = ["</s>", "<|endoftext|>", "<|im_end|>", "<end>"]
30
+ end
31
+ tokens.uniq
32
+ end
33
+ end
34
+ # Create a structured constraint from a JSON schema
35
+ # Uses the model's vocabulary with proper byte encoding handling
36
+ def constraint_from_schema(schema)
37
+ schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
38
+
39
+ # Extract the tokenizer source model ID for proper vocabulary loading
40
+ tokenizer_model = tokenizer_source_model
41
+ if tokenizer_model
42
+ begin
43
+ StructuredConstraint.from_schema_with_model(schema_str, tokenizer_model)
44
+ rescue RuntimeError => e
45
+ # Fall back to legacy method if from_pretrained fails
46
+ # (e.g., tokenizer doesn't have EOS token in expected format)
47
+ if e.message.include?("UnsupportedTokenizer")
48
+ StructuredConstraint.from_schema(schema_str, tokenizer)
49
+ else
50
+ raise
51
+ end
52
+ end
53
+ else
54
+ # Fall back to legacy method if we can't determine the model
55
+ StructuredConstraint.from_schema(schema_str, tokenizer)
56
+ end
57
+ end
58
+
59
+ # Create a structured constraint from a regex pattern
60
+ # Uses the model's vocabulary with proper byte encoding handling
61
+ def constraint_from_regex(pattern)
62
+ pattern_str = pattern.is_a?(Regexp) ? pattern.source : pattern.to_s
63
+
64
+ # Extract the tokenizer source model ID for proper vocabulary loading
65
+ tokenizer_model = tokenizer_source_model
66
+ if tokenizer_model
67
+ begin
68
+ StructuredConstraint.from_regex_with_model(pattern_str, tokenizer_model)
69
+ rescue RuntimeError => e
70
+ # Fall back to legacy method if from_pretrained fails
71
+ if e.message.include?("UnsupportedTokenizer")
72
+ StructuredConstraint.from_regex(pattern_str, tokenizer)
73
+ else
74
+ raise
75
+ end
76
+ end
77
+ else
78
+ # Fall back to legacy method if we can't determine the model
79
+ StructuredConstraint.from_regex(pattern_str, tokenizer)
80
+ end
81
+ end
82
+
83
+ private
84
+
85
+ # Get the model ID to use for vocabulary loading
86
+ # This handles GGUF models by extracting the tokenizer source
87
+ def tokenizer_source_model
88
+ opts = options rescue {}
89
+
90
+ # For GGUF models, use the tokenizer source if available
91
+ if opts["tokenizer_source"]
92
+ return opts["tokenizer_source"]
93
+ end
94
+
95
+ # For regular models, use the base model ID
96
+ if opts["base_model"]
97
+ return opts["base_model"]
98
+ end
99
+
100
+ # Try model_id but strip GGUF parts
101
+ model = opts["model_id"] || (model_id rescue nil)
102
+ return nil unless model
103
+
104
+ # Remove GGUF file suffix if present
105
+ if model.include?("@")
106
+ model = model.split("@").first
107
+ end
108
+
109
+ # For GGUF repos, try to guess the tokenizer source
110
+ if model.downcase.include?("gguf")
111
+ guessed = self.class.guess_tokenizer(model)
112
+ return guessed if guessed && guessed != model
113
+ end
114
+
115
+ model
116
+ end
117
+
118
+ public
119
+
120
+ # Generate with regex constraint
121
+ def generate_regex(prompt, pattern:, stop_on_match: true, **options)
122
+ constraint = constraint_from_regex(pattern)
123
+
124
+ # Configure generation with early stopping by default
125
+ config_opts = options.merge(
126
+ constraint: constraint,
127
+ stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, stop_on_match),
128
+ stop_on_match: stop_on_match
129
+ )
130
+ config = options[:config] || GenerationConfig.balanced(**config_opts)
131
+
132
+ generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
133
+ end
134
+
135
+ # Generate and parse structured output from a JSON schema
136
+ def generate_structured(prompt, schema:, **options)
137
+ constraint = constraint_from_schema(schema)
138
+
139
+ # Configure generation with early stopping by default
140
+ config_opts = options.merge(
141
+ constraint: constraint,
142
+ stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, true)
143
+ )
144
+ config = options[:config] || GenerationConfig.balanced(**config_opts)
145
+
146
+ result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
147
+
148
+ # Try to parse as JSON
149
+ begin
150
+ # First, try to extract JSON if there's content after stop tokens
151
+ json_content = extract_json_content(result)
152
+ JSON.parse(json_content)
153
+ rescue JSON::ParserError => e
154
+ # Return the raw string if parsing fails
155
+ Candle.logger.warn "Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
156
+ result
157
+ end
158
+ end
159
+ # Tokenizer registry for automatic detection
160
+ TOKENIZER_REGISTRY = {
161
+ # Exact model matches
162
+ "TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
163
+ "TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
164
+ "TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
165
+ "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
166
+
167
+ # Qwen official GGUF models
168
+ "Qwen/Qwen3-8B-GGUF" => "Qwen/Qwen3-8B",
169
+ "Qwen/Qwen3-4B-GGUF" => "Qwen/Qwen3-4B",
170
+ "Qwen/Qwen3-14B-GGUF" => "Qwen/Qwen3-14B",
171
+ "Qwen/Qwen3-32B-GGUF" => "Qwen/Qwen3-32B",
172
+ "Qwen/Qwen3-72B-GGUF" => "Qwen/Qwen3-72B",
173
+
174
+ # Phi GGUF models
175
+ "TheBloke/phi-2-GGUF" => "microsoft/phi-2",
176
+ "microsoft/phi-4-gguf" => "microsoft/phi-4",
177
+ "bartowski/Phi-3.5-mini-instruct-GGUF" => "microsoft/Phi-3.5-mini-instruct",
178
+
179
+ # Pattern-based fallbacks (evaluated in order)
180
+ :patterns => [
181
+ # Mistral models
182
+ [/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
183
+ [/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
184
+ [/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
185
+
186
+ # Llama models
187
+ [/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
188
+ [/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
189
+ [/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
190
+ [/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
191
+ [/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
192
+ [/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
193
+
194
+ # Gemma models
195
+ [/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
196
+ [/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
197
+ [/gemma.*?7b/i, "google/gemma-7b"],
198
+ [/gemma.*?2b/i, "google/gemma-2b"],
199
+
200
+ # Qwen models
201
+ [/qwen.*?3.*?72b/i, "Qwen/Qwen3-72B"],
202
+ [/qwen.*?3.*?32b/i, "Qwen/Qwen3-32B"],
203
+ [/qwen.*?3.*?14b/i, "Qwen/Qwen3-14B"],
204
+ [/qwen.*?3.*?8b/i, "Qwen/Qwen3-8B"],
205
+ [/qwen.*?3.*?4b/i, "Qwen/Qwen3-4B"],
206
+ [/qwen.*?3.*?1\.8b/i, "Qwen/Qwen3-1.8B"],
207
+ [/qwen.*?3.*?0\.5b/i, "Qwen/Qwen3-0.5B"],
208
+ [/qwen.*?2\.5/i, "Qwen/Qwen2.5-0.5B"],
209
+ [/qwen.*?2/i, "Qwen/Qwen2-1.5B"],
210
+ [/qwen/i, "Qwen/Qwen-1_8B"],
211
+
212
+ # Phi models (order matters - more specific patterns first)
213
+ [/phi.*?3\.5.*?mini/i, "microsoft/Phi-3.5-mini-instruct"],
214
+ [/phi.*?3.*?mini.*?4k/i, "microsoft/Phi-3-mini-4k-instruct"],
215
+ [/phi.*?3.*?medium/i, "microsoft/Phi-3-medium-4k-instruct"],
216
+ [/phi.*?3.*?small/i, "microsoft/Phi-3-small-8k-instruct"],
217
+ [/phi.*?3.*?mini/i, "microsoft/Phi-3-mini-4k-instruct"],
218
+ [/phi.*?3/i, "microsoft/Phi-3-mini-4k-instruct"],
219
+ [/phi-4/i, "microsoft/phi-4"],
220
+ [/phi.*?2/i, "microsoft/phi-2"],
221
+ [/phi.*?1\.5/i, "microsoft/phi-1_5"],
222
+ [/phi/i, "microsoft/phi-2"]
223
+ ]
224
+ }
225
+
226
+ # Allow users to register custom tokenizer mappings
227
+ def self.register_tokenizer(model_pattern, tokenizer_id)
228
+ if model_pattern.is_a?(String)
229
+ TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
230
+ elsif model_pattern.is_a?(Regexp)
231
+ TOKENIZER_REGISTRY[:patterns] ||= []
232
+ TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
233
+ else
234
+ raise ArgumentError, "model_pattern must be a String or Regexp"
235
+ end
236
+ end
237
+
238
+ # Guess the tokenizer for a model
239
+ def self.guess_tokenizer(model_id)
240
+ # Check exact matches first
241
+ return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
242
+
243
+ # Check patterns
244
+ if patterns = TOKENIZER_REGISTRY[:patterns]
245
+ patterns.each do |pattern, tokenizer|
246
+ return tokenizer if model_id.match?(pattern)
247
+ end
248
+ end
249
+
250
+ # Default: try removing common GGUF suffixes
251
+ base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
252
+ base_model
253
+ end
254
+
255
+ # Chat interface — always returns a String
256
+ def chat(messages, **options)
257
+ prompt = apply_chat_template(messages)
258
+ generate(prompt, **options)
259
+ end
260
+
261
+ # Streaming chat interface
262
+ def chat_stream(messages, **options, &block)
263
+ prompt = apply_chat_template(messages)
264
+ generate_stream(prompt, **options, &block)
265
+ end
266
+
267
+ # Chat with tool calling — always returns a ToolCallResult
268
+ # Set execute: true to automatically run the tools (default: false)
269
+ def chat_with_tools(messages, tools:, execute: false, **options)
270
+ tool_prompt = build_tool_system_prompt(tools)
271
+ augmented = inject_tool_instructions(messages, tool_prompt)
272
+
273
+ raw_response = chat(augmented, **options)
274
+
275
+ result = ToolCallParser.parse(raw_response, available_tools: tools)
276
+
277
+ if result.has_tool_calls? && execute
278
+ tool_results = result.tool_calls.map do |tool_call|
279
+ tool = tools.find { |t| t.name == tool_call.name }
280
+ unless tool
281
+ next { tool_call: tool_call, result: nil, error: "Unknown tool: #{tool_call.name}" }
282
+ end
283
+
284
+ begin
285
+ output = tool.call(tool_call.arguments)
286
+ { tool_call: tool_call, result: output, error: nil }
287
+ rescue Exception => e
288
+ { tool_call: tool_call, result: nil, error: e.message }
289
+ end
290
+ end
291
+
292
+ ToolCallResult.new(
293
+ tool_calls: result.tool_calls,
294
+ tool_results: tool_results,
295
+ text_response: result.text_response,
296
+ raw_response: raw_response
297
+ )
298
+ else
299
+ ToolCallResult.new(
300
+ tool_calls: result.tool_calls,
301
+ tool_results: [],
302
+ text_response: result.has_tool_calls? ? result.text_response : raw_response,
303
+ raw_response: raw_response
304
+ )
305
+ end
306
+ end
307
+
308
+ # Inspect method for debugging and exploration
309
+ def inspect
310
+ opts = options rescue {}
311
+
312
+ # Extract key information
313
+ model_type = opts["model_type"] || "Unknown"
314
+ device = opts["device"] || self.device.to_s rescue "unknown"
315
+
316
+ # Build the inspect string
317
+ parts = ["#<Candle::LLM"]
318
+
319
+ # Add base model or model_id
320
+ if opts["base_model"]
321
+ parts << "model=#{opts["base_model"]}"
322
+ elsif opts["model_id"]
323
+ parts << "model=#{opts["model_id"]}"
324
+ elsif respond_to?(:model_id)
325
+ parts << "model=#{model_id}"
326
+ end
327
+
328
+ # Add GGUF file if present
329
+ if opts["gguf_file"]
330
+ parts << "gguf=#{opts["gguf_file"]}"
331
+ end
332
+
333
+ # Add device
334
+ parts << "device=#{device}"
335
+
336
+ # Add model type
337
+ parts << "type=#{model_type}"
338
+
339
+ # Add architecture for GGUF models
340
+ if opts["architecture"]
341
+ parts << "arch=#{opts["architecture"]}"
342
+ end
343
+
344
+ parts.join(" ") + ">"
345
+ end
346
+
347
+ def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
348
+ begin
349
+ _generate(prompt, config)
350
+ ensure
351
+ clear_cache if reset_cache
352
+ end
353
+ end
354
+
355
+ def generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block)
356
+ begin
357
+ _generate_stream(prompt, config, &block)
358
+ ensure
359
+ clear_cache if reset_cache
360
+ end
361
+ end
362
+
363
+ def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
364
+ model_str = if gguf_file
365
+ "#{model_id}@#{gguf_file}"
366
+ else
367
+ model_id
368
+ end
369
+
370
+ # Handle GGUF models that need tokenizer
371
+ if model_str.downcase.include?("gguf") && tokenizer.nil?
372
+ # Try to load without tokenizer first
373
+ begin
374
+ _from_pretrained(model_str, device)
375
+ rescue => e
376
+ if e.message.include?("No tokenizer found")
377
+ # Auto-detect tokenizer
378
+ detected_tokenizer = guess_tokenizer(model_id)
379
+ Candle.logger.info "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
380
+ model_str = "#{model_str}@@#{detected_tokenizer}"
381
+ _from_pretrained(model_str, device)
382
+ else
383
+ raise e
384
+ end
385
+ end
386
+ elsif tokenizer
387
+ # User specified tokenizer
388
+ model_str = "#{model_str}@@#{tokenizer}"
389
+ _from_pretrained(model_str, device)
390
+ else
391
+ # Non-GGUF model or GGUF with embedded tokenizer
392
+ _from_pretrained(model_str, device)
393
+ end
394
+ end
395
+
396
+ private
397
+
398
+ def build_tool_system_prompt(tools)
399
+ tool_defs = tools.map { |t| JSON.generate(t.to_tool_definition) }.join("\n\n")
400
+ "You are a helpful assistant with access to the following tools:\n\n" \
401
+ "#{tool_defs}\n\n" \
402
+ "When you need to use a tool, respond with a tool call in the following format:\n" \
403
+ "<tool_call>\n" \
404
+ "{\"name\": \"tool_name\", \"arguments\": {\"arg1\": \"value1\"}}\n" \
405
+ "</tool_call>\n\n" \
406
+ "If you don't need to use a tool, respond normally with text."
407
+ end
408
+
409
+ def inject_tool_instructions(messages, tool_prompt)
410
+ msgs = messages.map { |m| m.dup }
411
+ if msgs.first && msgs.first[:role] == "system"
412
+ msgs.first[:content] = "#{tool_prompt}\n\n#{msgs.first[:content]}"
413
+ else
414
+ msgs.unshift({ role: "system", content: tool_prompt })
415
+ end
416
+ msgs
417
+ end
418
+
419
+ # Extract JSON content from generated text, handling stop tokens and extra content
420
+ def extract_json_content(text)
421
+ # Remove any content after common stop tokens
422
+ cleaned = text
423
+
424
+ # Check for EOS tokens and truncate at the first one found
425
+ model_eos_tokens.each do |token|
426
+ if idx = cleaned.index(token)
427
+ cleaned = cleaned[0...idx]
428
+ end
429
+ end
430
+
431
+ # Try to find valid JSON boundaries
432
+ # First try a simple approach - find the first { or [ and match to its closing } or ]
433
+ start_idx = cleaned.index(/[\{\[]/)
434
+ return cleaned.strip unless start_idx
435
+
436
+ # Extract from the start position
437
+ json_candidate = cleaned[start_idx..-1]
438
+
439
+ # Try to find a valid JSON object or array
440
+ # This regex handles nested structures better
441
+ if json_candidate[0] == '{'
442
+ # Match a JSON object
443
+ bracket_count = 0
444
+ in_string = false
445
+ escape_next = false
446
+
447
+ json_candidate.chars.each_with_index do |char, idx|
448
+ if !in_string
449
+ case char
450
+ when '{'
451
+ bracket_count += 1
452
+ when '}'
453
+ bracket_count -= 1
454
+ if bracket_count == 0
455
+ return json_candidate[0..idx]
456
+ end
457
+ when '"'
458
+ in_string = true unless escape_next
459
+ end
460
+ else
461
+ if char == '"' && !escape_next
462
+ in_string = false
463
+ end
464
+ end
465
+
466
+ escape_next = (!escape_next && char == '\\')
467
+ end
468
+ elsif json_candidate[0] == '['
469
+ # Match a JSON array (similar logic)
470
+ bracket_count = 0
471
+ in_string = false
472
+ escape_next = false
473
+
474
+ json_candidate.chars.each_with_index do |char, idx|
475
+ if !in_string
476
+ case char
477
+ when '['
478
+ bracket_count += 1
479
+ when ']'
480
+ bracket_count -= 1
481
+ if bracket_count == 0
482
+ return json_candidate[0..idx]
483
+ end
484
+ when '"'
485
+ in_string = true unless escape_next
486
+ end
487
+ else
488
+ if char == '"' && !escape_next
489
+ in_string = false
490
+ end
491
+ end
492
+
493
+ escape_next = (!escape_next && char == '\\')
494
+ end
495
+ end
496
+
497
+ # If no valid JSON structure found, return the cleaned string
498
+ cleaned.strip
499
+ end
500
+
501
+ # Legacy format messages method - kept for backward compatibility
502
+ # Use apply_chat_template for proper model-specific formatting
503
+ def format_messages(messages)
504
+ formatted = messages.map do |msg|
505
+ case msg[:role]
506
+ when "system"
507
+ "System: #{msg[:content]}"
508
+ when "user"
509
+ "User: #{msg[:content]}"
510
+ when "assistant"
511
+ "Assistant: #{msg[:content]}"
512
+ else
513
+ msg[:content]
514
+ end
515
+ end.join("\n\n")
516
+
517
+ # Add a prompt for the assistant to respond
518
+ formatted + "\n\nAssistant:"
519
+ end
520
+ end
521
+
522
+ class GenerationConfig
523
+ # Convenience method to create config with overrides
524
+ def with(**overrides)
525
+ current_config = {
526
+ max_length: max_length,
527
+ temperature: temperature,
528
+ top_p: top_p,
529
+ top_k: top_k,
530
+ repetition_penalty: repetition_penalty,
531
+ seed: seed,
532
+ stop_sequences: stop_sequences,
533
+ include_prompt: include_prompt,
534
+ constraint: defined?(@constraint) ? @constraint : nil
535
+ }.compact
536
+
537
+ self.class.new(current_config.merge(overrides))
538
+ end
539
+
540
+ # Create a deterministic configuration (temperature = 0, fixed seed)
541
+ def self.deterministic(**opts)
542
+ defaults = {
543
+ temperature: 0.0,
544
+ top_p: nil,
545
+ top_k: 1,
546
+ seed: 42
547
+ }
548
+ new(defaults.merge(opts))
549
+ end
550
+
551
+ # Create a creative configuration (higher temperature, random seed)
552
+ def self.creative(**opts)
553
+ defaults = {
554
+ temperature: 1.0,
555
+ top_p: 0.95,
556
+ top_k: 50,
557
+ repetition_penalty: 1.2
558
+ }
559
+ new(defaults.merge(opts))
560
+ end
561
+
562
+ # Create a balanced configuration (moderate temperature, random seed)
563
+ def self.balanced(**opts)
564
+ defaults = {
565
+ temperature: 0.7,
566
+ top_p: 0.9,
567
+ top_k: 40
568
+ }
569
+ new(defaults.merge(opts))
570
+ end
571
+
572
+ # Inspect method for debugging and exploration
573
+ def inspect
574
+ opts = options rescue {}
575
+
576
+ parts = ["#<Candle::GenerationConfig"]
577
+
578
+ # Add key configuration parameters
579
+ parts << "temp=#{opts["temperature"]}" if opts["temperature"]
580
+ parts << "max=#{opts["max_length"]}" if opts["max_length"]
581
+ parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
582
+ parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
583
+ parts << "seed=#{opts["seed"]}" if opts["seed"]
584
+
585
+ # Add flags
586
+ flags = []
587
+ flags << "debug" if opts["debug_tokens"]
588
+ flags << "constraint" if opts["has_constraint"]
589
+ flags << "stop_on_match" if opts["stop_on_match"]
590
+ parts << "flags=[#{flags.join(",")}]" if flags.any?
591
+
592
+ parts.join(" ") + ">"
593
+ end
594
+ end
595
+ end