red-candle 1.0.0.pre.6 → 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +481 -4
- data/Rakefile +1 -3
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +21 -79
- data/ext/candle/src/llm/generation_config.rs +3 -0
- data/ext/candle/src/llm/llama.rs +21 -79
- data/ext/candle/src/llm/mistral.rs +21 -89
- data/ext/candle/src/llm/mod.rs +3 -33
- data/ext/candle/src/llm/quantized_gguf.rs +501 -0
- data/ext/candle/src/llm/text_generation.rs +0 -4
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -34
- data/ext/candle/src/ruby/llm.rs +110 -49
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/llm.rb +91 -2
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +127 -3
- data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -3,7 +3,8 @@
|
|
3
3
|
use crate::ruby::{
|
4
4
|
errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
|
5
5
|
};
|
6
|
-
use crate::ruby::{Tensor, Device, Result
|
6
|
+
use crate::ruby::{Tensor, Device, Result};
|
7
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
7
8
|
use candle_core::{DType as CoreDType, Device as CoreDevice, Module, Tensor as CoreTensor};
|
8
9
|
use safetensors::tensor::SafeTensors;
|
9
10
|
use candle_nn::VarBuilder;
|
@@ -14,7 +15,6 @@ use candle_transformers::models::{
|
|
14
15
|
};
|
15
16
|
use magnus::{class, function, method, prelude::*, Error, RModule};
|
16
17
|
use std::path::Path;
|
17
|
-
use tokenizers::Tokenizer;
|
18
18
|
use serde_json;
|
19
19
|
|
20
20
|
|
@@ -70,12 +70,12 @@ pub struct EmbeddingModelInner {
|
|
70
70
|
model_path: Option<String>,
|
71
71
|
embedding_model_type: Option<EmbeddingModelType>,
|
72
72
|
model: Option<EmbeddingModelVariant>,
|
73
|
-
tokenizer: Option<
|
73
|
+
tokenizer: Option<TokenizerWrapper>,
|
74
74
|
embedding_size: Option<usize>,
|
75
75
|
}
|
76
76
|
|
77
77
|
impl EmbeddingModel {
|
78
|
-
pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) ->
|
78
|
+
pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
|
79
79
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
80
80
|
let embedding_model_type = embedding_model_type
|
81
81
|
.and_then(|mt| EmbeddingModelType::from_string(&mt))
|
@@ -102,7 +102,7 @@ impl EmbeddingModel {
|
|
102
102
|
/// Generates an embedding vector for the input text using the specified pooling method.
|
103
103
|
/// &RETURNS&: Tensor
|
104
104
|
/// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
|
105
|
-
pub fn embedding(&self, input: String, pooling_method: String) ->
|
105
|
+
pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
|
106
106
|
match &self.0.model {
|
107
107
|
Some(model) => {
|
108
108
|
match &self.0.tokenizer {
|
@@ -116,7 +116,7 @@ impl EmbeddingModel {
|
|
116
116
|
|
117
117
|
/// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
|
118
118
|
/// &RETURNS&: Tensor
|
119
|
-
pub fn embeddings(&self, input: String) ->
|
119
|
+
pub fn embeddings(&self, input: String) -> Result<Tensor> {
|
120
120
|
match &self.0.model {
|
121
121
|
Some(model) => {
|
122
122
|
match &self.0.tokenizer {
|
@@ -130,27 +130,27 @@ impl EmbeddingModel {
|
|
130
130
|
|
131
131
|
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
132
132
|
/// &RETURNS&: Tensor
|
133
|
-
pub fn pool_embedding(&self, tensor: &Tensor) ->
|
133
|
+
pub fn pool_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
134
134
|
let pooled = Self::pooled_embedding(&tensor.0)?;
|
135
135
|
Ok(Tensor(pooled))
|
136
136
|
}
|
137
137
|
|
138
138
|
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
139
139
|
/// &RETURNS&: Tensor
|
140
|
-
pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) ->
|
140
|
+
pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
141
141
|
let pooled = Self::pooled_normalized_embedding(&tensor.0)?;
|
142
142
|
Ok(Tensor(pooled))
|
143
143
|
}
|
144
144
|
|
145
145
|
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
146
146
|
/// &RETURNS&: Tensor
|
147
|
-
pub fn pool_cls_embedding(&self, tensor: &Tensor) ->
|
147
|
+
pub fn pool_cls_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
148
148
|
let pooled = Self::pooled_cls_embedding(&tensor.0)?;
|
149
149
|
Ok(Tensor(pooled))
|
150
150
|
}
|
151
151
|
|
152
152
|
/// Infers and validates the embedding size from a safetensors file
|
153
|
-
fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> Result<usize, magnus::Error> {
|
153
|
+
fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> std::result::Result<usize, magnus::Error> {
|
154
154
|
match embedding_size {
|
155
155
|
Some(user_dim) => {
|
156
156
|
Ok(user_dim)
|
@@ -170,7 +170,7 @@ impl EmbeddingModel {
|
|
170
170
|
}
|
171
171
|
}
|
172
172
|
|
173
|
-
fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) ->
|
173
|
+
fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
|
174
174
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
175
175
|
let api = Api::new().map_err(wrap_hf_err)?;
|
176
176
|
let repo = Repo::new(model_path.to_str().unwrap().to_string(), RepoType::Model);
|
@@ -257,7 +257,7 @@ impl EmbeddingModel {
|
|
257
257
|
}
|
258
258
|
}
|
259
259
|
|
260
|
-
fn build_tokenizer(tokenizer_path: String) ->
|
260
|
+
fn build_tokenizer(tokenizer_path: String) -> Result<TokenizerWrapper> {
|
261
261
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
262
262
|
let tokenizer_path = Api::new()
|
263
263
|
.map_err(wrap_hf_err)?
|
@@ -267,20 +267,16 @@ impl EmbeddingModel {
|
|
267
267
|
))
|
268
268
|
.get("tokenizer.json")
|
269
269
|
.map_err(wrap_hf_err)?;
|
270
|
-
let
|
270
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
271
271
|
.map_err(wrap_std_err)?;
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
};
|
276
|
-
tokenizer.with_padding(Some(pp));
|
277
|
-
|
278
|
-
Ok(tokenizer)
|
272
|
+
|
273
|
+
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
274
|
+
Ok(TokenizerWrapper::new(tokenizer))
|
279
275
|
}
|
280
276
|
|
281
277
|
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
282
278
|
/// &RETURNS&: Tensor
|
283
|
-
fn pooled_cls_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
279
|
+
fn pooled_cls_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
284
280
|
// 1) sanity-check that we have a 3D tensor
|
285
281
|
let (_batch, _seq_len, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
286
282
|
|
@@ -298,14 +294,14 @@ impl EmbeddingModel {
|
|
298
294
|
Ok(cls)
|
299
295
|
}
|
300
296
|
|
301
|
-
fn pooled_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
297
|
+
fn pooled_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
302
298
|
let (_n_sentence, n_tokens, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
303
299
|
let sum = result.sum(1).map_err(wrap_candle_err)?;
|
304
300
|
let mean = (sum / (n_tokens as f64)).map_err(wrap_candle_err)?;
|
305
301
|
Ok(mean)
|
306
302
|
}
|
307
303
|
|
308
|
-
fn pooled_normalized_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
304
|
+
fn pooled_normalized_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
309
305
|
let mean = Self::pooled_embedding(result)?;
|
310
306
|
let norm = Self::normalize_l2(&mean).map_err(wrap_candle_err)?;
|
311
307
|
Ok(norm)
|
@@ -315,13 +311,11 @@ impl EmbeddingModel {
|
|
315
311
|
&self,
|
316
312
|
prompt: String,
|
317
313
|
model: &EmbeddingModelVariant,
|
318
|
-
tokenizer: &
|
319
|
-
) -> Result<CoreTensor, Error> {
|
314
|
+
tokenizer: &TokenizerWrapper,
|
315
|
+
) -> std::result::Result<CoreTensor, Error> {
|
320
316
|
let tokens = tokenizer
|
321
|
-
.encode(prompt, true)
|
322
|
-
.map_err(
|
323
|
-
.get_ids()
|
324
|
-
.to_vec();
|
317
|
+
.encode(&prompt, true)
|
318
|
+
.map_err(wrap_candle_err)?;
|
325
319
|
let token_ids = CoreTensor::new(&tokens[..], &self.0.device)
|
326
320
|
.map_err(wrap_candle_err)?
|
327
321
|
.unsqueeze(0)
|
@@ -355,9 +349,9 @@ impl EmbeddingModel {
|
|
355
349
|
&self,
|
356
350
|
prompt: String,
|
357
351
|
model: &EmbeddingModelVariant,
|
358
|
-
tokenizer: &
|
352
|
+
tokenizer: &TokenizerWrapper,
|
359
353
|
pooling_method: &str,
|
360
|
-
) -> Result<CoreTensor, Error> {
|
354
|
+
) -> std::result::Result<CoreTensor, Error> {
|
361
355
|
let result = self.compute_embeddings(prompt, model, tokenizer)?;
|
362
356
|
match pooling_method {
|
363
357
|
"pooled" => Self::pooled_embedding(&result),
|
@@ -367,8 +361,7 @@ impl EmbeddingModel {
|
|
367
361
|
}
|
368
362
|
}
|
369
363
|
|
370
|
-
|
371
|
-
fn normalize_l2(v: &CoreTensor) -> Result<CoreTensor, candle_core::Error> {
|
364
|
+
fn normalize_l2(v: &CoreTensor) -> candle_core::Result<CoreTensor> {
|
372
365
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
373
366
|
}
|
374
367
|
|
@@ -392,9 +385,17 @@ impl EmbeddingModel {
|
|
392
385
|
pub fn __str__(&self) -> String {
|
393
386
|
self.__repr__()
|
394
387
|
}
|
388
|
+
|
389
|
+
/// Get the tokenizer used by this model
|
390
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
391
|
+
match &self.0.tokenizer {
|
392
|
+
Some(tokenizer) => Ok(crate::ruby::tokenizer::Tokenizer(tokenizer.clone())),
|
393
|
+
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "No tokenizer loaded for this model"))
|
394
|
+
}
|
395
|
+
}
|
395
396
|
}
|
396
397
|
|
397
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
398
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
398
399
|
let rb_embedding_model = rb_candle.define_class("EmbeddingModel", class::object())?;
|
399
400
|
rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
|
400
401
|
// Expose embedding with an optional pooling_method argument (default: "pooled")
|
@@ -406,5 +407,6 @@ pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
|
406
407
|
rb_embedding_model.define_method("embedding_model_type", method!(EmbeddingModel::embedding_model_type, 0))?;
|
407
408
|
rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
|
408
409
|
rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
|
410
|
+
rb_embedding_model.define_method("tokenizer", method!(EmbeddingModel::tokenizer, 0))?;
|
409
411
|
Ok(())
|
410
412
|
}
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
2
|
use std::cell::RefCell;
|
3
3
|
|
4
|
-
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma};
|
5
|
-
use crate::ruby::{Result
|
4
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
|
+
use crate::ruby::{Result, Device};
|
6
6
|
|
7
7
|
// Use an enum to handle different model types instead of trait objects
|
8
8
|
#[derive(Debug)]
|
@@ -10,6 +10,7 @@ enum ModelType {
|
|
10
10
|
Mistral(RustMistral),
|
11
11
|
Llama(RustLlama),
|
12
12
|
Gemma(RustGemma),
|
13
|
+
QuantizedGGUF(RustQuantizedGGUF),
|
13
14
|
}
|
14
15
|
|
15
16
|
impl ModelType {
|
@@ -18,6 +19,7 @@ impl ModelType {
|
|
18
19
|
ModelType::Mistral(m) => m.generate(prompt, config),
|
19
20
|
ModelType::Llama(m) => m.generate(prompt, config),
|
20
21
|
ModelType::Gemma(m) => m.generate(prompt, config),
|
22
|
+
ModelType::QuantizedGGUF(m) => m.generate(prompt, config),
|
21
23
|
}
|
22
24
|
}
|
23
25
|
|
@@ -31,15 +33,7 @@ impl ModelType {
|
|
31
33
|
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
32
34
|
ModelType::Llama(m) => m.generate_stream(prompt, config, callback),
|
33
35
|
ModelType::Gemma(m) => m.generate_stream(prompt, config, callback),
|
34
|
-
|
35
|
-
}
|
36
|
-
|
37
|
-
#[allow(dead_code)]
|
38
|
-
fn model_name(&self) -> &str {
|
39
|
-
match self {
|
40
|
-
ModelType::Mistral(m) => m.model_name(),
|
41
|
-
ModelType::Llama(m) => m.model_name(),
|
42
|
-
ModelType::Gemma(m) => m.model_name(),
|
36
|
+
ModelType::QuantizedGGUF(m) => m.generate_stream(prompt, config, callback),
|
43
37
|
}
|
44
38
|
}
|
45
39
|
|
@@ -48,6 +42,7 @@ impl ModelType {
|
|
48
42
|
ModelType::Mistral(m) => m.clear_cache(),
|
49
43
|
ModelType::Llama(m) => m.clear_cache(),
|
50
44
|
ModelType::Gemma(m) => m.clear_cache(),
|
45
|
+
ModelType::QuantizedGGUF(m) => m.clear_cache(),
|
51
46
|
}
|
52
47
|
}
|
53
48
|
|
@@ -72,6 +67,7 @@ impl ModelType {
|
|
72
67
|
},
|
73
68
|
ModelType::Llama(m) => m.apply_chat_template(messages),
|
74
69
|
ModelType::Gemma(m) => m.apply_chat_template(messages),
|
70
|
+
ModelType::QuantizedGGUF(m) => m.apply_chat_template(messages),
|
75
71
|
}
|
76
72
|
}
|
77
73
|
}
|
@@ -83,7 +79,7 @@ pub struct GenerationConfig {
|
|
83
79
|
}
|
84
80
|
|
85
81
|
impl GenerationConfig {
|
86
|
-
pub fn new(kwargs: RHash) ->
|
82
|
+
pub fn new(kwargs: RHash) -> Result<Self> {
|
87
83
|
let mut config = RustGenerationConfig::default();
|
88
84
|
|
89
85
|
// Extract values from kwargs manually
|
@@ -144,6 +140,12 @@ impl GenerationConfig {
|
|
144
140
|
}
|
145
141
|
}
|
146
142
|
|
143
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("debug_tokens")) {
|
144
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
145
|
+
config.debug_tokens = v;
|
146
|
+
}
|
147
|
+
}
|
148
|
+
|
147
149
|
Ok(Self { inner: config })
|
148
150
|
}
|
149
151
|
|
@@ -185,6 +187,10 @@ impl GenerationConfig {
|
|
185
187
|
pub fn include_prompt(&self) -> bool {
|
186
188
|
self.inner.include_prompt
|
187
189
|
}
|
190
|
+
|
191
|
+
pub fn debug_tokens(&self) -> bool {
|
192
|
+
self.inner.debug_tokens
|
193
|
+
}
|
188
194
|
}
|
189
195
|
|
190
196
|
#[derive(Clone, Debug)]
|
@@ -192,13 +198,13 @@ impl GenerationConfig {
|
|
192
198
|
pub struct LLM {
|
193
199
|
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
194
200
|
model_id: String,
|
195
|
-
device:
|
201
|
+
device: Device,
|
196
202
|
}
|
197
203
|
|
198
204
|
impl LLM {
|
199
205
|
/// Create a new LLM from a pretrained model
|
200
|
-
pub fn from_pretrained(model_id: String, device: Option<
|
201
|
-
let device = device.unwrap_or(
|
206
|
+
pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
|
207
|
+
let device = device.unwrap_or(Device::Cpu);
|
202
208
|
let candle_device = device.as_device()?;
|
203
209
|
|
204
210
|
// For now, we'll use tokio runtime directly
|
@@ -206,31 +212,51 @@ impl LLM {
|
|
206
212
|
let rt = tokio::runtime::Runtime::new()
|
207
213
|
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
208
214
|
|
209
|
-
// Determine model type from ID and
|
215
|
+
// Determine model type from ID and whether it's quantized
|
210
216
|
let model_lower = model_id.to_lowercase();
|
211
|
-
let
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
}
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
RustGemma::from_pretrained(&model_id, candle_device).await
|
217
|
+
let is_quantized = model_lower.contains("gguf") || model_lower.contains("-q4") || model_lower.contains("-q5") || model_lower.contains("-q8");
|
218
|
+
|
219
|
+
let model = if is_quantized {
|
220
|
+
// Extract tokenizer source if provided in model_id
|
221
|
+
let (model_id_clean, tokenizer_source) = if let Some(pos) = model_id.find("@@") {
|
222
|
+
let (id, _tok) = model_id.split_at(pos);
|
223
|
+
(id.to_string(), Some(&model_id[pos+2..]))
|
224
|
+
} else {
|
225
|
+
(model_id.clone(), None)
|
226
|
+
};
|
227
|
+
|
228
|
+
// Use unified GGUF loader for all quantized models
|
229
|
+
let gguf_model = rt.block_on(async {
|
230
|
+
RustQuantizedGGUF::from_pretrained(&model_id_clean, candle_device, tokenizer_source).await
|
226
231
|
})
|
227
|
-
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
228
|
-
ModelType::
|
232
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load GGUF model: {}", e)))?;
|
233
|
+
ModelType::QuantizedGGUF(gguf_model)
|
229
234
|
} else {
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
235
|
+
// Load non-quantized models
|
236
|
+
if model_lower.contains("mistral") {
|
237
|
+
let mistral = rt.block_on(async {
|
238
|
+
RustMistral::from_pretrained(&model_id, candle_device).await
|
239
|
+
})
|
240
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
241
|
+
ModelType::Mistral(mistral)
|
242
|
+
} else if model_lower.contains("llama") || model_lower.contains("meta-llama") || model_lower.contains("tinyllama") {
|
243
|
+
let llama = rt.block_on(async {
|
244
|
+
RustLlama::from_pretrained(&model_id, candle_device).await
|
245
|
+
})
|
246
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
247
|
+
ModelType::Llama(llama)
|
248
|
+
} else if model_lower.contains("gemma") || model_lower.contains("google/gemma") {
|
249
|
+
let gemma = rt.block_on(async {
|
250
|
+
RustGemma::from_pretrained(&model_id, candle_device).await
|
251
|
+
})
|
252
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
253
|
+
ModelType::Gemma(gemma)
|
254
|
+
} else {
|
255
|
+
return Err(Error::new(
|
256
|
+
magnus::exception::runtime_error(),
|
257
|
+
format!("Unsupported model type: {}. Currently Mistral, Llama, and Gemma models are supported.", model_id),
|
258
|
+
));
|
259
|
+
}
|
234
260
|
};
|
235
261
|
|
236
262
|
Ok(Self {
|
@@ -241,12 +267,15 @@ impl LLM {
|
|
241
267
|
}
|
242
268
|
|
243
269
|
/// Generate text from a prompt
|
244
|
-
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) ->
|
270
|
+
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
245
271
|
let config = config
|
246
272
|
.map(|c| c.inner.clone())
|
247
273
|
.unwrap_or_default();
|
248
274
|
|
249
|
-
let model = self.model.lock()
|
275
|
+
let model = match self.model.lock() {
|
276
|
+
Ok(guard) => guard,
|
277
|
+
Err(poisoned) => poisoned.into_inner(),
|
278
|
+
};
|
250
279
|
let mut model_ref = model.borrow_mut();
|
251
280
|
|
252
281
|
model_ref.generate(&prompt, &config)
|
@@ -254,7 +283,7 @@ impl LLM {
|
|
254
283
|
}
|
255
284
|
|
256
285
|
/// Generate text with streaming output
|
257
|
-
pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) ->
|
286
|
+
pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
258
287
|
let config = config
|
259
288
|
.map(|c| c.inner.clone())
|
260
289
|
.unwrap_or_default();
|
@@ -266,7 +295,10 @@ impl LLM {
|
|
266
295
|
}
|
267
296
|
let block = block.unwrap();
|
268
297
|
|
269
|
-
let model = self.model.lock()
|
298
|
+
let model = match self.model.lock() {
|
299
|
+
Ok(guard) => guard,
|
300
|
+
Err(poisoned) => poisoned.into_inner(),
|
301
|
+
};
|
270
302
|
let mut model_ref = model.borrow_mut();
|
271
303
|
|
272
304
|
let result = model_ref.generate_stream(&prompt, &config, |token| {
|
@@ -283,20 +315,44 @@ impl LLM {
|
|
283
315
|
}
|
284
316
|
|
285
317
|
/// Get the device the model is running on
|
286
|
-
pub fn device(&self) ->
|
318
|
+
pub fn device(&self) -> Device {
|
287
319
|
self.device
|
288
320
|
}
|
321
|
+
|
322
|
+
/// Get the tokenizer used by this model
|
323
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
324
|
+
let model = match self.model.lock() {
|
325
|
+
Ok(guard) => guard,
|
326
|
+
Err(poisoned) => poisoned.into_inner(),
|
327
|
+
};
|
328
|
+
let model_ref = model.borrow();
|
329
|
+
|
330
|
+
// Clone the tokenizer from the model
|
331
|
+
match &*model_ref {
|
332
|
+
ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
333
|
+
ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
334
|
+
ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
335
|
+
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
336
|
+
}
|
337
|
+
}
|
289
338
|
|
290
339
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
291
|
-
pub fn clear_cache(&self) ->
|
292
|
-
let model = self.model.lock()
|
340
|
+
pub fn clear_cache(&self) -> Result<()> {
|
341
|
+
let model = match self.model.lock() {
|
342
|
+
Ok(guard) => guard,
|
343
|
+
Err(poisoned) => {
|
344
|
+
// If the mutex is poisoned, we can still recover the data
|
345
|
+
// This happens when another thread panicked while holding the lock
|
346
|
+
poisoned.into_inner()
|
347
|
+
}
|
348
|
+
};
|
293
349
|
let mut model_ref = model.borrow_mut();
|
294
350
|
model_ref.clear_cache();
|
295
351
|
Ok(())
|
296
352
|
}
|
297
353
|
|
298
354
|
/// Apply chat template to messages
|
299
|
-
pub fn apply_chat_template(&self, messages: RArray) ->
|
355
|
+
pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
|
300
356
|
// Convert Ruby array to JSON values
|
301
357
|
let json_messages: Vec<serde_json::Value> = messages
|
302
358
|
.into_iter()
|
@@ -323,7 +379,10 @@ impl LLM {
|
|
323
379
|
})
|
324
380
|
.collect();
|
325
381
|
|
326
|
-
let model = self.model.lock()
|
382
|
+
let model = match self.model.lock() {
|
383
|
+
Ok(guard) => guard,
|
384
|
+
Err(poisoned) => poisoned.into_inner(),
|
385
|
+
};
|
327
386
|
let model_ref = model.borrow();
|
328
387
|
|
329
388
|
model_ref.apply_chat_template(&json_messages)
|
@@ -332,7 +391,7 @@ impl LLM {
|
|
332
391
|
}
|
333
392
|
|
334
393
|
// Define a standalone function for from_pretrained that handles variable arguments
|
335
|
-
fn from_pretrained_wrapper(args: &[Value]) ->
|
394
|
+
fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
|
336
395
|
match args.len() {
|
337
396
|
1 => {
|
338
397
|
let model_id: String = TryConvert::try_convert(args[0])?;
|
@@ -340,7 +399,7 @@ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
|
|
340
399
|
},
|
341
400
|
2 => {
|
342
401
|
let model_id: String = TryConvert::try_convert(args[0])?;
|
343
|
-
let device:
|
402
|
+
let device: Device = TryConvert::try_convert(args[1])?;
|
344
403
|
LLM::from_pretrained(model_id, Some(device))
|
345
404
|
},
|
346
405
|
_ => Err(Error::new(
|
@@ -350,7 +409,7 @@ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
|
|
350
409
|
}
|
351
410
|
}
|
352
411
|
|
353
|
-
pub fn init_llm(rb_candle: RModule) ->
|
412
|
+
pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
354
413
|
let rb_generation_config = rb_candle.define_class("GenerationConfig", magnus::class::object())?;
|
355
414
|
rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
|
356
415
|
rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
|
@@ -363,6 +422,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
|
363
422
|
rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
|
364
423
|
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
365
424
|
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
425
|
+
rb_generation_config.define_method("debug_tokens", method!(GenerationConfig::debug_tokens, 0))?;
|
366
426
|
|
367
427
|
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
368
428
|
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
@@ -370,6 +430,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
|
370
430
|
rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
|
371
431
|
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
372
432
|
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
433
|
+
rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
|
373
434
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
374
435
|
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
375
436
|
|
data/ext/candle/src/ruby/mod.rs
CHANGED
@@ -2,17 +2,16 @@ pub mod embedding_model;
|
|
2
2
|
pub mod tensor;
|
3
3
|
pub mod device;
|
4
4
|
pub mod dtype;
|
5
|
-
pub mod qtensor;
|
6
5
|
pub mod result;
|
7
6
|
pub mod errors;
|
8
7
|
pub mod utils;
|
9
8
|
pub mod llm;
|
9
|
+
pub mod tokenizer;
|
10
10
|
|
11
11
|
pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
|
12
12
|
pub use tensor::Tensor;
|
13
13
|
pub use device::Device;
|
14
14
|
pub use dtype::DType;
|
15
|
-
pub use qtensor::QTensor;
|
16
15
|
pub use result::Result;
|
17
16
|
|
18
17
|
// Re-export for convenience
|