red-candle 1.0.0.pre.1 → 1.0.0.pre.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,295 @@
1
+ use magnus::{function, method, prelude::*, Error, Module, RArray, RHash, RModule, Ruby, TryConvert, Value};
2
+ use std::cell::RefCell;
3
+
4
+ use crate::llm::{GenerationConfig as RustGenerationConfig, TextGenerator, mistral::Mistral as RustMistral};
5
+ use crate::ruby::{Result as RbResult, Device as RbDevice};
6
+
7
+ // Use an enum to handle different model types instead of trait objects
8
+ #[derive(Debug)]
9
+ enum ModelType {
10
+ Mistral(RustMistral),
11
+ }
12
+
13
+ impl ModelType {
14
+ fn generate(&mut self, prompt: &str, config: &RustGenerationConfig) -> candle_core::Result<String> {
15
+ match self {
16
+ ModelType::Mistral(m) => m.generate(prompt, config),
17
+ }
18
+ }
19
+
20
+ fn generate_stream(
21
+ &mut self,
22
+ prompt: &str,
23
+ config: &RustGenerationConfig,
24
+ callback: impl FnMut(&str),
25
+ ) -> candle_core::Result<String> {
26
+ match self {
27
+ ModelType::Mistral(m) => m.generate_stream(prompt, config, callback),
28
+ }
29
+ }
30
+
31
+ #[allow(dead_code)]
32
+ fn model_name(&self) -> &str {
33
+ match self {
34
+ ModelType::Mistral(m) => m.model_name(),
35
+ }
36
+ }
37
+
38
+ fn clear_cache(&mut self) {
39
+ match self {
40
+ ModelType::Mistral(m) => m.clear_cache(),
41
+ }
42
+ }
43
+ }
44
+
45
+ #[derive(Clone, Debug)]
46
+ #[magnus::wrap(class = "Candle::GenerationConfig", mark, free_immediately)]
47
+ pub struct GenerationConfig {
48
+ inner: RustGenerationConfig,
49
+ }
50
+
51
+ impl GenerationConfig {
52
+ pub fn new(kwargs: RHash) -> RbResult<Self> {
53
+ let mut config = RustGenerationConfig::default();
54
+
55
+ // Extract values from kwargs manually
56
+ if let Some(value) = kwargs.get(magnus::Symbol::new("max_length")) {
57
+ if let Ok(v) = TryConvert::try_convert(value) {
58
+ config.max_length = v;
59
+ }
60
+ }
61
+
62
+ if let Some(value) = kwargs.get(magnus::Symbol::new("temperature")) {
63
+ if let Ok(v) = TryConvert::try_convert(value) {
64
+ config.temperature = v;
65
+ }
66
+ }
67
+
68
+ if let Some(value) = kwargs.get(magnus::Symbol::new("top_p")) {
69
+ if let Ok(v) = TryConvert::try_convert(value) {
70
+ config.top_p = Some(v);
71
+ }
72
+ }
73
+
74
+ if let Some(value) = kwargs.get(magnus::Symbol::new("top_k")) {
75
+ if let Ok(v) = TryConvert::try_convert(value) {
76
+ config.top_k = Some(v);
77
+ }
78
+ }
79
+
80
+ if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty")) {
81
+ if let Ok(v) = TryConvert::try_convert(value) {
82
+ config.repetition_penalty = v;
83
+ }
84
+ }
85
+
86
+ if let Some(value) = kwargs.get(magnus::Symbol::new("repetition_penalty_last_n")) {
87
+ if let Ok(v) = TryConvert::try_convert(value) {
88
+ config.repetition_penalty_last_n = v;
89
+ }
90
+ }
91
+
92
+ if let Some(value) = kwargs.get(magnus::Symbol::new("seed")) {
93
+ if let Ok(v) = TryConvert::try_convert(value) {
94
+ config.seed = v;
95
+ }
96
+ }
97
+
98
+ if let Some(value) = kwargs.get(magnus::Symbol::new("include_prompt")) {
99
+ if let Ok(v) = TryConvert::try_convert(value) {
100
+ config.include_prompt = v;
101
+ }
102
+ }
103
+
104
+ if let Some(value) = kwargs.get(magnus::Symbol::new("stop_sequences")) {
105
+ if let Ok(arr) = <RArray as TryConvert>::try_convert(value) {
106
+ config.stop_sequences = arr
107
+ .into_iter()
108
+ .filter_map(|v| <String as TryConvert>::try_convert(v).ok())
109
+ .collect();
110
+ }
111
+ }
112
+
113
+ Ok(Self { inner: config })
114
+ }
115
+
116
+ pub fn default() -> Self {
117
+ Self {
118
+ inner: RustGenerationConfig::default(),
119
+ }
120
+ }
121
+
122
+ // Getters
123
+ pub fn max_length(&self) -> usize {
124
+ self.inner.max_length
125
+ }
126
+
127
+ pub fn temperature(&self) -> f64 {
128
+ self.inner.temperature
129
+ }
130
+
131
+ pub fn top_p(&self) -> Option<f64> {
132
+ self.inner.top_p
133
+ }
134
+
135
+ pub fn top_k(&self) -> Option<usize> {
136
+ self.inner.top_k
137
+ }
138
+
139
+ pub fn repetition_penalty(&self) -> f32 {
140
+ self.inner.repetition_penalty
141
+ }
142
+
143
+ pub fn seed(&self) -> u64 {
144
+ self.inner.seed
145
+ }
146
+
147
+ pub fn stop_sequences(&self) -> Vec<String> {
148
+ self.inner.stop_sequences.clone()
149
+ }
150
+
151
+ pub fn include_prompt(&self) -> bool {
152
+ self.inner.include_prompt
153
+ }
154
+ }
155
+
156
+ #[derive(Clone, Debug)]
157
+ #[magnus::wrap(class = "Candle::LLM", mark, free_immediately)]
158
+ pub struct LLM {
159
+ model: std::sync::Arc<std::sync::Mutex<RefCell<ModelType>>>,
160
+ model_id: String,
161
+ device: RbDevice,
162
+ }
163
+
164
+ impl LLM {
165
+ /// Create a new LLM from a pretrained model
166
+ pub fn from_pretrained(model_id: String, device: Option<RbDevice>) -> RbResult<Self> {
167
+ let device = device.unwrap_or(RbDevice::Cpu);
168
+ let candle_device = device.as_device()?;
169
+
170
+ // For now, we'll use tokio runtime directly
171
+ // In production, you might want to share a runtime
172
+ let rt = tokio::runtime::Runtime::new()
173
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to create runtime: {}", e)))?;
174
+
175
+ // Determine model type from ID and load appropriately
176
+ let model_lower = model_id.to_lowercase();
177
+ let model = if model_lower.contains("mistral") {
178
+ let mistral = rt.block_on(async {
179
+ RustMistral::from_pretrained(&model_id, candle_device).await
180
+ })
181
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Failed to load model: {}", e)))?;
182
+ ModelType::Mistral(mistral)
183
+ } else {
184
+ return Err(Error::new(
185
+ magnus::exception::runtime_error(),
186
+ format!("Unsupported model type: {}. Currently only Mistral models are supported.", model_id),
187
+ ));
188
+ };
189
+
190
+ Ok(Self {
191
+ model: std::sync::Arc::new(std::sync::Mutex::new(RefCell::new(model))),
192
+ model_id,
193
+ device,
194
+ })
195
+ }
196
+
197
+ /// Generate text from a prompt
198
+ pub fn generate(&self, prompt: String, config: Option<&GenerationConfig>) -> RbResult<String> {
199
+ let config = config
200
+ .map(|c| c.inner.clone())
201
+ .unwrap_or_default();
202
+
203
+ let model = self.model.lock().unwrap();
204
+ let mut model_ref = model.borrow_mut();
205
+
206
+ model_ref.generate(&prompt, &config)
207
+ .map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Generation failed: {}", e)))
208
+ }
209
+
210
+ /// Generate text with streaming output
211
+ pub fn generate_stream(&self, prompt: String, config: Option<&GenerationConfig>) -> RbResult<String> {
212
+ let config = config
213
+ .map(|c| c.inner.clone())
214
+ .unwrap_or_default();
215
+
216
+ let ruby = Ruby::get().unwrap();
217
+ let block = ruby.block_proc();
218
+ if let Err(_) = block {
219
+ return Err(Error::new(magnus::exception::runtime_error(), "No block given"));
220
+ }
221
+ let block = block.unwrap();
222
+
223
+ let model = self.model.lock().unwrap();
224
+ let mut model_ref = model.borrow_mut();
225
+
226
+ let result = model_ref.generate_stream(&prompt, &config, |token| {
227
+ // Call the Ruby block with each token
228
+ let _ = block.call::<(String,), Value>((token.to_string(),));
229
+ });
230
+
231
+ result.map_err(|e| Error::new(magnus::exception::runtime_error(), format!("Generation failed: {}", e)))
232
+ }
233
+
234
+ /// Get the model name
235
+ pub fn model_name(&self) -> String {
236
+ self.model_id.clone()
237
+ }
238
+
239
+ /// Get the device the model is running on
240
+ pub fn device(&self) -> RbDevice {
241
+ self.device
242
+ }
243
+
244
+ /// Clear the model's cache (e.g., KV cache for transformers)
245
+ pub fn clear_cache(&self) -> RbResult<()> {
246
+ let model = self.model.lock().unwrap();
247
+ let mut model_ref = model.borrow_mut();
248
+ model_ref.clear_cache();
249
+ Ok(())
250
+ }
251
+ }
252
+
253
+ // Define a standalone function for from_pretrained that handles variable arguments
254
+ fn from_pretrained_wrapper(args: &[Value]) -> RbResult<LLM> {
255
+ match args.len() {
256
+ 1 => {
257
+ let model_id: String = TryConvert::try_convert(args[0])?;
258
+ LLM::from_pretrained(model_id, None)
259
+ },
260
+ 2 => {
261
+ let model_id: String = TryConvert::try_convert(args[0])?;
262
+ let device: RbDevice = TryConvert::try_convert(args[1])?;
263
+ LLM::from_pretrained(model_id, Some(device))
264
+ },
265
+ _ => Err(Error::new(
266
+ magnus::exception::arg_error(),
267
+ "wrong number of arguments (expected 1..2)"
268
+ ))
269
+ }
270
+ }
271
+
272
+ pub fn init_llm(rb_candle: RModule) -> RbResult<()> {
273
+ let rb_generation_config = rb_candle.define_class("GenerationConfig", magnus::class::object())?;
274
+ rb_generation_config.define_singleton_method("new", function!(GenerationConfig::new, 1))?;
275
+ rb_generation_config.define_singleton_method("default", function!(GenerationConfig::default, 0))?;
276
+
277
+ rb_generation_config.define_method("max_length", method!(GenerationConfig::max_length, 0))?;
278
+ rb_generation_config.define_method("temperature", method!(GenerationConfig::temperature, 0))?;
279
+ rb_generation_config.define_method("top_p", method!(GenerationConfig::top_p, 0))?;
280
+ rb_generation_config.define_method("top_k", method!(GenerationConfig::top_k, 0))?;
281
+ rb_generation_config.define_method("repetition_penalty", method!(GenerationConfig::repetition_penalty, 0))?;
282
+ rb_generation_config.define_method("seed", method!(GenerationConfig::seed, 0))?;
283
+ rb_generation_config.define_method("stop_sequences", method!(GenerationConfig::stop_sequences, 0))?;
284
+ rb_generation_config.define_method("include_prompt", method!(GenerationConfig::include_prompt, 0))?;
285
+
286
+ let rb_llm = rb_candle.define_class("LLM", magnus::class::object())?;
287
+ rb_llm.define_singleton_method("_from_pretrained", function!(from_pretrained_wrapper, -1))?;
288
+ rb_llm.define_method("_generate", method!(LLM::generate, 2))?;
289
+ rb_llm.define_method("_generate_stream", method!(LLM::generate_stream, 2))?;
290
+ rb_llm.define_method("model_name", method!(LLM::model_name, 0))?;
291
+ rb_llm.define_method("device", method!(LLM::device, 0))?;
292
+ rb_llm.define_method("clear_cache", method!(LLM::clear_cache, 0))?;
293
+
294
+ Ok(())
295
+ }
@@ -0,0 +1,21 @@
1
+ pub mod embedding_model;
2
+ pub mod tensor;
3
+ pub mod device;
4
+ pub mod dtype;
5
+ pub mod qtensor;
6
+ pub mod result;
7
+ pub mod errors;
8
+ pub mod utils;
9
+ pub mod llm;
10
+
11
+ pub use embedding_model::{EmbeddingModel, EmbeddingModelInner};
12
+ pub use tensor::Tensor;
13
+ pub use device::Device;
14
+ pub use dtype::DType;
15
+ pub use qtensor::QTensor;
16
+ pub use result::Result;
17
+
18
+ // Re-export for convenience
19
+ pub use embedding_model::init as init_embedding_model;
20
+ pub use utils::candle_utils;
21
+ pub use llm::init_llm;
@@ -0,0 +1,69 @@
1
+ use std::sync::Arc;
2
+ use magnus::{method, class, RModule, Error, Module};
3
+
4
+ use crate::ruby::errors::wrap_candle_err;
5
+ use crate::ruby::{Tensor, Result as RbResult};
6
+ use ::candle_core::{quantized::QTensor as CoreQTensor, Device as CoreDevice};
7
+
8
+ #[derive(Debug)]
9
+ #[magnus::wrap(class = "Candle::QTensor", free_immediately, size)]
10
+ /// A quantized tensor.
11
+ pub struct QTensor(Arc<CoreQTensor>);
12
+
13
+ impl std::ops::Deref for QTensor {
14
+ type Target = CoreQTensor;
15
+
16
+ fn deref(&self) -> &Self::Target {
17
+ self.0.as_ref()
18
+ }
19
+ }
20
+
21
+ impl QTensor {
22
+ ///Gets the tensors quantized dtype.
23
+ /// &RETURNS&: str
24
+ pub fn ggml_dtype(&self) -> String {
25
+ format!("{:?}", self.0.dtype())
26
+ }
27
+
28
+ ///Gets the rank of the tensor.
29
+ /// &RETURNS&: int
30
+ pub fn rank(&self) -> usize {
31
+ self.0.rank()
32
+ }
33
+
34
+ ///Gets the shape of the tensor.
35
+ /// &RETURNS&: Tuple[int]
36
+ pub fn shape(&self) -> Vec<usize> {
37
+ self.0.shape().dims().to_vec()
38
+ }
39
+
40
+ pub fn __repr__(&self) -> String {
41
+ format!("{:?}", self.0)
42
+ }
43
+
44
+ pub fn __str__(&self) -> String {
45
+ self.__repr__()
46
+ }
47
+
48
+ /// Dequantizes the tensor.
49
+ /// &RETURNS&: Tensor
50
+ pub fn dequantize(&self) -> RbResult<Tensor> {
51
+ let tensor = self.0.dequantize(&CoreDevice::Cpu).map_err(wrap_candle_err)?;
52
+ Ok(Tensor(tensor))
53
+ }
54
+
55
+ // fn matmul_t(&self, lhs: &Tensor) -> RbResult<Tensor> {
56
+ // let qmatmul = ::candle_core::quantized::QMatMul::from_arc(self.0.clone());
57
+ // let res = qmatmul.forward(lhs).map_err(wrap_candle_err)?;
58
+ // Ok(Tensor(res))
59
+ // }
60
+ }
61
+
62
+ pub fn init(rb_candle: RModule) -> Result<(), Error> {
63
+ let rb_qtensor = rb_candle.define_class("QTensor", class::object())?;
64
+ rb_qtensor.define_method("ggml_dtype", method!(QTensor::ggml_dtype, 0))?;
65
+ rb_qtensor.define_method("rank", method!(QTensor::rank, 0))?;
66
+ rb_qtensor.define_method("shape", method!(QTensor::shape, 0))?;
67
+ rb_qtensor.define_method("dequantize", method!(QTensor::dequantize, 0))?;
68
+ Ok(())
69
+ }
@@ -0,0 +1,3 @@
1
+ use magnus::Error;
2
+
3
+ pub type Result<T> = std::result::Result<T, Error>;