red-candle 1.8.0.pre3-aarch64-linux

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. checksums.yaml +7 -0
  2. data/Cargo.lock +5021 -0
  3. data/Cargo.toml +6 -0
  4. data/Gemfile +3 -0
  5. data/LICENSE +22 -0
  6. data/README.md +1171 -0
  7. data/Rakefile +167 -0
  8. data/bin/console +11 -0
  9. data/bin/setup +17 -0
  10. data/ext/candle/Cargo.toml +38 -0
  11. data/ext/candle/build.rs +117 -0
  12. data/ext/candle/extconf.rb +79 -0
  13. data/ext/candle/rustfmt.toml +63 -0
  14. data/ext/candle/src/gvl.rs +58 -0
  15. data/ext/candle/src/lib.rs +59 -0
  16. data/ext/candle/src/llm/constrained_generation_test.rs +395 -0
  17. data/ext/candle/src/llm/gemma.rs +313 -0
  18. data/ext/candle/src/llm/generation_config.rs +63 -0
  19. data/ext/candle/src/llm/glm4.rs +236 -0
  20. data/ext/candle/src/llm/granite.rs +308 -0
  21. data/ext/candle/src/llm/granitemoehybrid.rs +315 -0
  22. data/ext/candle/src/llm/llama.rs +396 -0
  23. data/ext/candle/src/llm/mistral.rs +309 -0
  24. data/ext/candle/src/llm/mod.rs +49 -0
  25. data/ext/candle/src/llm/phi.rs +369 -0
  26. data/ext/candle/src/llm/quantized_gguf.rs +734 -0
  27. data/ext/candle/src/llm/qwen.rs +261 -0
  28. data/ext/candle/src/llm/qwen3.rs +257 -0
  29. data/ext/candle/src/llm/text_generation.rs +284 -0
  30. data/ext/candle/src/ruby/device.rs +234 -0
  31. data/ext/candle/src/ruby/dtype.rs +39 -0
  32. data/ext/candle/src/ruby/embedding_model.rs +477 -0
  33. data/ext/candle/src/ruby/errors.rs +16 -0
  34. data/ext/candle/src/ruby/llm.rs +730 -0
  35. data/ext/candle/src/ruby/mod.rs +24 -0
  36. data/ext/candle/src/ruby/ner.rs +444 -0
  37. data/ext/candle/src/ruby/reranker.rs +488 -0
  38. data/ext/candle/src/ruby/result.rs +3 -0
  39. data/ext/candle/src/ruby/structured.rs +92 -0
  40. data/ext/candle/src/ruby/tensor.rs +731 -0
  41. data/ext/candle/src/ruby/tokenizer.rs +343 -0
  42. data/ext/candle/src/ruby/utils.rs +96 -0
  43. data/ext/candle/src/ruby/vlm.rs +330 -0
  44. data/ext/candle/src/structured/integration_test.rs +130 -0
  45. data/ext/candle/src/structured/mod.rs +31 -0
  46. data/ext/candle/src/structured/schema_processor.rs +215 -0
  47. data/ext/candle/src/structured/vocabulary_adapter.rs +152 -0
  48. data/ext/candle/src/structured/vocabulary_adapter_real_test.rs +66 -0
  49. data/ext/candle/src/structured/vocabulary_adapter_simple_test.rs +70 -0
  50. data/ext/candle/src/tokenizer/loader.rs +108 -0
  51. data/ext/candle/src/tokenizer/mod.rs +104 -0
  52. data/ext/candle/tests/device_tests.rs +43 -0
  53. data/ext/candle/tests/tensor_tests.rs +162 -0
  54. data/lib/candle/3.1/candle.so +0 -0
  55. data/lib/candle/3.2/candle.so +0 -0
  56. data/lib/candle/3.3/candle.so +0 -0
  57. data/lib/candle/3.4/candle.so +0 -0
  58. data/lib/candle/4.0/candle.so +0 -0
  59. data/lib/candle/agent.rb +68 -0
  60. data/lib/candle/build_info.rb +67 -0
  61. data/lib/candle/device_utils.rb +10 -0
  62. data/lib/candle/embedding_model.rb +75 -0
  63. data/lib/candle/embedding_model_type.rb +31 -0
  64. data/lib/candle/llm.rb +595 -0
  65. data/lib/candle/logger.rb +149 -0
  66. data/lib/candle/ner.rb +368 -0
  67. data/lib/candle/reranker.rb +45 -0
  68. data/lib/candle/tensor.rb +99 -0
  69. data/lib/candle/tokenizer.rb +139 -0
  70. data/lib/candle/tool.rb +47 -0
  71. data/lib/candle/tool_call_parser.rb +57 -0
  72. data/lib/candle/version.rb +5 -0
  73. data/lib/candle/vlm.rb +31 -0
  74. data/lib/candle.rb +29 -0
  75. data/lib/red-candle.rb +1 -0
  76. metadata +309 -0
@@ -0,0 +1,313 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_nn::VarBuilder;
3
+ use candle_transformers::models::gemma::{Config, Model as GemmaModel};
4
+ use hf_hub::{api::tokio::Api, Repo};
5
+ use tokenizers::Tokenizer;
6
+
7
+ use super::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
8
+
9
+ #[derive(Debug)]
10
+ pub struct Gemma {
11
+ model: GemmaModel,
12
+ tokenizer: TokenizerWrapper,
13
+ device: Device,
14
+ model_id: String,
15
+ eos_token_id: u32,
16
+ }
17
+
18
+ impl Gemma {
19
+ pub fn eos_token_id(&self) -> u32 {
20
+ self.eos_token_id
21
+ }
22
+
23
+ /// Clear the KV cache between generations
24
+ pub fn clear_kv_cache(&mut self) {
25
+ self.model.clear_kv_cache();
26
+ }
27
+
28
+ /// Get the tokenizer
29
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
30
+ &self.tokenizer
31
+ }
32
+
33
+ /// Load a Gemma model from HuggingFace Hub with optional custom tokenizer
34
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
35
+ let api = Api::new()
36
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
37
+
38
+ let repo = api.repo(Repo::model(model_id.to_string()));
39
+
40
+ // Download model files
41
+ let config_filename = repo
42
+ .get("config.json")
43
+ .await
44
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
45
+
46
+ // Download tokenizer from custom source if provided, otherwise from model repo
47
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
48
+ let tokenizer_repo = api.repo(Repo::model(tokenizer_id.to_string()));
49
+ let tokenizer_filename = tokenizer_repo
50
+ .get("tokenizer.json")
51
+ .await
52
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
53
+ Tokenizer::from_file(tokenizer_filename)
54
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
55
+ } else {
56
+ let tokenizer_filename = repo
57
+ .get("tokenizer.json")
58
+ .await
59
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
60
+ Tokenizer::from_file(tokenizer_filename)
61
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
62
+ };
63
+
64
+ // Try different file patterns for model weights
65
+ let weights_filenames = if let Ok(single_file) = repo.get("model.safetensors").await {
66
+ vec![single_file]
67
+ } else {
68
+ // Try to find sharded model files
69
+ // NOTE: This uses a brute-force approach, trying common shard counts.
70
+ // A better approach would be to read model.safetensors.index.json which
71
+ // contains the exact file list, but this works for most models (≤12 shards).
72
+ let mut sharded_files = Vec::new();
73
+ let mut index = 1;
74
+ loop {
75
+ // Try common shard counts for Gemma models
76
+ let mut found = false;
77
+ for total in [2, 3, 4, 5, 6, 7, 8, 10, 12] {
78
+ let filename = format!("model-{:05}-of-{:05}.safetensors", index, total);
79
+ if let Ok(file) = repo.get(&filename).await {
80
+ sharded_files.push(file);
81
+ found = true;
82
+ break;
83
+ }
84
+ }
85
+ if !found {
86
+ break;
87
+ }
88
+ index += 1;
89
+ }
90
+
91
+ if sharded_files.is_empty() {
92
+ return Err(candle_core::Error::Msg(
93
+ "Could not find model weights. Tried: model.safetensors, model-*-of-*.safetensors".to_string()
94
+ ));
95
+ }
96
+ sharded_files
97
+ };
98
+
99
+ // Load config
100
+ let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)
101
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
102
+
103
+
104
+ // Gemma uses specific tokens
105
+ let eos_token_id = {
106
+ let vocab = tokenizer.get_vocab(true);
107
+ vocab.get("<eos>")
108
+ .or_else(|| vocab.get("<end_of_turn>"))
109
+ .copied()
110
+ .unwrap_or(1) // Default Gemma EOS
111
+ };
112
+
113
+ // Load model weights
114
+ let vb = unsafe {
115
+ VarBuilder::from_mmaped_safetensors(&weights_filenames, DType::F32, &device)?
116
+ };
117
+
118
+ let model = GemmaModel::new(false, &config, vb)?; // Don't use flash attention for now
119
+
120
+ Ok(Self {
121
+ model,
122
+ tokenizer: TokenizerWrapper::new(tokenizer),
123
+ device,
124
+ model_id: model_id.to_string(),
125
+ eos_token_id,
126
+ })
127
+ }
128
+
129
+ /// Load a Gemma model from HuggingFace Hub (backwards compatibility)
130
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
131
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
132
+ }
133
+
134
+ /// Create from existing components (useful for testing)
135
+ pub fn new(
136
+ model: GemmaModel,
137
+ tokenizer: Tokenizer,
138
+ device: Device,
139
+ model_id: String,
140
+ ) -> Self {
141
+ let eos_token_id = {
142
+ let vocab = tokenizer.get_vocab(true);
143
+ vocab.get("<eos>")
144
+ .or_else(|| vocab.get("<end_of_turn>"))
145
+ .copied()
146
+ .unwrap_or(1)
147
+ };
148
+
149
+ Self {
150
+ model,
151
+ tokenizer: TokenizerWrapper::new(tokenizer),
152
+ device,
153
+ model_id,
154
+ eos_token_id,
155
+ }
156
+ }
157
+
158
+ fn generate_tokens(
159
+ &mut self,
160
+ prompt_tokens: Vec<u32>,
161
+ config: &GenerationConfig,
162
+ mut callback: Option<impl FnMut(&str)>,
163
+ ) -> CandleResult<Vec<u32>> {
164
+ let mut text_gen = TextGeneration::new(config);
165
+ text_gen.set_eos_token_id(self.eos_token_id);
166
+ text_gen.set_tokens(prompt_tokens.clone());
167
+
168
+ let mut all_tokens = prompt_tokens.clone();
169
+ let start_gen = all_tokens.len();
170
+
171
+ for index in 0..config.max_length {
172
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
173
+ let start_pos = all_tokens.len().saturating_sub(context_size);
174
+ let ctxt = &all_tokens[start_pos..];
175
+
176
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
177
+ let input = input.contiguous()?;
178
+ let logits = self.model.forward(&input, start_pos)?;
179
+
180
+ let logits = logits.squeeze(0)?;
181
+ let logits = if logits.dims().len() == 2 {
182
+ let seq_len = logits.dim(0)?;
183
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
184
+ } else {
185
+ logits
186
+ };
187
+
188
+ let logits = logits.to_dtype(DType::F32)?;
189
+
190
+ let next_token = text_gen.sample_next_token(&logits)?;
191
+
192
+ all_tokens.push(next_token);
193
+
194
+ // Stream callback
195
+ if let Some(ref mut cb) = callback {
196
+ if config.debug_tokens {
197
+ // In debug mode, only show debug tokens
198
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
199
+ cb(&format!("[{}:{}]", next_token, token_piece));
200
+ } else {
201
+ // Normal mode: use incremental decoding for proper text
202
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
203
+ cb(&decoded_text);
204
+ }
205
+ }
206
+
207
+ // Check stop conditions
208
+ if text_gen.should_stop(next_token, config.max_length) {
209
+ break;
210
+ }
211
+
212
+ // Check if constraint is satisfied (early stopping)
213
+ if config.stop_on_constraint_satisfaction {
214
+ let satisfied = if config.stop_on_match {
215
+ text_gen.is_constraint_satisfied_stop_on_match()
216
+ } else {
217
+ text_gen.is_constraint_satisfied()
218
+ };
219
+ if satisfied {
220
+ break;
221
+ }
222
+ }
223
+
224
+ // Check stop sequences
225
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
226
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
227
+ break;
228
+ }
229
+ }
230
+
231
+ Ok(if config.include_prompt {
232
+ all_tokens
233
+ } else {
234
+ all_tokens[start_gen..].to_vec()
235
+ })
236
+ }
237
+
238
+ /// Apply Gemma chat template
239
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
240
+ let mut prompt = String::new();
241
+
242
+ // Gemma uses a specific format:
243
+ // <start_of_turn>user\n{user_message}<end_of_turn>
244
+ // <start_of_turn>model\n{model_message}<end_of_turn>
245
+
246
+ for message in messages {
247
+ let role = message["role"].as_str().unwrap_or("");
248
+ let content = message["content"].as_str().unwrap_or("");
249
+
250
+ match role {
251
+ "system" => {
252
+ // Gemma doesn't have explicit system messages, prepend to first user message
253
+ prompt.push_str(&format!("<start_of_turn>user\nSystem: {}\n", content));
254
+ }
255
+ "user" => {
256
+ if !prompt.contains("<start_of_turn>user") || prompt.ends_with("<end_of_turn>\n") {
257
+ prompt.push_str("<start_of_turn>user\n");
258
+ }
259
+ prompt.push_str(&format!("{}<end_of_turn>\n", content));
260
+ }
261
+ "assistant" | "model" => {
262
+ prompt.push_str(&format!("<start_of_turn>model\n{}<end_of_turn>\n", content));
263
+ }
264
+ _ => {}
265
+ }
266
+ }
267
+
268
+ // Add the model prompt
269
+ prompt.push_str("<start_of_turn>model\n");
270
+
271
+ Ok(prompt)
272
+ }
273
+ }
274
+
275
+ impl TextGenerator for Gemma {
276
+ fn generate(
277
+ &mut self,
278
+ prompt: &str,
279
+ config: &GenerationConfig,
280
+ ) -> CandleResult<String> {
281
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
282
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
283
+
284
+ if config.debug_tokens {
285
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
286
+ } else {
287
+ self.tokenizer.decode(&output_tokens, true)
288
+ }
289
+ }
290
+
291
+ fn generate_stream(
292
+ &mut self,
293
+ prompt: &str,
294
+ config: &GenerationConfig,
295
+ mut callback: impl FnMut(&str),
296
+ ) -> CandleResult<String> {
297
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
298
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
299
+ self.tokenizer.decode(&output_tokens, true)
300
+ }
301
+
302
+ fn model_name(&self) -> &str {
303
+ &self.model_id
304
+ }
305
+
306
+ fn device(&self) -> &Device {
307
+ &self.device
308
+ }
309
+
310
+ fn clear_cache(&mut self) {
311
+ self.clear_kv_cache();
312
+ }
313
+ }
@@ -0,0 +1,63 @@
1
+ use std::time::{SystemTime, UNIX_EPOCH};
2
+ use std::sync::Arc;
3
+ use crate::structured::Index;
4
+
5
+ /// Configuration for text generation
6
+ #[derive(Debug, Clone)]
7
+ pub struct GenerationConfig {
8
+ /// The maximum number of tokens to generate
9
+ pub max_length: usize,
10
+ /// The temperature for sampling
11
+ pub temperature: f64,
12
+ /// The top-p value for nucleus sampling
13
+ pub top_p: Option<f64>,
14
+ /// The top-k value for top-k sampling
15
+ pub top_k: Option<usize>,
16
+ /// The repetition penalty
17
+ pub repetition_penalty: f32,
18
+ /// The repetition penalty range
19
+ pub repetition_penalty_last_n: usize,
20
+ /// Random seed for sampling
21
+ pub seed: u64,
22
+ /// Stop sequences
23
+ pub stop_sequences: Vec<String>,
24
+ /// Whether to return the prompt in the output
25
+ pub include_prompt: bool,
26
+ /// Whether to show raw tokens during generation (for debugging)
27
+ pub debug_tokens: bool,
28
+ /// Optional constraint index for structured generation
29
+ pub constraint: Option<Arc<Index>>,
30
+ /// Stop immediately when constraint is satisfied
31
+ pub stop_on_constraint_satisfaction: bool,
32
+ /// Whether to stop immediately when pattern is matched (vs allowing continuation)
33
+ pub stop_on_match: bool,
34
+ }
35
+
36
+ /// Generate a random seed based on current time
37
+ fn random_seed() -> u64 {
38
+ SystemTime::now()
39
+ .duration_since(UNIX_EPOCH)
40
+ .map(|d| d.as_nanos() as u64)
41
+ .unwrap_or(42)
42
+ }
43
+
44
+ impl Default for GenerationConfig {
45
+ fn default() -> Self {
46
+ Self {
47
+ max_length: 512,
48
+ temperature: 0.7,
49
+ top_p: None,
50
+ top_k: None,
51
+ repetition_penalty: 1.1,
52
+ repetition_penalty_last_n: 64,
53
+ seed: random_seed(),
54
+ stop_sequences: vec![],
55
+ include_prompt: false,
56
+ debug_tokens: false,
57
+ constraint: None,
58
+ stop_on_constraint_satisfaction: true,
59
+ stop_on_match: true,
60
+ }
61
+ }
62
+ }
63
+
@@ -0,0 +1,236 @@
1
+ use candle_core::{DType, Device, Result as CandleResult, Tensor};
2
+ use candle_transformers::models::glm4_new::{Config, ModelForCausalLM as Glm4Model};
3
+ use hf_hub::api::tokio::Api;
4
+ use tokenizers::Tokenizer;
5
+
6
+ use crate::llm::{GenerationConfig, TextGeneration, TextGenerator, TokenizerWrapper};
7
+
8
+ #[derive(Debug)]
9
+ pub struct Glm4 {
10
+ model: Glm4Model,
11
+ tokenizer: TokenizerWrapper,
12
+ device: Device,
13
+ model_id: String,
14
+ eos_token_id: u32,
15
+ }
16
+
17
+ impl Glm4 {
18
+ pub fn eos_token_id(&self) -> u32 {
19
+ self.eos_token_id
20
+ }
21
+
22
+ pub fn tokenizer(&self) -> &TokenizerWrapper {
23
+ &self.tokenizer
24
+ }
25
+
26
+ pub fn clear_kv_cache(&mut self) {
27
+ self.model.clear_kv_cache();
28
+ }
29
+
30
+ pub async fn from_pretrained_with_tokenizer(model_id: &str, device: Device, tokenizer_source: Option<&str>) -> CandleResult<Self> {
31
+ let api = Api::new()
32
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to create HF API: {}", e)))?;
33
+
34
+ let repo = api.model(model_id.to_string());
35
+
36
+ let config_filename = repo.get("config.json").await
37
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download config: {}", e)))?;
38
+ let config_str = std::fs::read_to_string(config_filename)?;
39
+ let config: Config = serde_json::from_str(&config_str)
40
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to parse config: {}", e)))?;
41
+
42
+ let tokenizer = if let Some(tokenizer_id) = tokenizer_source {
43
+ let tokenizer_repo = api.model(tokenizer_id.to_string());
44
+ let tokenizer_filename = tokenizer_repo.get("tokenizer.json").await
45
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer from {}: {}", tokenizer_id, e)))?;
46
+ Tokenizer::from_file(tokenizer_filename)
47
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
48
+ } else {
49
+ let tokenizer_filename = repo.get("tokenizer.json").await
50
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download tokenizer: {}", e)))?;
51
+ Tokenizer::from_file(tokenizer_filename)
52
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to load tokenizer: {}", e)))?
53
+ };
54
+
55
+ let vocab = tokenizer.get_vocab(true);
56
+ let eos_token_id = vocab.get("<|endoftext|>")
57
+ .or_else(|| vocab.get("<|user|>"))
58
+ .or_else(|| vocab.get("</s>"))
59
+ .copied()
60
+ .unwrap_or(151329);
61
+
62
+ let mut filenames = vec![];
63
+ let num_shards = if model_id.contains("9b") || model_id.contains("9B") { 4 } else { 1 };
64
+
65
+ if num_shards == 1 {
66
+ let filename = repo.get("model.safetensors").await
67
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download model weights: {}", e)))?;
68
+ filenames.push(filename);
69
+ } else {
70
+ for shard_idx in 1..=num_shards {
71
+ let filename = repo.get(&format!("model-{:05}-of-{:05}.safetensors", shard_idx, num_shards)).await
72
+ .map_err(|e| candle_core::Error::Msg(format!("Failed to download shard {}: {}", shard_idx, e)))?;
73
+ filenames.push(filename);
74
+ }
75
+ }
76
+
77
+ let vb = unsafe {
78
+ candle_nn::VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)?
79
+ };
80
+
81
+ let model = Glm4Model::new(&config, vb)?;
82
+
83
+ Ok(Self {
84
+ model,
85
+ tokenizer: TokenizerWrapper::new(tokenizer),
86
+ device,
87
+ model_id: model_id.to_string(),
88
+ eos_token_id,
89
+ })
90
+ }
91
+
92
+ pub async fn from_pretrained(model_id: &str, device: Device) -> CandleResult<Self> {
93
+ Self::from_pretrained_with_tokenizer(model_id, device, None).await
94
+ }
95
+
96
+ pub fn apply_chat_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
97
+ let mut prompt = String::new();
98
+
99
+ prompt.push_str("[gMASK]<sop>");
100
+
101
+ for message in messages {
102
+ let role = message["role"].as_str().unwrap_or("");
103
+ let content = message["content"].as_str().unwrap_or("");
104
+
105
+ match role {
106
+ "system" => {
107
+ prompt.push_str(&format!("<|system|>\n{}", content));
108
+ }
109
+ "user" => {
110
+ prompt.push_str(&format!("<|user|>\n{}", content));
111
+ }
112
+ "assistant" => {
113
+ prompt.push_str(&format!("<|assistant|>\n{}", content));
114
+ }
115
+ _ => {}
116
+ }
117
+ }
118
+
119
+ prompt.push_str("<|assistant|>\n");
120
+
121
+ Ok(prompt)
122
+ }
123
+
124
+ fn generate_tokens(
125
+ &mut self,
126
+ prompt_tokens: Vec<u32>,
127
+ config: &GenerationConfig,
128
+ mut callback: Option<impl FnMut(&str)>,
129
+ ) -> CandleResult<Vec<u32>> {
130
+ let mut text_gen = TextGeneration::new(config);
131
+ text_gen.set_eos_token_id(self.eos_token_id);
132
+ text_gen.set_tokens(prompt_tokens.clone());
133
+
134
+ let mut all_tokens = prompt_tokens.clone();
135
+ let start_gen = all_tokens.len();
136
+
137
+ for index in 0..config.max_length {
138
+ let context_size = if index > 0 { 1 } else { all_tokens.len() };
139
+ let start_pos = all_tokens.len().saturating_sub(context_size);
140
+ let ctxt = &all_tokens[start_pos..];
141
+
142
+ let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
143
+ let logits = self.model.forward(&input, start_pos)?;
144
+ let logits = logits.squeeze(0)?;
145
+
146
+ let logits = if logits.dims().len() == 2 {
147
+ let seq_len = logits.dim(0)?;
148
+ logits.narrow(0, seq_len - 1, 1)?.squeeze(0)?
149
+ } else {
150
+ logits
151
+ };
152
+
153
+ let logits = logits.to_dtype(DType::F32)?;
154
+
155
+ let next_token = text_gen.sample_next_token(&logits)?;
156
+
157
+ all_tokens.push(next_token);
158
+
159
+ if let Some(ref mut cb) = callback {
160
+ if config.debug_tokens {
161
+ let token_piece = self.tokenizer.token_to_piece(next_token)?;
162
+ cb(&format!("[{}:{}]", next_token, token_piece));
163
+ } else {
164
+ let decoded_text = self.tokenizer.decode_incremental(&all_tokens, all_tokens.len() - 1)?;
165
+ cb(&decoded_text);
166
+ }
167
+ }
168
+
169
+ if text_gen.should_stop(next_token, config.max_length) {
170
+ break;
171
+ }
172
+
173
+ if config.stop_on_constraint_satisfaction {
174
+ let satisfied = if config.stop_on_match {
175
+ text_gen.is_constraint_satisfied_stop_on_match()
176
+ } else {
177
+ text_gen.is_constraint_satisfied()
178
+ };
179
+ if satisfied {
180
+ break;
181
+ }
182
+ }
183
+
184
+ let generated_text = self.tokenizer.decode(&all_tokens[start_gen..], true)?;
185
+ if text_gen.check_stop_sequences(&generated_text, &config.stop_sequences) {
186
+ break;
187
+ }
188
+ }
189
+
190
+ Ok(if config.include_prompt {
191
+ all_tokens
192
+ } else {
193
+ all_tokens[start_gen..].to_vec()
194
+ })
195
+ }
196
+ }
197
+
198
+ impl TextGenerator for Glm4 {
199
+ fn generate(
200
+ &mut self,
201
+ prompt: &str,
202
+ config: &GenerationConfig,
203
+ ) -> CandleResult<String> {
204
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
205
+ let output_tokens = self.generate_tokens(prompt_tokens, config, None::<fn(&str)>)?;
206
+
207
+ if config.debug_tokens {
208
+ self.tokenizer.format_tokens_with_debug(&output_tokens)
209
+ } else {
210
+ self.tokenizer.decode(&output_tokens, true)
211
+ }
212
+ }
213
+
214
+ fn generate_stream(
215
+ &mut self,
216
+ prompt: &str,
217
+ config: &GenerationConfig,
218
+ mut callback: impl FnMut(&str),
219
+ ) -> CandleResult<String> {
220
+ let prompt_tokens = self.tokenizer.encode(prompt, true)?;
221
+ let output_tokens = self.generate_tokens(prompt_tokens, config, Some(&mut callback))?;
222
+ self.tokenizer.decode(&output_tokens, true)
223
+ }
224
+
225
+ fn model_name(&self) -> &str {
226
+ &self.model_id
227
+ }
228
+
229
+ fn device(&self) -> &Device {
230
+ &self.device
231
+ }
232
+
233
+ fn clear_cache(&mut self) {
234
+ self.clear_kv_cache();
235
+ }
236
+ }