red-candle 1.5.0 → 1.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a15236dd78a04cfcc6d7978ca088934987b84a599371c1508383d147172abcc3
4
- data.tar.gz: f2661ad1220f566dd8a7d9eab1d74bd30007ed022c9381daf2697d12d88fac96
3
+ metadata.gz: 00c0870d599db76ba556ab880f69219bc32d72698c3570cd308df3923b1332f9
4
+ data.tar.gz: bffbd02fa1fe11813a304ed16771e6a72ee01f85ccb445eb67c5da6bf696c670
5
5
  SHA512:
6
- metadata.gz: d74f26e2c0527f25424245442b5b74cf1575505a42a2c1c823be31606433a6c9e4736afc6d8062bd54642cbc8373deaae197bfc7163cbaa392d457a844e67628
7
- data.tar.gz: 1ffc6d400ba4ffe9c23603fb42adae088142446f1d0bc03d0365a2bfcde3c1d1beab51eec01481d1a6e89f686c235e2ec58f0703daed73fabcf5cfd7f7e22d95
6
+ metadata.gz: c9ab39f01e4777c3d9906592cf62b8b1d2d03fd04be4ba82d8ddee837678ac477be88b1a0a48fa898e443f7469681e26f25297d4d310c39fb5222137d3b17d87
7
+ data.tar.gz: 22fe87f0fb64754a0a19bcc632dff72b0cc43838d73797aa991c961b65c9ecb7d9bd7aacbde51baf4063011081f4111a64f53726b51c87aca3e48214304e0119
data/README.md CHANGED
@@ -273,6 +273,85 @@ See [STRUCTURED_GENERATION.md](docs/STRUCTURED_GENERATION.md) for detailed docum
273
273
 
274
274
  **Note on Reliability**: Structured generation constrains the model's output tokens, but success rates vary by model size and schema complexity. Smaller models (< 7B parameters) may occasionally produce incomplete or invalid JSON, especially with complex schemas. Consider implementing retry logic or fallback strategies in production applications. Larger models generally perform much better with structured generation.
275
275
 
276
+ ## Tool Calling
277
+
278
+ Red-candle supports tool/function calling, enabling models to invoke external functions during generation. This works best with models fine-tuned for tool calling, such as Qwen3.
279
+
280
+ ### Defining Tools
281
+
282
+ ```ruby
283
+ get_weather = Candle::Tool.new(
284
+ name: "get_weather",
285
+ description: "Get the current weather for a city",
286
+ parameters: {
287
+ type: "object",
288
+ properties: { city: { type: "string", description: "City name" } },
289
+ required: ["city"]
290
+ }
291
+ ) { |args| { city: args["city"], temperature: 72, condition: "sunny" } }
292
+ ```
293
+
294
+ ### Extracting Tool Calls
295
+
296
+ `chat_with_tools` injects tool definitions into the system prompt, generates a response, and parses any `<tool_call>` tags from the output. It does **not** feed results back to the model — it just tells you what the model wants to call. You decide what to do with it:
297
+
298
+ ```ruby
299
+ llm = Candle::LLM.from_pretrained("Qwen/Qwen3-0.6B")
300
+
301
+ messages = [{ role: "user", content: "What's the weather in San Francisco?" }]
302
+ result = llm.chat_with_tools(messages, tools: [get_weather],
303
+ config: Candle::GenerationConfig.deterministic(max_length: 500))
304
+
305
+ if result.has_tool_calls?
306
+ result.tool_calls.each do |tc|
307
+ puts "#{tc.name}(#{tc.arguments})"
308
+ output = get_weather.call(tc.arguments)
309
+ puts "=> #{output}"
310
+ end
311
+ else
312
+ puts result.text_response
313
+ end
314
+ ```
315
+
316
+ Pass `execute: true` to automatically run the tools (but still no round-trip back to the model):
317
+
318
+ ```ruby
319
+ result = llm.chat_with_tools(messages, tools: [get_weather], execute: true,
320
+ config: Candle::GenerationConfig.deterministic(max_length: 500))
321
+
322
+ result.tool_results.each do |tr|
323
+ puts "#{tr[:tool_call].name} => #{tr[:result]}"
324
+ end
325
+ ```
326
+
327
+ ### Agent (Multi-Turn Tool Loop)
328
+
329
+ `Candle::Agent` completes the round-trip: generate → parse tool calls → execute → feed results back to the model → repeat until the model produces a final text answer or hits `max_iterations`. This is a convenience wrapper for quick prototyping — for production use, frameworks like [RubyLLM](https://github.com/crmne/ruby_llm) manage this loop for you via the [ruby_llm-red_candle](https://github.com/scientist-labs/ruby_llm-red_candle) plugin:
330
+
331
+ ```ruby
332
+ agent = Candle::Agent.new(llm, tools: [get_weather, lookup_price], max_iterations: 5)
333
+ result = agent.run("What's the weather in Paris, and how much does a widget cost?",
334
+ config: Candle::GenerationConfig.deterministic(max_length: 1000))
335
+
336
+ puts result.response # Final text answer from the model
337
+ puts result.iterations # Number of generate cycles
338
+ puts result.tool_calls_made # Number of tools invoked
339
+ ```
340
+
341
+ ### Model Recommendations
342
+
343
+ Tool calling quality depends heavily on model size:
344
+
345
+ | Model | Tool Calling Quality |
346
+ |-------|---------------------|
347
+ | **Qwen3-8B GGUF** (~5 GB) | Calls correct tools, self-corrects errors, but may hallucinate values from tool results |
348
+ | **Qwen3-4B GGUF** (~2.5 GB) | Calls correct tools, occasional reasoning errors |
349
+ | **Qwen3-0.6B** (~1.2 GB) | Single-turn works, needs `max_length: 500+` for thinking |
350
+ | SmolLM2-360M | Does not work |
351
+ | TinyLlama-1.1B | Does not work (not fine-tuned for tool calling) |
352
+
353
+ **Tip:** Qwen3 models use a `<think>` reasoning block before producing tool calls. Set `max_length` high enough (500+ for 0.6B, 1000+ for larger models) to allow room for both thinking and the tool call.
354
+
276
355
  ## ⚠️ Model Format Requirements
277
356
 
278
357
  ### EmbeddingModels and Rerankers: Safetensors Only
@@ -355,12 +434,26 @@ embedding = model.embedding("Hi there!")
355
434
 
356
435
  Red-Candle includes support for cross-encoder reranking models, which can be used to reorder documents by relevance to a query. This is particularly useful for improving search results or implementing retrieval-augmented generation (RAG) systems.
357
436
 
437
+ ### Supported Models
438
+
439
+ Red-Candle supports both **BERT** and **XLM-RoBERTa** reranker architectures. The model type is auto-detected from `config.json`.
440
+
441
+ | Model | Architecture | Params | Notes |
442
+ |-------|-------------|--------|-------|
443
+ | `BAAI/bge-reranker-base` | XLM-RoBERTa | 278M | Recommended — strong quality, multilingual |
444
+ | `BAAI/bge-reranker-large` | XLM-RoBERTa | 560M | Best quality, higher resource usage |
445
+ | `BAAI/bge-reranker-v2-m3` | XLM-RoBERTa | 278M | Multilingual, very strong |
446
+ | `cross-encoder/ms-marco-MiniLM-L-12-v2` | BERT | 33M | Lightweight, English only |
447
+
358
448
  ### Basic Usage
359
449
 
360
450
  ```ruby
361
451
  require 'candle'
362
452
 
363
- # Initialize the reranker with a cross-encoder model
453
+ # Initialize the reranker (BGE reranker recommended for quality)
454
+ reranker = Candle::Reranker.from_pretrained("BAAI/bge-reranker-base")
455
+
456
+ # Or use the lighter BERT-based model
364
457
  reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
365
458
 
366
459
  # Or with custom max_length for truncation (default is 512)
@@ -449,7 +542,7 @@ results = reranker.rerank_with_pooling(query, documents, "cls")
449
542
  results = reranker.rerank_sigmoid_with_pooling(query, documents, "cls")
450
543
  ```
451
544
 
452
- Note: The default "pooler" method is recommended as it matches how cross-encoder models are trained. Other pooling methods may produce different ranking results.
545
+ Note: Pooling methods only apply to BERT-based models. XLM-RoBERTa models (e.g., BGE rerankers) have a built-in classification head and ignore the `pooling_method` parameter. For BERT models, the default "pooler" method is recommended as it matches how cross-encoder models are trained.
453
546
 
454
547
  ### CUDA Support
455
548
 
@@ -467,11 +560,11 @@ Cross-encoder reranking models differ from bi-encoder embedding models:
467
560
  - **Bi-encoders** (like the embedding models above) encode queries and documents separately into dense vectors
468
561
  - **Cross-encoders** process the query and document together, allowing for more nuanced relevance scoring
469
562
 
470
- The reranker uses a BERT-based architecture that:
471
- 1. Concatenates the query and document with special tokens: `[CLS] query [SEP] document [SEP]`
472
- 2. Processes them jointly through BERT layers
473
- 3. Applies a pooler layer (dense + tanh) to the [CLS] token
474
- 4. Uses a classifier layer to produce a single relevance score
563
+ The reranker concatenates the query and document with special tokens and processes them jointly through transformer layers to produce a single relevance score.
564
+
565
+ **BERT models** (e.g., MiniLM): Use a pooler layer (dense + tanh) on the [CLS] token, then a classifier layer. Pooling method is configurable (`pooler`, `cls`, `mean`).
566
+
567
+ **XLM-RoBERTa models** (e.g., BGE rerankers): Use a built-in classification head that returns logits directly. The `pooling_method` parameter is ignored — the model handles its own pooling internally.
475
568
 
476
569
  This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
477
570
 
@@ -319,9 +319,9 @@ impl Llama {
319
319
  if i == 1 || (i == 0 && system_message.is_empty()) {
320
320
  // First user message
321
321
  if !system_message.is_empty() {
322
- prompt.push_str(&format!("<s>[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
322
+ prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
323
323
  } else {
324
- prompt.push_str(&format!("<s>[INST] {} [/INST]", content));
324
+ prompt.push_str(&format!("[INST] {} [/INST]", content));
325
325
  }
326
326
  } else {
327
327
  prompt.push_str(&format!(" [INST] {} [/INST]", content));
@@ -340,7 +340,7 @@ impl Llama {
340
340
  fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
341
341
  let mut prompt = String::new();
342
342
 
343
- prompt.push_str("<|begin_of_text|>");
343
+ // BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
344
344
 
345
345
  for message in messages {
346
346
  let role = message["role"].as_str().unwrap_or("");
@@ -386,9 +386,9 @@ impl QuantizedGGUF {
386
386
  "user" => {
387
387
  if i == 1 || (i == 0 && system_message.is_empty()) {
388
388
  if !system_message.is_empty() {
389
- prompt.push_str(&format!("<s>[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
389
+ prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
390
390
  } else {
391
- prompt.push_str(&format!("<s>[INST] {} [/INST]", content));
391
+ prompt.push_str(&format!("[INST] {} [/INST]", content));
392
392
  }
393
393
  } else {
394
394
  prompt.push_str(&format!(" [INST] {} [/INST]", content));
@@ -406,7 +406,7 @@ impl QuantizedGGUF {
406
406
 
407
407
  fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
408
408
  let mut prompt = String::new();
409
- prompt.push_str("<|begin_of_text|>");
409
+ // BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
410
410
 
411
411
  for message in messages {
412
412
  let role = message["role"].as_str().unwrap_or("");
@@ -481,15 +481,18 @@ impl QuantizedGGUF {
481
481
  "assistant" => {
482
482
  prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
483
483
  }
484
+ "tool" => {
485
+ prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
486
+ }
484
487
  _ => {}
485
488
  }
486
489
  }
487
-
490
+
488
491
  // Add generation prompt
489
492
  prompt.push_str("<|im_start|>assistant\n");
490
493
  Ok(prompt)
491
494
  }
492
-
495
+
493
496
  fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
494
497
  let mut prompt = String::new();
495
498
 
@@ -128,6 +128,9 @@ impl Qwen {
128
128
  "assistant" => {
129
129
  prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
130
130
  }
131
+ "tool" => {
132
+ prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
133
+ }
131
134
  _ => {}
132
135
  }
133
136
  }
@@ -1,18 +1,59 @@
1
1
  use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
2
- use candle_transformers::models::bert::{BertModel, Config};
2
+ use candle_transformers::models::bert::{BertModel, Config as BertConfig};
3
+ use candle_transformers::models::xlm_roberta::{
4
+ XLMRobertaForSequenceClassification, Config as XLMRobertaConfig,
5
+ };
6
+ use candle_transformers::models::debertav2::{
7
+ DebertaV2Model, DebertaV2ContextPooler, Config as DebertaV2Config,
8
+ };
9
+ use candle_transformers::models::modernbert::{
10
+ ModernBert, Config as ModernBertConfig,
11
+ };
12
+ use candle_transformers::models::qwen3::{
13
+ ModelForCausalLM as Qwen3Model, Config as Qwen3Config,
14
+ };
3
15
  use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
4
16
  use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
5
17
  use hf_hub::{api::sync::Api, Repo, RepoType};
6
18
  use tokenizers::{EncodeInput, Tokenizer};
19
+ use std::cell::RefCell;
7
20
  use crate::ruby::{Device, Result};
8
21
  use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
9
22
 
23
+ enum RerankerModel {
24
+ Bert {
25
+ model: BertModel,
26
+ pooler: Linear,
27
+ classifier: Linear,
28
+ },
29
+ XLMRoberta {
30
+ model: XLMRobertaForSequenceClassification,
31
+ pad_token_id: u32,
32
+ },
33
+ DeBERTa {
34
+ model: DebertaV2Model,
35
+ pooler: DebertaV2ContextPooler,
36
+ classifier: Linear,
37
+ pad_token_id: u32,
38
+ },
39
+ ModernBert {
40
+ model: ModernBert,
41
+ head_dense: Linear,
42
+ head_norm: candle_nn::LayerNorm,
43
+ classifier: Linear,
44
+ pad_token_id: u32,
45
+ },
46
+ Qwen3 {
47
+ model: RefCell<Qwen3Model>,
48
+ yes_token_id: u32,
49
+ no_token_id: u32,
50
+ },
51
+ }
52
+
10
53
  #[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
11
54
  pub struct Reranker {
12
- model: BertModel,
55
+ model: RerankerModel,
13
56
  tokenizer: TokenizerWrapper,
14
- pooler: Linear,
15
- classifier: Linear,
16
57
  device: CoreDevice,
17
58
  model_id: String,
18
59
  }
@@ -28,7 +69,7 @@ impl Reranker {
28
69
  let ruby = Ruby::get().unwrap();
29
70
  let runtime_error = ruby.exception_runtime_error();
30
71
 
31
- let result = (|| -> std::result::Result<(BertModel, TokenizerWrapper, Linear, Linear), Box<dyn std::error::Error + Send + Sync>> {
72
+ let result = (|| -> std::result::Result<(RerankerModel, TokenizerWrapper), Box<dyn std::error::Error + Send + Sync>> {
32
73
  let api = Api::new()?;
33
74
  let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
34
75
 
@@ -37,9 +78,10 @@ impl Reranker {
37
78
  let tokenizer_filename = repo.get("tokenizer.json")?;
38
79
  let weights_filename = repo.get("model.safetensors")?;
39
80
 
40
- // Load config
41
- let config = std::fs::read_to_string(config_filename)?;
42
- let config: Config = serde_json::from_str(&config)?;
81
+ // Read raw config to detect model type
82
+ let config_str = std::fs::read_to_string(&config_filename)?;
83
+ let raw_config: serde_json::Value = serde_json::from_str(&config_str)?;
84
+ let model_type = raw_config["model_type"].as_str().unwrap_or("bert");
43
85
 
44
86
  // Setup tokenizer with padding AND truncation
45
87
  let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
@@ -51,21 +93,71 @@ impl Reranker {
51
93
  VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
52
94
  };
53
95
 
54
- // Load BERT model
55
- let model = BertModel::load(vb.pp("bert"), &config)?;
56
-
57
- // Load pooler layer (dense + tanh activation)
58
- let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
59
-
60
- // Load classifier layer for cross-encoder (single output score)
61
- let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
96
+ let model = match model_type {
97
+ "xlm-roberta" => {
98
+ let config: XLMRobertaConfig = serde_json::from_str(&config_str)?;
99
+ let pad_token_id = config.pad_token_id;
100
+ let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
101
+ RerankerModel::XLMRoberta { model, pad_token_id }
102
+ }
103
+ "deberta-v2" => {
104
+ let config: DebertaV2Config = serde_json::from_str(&config_str)?;
105
+ let pad_token_id = config.pad_token_id.unwrap_or(0) as u32;
106
+ let model = DebertaV2Model::load(vb.pp("deberta"), &config)?;
107
+ let pooler = DebertaV2ContextPooler::load(vb.clone(), &config)?;
108
+ let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
109
+ let num_labels = config.id2label.as_ref().map_or(1, |m| m.len());
110
+ let classifier = candle_nn::linear(pooler_hidden_size, num_labels, vb.pp("classifier"))?;
111
+ RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id }
112
+ }
113
+ "qwen3" => {
114
+ let config: Qwen3Config = serde_json::from_str(&config_str)?;
115
+ let model = Qwen3Model::new(&config, vb)?;
116
+
117
+ // Look up "yes" and "no" token IDs from the tokenizer
118
+ let yes_token_id: u32 = tokenizer
119
+ .encode("yes", false)
120
+ .ok()
121
+ .and_then(|enc| enc.get_ids().first().copied())
122
+ .unwrap_or(9693);
123
+ let no_token_id: u32 = tokenizer
124
+ .encode("no", false)
125
+ .ok()
126
+ .and_then(|enc| enc.get_ids().first().copied())
127
+ .unwrap_or(2152);
128
+
129
+ RerankerModel::Qwen3 {
130
+ model: RefCell::new(model),
131
+ yes_token_id,
132
+ no_token_id,
133
+ }
134
+ }
135
+ "modernbert" => {
136
+ let config: ModernBertConfig = serde_json::from_str(&config_str)?;
137
+ let pad_token_id = config.pad_token_id;
138
+ let model = ModernBert::load(vb.clone(), &config)?;
139
+ // ModernBertHead::load is private, so load the head layers manually
140
+ let head_vb = vb.pp("head");
141
+ let head_dense = candle_nn::linear_no_bias(config.hidden_size, config.hidden_size, head_vb.pp("dense"))?;
142
+ let head_norm = candle_nn::layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, head_vb.pp("norm"))?;
143
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
144
+ RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id }
145
+ }
146
+ _ => {
147
+ let config: BertConfig = serde_json::from_str(&config_str)?;
148
+ let model = BertModel::load(vb.pp("bert"), &config)?;
149
+ let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
150
+ let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
151
+ RerankerModel::Bert { model, pooler, classifier }
152
+ }
153
+ };
62
154
 
63
- Ok((model, TokenizerWrapper::new(tokenizer), pooler, classifier))
155
+ Ok((model, TokenizerWrapper::new(tokenizer)))
64
156
  })();
65
157
 
66
158
  match result {
67
- Ok((model, tokenizer, pooler, classifier)) => {
68
- Ok(Self { model, tokenizer, pooler, classifier, device, model_id })
159
+ Ok((model, tokenizer)) => {
160
+ Ok(Self { model, tokenizer, device, model_id })
69
161
  }
70
162
  Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
71
163
  }
@@ -164,49 +256,161 @@ impl Reranker {
164
256
  .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
165
257
  let token_type_ids = Tensor::new(token_type_ids, &self.device)
166
258
  .map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
167
- let attention_mask = token_ids.ne(0u32)
168
- .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
169
-
170
- // Forward pass through BERT
171
- let embeddings = self.model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
172
- .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
173
-
174
- // Apply pooling based on the specified method
175
- let pooled_embeddings = match pooling_method.as_str() {
176
- "pooler" => {
177
- // Extract [CLS] token and apply pooler (dense + tanh)
178
- let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
179
- let pooled = self.pooler.forward(&cls_embeddings)
259
+
260
+ // Compute scores based on model type
261
+ let scores = match &self.model {
262
+ RerankerModel::Bert { model, pooler, classifier } => {
263
+ let attention_mask = token_ids.ne(0u32)
264
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
265
+
266
+ // Forward pass through BERT
267
+ let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
268
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
269
+
270
+ // Apply pooling based on the specified method
271
+ let pooled_embeddings = match pooling_method.as_str() {
272
+ "pooler" => {
273
+ let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
274
+ let pooled = pooler.forward(&cls_embeddings)
275
+ .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
276
+ pooled.tanh()
277
+ .map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
278
+ },
279
+ "cls" => {
280
+ self.extract_cls_embeddings(&embeddings)?
281
+ },
282
+ "mean" => {
283
+ let (_batch, seq_len, _hidden) = embeddings.dims3()
284
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
285
+ let sum = embeddings.sum(1)
286
+ .map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
287
+ (sum / (seq_len as f64))
288
+ .map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
289
+ },
290
+ _ => return Err(Error::new(runtime_error,
291
+ format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
292
+ };
293
+
294
+ let pooled_embeddings = pooled_embeddings.contiguous()
295
+ .map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
296
+ let logits = classifier.forward(&pooled_embeddings)
297
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
298
+ logits.squeeze(1)
299
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
300
+ }
301
+ RerankerModel::XLMRoberta { model, pad_token_id } => {
302
+ let attention_mask = token_ids.ne(*pad_token_id)
303
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
304
+
305
+ // XLMRobertaForSequenceClassification returns logits directly
306
+ let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
307
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
308
+ logits.squeeze(1)
309
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
310
+ }
311
+ RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
312
+ let attention_mask = token_ids.ne(*pad_token_id)
313
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
314
+
315
+ // Forward through DeBERTa encoder
316
+ let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
317
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
318
+
319
+ // Pool and classify
320
+ let pooled = pooler.forward(&encoder_output)
180
321
  .map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
181
- pooled.tanh()
182
- .map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
183
- },
184
- "cls" => {
185
- // Just use the [CLS] token embeddings directly (no pooler layer)
186
- self.extract_cls_embeddings(&embeddings)?
187
- },
188
- "mean" => {
189
- // Mean pooling across all tokens
190
- let (_batch, seq_len, _hidden) = embeddings.dims3()
191
- .map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
192
- let sum = embeddings.sum(1)
193
- .map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
194
- (sum / (seq_len as f64))
195
- .map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
196
- },
197
- _ => return Err(Error::new(runtime_error,
198
- format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
322
+ let logits = classifier.forward(&pooled)
323
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
324
+ logits.squeeze(1)
325
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
326
+ }
327
+ RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
328
+ let attention_mask = token_ids.ne(*pad_token_id)
329
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
330
+ let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
331
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert attention mask: {}", e)))?;
332
+
333
+ // Forward through ModernBERT encoder
334
+ let encoder_output = model.forward(&token_ids, &attention_mask_f32)
335
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
336
+
337
+ // CLS pooling, then head (dense + GELU + norm) + classifier
338
+ let cls = encoder_output.i((.., 0, ..))
339
+ .map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?
340
+ .contiguous()
341
+ .map_err(|e| Error::new(runtime_error, format!("Failed to make contiguous: {}", e)))?;
342
+ let hidden = head_dense.forward(&cls)
343
+ .map_err(|e| Error::new(runtime_error, format!("Head dense failed: {}", e)))?;
344
+ let hidden = hidden.gelu_erf()
345
+ .map_err(|e| Error::new(runtime_error, format!("GELU activation failed: {}", e)))?;
346
+ let hidden = head_norm.forward(&hidden)
347
+ .map_err(|e| Error::new(runtime_error, format!("Head norm failed: {}", e)))?;
348
+ let logits = classifier.forward(&hidden)
349
+ .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
350
+ logits.squeeze(1)
351
+ .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
352
+ }
353
+ RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
354
+ // Qwen3 reranker: decoder-based yes/no scoring
355
+ // Process each document individually (causal LM, not batch encoder)
356
+ let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
357
+ let mut model = model.borrow_mut();
358
+
359
+ for doc in &documents {
360
+ // Build the Qwen3 reranker prompt
361
+ let prompt = format!(
362
+ "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
363
+ query, doc
364
+ );
365
+
366
+ // Tokenize the prompt
367
+ let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
368
+ .map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
369
+ let input_ids: Vec<u32> = encoding.get_ids().to_vec();
370
+
371
+ // Clear KV cache for each document
372
+ model.clear_kv_cache();
373
+
374
+ // Forward pass — get logits for the last token position
375
+ let input_tensor = Tensor::new(&input_ids[..], &self.device)
376
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?
377
+ .unsqueeze(0)
378
+ .map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
379
+
380
+ let logits = model.forward(&input_tensor, 0)
381
+ .map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
382
+
383
+ // logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
384
+ let logits = logits.flatten_all()
385
+ .map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?
386
+ .to_dtype(DType::F32)
387
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
388
+
389
+ // Extract yes/no logits and compute score
390
+ let yes_logit: f32 = logits.i(*yes_token_id as usize)
391
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get yes logit: {}", e)))?
392
+ .to_scalar()
393
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert yes logit: {}", e)))?;
394
+ let no_logit: f32 = logits.i(*no_token_id as usize)
395
+ .map_err(|e| Error::new(runtime_error, format!("Failed to get no logit: {}", e)))?
396
+ .to_scalar()
397
+ .map_err(|e| Error::new(runtime_error, format!("Failed to convert no logit: {}", e)))?;
398
+
399
+ // softmax over [yes, no] → P(yes)
400
+ let max_logit = yes_logit.max(no_logit);
401
+ let yes_exp = (yes_logit - max_logit).exp();
402
+ let no_exp = (no_logit - max_logit).exp();
403
+ let score = yes_exp / (yes_exp + no_exp);
404
+
405
+ scores_vec.push(score);
406
+ }
407
+
408
+ // Build scores tensor for uniform handling below
409
+ Tensor::new(scores_vec.as_slice(), &self.device)
410
+ .map_err(|e| Error::new(runtime_error, format!("Failed to create scores tensor: {}", e)))?
411
+ }
199
412
  };
200
413
 
201
- // Apply classifier to get relevance scores (raw logits)
202
- // Ensure tensor is contiguous before linear layer
203
- let pooled_embeddings = pooled_embeddings.contiguous()
204
- .map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
205
- let logits = self.classifier.forward(&pooled_embeddings)
206
- .map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
207
- let scores = logits.squeeze(1)
208
- .map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?;
209
-
210
414
  // Optionally apply sigmoid activation
211
415
  let scores = if apply_sigmoid {
212
416
  sigmoid(&scores)
@@ -0,0 +1,68 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module Candle
6
+ class Agent
7
+ MAX_ITERATIONS = 10
8
+
9
+ attr_reader :llm, :tools, :system_prompt, :max_iterations
10
+
11
+ def initialize(llm, tools:, system_prompt: nil, max_iterations: MAX_ITERATIONS)
12
+ @llm = llm
13
+ @tools = tools
14
+ @system_prompt = system_prompt
15
+ @max_iterations = max_iterations
16
+ end
17
+
18
+ def run(user_message, **options)
19
+ messages = []
20
+ messages << { role: "system", content: @system_prompt } if @system_prompt
21
+ messages << { role: "user", content: user_message }
22
+
23
+ iterations = 0
24
+ loop do
25
+ iterations += 1
26
+ if iterations > @max_iterations
27
+ raise AgentMaxIterationsError,
28
+ "Agent exceeded maximum iterations (#{@max_iterations})"
29
+ end
30
+
31
+ result = @llm.chat_with_tools(messages, tools: @tools, execute: true, **options)
32
+
33
+ if result.has_tool_calls?
34
+ # If the model produced a substantial text answer alongside tool calls,
35
+ # treat it as a final response (model is done, trailing tool calls are noise).
36
+ # Strip <think> blocks so they don't count toward the length check.
37
+ text_without_thinking = result.text_response&.gsub(/<think>.*?<\/think>/m, "")&.strip
38
+ if text_without_thinking && text_without_thinking.length > 50
39
+ return AgentResult.new(
40
+ response: result.text_response,
41
+ messages: messages,
42
+ iterations: iterations,
43
+ tool_calls_made: messages.count { |m| m[:role] == "tool" }
44
+ )
45
+ end
46
+
47
+ messages << { role: "assistant", content: result.raw_response }
48
+
49
+ result.tool_results.each do |tr|
50
+ tool_name = tr[:tool_call]&.name || "unknown"
51
+ tool_output = tr[:error] ? "Error: #{tr[:error]}" : JSON.generate(tr[:result])
52
+ messages << { role: "tool", content: "[#{tool_name}] #{tool_output}" }
53
+ end
54
+ else
55
+ return AgentResult.new(
56
+ response: result.text_response || result.raw_response,
57
+ messages: messages,
58
+ iterations: iterations,
59
+ tool_calls_made: messages.count { |m| m[:role] == "tool" }
60
+ )
61
+ end
62
+ end
63
+ end
64
+ end
65
+
66
+ AgentResult = Struct.new(:response, :messages, :iterations, :tool_calls_made, keyword_init: true)
67
+ AgentMaxIterationsError = Class.new(StandardError)
68
+ end
data/lib/candle/llm.rb CHANGED
@@ -252,7 +252,7 @@ module Candle
252
252
  base_model
253
253
  end
254
254
 
255
- # Simple chat interface for instruction models
255
+ # Chat interface always returns a String
256
256
  def chat(messages, **options)
257
257
  prompt = apply_chat_template(messages)
258
258
  generate(prompt, **options)
@@ -263,7 +263,48 @@ module Candle
263
263
  prompt = apply_chat_template(messages)
264
264
  generate_stream(prompt, **options, &block)
265
265
  end
266
-
266
+
267
+ # Chat with tool calling — always returns a ToolCallResult
268
+ # Set execute: true to automatically run the tools (default: false)
269
+ def chat_with_tools(messages, tools:, execute: false, **options)
270
+ tool_prompt = build_tool_system_prompt(tools)
271
+ augmented = inject_tool_instructions(messages, tool_prompt)
272
+
273
+ raw_response = chat(augmented, **options)
274
+
275
+ result = ToolCallParser.parse(raw_response, available_tools: tools)
276
+
277
+ if result.has_tool_calls? && execute
278
+ tool_results = result.tool_calls.map do |tool_call|
279
+ tool = tools.find { |t| t.name == tool_call.name }
280
+ unless tool
281
+ next { tool_call: tool_call, result: nil, error: "Unknown tool: #{tool_call.name}" }
282
+ end
283
+
284
+ begin
285
+ output = tool.call(tool_call.arguments)
286
+ { tool_call: tool_call, result: output, error: nil }
287
+ rescue Exception => e
288
+ { tool_call: tool_call, result: nil, error: e.message }
289
+ end
290
+ end
291
+
292
+ ToolCallResult.new(
293
+ tool_calls: result.tool_calls,
294
+ tool_results: tool_results,
295
+ text_response: result.text_response,
296
+ raw_response: raw_response
297
+ )
298
+ else
299
+ ToolCallResult.new(
300
+ tool_calls: result.tool_calls,
301
+ tool_results: [],
302
+ text_response: result.has_tool_calls? ? result.text_response : raw_response,
303
+ raw_response: raw_response
304
+ )
305
+ end
306
+ end
307
+
267
308
  # Inspect method for debugging and exploration
268
309
  def inspect
269
310
  opts = options rescue {}
@@ -354,6 +395,27 @@ module Candle
354
395
 
355
396
  private
356
397
 
398
+ def build_tool_system_prompt(tools)
399
+ tool_defs = tools.map { |t| JSON.generate(t.to_tool_definition) }.join("\n\n")
400
+ "You are a helpful assistant with access to the following tools:\n\n" \
401
+ "#{tool_defs}\n\n" \
402
+ "When you need to use a tool, respond with a tool call in the following format:\n" \
403
+ "<tool_call>\n" \
404
+ "{\"name\": \"tool_name\", \"arguments\": {\"arg1\": \"value1\"}}\n" \
405
+ "</tool_call>\n\n" \
406
+ "If you don't need to use a tool, respond normally with text."
407
+ end
408
+
409
+ def inject_tool_instructions(messages, tool_prompt)
410
+ msgs = messages.map { |m| m.dup }
411
+ if msgs.first && msgs.first[:role] == "system"
412
+ msgs.first[:content] = "#{tool_prompt}\n\n#{msgs.first[:content]}"
413
+ else
414
+ msgs.unshift({ role: "system", content: tool_prompt })
415
+ end
416
+ msgs
417
+ end
418
+
357
419
  # Extract JSON content from generated text, handling stop tokens and extra content
358
420
  def extract_json_content(text)
359
421
  # Remove any content after common stop tokens
@@ -0,0 +1,47 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Candle
4
+ class Tool
5
+ attr_reader :name, :description, :parameters
6
+
7
+ def initialize(name:, description:, parameters: {}, &block)
8
+ @name = name
9
+ @description = description
10
+ @parameters = parameters
11
+ @callable = block
12
+ end
13
+
14
+ def call(arguments)
15
+ @callable.call(arguments)
16
+ end
17
+
18
+ def to_tool_definition
19
+ {
20
+ "type" => "function",
21
+ "function" => {
22
+ "name" => @name,
23
+ "description" => @description,
24
+ "parameters" => @parameters
25
+ }
26
+ }
27
+ end
28
+ end
29
+
30
+ ToolCall = Struct.new(:name, :arguments, keyword_init: true)
31
+
32
+ ToolCallResult = Struct.new(
33
+ :tool_calls,
34
+ :tool_results,
35
+ :text_response,
36
+ :raw_response,
37
+ keyword_init: true
38
+ ) do
39
+ def has_tool_calls?
40
+ tool_calls && !tool_calls.empty?
41
+ end
42
+
43
+ def success?
44
+ tool_results.all? { |r| r[:error].nil? }
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,57 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module Candle
6
+ class ToolCallParser
7
+ DEFAULT_PATTERN = /<tool_call>\s*(.*?)\s*<\/tool_call>/m
8
+
9
+ attr_reader :pattern
10
+
11
+ def initialize(pattern: DEFAULT_PATTERN)
12
+ @pattern = pattern
13
+ end
14
+
15
+ ParseResult = Struct.new(:text_response, :tool_calls, keyword_init: true) do
16
+ def has_tool_calls?
17
+ tool_calls && !tool_calls.empty?
18
+ end
19
+ end
20
+
21
+ def parse(text, available_tools: [])
22
+ tool_calls = []
23
+
24
+ text.scan(@pattern) do |match|
25
+ json_str = match[0].strip
26
+ begin
27
+ parsed = JSON.parse(json_str)
28
+ name = parsed["name"]
29
+ arguments = parsed["arguments"] || parsed["parameters"] || {}
30
+
31
+ next unless name
32
+ if available_tools.empty? || available_tools.any? { |t| t.name == name }
33
+ tool_calls << ToolCall.new(name: name, arguments: arguments)
34
+ end
35
+ rescue JSON::ParserError
36
+ # Skip malformed tool calls
37
+ end
38
+ end
39
+
40
+ # Deduplicate identical tool calls (models sometimes repeat the same call)
41
+ tool_calls.uniq! { |tc| [tc.name, tc.arguments] }
42
+
43
+ remaining_text = text.gsub(@pattern, "").strip
44
+ remaining_text = nil if remaining_text.empty?
45
+
46
+ ParseResult.new(
47
+ text_response: remaining_text,
48
+ tool_calls: tool_calls
49
+ )
50
+ end
51
+
52
+ # Convenience class method using the default pattern
53
+ def self.parse(text, available_tools: [])
54
+ new.parse(text, available_tools: available_tools)
55
+ end
56
+ end
57
+ end
@@ -1,5 +1,5 @@
1
1
  # :nocov:
2
2
  module Candle
3
- VERSION = "1.5.0"
3
+ VERSION = "1.6.1"
4
4
  end
5
5
  # :nocov:
data/lib/candle.rb CHANGED
@@ -5,7 +5,10 @@ require_relative "candle/device_utils"
5
5
  require_relative "candle/embedding_model_type"
6
6
  require_relative "candle/embedding_model"
7
7
  require_relative "candle/reranker"
8
+ require_relative "candle/tool"
9
+ require_relative "candle/tool_call_parser"
8
10
  require_relative "candle/llm"
11
+ require_relative "candle/agent"
9
12
  require_relative "candle/tokenizer"
10
13
  require_relative "candle/ner"
11
14
  require_relative "candle/build_info"
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 1.5.0
4
+ version: 1.6.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -254,6 +254,7 @@ files:
254
254
  - ext/candle/tests/device_tests.rs
255
255
  - ext/candle/tests/tensor_tests.rs
256
256
  - lib/candle.rb
257
+ - lib/candle/agent.rb
257
258
  - lib/candle/build_info.rb
258
259
  - lib/candle/device_utils.rb
259
260
  - lib/candle/embedding_model.rb
@@ -264,6 +265,8 @@ files:
264
265
  - lib/candle/reranker.rb
265
266
  - lib/candle/tensor.rb
266
267
  - lib/candle/tokenizer.rb
268
+ - lib/candle/tool.rb
269
+ - lib/candle/tool_call_parser.rb
267
270
  - lib/candle/version.rb
268
271
  - lib/red-candle.rb
269
272
  homepage: https://github.com/scientist-labs/red-candle