red-candle 1.0.0.pre.7 → 1.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile +1 -10
- data/README.md +399 -18
- data/ext/candle/src/lib.rs +6 -3
- data/ext/candle/src/llm/gemma.rs +5 -0
- data/ext/candle/src/llm/llama.rs +5 -0
- data/ext/candle/src/llm/mistral.rs +5 -0
- data/ext/candle/src/llm/mod.rs +1 -89
- data/ext/candle/src/llm/quantized_gguf.rs +5 -0
- data/ext/candle/src/ner.rs +423 -0
- data/ext/candle/src/reranker.rs +24 -21
- data/ext/candle/src/ruby/device.rs +6 -6
- data/ext/candle/src/ruby/dtype.rs +4 -4
- data/ext/candle/src/ruby/embedding_model.rs +36 -33
- data/ext/candle/src/ruby/llm.rs +31 -13
- data/ext/candle/src/ruby/mod.rs +1 -2
- data/ext/candle/src/ruby/tensor.rs +66 -66
- data/ext/candle/src/ruby/tokenizer.rs +269 -0
- data/ext/candle/src/ruby/utils.rs +6 -24
- data/ext/candle/src/tokenizer/loader.rs +108 -0
- data/ext/candle/src/tokenizer/mod.rs +103 -0
- data/ext/candle/target/release/build/bindgen-0f89ba23b9ca1395/out/host-target.txt +1 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/common.rs +355 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/dynamic.rs +276 -0
- data/ext/candle/target/release/build/clang-sys-cac31d63c4694603/out/macros.rs +49 -0
- data/ext/candle/target/release/build/pulp-1b95cfe377eede97/out/x86_64_asm.rs +2748 -0
- data/ext/candle/target/release/build/rb-sys-f8ac4edc30ab3e53/out/bindings-0.9.116-mri-arm64-darwin24-3.3.0.rs +8902 -0
- data/lib/candle/build_info.rb +2 -0
- data/lib/candle/device_utils.rb +2 -0
- data/lib/candle/ner.rb +345 -0
- data/lib/candle/reranker.rb +1 -1
- data/lib/candle/tensor.rb +2 -0
- data/lib/candle/tokenizer.rb +139 -0
- data/lib/candle/version.rb +4 -2
- data/lib/candle.rb +2 -0
- metadata +128 -5
- data/ext/candle/src/ruby/qtensor.rs +0 -69
@@ -2,7 +2,7 @@ use magnus::Error;
|
|
2
2
|
use magnus::{function, method, class, RModule, Module, Object};
|
3
3
|
|
4
4
|
use ::candle_core::Device as CoreDevice;
|
5
|
-
use crate::ruby::Result
|
5
|
+
use crate::ruby::Result;
|
6
6
|
|
7
7
|
#[cfg(any(feature = "cuda", feature = "metal"))]
|
8
8
|
use crate::ruby::errors::wrap_candle_err;
|
@@ -68,7 +68,7 @@ impl Device {
|
|
68
68
|
}
|
69
69
|
|
70
70
|
/// Create a CUDA device (GPU)
|
71
|
-
pub fn cuda() ->
|
71
|
+
pub fn cuda() -> Result<Self> {
|
72
72
|
#[cfg(not(feature = "cuda"))]
|
73
73
|
{
|
74
74
|
return Err(Error::new(
|
@@ -82,7 +82,7 @@ impl Device {
|
|
82
82
|
}
|
83
83
|
|
84
84
|
/// Create a Metal device (Apple GPU)
|
85
|
-
pub fn metal() ->
|
85
|
+
pub fn metal() -> Result<Self> {
|
86
86
|
#[cfg(not(feature = "metal"))]
|
87
87
|
{
|
88
88
|
return Err(Error::new(
|
@@ -103,7 +103,7 @@ impl Device {
|
|
103
103
|
}
|
104
104
|
}
|
105
105
|
|
106
|
-
pub fn as_device(&self) ->
|
106
|
+
pub fn as_device(&self) -> Result<CoreDevice> {
|
107
107
|
match self {
|
108
108
|
Self::Cpu => Ok(CoreDevice::Cpu),
|
109
109
|
Self::Cuda => {
|
@@ -165,7 +165,7 @@ impl Device {
|
|
165
165
|
}
|
166
166
|
|
167
167
|
impl magnus::TryConvert for Device {
|
168
|
-
fn try_convert(val: magnus::Value) ->
|
168
|
+
fn try_convert(val: magnus::Value) -> Result<Self> {
|
169
169
|
// First check if it's already a wrapped Device object
|
170
170
|
if let Ok(device) = <magnus::typed_data::Obj<Device> as magnus::TryConvert>::try_convert(val) {
|
171
171
|
return Ok(*device);
|
@@ -184,7 +184,7 @@ impl magnus::TryConvert for Device {
|
|
184
184
|
}
|
185
185
|
}
|
186
186
|
|
187
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
187
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
188
188
|
let rb_device = rb_candle.define_class("Device", class::object())?;
|
189
189
|
rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
|
190
190
|
rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
|
@@ -1,8 +1,8 @@
|
|
1
1
|
use magnus::value::ReprValue;
|
2
|
-
use magnus::{method, class, RModule,
|
2
|
+
use magnus::{method, class, RModule, Module};
|
3
3
|
|
4
4
|
use ::candle_core::DType as CoreDType;
|
5
|
-
use crate::ruby::Result
|
5
|
+
use crate::ruby::Result;
|
6
6
|
|
7
7
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
8
8
|
#[magnus::wrap(class = "Candle::DType", free_immediately, size)]
|
@@ -21,7 +21,7 @@ impl DType {
|
|
21
21
|
}
|
22
22
|
|
23
23
|
impl DType {
|
24
|
-
pub fn from_rbobject(dtype: magnus::Symbol) ->
|
24
|
+
pub fn from_rbobject(dtype: magnus::Symbol) -> Result<Self> {
|
25
25
|
let dtype = unsafe { dtype.to_s() }.unwrap().into_owned();
|
26
26
|
use std::str::FromStr;
|
27
27
|
let dtype = CoreDType::from_str(&dtype).unwrap();
|
@@ -29,7 +29,7 @@ impl DType {
|
|
29
29
|
}
|
30
30
|
}
|
31
31
|
|
32
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
32
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
33
33
|
let rb_dtype = rb_candle.define_class("DType", class::object())?;
|
34
34
|
rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
|
35
35
|
rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
|
@@ -3,7 +3,8 @@
|
|
3
3
|
use crate::ruby::{
|
4
4
|
errors::{wrap_candle_err, wrap_hf_err, wrap_std_err},
|
5
5
|
};
|
6
|
-
use crate::ruby::{Tensor, Device, Result
|
6
|
+
use crate::ruby::{Tensor, Device, Result};
|
7
|
+
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
7
8
|
use candle_core::{DType as CoreDType, Device as CoreDevice, Module, Tensor as CoreTensor};
|
8
9
|
use safetensors::tensor::SafeTensors;
|
9
10
|
use candle_nn::VarBuilder;
|
@@ -14,7 +15,6 @@ use candle_transformers::models::{
|
|
14
15
|
};
|
15
16
|
use magnus::{class, function, method, prelude::*, Error, RModule};
|
16
17
|
use std::path::Path;
|
17
|
-
use tokenizers::Tokenizer;
|
18
18
|
use serde_json;
|
19
19
|
|
20
20
|
|
@@ -70,12 +70,12 @@ pub struct EmbeddingModelInner {
|
|
70
70
|
model_path: Option<String>,
|
71
71
|
embedding_model_type: Option<EmbeddingModelType>,
|
72
72
|
model: Option<EmbeddingModelVariant>,
|
73
|
-
tokenizer: Option<
|
73
|
+
tokenizer: Option<TokenizerWrapper>,
|
74
74
|
embedding_size: Option<usize>,
|
75
75
|
}
|
76
76
|
|
77
77
|
impl EmbeddingModel {
|
78
|
-
pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) ->
|
78
|
+
pub fn new(model_path: Option<String>, tokenizer_path: Option<String>, device: Option<Device>, embedding_model_type: Option<String>, embedding_size: Option<usize>) -> Result<Self> {
|
79
79
|
let device = device.unwrap_or(Device::Cpu).as_device()?;
|
80
80
|
let embedding_model_type = embedding_model_type
|
81
81
|
.and_then(|mt| EmbeddingModelType::from_string(&mt))
|
@@ -102,7 +102,7 @@ impl EmbeddingModel {
|
|
102
102
|
/// Generates an embedding vector for the input text using the specified pooling method.
|
103
103
|
/// &RETURNS&: Tensor
|
104
104
|
/// pooling_method: "pooled", "pooled_normalized", or "cls" (default: "pooled")
|
105
|
-
pub fn embedding(&self, input: String, pooling_method: String) ->
|
105
|
+
pub fn embedding(&self, input: String, pooling_method: String) -> Result<Tensor> {
|
106
106
|
match &self.0.model {
|
107
107
|
Some(model) => {
|
108
108
|
match &self.0.tokenizer {
|
@@ -116,7 +116,7 @@ impl EmbeddingModel {
|
|
116
116
|
|
117
117
|
/// Returns the unpooled embedding tensor ([1, SEQLENGTH, DIM]) for the input text
|
118
118
|
/// &RETURNS&: Tensor
|
119
|
-
pub fn embeddings(&self, input: String) ->
|
119
|
+
pub fn embeddings(&self, input: String) -> Result<Tensor> {
|
120
120
|
match &self.0.model {
|
121
121
|
Some(model) => {
|
122
122
|
match &self.0.tokenizer {
|
@@ -130,27 +130,27 @@ impl EmbeddingModel {
|
|
130
130
|
|
131
131
|
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
132
132
|
/// &RETURNS&: Tensor
|
133
|
-
pub fn pool_embedding(&self, tensor: &Tensor) ->
|
133
|
+
pub fn pool_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
134
134
|
let pooled = Self::pooled_embedding(&tensor.0)?;
|
135
135
|
Ok(Tensor(pooled))
|
136
136
|
}
|
137
137
|
|
138
138
|
/// Pools and normalizes a sequence embedding tensor ([1, SEQLENGTH, DIM]) to [1, DIM]
|
139
139
|
/// &RETURNS&: Tensor
|
140
|
-
pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) ->
|
140
|
+
pub fn pool_and_normalize_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
141
141
|
let pooled = Self::pooled_normalized_embedding(&tensor.0)?;
|
142
142
|
Ok(Tensor(pooled))
|
143
143
|
}
|
144
144
|
|
145
145
|
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
146
146
|
/// &RETURNS&: Tensor
|
147
|
-
pub fn pool_cls_embedding(&self, tensor: &Tensor) ->
|
147
|
+
pub fn pool_cls_embedding(&self, tensor: &Tensor) -> Result<Tensor> {
|
148
148
|
let pooled = Self::pooled_cls_embedding(&tensor.0)?;
|
149
149
|
Ok(Tensor(pooled))
|
150
150
|
}
|
151
151
|
|
152
152
|
/// Infers and validates the embedding size from a safetensors file
|
153
|
-
fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> Result<usize, magnus::Error> {
|
153
|
+
fn resolve_embedding_size(model_path: &Path, embedding_size: Option<usize>) -> std::result::Result<usize, magnus::Error> {
|
154
154
|
match embedding_size {
|
155
155
|
Some(user_dim) => {
|
156
156
|
Ok(user_dim)
|
@@ -170,7 +170,7 @@ impl EmbeddingModel {
|
|
170
170
|
}
|
171
171
|
}
|
172
172
|
|
173
|
-
fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) ->
|
173
|
+
fn build_embedding_model(model_path: &Path, device: CoreDevice, embedding_model_type: EmbeddingModelType, embedding_size: Option<usize>) -> Result<EmbeddingModelVariant> {
|
174
174
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
175
175
|
let api = Api::new().map_err(wrap_hf_err)?;
|
176
176
|
let repo = Repo::new(model_path.to_str().unwrap().to_string(), RepoType::Model);
|
@@ -257,7 +257,7 @@ impl EmbeddingModel {
|
|
257
257
|
}
|
258
258
|
}
|
259
259
|
|
260
|
-
fn build_tokenizer(tokenizer_path: String) ->
|
260
|
+
fn build_tokenizer(tokenizer_path: String) -> Result<TokenizerWrapper> {
|
261
261
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
262
262
|
let tokenizer_path = Api::new()
|
263
263
|
.map_err(wrap_hf_err)?
|
@@ -267,20 +267,16 @@ impl EmbeddingModel {
|
|
267
267
|
))
|
268
268
|
.get("tokenizer.json")
|
269
269
|
.map_err(wrap_hf_err)?;
|
270
|
-
let
|
270
|
+
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
|
271
271
|
.map_err(wrap_std_err)?;
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
};
|
276
|
-
tokenizer.with_padding(Some(pp));
|
277
|
-
|
278
|
-
Ok(tokenizer)
|
272
|
+
|
273
|
+
let tokenizer = TokenizerLoader::with_padding(tokenizer, None);
|
274
|
+
Ok(TokenizerWrapper::new(tokenizer))
|
279
275
|
}
|
280
276
|
|
281
277
|
/// Pools the embedding tensor by extracting the CLS token ([1, SEQLENGTH, DIM] -> [1, DIM])
|
282
278
|
/// &RETURNS&: Tensor
|
283
|
-
fn pooled_cls_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
279
|
+
fn pooled_cls_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
284
280
|
// 1) sanity-check that we have a 3D tensor
|
285
281
|
let (_batch, _seq_len, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
286
282
|
|
@@ -298,14 +294,14 @@ impl EmbeddingModel {
|
|
298
294
|
Ok(cls)
|
299
295
|
}
|
300
296
|
|
301
|
-
fn pooled_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
297
|
+
fn pooled_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
302
298
|
let (_n_sentence, n_tokens, _hidden_size) = result.dims3().map_err(wrap_candle_err)?;
|
303
299
|
let sum = result.sum(1).map_err(wrap_candle_err)?;
|
304
300
|
let mean = (sum / (n_tokens as f64)).map_err(wrap_candle_err)?;
|
305
301
|
Ok(mean)
|
306
302
|
}
|
307
303
|
|
308
|
-
fn pooled_normalized_embedding(result: &CoreTensor) -> Result<CoreTensor, Error> {
|
304
|
+
fn pooled_normalized_embedding(result: &CoreTensor) -> std::result::Result<CoreTensor, Error> {
|
309
305
|
let mean = Self::pooled_embedding(result)?;
|
310
306
|
let norm = Self::normalize_l2(&mean).map_err(wrap_candle_err)?;
|
311
307
|
Ok(norm)
|
@@ -315,13 +311,11 @@ impl EmbeddingModel {
|
|
315
311
|
&self,
|
316
312
|
prompt: String,
|
317
313
|
model: &EmbeddingModelVariant,
|
318
|
-
tokenizer: &
|
319
|
-
) -> Result<CoreTensor, Error> {
|
314
|
+
tokenizer: &TokenizerWrapper,
|
315
|
+
) -> std::result::Result<CoreTensor, Error> {
|
320
316
|
let tokens = tokenizer
|
321
|
-
.encode(prompt, true)
|
322
|
-
.map_err(
|
323
|
-
.get_ids()
|
324
|
-
.to_vec();
|
317
|
+
.encode(&prompt, true)
|
318
|
+
.map_err(wrap_candle_err)?;
|
325
319
|
let token_ids = CoreTensor::new(&tokens[..], &self.0.device)
|
326
320
|
.map_err(wrap_candle_err)?
|
327
321
|
.unsqueeze(0)
|
@@ -355,9 +349,9 @@ impl EmbeddingModel {
|
|
355
349
|
&self,
|
356
350
|
prompt: String,
|
357
351
|
model: &EmbeddingModelVariant,
|
358
|
-
tokenizer: &
|
352
|
+
tokenizer: &TokenizerWrapper,
|
359
353
|
pooling_method: &str,
|
360
|
-
) -> Result<CoreTensor, Error> {
|
354
|
+
) -> std::result::Result<CoreTensor, Error> {
|
361
355
|
let result = self.compute_embeddings(prompt, model, tokenizer)?;
|
362
356
|
match pooling_method {
|
363
357
|
"pooled" => Self::pooled_embedding(&result),
|
@@ -367,7 +361,7 @@ impl EmbeddingModel {
|
|
367
361
|
}
|
368
362
|
}
|
369
363
|
|
370
|
-
fn normalize_l2(v: &CoreTensor) -> Result<CoreTensor
|
364
|
+
fn normalize_l2(v: &CoreTensor) -> candle_core::Result<CoreTensor> {
|
371
365
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
372
366
|
}
|
373
367
|
|
@@ -391,9 +385,17 @@ impl EmbeddingModel {
|
|
391
385
|
pub fn __str__(&self) -> String {
|
392
386
|
self.__repr__()
|
393
387
|
}
|
388
|
+
|
389
|
+
/// Get the tokenizer used by this model
|
390
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
391
|
+
match &self.0.tokenizer {
|
392
|
+
Some(tokenizer) => Ok(crate::ruby::tokenizer::Tokenizer(tokenizer.clone())),
|
393
|
+
None => Err(magnus::Error::new(magnus::exception::runtime_error(), "No tokenizer loaded for this model"))
|
394
|
+
}
|
395
|
+
}
|
394
396
|
}
|
395
397
|
|
396
|
-
pub fn init(rb_candle: RModule) -> Result<()
|
398
|
+
pub fn init(rb_candle: RModule) -> Result<()> {
|
397
399
|
let rb_embedding_model = rb_candle.define_class("EmbeddingModel", class::object())?;
|
398
400
|
rb_embedding_model.define_singleton_method("_create", function!(EmbeddingModel::new, 5))?;
|
399
401
|
// Expose embedding with an optional pooling_method argument (default: "pooled")
|
@@ -405,5 +407,6 @@ pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
|
405
407
|
rb_embedding_model.define_method("embedding_model_type", method!(EmbeddingModel::embedding_model_type, 0))?;
|
406
408
|
rb_embedding_model.define_method("to_s", method!(EmbeddingModel::__str__, 0))?;
|
407
409
|
rb_embedding_model.define_method("inspect", method!(EmbeddingModel::__repr__, 0))?;
|
410
|
+
rb_embedding_model.define_method("tokenizer", method!(EmbeddingModel::tokenizer, 0))?;
|
408
411
|
Ok(())
|
409
412
|
}
|
data/ext/candle/src/ruby/llm.rs
CHANGED
@@ -2,7 +2,7 @@ use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule
|
|
2
2
|
use std::cell::RefCell;
|
3
3
|
|
4
4
|
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral, llama::Llama as RustLlama, gemma::Gemma as RustGemma, QuantizedGGUF as RustQuantizedGGUF};
|
5
|
-
use crate::ruby::{Result
|
5
|
+
use crate::ruby::{Result, Device};
|
6
6
|
|
7
7
|
// Use an enum to handle different model types instead of trait objects
|
8
8
|
#[derive(Debug)]
|
@@ -79,7 +79,7 @@ pub struct GenerationConfig {
|
|
79
79
|
}
|
80
80
|
|
81
81
|
impl GenerationConfig {
|
82
|
-
pub fn new(kwargs: RHash) ->
|
82
|
+
pub fn new(kwargs: RHash) -> Result<Self> {
|
83
83
|
let mut config = RustGenerationConfig::default();
|
84
84
|
|
85
85
|
// Extract values from kwargs manually
|
@@ -198,13 +198,13 @@ impl GenerationConfig {
|
|
198
198
|
pub struct LLM {
|
199
199
|
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
200
200
|
model_id: String,
|
201
|
-
device:
|
201
|
+
device: Device,
|
202
202
|
}
|
203
203
|
|
204
204
|
impl LLM {
|
205
205
|
/// Create a new LLM from a pretrained model
|
206
|
-
pub fn from_pretrained(model_id: String, device: Option<
|
207
|
-
let device = device.unwrap_or(
|
206
|
+
pub fn from_pretrained(model_id: String, device: Option<Device>) -> Result<Self> {
|
207
|
+
let device = device.unwrap_or(Device::Cpu);
|
208
208
|
let candle_device = device.as_device()?;
|
209
209
|
|
210
210
|
// For now, we'll use tokio runtime directly
|
@@ -267,7 +267,7 @@ impl LLM {
|
|
267
267
|
}
|
268
268
|
|
269
269
|
/// Generate text from a prompt
|
270
|
-
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) ->
|
270
|
+
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
271
271
|
let config = config
|
272
272
|
.map(|c| c.inner.clone())
|
273
273
|
.unwrap_or_default();
|
@@ -283,7 +283,7 @@ impl LLM {
|
|
283
283
|
}
|
284
284
|
|
285
285
|
/// Generate text with streaming output
|
286
|
-
pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) ->
|
286
|
+
pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> Result<String> {
|
287
287
|
let config = config
|
288
288
|
.map(|c| c.inner.clone())
|
289
289
|
.unwrap_or_default();
|
@@ -315,12 +315,29 @@ impl LLM {
|
|
315
315
|
}
|
316
316
|
|
317
317
|
/// Get the device the model is running on
|
318
|
-
pub fn device(&self) ->
|
318
|
+
pub fn device(&self) -> Device {
|
319
319
|
self.device
|
320
320
|
}
|
321
|
+
|
322
|
+
/// Get the tokenizer used by this model
|
323
|
+
pub fn tokenizer(&self) -> Result<crate::ruby::tokenizer::Tokenizer> {
|
324
|
+
let model = match self.model.lock() {
|
325
|
+
Ok(guard) => guard,
|
326
|
+
Err(poisoned) => poisoned.into_inner(),
|
327
|
+
};
|
328
|
+
let model_ref = model.borrow();
|
329
|
+
|
330
|
+
// Clone the tokenizer from the model
|
331
|
+
match &*model_ref {
|
332
|
+
ModelType::Mistral(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
333
|
+
ModelType::Llama(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
334
|
+
ModelType::Gemma(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
335
|
+
ModelType::QuantizedGGUF(m) => Ok(crate::ruby::tokenizer::Tokenizer(m.tokenizer().clone())),
|
336
|
+
}
|
337
|
+
}
|
321
338
|
|
322
339
|
/// Clear the model's cache (e.g., KV cache for transformers)
|
323
|
-
pub fn clear_cache(&self) ->
|
340
|
+
pub fn clear_cache(&self) -> Result<()> {
|
324
341
|
let model = match self.model.lock() {
|
325
342
|
Ok(guard) => guard,
|
326
343
|
Err(poisoned) => {
|
@@ -335,7 +352,7 @@ impl LLM {
|
|
335
352
|
}
|
336
353
|
|
337
354
|
/// Apply chat template to messages
|
338
|
-
pub fn apply_chat_template(&self, messages: RArray) ->
|
355
|
+
pub fn apply_chat_template(&self, messages: RArray) -> Result<String> {
|
339
356
|
// Convert Ruby array to JSON values
|
340
357
|
let json_messages: Vec<serde_json::Value> = messages
|
341
358
|
.into_iter()
|
@@ -374,7 +391,7 @@ impl LLM {
|
|
374
391
|
}
|
375
392
|
|
376
393
|
// Define a standalone function for from_pretrained that handles variable arguments
|
377
|
-
fn from_pretrained_wrapper(args: &[Value]) ->
|
394
|
+
fn from_pretrained_wrapper(args: &[Value]) -> Result<LLM> {
|
378
395
|
match args.len() {
|
379
396
|
1 => {
|
380
397
|
let model_id: String = TryConvert::try_convert(args[0])?;
|
@@ -382,7 +399,7 @@ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
|
|
382
399
|
},
|
383
400
|
2 => {
|
384
401
|
let model_id: String = TryConvert::try_convert(args[0])?;
|
385
|
-
let device:
|
402
|
+
let device: Device = TryConvert::try_convert(args[1])?;
|
386
403
|
LLM::from_pretrained(model_id, Some(device))
|
387
404
|
},
|
388
405
|
_ => Err(Error::new(
|
@@ -392,7 +409,7 @@ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
|
|
392
409
|
}
|
393
410
|
}
|
394
411
|
|
395
|
-
pub fn init_llm(rb_candle: RModule) ->
|
412
|
+
pub fn init_llm(rb_candle: RModule) -> Result<()> {
|
396
413
|
let rb_generation_config = rb_candle.define_class("GenerationConfig", magnus::class::object())?;
|
397
414
|
rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
|
398
415
|
rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
|
@@ -413,6 +430,7 @@ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
|
413
430
|
rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
|
414
431
|
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
415
432
|
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
433
|
+
rb_llm.define_method("tokenizer", method!(LLM::tokenizer, 0))?;
|
416
434
|
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
417
435
|
rb_llm.define_method("apply_chat_template", method!(LLM::apply_chat_template, 1))?;
|
418
436
|
|
data/ext/candle/src/ruby/mod.rs
CHANGED
@@ -2,17 +2,16 @@ pub mod embedding_model;
|
|
2
2
|
pub mod tensor;
|
3
3
|
pub mod device;
|
4
4
|
pub mod dtype;
|
5
|
-
pub mod qtensor;
|
6
5
|
pub mod result;
|
7
6
|
pub mod errors;
|
8
7
|
pub mod utils;
|
9
8
|
pub mod llm;
|
9
|
+
pub mod tokenizer;
|
10
10
|
|
11
11
|
pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
|
12
12
|
pub use tensor::Tensor;
|
13
13
|
pub use device::Device;
|
14
14
|
pub use dtype::DType;
|
15
|
-
pub use qtensor::QTensor;
|
16
15
|
pub use result::Result;
|
17
16
|
|
18
17
|
// Re-export for convenience
|