red-candle 1.0.0.pre.1 → 1.0.0.pre.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 34ca4771af8508ace5ee8df5d5111a407c79fea307f9fbc2b72d9e46d6dd2099
4
- data.tar.gz: 39a316c2990f8766257e07c5ae8803ba852aebc5c93b14349548b1da82b6ae4a
3
+ metadata.gz: 9cdd24afacff8070c3011dad25e975cb4bff9420d7a760ec105d840480bd0a4f
4
+ data.tar.gz: df1d8df277e5c8447c0a6598c2f7ed1389e74b356b42b3bb8f51b2f4c7863884
5
5
  SHA512:
6
- metadata.gz: 51fcc58936f25818485bacd5dfc86911ca1ecbc0a8b87f6fdcf2d015a91faccd6aeb889df33a931eae6d864984254ab4a81ab7050fb00caf14d090820cf4d86b
7
- data.tar.gz: f6e58cc4e723c12e3d290fcb45b1cab5923594bfd011be223ec327c65c7c91cb7d37f9330ed82eadd4b7eb3085ef0ba9b874a6b98c15221b21c60cd01de11418
6
+ metadata.gz: b1a4c0df852b01a3cdff7f78dbe0a6553cd39ff51027b35cca4f8ff6c63d4d89b411493d644fe4d4da29d516bbfa48c1b58a5a3b317556041848004909b63edc
7
+ data.tar.gz: facd28b47751a365e5ef1fdf900f55100392eeee4d5528c3f46e340ec8d69c913ea5698aa4c2900c926947b7e444bc9f876cc79644c81fe731d71f0970715ff3
@@ -0,0 +1,116 @@
1
+ use std::env;
2
+ use std::path::Path;
3
+
4
+ fn main() {
5
+ // Register our custom cfg flags with rustc
6
+ println!("cargo::rustc-check-cfg=cfg(force_cpu)");
7
+ println!("cargo::rustc-check-cfg=cfg(has_cuda)");
8
+ println!("cargo::rustc-check-cfg=cfg(has_metal)");
9
+ println!("cargo::rustc-check-cfg=cfg(has_mkl)");
10
+ println!("cargo::rustc-check-cfg=cfg(has_accelerate)");
11
+
12
+ println!("cargo:rerun-if-changed=build.rs");
13
+ println!("cargo:rerun-if-env-changed=CANDLE_FORCE_CPU");
14
+ println!("cargo:rerun-if-env-changed=CANDLE_CUDA_PATH");
15
+ println!("cargo:rerun-if-env-changed=CUDA_ROOT");
16
+ println!("cargo:rerun-if-env-changed=CUDA_PATH");
17
+ println!("cargo:rerun-if-env-changed=CANDLE_FEATURES");
18
+ println!("cargo:rerun-if-env-changed=CANDLE_ENABLE_CUDA");
19
+
20
+ // Check if we should force CPU only
21
+ if env::var("CANDLE_FORCE_CPU").is_ok() {
22
+ println!("cargo:rustc-cfg=force_cpu");
23
+ println!("cargo:warning=CANDLE_FORCE_CPU is set, disabling all acceleration");
24
+ return;
25
+ }
26
+
27
+ // Detect CUDA availability
28
+ let cuda_available = detect_cuda();
29
+ let cuda_enabled = env::var("CANDLE_ENABLE_CUDA").is_ok();
30
+
31
+ if cuda_available && cuda_enabled {
32
+ println!("cargo:rustc-cfg=has_cuda");
33
+ println!("cargo:warning=CUDA detected and enabled via CANDLE_ENABLE_CUDA");
34
+ } else if cuda_available && !cuda_enabled {
35
+ println!("cargo:warning=CUDA detected but not enabled. To enable CUDA support (coming soon), set CANDLE_ENABLE_CUDA=1");
36
+ }
37
+
38
+ // Detect Metal availability (macOS only)
39
+ #[cfg(target_os = "macos")]
40
+ {
41
+ println!("cargo:rustc-cfg=has_metal");
42
+ println!("cargo:warning=Metal detected (macOS), Metal acceleration will be available");
43
+ }
44
+
45
+ // Detect MKL availability
46
+ if detect_mkl() {
47
+ println!("cargo:rustc-cfg=has_mkl");
48
+ println!("cargo:warning=Intel MKL detected, MKL acceleration will be available");
49
+ }
50
+
51
+ // Detect Accelerate framework (macOS)
52
+ #[cfg(target_os = "macos")]
53
+ {
54
+ println!("cargo:rustc-cfg=has_accelerate");
55
+ println!("cargo:warning=Accelerate framework detected (macOS)");
56
+ }
57
+ }
58
+
59
+ fn detect_cuda() -> bool {
60
+ // Check environment variables first
61
+ if env::var("CANDLE_CUDA_PATH").is_ok() {
62
+ return true;
63
+ }
64
+
65
+ if env::var("CUDA_ROOT").is_ok() || env::var("CUDA_PATH").is_ok() {
66
+ return true;
67
+ }
68
+
69
+ // Check common CUDA installation paths
70
+ let cuda_paths = [
71
+ "/usr/local/cuda",
72
+ "/opt/cuda",
73
+ "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA",
74
+ "C:\\CUDA",
75
+ ];
76
+
77
+ for path in &cuda_paths {
78
+ if Path::new(path).exists() {
79
+ return true;
80
+ }
81
+ }
82
+
83
+ // Check if nvcc is in PATH
84
+ if let Ok(path_var) = env::var("PATH") {
85
+ for path in env::split_paths(&path_var) {
86
+ if path.join("nvcc").exists() || path.join("nvcc.exe").exists() {
87
+ return true;
88
+ }
89
+ }
90
+ }
91
+
92
+ false
93
+ }
94
+
95
+ fn detect_mkl() -> bool {
96
+ // Check environment variables
97
+ if env::var("MKLROOT").is_ok() || env::var("MKL_ROOT").is_ok() {
98
+ return true;
99
+ }
100
+
101
+ // Check common MKL installation paths
102
+ let mkl_paths = [
103
+ "/opt/intel/mkl",
104
+ "/opt/intel/oneapi/mkl/latest",
105
+ "C:\\Program Files (x86)\\Intel\\oneAPI\\mkl\\latest",
106
+ "C:\\Program Files\\Intel\\oneAPI\\mkl\\latest",
107
+ ];
108
+
109
+ for path in &mkl_paths {
110
+ if Path::new(path).exists() {
111
+ return true;
112
+ }
113
+ }
114
+
115
+ false
116
+ }
@@ -1,7 +1,7 @@
1
- use magnus::{function, method, prelude::*, Ruby};
1
+ use magnus::{function, prelude::*, Ruby};
2
2
 
3
3
  use crate::ruby::candle_utils;
4
- use crate::ruby::{DType, Device, QTensor, Result as RbResult, Tensor};
4
+ use crate::ruby::Result as RbResult;
5
5
 
6
6
  pub mod llm;
7
7
  pub mod reranker;
@@ -42,101 +42,11 @@ fn init(ruby: &Ruby) -> RbResult<()> {
42
42
  ruby::init_embedding_model(rb_candle)?;
43
43
  ruby::init_llm(rb_candle)?;
44
44
  reranker::init(rb_candle)?;
45
+ ruby::dtype::init(rb_candle)?;
46
+ ruby::qtensor::init(rb_candle)?;
47
+ ruby::device::init(rb_candle)?;
48
+ ruby::tensor::init(rb_candle)?;
45
49
  candle_utils(rb_candle)?;
46
- let rb_tensor = rb_candle.define_class("Tensor", Ruby::class_object(ruby))?;
47
- rb_tensor.define_singleton_method("new", function!(Tensor::new, 3))?;
48
- // rb_tensor.define_singleton_method("cat", function!(Tensor::cat, 2))?;
49
- // rb_tensor.define_singleton_method("stack", function!(Tensor::stack, 2))?;
50
- rb_tensor.define_singleton_method("rand", function!(Tensor::rand, 2))?;
51
- rb_tensor.define_singleton_method("randn", function!(Tensor::randn, 2))?;
52
- rb_tensor.define_singleton_method("ones", function!(Tensor::ones, 2))?;
53
- rb_tensor.define_singleton_method("zeros", function!(Tensor::zeros, 2))?;
54
- rb_tensor.define_method("values", method!(Tensor::values, 0))?;
55
- rb_tensor.define_method("values_f32", method!(Tensor::values_f32, 0))?;
56
- rb_tensor.define_method("item", method!(Tensor::item, 0))?;
57
- rb_tensor.define_method("shape", method!(Tensor::shape, 0))?;
58
- rb_tensor.define_method("stride", method!(Tensor::stride, 0))?;
59
- rb_tensor.define_method("dtype", method!(Tensor::dtype, 0))?;
60
- rb_tensor.define_method("device", method!(Tensor::device, 0))?;
61
- rb_tensor.define_method("rank", method!(Tensor::rank, 0))?;
62
- rb_tensor.define_method("elem_count", method!(Tensor::elem_count, 0))?;
63
- rb_tensor.define_method("sin", method!(Tensor::sin, 0))?;
64
- rb_tensor.define_method("cos", method!(Tensor::cos, 0))?;
65
- rb_tensor.define_method("log", method!(Tensor::log, 0))?;
66
- rb_tensor.define_method("sqr", method!(Tensor::sqr, 0))?;
67
- rb_tensor.define_method("mean", method!(Tensor::mean, 1))?;
68
- rb_tensor.define_method("sum", method!(Tensor::sum, 1))?;
69
- rb_tensor.define_method("sqrt", method!(Tensor::sqrt, 0))?;
70
- rb_tensor.define_method("/", method!(Tensor::__truediv__, 1))?; // Accepts Tensor, Float, or Integer
71
- rb_tensor.define_method("recip", method!(Tensor::recip, 0))?;
72
- rb_tensor.define_method("exp", method!(Tensor::exp, 0))?;
73
- rb_tensor.define_method("powf", method!(Tensor::powf, 1))?;
74
- rb_tensor.define_method("index_select", method!(Tensor::index_select, 2))?;
75
- rb_tensor.define_method("matmul", method!(Tensor::matmul, 1))?;
76
- rb_tensor.define_method("broadcast_add", method!(Tensor::broadcast_add, 1))?;
77
- rb_tensor.define_method("broadcast_sub", method!(Tensor::broadcast_sub, 1))?;
78
- rb_tensor.define_method("broadcast_mul", method!(Tensor::broadcast_mul, 1))?;
79
- rb_tensor.define_method("broadcast_div", method!(Tensor::broadcast_div, 1))?;
80
- rb_tensor.define_method("where_cond", method!(Tensor::where_cond, 2))?;
81
- rb_tensor.define_method("+", method!(Tensor::__add__, 1))?;
82
- rb_tensor.define_method("*", method!(Tensor::__mul__, 1))?;
83
- rb_tensor.define_method("-", method!(Tensor::__sub__, 1))?;
84
- rb_tensor.define_method("reshape", method!(Tensor::reshape, 1))?;
85
- rb_tensor.define_method("broadcast_as", method!(Tensor::broadcast_as, 1))?;
86
- rb_tensor.define_method("broadcast_left", method!(Tensor::broadcast_left, 1))?;
87
- rb_tensor.define_method("squeeze", method!(Tensor::squeeze, 1))?;
88
- rb_tensor.define_method("unsqueeze", method!(Tensor::unsqueeze, 1))?;
89
- rb_tensor.define_method("get", method!(Tensor::get, 1))?;
90
- rb_tensor.define_method("[]", method!(Tensor::get, 1))?;
91
- rb_tensor.define_method("transpose", method!(Tensor::transpose, 2))?;
92
- rb_tensor.define_method("narrow", method!(Tensor::narrow, 3))?;
93
- rb_tensor.define_method("argmax_keepdim", method!(Tensor::argmax_keepdim, 1))?;
94
- rb_tensor.define_method("argmin_keepdim", method!(Tensor::argmin_keepdim, 1))?;
95
- rb_tensor.define_method("max_keepdim", method!(Tensor::max_keepdim, 1))?;
96
- rb_tensor.define_method("min_keepdim", method!(Tensor::min_keepdim, 1))?;
97
- // rb_tensor.define_method("eq", method!(Tensor::eq, 1))?;
98
- // rb_tensor.define_method("ne", method!(Tensor::ne, 1))?;
99
- // rb_tensor.define_method("lt", method!(Tensor::lt, 1))?;
100
- // rb_tensor.define_method("gt", method!(Tensor::gt, 1))?;
101
- // rb_tensor.define_method("ge", method!(Tensor::ge, 1))?;
102
- // rb_tensor.define_method("le", method!(Tensor::le, 1))?;
103
- rb_tensor.define_method("sum_all", method!(Tensor::sum_all, 0))?;
104
- rb_tensor.define_method("mean_all", method!(Tensor::mean_all, 0))?;
105
- rb_tensor.define_method("flatten_from", method!(Tensor::flatten_from, 1))?;
106
- rb_tensor.define_method("flatten_to", method!(Tensor::flatten_to, 1))?;
107
- rb_tensor.define_method("flatten_all", method!(Tensor::flatten_all, 0))?;
108
- rb_tensor.define_method("t", method!(Tensor::t, 0))?;
109
- rb_tensor.define_method("contiguous", method!(Tensor::contiguous, 0))?;
110
- rb_tensor.define_method("is_contiguous", method!(Tensor::is_contiguous, 0))?;
111
- rb_tensor.define_method(
112
- "is_fortran_contiguous",
113
- method!(Tensor::is_fortran_contiguous, 0),
114
- )?;
115
- rb_tensor.define_method("detach", method!(Tensor::detach, 0))?;
116
- rb_tensor.define_method("copy", method!(Tensor::copy, 0))?;
117
- rb_tensor.define_method("to_dtype", method!(Tensor::to_dtype, 1))?;
118
- rb_tensor.define_method("to_device", method!(Tensor::to_device, 1))?;
119
- rb_tensor.define_method("to_s", method!(Tensor::__str__, 0))?;
120
- rb_tensor.define_method("inspect", method!(Tensor::__repr__, 0))?;
121
-
122
- let rb_dtype = rb_candle.define_class("DType", Ruby::class_object(ruby))?;
123
- rb_dtype.define_method("to_s", method!(DType::__str__, 0))?;
124
- rb_dtype.define_method("inspect", method!(DType::__repr__, 0))?;
125
-
126
- let rb_device = rb_candle.define_class("Device", Ruby::class_object(ruby))?;
127
- rb_device.define_singleton_method("cpu", function!(Device::cpu, 0))?;
128
- rb_device.define_singleton_method("cuda", function!(Device::cuda, 0))?;
129
- rb_device.define_singleton_method("metal", function!(Device::metal, 0))?;
130
- rb_device.define_singleton_method("available_devices", function!(ruby::device::available_devices, 0))?;
131
- rb_device.define_singleton_method("default", function!(ruby::device::default_device, 0))?;
132
- rb_device.define_method("to_s", method!(Device::__str__, 0))?;
133
- rb_device.define_method("inspect", method!(Device::__repr__, 0))?;
134
-
135
- let rb_qtensor = rb_candle.define_class("QTensor", Ruby::class_object(ruby))?;
136
- rb_qtensor.define_method("ggml_dtype", method!(QTensor::ggml_dtype, 0))?;
137
- rb_qtensor.define_method("rank", method!(QTensor::rank, 0))?;
138
- rb_qtensor.define_method("shape", method!(QTensor::shape, 0))?;
139
- rb_qtensor.define_method("dequantize", method!(QTensor::dequantize, 0))?;
140
50
 
141
51
  Ok(())
142
52
  }
@@ -0,0 +1,49 @@
1
+ use std::time::{SystemTime, UNIX_EPOCH};
2
+
3
+ /// Configuration for text generation
4
+ #[derive(Debug, Clone)]
5
+ pub struct GenerationConfig {
6
+ /// The maximum number of tokens to generate
7
+ pub max_length: usize,
8
+ /// The temperature for sampling
9
+ pub temperature: f64,
10
+ /// The top-p value for nucleus sampling
11
+ pub top_p: Option<f64>,
12
+ /// The top-k value for top-k sampling
13
+ pub top_k: Option<usize>,
14
+ /// The repetition penalty
15
+ pub repetition_penalty: f32,
16
+ /// The repetition penalty range
17
+ pub repetition_penalty_last_n: usize,
18
+ /// Random seed for sampling
19
+ pub seed: u64,
20
+ /// Stop sequences
21
+ pub stop_sequences: Vec<String>,
22
+ /// Whether to return the prompt in the output
23
+ pub include_prompt: bool,
24
+ }
25
+
26
+ /// Generate a random seed based on current time
27
+ fn random_seed() -> u64 {
28
+ SystemTime::now()
29
+ .duration_since(UNIX_EPOCH)
30
+ .map(|d| d.as_nanos() as u64)
31
+ .unwrap_or(42)
32
+ }
33
+
34
+ impl Default for GenerationConfig {
35
+ fn default() -> Self {
36
+ Self {
37
+ max_length: 512,
38
+ temperature: 0.7,
39
+ top_p: None,
40
+ top_k: None,
41
+ repetition_penalty: 1.1,
42
+ repetition_penalty_last_n: 64,
43
+ seed: random_seed(),
44
+ stop_sequences: vec![],
45
+ include_prompt: false,
46
+ }
47
+ }
48
+ }
49
+
@@ -0,0 +1,325 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::mistral::{Config, Model as MistralModel};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Mistral {
11
+ model: MistralModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Mistral {
19
+ /// Clear the KV cache between generations
20
+ pub fn clear_kv_cache(&mut self) {
21
+ self.model.clear_kv_cache();
22
+ }
23
+
24
+ /// Load a Mistral model from HuggingFace Hub
25
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
26
+ let api = Api::new()
27
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
28
+
29
+ let repo = api.repo(Repo::model(model_id.to_string()));
30
+
31
+ // Download model files
32
+ let config_filename = repo
33
+ .get("config.json")
34
+ .await
35
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
36
+
37
+ let tokenizer_filename = repo
38
+ .get("tokenizer.json")
39
+ .await
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
41
+
42
+ // Try different file patterns for model weights
43
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
44
+ vec![single_file]
45
+ } else if let Ok(consolidated_file) = repo.get("consolidated.safetensors").await {
46
+ // Some Mistral models use consolidated.safetensors
47
+ vec![consolidated_file]
48
+ } else {
49
+ // Try to find sharded model files
50
+ let mut sharded_files = Vec::new();
51
+ let mut index = 1;
52
+ loop {
53
+ // Try common shard counts
54
+ let mut found = false;
55
+ for total in [2, 3, 4, 5, 6, 7, 8] {
56
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
57
+ if let Ok(file) = repo.get(&filename).await {
58
+ sharded_files.push(file);
59
+ found = true;
60
+ break;
61
+ }
62
+ }
63
+ if !found {
64
+ break;
65
+ }
66
+ index += 1;
67
+ }
68
+
69
+ if sharded_files.is_empty() {
70
+ // Try single pytorch_model.bin as last resort (though we prefer safetensors)
71
+ if let Ok(_pytorch_file) = repo.get("pytorch_model.bin").await {
72
+ return Err(candle_core::Error::Msg(
73
+ "Only safetensors format is supported. This model uses pytorch_model.bin format.".to_string()
74
+ ));
75
+ } else {
76
+ return Err(candle_core::Error::Msg(
77
+ "Could not find model weights. Tried: model.safetensors, consolidated.safetensors, model-*-of-*.safetensors".to_string()
78
+ ));
79
+ }
80
+ }
81
+ sharded_files
82
+ };
83
+
84
+ // Load config
85
+ let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
86
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
87
+
88
+ // Load tokenizer
89
+ let tokenizer = Tokenizer::from_file(tokenizer_filename)
90
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?;
91
+
92
+ let eos_token_id = tokenizer
93
+ .get_vocab(true)
94
+ .get("</s>")
95
+ .copied()
96
+ .unwrap_or(2);
97
+
98
+ // Load model weights
99
+ let vb = unsafe {
100
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
101
+ };
102
+
103
+ let model = MistralModel::new(&config, vb)?;
104
+
105
+ Ok(Self {
106
+ model,
107
+ tokenizer: TokenizerWrapper::new(tokenizer),
108
+ device,
109
+ model_id: model_id.to_string(),
110
+ eos_token_id,
111
+ })
112
+ }
113
+
114
+ /// Create from existing components (useful for testing)
115
+ pub fn new(
116
+ model: MistralModel,
117
+ tokenizer: Tokenizer,
118
+ device: Device,
119
+ model_id: String,
120
+ ) -> Self {
121
+ let eos_token_id = tokenizer
122
+ .get_vocab(true)
123
+ .get("</s>")
124
+ .copied()
125
+ .unwrap_or(2);
126
+
127
+ Self {
128
+ model,
129
+ tokenizer: TokenizerWrapper::new(tokenizer),
130
+ device,
131
+ model_id,
132
+ eos_token_id,
133
+ }
134
+ }
135
+
136
+ fn generate_tokens(
137
+ &mut self,
138
+ prompt_tokens: Vec<u32>,
139
+ config: &GenerationConfig,
140
+ mut callback: Option<impl FnMut(&str)>,
141
+ ) -> CandleResult<Vec<u32>> {
142
+ let mut text_gen = TextGeneration::from_config(config);
143
+ text_gen.set_eos_token_id(self.eos_token_id);
144
+ text_gen.set_tokens(prompt_tokens.clone());
145
+
146
+ let mut all_tokens = prompt_tokens.clone();
147
+ let start_gen = all_tokens.len();
148
+
149
+ for index in 0..config.max_length {
150
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
151
+ let start_pos = all_tokens.len().saturating_sub(context_size);
152
+ let ctxt = &all_tokens[start_pos..];
153
+
154
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
155
+ // Ensure input tensor is contiguous for Metal backend
156
+ let input = input.contiguous()?;
157
+ let logits = self.model.forward(&input, start_pos)?;
158
+
159
+ // The model returns logits of shape [batch_size, seq_len, vocab_size]
160
+ // We need to get the logits for the last token only
161
+ let logits = logits.squeeze(0)?; // Remove batch dimension
162
+ let logits = if logits.dims().len() == 2 {
163
+ // If we still have [seq_len, vocab_size], take the last token
164
+ let seq_len = logits.dim(0)?;
165
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
166
+ } else {
167
+ // Already [vocab_size]
168
+ logits
169
+ };
170
+
171
+ // Convert to F32 for sampling if needed
172
+ let logits = logits.to_dtype(DType::F32)?;
173
+
174
+ let next_token = text_gen.sample_next_token(
175
+ &logits,
176
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
177
+ )?;
178
+
179
+ all_tokens.push(next_token);
180
+
181
+ // Stream callback
182
+ if let Some(ref mut cb) = callback {
183
+ let token_text = self.tokenizer.token_to_piece(next_token)?;
184
+ cb(&token_text);
185
+ }
186
+
187
+ // Check stop conditions
188
+ if text_gen.should_stop(next_token, config.max_length) {
189
+ break;
190
+ }
191
+
192
+ // Check stop sequences
193
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
194
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
195
+ break;
196
+ }
197
+ }
198
+
199
+ Ok(if config.include_prompt {
200
+ all_tokens
201
+ } else {
202
+ all_tokens[start_gen..].to_vec()
203
+ })
204
+ }
205
+
206
+ fn generate_tokens_decoded(
207
+ &mut self,
208
+ prompt_tokens: Vec<u32>,
209
+ config: &GenerationConfig,
210
+ mut callback: Option<impl FnMut(&str)>,
211
+ ) -> CandleResult<Vec<u32>> {
212
+ let mut text_gen = TextGeneration::from_config(config);
213
+ text_gen.set_eos_token_id(self.eos_token_id);
214
+ text_gen.set_tokens(prompt_tokens.clone());
215
+
216
+ let mut all_tokens = prompt_tokens.clone();
217
+ let start_gen = all_tokens.len();
218
+
219
+ // For incremental decoding
220
+ let mut previously_decoded = String::new();
221
+
222
+ for index in 0..config.max_length {
223
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
224
+ let start_pos = all_tokens.len().saturating_sub(context_size);
225
+ let ctxt = &all_tokens[start_pos..];
226
+
227
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
228
+ // Ensure input tensor is contiguous for Metal backend
229
+ let input = input.contiguous()?;
230
+ let logits = self.model.forward(&input, start_pos)?;
231
+
232
+ // The model returns logits of shape [batch_size, seq_len, vocab_size]
233
+ // We need to get the logits for the last token only
234
+ let logits = logits.squeeze(0)?; // Remove batch dimension
235
+ let logits = if logits.dims().len() == 2 {
236
+ // If we still have [seq_len, vocab_size], take the last token
237
+ let seq_len = logits.dim(0)?;
238
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
239
+ } else {
240
+ // Already [vocab_size]
241
+ logits
242
+ };
243
+
244
+ // Convert to F32 for sampling if needed
245
+ let logits = logits.to_dtype(DType::F32)?;
246
+
247
+ let next_token = text_gen.sample_next_token(
248
+ &logits,
249
+ Some((config.repetition_penalty, config.repetition_penalty_last_n)),
250
+ )?;
251
+
252
+ all_tokens.push(next_token);
253
+
254
+ // Stream callback with incremental decoding
255
+ if let Some(ref mut cb) = callback {
256
+ // Decode all generated tokens so far
257
+ let current_decoded = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
258
+
259
+ // Only emit the new text since last callback
260
+ if current_decoded.len() > previously_decoded.len() {
261
+ let new_text = &current_decoded[previously_decoded.len()..];
262
+ cb(new_text);
263
+ previously_decoded = current_decoded;
264
+ }
265
+ }
266
+
267
+ // Check stop conditions
268
+ if text_gen.should_stop(next_token, config.max_length) {
269
+ break;
270
+ }
271
+
272
+ // Check stop sequences
273
+ let generated_text = if callback.is_some() {
274
+ previously_decoded.clone()
275
+ } else {
276
+ self.tokenizer.decode(&all_tokens[start_gen..], true)?
277
+ };
278
+
279
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
280
+ break;
281
+ }
282
+ }
283
+
284
+ Ok(if config.include_prompt {
285
+ all_tokens
286
+ } else {
287
+ all_tokens[start_gen..].to_vec()
288
+ })
289
+ }
290
+ }
291
+
292
+ impl TextGenerator for Mistral {
293
+ fn generate(
294
+ &mut self,
295
+ prompt: &str,
296
+ config: &GenerationConfig,
297
+ ) -> CandleResult<String> {
298
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
299
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
300
+ self.tokenizer.decode(&output_tokens, true)
301
+ }
302
+
303
+ fn generate_stream(
304
+ &mut self,
305
+ prompt: &str,
306
+ config: &GenerationConfig,
307
+ mut callback: impl FnMut(&str),
308
+ ) -> CandleResult<String> {
309
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
310
+ let output_tokens = self.generate_tokens_decoded(prompt_tokens, config, Some(&mut callback))?;
311
+ self.tokenizer.decode(&output_tokens, true)
312
+ }
313
+
314
+ fn model_name(&self) -> &str {
315
+ &self.model_id
316
+ }
317
+
318
+ fn device(&self) -> &Device {
319
+ &self.device
320
+ }
321
+
322
+ fn clear_cache(&mut self) {
323
+ self.clear_kv_cache();
324
+ }
325
+ }
@@ -0,0 +1,68 @@
1
+ use candle_core::{Device, Result as CandleResult};
2
+ use tokenizers::Tokenizer;
3
+
4
+ pub mod mistral;
5
+ pub mod generation_config;
6
+ pub mod text_generation;
7
+
8
+ pub use generation_config::GenerationConfig;
9
+ pub use text_generation::TextGeneration;
10
+
11
+ /// Trait for text generation models
12
+ pub trait TextGenerator: Send + Sync {
13
+ /// Generate text from a prompt
14
+ fn generate(
15
+ &mut self,
16
+ prompt: &str,
17
+ config: &GenerationConfig,
18
+ ) -> CandleResult<String>;
19
+
20
+ /// Generate text with streaming callback
21
+ fn generate_stream(
22
+ &mut self,
23
+ prompt: &str,
24
+ config: &GenerationConfig,
25
+ callback: impl FnMut(&str),
26
+ ) -> CandleResult<String>;
27
+
28
+ /// Get the model's name
29
+ fn model_name(&self) -> &str;
30
+
31
+ /// Get the device the model is running on
32
+ fn device(&self) -> &Device;
33
+
34
+ /// Clear any cached state (like KV cache)
35
+ fn clear_cache(&mut self);
36
+ }
37
+
38
+ /// Common structure for managing tokenizer
39
+ #[derive(Debug)]
40
+ pub struct TokenizerWrapper {
41
+ tokenizer: Tokenizer,
42
+ }
43
+
44
+ impl TokenizerWrapper {
45
+ pub fn new(tokenizer: Tokenizer) -> Self {
46
+ Self { tokenizer }
47
+ }
48
+
49
+ pub fn encode(&self, text: &str, add_special_tokens: bool) -> CandleResult<Vec<u32>> {
50
+ let encoding = self.tokenizer
51
+ .encode(text, add_special_tokens)
52
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer error: {}", e)))?;
53
+ Ok(encoding.get_ids().to_vec())
54
+ }
55
+
56
+ pub fn decode(&self, tokens: &[u32], skip_special_tokens: bool) -> CandleResult<String> {
57
+ self.tokenizer
58
+ .decode(tokens, skip_special_tokens)
59
+ .map_err(|e| candle_core::Error::Msg(format!("Tokenizer decode error: {}", e)))
60
+ }
61
+
62
+ pub fn token_to_piece(&self, token: u32) -> CandleResult<String> {
63
+ self.tokenizer
64
+ .id_to_token(token)
65
+ .map(|s| s.to_string())
66
+ .ok_or_else(|| candle_core::Error::Msg(format!("Unknown token id: {}", token)))
67
+ }
68
+ }