red-candle 1.8.0.pre3-aarch64-linux
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Cargo.lock +5021 -0
- data/Cargo.toml +6 -0
- data/Gemfile +3 -0
- data/LICENSE +22 -0
- data/README.md +1171 -0
- data/Rakefile +167 -0
- data/bin/console +11 -0
- data/bin/setup +17 -0
- data/ext/candle/Cargo.toml +38 -0
- data/ext/candle/build.rs +117 -0
- data/ext/candle/extconf.rb +79 -0
- data/ext/candle/rustfmt.toml +63 -0
- data/ext/candle/src/gvl.rs +58 -0
- data/ext/candle/src/lib.rs +59 -0
- data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
- data/ext/candle/src/llm/gemma.rs +313 -0
- data/ext/candle/src/llm/generation_config.rs +63 -0
- data/ext/candle/src/llm/glm4.rs +236 -0
- data/ext/candle/src/llm/granite.rs +308 -0
- data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
- data/ext/candle/src/llm/llama.rs +396 -0
- data/ext/candle/src/llm/mistral.rs +309 -0
- data/ext/candle/src/llm/mod.rs +49 -0
- data/ext/candle/src/llm/phi.rs +369 -0
- data/ext/candle/src/llm/quantized_gguf.rs +734 -0
- data/ext/candle/src/llm/qwen.rs +261 -0
- data/ext/candle/src/llm/qwen3.rs +257 -0
- data/ext/candle/src/llm/text_generation.rs +284 -0
- data/ext/candle/src/ruby/device.rs +234 -0
- data/ext/candle/src/ruby/dtype.rs +39 -0
- data/ext/candle/src/ruby/embedding_model.rs +477 -0
- data/ext/candle/src/ruby/errors.rs +16 -0
- data/ext/candle/src/ruby/llm.rs +730 -0
- data/ext/candle/src/ruby/mod.rs +24 -0
- data/ext/candle/src/ruby/ner.rs +444 -0
- data/ext/candle/src/ruby/reranker.rs +488 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/structured.rs +92 -0
- data/ext/candle/src/ruby/tensor.rs +731 -0
- data/ext/candle/src/ruby/tokenizer.rs +343 -0
- data/ext/candle/src/ruby/utils.rs +96 -0
- data/ext/candle/src/ruby/vlm.rs +330 -0
- data/ext/candle/src/structured/integration_test.rs +130 -0
- data/ext/candle/src/structured/mod.rs +31 -0
- data/ext/candle/src/structured/schema_processor.rs +215 -0
- data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
- data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
- data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +104 -0
- data/ext/candle/tests/device_tests.rs +43 -0
- data/ext/candle/tests/tensor_tests.rs +162 -0
- data/lib/candle/3.1/candle.so +0 -0
- data/lib/candle/3.2/candle.so +0 -0
- data/lib/candle/3.3/candle.so +0 -0
- data/lib/candle/3.4/candle.so +0 -0
- data/lib/candle/4.0/candle.so +0 -0
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/build_info.rb +67 -0
- data/lib/candle/device_utils.rb +10 -0
- data/lib/candle/embedding_model.rb +75 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +595 -0
- data/lib/candle/logger.rb +149 -0
- data/lib/candle/ner.rb +368 -0
- data/lib/candle/reranker.rb +45 -0
- data/lib/candle/tensor.rb +99 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +5 -0
- data/lib/candle/vlm.rb +31 -0
- data/lib/candle.rb +29 -0
- data/lib/red-candle.rb +1 -0
- metadata +309 -0
|
@@ -0,0 +1,730 @@
|
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
|
2
|
+
use std::cell::RefCell;
|
|
3
|
+
use std::sync::Arc;
|
|
4
|
+
|
|
5
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, qwen3::Qwen3 as RustQwen3, phi::Phi as RustPhi, granite::Granite as RustGranite, granitemoehybrid::GraniteMoeHybrid as RustGraniteMoeHybrid, glm4::Glm4 as RustGlm4, QuantizedGGUF as RustQuantizedGGUF};
|
|
6
|
+
use crate::ruby::{Result, Device};
|
|
7
|
+
use crate::ruby::structured::StructuredConstraint;
|
|
8
|
+
use crate::gvl;
|
|
9
|
+
|
|
10
|
+
// Use an enum to handle different model types instead of trait objects
|
|
11
|
+
enum ModelType {
|
|
12
|
+
Mistral(RustMistral),
|
|
13
|
+
Llama(RustLlama),
|
|
14
|
+
Gemma(RustGemma),
|
|
15
|
+
Qwen(RustQwen),
|
|
16
|
+
Qwen3(RustQwen3),
|
|
17
|
+
Phi(RustPhi),
|
|
18
|
+
Granite(RustGranite),
|
|
19
|
+
GraniteMoeHybrid(RustGraniteMoeHybrid),
|
|
20
|
+
Glm4(RustGlm4),
|
|
21
|
+
QuantizedGGUF(RustQuantizedGGUF),
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
impl ModelType {
|
|
25
|
+
fn generate(&mut self, prompt: &str, config: &RustGenerationConfig) -> candle_core::Result<String> {
|
|
26
|
+
match self {
|
|
27
|
+
ModelType::Mistral(m) => m.generate(prompt, config),
|
|
28
|
+
ModelType::Llama(m) => m.generate(prompt, config),
|
|
29
|
+
ModelType::Gemma(m) => m.generate(prompt, config),
|
|
30
|
+
ModelType::Qwen(m) => m.generate(prompt, config),
|
|
31
|
+
ModelType::Qwen3(m) => m.generate(prompt, config),
|
|
32
|
+
ModelType::Phi(m) => m.generate(prompt, config),
|
|
33
|
+
ModelType::Granite(m) => m.generate(prompt, config),
|
|
34
|
+
ModelType::GraniteMoeHybrid(m) => m.generate(prompt, config),
|
|
35
|
+
ModelType::Glm4(m) => m.generate(prompt, config),
|
|
36
|
+
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
fn generate_stream(
|
|
41
|
+
&mut self,
|
|
42
|
+
prompt: &str,
|
|
43
|
+
config: &RustGenerationConfig,
|
|
44
|
+
callback: impl FnMut(&str),
|
|
45
|
+
) -> candle_core::Result<String> {
|
|
46
|
+
match self {
|
|
47
|
+
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
|
48
|
+
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
|
49
|
+
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
|
50
|
+
ModelType::Qwen(m) => m.generate_stream(prompt, config, callback),
|
|
51
|
+
ModelType::Qwen3(m) => m.generate_stream(prompt, config, callback),
|
|
52
|
+
ModelType::Phi(m) => m.generate_stream(prompt, config, callback),
|
|
53
|
+
ModelType::Granite(m) => m.generate_stream(prompt, config, callback),
|
|
54
|
+
ModelType::GraniteMoeHybrid(m) => m.generate_stream(prompt, config, callback),
|
|
55
|
+
ModelType::Glm4(m) => m.generate_stream(prompt, config, callback),
|
|
56
|
+
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
|
57
|
+
}
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
fn clear_cache(&mut self) {
|
|
61
|
+
match self {
|
|
62
|
+
ModelType::Mistral(m) => m.clear_cache(),
|
|
63
|
+
ModelType::Llama(m) => m.clear_cache(),
|
|
64
|
+
ModelType::Gemma(m) => m.clear_cache(),
|
|
65
|
+
ModelType::Qwen(m) => m.clear_cache(),
|
|
66
|
+
ModelType::Qwen3(m) => m.clear_cache(),
|
|
67
|
+
ModelType::Phi(m) => m.clear_cache(),
|
|
68
|
+
ModelType::Granite(m) => m.clear_cache(),
|
|
69
|
+
ModelType::GraniteMoeHybrid(m) => m.clear_cache(),
|
|
70
|
+
ModelType::Glm4(m) => m.clear_cache(),
|
|
71
|
+
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
fn apply_chat_template(&self, messages: &[serde_json::Value]) -> candle_core::Result<String> {
|
|
76
|
+
match self {
|
|
77
|
+
ModelType::Mistral(_) => {
|
|
78
|
+
// For now, use a simple template for Mistral
|
|
79
|
+
// In the future, we could implement proper Mistral chat templating
|
|
80
|
+
let mut prompt = String::new();
|
|
81
|
+
for message in messages {
|
|
82
|
+
let role = message["role"].as_str().unwrap_or("");
|
|
83
|
+
let content = message["content"].as_str().unwrap_or("");
|
|
84
|
+
match role {
|
|
85
|
+
"system" => prompt.push_str(&format!("System: {}\n\n", content)),
|
|
86
|
+
"user" => prompt.push_str(&format!("User: {}\n\n", content)),
|
|
87
|
+
"assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
|
|
88
|
+
_ => {}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
prompt.push_str("Assistant: ");
|
|
92
|
+
Ok(prompt)
|
|
93
|
+
},
|
|
94
|
+
ModelType::Llama(m) => m.apply_chat_template(messages),
|
|
95
|
+
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
|
96
|
+
ModelType::Qwen(m) => m.apply_chat_template(messages),
|
|
97
|
+
ModelType::Qwen3(m) => m.apply_chat_template(messages),
|
|
98
|
+
ModelType::Phi(m) => m.apply_chat_template(messages),
|
|
99
|
+
ModelType::Granite(m) => m.apply_chat_template(messages),
|
|
100
|
+
ModelType::GraniteMoeHybrid(m) => m.apply_chat_template(messages),
|
|
101
|
+
ModelType::Glm4(m) => m.apply_chat_template(messages),
|
|
102
|
+
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
|
103
|
+
}
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// Macro to extract parameters from Ruby hash to reduce boilerplate
|
|
108
|
+
macro_rules! extract_param {
|
|
109
|
+
// Basic parameter extraction
|
|
110
|
+
($ruby:expr, $kwargs:expr, $config:expr, $param:ident) => {
|
|
111
|
+
if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
|
|
112
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
|
113
|
+
$config.$param = v;
|
|
114
|
+
}
|
|
115
|
+
}
|
|
116
|
+
};
|
|
117
|
+
// Optional parameter extraction (wraps in Some)
|
|
118
|
+
($ruby:expr, $kwargs:expr, $config:expr, $param:ident, optional) => {
|
|
119
|
+
if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
|
|
120
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
|
121
|
+
$config.$param = Some(v);
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
};
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
#[derive(Clone, Debug)]
|
|
128
|
+
#[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
|
|
129
|
+
pub struct GenerationConfig {
|
|
130
|
+
inner: RustGenerationConfig,
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
impl GenerationConfig {
|
|
134
|
+
pub fn new(kwargs: RHash) -> Result<Self> {
|
|
135
|
+
let ruby = Ruby::get().unwrap();
|
|
136
|
+
let mut config = RustGenerationConfig::default();
|
|
137
|
+
|
|
138
|
+
// Extract basic parameters using macro
|
|
139
|
+
extract_param!(ruby, kwargs, config, max_length);
|
|
140
|
+
extract_param!(ruby, kwargs, config, temperature);
|
|
141
|
+
extract_param!(ruby, kwargs, config, top_p, optional);
|
|
142
|
+
extract_param!(ruby, kwargs, config, top_k, optional);
|
|
143
|
+
extract_param!(ruby, kwargs, config, repetition_penalty);
|
|
144
|
+
extract_param!(ruby, kwargs, config, repetition_penalty_last_n);
|
|
145
|
+
extract_param!(ruby, kwargs, config, seed);
|
|
146
|
+
extract_param!(ruby, kwargs, config, include_prompt);
|
|
147
|
+
extract_param!(ruby, kwargs, config, debug_tokens);
|
|
148
|
+
extract_param!(ruby, kwargs, config, stop_on_constraint_satisfaction);
|
|
149
|
+
extract_param!(ruby, kwargs, config, stop_on_match);
|
|
150
|
+
|
|
151
|
+
// Handle special cases that need custom logic
|
|
152
|
+
if let Some(value) = kwargs.get(ruby.to_symbol("stop_sequences")) {
|
|
153
|
+
if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
|
|
154
|
+
config.stop_sequences = arr
|
|
155
|
+
.into_iter()
|
|
156
|
+
.filter_map(|v| <String as TryConvert>::try_convert(v).ok())
|
|
157
|
+
.collect();
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
if let Some(value) = kwargs.get(ruby.to_symbol("constraint")) {
|
|
162
|
+
if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
|
|
163
|
+
config.constraint = Some(Arc::clone(&constraint.index));
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
Ok(Self { inner: config })
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
pub fn default() -> Self {
|
|
171
|
+
Self {
|
|
172
|
+
inner: RustGenerationConfig::default(),
|
|
173
|
+
}
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
// Getters
|
|
177
|
+
pub fn max_length(&self) -> usize {
|
|
178
|
+
self.inner.max_length
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
pub fn temperature(&self) -> f64 {
|
|
182
|
+
self.inner.temperature
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
pub fn top_p(&self) -> Option<f64> {
|
|
186
|
+
self.inner.top_p
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
pub fn top_k(&self) -> Option<usize> {
|
|
190
|
+
self.inner.top_k
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
pub fn repetition_penalty(&self) -> f32 {
|
|
194
|
+
self.inner.repetition_penalty
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
pub fn seed(&self) -> u64 {
|
|
198
|
+
self.inner.seed
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
pub fn stop_sequences(&self) -> Vec<String> {
|
|
202
|
+
self.inner.stop_sequences.clone()
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
pub fn include_prompt(&self) -> bool {
|
|
206
|
+
self.inner.include_prompt
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
pub fn debug_tokens(&self) -> bool {
|
|
210
|
+
self.inner.debug_tokens
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
pub fn stop_on_constraint_satisfaction(&self) -> bool {
|
|
214
|
+
self.inner.stop_on_constraint_satisfaction
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
pub fn stop_on_match(&self) -> bool {
|
|
218
|
+
self.inner.stop_on_match
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
pub fn constraint(&self) -> Option<StructuredConstraint> {
|
|
222
|
+
self.inner.constraint.as_ref().map(|c| StructuredConstraint {
|
|
223
|
+
index: Arc::clone(c),
|
|
224
|
+
})
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
/// Get all options as a hash
|
|
228
|
+
pub fn options(&self) -> Result<RHash> {
|
|
229
|
+
let ruby = Ruby::get().unwrap();
|
|
230
|
+
let hash = ruby.hash_new();
|
|
231
|
+
|
|
232
|
+
hash.aset("max_length", self.inner.max_length)?;
|
|
233
|
+
hash.aset("temperature", self.inner.temperature)?;
|
|
234
|
+
|
|
235
|
+
if let Some(top_p) = self.inner.top_p {
|
|
236
|
+
hash.aset("top_p", top_p)?;
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
if let Some(top_k) = self.inner.top_k {
|
|
240
|
+
hash.aset("top_k", top_k)?;
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
hash.aset("repetition_penalty", self.inner.repetition_penalty)?;
|
|
244
|
+
hash.aset("repetition_penalty_last_n", self.inner.repetition_penalty_last_n)?;
|
|
245
|
+
hash.aset("seed", self.inner.seed)?;
|
|
246
|
+
hash.aset("stop_sequences", self.inner.stop_sequences.clone())?;
|
|
247
|
+
hash.aset("include_prompt", self.inner.include_prompt)?;
|
|
248
|
+
hash.aset("debug_tokens", self.inner.debug_tokens)?;
|
|
249
|
+
hash.aset("stop_on_constraint_satisfaction", self.inner.stop_on_constraint_satisfaction)?;
|
|
250
|
+
hash.aset("stop_on_match", self.inner.stop_on_match)?;
|
|
251
|
+
|
|
252
|
+
if self.inner.constraint.is_some() {
|
|
253
|
+
hash.aset("has_constraint", true)?;
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
Ok(hash)
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
#[derive(Clone)]
|
|
261
|
+
#[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
|
|
262
|
+
pub struct LLM {
|
|
263
|
+
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
|
264
|
+
model_id: String,
|
|
265
|
+
device: Device,
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
impl LLM {
|
|
269
|
+
/// Create a new LLM from a pretrained model
|
|
270
|
+
pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
|
|
271
|
+
let ruby = Ruby::get().unwrap();
|
|
272
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
273
|
+
let device = device.unwrap_or(Device::best());
|
|
274
|
+
let candle_device = device.as_device()?;
|
|
275
|
+
|
|
276
|
+
let rt = tokio::runtime::Runtime::new()
|
|
277
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create runtime: {}", e)))?;
|
|
278
|
+
|
|
279
|
+
// Determine model type from ID and whether it's quantized
|
|
280
|
+
let model_lower = model_id.to_lowercase();
|
|
281
|
+
let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
|
|
282
|
+
|
|
283
|
+
// Extract tokenizer source if provided in model_id (for both GGUF and regular models)
|
|
284
|
+
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
|
285
|
+
let (id, _tok) = model_id.split_at(pos);
|
|
286
|
+
(id.to_string(), Some(&model_id[pos+2..]))
|
|
287
|
+
} else {
|
|
288
|
+
(model_id.clone(), None)
|
|
289
|
+
};
|
|
290
|
+
|
|
291
|
+
let model = if is_quantized {
|
|
292
|
+
|
|
293
|
+
// Use unified GGUF loader for all quantized models
|
|
294
|
+
let gguf_model = rt.block_on(async {
|
|
295
|
+
RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
|
|
296
|
+
})
|
|
297
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load GGUF model: {}", e)))?;
|
|
298
|
+
ModelType::QuantizedGGUF(gguf_model)
|
|
299
|
+
} else {
|
|
300
|
+
// Load non-quantized models based on type
|
|
301
|
+
let model_lower_clean = model_id_clean.to_lowercase();
|
|
302
|
+
|
|
303
|
+
if model_lower_clean.contains("mistral") {
|
|
304
|
+
let mistral = if tokenizer_source.is_some() {
|
|
305
|
+
rt.block_on(async {
|
|
306
|
+
RustMistral::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
307
|
+
})
|
|
308
|
+
} else {
|
|
309
|
+
rt.block_on(async {
|
|
310
|
+
RustMistral::from_pretrained(&model_id_clean, candle_device).await
|
|
311
|
+
})
|
|
312
|
+
}
|
|
313
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
314
|
+
ModelType::Mistral(mistral)
|
|
315
|
+
} else if model_lower_clean.contains("llama") || model_lower_clean.contains("meta-llama") || model_lower_clean.contains("tinyllama") || model_lower_clean.contains("smollm") || model_lower_clean.contains("/yi-") || model_lower_clean.contains("01-ai") {
|
|
316
|
+
let llama = if tokenizer_source.is_some() {
|
|
317
|
+
rt.block_on(async {
|
|
318
|
+
RustLlama::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
319
|
+
})
|
|
320
|
+
} else {
|
|
321
|
+
rt.block_on(async {
|
|
322
|
+
RustLlama::from_pretrained(&model_id_clean, candle_device).await
|
|
323
|
+
})
|
|
324
|
+
}
|
|
325
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
326
|
+
ModelType::Llama(llama)
|
|
327
|
+
} else if model_lower_clean.contains("gemma") || model_lower_clean.contains("google/gemma") {
|
|
328
|
+
let gemma = if tokenizer_source.is_some() {
|
|
329
|
+
rt.block_on(async {
|
|
330
|
+
RustGemma::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
331
|
+
})
|
|
332
|
+
} else {
|
|
333
|
+
rt.block_on(async {
|
|
334
|
+
RustGemma::from_pretrained(&model_id_clean, candle_device).await
|
|
335
|
+
})
|
|
336
|
+
}
|
|
337
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
338
|
+
ModelType::Gemma(gemma)
|
|
339
|
+
} else if model_lower_clean.contains("qwen3") {
|
|
340
|
+
let qwen3 = if tokenizer_source.is_some() {
|
|
341
|
+
rt.block_on(async {
|
|
342
|
+
RustQwen3::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
343
|
+
})
|
|
344
|
+
} else {
|
|
345
|
+
rt.block_on(async {
|
|
346
|
+
RustQwen3::from_pretrained(&model_id_clean, candle_device).await
|
|
347
|
+
})
|
|
348
|
+
}
|
|
349
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
350
|
+
ModelType::Qwen3(qwen3)
|
|
351
|
+
} else if model_lower_clean.contains("qwen") {
|
|
352
|
+
let qwen = if tokenizer_source.is_some() {
|
|
353
|
+
rt.block_on(async {
|
|
354
|
+
RustQwen::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
355
|
+
})
|
|
356
|
+
} else {
|
|
357
|
+
rt.block_on(async {
|
|
358
|
+
RustQwen::from_pretrained(&model_id_clean, candle_device).await
|
|
359
|
+
})
|
|
360
|
+
}
|
|
361
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
362
|
+
ModelType::Qwen(qwen)
|
|
363
|
+
} else if model_lower_clean.contains("phi") {
|
|
364
|
+
let phi = if tokenizer_source.is_some() {
|
|
365
|
+
rt.block_on(async {
|
|
366
|
+
RustPhi::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
367
|
+
})
|
|
368
|
+
} else {
|
|
369
|
+
rt.block_on(async {
|
|
370
|
+
RustPhi::from_pretrained(&model_id_clean, candle_device).await
|
|
371
|
+
})
|
|
372
|
+
}
|
|
373
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
374
|
+
ModelType::Phi(phi)
|
|
375
|
+
} else if model_lower_clean.contains("granite-4") || model_lower_clean.contains("granitemoehybrid") {
|
|
376
|
+
let granite_moe = if tokenizer_source.is_some() {
|
|
377
|
+
rt.block_on(async {
|
|
378
|
+
RustGraniteMoeHybrid::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
379
|
+
})
|
|
380
|
+
} else {
|
|
381
|
+
rt.block_on(async {
|
|
382
|
+
RustGraniteMoeHybrid::from_pretrained(&model_id_clean, candle_device).await
|
|
383
|
+
})
|
|
384
|
+
}
|
|
385
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
386
|
+
ModelType::GraniteMoeHybrid(granite_moe)
|
|
387
|
+
} else if model_lower_clean.contains("granite") {
|
|
388
|
+
let granite = if tokenizer_source.is_some() {
|
|
389
|
+
rt.block_on(async {
|
|
390
|
+
RustGranite::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
391
|
+
})
|
|
392
|
+
} else {
|
|
393
|
+
rt.block_on(async {
|
|
394
|
+
RustGranite::from_pretrained(&model_id_clean, candle_device).await
|
|
395
|
+
})
|
|
396
|
+
}
|
|
397
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
398
|
+
ModelType::Granite(granite)
|
|
399
|
+
} else if model_lower_clean.contains("glm") {
|
|
400
|
+
let glm4 = if tokenizer_source.is_some() {
|
|
401
|
+
rt.block_on(async {
|
|
402
|
+
RustGlm4::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
|
|
403
|
+
})
|
|
404
|
+
} else {
|
|
405
|
+
rt.block_on(async {
|
|
406
|
+
RustGlm4::from_pretrained(&model_id_clean, candle_device).await
|
|
407
|
+
})
|
|
408
|
+
}
|
|
409
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
|
|
410
|
+
ModelType::Glm4(glm4)
|
|
411
|
+
} else {
|
|
412
|
+
return Err(Error::new(
|
|
413
|
+
runtime_error,
|
|
414
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, Phi, Granite, and GLM-4 models are supported.", model_id_clean),
|
|
415
|
+
));
|
|
416
|
+
}
|
|
417
|
+
};
|
|
418
|
+
|
|
419
|
+
Ok(Self {
|
|
420
|
+
model: std::sync::Arc::new(std::sync::Mutex::new(RefCell::new(model))),
|
|
421
|
+
model_id,
|
|
422
|
+
device,
|
|
423
|
+
})
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
/// Generate text from a prompt (releases GVL during inference)
|
|
427
|
+
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
|
428
|
+
let ruby = Ruby::get().unwrap();
|
|
429
|
+
let config = config
|
|
430
|
+
.map(|c| c.inner.clone())
|
|
431
|
+
.unwrap_or_default();
|
|
432
|
+
|
|
433
|
+
let model = match self.model.lock() {
|
|
434
|
+
Ok(guard) => guard,
|
|
435
|
+
Err(poisoned) => poisoned.into_inner(),
|
|
436
|
+
};
|
|
437
|
+
let mut model_ref = model.borrow_mut();
|
|
438
|
+
|
|
439
|
+
// Release the GVL during inference so other Ruby threads can run
|
|
440
|
+
// (e.g., TUI render loops, HTTP servers, etc.)
|
|
441
|
+
let result = gvl::without_gvl(|| {
|
|
442
|
+
model_ref.generate(&prompt, &config)
|
|
443
|
+
});
|
|
444
|
+
|
|
445
|
+
result.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
/// Generate text with streaming output
|
|
449
|
+
pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
|
450
|
+
let config = config
|
|
451
|
+
.map(|c| c.inner.clone())
|
|
452
|
+
.unwrap_or_default();
|
|
453
|
+
|
|
454
|
+
let ruby = Ruby::get().unwrap();
|
|
455
|
+
let runtime_error = ruby.exception_runtime_error();
|
|
456
|
+
let block = ruby.block_proc();
|
|
457
|
+
if let Err(_) = block {
|
|
458
|
+
return Err(Error::new(runtime_error, "No block given"));
|
|
459
|
+
}
|
|
460
|
+
let block = block.unwrap();
|
|
461
|
+
|
|
462
|
+
let model = match self.model.lock() {
|
|
463
|
+
Ok(guard) => guard,
|
|
464
|
+
Err(poisoned) => poisoned.into_inner(),
|
|
465
|
+
};
|
|
466
|
+
let mut model_ref = model.borrow_mut();
|
|
467
|
+
|
|
468
|
+
let result = model_ref.generate_stream(&prompt, &config, |token| {
|
|
469
|
+
// Call the Ruby block with each token
|
|
470
|
+
let _ = block.call::<(String,), Value>((token.to_string(),));
|
|
471
|
+
});
|
|
472
|
+
|
|
473
|
+
result.map_err(|e| Error::new(runtime_error, format!("Generation failed: {}", e)))
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
/// Get the model name
|
|
477
|
+
pub fn model_name(&self) -> String {
|
|
478
|
+
self.model_id.clone()
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
/// Get the device the model is running on
|
|
482
|
+
pub fn device(&self) -> Device {
|
|
483
|
+
self.device
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
/// Get the tokenizer used by this model
|
|
487
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
|
488
|
+
let model = match self.model.lock() {
|
|
489
|
+
Ok(guard) => guard,
|
|
490
|
+
Err(poisoned) => poisoned.into_inner(),
|
|
491
|
+
};
|
|
492
|
+
let model_ref = model.borrow();
|
|
493
|
+
|
|
494
|
+
// Clone the tokenizer from the model
|
|
495
|
+
match &*model_ref {
|
|
496
|
+
ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
497
|
+
ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
498
|
+
ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
499
|
+
ModelType::Qwen(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
500
|
+
ModelType::Qwen3(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
501
|
+
ModelType::Phi(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
502
|
+
ModelType::Granite(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
503
|
+
ModelType::GraniteMoeHybrid(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
504
|
+
ModelType::Glm4(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
505
|
+
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
|
506
|
+
}
|
|
507
|
+
}
|
|
508
|
+
|
|
509
|
+
/// Get the EOS token string for this model
|
|
510
|
+
pub fn eos_token(&self) -> Result<String> {
|
|
511
|
+
let (eos_token_id, tokenizer_clone) = {
|
|
512
|
+
let model = match self.model.lock() {
|
|
513
|
+
Ok(guard) => guard,
|
|
514
|
+
Err(poisoned) => poisoned.into_inner(),
|
|
515
|
+
};
|
|
516
|
+
let model_ref = model.borrow();
|
|
517
|
+
|
|
518
|
+
// Get both EOS token ID and tokenizer clone in one lock scope
|
|
519
|
+
let eos_id = match &*model_ref {
|
|
520
|
+
ModelType::Mistral(m) => m.eos_token_id(),
|
|
521
|
+
ModelType::Llama(m) => m.eos_token_id(),
|
|
522
|
+
ModelType::Gemma(m) => m.eos_token_id(),
|
|
523
|
+
ModelType::Qwen(m) => m.eos_token_id(),
|
|
524
|
+
ModelType::Qwen3(m) => m.eos_token_id(),
|
|
525
|
+
ModelType::Phi(m) => m.eos_token_id(),
|
|
526
|
+
ModelType::Granite(m) => m.eos_token_id(),
|
|
527
|
+
ModelType::GraniteMoeHybrid(m) => m.eos_token_id(),
|
|
528
|
+
ModelType::Glm4(m) => m.eos_token_id(),
|
|
529
|
+
ModelType::QuantizedGGUF(m) => m.eos_token_id(),
|
|
530
|
+
};
|
|
531
|
+
|
|
532
|
+
let tokenizer = match &*model_ref {
|
|
533
|
+
ModelType::Mistral(m) => m.tokenizer().clone(),
|
|
534
|
+
ModelType::Llama(m) => m.tokenizer().clone(),
|
|
535
|
+
ModelType::Gemma(m) => m.tokenizer().clone(),
|
|
536
|
+
ModelType::Qwen(m) => m.tokenizer().clone(),
|
|
537
|
+
ModelType::Qwen3(m) => m.tokenizer().clone(),
|
|
538
|
+
ModelType::Phi(m) => m.tokenizer().clone(),
|
|
539
|
+
ModelType::Granite(m) => m.tokenizer().clone(),
|
|
540
|
+
ModelType::GraniteMoeHybrid(m) => m.tokenizer().clone(),
|
|
541
|
+
ModelType::Glm4(m) => m.tokenizer().clone(),
|
|
542
|
+
ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
|
|
543
|
+
};
|
|
544
|
+
|
|
545
|
+
(eos_id, tokenizer)
|
|
546
|
+
}; // Lock is released here
|
|
547
|
+
|
|
548
|
+
// Convert ID to string using the tokenizer
|
|
549
|
+
let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
|
|
550
|
+
tokenizer_wrapper.id_to_token(eos_token_id as i64)
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
/// Clear the model's cache (e.g., KV cache for transformers)
|
|
554
|
+
pub fn clear_cache(&self) -> Result<()> {
|
|
555
|
+
let model = match self.model.lock() {
|
|
556
|
+
Ok(guard) => guard,
|
|
557
|
+
Err(poisoned) => {
|
|
558
|
+
// If the mutex is poisoned, we can still recover the data
|
|
559
|
+
// This happens when another thread panicked while holding the lock
|
|
560
|
+
poisoned.into_inner()
|
|
561
|
+
}
|
|
562
|
+
};
|
|
563
|
+
let mut model_ref = model.borrow_mut();
|
|
564
|
+
model_ref.clear_cache();
|
|
565
|
+
Ok(())
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
/// Apply chat template to messages
|
|
569
|
+
pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
|
|
570
|
+
let ruby = Ruby::get().unwrap();
|
|
571
|
+
// Convert Ruby array to JSON values
|
|
572
|
+
let json_messages: Vec<serde_json::Value> = messages
|
|
573
|
+
.into_iter()
|
|
574
|
+
.filter_map(|msg| {
|
|
575
|
+
if let Ok(hash) = <RHash as TryConvert>::try_convert(msg) {
|
|
576
|
+
let mut json_msg = serde_json::Map::new();
|
|
577
|
+
|
|
578
|
+
if let Some(role) = hash.get(ruby.to_symbol("role")) {
|
|
579
|
+
if let Ok(role_str) = <String as TryConvert>::try_convert(role) {
|
|
580
|
+
json_msg.insert("role".to_string(), serde_json::Value::String(role_str));
|
|
581
|
+
}
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
if let Some(content) = hash.get(ruby.to_symbol("content")) {
|
|
585
|
+
if let Ok(content_str) = <String as TryConvert>::try_convert(content) {
|
|
586
|
+
json_msg.insert("content".to_string(), serde_json::Value::String(content_str));
|
|
587
|
+
}
|
|
588
|
+
}
|
|
589
|
+
|
|
590
|
+
Some(serde_json::Value::Object(json_msg))
|
|
591
|
+
} else {
|
|
592
|
+
None
|
|
593
|
+
}
|
|
594
|
+
})
|
|
595
|
+
.collect();
|
|
596
|
+
|
|
597
|
+
let model = match self.model.lock() {
|
|
598
|
+
Ok(guard) => guard,
|
|
599
|
+
Err(poisoned) => poisoned.into_inner(),
|
|
600
|
+
};
|
|
601
|
+
let model_ref = model.borrow();
|
|
602
|
+
|
|
603
|
+
model_ref.apply_chat_template(&json_messages)
|
|
604
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to apply chat template: {}", e)))
|
|
605
|
+
}
|
|
606
|
+
|
|
607
|
+
/// Get the model ID
|
|
608
|
+
pub fn model_id(&self) -> String {
|
|
609
|
+
self.model_id.clone()
|
|
610
|
+
}
|
|
611
|
+
|
|
612
|
+
/// Get model options
|
|
613
|
+
pub fn options(&self) -> Result<RHash> {
|
|
614
|
+
let ruby = Ruby::get().unwrap();
|
|
615
|
+
let hash = ruby.hash_new();
|
|
616
|
+
|
|
617
|
+
// Basic metadata
|
|
618
|
+
hash.aset("model_id", self.model_id.clone())?;
|
|
619
|
+
let device_str = match self.device {
|
|
620
|
+
Device::Cpu => "cpu",
|
|
621
|
+
Device::Cuda => "cuda",
|
|
622
|
+
Device::Metal => "metal",
|
|
623
|
+
};
|
|
624
|
+
hash.aset("device", device_str)?;
|
|
625
|
+
|
|
626
|
+
// Parse model_id to extract GGUF file if present
|
|
627
|
+
if let Some(at_pos) = self.model_id.find('@') {
|
|
628
|
+
let (base_model, gguf_part) = self.model_id.split_at(at_pos);
|
|
629
|
+
let gguf_part = &gguf_part[1..]; // Skip the @ character
|
|
630
|
+
|
|
631
|
+
// Check for tokenizer (@@)
|
|
632
|
+
if let Some(tokenizer_pos) = gguf_part.find("@@") {
|
|
633
|
+
let (gguf_file, tokenizer) = gguf_part.split_at(tokenizer_pos);
|
|
634
|
+
hash.aset("base_model", base_model)?;
|
|
635
|
+
hash.aset("gguf_file", gguf_file)?;
|
|
636
|
+
hash.aset("tokenizer_source", &tokenizer[2..])?;
|
|
637
|
+
} else {
|
|
638
|
+
hash.aset("base_model", base_model)?;
|
|
639
|
+
hash.aset("gguf_file", gguf_part)?;
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
// Add model type
|
|
644
|
+
let model = match self.model.lock() {
|
|
645
|
+
Ok(guard) => guard,
|
|
646
|
+
Err(poisoned) => poisoned.into_inner(),
|
|
647
|
+
};
|
|
648
|
+
let model_ref = model.borrow();
|
|
649
|
+
|
|
650
|
+
let model_type = match &*model_ref {
|
|
651
|
+
ModelType::Mistral(_) => "Mistral",
|
|
652
|
+
ModelType::Llama(_) => "Llama",
|
|
653
|
+
ModelType::Gemma(_) => "Gemma",
|
|
654
|
+
ModelType::Qwen(_) => "Qwen",
|
|
655
|
+
ModelType::Qwen3(_) => "Qwen3",
|
|
656
|
+
ModelType::Phi(_) => "Phi",
|
|
657
|
+
ModelType::Granite(_) => "Granite",
|
|
658
|
+
ModelType::GraniteMoeHybrid(_) => "GraniteMoeHybrid",
|
|
659
|
+
ModelType::Glm4(_) => "Glm4",
|
|
660
|
+
ModelType::QuantizedGGUF(_) => "QuantizedGGUF",
|
|
661
|
+
};
|
|
662
|
+
hash.aset("model_type", model_type)?;
|
|
663
|
+
|
|
664
|
+
// For GGUF models, add architecture info
|
|
665
|
+
if let ModelType::QuantizedGGUF(gguf) = &*model_ref {
|
|
666
|
+
hash.aset("architecture", gguf.architecture.clone())?;
|
|
667
|
+
hash.aset("eos_token_id", gguf.eos_token_id())?;
|
|
668
|
+
}
|
|
669
|
+
|
|
670
|
+
Ok(hash)
|
|
671
|
+
}
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
// Define a standalone function for from_pretrained that handles variable arguments
|
|
675
|
+
fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
|
|
676
|
+
match args.len() {
|
|
677
|
+
1 => {
|
|
678
|
+
let model_id: String = TryConvert::try_convert(args[0])?;
|
|
679
|
+
LLM::from_pretrained(model_id, None)
|
|
680
|
+
},
|
|
681
|
+
2 => {
|
|
682
|
+
let model_id: String = TryConvert::try_convert(args[0])?;
|
|
683
|
+
let device: Device = TryConvert::try_convert(args[1])?;
|
|
684
|
+
LLM::from_pretrained(model_id, Some(device))
|
|
685
|
+
},
|
|
686
|
+
_ => {
|
|
687
|
+
let ruby = Ruby::get().unwrap();
|
|
688
|
+
Err(Error::new(
|
|
689
|
+
ruby.exception_arg_error(),
|
|
690
|
+
"wrong number of arguments (expected 1..2)"
|
|
691
|
+
))
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
|
697
|
+
let ruby = Ruby::get().unwrap();
|
|
698
|
+
let rb_generation_config = rb_candle.define_class("GenerationConfig", ruby.class_object())?;
|
|
699
|
+
rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
|
|
700
|
+
rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
|
|
701
|
+
|
|
702
|
+
rb_generation_config.define_method("max_length", method!(GenerationConfig::max_length, 0))?;
|
|
703
|
+
rb_generation_config.define_method("temperature", method!(GenerationConfig::temperature, 0))?;
|
|
704
|
+
rb_generation_config.define_method("top_p", method!(GenerationConfig::top_p, 0))?;
|
|
705
|
+
rb_generation_config.define_method("top_k", method!(GenerationConfig::top_k, 0))?;
|
|
706
|
+
rb_generation_config.define_method("repetition_penalty", method!(GenerationConfig::repetition_penalty, 0))?;
|
|
707
|
+
rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
|
|
708
|
+
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
|
709
|
+
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
|
710
|
+
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
|
711
|
+
rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
|
|
712
|
+
rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
|
|
713
|
+
rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
|
|
714
|
+
rb_generation_config.define_method("options", method!(GenerationConfig::options, 0))?;
|
|
715
|
+
|
|
716
|
+
let rb_llm = rb_candle.define_class("LLM", ruby.class_object())?;
|
|
717
|
+
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
|
718
|
+
rb_llm.define_method("_generate", method!(LLM::generate, 2))?;
|
|
719
|
+
rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
|
|
720
|
+
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
|
721
|
+
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
|
722
|
+
rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
|
|
723
|
+
rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
|
|
724
|
+
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
|
725
|
+
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
|
726
|
+
rb_llm.define_method("model_id", method!(LLM::model_id, 0))?;
|
|
727
|
+
rb_llm.define_method("options", method!(LLM::options, 0))?;
|
|
728
|
+
|
|
729
|
+
Ok(())
|
|
730
|
+
}
|