red-candle 1.8.0-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,308 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::granite::{
4
+ Cache, Config, Granite as GraniteModel, GraniteConfig,
5
+ };
6
+ use hf_hub::{api::tokio::Api, Repo};
7
+ use tokenizers::Tokenizer;
8
+
9
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
10
+
11
+ #[derive(Debug)]
12
+ pub struct Granite {
13
+ model: GraniteModel,
14
+ tokenizer: TokenizerWrapper,
15
+ device: Device,
16
+ model_id: String,
17
+ eos_token_id: u32,
18
+ cache: Cache,
19
+ config: Config,
20
+ }
21
+
22
+ impl Granite {
23
+ pub fn eos_token_id(&self) -> u32 {
24
+ self.eos_token_id
25
+ }
26
+
27
+ pub fn clear_kv_cache(&mut self) {
28
+ if let Ok(new_cache) =
29
+ Cache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device)
30
+ {
31
+ self.cache = new_cache;
32
+ }
33
+ }
34
+
35
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
36
+ &self.tokenizer
37
+ }
38
+
39
+ pub async fn from_pretrained_with_tokenizer(
40
+ model_id: &str,
41
+ device: Device,
42
+ tokenizer_source: Option<&str>,
43
+ ) -> CandleResult<Self> {
44
+ let api = Api::new()
45
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
46
+
47
+ let repo = api.repo(Repo::model(model_id.to_string()));
48
+
49
+ let config_filename = repo
50
+ .get("config.json")
51
+ .await
52
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
53
+
54
+ let config_str = std::fs::read_to_string(config_filename)?;
55
+ let granite_config: GraniteConfig = serde_json::from_str(&config_str)
56
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
57
+ let config = granite_config.into_config(false);
58
+
59
+ let config_json: serde_json::Value = serde_json::from_str(&config_str)
60
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
61
+ let tie_word_embeddings = config_json
62
+ .get("tie_word_embeddings")
63
+ .and_then(|v| v.as_bool())
64
+ .unwrap_or(false);
65
+
66
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
67
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
68
+ let tokenizer_filename = tokenizer_repo
69
+ .get("tokenizer.json")
70
+ .await
71
+ .map_err(|e| {
72
+ candle_core::Error::Msg(format!(
73
+ "Failed to download tokenizer from {}: {}",
74
+ tokenizer_id, e
75
+ ))
76
+ })?;
77
+ Tokenizer::from_file(tokenizer_filename)
78
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
79
+ } else {
80
+ let tokenizer_filename = repo.get("tokenizer.json").await.map_err(|e| {
81
+ candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e))
82
+ })?;
83
+ Tokenizer::from_file(tokenizer_filename)
84
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
85
+ };
86
+
87
+ let vocab = tokenizer.get_vocab(true);
88
+ let eos_token_id = vocab
89
+ .get("<|end_of_text|>")
90
+ .or_else(|| vocab.get("<|endoftext|>"))
91
+ .or_else(|| vocab.get("</s>"))
92
+ .copied()
93
+ .unwrap_or(0);
94
+
95
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
96
+ vec![single_file]
97
+ } else {
98
+ let mut sharded_files = Vec::new();
99
+ let mut index = 1;
100
+ loop {
101
+ let mut found = false;
102
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
103
+ let filename =
104
+ format!("model-{:05}-of-{:05}.safetensors", index, total);
105
+ if let Ok(file) = repo.get(&filename).await {
106
+ sharded_files.push(file);
107
+ found = true;
108
+ break;
109
+ }
110
+ }
111
+ if !found {
112
+ break;
113
+ }
114
+ index += 1;
115
+ }
116
+
117
+ if sharded_files.is_empty() {
118
+ return Err(candle_core::Error::Msg(
119
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string(),
120
+ ));
121
+ }
122
+ sharded_files
123
+ };
124
+
125
+ let vb = unsafe {
126
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
127
+ };
128
+
129
+ let vb = if tie_word_embeddings {
130
+ vb.rename_f(|name: &str| {
131
+ if name == "lm_head.weight" {
132
+ "model.embed_tokens.weight".to_string()
133
+ } else {
134
+ name.to_string()
135
+ }
136
+ })
137
+ } else {
138
+ vb
139
+ };
140
+
141
+ let model = GraniteModel::load(vb, &config)?;
142
+ let cache = Cache::new(true, DType::F32, &config, &device)?;
143
+
144
+ Ok(Self {
145
+ model,
146
+ tokenizer: TokenizerWrapper::new(tokenizer),
147
+ device,
148
+ model_id: model_id.to_string(),
149
+ eos_token_id,
150
+ cache,
151
+ config,
152
+ })
153
+ }
154
+
155
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
156
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
157
+ }
158
+
159
+ pub fn apply_chat_template(
160
+ &self,
161
+ messages: &[serde_json::Value],
162
+ ) -> CandleResult<String> {
163
+ let mut prompt = String::new();
164
+
165
+ for message in messages {
166
+ let role = message["role"].as_str().unwrap_or("");
167
+ let content = message["content"].as_str().unwrap_or("");
168
+
169
+ match role {
170
+ "system" => {
171
+ prompt.push_str(&format!(
172
+ "<|start_of_role|>system<|end_of_role|>{}<|end_of_text|>\n",
173
+ content
174
+ ));
175
+ }
176
+ "user" => {
177
+ prompt.push_str(&format!(
178
+ "<|start_of_role|>user<|end_of_role|>{}<|end_of_text|>\n",
179
+ content
180
+ ));
181
+ }
182
+ "assistant" => {
183
+ prompt.push_str(&format!(
184
+ "<|start_of_role|>assistant<|end_of_role|>{}<|end_of_text|>\n",
185
+ content
186
+ ));
187
+ }
188
+ _ => {}
189
+ }
190
+ }
191
+
192
+ prompt.push_str("<|start_of_role|>assistant<|end_of_role|>");
193
+
194
+ Ok(prompt)
195
+ }
196
+
197
+ fn generate_tokens(
198
+ &mut self,
199
+ prompt_tokens: Vec<u32>,
200
+ config: &GenerationConfig,
201
+ mut callback: Option<impl FnMut(&str)>,
202
+ ) -> CandleResult<Vec<u32>> {
203
+ let mut text_gen = TextGeneration::new(config);
204
+ text_gen.set_eos_token_id(self.eos_token_id);
205
+ text_gen.set_tokens(prompt_tokens.clone());
206
+
207
+ let mut all_tokens = prompt_tokens.clone();
208
+ let start_gen = all_tokens.len();
209
+
210
+ for index in 0..config.max_length {
211
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
212
+ let start_pos = all_tokens.len().saturating_sub(context_size);
213
+ let ctxt = &all_tokens[start_pos..];
214
+
215
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
216
+ let input = input.contiguous()?;
217
+ let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
218
+
219
+ let logits = logits.squeeze(0)?;
220
+ let logits = if logits.dims().len() == 2 {
221
+ let seq_len = logits.dim(0)?;
222
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
223
+ } else {
224
+ logits
225
+ };
226
+
227
+ let logits = logits.to_dtype(DType::F32)?;
228
+
229
+ let next_token = text_gen.sample_next_token(&logits)?;
230
+
231
+ all_tokens.push(next_token);
232
+
233
+ if let Some(ref mut cb) = callback {
234
+ if config.debug_tokens {
235
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
236
+ cb(&format!("[{}:{}]", next_token, token_piece));
237
+ } else {
238
+ let decoded_text =
239
+ self.tokenizer
240
+ .decode_incremental(&all_tokens, all_tokens.len() - 1)?;
241
+ cb(&decoded_text);
242
+ }
243
+ }
244
+
245
+ if text_gen.should_stop(next_token, config.max_length) {
246
+ break;
247
+ }
248
+
249
+ if config.stop_on_constraint_satisfaction {
250
+ let satisfied = if config.stop_on_match {
251
+ text_gen.is_constraint_satisfied_stop_on_match()
252
+ } else {
253
+ text_gen.is_constraint_satisfied()
254
+ };
255
+ if satisfied {
256
+ break;
257
+ }
258
+ }
259
+
260
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
261
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
262
+ break;
263
+ }
264
+ }
265
+
266
+ Ok(if config.include_prompt {
267
+ all_tokens
268
+ } else {
269
+ all_tokens[start_gen..].to_vec()
270
+ })
271
+ }
272
+ }
273
+
274
+ impl TextGenerator for Granite {
275
+ fn generate(&mut self, prompt: &str, config: &GenerationConfig) -> CandleResult<String> {
276
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
277
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
278
+
279
+ if config.debug_tokens {
280
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
281
+ } else {
282
+ self.tokenizer.decode(&output_tokens, true)
283
+ }
284
+ }
285
+
286
+ fn generate_stream(
287
+ &mut self,
288
+ prompt: &str,
289
+ config: &GenerationConfig,
290
+ mut callback: impl FnMut(&str),
291
+ ) -> CandleResult<String> {
292
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
293
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
294
+ self.tokenizer.decode(&output_tokens, true)
295
+ }
296
+
297
+ fn model_name(&self) -> &str {
298
+ &self.model_id
299
+ }
300
+
301
+ fn device(&self) -> &Device {
302
+ &self.device
303
+ }
304
+
305
+ fn clear_cache(&mut self) {
306
+ self.clear_kv_cache();
307
+ }
308
+ }
@@ -0,0 +1,315 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::granitemoehybrid::{
4
+ GraniteMoeHybrid as GraniteMoeHybridModel, GraniteMoeHybridCache,
5
+ GraniteMoeHybridConfig, GraniteMoeHybridInternalConfig,
6
+ };
7
+ use hf_hub::{api::tokio::Api, Repo};
8
+ use tokenizers::Tokenizer;
9
+
10
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
11
+
12
+ #[derive(Debug)]
13
+ pub struct GraniteMoeHybrid {
14
+ model: GraniteMoeHybridModel,
15
+ tokenizer: TokenizerWrapper,
16
+ device: Device,
17
+ model_id: String,
18
+ eos_token_id: u32,
19
+ cache: GraniteMoeHybridCache,
20
+ config: GraniteMoeHybridInternalConfig,
21
+ }
22
+
23
+ impl GraniteMoeHybrid {
24
+ pub fn eos_token_id(&self) -> u32 {
25
+ self.eos_token_id
26
+ }
27
+
28
+ pub fn clear_kv_cache(&mut self) {
29
+ if let Ok(new_cache) =
30
+ GraniteMoeHybridCache::new(self.cache.use_kv_cache, DType::F32, &self.config, &self.device)
31
+ {
32
+ self.cache = new_cache;
33
+ }
34
+ }
35
+
36
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
37
+ &self.tokenizer
38
+ }
39
+
40
+ pub async fn from_pretrained_with_tokenizer(
41
+ model_id: &str,
42
+ device: Device,
43
+ tokenizer_source: Option<&str>,
44
+ ) -> CandleResult<Self> {
45
+ let api = Api::new()
46
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
47
+
48
+ let repo = api.repo(Repo::model(model_id.to_string()));
49
+
50
+ let config_filename = repo
51
+ .get("config.json")
52
+ .await
53
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
54
+
55
+ let config_str = std::fs::read_to_string(config_filename)?;
56
+ let granite_config: GraniteMoeHybridConfig = serde_json::from_str(&config_str)
57
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
58
+ let config = granite_config.into_config(false);
59
+
60
+ let config_json: serde_json::Value = serde_json::from_str(&config_str)
61
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config JSON: {}", e)))?;
62
+ let tie_word_embeddings = config_json
63
+ .get("tie_word_embeddings")
64
+ .and_then(|v| v.as_bool())
65
+ .unwrap_or(false);
66
+
67
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
68
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
69
+ let tokenizer_filename = tokenizer_repo
70
+ .get("tokenizer.json")
71
+ .await
72
+ .map_err(|e| {
73
+ candle_core::Error::Msg(format!(
74
+ "Failed to download tokenizer from {}: {}",
75
+ tokenizer_id, e
76
+ ))
77
+ })?;
78
+ Tokenizer::from_file(tokenizer_filename)
79
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
80
+ } else {
81
+ let tokenizer_filename = repo.get("tokenizer.json").await.map_err(|e| {
82
+ candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e))
83
+ })?;
84
+ Tokenizer::from_file(tokenizer_filename)
85
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
86
+ };
87
+
88
+ let vocab = tokenizer.get_vocab(true);
89
+ let eos_token_id = vocab
90
+ .get("<|end_of_text|>")
91
+ .or_else(|| vocab.get("<|endoftext|>"))
92
+ .or_else(|| vocab.get("</s>"))
93
+ .copied()
94
+ .unwrap_or(0);
95
+
96
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
97
+ vec![single_file]
98
+ } else {
99
+ let mut sharded_files = Vec::new();
100
+ let mut index = 1;
101
+ loop {
102
+ let mut found = false;
103
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 15, 20, 30] {
104
+ let filename =
105
+ format!("model-{:05}-of-{:05}.safetensors", index, total);
106
+ if let Ok(file) = repo.get(&filename).await {
107
+ sharded_files.push(file);
108
+ found = true;
109
+ break;
110
+ }
111
+ }
112
+ if !found {
113
+ break;
114
+ }
115
+ index += 1;
116
+ }
117
+
118
+ if sharded_files.is_empty() {
119
+ return Err(candle_core::Error::Msg(
120
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string(),
121
+ ));
122
+ }
123
+ sharded_files
124
+ };
125
+
126
+ let vb = unsafe {
127
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
128
+ };
129
+
130
+ let vb = if tie_word_embeddings {
131
+ vb.rename_f(|name: &str| {
132
+ if name == "lm_head.weight" {
133
+ "model.embed_tokens.weight".to_string()
134
+ } else {
135
+ name.to_string()
136
+ }
137
+ })
138
+ } else {
139
+ vb
140
+ };
141
+
142
+ let model = GraniteMoeHybridModel::load(vb, &config)?;
143
+ let cache = GraniteMoeHybridCache::new(true, DType::F32, &config, &device)?;
144
+
145
+ Ok(Self {
146
+ model,
147
+ tokenizer: TokenizerWrapper::new(tokenizer),
148
+ device,
149
+ model_id: model_id.to_string(),
150
+ eos_token_id,
151
+ cache,
152
+ config,
153
+ })
154
+ }
155
+
156
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
157
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
158
+ }
159
+
160
+ pub fn apply_chat_template(
161
+ &self,
162
+ messages: &[serde_json::Value],
163
+ ) -> CandleResult<String> {
164
+ let mut prompt = String::new();
165
+
166
+ for message in messages {
167
+ let role = message["role"].as_str().unwrap_or("");
168
+ let content = message["content"].as_str().unwrap_or("");
169
+
170
+ match role {
171
+ "system" => {
172
+ prompt.push_str(&format!(
173
+ "<|start_of_role|>system<|end_of_role|>{}<|end_of_text|>\n",
174
+ content
175
+ ));
176
+ }
177
+ "user" => {
178
+ prompt.push_str(&format!(
179
+ "<|start_of_role|>user<|end_of_role|>{}<|end_of_text|>\n",
180
+ content
181
+ ));
182
+ }
183
+ "assistant" => {
184
+ prompt.push_str(&format!(
185
+ "<|start_of_role|>assistant<|end_of_role|>{}<|end_of_text|>\n",
186
+ content
187
+ ));
188
+ }
189
+ "tool" => {
190
+ prompt.push_str(&format!(
191
+ "<|start_of_role|>tool<|end_of_role|>{}<|end_of_text|>\n",
192
+ content
193
+ ));
194
+ }
195
+ _ => {}
196
+ }
197
+ }
198
+
199
+ prompt.push_str("<|start_of_role|>assistant<|end_of_role|>");
200
+
201
+ Ok(prompt)
202
+ }
203
+
204
+ fn generate_tokens(
205
+ &mut self,
206
+ prompt_tokens: Vec<u32>,
207
+ config: &GenerationConfig,
208
+ mut callback: Option<impl FnMut(&str)>,
209
+ ) -> CandleResult<Vec<u32>> {
210
+ let mut text_gen = TextGeneration::new(config);
211
+ text_gen.set_eos_token_id(self.eos_token_id);
212
+ text_gen.set_tokens(prompt_tokens.clone());
213
+
214
+ let mut all_tokens = prompt_tokens.clone();
215
+ let start_gen = all_tokens.len();
216
+
217
+ for index in 0..config.max_length {
218
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
219
+ let start_pos = all_tokens.len().saturating_sub(context_size);
220
+ let ctxt = &all_tokens[start_pos..];
221
+
222
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
223
+ let input = input.contiguous()?;
224
+ let logits = self.model.forward(&input, start_pos, &mut self.cache)?;
225
+
226
+ let logits = logits.squeeze(0)?;
227
+ let logits = if logits.dims().len() == 2 {
228
+ let seq_len = logits.dim(0)?;
229
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
230
+ } else {
231
+ logits
232
+ };
233
+
234
+ let logits = logits.to_dtype(DType::F32)?;
235
+
236
+ let next_token = text_gen.sample_next_token(&logits)?;
237
+
238
+ all_tokens.push(next_token);
239
+
240
+ if let Some(ref mut cb) = callback {
241
+ if config.debug_tokens {
242
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
243
+ cb(&format!("[{}:{}]", next_token, token_piece));
244
+ } else {
245
+ let decoded_text =
246
+ self.tokenizer
247
+ .decode_incremental(&all_tokens, all_tokens.len() - 1)?;
248
+ cb(&decoded_text);
249
+ }
250
+ }
251
+
252
+ if text_gen.should_stop(next_token, config.max_length) {
253
+ break;
254
+ }
255
+
256
+ if config.stop_on_constraint_satisfaction {
257
+ let satisfied = if config.stop_on_match {
258
+ text_gen.is_constraint_satisfied_stop_on_match()
259
+ } else {
260
+ text_gen.is_constraint_satisfied()
261
+ };
262
+ if satisfied {
263
+ break;
264
+ }
265
+ }
266
+
267
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
268
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
269
+ break;
270
+ }
271
+ }
272
+
273
+ Ok(if config.include_prompt {
274
+ all_tokens
275
+ } else {
276
+ all_tokens[start_gen..].to_vec()
277
+ })
278
+ }
279
+ }
280
+
281
+ impl TextGenerator for GraniteMoeHybrid {
282
+ fn generate(&mut self, prompt: &str, config: &GenerationConfig) -> CandleResult<String> {
283
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
284
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
285
+
286
+ if config.debug_tokens {
287
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
288
+ } else {
289
+ self.tokenizer.decode(&output_tokens, true)
290
+ }
291
+ }
292
+
293
+ fn generate_stream(
294
+ &mut self,
295
+ prompt: &str,
296
+ config: &GenerationConfig,
297
+ mut callback: impl FnMut(&str),
298
+ ) -> CandleResult<String> {
299
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
300
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
301
+ self.tokenizer.decode(&output_tokens, true)
302
+ }
303
+
304
+ fn model_name(&self) -> &str {
305
+ &self.model_id
306
+ }
307
+
308
+ fn device(&self) -> &Device {
309
+ &self.device
310
+ }
311
+
312
+ fn clear_cache(&mut self) {
313
+ self.clear_kv_cache();
314
+ }
315
+ }