red-candle 1.0.0.pre.1 → 1.0.0.pre.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/candle/build.rs +116 -0
- data/ext/candle/src/lib.rs +6 -96
- data/ext/candle/src/llm/generation_config.rs +49 -0
- data/ext/candle/src/llm/mistral.rs +325 -0
- data/ext/candle/src/llm/mod.rs +68 -0
- data/ext/candle/src/llm/text_generation.rs +141 -0
- data/ext/candle/src/reranker.rs +267 -0
- data/ext/candle/src/ruby/device.rs +197 -0
- data/ext/candle/src/ruby/dtype.rs +37 -0
- data/ext/candle/src/ruby/embedding_model.rs +410 -0
- data/ext/candle/src/ruby/errors.rs +13 -0
- data/ext/candle/src/ruby/llm.rs +295 -0
- data/ext/candle/src/ruby/mod.rs +21 -0
- data/ext/candle/src/ruby/qtensor.rs +69 -0
- data/ext/candle/src/ruby/result.rs +3 -0
- data/ext/candle/src/ruby/tensor.rs +654 -0
- data/ext/candle/src/ruby/utils.rs +88 -0
- data/lib/candle/version.rb +1 -1
- metadata +17 -1
@@ -0,0 +1,295 @@
|
|
1
|
+
use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
|
2
|
+
use std::cell::RefCell;
|
3
|
+
|
4
|
+
use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral};
|
5
|
+
use crate::ruby::{Result as RbResult, Device as RbDevice};
|
6
|
+
|
7
|
+
// Use an enum to handle different model types instead of trait objects
|
8
|
+
#[derive(Debug)]
|
9
|
+
enum ModelType {
|
10
|
+
Mistral(RustMistral),
|
11
|
+
}
|
12
|
+
|
13
|
+
impl ModelType {
|
14
|
+
fn generate(&mut self, prompt: &str, config: &RustGenerationConfig) -> candle_core::Result<String> {
|
15
|
+
match self {
|
16
|
+
ModelType::Mistral(m) => m.generate(prompt, config),
|
17
|
+
}
|
18
|
+
}
|
19
|
+
|
20
|
+
fn generate_stream(
|
21
|
+
&mut self,
|
22
|
+
prompt: &str,
|
23
|
+
config: &RustGenerationConfig,
|
24
|
+
callback: impl FnMut(&str),
|
25
|
+
) -> candle_core::Result<String> {
|
26
|
+
match self {
|
27
|
+
ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
|
28
|
+
}
|
29
|
+
}
|
30
|
+
|
31
|
+
#[allow(dead_code)]
|
32
|
+
fn model_name(&self) -> &str {
|
33
|
+
match self {
|
34
|
+
ModelType::Mistral(m) => m.model_name(),
|
35
|
+
}
|
36
|
+
}
|
37
|
+
|
38
|
+
fn clear_cache(&mut self) {
|
39
|
+
match self {
|
40
|
+
ModelType::Mistral(m) => m.clear_cache(),
|
41
|
+
}
|
42
|
+
}
|
43
|
+
}
|
44
|
+
|
45
|
+
#[derive(Clone, Debug)]
|
46
|
+
#[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
|
47
|
+
pub struct GenerationConfig {
|
48
|
+
inner: RustGenerationConfig,
|
49
|
+
}
|
50
|
+
|
51
|
+
impl GenerationConfig {
|
52
|
+
pub fn new(kwargs: RHash) -> RbResult<Self> {
|
53
|
+
let mut config = RustGenerationConfig::default();
|
54
|
+
|
55
|
+
// Extract values from kwargs manually
|
56
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("max_length")) {
|
57
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
58
|
+
config.max_length = v;
|
59
|
+
}
|
60
|
+
}
|
61
|
+
|
62
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("temperature")) {
|
63
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
64
|
+
config.temperature = v;
|
65
|
+
}
|
66
|
+
}
|
67
|
+
|
68
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
|
69
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
70
|
+
config.top_p = Some(v);
|
71
|
+
}
|
72
|
+
}
|
73
|
+
|
74
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
|
75
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
76
|
+
config.top_k = Some(v);
|
77
|
+
}
|
78
|
+
}
|
79
|
+
|
80
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
|
81
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
82
|
+
config.repetition_penalty = v;
|
83
|
+
}
|
84
|
+
}
|
85
|
+
|
86
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
|
87
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
88
|
+
config.repetition_penalty_last_n = v;
|
89
|
+
}
|
90
|
+
}
|
91
|
+
|
92
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
|
93
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
94
|
+
config.seed = v;
|
95
|
+
}
|
96
|
+
}
|
97
|
+
|
98
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
|
99
|
+
if let Ok(v) = TryConvert::try_convert(value) {
|
100
|
+
config.include_prompt = v;
|
101
|
+
}
|
102
|
+
}
|
103
|
+
|
104
|
+
if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
|
105
|
+
if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
|
106
|
+
config.stop_sequences = arr
|
107
|
+
.into_iter()
|
108
|
+
.filter_map(|v| <String as TryConvert>::try_convert(v).ok())
|
109
|
+
.collect();
|
110
|
+
}
|
111
|
+
}
|
112
|
+
|
113
|
+
Ok(Self { inner: config })
|
114
|
+
}
|
115
|
+
|
116
|
+
pub fn default() -> Self {
|
117
|
+
Self {
|
118
|
+
inner: RustGenerationConfig::default(),
|
119
|
+
}
|
120
|
+
}
|
121
|
+
|
122
|
+
// Getters
|
123
|
+
pub fn max_length(&self) -> usize {
|
124
|
+
self.inner.max_length
|
125
|
+
}
|
126
|
+
|
127
|
+
pub fn temperature(&self) -> f64 {
|
128
|
+
self.inner.temperature
|
129
|
+
}
|
130
|
+
|
131
|
+
pub fn top_p(&self) -> Option<f64> {
|
132
|
+
self.inner.top_p
|
133
|
+
}
|
134
|
+
|
135
|
+
pub fn top_k(&self) -> Option<usize> {
|
136
|
+
self.inner.top_k
|
137
|
+
}
|
138
|
+
|
139
|
+
pub fn repetition_penalty(&self) -> f32 {
|
140
|
+
self.inner.repetition_penalty
|
141
|
+
}
|
142
|
+
|
143
|
+
pub fn seed(&self) -> u64 {
|
144
|
+
self.inner.seed
|
145
|
+
}
|
146
|
+
|
147
|
+
pub fn stop_sequences(&self) -> Vec<String> {
|
148
|
+
self.inner.stop_sequences.clone()
|
149
|
+
}
|
150
|
+
|
151
|
+
pub fn include_prompt(&self) -> bool {
|
152
|
+
self.inner.include_prompt
|
153
|
+
}
|
154
|
+
}
|
155
|
+
|
156
|
+
#[derive(Clone, Debug)]
|
157
|
+
#[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
|
158
|
+
pub struct LLM {
|
159
|
+
model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
|
160
|
+
model_id: String,
|
161
|
+
device: RbDevice,
|
162
|
+
}
|
163
|
+
|
164
|
+
impl LLM {
|
165
|
+
/// Create a new LLM from a pretrained model
|
166
|
+
pub fn from_pretrained(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
|
167
|
+
let device = device.unwrap_or(RbDevice::Cpu);
|
168
|
+
let candle_device = device.as_device()?;
|
169
|
+
|
170
|
+
// For now, we'll use tokio runtime directly
|
171
|
+
// In production, you might want to share a runtime
|
172
|
+
let rt = tokio::runtime::Runtime::new()
|
173
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
|
174
|
+
|
175
|
+
// Determine model type from ID and load appropriately
|
176
|
+
let model_lower = model_id.to_lowercase();
|
177
|
+
let model = if model_lower.contains("mistral") {
|
178
|
+
let mistral = rt.block_on(async {
|
179
|
+
RustMistral::from_pretrained(&model_id, candle_device).await
|
180
|
+
})
|
181
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
|
182
|
+
ModelType::Mistral(mistral)
|
183
|
+
} else {
|
184
|
+
return Err(Error::new(
|
185
|
+
magnus::exception::runtime_error(),
|
186
|
+
format!("Unsupported model type: {}. Currently only Mistral models are supported.", model_id),
|
187
|
+
));
|
188
|
+
};
|
189
|
+
|
190
|
+
Ok(Self {
|
191
|
+
model: std::sync::Arc::new(std::sync::Mutex::new(RefCell::new(model))),
|
192
|
+
model_id,
|
193
|
+
device,
|
194
|
+
})
|
195
|
+
}
|
196
|
+
|
197
|
+
/// Generate text from a prompt
|
198
|
+
pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> RbResult<String> {
|
199
|
+
let config = config
|
200
|
+
.map(|c| c.inner.clone())
|
201
|
+
.unwrap_or_default();
|
202
|
+
|
203
|
+
let model = self.model.lock().unwrap();
|
204
|
+
let mut model_ref = model.borrow_mut();
|
205
|
+
|
206
|
+
model_ref.generate(&prompt, &config)
|
207
|
+
.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Generation failed: {}", e)))
|
208
|
+
}
|
209
|
+
|
210
|
+
/// Generate text with streaming output
|
211
|
+
pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> RbResult<String> {
|
212
|
+
let config = config
|
213
|
+
.map(|c| c.inner.clone())
|
214
|
+
.unwrap_or_default();
|
215
|
+
|
216
|
+
let ruby = Ruby::get().unwrap();
|
217
|
+
let block = ruby.block_proc();
|
218
|
+
if let Err(_) = block {
|
219
|
+
return Err(Error::new(magnus::exception::runtime_error(), "No block given"));
|
220
|
+
}
|
221
|
+
let block = block.unwrap();
|
222
|
+
|
223
|
+
let model = self.model.lock().unwrap();
|
224
|
+
let mut model_ref = model.borrow_mut();
|
225
|
+
|
226
|
+
let result = model_ref.generate_stream(&prompt, &config, |token| {
|
227
|
+
// Call the Ruby block with each token
|
228
|
+
let _ = block.call::<(String,), Value>((token.to_string(),));
|
229
|
+
});
|
230
|
+
|
231
|
+
result.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Generation failed: {}", e)))
|
232
|
+
}
|
233
|
+
|
234
|
+
/// Get the model name
|
235
|
+
pub fn model_name(&self) -> String {
|
236
|
+
self.model_id.clone()
|
237
|
+
}
|
238
|
+
|
239
|
+
/// Get the device the model is running on
|
240
|
+
pub fn device(&self) -> RbDevice {
|
241
|
+
self.device
|
242
|
+
}
|
243
|
+
|
244
|
+
/// Clear the model's cache (e.g., KV cache for transformers)
|
245
|
+
pub fn clear_cache(&self) -> RbResult<()> {
|
246
|
+
let model = self.model.lock().unwrap();
|
247
|
+
let mut model_ref = model.borrow_mut();
|
248
|
+
model_ref.clear_cache();
|
249
|
+
Ok(())
|
250
|
+
}
|
251
|
+
}
|
252
|
+
|
253
|
+
// Define a standalone function for from_pretrained that handles variable arguments
|
254
|
+
fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
|
255
|
+
match args.len() {
|
256
|
+
1 => {
|
257
|
+
let model_id: String = TryConvert::try_convert(args[0])?;
|
258
|
+
LLM::from_pretrained(model_id, None)
|
259
|
+
},
|
260
|
+
2 => {
|
261
|
+
let model_id: String = TryConvert::try_convert(args[0])?;
|
262
|
+
let device: RbDevice = TryConvert::try_convert(args[1])?;
|
263
|
+
LLM::from_pretrained(model_id, Some(device))
|
264
|
+
},
|
265
|
+
_ => Err(Error::new(
|
266
|
+
magnus::exception::arg_error(),
|
267
|
+
"wrong number of arguments (expected 1..2)"
|
268
|
+
))
|
269
|
+
}
|
270
|
+
}
|
271
|
+
|
272
|
+
pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
|
273
|
+
let rb_generation_config = rb_candle.define_class("GenerationConfig", magnus::class::object())?;
|
274
|
+
rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
|
275
|
+
rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
|
276
|
+
|
277
|
+
rb_generation_config.define_method("max_length", method!(GenerationConfig::max_length, 0))?;
|
278
|
+
rb_generation_config.define_method("temperature", method!(GenerationConfig::temperature, 0))?;
|
279
|
+
rb_generation_config.define_method("top_p", method!(GenerationConfig::top_p, 0))?;
|
280
|
+
rb_generation_config.define_method("top_k", method!(GenerationConfig::top_k, 0))?;
|
281
|
+
rb_generation_config.define_method("repetition_penalty", method!(GenerationConfig::repetition_penalty, 0))?;
|
282
|
+
rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
|
283
|
+
rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
|
284
|
+
rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
|
285
|
+
|
286
|
+
let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
|
287
|
+
rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
|
288
|
+
rb_llm.define_method("_generate", method!(LLM::generate, 2))?;
|
289
|
+
rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
|
290
|
+
rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
|
291
|
+
rb_llm.define_method("device", method!(LLM::device, 0))?;
|
292
|
+
rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
|
293
|
+
|
294
|
+
Ok(())
|
295
|
+
}
|
@@ -0,0 +1,21 @@
|
|
1
|
+
pub mod embedding_model;
|
2
|
+
pub mod tensor;
|
3
|
+
pub mod device;
|
4
|
+
pub mod dtype;
|
5
|
+
pub mod qtensor;
|
6
|
+
pub mod result;
|
7
|
+
pub mod errors;
|
8
|
+
pub mod utils;
|
9
|
+
pub mod llm;
|
10
|
+
|
11
|
+
pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
|
12
|
+
pub use tensor::Tensor;
|
13
|
+
pub use device::Device;
|
14
|
+
pub use dtype::DType;
|
15
|
+
pub use qtensor::QTensor;
|
16
|
+
pub use result::Result;
|
17
|
+
|
18
|
+
// Re-export for convenience
|
19
|
+
pub use embedding_model::init as init_embedding_model;
|
20
|
+
pub use utils::candle_utils;
|
21
|
+
pub use llm::init_llm;
|
@@ -0,0 +1,69 @@
|
|
1
|
+
use std::sync::Arc;
|
2
|
+
use magnus::{method, class, RModule, Error, Module};
|
3
|
+
|
4
|
+
use crate::ruby::errors::wrap_candle_err;
|
5
|
+
use crate::ruby::{Tensor, Result as RbResult};
|
6
|
+
use ::candle_core::{quantized::QTensor as CoreQTensor, Device as CoreDevice};
|
7
|
+
|
8
|
+
#[derive(Debug)]
|
9
|
+
#[magnus::wrap(class = "Candle::QTensor", free_immediately, size)]
|
10
|
+
/// A quantized tensor.
|
11
|
+
pub struct QTensor(Arc<CoreQTensor>);
|
12
|
+
|
13
|
+
impl std::ops::Deref for QTensor {
|
14
|
+
type Target = CoreQTensor;
|
15
|
+
|
16
|
+
fn deref(&self) -> &Self::Target {
|
17
|
+
self.0.as_ref()
|
18
|
+
}
|
19
|
+
}
|
20
|
+
|
21
|
+
impl QTensor {
|
22
|
+
///Gets the tensors quantized dtype.
|
23
|
+
/// &RETURNS&: str
|
24
|
+
pub fn ggml_dtype(&self) -> String {
|
25
|
+
format!("{:?}", self.0.dtype())
|
26
|
+
}
|
27
|
+
|
28
|
+
///Gets the rank of the tensor.
|
29
|
+
/// &RETURNS&: int
|
30
|
+
pub fn rank(&self) -> usize {
|
31
|
+
self.0.rank()
|
32
|
+
}
|
33
|
+
|
34
|
+
///Gets the shape of the tensor.
|
35
|
+
/// &RETURNS&: Tuple[int]
|
36
|
+
pub fn shape(&self) -> Vec<usize> {
|
37
|
+
self.0.shape().dims().to_vec()
|
38
|
+
}
|
39
|
+
|
40
|
+
pub fn __repr__(&self) -> String {
|
41
|
+
format!("{:?}", self.0)
|
42
|
+
}
|
43
|
+
|
44
|
+
pub fn __str__(&self) -> String {
|
45
|
+
self.__repr__()
|
46
|
+
}
|
47
|
+
|
48
|
+
/// Dequantizes the tensor.
|
49
|
+
/// &RETURNS&: Tensor
|
50
|
+
pub fn dequantize(&self) -> RbResult<Tensor> {
|
51
|
+
let tensor = self.0.dequantize(&CoreDevice::Cpu).map_err(wrap_candle_err)?;
|
52
|
+
Ok(Tensor(tensor))
|
53
|
+
}
|
54
|
+
|
55
|
+
// fn matmul_t(&self, lhs: &Tensor) -> RbResult<Tensor> {
|
56
|
+
// let qmatmul = ::candle_core::quantized::QMatMul::from_arc(self.0.clone());
|
57
|
+
// let res = qmatmul.forward(lhs).map_err(wrap_candle_err)?;
|
58
|
+
// Ok(Tensor(res))
|
59
|
+
// }
|
60
|
+
}
|
61
|
+
|
62
|
+
pub fn init(rb_candle: RModule) -> Result<(), Error> {
|
63
|
+
let rb_qtensor = rb_candle.define_class("QTensor", class::object())?;
|
64
|
+
rb_qtensor.define_method("ggml_dtype", method!(QTensor::ggml_dtype, 0))?;
|
65
|
+
rb_qtensor.define_method("rank", method!(QTensor::rank, 0))?;
|
66
|
+
rb_qtensor.define_method("shape", method!(QTensor::shape, 0))?;
|
67
|
+
rb_qtensor.define_method("dequantize", method!(QTensor::dequantize, 0))?;
|
68
|
+
Ok(())
|
69
|
+
}
|