red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
data/lib/candle/llm.rb
ADDED
|
@@ -0,0 +1,595 @@
|
|
|
1
|
+
require 'json'
|
|
2
|
+
|
|
3
|
+
module Candle
|
|
4
|
+
class LLM
|
|
5
|
+
# Cache for EOS token to avoid repeated calls
|
|
6
|
+
def cached_eos_token
|
|
7
|
+
@cached_eos_token ||= begin
|
|
8
|
+
if respond_to?(:eos_token)
|
|
9
|
+
eos_token rescue nil
|
|
10
|
+
end
|
|
11
|
+
end
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
# Get model-specific EOS tokens
|
|
15
|
+
def model_eos_tokens
|
|
16
|
+
@model_eos_tokens ||= begin
|
|
17
|
+
tokens = []
|
|
18
|
+
if model_eos = cached_eos_token
|
|
19
|
+
tokens << model_eos
|
|
20
|
+
# For Gemma, also include end_of_turn for chat scenarios and </s>
|
|
21
|
+
# Even though </s> is technically an HTML tag in Gemma's vocabulary,
|
|
22
|
+
# it seems to use it as a generation boundary in practice
|
|
23
|
+
if model_name.downcase.include?("gemma")
|
|
24
|
+
tokens << "<end_of_turn>"
|
|
25
|
+
tokens << "</s>"
|
|
26
|
+
end
|
|
27
|
+
else
|
|
28
|
+
# Fallback to common tokens only if model doesn't provide one
|
|
29
|
+
tokens = ["</s>", "<|endoftext|>", "<|im_end|>", "<end>"]
|
|
30
|
+
end
|
|
31
|
+
tokens.uniq
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
# Create a structured constraint from a JSON schema
|
|
35
|
+
# Uses the model's vocabulary with proper byte encoding handling
|
|
36
|
+
def constraint_from_schema(schema)
|
|
37
|
+
schema_str = schema.is_a?(String) ? schema : JSON.generate(schema)
|
|
38
|
+
|
|
39
|
+
# Extract the tokenizer source model ID for proper vocabulary loading
|
|
40
|
+
tokenizer_model = tokenizer_source_model
|
|
41
|
+
if tokenizer_model
|
|
42
|
+
begin
|
|
43
|
+
StructuredConstraint.from_schema_with_model(schema_str, tokenizer_model)
|
|
44
|
+
rescue RuntimeError => e
|
|
45
|
+
# Fall back to legacy method if from_pretrained fails
|
|
46
|
+
# (e.g., tokenizer doesn't have EOS token in expected format)
|
|
47
|
+
if e.message.include?("UnsupportedTokenizer")
|
|
48
|
+
StructuredConstraint.from_schema(schema_str, tokenizer)
|
|
49
|
+
else
|
|
50
|
+
raise
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
else
|
|
54
|
+
# Fall back to legacy method if we can't determine the model
|
|
55
|
+
StructuredConstraint.from_schema(schema_str, tokenizer)
|
|
56
|
+
end
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Create a structured constraint from a regex pattern
|
|
60
|
+
# Uses the model's vocabulary with proper byte encoding handling
|
|
61
|
+
def constraint_from_regex(pattern)
|
|
62
|
+
pattern_str = pattern.is_a?(Regexp) ? pattern.source : pattern.to_s
|
|
63
|
+
|
|
64
|
+
# Extract the tokenizer source model ID for proper vocabulary loading
|
|
65
|
+
tokenizer_model = tokenizer_source_model
|
|
66
|
+
if tokenizer_model
|
|
67
|
+
begin
|
|
68
|
+
StructuredConstraint.from_regex_with_model(pattern_str, tokenizer_model)
|
|
69
|
+
rescue RuntimeError => e
|
|
70
|
+
# Fall back to legacy method if from_pretrained fails
|
|
71
|
+
if e.message.include?("UnsupportedTokenizer")
|
|
72
|
+
StructuredConstraint.from_regex(pattern_str, tokenizer)
|
|
73
|
+
else
|
|
74
|
+
raise
|
|
75
|
+
end
|
|
76
|
+
end
|
|
77
|
+
else
|
|
78
|
+
# Fall back to legacy method if we can't determine the model
|
|
79
|
+
StructuredConstraint.from_regex(pattern_str, tokenizer)
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
private
|
|
84
|
+
|
|
85
|
+
# Get the model ID to use for vocabulary loading
|
|
86
|
+
# This handles GGUF models by extracting the tokenizer source
|
|
87
|
+
def tokenizer_source_model
|
|
88
|
+
opts = options rescue {}
|
|
89
|
+
|
|
90
|
+
# For GGUF models, use the tokenizer source if available
|
|
91
|
+
if opts["tokenizer_source"]
|
|
92
|
+
return opts["tokenizer_source"]
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
# For regular models, use the base model ID
|
|
96
|
+
if opts["base_model"]
|
|
97
|
+
return opts["base_model"]
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
# Try model_id but strip GGUF parts
|
|
101
|
+
model = opts["model_id"] || (model_id rescue nil)
|
|
102
|
+
return nil unless model
|
|
103
|
+
|
|
104
|
+
# Remove GGUF file suffix if present
|
|
105
|
+
if model.include?("@")
|
|
106
|
+
model = model.split("@").first
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
# For GGUF repos, try to guess the tokenizer source
|
|
110
|
+
if model.downcase.include?("gguf")
|
|
111
|
+
guessed = self.class.guess_tokenizer(model)
|
|
112
|
+
return guessed if guessed && guessed != model
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
model
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
public
|
|
119
|
+
|
|
120
|
+
# Generate with regex constraint
|
|
121
|
+
def generate_regex(prompt, pattern:, stop_on_match: true, **options)
|
|
122
|
+
constraint = constraint_from_regex(pattern)
|
|
123
|
+
|
|
124
|
+
# Configure generation with early stopping by default
|
|
125
|
+
config_opts = options.merge(
|
|
126
|
+
constraint: constraint,
|
|
127
|
+
stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, stop_on_match),
|
|
128
|
+
stop_on_match: stop_on_match
|
|
129
|
+
)
|
|
130
|
+
config = options[:config] || GenerationConfig.balanced(**config_opts)
|
|
131
|
+
|
|
132
|
+
generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
# Generate and parse structured output from a JSON schema
|
|
136
|
+
def generate_structured(prompt, schema:, **options)
|
|
137
|
+
constraint = constraint_from_schema(schema)
|
|
138
|
+
|
|
139
|
+
# Configure generation with early stopping by default
|
|
140
|
+
config_opts = options.merge(
|
|
141
|
+
constraint: constraint,
|
|
142
|
+
stop_on_constraint_satisfaction: options.fetch(:stop_on_constraint_satisfaction, true)
|
|
143
|
+
)
|
|
144
|
+
config = options[:config] || GenerationConfig.balanced(**config_opts)
|
|
145
|
+
|
|
146
|
+
result = generate(prompt, config: config, reset_cache: options.fetch(:reset_cache, true))
|
|
147
|
+
|
|
148
|
+
# Try to parse as JSON
|
|
149
|
+
begin
|
|
150
|
+
# First, try to extract JSON if there's content after stop tokens
|
|
151
|
+
json_content = extract_json_content(result)
|
|
152
|
+
JSON.parse(json_content)
|
|
153
|
+
rescue JSON::ParserError => e
|
|
154
|
+
# Return the raw string if parsing fails
|
|
155
|
+
Candle.logger.warn "Generated output is not valid JSON: #{e.message}" if options[:warn_on_parse_error]
|
|
156
|
+
result
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
# Tokenizer registry for automatic detection
|
|
160
|
+
TOKENIZER_REGISTRY = {
|
|
161
|
+
# Exact model matches
|
|
162
|
+
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF" => "mistralai/Mistral-7B-Instruct-v0.2",
|
|
163
|
+
"TheBloke/Mistral-7B-v0.1-GGUF" => "mistralai/Mistral-7B-v0.1",
|
|
164
|
+
"TheBloke/Llama-2-7B-Chat-GGUF" => "meta-llama/Llama-2-7b-chat-hf",
|
|
165
|
+
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
|
166
|
+
|
|
167
|
+
# Qwen official GGUF models
|
|
168
|
+
"Qwen/Qwen3-8B-GGUF" => "Qwen/Qwen3-8B",
|
|
169
|
+
"Qwen/Qwen3-4B-GGUF" => "Qwen/Qwen3-4B",
|
|
170
|
+
"Qwen/Qwen3-14B-GGUF" => "Qwen/Qwen3-14B",
|
|
171
|
+
"Qwen/Qwen3-32B-GGUF" => "Qwen/Qwen3-32B",
|
|
172
|
+
"Qwen/Qwen3-72B-GGUF" => "Qwen/Qwen3-72B",
|
|
173
|
+
|
|
174
|
+
# Phi GGUF models
|
|
175
|
+
"TheBloke/phi-2-GGUF" => "microsoft/phi-2",
|
|
176
|
+
"microsoft/phi-4-gguf" => "microsoft/phi-4",
|
|
177
|
+
"bartowski/Phi-3.5-mini-instruct-GGUF" => "microsoft/Phi-3.5-mini-instruct",
|
|
178
|
+
|
|
179
|
+
# Pattern-based fallbacks (evaluated in order)
|
|
180
|
+
:patterns => [
|
|
181
|
+
# Mistral models
|
|
182
|
+
[/mistral.*?7b.*?instruct.*?v0\.2/i, "mistralai/Mistral-7B-Instruct-v0.2"],
|
|
183
|
+
[/mistral.*?7b.*?instruct.*?v0\.1/i, "mistralai/Mistral-7B-Instruct-v0.1"],
|
|
184
|
+
[/mistral.*?7b/i, "mistralai/Mistral-7B-v0.1"],
|
|
185
|
+
|
|
186
|
+
# Llama models
|
|
187
|
+
[/llama.*?3.*?8b/i, "meta-llama/Meta-Llama-3-8B"],
|
|
188
|
+
[/llama.*?3.*?70b/i, "meta-llama/Meta-Llama-3-70B"],
|
|
189
|
+
[/llama.*?2.*?7b.*?chat/i, "meta-llama/Llama-2-7b-chat-hf"],
|
|
190
|
+
[/llama.*?2.*?13b.*?chat/i, "meta-llama/Llama-2-13b-chat-hf"],
|
|
191
|
+
[/llama.*?2.*?70b.*?chat/i, "meta-llama/Llama-2-70b-chat-hf"],
|
|
192
|
+
[/tinyllama/i, "TinyLlama/TinyLlama-1.1B-Chat-v1.0"],
|
|
193
|
+
|
|
194
|
+
# Gemma models
|
|
195
|
+
[/gemma.*?2.*?9b/i, "google/gemma-2-9b"],
|
|
196
|
+
[/gemma.*?2.*?2b/i, "google/gemma-2-2b"],
|
|
197
|
+
[/gemma.*?7b/i, "google/gemma-7b"],
|
|
198
|
+
[/gemma.*?2b/i, "google/gemma-2b"],
|
|
199
|
+
|
|
200
|
+
# Qwen models
|
|
201
|
+
[/qwen.*?3.*?72b/i, "Qwen/Qwen3-72B"],
|
|
202
|
+
[/qwen.*?3.*?32b/i, "Qwen/Qwen3-32B"],
|
|
203
|
+
[/qwen.*?3.*?14b/i, "Qwen/Qwen3-14B"],
|
|
204
|
+
[/qwen.*?3.*?8b/i, "Qwen/Qwen3-8B"],
|
|
205
|
+
[/qwen.*?3.*?4b/i, "Qwen/Qwen3-4B"],
|
|
206
|
+
[/qwen.*?3.*?1\.8b/i, "Qwen/Qwen3-1.8B"],
|
|
207
|
+
[/qwen.*?3.*?0\.5b/i, "Qwen/Qwen3-0.5B"],
|
|
208
|
+
[/qwen.*?2\.5/i, "Qwen/Qwen2.5-0.5B"],
|
|
209
|
+
[/qwen.*?2/i, "Qwen/Qwen2-1.5B"],
|
|
210
|
+
[/qwen/i, "Qwen/Qwen-1_8B"],
|
|
211
|
+
|
|
212
|
+
# Phi models (order matters - more specific patterns first)
|
|
213
|
+
[/phi.*?3\.5.*?mini/i, "microsoft/Phi-3.5-mini-instruct"],
|
|
214
|
+
[/phi.*?3.*?mini.*?4k/i, "microsoft/Phi-3-mini-4k-instruct"],
|
|
215
|
+
[/phi.*?3.*?medium/i, "microsoft/Phi-3-medium-4k-instruct"],
|
|
216
|
+
[/phi.*?3.*?small/i, "microsoft/Phi-3-small-8k-instruct"],
|
|
217
|
+
[/phi.*?3.*?mini/i, "microsoft/Phi-3-mini-4k-instruct"],
|
|
218
|
+
[/phi.*?3/i, "microsoft/Phi-3-mini-4k-instruct"],
|
|
219
|
+
[/phi-4/i, "microsoft/phi-4"],
|
|
220
|
+
[/phi.*?2/i, "microsoft/phi-2"],
|
|
221
|
+
[/phi.*?1\.5/i, "microsoft/phi-1_5"],
|
|
222
|
+
[/phi/i, "microsoft/phi-2"]
|
|
223
|
+
]
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
# Allow users to register custom tokenizer mappings
|
|
227
|
+
def self.register_tokenizer(model_pattern, tokenizer_id)
|
|
228
|
+
if model_pattern.is_a?(String)
|
|
229
|
+
TOKENIZER_REGISTRY[model_pattern] = tokenizer_id
|
|
230
|
+
elsif model_pattern.is_a?(Regexp)
|
|
231
|
+
TOKENIZER_REGISTRY[:patterns] ||= []
|
|
232
|
+
TOKENIZER_REGISTRY[:patterns].unshift([model_pattern, tokenizer_id])
|
|
233
|
+
else
|
|
234
|
+
raise ArgumentError, "model_pattern must be a String or Regexp"
|
|
235
|
+
end
|
|
236
|
+
end
|
|
237
|
+
|
|
238
|
+
# Guess the tokenizer for a model
|
|
239
|
+
def self.guess_tokenizer(model_id)
|
|
240
|
+
# Check exact matches first
|
|
241
|
+
return TOKENIZER_REGISTRY[model_id] if TOKENIZER_REGISTRY[model_id]
|
|
242
|
+
|
|
243
|
+
# Check patterns
|
|
244
|
+
if patterns = TOKENIZER_REGISTRY[:patterns]
|
|
245
|
+
patterns.each do |pattern, tokenizer|
|
|
246
|
+
return tokenizer if model_id.match?(pattern)
|
|
247
|
+
end
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
# Default: try removing common GGUF suffixes
|
|
251
|
+
base_model = model_id.gsub(/-gguf|-q\d+_\w+$/i, "")
|
|
252
|
+
base_model
|
|
253
|
+
end
|
|
254
|
+
|
|
255
|
+
# Chat interface — always returns a String
|
|
256
|
+
def chat(messages, **options)
|
|
257
|
+
prompt = apply_chat_template(messages)
|
|
258
|
+
generate(prompt, **options)
|
|
259
|
+
end
|
|
260
|
+
|
|
261
|
+
# Streaming chat interface
|
|
262
|
+
def chat_stream(messages, **options, &block)
|
|
263
|
+
prompt = apply_chat_template(messages)
|
|
264
|
+
generate_stream(prompt, **options, &block)
|
|
265
|
+
end
|
|
266
|
+
|
|
267
|
+
# Chat with tool calling — always returns a ToolCallResult
|
|
268
|
+
# Set execute: true to automatically run the tools (default: false)
|
|
269
|
+
def chat_with_tools(messages, tools:, execute: false, **options)
|
|
270
|
+
tool_prompt = build_tool_system_prompt(tools)
|
|
271
|
+
augmented = inject_tool_instructions(messages, tool_prompt)
|
|
272
|
+
|
|
273
|
+
raw_response = chat(augmented, **options)
|
|
274
|
+
|
|
275
|
+
result = ToolCallParser.parse(raw_response, available_tools: tools)
|
|
276
|
+
|
|
277
|
+
if result.has_tool_calls? && execute
|
|
278
|
+
tool_results = result.tool_calls.map do |tool_call|
|
|
279
|
+
tool = tools.find { |t| t.name == tool_call.name }
|
|
280
|
+
unless tool
|
|
281
|
+
next { tool_call: tool_call, result: nil, error: "Unknown tool: #{tool_call.name}" }
|
|
282
|
+
end
|
|
283
|
+
|
|
284
|
+
begin
|
|
285
|
+
output = tool.call(tool_call.arguments)
|
|
286
|
+
{ tool_call: tool_call, result: output, error: nil }
|
|
287
|
+
rescue Exception => e
|
|
288
|
+
{ tool_call: tool_call, result: nil, error: e.message }
|
|
289
|
+
end
|
|
290
|
+
end
|
|
291
|
+
|
|
292
|
+
ToolCallResult.new(
|
|
293
|
+
tool_calls: result.tool_calls,
|
|
294
|
+
tool_results: tool_results,
|
|
295
|
+
text_response: result.text_response,
|
|
296
|
+
raw_response: raw_response
|
|
297
|
+
)
|
|
298
|
+
else
|
|
299
|
+
ToolCallResult.new(
|
|
300
|
+
tool_calls: result.tool_calls,
|
|
301
|
+
tool_results: [],
|
|
302
|
+
text_response: result.has_tool_calls? ? result.text_response : raw_response,
|
|
303
|
+
raw_response: raw_response
|
|
304
|
+
)
|
|
305
|
+
end
|
|
306
|
+
end
|
|
307
|
+
|
|
308
|
+
# Inspect method for debugging and exploration
|
|
309
|
+
def inspect
|
|
310
|
+
opts = options rescue {}
|
|
311
|
+
|
|
312
|
+
# Extract key information
|
|
313
|
+
model_type = opts["model_type"] || "Unknown"
|
|
314
|
+
device = opts["device"] || self.device.to_s rescue "unknown"
|
|
315
|
+
|
|
316
|
+
# Build the inspect string
|
|
317
|
+
parts = ["#<Candle::LLM"]
|
|
318
|
+
|
|
319
|
+
# Add base model or model_id
|
|
320
|
+
if opts["base_model"]
|
|
321
|
+
parts << "model=#{opts["base_model"]}"
|
|
322
|
+
elsif opts["model_id"]
|
|
323
|
+
parts << "model=#{opts["model_id"]}"
|
|
324
|
+
elsif respond_to?(:model_id)
|
|
325
|
+
parts << "model=#{model_id}"
|
|
326
|
+
end
|
|
327
|
+
|
|
328
|
+
# Add GGUF file if present
|
|
329
|
+
if opts["gguf_file"]
|
|
330
|
+
parts << "gguf=#{opts["gguf_file"]}"
|
|
331
|
+
end
|
|
332
|
+
|
|
333
|
+
# Add device
|
|
334
|
+
parts << "device=#{device}"
|
|
335
|
+
|
|
336
|
+
# Add model type
|
|
337
|
+
parts << "type=#{model_type}"
|
|
338
|
+
|
|
339
|
+
# Add architecture for GGUF models
|
|
340
|
+
if opts["architecture"]
|
|
341
|
+
parts << "arch=#{opts["architecture"]}"
|
|
342
|
+
end
|
|
343
|
+
|
|
344
|
+
parts.join(" ") + ">"
|
|
345
|
+
end
|
|
346
|
+
|
|
347
|
+
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
|
348
|
+
begin
|
|
349
|
+
_generate(prompt, config)
|
|
350
|
+
ensure
|
|
351
|
+
clear_cache if reset_cache
|
|
352
|
+
end
|
|
353
|
+
end
|
|
354
|
+
|
|
355
|
+
def generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block)
|
|
356
|
+
begin
|
|
357
|
+
_generate_stream(prompt, config, &block)
|
|
358
|
+
ensure
|
|
359
|
+
clear_cache if reset_cache
|
|
360
|
+
end
|
|
361
|
+
end
|
|
362
|
+
|
|
363
|
+
def self.from_pretrained(model_id, device: Candle::Device.best, gguf_file: nil, tokenizer: nil)
|
|
364
|
+
model_str = if gguf_file
|
|
365
|
+
"#{model_id}@#{gguf_file}"
|
|
366
|
+
else
|
|
367
|
+
model_id
|
|
368
|
+
end
|
|
369
|
+
|
|
370
|
+
# Handle GGUF models that need tokenizer
|
|
371
|
+
if model_str.downcase.include?("gguf") && tokenizer.nil?
|
|
372
|
+
# Try to load without tokenizer first
|
|
373
|
+
begin
|
|
374
|
+
_from_pretrained(model_str, device)
|
|
375
|
+
rescue => e
|
|
376
|
+
if e.message.include?("No tokenizer found")
|
|
377
|
+
# Auto-detect tokenizer
|
|
378
|
+
detected_tokenizer = guess_tokenizer(model_id)
|
|
379
|
+
Candle.logger.info "No tokenizer found in GGUF repo. Using tokenizer from: #{detected_tokenizer}"
|
|
380
|
+
model_str = "#{model_str}@@#{detected_tokenizer}"
|
|
381
|
+
_from_pretrained(model_str, device)
|
|
382
|
+
else
|
|
383
|
+
raise e
|
|
384
|
+
end
|
|
385
|
+
end
|
|
386
|
+
elsif tokenizer
|
|
387
|
+
# User specified tokenizer
|
|
388
|
+
model_str = "#{model_str}@@#{tokenizer}"
|
|
389
|
+
_from_pretrained(model_str, device)
|
|
390
|
+
else
|
|
391
|
+
# Non-GGUF model or GGUF with embedded tokenizer
|
|
392
|
+
_from_pretrained(model_str, device)
|
|
393
|
+
end
|
|
394
|
+
end
|
|
395
|
+
|
|
396
|
+
private
|
|
397
|
+
|
|
398
|
+
def build_tool_system_prompt(tools)
|
|
399
|
+
tool_defs = tools.map { |t| JSON.generate(t.to_tool_definition) }.join("\n\n")
|
|
400
|
+
"You are a helpful assistant with access to the following tools:\n\n" \
|
|
401
|
+
"#{tool_defs}\n\n" \
|
|
402
|
+
"When you need to use a tool, respond with a tool call in the following format:\n" \
|
|
403
|
+
"<tool_call>\n" \
|
|
404
|
+
"{\"name\": \"tool_name\", \"arguments\": {\"arg1\": \"value1\"}}\n" \
|
|
405
|
+
"</tool_call>\n\n" \
|
|
406
|
+
"If you don't need to use a tool, respond normally with text."
|
|
407
|
+
end
|
|
408
|
+
|
|
409
|
+
def inject_tool_instructions(messages, tool_prompt)
|
|
410
|
+
msgs = messages.map { |m| m.dup }
|
|
411
|
+
if msgs.first && msgs.first[:role] == "system"
|
|
412
|
+
msgs.first[:content] = "#{tool_prompt}\n\n#{msgs.first[:content]}"
|
|
413
|
+
else
|
|
414
|
+
msgs.unshift({ role: "system", content: tool_prompt })
|
|
415
|
+
end
|
|
416
|
+
msgs
|
|
417
|
+
end
|
|
418
|
+
|
|
419
|
+
# Extract JSON content from generated text, handling stop tokens and extra content
|
|
420
|
+
def extract_json_content(text)
|
|
421
|
+
# Remove any content after common stop tokens
|
|
422
|
+
cleaned = text
|
|
423
|
+
|
|
424
|
+
# Check for EOS tokens and truncate at the first one found
|
|
425
|
+
model_eos_tokens.each do |token|
|
|
426
|
+
if idx = cleaned.index(token)
|
|
427
|
+
cleaned = cleaned[0...idx]
|
|
428
|
+
end
|
|
429
|
+
end
|
|
430
|
+
|
|
431
|
+
# Try to find valid JSON boundaries
|
|
432
|
+
# First try a simple approach - find the first { or [ and match to its closing } or ]
|
|
433
|
+
start_idx = cleaned.index(/[\{\[]/)
|
|
434
|
+
return cleaned.strip unless start_idx
|
|
435
|
+
|
|
436
|
+
# Extract from the start position
|
|
437
|
+
json_candidate = cleaned[start_idx..-1]
|
|
438
|
+
|
|
439
|
+
# Try to find a valid JSON object or array
|
|
440
|
+
# This regex handles nested structures better
|
|
441
|
+
if json_candidate[0] == '{'
|
|
442
|
+
# Match a JSON object
|
|
443
|
+
bracket_count = 0
|
|
444
|
+
in_string = false
|
|
445
|
+
escape_next = false
|
|
446
|
+
|
|
447
|
+
json_candidate.chars.each_with_index do |char, idx|
|
|
448
|
+
if !in_string
|
|
449
|
+
case char
|
|
450
|
+
when '{'
|
|
451
|
+
bracket_count += 1
|
|
452
|
+
when '}'
|
|
453
|
+
bracket_count -= 1
|
|
454
|
+
if bracket_count == 0
|
|
455
|
+
return json_candidate[0..idx]
|
|
456
|
+
end
|
|
457
|
+
when '"'
|
|
458
|
+
in_string = true unless escape_next
|
|
459
|
+
end
|
|
460
|
+
else
|
|
461
|
+
if char == '"' && !escape_next
|
|
462
|
+
in_string = false
|
|
463
|
+
end
|
|
464
|
+
end
|
|
465
|
+
|
|
466
|
+
escape_next = (!escape_next && char == '\\')
|
|
467
|
+
end
|
|
468
|
+
elsif json_candidate[0] == '['
|
|
469
|
+
# Match a JSON array (similar logic)
|
|
470
|
+
bracket_count = 0
|
|
471
|
+
in_string = false
|
|
472
|
+
escape_next = false
|
|
473
|
+
|
|
474
|
+
json_candidate.chars.each_with_index do |char, idx|
|
|
475
|
+
if !in_string
|
|
476
|
+
case char
|
|
477
|
+
when '['
|
|
478
|
+
bracket_count += 1
|
|
479
|
+
when ']'
|
|
480
|
+
bracket_count -= 1
|
|
481
|
+
if bracket_count == 0
|
|
482
|
+
return json_candidate[0..idx]
|
|
483
|
+
end
|
|
484
|
+
when '"'
|
|
485
|
+
in_string = true unless escape_next
|
|
486
|
+
end
|
|
487
|
+
else
|
|
488
|
+
if char == '"' && !escape_next
|
|
489
|
+
in_string = false
|
|
490
|
+
end
|
|
491
|
+
end
|
|
492
|
+
|
|
493
|
+
escape_next = (!escape_next && char == '\\')
|
|
494
|
+
end
|
|
495
|
+
end
|
|
496
|
+
|
|
497
|
+
# If no valid JSON structure found, return the cleaned string
|
|
498
|
+
cleaned.strip
|
|
499
|
+
end
|
|
500
|
+
|
|
501
|
+
# Legacy format messages method - kept for backward compatibility
|
|
502
|
+
# Use apply_chat_template for proper model-specific formatting
|
|
503
|
+
def format_messages(messages)
|
|
504
|
+
formatted = messages.map do |msg|
|
|
505
|
+
case msg[:role]
|
|
506
|
+
when "system"
|
|
507
|
+
"System: #{msg[:content]}"
|
|
508
|
+
when "user"
|
|
509
|
+
"User: #{msg[:content]}"
|
|
510
|
+
when "assistant"
|
|
511
|
+
"Assistant: #{msg[:content]}"
|
|
512
|
+
else
|
|
513
|
+
msg[:content]
|
|
514
|
+
end
|
|
515
|
+
end.join("\n\n")
|
|
516
|
+
|
|
517
|
+
# Add a prompt for the assistant to respond
|
|
518
|
+
formatted + "\n\nAssistant:"
|
|
519
|
+
end
|
|
520
|
+
end
|
|
521
|
+
|
|
522
|
+
class GenerationConfig
|
|
523
|
+
# Convenience method to create config with overrides
|
|
524
|
+
def with(**overrides)
|
|
525
|
+
current_config = {
|
|
526
|
+
max_length: max_length,
|
|
527
|
+
temperature: temperature,
|
|
528
|
+
top_p: top_p,
|
|
529
|
+
top_k: top_k,
|
|
530
|
+
repetition_penalty: repetition_penalty,
|
|
531
|
+
seed: seed,
|
|
532
|
+
stop_sequences: stop_sequences,
|
|
533
|
+
include_prompt: include_prompt,
|
|
534
|
+
constraint: defined?(@constraint) ? @constraint : nil
|
|
535
|
+
}.compact
|
|
536
|
+
|
|
537
|
+
self.class.new(current_config.merge(overrides))
|
|
538
|
+
end
|
|
539
|
+
|
|
540
|
+
# Create a deterministic configuration (temperature = 0, fixed seed)
|
|
541
|
+
def self.deterministic(**opts)
|
|
542
|
+
defaults = {
|
|
543
|
+
temperature: 0.0,
|
|
544
|
+
top_p: nil,
|
|
545
|
+
top_k: 1,
|
|
546
|
+
seed: 42
|
|
547
|
+
}
|
|
548
|
+
new(defaults.merge(opts))
|
|
549
|
+
end
|
|
550
|
+
|
|
551
|
+
# Create a creative configuration (higher temperature, random seed)
|
|
552
|
+
def self.creative(**opts)
|
|
553
|
+
defaults = {
|
|
554
|
+
temperature: 1.0,
|
|
555
|
+
top_p: 0.95,
|
|
556
|
+
top_k: 50,
|
|
557
|
+
repetition_penalty: 1.2
|
|
558
|
+
}
|
|
559
|
+
new(defaults.merge(opts))
|
|
560
|
+
end
|
|
561
|
+
|
|
562
|
+
# Create a balanced configuration (moderate temperature, random seed)
|
|
563
|
+
def self.balanced(**opts)
|
|
564
|
+
defaults = {
|
|
565
|
+
temperature: 0.7,
|
|
566
|
+
top_p: 0.9,
|
|
567
|
+
top_k: 40
|
|
568
|
+
}
|
|
569
|
+
new(defaults.merge(opts))
|
|
570
|
+
end
|
|
571
|
+
|
|
572
|
+
# Inspect method for debugging and exploration
|
|
573
|
+
def inspect
|
|
574
|
+
opts = options rescue {}
|
|
575
|
+
|
|
576
|
+
parts = ["#<Candle::GenerationConfig"]
|
|
577
|
+
|
|
578
|
+
# Add key configuration parameters
|
|
579
|
+
parts << "temp=#{opts["temperature"]}" if opts["temperature"]
|
|
580
|
+
parts << "max=#{opts["max_length"]}" if opts["max_length"]
|
|
581
|
+
parts << "top_p=#{opts["top_p"]}" if opts["top_p"]
|
|
582
|
+
parts << "top_k=#{opts["top_k"]}" if opts["top_k"]
|
|
583
|
+
parts << "seed=#{opts["seed"]}" if opts["seed"]
|
|
584
|
+
|
|
585
|
+
# Add flags
|
|
586
|
+
flags = []
|
|
587
|
+
flags << "debug" if opts["debug_tokens"]
|
|
588
|
+
flags << "constraint" if opts["has_constraint"]
|
|
589
|
+
flags << "stop_on_match" if opts["stop_on_match"]
|
|
590
|
+
parts << "flags=[#{flags.join(",")}]" if flags.any?
|
|
591
|
+
|
|
592
|
+
parts.join(" ") + ">"
|
|
593
|
+
end
|
|
594
|
+
end
|
|
595
|
+
end
|