red-candle 1.0.0.pre.6 → 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/Gemfile +1 -10
  3. data/README.md +481 -4
  4. data/Rakefile +1 -3
  5. data/ext/candle/src/lib.rs +6 -3
  6. data/ext/candle/src/llm/gemma.rs +21 -79
  7. data/ext/candle/src/llm/generation_config.rs +3 -0
  8. data/ext/candle/src/llm/llama.rs +21 -79
  9. data/ext/candle/src/llm/mistral.rs +21 -89
  10. data/ext/candle/src/llm/mod.rs +3 -33
  11. data/ext/candle/src/llm/quantized_gguf.rs +501 -0
  12. data/ext/candle/src/llm/text_generation.rs +0 -4
  13. data/ext/candle/src/ner.rs +423 -0
  14. data/ext/candle/src/reranker.rs +24 -21
  15. data/ext/candle/src/ruby/device.rs +6 -6
  16. data/ext/candle/src/ruby/dtype.rs +4 -4
  17. data/ext/candle/src/ruby/embedding_model.rs +36 -34
  18. data/ext/candle/src/ruby/llm.rs +110 -49
  19. data/ext/candle/src/ruby/mod.rs +1 -2
  20. data/ext/candle/src/ruby/tensor.rs +66 -66
  21. data/ext/candle/src/ruby/tokenizer.rs +269 -0
  22. data/ext/candle/src/ruby/utils.rs +6 -24
  23. data/ext/candle/src/tokenizer/loader.rs +108 -0
  24. data/ext/candle/src/tokenizer/mod.rs +103 -0
  25. data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
  26. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
  27. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
  28. data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
  29. data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
  30. data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
  31. data/lib/candle/build_info.rb +2 -0
  32. data/lib/candle/device_utils.rb +2 -0
  33. data/lib/candle/llm.rb +91 -2
  34. data/lib/candle/ner.rb +345 -0
  35. data/lib/candle/reranker.rb +1 -1
  36. data/lib/candle/tensor.rb +2 -0
  37. data/lib/candle/tokenizer.rb +139 -0
  38. data/lib/candle/version.rb +4 -2
  39. data/lib/candle.rb +2 -0
  40. metadata +127 -3
  41. data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -3,7 +3,8 @@
3
3
  use crate::ruby::{
4
4
  errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
5
5
  };
6
- use crate::ruby::{Tensor, Device, Result as RbResult};
6
+ use crate::ruby::{Tensor, Device, Result};
7
+ use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
7
8
  use candle_core::{DType as CoreDType, Device as CoreDevice, Module, Tensor as CoreTensor};
8
9
  use safetensors::tensor::SafeTensors;
9
10
  use candle_nn::VarBuilder;
@@ -14,7 +15,6 @@ use candle_transformers::models::{
14
15
  };
15
16
  use magnus::{class, function, method, prelude::*, Error, RModule};
16
17
  use std::path::Path;
17
- use tokenizers::Tokenizer;
18
18
  use serde_json;
19
19
 
20
20
 
@@ -70,12 +70,12 @@ pub struct EmbeddingModelInner {
70
70
  model_path: Option<String>,
71
71
  embedding_model_type: Option<EmbeddingModelType>,
72
72
  model: Option<EmbeddingModelVariant>,
73
- tokenizer: Option<Tokenizer>,
73
+ tokenizer: Option<TokenizerWrapper>,
74
74
  embedding_size: Option<usize>,
75
75
  }
76
76
 
77
77
  impl EmbeddingModel {
78
- pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> RbResult<Self> {
78
+ pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
79
79
  let device = device.unwrap_or(Device::Cpu).as_device()?;
80
80
  let embedding_model_type = embedding_model_type
81
81
  .and_then(|mt| EmbeddingModelType::from_string(&mt))
@@ -102,7 +102,7 @@ impl EmbeddingModel {
102
102
  /// Generates an embedding vector for the input text using the specified pooling method.
103
103
  /// &RETURNS&: Tensor
104
104
  /// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
105
- pub fn embedding(&self, input: String, pooling_method: String) -> RbResult<Tensor> {
105
+ pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
106
106
  match &self.0.model {
107
107
  Some(model) => {
108
108
  match &self.0.tokenizer {
@@ -116,7 +116,7 @@ impl EmbeddingModel {
116
116
 
117
117
  /// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
118
118
  /// &RETURNS&: Tensor
119
- pub fn embeddings(&self, input: String) -> RbResult<Tensor> {
119
+ pub fn embeddings(&self, input: String) -> Result<Tensor> {
120
120
  match &self.0.model {
121
121
  Some(model) => {
122
122
  match &self.0.tokenizer {
@@ -130,27 +130,27 @@ impl EmbeddingModel {
130
130
 
131
131
  /// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
132
132
  /// &RETURNS&: Tensor
133
- pub fn pool_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
133
+ pub fn pool_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
134
134
  let pooled = Self::pooled_embedding(&tensor.0)?;
135
135
  Ok(Tensor(pooled))
136
136
  }
137
137
 
138
138
  /// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
139
139
  /// &RETURNS&: Tensor
140
- pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
140
+ pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
141
141
  let pooled = Self::pooled_normalized_embedding(&tensor.0)?;
142
142
  Ok(Tensor(pooled))
143
143
  }
144
144
 
145
145
  /// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
146
146
  /// &RETURNS&: Tensor
147
- pub fn pool_cls_embedding(&self, tensor: &Tensor) -> RbResult<Tensor> {
147
+ pub fn pool_cls_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
148
148
  let pooled = Self::pooled_cls_embedding(&tensor.0)?;
149
149
  Ok(Tensor(pooled))
150
150
  }
151
151
 
152
152
  /// Infers and validates the embedding size from a safetensors file
153
- fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> Result<usize, magnus::Error> {
153
+ fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> std::result::Result<usize, magnus::Error> {
154
154
  match embedding_size {
155
155
  Some(user_dim) => {
156
156
  Ok(user_dim)
@@ -170,7 +170,7 @@ impl EmbeddingModel {
170
170
  }
171
171
  }
172
172
 
173
- fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> RbResult<EmbeddingModelVariant> {
173
+ fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
174
174
  use hf_hub::{api::sync::Api, Repo, RepoType};
175
175
  let api = Api::new().map_err(wrap_hf_err)?;
176
176
  let repo = Repo::new(model_path.to_str().unwrap().to_string(), RepoType::Model);
@@ -257,7 +257,7 @@ impl EmbeddingModel {
257
257
  }
258
258
  }
259
259
 
260
- fn build_tokenizer(tokenizer_path: String) -> RbResult<Tokenizer> {
260
+ fn build_tokenizer(tokenizer_path: String) -> Result<TokenizerWrapper> {
261
261
  use hf_hub::{api::sync::Api, Repo, RepoType};
262
262
  let tokenizer_path = Api::new()
263
263
  .map_err(wrap_hf_err)?
@@ -267,20 +267,16 @@ impl EmbeddingModel {
267
267
  ))
268
268
  .get("tokenizer.json")
269
269
  .map_err(wrap_hf_err)?;
270
- let mut tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
270
+ let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
271
271
  .map_err(wrap_std_err)?;
272
- let pp = tokenizers::PaddingParams {
273
- strategy: tokenizers::PaddingStrategy::BatchLongest,
274
- ..Default::default()
275
- };
276
- tokenizer.with_padding(Some(pp));
277
-
278
- Ok(tokenizer)
272
+
273
+ let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
274
+ Ok(TokenizerWrapper::new(tokenizer))
279
275
  }
280
276
 
281
277
  /// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
282
278
  /// &RETURNS&: Tensor
283
- fn pooled_cls_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
279
+ fn pooled_cls_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
284
280
  // 1) sanity-check that we have a 3D tensor
285
281
  let (_batch, _seq_len, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
286
282
 
@@ -298,14 +294,14 @@ impl EmbeddingModel {
298
294
  Ok(cls)
299
295
  }
300
296
 
301
- fn pooled_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
297
+ fn pooled_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
302
298
  let (_n_sentence, n_tokens, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
303
299
  let sum = result.sum(1).map_err(wrap_candle_err)?;
304
300
  let mean = (sum / (n_tokens as f64)).map_err(wrap_candle_err)?;
305
301
  Ok(mean)
306
302
  }
307
303
 
308
- fn pooled_normalized_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
304
+ fn pooled_normalized_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
309
305
  let mean = Self::pooled_embedding(result)?;
310
306
  let norm = Self::normalize_l2(&mean).map_err(wrap_candle_err)?;
311
307
  Ok(norm)
@@ -315,13 +311,11 @@ impl EmbeddingModel {
315
311
  &self,
316
312
  prompt: String,
317
313
  model: &EmbeddingModelVariant,
318
- tokenizer: &Tokenizer,
319
- ) -> Result<CoreTensor, Error> {
314
+ tokenizer: &TokenizerWrapper,
315
+ ) -> std::result::Result<CoreTensor, Error> {
320
316
  let tokens = tokenizer
321
- .encode(prompt, true)
322
- .map_err(wrap_std_err)?
323
- .get_ids()
324
- .to_vec();
317
+ .encode(&prompt, true)
318
+ .map_err(wrap_candle_err)?;
325
319
  let token_ids = CoreTensor::new(&tokens[..], &self.0.device)
326
320
  .map_err(wrap_candle_err)?
327
321
  .unsqueeze(0)
@@ -355,9 +349,9 @@ impl EmbeddingModel {
355
349
  &self,
356
350
  prompt: String,
357
351
  model: &EmbeddingModelVariant,
358
- tokenizer: &Tokenizer,
352
+ tokenizer: &TokenizerWrapper,
359
353
  pooling_method: &str,
360
- ) -> Result<CoreTensor, Error> {
354
+ ) -> std::result::Result<CoreTensor, Error> {
361
355
  let result = self.compute_embeddings(prompt, model, tokenizer)?;
362
356
  match pooling_method {
363
357
  "pooled" => Self::pooled_embedding(&result),
@@ -367,8 +361,7 @@ impl EmbeddingModel {
367
361
  }
368
362
  }
369
363
 
370
- #[allow(dead_code)]
371
- fn normalize_l2(v: &CoreTensor) -> Result<CoreTensor, candle_core::Error> {
364
+ fn normalize_l2(v: &CoreTensor) -> candle_core::Result<CoreTensor> {
372
365
  v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
373
366
  }
374
367
 
@@ -392,9 +385,17 @@ impl EmbeddingModel {
392
385
  pub fn __str__(&self) -> String {
393
386
  self.__repr__()
394
387
  }
388
+
389
+ /// Get the tokenizer used by this model
390
+ pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
391
+ match &self.0.tokenizer {
392
+ Some(tokenizer) => Ok(crate::ruby::tokenizer::Tokenizer(tokenizer.clone())),
393
+ None => Err(magnus::Error::new(magnus::exception::runtime_error(), "No tokenizer loaded for this model"))
394
+ }
395
+ }
395
396
  }
396
397
 
397
- pub fn init(rb_candle: RModule) -> Result<(), Error> {
398
+ pub fn init(rb_candle: RModule) -> Result<()> {
398
399
  let rb_embedding_model = rb_candle.define_class("EmbeddingModel", class::object())?;
399
400
  rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
400
401
  // Expose embedding with an optional pooling_method argument (default: "pooled")
@@ -406,5 +407,6 @@ pub fn init(rb_candle: RModule) -> Result<(), Error> {
406
407
  rb_embedding_model.define_method("embedding_model_type", method!(EmbeddingModel::embedding_model_type, 0))?;
407
408
  rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
408
409
  rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
410
+ rb_embedding_model.define_method("tokenizer", method!(EmbeddingModel::tokenizer, 0))?;
409
411
  Ok(())
410
412
  }
@@ -1,8 +1,8 @@
1
1
  use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
2
  use std::cell::RefCell;
3
3
 
4
- use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma};
5
- use crate::ruby::{Result as RbResult, Device as RbDevice};
4
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
5
+ use crate::ruby::{Result, Device};
6
6
 
7
7
  // Use an enum to handle different model types instead of trait objects
8
8
  #[derive(Debug)]
@@ -10,6 +10,7 @@ enum ModelType {
10
10
  Mistral(RustMistral),
11
11
  Llama(RustLlama),
12
12
  Gemma(RustGemma),
13
+ QuantizedGGUF(RustQuantizedGGUF),
13
14
  }
14
15
 
15
16
  impl ModelType {
@@ -18,6 +19,7 @@ impl ModelType {
18
19
  ModelType::Mistral(m) => m.generate(prompt, config),
19
20
  ModelType::Llama(m) => m.generate(prompt, config),
20
21
  ModelType::Gemma(m) => m.generate(prompt, config),
22
+ ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
21
23
  }
22
24
  }
23
25
 
@@ -31,15 +33,7 @@ impl ModelType {
31
33
  ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
32
34
  ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
33
35
  ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
34
- }
35
- }
36
-
37
- #[allow(dead_code)]
38
- fn model_name(&self) -> &str {
39
- match self {
40
- ModelType::Mistral(m) => m.model_name(),
41
- ModelType::Llama(m) => m.model_name(),
42
- ModelType::Gemma(m) => m.model_name(),
36
+ ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
43
37
  }
44
38
  }
45
39
 
@@ -48,6 +42,7 @@ impl ModelType {
48
42
  ModelType::Mistral(m) => m.clear_cache(),
49
43
  ModelType::Llama(m) => m.clear_cache(),
50
44
  ModelType::Gemma(m) => m.clear_cache(),
45
+ ModelType::QuantizedGGUF(m) => m.clear_cache(),
51
46
  }
52
47
  }
53
48
 
@@ -72,6 +67,7 @@ impl ModelType {
72
67
  },
73
68
  ModelType::Llama(m) => m.apply_chat_template(messages),
74
69
  ModelType::Gemma(m) => m.apply_chat_template(messages),
70
+ ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
75
71
  }
76
72
  }
77
73
  }
@@ -83,7 +79,7 @@ pub struct GenerationConfig {
83
79
  }
84
80
 
85
81
  impl GenerationConfig {
86
- pub fn new(kwargs: RHash) -> RbResult<Self> {
82
+ pub fn new(kwargs: RHash) -> Result<Self> {
87
83
  let mut config = RustGenerationConfig::default();
88
84
 
89
85
  // Extract values from kwargs manually
@@ -144,6 +140,12 @@ impl GenerationConfig {
144
140
  }
145
141
  }
146
142
 
143
+ if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
144
+ if let Ok(v) = TryConvert::try_convert(value) {
145
+ config.debug_tokens = v;
146
+ }
147
+ }
148
+
147
149
  Ok(Self { inner: config })
148
150
  }
149
151
 
@@ -185,6 +187,10 @@ impl GenerationConfig {
185
187
  pub fn include_prompt(&self) -> bool {
186
188
  self.inner.include_prompt
187
189
  }
190
+
191
+ pub fn debug_tokens(&self) -> bool {
192
+ self.inner.debug_tokens
193
+ }
188
194
  }
189
195
 
190
196
  #[derive(Clone, Debug)]
@@ -192,13 +198,13 @@ impl GenerationConfig {
192
198
  pub struct LLM {
193
199
  model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
194
200
  model_id: String,
195
- device: RbDevice,
201
+ device: Device,
196
202
  }
197
203
 
198
204
  impl LLM {
199
205
  /// Create a new LLM from a pretrained model
200
- pub fn from_pretrained(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
201
- let device = device.unwrap_or(RbDevice::Cpu);
206
+ pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
207
+ let device = device.unwrap_or(Device::Cpu);
202
208
  let candle_device = device.as_device()?;
203
209
 
204
210
  // For now, we'll use tokio runtime directly
@@ -206,31 +212,51 @@ impl LLM {
206
212
  let rt = tokio::runtime::Runtime::new()
207
213
  .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
208
214
 
209
- // Determine model type from ID and load appropriately
215
+ // Determine model type from ID and whether it's quantized
210
216
  let model_lower = model_id.to_lowercase();
211
- let model = if model_lower.contains("mistral") {
212
- let mistral = rt.block_on(async {
213
- RustMistral::from_pretrained(&model_id, candle_device).await
214
- })
215
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
216
- ModelType::Mistral(mistral)
217
- } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
218
- let llama = rt.block_on(async {
219
- RustLlama::from_pretrained(&model_id, candle_device).await
220
- })
221
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
222
- ModelType::Llama(llama)
223
- } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
224
- let gemma = rt.block_on(async {
225
- RustGemma::from_pretrained(&model_id, candle_device).await
217
+ let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
218
+
219
+ let model = if is_quantized {
220
+ // Extract tokenizer source if provided in model_id
221
+ let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
222
+ let (id, _tok) = model_id.split_at(pos);
223
+ (id.to_string(), Some(&model_id[pos+2..]))
224
+ } else {
225
+ (model_id.clone(), None)
226
+ };
227
+
228
+ // Use unified GGUF loader for all quantized models
229
+ let gguf_model = rt.block_on(async {
230
+ RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
226
231
  })
227
- .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
228
- ModelType::Gemma(gemma)
232
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
233
+ ModelType::QuantizedGGUF(gguf_model)
229
234
  } else {
230
- return Err(Error::new(
231
- magnus::exception::runtime_error(),
232
- format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
233
- ));
235
+ // Load non-quantized models
236
+ if model_lower.contains("mistral") {
237
+ let mistral = rt.block_on(async {
238
+ RustMistral::from_pretrained(&model_id, candle_device).await
239
+ })
240
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
241
+ ModelType::Mistral(mistral)
242
+ } else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
243
+ let llama = rt.block_on(async {
244
+ RustLlama::from_pretrained(&model_id, candle_device).await
245
+ })
246
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
247
+ ModelType::Llama(llama)
248
+ } else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
249
+ let gemma = rt.block_on(async {
250
+ RustGemma::from_pretrained(&model_id, candle_device).await
251
+ })
252
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
253
+ ModelType::Gemma(gemma)
254
+ } else {
255
+ return Err(Error::new(
256
+ magnus::exception::runtime_error(),
257
+ format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
258
+ ));
259
+ }
234
260
  };
235
261
 
236
262
  Ok(Self {
@@ -241,12 +267,15 @@ impl LLM {
241
267
  }
242
268
 
243
269
  /// Generate text from a prompt
244
- pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> RbResult<String> {
270
+ pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
245
271
  let config = config
246
272
  .map(|c| c.inner.clone())
247
273
  .unwrap_or_default();
248
274
 
249
- let model = self.model.lock().unwrap();
275
+ let model = match self.model.lock() {
276
+ Ok(guard) => guard,
277
+ Err(poisoned) => poisoned.into_inner(),
278
+ };
250
279
  let mut model_ref = model.borrow_mut();
251
280
 
252
281
  model_ref.generate(&prompt, &config)
@@ -254,7 +283,7 @@ impl LLM {
254
283
  }
255
284
 
256
285
  /// Generate text with streaming output
257
- pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> RbResult<String> {
286
+ pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
258
287
  let config = config
259
288
  .map(|c| c.inner.clone())
260
289
  .unwrap_or_default();
@@ -266,7 +295,10 @@ impl LLM {
266
295
  }
267
296
  let block = block.unwrap();
268
297
 
269
- let model = self.model.lock().unwrap();
298
+ let model = match self.model.lock() {
299
+ Ok(guard) => guard,
300
+ Err(poisoned) => poisoned.into_inner(),
301
+ };
270
302
  let mut model_ref = model.borrow_mut();
271
303
 
272
304
  let result = model_ref.generate_stream(&prompt, &config, |token| {
@@ -283,20 +315,44 @@ impl LLM {
283
315
  }
284
316
 
285
317
  /// Get the device the model is running on
286
- pub fn device(&self) -> RbDevice {
318
+ pub fn device(&self) -> Device {
287
319
  self.device
288
320
  }
321
+
322
+ /// Get the tokenizer used by this model
323
+ pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
324
+ let model = match self.model.lock() {
325
+ Ok(guard) => guard,
326
+ Err(poisoned) => poisoned.into_inner(),
327
+ };
328
+ let model_ref = model.borrow();
329
+
330
+ // Clone the tokenizer from the model
331
+ match &*model_ref {
332
+ ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
333
+ ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
334
+ ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
335
+ ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
336
+ }
337
+ }
289
338
 
290
339
  /// Clear the model's cache (e.g., KV cache for transformers)
291
- pub fn clear_cache(&self) -> RbResult<()> {
292
- let model = self.model.lock().unwrap();
340
+ pub fn clear_cache(&self) -> Result<()> {
341
+ let model = match self.model.lock() {
342
+ Ok(guard) => guard,
343
+ Err(poisoned) => {
344
+ // If the mutex is poisoned, we can still recover the data
345
+ // This happens when another thread panicked while holding the lock
346
+ poisoned.into_inner()
347
+ }
348
+ };
293
349
  let mut model_ref = model.borrow_mut();
294
350
  model_ref.clear_cache();
295
351
  Ok(())
296
352
  }
297
353
 
298
354
  /// Apply chat template to messages
299
- pub fn apply_chat_template(&self, messages: RArray) -> RbResult<String> {
355
+ pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
300
356
  // Convert Ruby array to JSON values
301
357
  let json_messages: Vec<serde_json::Value> = messages
302
358
  .into_iter()
@@ -323,7 +379,10 @@ impl LLM {
323
379
  })
324
380
  .collect();
325
381
 
326
- let model = self.model.lock().unwrap();
382
+ let model = match self.model.lock() {
383
+ Ok(guard) => guard,
384
+ Err(poisoned) => poisoned.into_inner(),
385
+ };
327
386
  let model_ref = model.borrow();
328
387
 
329
388
  model_ref.apply_chat_template(&json_messages)
@@ -332,7 +391,7 @@ impl LLM {
332
391
  }
333
392
 
334
393
  // Define a standalone function for from_pretrained that handles variable arguments
335
- fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
394
+ fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
336
395
  match args.len() {
337
396
  1 => {
338
397
  let model_id: String = TryConvert::try_convert(args[0])?;
@@ -340,7 +399,7 @@ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
340
399
  },
341
400
  2 => {
342
401
  let model_id: String = TryConvert::try_convert(args[0])?;
343
- let device: RbDevice = TryConvert::try_convert(args[1])?;
402
+ let device: Device = TryConvert::try_convert(args[1])?;
344
403
  LLM::from_pretrained(model_id, Some(device))
345
404
  },
346
405
  _ => Err(Error::new(
@@ -350,7 +409,7 @@ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
350
409
  }
351
410
  }
352
411
 
353
- pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
412
+ pub fn init_llm(rb_candle: RModule) -> Result<()> {
354
413
  let rb_generation_config = rb_candle.define_class("GenerationConfig", magnus::class::object())?;
355
414
  rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
356
415
  rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
@@ -363,6 +422,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
363
422
  rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
364
423
  rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
365
424
  rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
425
+ rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
366
426
 
367
427
  let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
368
428
  rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
@@ -370,6 +430,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
370
430
  rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
371
431
  rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
372
432
  rb_llm.define_method("device", method!(LLM::device, 0))?;
433
+ rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
373
434
  rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
374
435
  rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
375
436
 
@@ -2,17 +2,16 @@ pub mod embedding_model;
2
2
  pub mod tensor;
3
3
  pub mod device;
4
4
  pub mod dtype;
5
- pub mod qtensor;
6
5
  pub mod result;
7
6
  pub mod errors;
8
7
  pub mod utils;
9
8
  pub mod llm;
9
+ pub mod tokenizer;
10
10
 
11
11
  pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
12
12
  pub use tensor::Tensor;
13
13
  pub use device::Device;
14
14
  pub use dtype::DType;
15
- pub use qtensor::QTensor;
16
15
  pub use result::Result;
17
16
 
18
17
  // Re-export for convenience