red-candle 1.8.0.pre2-x86_64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5193 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +33 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,730 @@
1
+ use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
+ use std::cell::RefCell;
3
+ use std::sync::Arc;
4
+
5
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, qwen::Qwen as RustQwen, qwen3::Qwen3 as RustQwen3, phi::Phi as RustPhi, granite::Granite as RustGranite, granitemoehybrid::GraniteMoeHybrid as RustGraniteMoeHybrid, glm4::Glm4 as RustGlm4, QuantizedGGUF as RustQuantizedGGUF};
6
+ use crate::ruby::{Result, Device};
7
+ use crate::ruby::structured::StructuredConstraint;
8
+ use crate::gvl;
9
+
10
+ // Use an enum to handle different model types instead of trait objects
11
+ enum ModelType {
12
+ Mistral(RustMistral),
13
+ Llama(RustLlama),
14
+ Gemma(RustGemma),
15
+ Qwen(RustQwen),
16
+ Qwen3(RustQwen3),
17
+ Phi(RustPhi),
18
+ Granite(RustGranite),
19
+ GraniteMoeHybrid(RustGraniteMoeHybrid),
20
+ Glm4(RustGlm4),
21
+ QuantizedGGUF(RustQuantizedGGUF),
22
+ }
23
+
24
+ impl ModelType {
25
+ fn generate(&mut self, prompt: &str, config: &RustGenerationConfig) -> candle_core::Result<String> {
26
+ match self {
27
+ ModelType::Mistral(m) => m.generate(prompt, config),
28
+ ModelType::Llama(m) => m.generate(prompt, config),
29
+ ModelType::Gemma(m) => m.generate(prompt, config),
30
+ ModelType::Qwen(m) => m.generate(prompt, config),
31
+ ModelType::Qwen3(m) => m.generate(prompt, config),
32
+ ModelType::Phi(m) => m.generate(prompt, config),
33
+ ModelType::Granite(m) => m.generate(prompt, config),
34
+ ModelType::GraniteMoeHybrid(m) => m.generate(prompt, config),
35
+ ModelType::Glm4(m) => m.generate(prompt, config),
36
+ ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
37
+ }
38
+ }
39
+
40
+ fn generate_stream(
41
+ &mut self,
42
+ prompt: &str,
43
+ config: &RustGenerationConfig,
44
+ callback: impl FnMut(&str),
45
+ ) -> candle_core::Result<String> {
46
+ match self {
47
+ ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
48
+ ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
49
+ ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
50
+ ModelType::Qwen(m) => m.generate_stream(prompt, config, callback),
51
+ ModelType::Qwen3(m) => m.generate_stream(prompt, config, callback),
52
+ ModelType::Phi(m) => m.generate_stream(prompt, config, callback),
53
+ ModelType::Granite(m) => m.generate_stream(prompt, config, callback),
54
+ ModelType::GraniteMoeHybrid(m) => m.generate_stream(prompt, config, callback),
55
+ ModelType::Glm4(m) => m.generate_stream(prompt, config, callback),
56
+ ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
57
+ }
58
+ }
59
+
60
+ fn clear_cache(&mut self) {
61
+ match self {
62
+ ModelType::Mistral(m) => m.clear_cache(),
63
+ ModelType::Llama(m) => m.clear_cache(),
64
+ ModelType::Gemma(m) => m.clear_cache(),
65
+ ModelType::Qwen(m) => m.clear_cache(),
66
+ ModelType::Qwen3(m) => m.clear_cache(),
67
+ ModelType::Phi(m) => m.clear_cache(),
68
+ ModelType::Granite(m) => m.clear_cache(),
69
+ ModelType::GraniteMoeHybrid(m) => m.clear_cache(),
70
+ ModelType::Glm4(m) => m.clear_cache(),
71
+ ModelType::QuantizedGGUF(m) => m.clear_cache(),
72
+ }
73
+ }
74
+
75
+ fn apply_chat_template(&self, messages: &[serde_json::Value]) -> candle_core::Result<String> {
76
+ match self {
77
+ ModelType::Mistral(_) => {
78
+ // For now, use a simple template for Mistral
79
+ // In the future, we could implement proper Mistral chat templating
80
+ let mut prompt = String::new();
81
+ for message in messages {
82
+ let role = message["role"].as_str().unwrap_or("");
83
+ let content = message["content"].as_str().unwrap_or("");
84
+ match role {
85
+ "system" => prompt.push_str(&format!("System: {}\n\n", content)),
86
+ "user" => prompt.push_str(&format!("User: {}\n\n", content)),
87
+ "assistant" => prompt.push_str(&format!("Assistant: {}\n\n", content)),
88
+ _ => {}
89
+ }
90
+ }
91
+ prompt.push_str("Assistant: ");
92
+ Ok(prompt)
93
+ },
94
+ ModelType::Llama(m) => m.apply_chat_template(messages),
95
+ ModelType::Gemma(m) => m.apply_chat_template(messages),
96
+ ModelType::Qwen(m) => m.apply_chat_template(messages),
97
+ ModelType::Qwen3(m) => m.apply_chat_template(messages),
98
+ ModelType::Phi(m) => m.apply_chat_template(messages),
99
+ ModelType::Granite(m) => m.apply_chat_template(messages),
100
+ ModelType::GraniteMoeHybrid(m) => m.apply_chat_template(messages),
101
+ ModelType::Glm4(m) => m.apply_chat_template(messages),
102
+ ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
103
+ }
104
+ }
105
+ }
106
+
107
+ // Macro to extract parameters from Ruby hash to reduce boilerplate
108
+ macro_rules! extract_param {
109
+ // Basic parameter extraction
110
+ ($ruby:expr, $kwargs:expr, $config:expr, $param:ident) => {
111
+ if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
112
+ if let Ok(v) = TryConvert::try_convert(value) {
113
+ $config.$param = v;
114
+ }
115
+ }
116
+ };
117
+ // Optional parameter extraction (wraps in Some)
118
+ ($ruby:expr, $kwargs:expr, $config:expr, $param:ident, optional) => {
119
+ if let Some(value) = $kwargs.get($ruby.to_symbol(stringify!($param))) {
120
+ if let Ok(v) = TryConvert::try_convert(value) {
121
+ $config.$param = Some(v);
122
+ }
123
+ }
124
+ };
125
+ }
126
+
127
+ #[derive(Clone, Debug)]
128
+ #[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
129
+ pub struct GenerationConfig {
130
+ inner: RustGenerationConfig,
131
+ }
132
+
133
+ impl GenerationConfig {
134
+ pub fn new(kwargs: RHash) -> Result<Self> {
135
+ let ruby = Ruby::get().unwrap();
136
+ let mut config = RustGenerationConfig::default();
137
+
138
+ // Extract basic parameters using macro
139
+ extract_param!(ruby, kwargs, config, max_length);
140
+ extract_param!(ruby, kwargs, config, temperature);
141
+ extract_param!(ruby, kwargs, config, top_p, optional);
142
+ extract_param!(ruby, kwargs, config, top_k, optional);
143
+ extract_param!(ruby, kwargs, config, repetition_penalty);
144
+ extract_param!(ruby, kwargs, config, repetition_penalty_last_n);
145
+ extract_param!(ruby, kwargs, config, seed);
146
+ extract_param!(ruby, kwargs, config, include_prompt);
147
+ extract_param!(ruby, kwargs, config, debug_tokens);
148
+ extract_param!(ruby, kwargs, config, stop_on_constraint_satisfaction);
149
+ extract_param!(ruby, kwargs, config, stop_on_match);
150
+
151
+ // Handle special cases that need custom logic
152
+ if let Some(value) = kwargs.get(ruby.to_symbol("stop_sequences")) {
153
+ if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
154
+ config.stop_sequences = arr
155
+ .into_iter()
156
+ .filter_map(|v| <String as TryConvert>::try_convert(v).ok())
157
+ .collect();
158
+ }
159
+ }
160
+
161
+ if let Some(value) = kwargs.get(ruby.to_symbol("constraint")) {
162
+ if let Ok(constraint) = <&StructuredConstraint as TryConvert>::try_convert(value) {
163
+ config.constraint = Some(Arc::clone(&constraint.index));
164
+ }
165
+ }
166
+
167
+ Ok(Self { inner: config })
168
+ }
169
+
170
+ pub fn default() -> Self {
171
+ Self {
172
+ inner: RustGenerationConfig::default(),
173
+ }
174
+ }
175
+
176
+ // Getters
177
+ pub fn max_length(&self) -> usize {
178
+ self.inner.max_length
179
+ }
180
+
181
+ pub fn temperature(&self) -> f64 {
182
+ self.inner.temperature
183
+ }
184
+
185
+ pub fn top_p(&self) -> Option<f64> {
186
+ self.inner.top_p
187
+ }
188
+
189
+ pub fn top_k(&self) -> Option<usize> {
190
+ self.inner.top_k
191
+ }
192
+
193
+ pub fn repetition_penalty(&self) -> f32 {
194
+ self.inner.repetition_penalty
195
+ }
196
+
197
+ pub fn seed(&self) -> u64 {
198
+ self.inner.seed
199
+ }
200
+
201
+ pub fn stop_sequences(&self) -> Vec<String> {
202
+ self.inner.stop_sequences.clone()
203
+ }
204
+
205
+ pub fn include_prompt(&self) -> bool {
206
+ self.inner.include_prompt
207
+ }
208
+
209
+ pub fn debug_tokens(&self) -> bool {
210
+ self.inner.debug_tokens
211
+ }
212
+
213
+ pub fn stop_on_constraint_satisfaction(&self) -> bool {
214
+ self.inner.stop_on_constraint_satisfaction
215
+ }
216
+
217
+ pub fn stop_on_match(&self) -> bool {
218
+ self.inner.stop_on_match
219
+ }
220
+
221
+ pub fn constraint(&self) -> Option<StructuredConstraint> {
222
+ self.inner.constraint.as_ref().map(|c| StructuredConstraint {
223
+ index: Arc::clone(c),
224
+ })
225
+ }
226
+
227
+ /// Get all options as a hash
228
+ pub fn options(&self) -> Result<RHash> {
229
+ let ruby = Ruby::get().unwrap();
230
+ let hash = ruby.hash_new();
231
+
232
+ hash.aset("max_length", self.inner.max_length)?;
233
+ hash.aset("temperature", self.inner.temperature)?;
234
+
235
+ if let Some(top_p) = self.inner.top_p {
236
+ hash.aset("top_p", top_p)?;
237
+ }
238
+
239
+ if let Some(top_k) = self.inner.top_k {
240
+ hash.aset("top_k", top_k)?;
241
+ }
242
+
243
+ hash.aset("repetition_penalty", self.inner.repetition_penalty)?;
244
+ hash.aset("repetition_penalty_last_n", self.inner.repetition_penalty_last_n)?;
245
+ hash.aset("seed", self.inner.seed)?;
246
+ hash.aset("stop_sequences", self.inner.stop_sequences.clone())?;
247
+ hash.aset("include_prompt", self.inner.include_prompt)?;
248
+ hash.aset("debug_tokens", self.inner.debug_tokens)?;
249
+ hash.aset("stop_on_constraint_satisfaction", self.inner.stop_on_constraint_satisfaction)?;
250
+ hash.aset("stop_on_match", self.inner.stop_on_match)?;
251
+
252
+ if self.inner.constraint.is_some() {
253
+ hash.aset("has_constraint", true)?;
254
+ }
255
+
256
+ Ok(hash)
257
+ }
258
+ }
259
+
260
+ #[derive(Clone)]
261
+ #[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
262
+ pub struct LLM {
263
+ model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
264
+ model_id: String,
265
+ device: Device,
266
+ }
267
+
268
+ impl LLM {
269
+ /// Create a new LLM from a pretrained model
270
+ pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
271
+ let ruby = Ruby::get().unwrap();
272
+ let runtime_error = ruby.exception_runtime_error();
273
+ let device = device.unwrap_or(Device::best());
274
+ let candle_device = device.as_device()?;
275
+
276
+ let rt = tokio::runtime::Runtime::new()
277
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create runtime: {}", e)))?;
278
+
279
+ // Determine model type from ID and whether it's quantized
280
+ let model_lower = model_id.to_lowercase();
281
+ let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
282
+
283
+ // Extract tokenizer source if provided in model_id (for both GGUF and regular models)
284
+ let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
285
+ let (id, _tok) = model_id.split_at(pos);
286
+ (id.to_string(), Some(&model_id[pos+2..]))
287
+ } else {
288
+ (model_id.clone(), None)
289
+ };
290
+
291
+ let model = if is_quantized {
292
+
293
+ // Use unified GGUF loader for all quantized models
294
+ let gguf_model = rt.block_on(async {
295
+ RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
296
+ })
297
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load GGUF model: {}", e)))?;
298
+ ModelType::QuantizedGGUF(gguf_model)
299
+ } else {
300
+ // Load non-quantized models based on type
301
+ let model_lower_clean = model_id_clean.to_lowercase();
302
+
303
+ if model_lower_clean.contains("mistral") {
304
+ let mistral = if tokenizer_source.is_some() {
305
+ rt.block_on(async {
306
+ RustMistral::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
307
+ })
308
+ } else {
309
+ rt.block_on(async {
310
+ RustMistral::from_pretrained(&model_id_clean, candle_device).await
311
+ })
312
+ }
313
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
314
+ ModelType::Mistral(mistral)
315
+ } else if model_lower_clean.contains("llama") || model_lower_clean.contains("meta-llama") || model_lower_clean.contains("tinyllama") || model_lower_clean.contains("smollm") || model_lower_clean.contains("/yi-") || model_lower_clean.contains("01-ai") {
316
+ let llama = if tokenizer_source.is_some() {
317
+ rt.block_on(async {
318
+ RustLlama::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
319
+ })
320
+ } else {
321
+ rt.block_on(async {
322
+ RustLlama::from_pretrained(&model_id_clean, candle_device).await
323
+ })
324
+ }
325
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
326
+ ModelType::Llama(llama)
327
+ } else if model_lower_clean.contains("gemma") || model_lower_clean.contains("google/gemma") {
328
+ let gemma = if tokenizer_source.is_some() {
329
+ rt.block_on(async {
330
+ RustGemma::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
331
+ })
332
+ } else {
333
+ rt.block_on(async {
334
+ RustGemma::from_pretrained(&model_id_clean, candle_device).await
335
+ })
336
+ }
337
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
338
+ ModelType::Gemma(gemma)
339
+ } else if model_lower_clean.contains("qwen3") {
340
+ let qwen3 = if tokenizer_source.is_some() {
341
+ rt.block_on(async {
342
+ RustQwen3::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
343
+ })
344
+ } else {
345
+ rt.block_on(async {
346
+ RustQwen3::from_pretrained(&model_id_clean, candle_device).await
347
+ })
348
+ }
349
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
350
+ ModelType::Qwen3(qwen3)
351
+ } else if model_lower_clean.contains("qwen") {
352
+ let qwen = if tokenizer_source.is_some() {
353
+ rt.block_on(async {
354
+ RustQwen::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
355
+ })
356
+ } else {
357
+ rt.block_on(async {
358
+ RustQwen::from_pretrained(&model_id_clean, candle_device).await
359
+ })
360
+ }
361
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
362
+ ModelType::Qwen(qwen)
363
+ } else if model_lower_clean.contains("phi") {
364
+ let phi = if tokenizer_source.is_some() {
365
+ rt.block_on(async {
366
+ RustPhi::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
367
+ })
368
+ } else {
369
+ rt.block_on(async {
370
+ RustPhi::from_pretrained(&model_id_clean, candle_device).await
371
+ })
372
+ }
373
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
374
+ ModelType::Phi(phi)
375
+ } else if model_lower_clean.contains("granite-4") || model_lower_clean.contains("granitemoehybrid") {
376
+ let granite_moe = if tokenizer_source.is_some() {
377
+ rt.block_on(async {
378
+ RustGraniteMoeHybrid::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
379
+ })
380
+ } else {
381
+ rt.block_on(async {
382
+ RustGraniteMoeHybrid::from_pretrained(&model_id_clean, candle_device).await
383
+ })
384
+ }
385
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
386
+ ModelType::GraniteMoeHybrid(granite_moe)
387
+ } else if model_lower_clean.contains("granite") {
388
+ let granite = if tokenizer_source.is_some() {
389
+ rt.block_on(async {
390
+ RustGranite::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
391
+ })
392
+ } else {
393
+ rt.block_on(async {
394
+ RustGranite::from_pretrained(&model_id_clean, candle_device).await
395
+ })
396
+ }
397
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
398
+ ModelType::Granite(granite)
399
+ } else if model_lower_clean.contains("glm") {
400
+ let glm4 = if tokenizer_source.is_some() {
401
+ rt.block_on(async {
402
+ RustGlm4::from_pretrained_with_tokenizer(&model_id_clean, candle_device, tokenizer_source).await
403
+ })
404
+ } else {
405
+ rt.block_on(async {
406
+ RustGlm4::from_pretrained(&model_id_clean, candle_device).await
407
+ })
408
+ }
409
+ .map_err(|e| Error::new(runtime_error, format!("Failed to load model: {}", e)))?;
410
+ ModelType::Glm4(glm4)
411
+ } else {
412
+ return Err(Error::new(
413
+ runtime_error,
414
+ format!("Unsupported model type: {}. Currently Mistral, Llama, Gemma, Qwen, Phi, Granite, and GLM-4 models are supported.", model_id_clean),
415
+ ));
416
+ }
417
+ };
418
+
419
+ Ok(Self {
420
+ model: std::sync::Arc::new(std::sync::Mutex::new(RefCell::new(model))),
421
+ model_id,
422
+ device,
423
+ })
424
+ }
425
+
426
+ /// Generate text from a prompt (releases GVL during inference)
427
+ pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
428
+ let ruby = Ruby::get().unwrap();
429
+ let config = config
430
+ .map(|c| c.inner.clone())
431
+ .unwrap_or_default();
432
+
433
+ let model = match self.model.lock() {
434
+ Ok(guard) => guard,
435
+ Err(poisoned) => poisoned.into_inner(),
436
+ };
437
+ let mut model_ref = model.borrow_mut();
438
+
439
+ // Release the GVL during inference so other Ruby threads can run
440
+ // (e.g., TUI render loops, HTTP servers, etc.)
441
+ let result = gvl::without_gvl(|| {
442
+ model_ref.generate(&prompt, &config)
443
+ });
444
+
445
+ result.map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Generation failed: {}", e)))
446
+ }
447
+
448
+ /// Generate text with streaming output
449
+ pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
450
+ let config = config
451
+ .map(|c| c.inner.clone())
452
+ .unwrap_or_default();
453
+
454
+ let ruby = Ruby::get().unwrap();
455
+ let runtime_error = ruby.exception_runtime_error();
456
+ let block = ruby.block_proc();
457
+ if let Err(_) = block {
458
+ return Err(Error::new(runtime_error, "No block given"));
459
+ }
460
+ let block = block.unwrap();
461
+
462
+ let model = match self.model.lock() {
463
+ Ok(guard) => guard,
464
+ Err(poisoned) => poisoned.into_inner(),
465
+ };
466
+ let mut model_ref = model.borrow_mut();
467
+
468
+ let result = model_ref.generate_stream(&prompt, &config, |token| {
469
+ // Call the Ruby block with each token
470
+ let _ = block.call::<(String,), Value>((token.to_string(),));
471
+ });
472
+
473
+ result.map_err(|e| Error::new(runtime_error, format!("Generation failed: {}", e)))
474
+ }
475
+
476
+ /// Get the model name
477
+ pub fn model_name(&self) -> String {
478
+ self.model_id.clone()
479
+ }
480
+
481
+ /// Get the device the model is running on
482
+ pub fn device(&self) -> Device {
483
+ self.device
484
+ }
485
+
486
+ /// Get the tokenizer used by this model
487
+ pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
488
+ let model = match self.model.lock() {
489
+ Ok(guard) => guard,
490
+ Err(poisoned) => poisoned.into_inner(),
491
+ };
492
+ let model_ref = model.borrow();
493
+
494
+ // Clone the tokenizer from the model
495
+ match &*model_ref {
496
+ ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
497
+ ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
498
+ ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
499
+ ModelType::Qwen(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
500
+ ModelType::Qwen3(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
501
+ ModelType::Phi(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
502
+ ModelType::Granite(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
503
+ ModelType::GraniteMoeHybrid(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
504
+ ModelType::Glm4(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
505
+ ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
506
+ }
507
+ }
508
+
509
+ /// Get the EOS token string for this model
510
+ pub fn eos_token(&self) -> Result<String> {
511
+ let (eos_token_id, tokenizer_clone) = {
512
+ let model = match self.model.lock() {
513
+ Ok(guard) => guard,
514
+ Err(poisoned) => poisoned.into_inner(),
515
+ };
516
+ let model_ref = model.borrow();
517
+
518
+ // Get both EOS token ID and tokenizer clone in one lock scope
519
+ let eos_id = match &*model_ref {
520
+ ModelType::Mistral(m) => m.eos_token_id(),
521
+ ModelType::Llama(m) => m.eos_token_id(),
522
+ ModelType::Gemma(m) => m.eos_token_id(),
523
+ ModelType::Qwen(m) => m.eos_token_id(),
524
+ ModelType::Qwen3(m) => m.eos_token_id(),
525
+ ModelType::Phi(m) => m.eos_token_id(),
526
+ ModelType::Granite(m) => m.eos_token_id(),
527
+ ModelType::GraniteMoeHybrid(m) => m.eos_token_id(),
528
+ ModelType::Glm4(m) => m.eos_token_id(),
529
+ ModelType::QuantizedGGUF(m) => m.eos_token_id(),
530
+ };
531
+
532
+ let tokenizer = match &*model_ref {
533
+ ModelType::Mistral(m) => m.tokenizer().clone(),
534
+ ModelType::Llama(m) => m.tokenizer().clone(),
535
+ ModelType::Gemma(m) => m.tokenizer().clone(),
536
+ ModelType::Qwen(m) => m.tokenizer().clone(),
537
+ ModelType::Qwen3(m) => m.tokenizer().clone(),
538
+ ModelType::Phi(m) => m.tokenizer().clone(),
539
+ ModelType::Granite(m) => m.tokenizer().clone(),
540
+ ModelType::GraniteMoeHybrid(m) => m.tokenizer().clone(),
541
+ ModelType::Glm4(m) => m.tokenizer().clone(),
542
+ ModelType::QuantizedGGUF(m) => m.tokenizer().clone(),
543
+ };
544
+
545
+ (eos_id, tokenizer)
546
+ }; // Lock is released here
547
+
548
+ // Convert ID to string using the tokenizer
549
+ let tokenizer_wrapper = crate::ruby::tokenizer::Tokenizer(tokenizer_clone);
550
+ tokenizer_wrapper.id_to_token(eos_token_id as i64)
551
+ }
552
+
553
+ /// Clear the model's cache (e.g., KV cache for transformers)
554
+ pub fn clear_cache(&self) -> Result<()> {
555
+ let model = match self.model.lock() {
556
+ Ok(guard) => guard,
557
+ Err(poisoned) => {
558
+ // If the mutex is poisoned, we can still recover the data
559
+ // This happens when another thread panicked while holding the lock
560
+ poisoned.into_inner()
561
+ }
562
+ };
563
+ let mut model_ref = model.borrow_mut();
564
+ model_ref.clear_cache();
565
+ Ok(())
566
+ }
567
+
568
+ /// Apply chat template to messages
569
+ pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
570
+ let ruby = Ruby::get().unwrap();
571
+ // Convert Ruby array to JSON values
572
+ let json_messages: Vec<serde_json::Value> = messages
573
+ .into_iter()
574
+ .filter_map(|msg| {
575
+ if let Ok(hash) = <RHash as TryConvert>::try_convert(msg) {
576
+ let mut json_msg = serde_json::Map::new();
577
+
578
+ if let Some(role) = hash.get(ruby.to_symbol("role")) {
579
+ if let Ok(role_str) = <String as TryConvert>::try_convert(role) {
580
+ json_msg.insert("role".to_string(), serde_json::Value::String(role_str));
581
+ }
582
+ }
583
+
584
+ if let Some(content) = hash.get(ruby.to_symbol("content")) {
585
+ if let Ok(content_str) = <String as TryConvert>::try_convert(content) {
586
+ json_msg.insert("content".to_string(), serde_json::Value::String(content_str));
587
+ }
588
+ }
589
+
590
+ Some(serde_json::Value::Object(json_msg))
591
+ } else {
592
+ None
593
+ }
594
+ })
595
+ .collect();
596
+
597
+ let model = match self.model.lock() {
598
+ Ok(guard) => guard,
599
+ Err(poisoned) => poisoned.into_inner(),
600
+ };
601
+ let model_ref = model.borrow();
602
+
603
+ model_ref.apply_chat_template(&json_messages)
604
+ .map_err(|e| Error::new(ruby.exception_runtime_error(), format!("Failed to apply chat template: {}", e)))
605
+ }
606
+
607
+ /// Get the model ID
608
+ pub fn model_id(&self) -> String {
609
+ self.model_id.clone()
610
+ }
611
+
612
+ /// Get model options
613
+ pub fn options(&self) -> Result<RHash> {
614
+ let ruby = Ruby::get().unwrap();
615
+ let hash = ruby.hash_new();
616
+
617
+ // Basic metadata
618
+ hash.aset("model_id", self.model_id.clone())?;
619
+ let device_str = match self.device {
620
+ Device::Cpu => "cpu",
621
+ Device::Cuda => "cuda",
622
+ Device::Metal => "metal",
623
+ };
624
+ hash.aset("device", device_str)?;
625
+
626
+ // Parse model_id to extract GGUF file if present
627
+ if let Some(at_pos) = self.model_id.find('@') {
628
+ let (base_model, gguf_part) = self.model_id.split_at(at_pos);
629
+ let gguf_part = &gguf_part[1..]; // Skip the @ character
630
+
631
+ // Check for tokenizer (@@)
632
+ if let Some(tokenizer_pos) = gguf_part.find("@@") {
633
+ let (gguf_file, tokenizer) = gguf_part.split_at(tokenizer_pos);
634
+ hash.aset("base_model", base_model)?;
635
+ hash.aset("gguf_file", gguf_file)?;
636
+ hash.aset("tokenizer_source", &tokenizer[2..])?;
637
+ } else {
638
+ hash.aset("base_model", base_model)?;
639
+ hash.aset("gguf_file", gguf_part)?;
640
+ }
641
+ }
642
+
643
+ // Add model type
644
+ let model = match self.model.lock() {
645
+ Ok(guard) => guard,
646
+ Err(poisoned) => poisoned.into_inner(),
647
+ };
648
+ let model_ref = model.borrow();
649
+
650
+ let model_type = match &*model_ref {
651
+ ModelType::Mistral(_) => "Mistral",
652
+ ModelType::Llama(_) => "Llama",
653
+ ModelType::Gemma(_) => "Gemma",
654
+ ModelType::Qwen(_) => "Qwen",
655
+ ModelType::Qwen3(_) => "Qwen3",
656
+ ModelType::Phi(_) => "Phi",
657
+ ModelType::Granite(_) => "Granite",
658
+ ModelType::GraniteMoeHybrid(_) => "GraniteMoeHybrid",
659
+ ModelType::Glm4(_) => "Glm4",
660
+ ModelType::QuantizedGGUF(_) => "QuantizedGGUF",
661
+ };
662
+ hash.aset("model_type", model_type)?;
663
+
664
+ // For GGUF models, add architecture info
665
+ if let ModelType::QuantizedGGUF(gguf) = &*model_ref {
666
+ hash.aset("architecture", gguf.architecture.clone())?;
667
+ hash.aset("eos_token_id", gguf.eos_token_id())?;
668
+ }
669
+
670
+ Ok(hash)
671
+ }
672
+ }
673
+
674
+ // Define a standalone function for from_pretrained that handles variable arguments
675
+ fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
676
+ match args.len() {
677
+ 1 => {
678
+ let model_id: String = TryConvert::try_convert(args[0])?;
679
+ LLM::from_pretrained(model_id, None)
680
+ },
681
+ 2 => {
682
+ let model_id: String = TryConvert::try_convert(args[0])?;
683
+ let device: Device = TryConvert::try_convert(args[1])?;
684
+ LLM::from_pretrained(model_id, Some(device))
685
+ },
686
+ _ => {
687
+ let ruby = Ruby::get().unwrap();
688
+ Err(Error::new(
689
+ ruby.exception_arg_error(),
690
+ "wrong number of arguments (expected 1..2)"
691
+ ))
692
+ }
693
+ }
694
+ }
695
+
696
+ pub fn init_llm(rb_candle: RModule) -> Result<()> {
697
+ let ruby = Ruby::get().unwrap();
698
+ let rb_generation_config = rb_candle.define_class("GenerationConfig", ruby.class_object())?;
699
+ rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
700
+ rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
701
+
702
+ rb_generation_config.define_method("max_length", method!(GenerationConfig::max_length, 0))?;
703
+ rb_generation_config.define_method("temperature", method!(GenerationConfig::temperature, 0))?;
704
+ rb_generation_config.define_method("top_p", method!(GenerationConfig::top_p, 0))?;
705
+ rb_generation_config.define_method("top_k", method!(GenerationConfig::top_k, 0))?;
706
+ rb_generation_config.define_method("repetition_penalty", method!(GenerationConfig::repetition_penalty, 0))?;
707
+ rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
708
+ rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
709
+ rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
710
+ rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
711
+ rb_generation_config.define_method("stop_on_constraint_satisfaction", method!(GenerationConfig::stop_on_constraint_satisfaction, 0))?;
712
+ rb_generation_config.define_method("stop_on_match", method!(GenerationConfig::stop_on_match, 0))?;
713
+ rb_generation_config.define_method("constraint", method!(GenerationConfig::constraint, 0))?;
714
+ rb_generation_config.define_method("options", method!(GenerationConfig::options, 0))?;
715
+
716
+ let rb_llm = rb_candle.define_class("LLM", ruby.class_object())?;
717
+ rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
718
+ rb_llm.define_method("_generate", method!(LLM::generate, 2))?;
719
+ rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
720
+ rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
721
+ rb_llm.define_method("device", method!(LLM::device, 0))?;
722
+ rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
723
+ rb_llm.define_method("eos_token", method!(LLM::eos_token, 0))?;
724
+ rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
725
+ rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
726
+ rb_llm.define_method("model_id", method!(LLM::model_id, 0))?;
727
+ rb_llm.define_method("options", method!(LLM::options, 0))?;
728
+
729
+ Ok(())
730
+ }