red-candle 1.5.0 → 1.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +100 -7
- data/ext/candle/src/llm/llama.rs +3 -3
- data/ext/candle/src/llm/quantized_gguf.rs +8 -5
- data/ext/candle/src/llm/qwen.rs +3 -0
- data/ext/candle/src/ruby/reranker.rs +263 -59
- data/lib/candle/agent.rb +68 -0
- data/lib/candle/llm.rb +64 -2
- data/lib/candle/tool.rb +47 -0
- data/lib/candle/tool_call_parser.rb +57 -0
- data/lib/candle/version.rb +1 -1
- data/lib/candle.rb +3 -0
- metadata +4 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 00c0870d599db76ba556ab880f69219bc32d72698c3570cd308df3923b1332f9
|
|
4
|
+
data.tar.gz: bffbd02fa1fe11813a304ed16771e6a72ee01f85ccb445eb67c5da6bf696c670
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: c9ab39f01e4777c3d9906592cf62b8b1d2d03fd04be4ba82d8ddee837678ac477be88b1a0a48fa898e443f7469681e26f25297d4d310c39fb5222137d3b17d87
|
|
7
|
+
data.tar.gz: 22fe87f0fb64754a0a19bcc632dff72b0cc43838d73797aa991c961b65c9ecb7d9bd7aacbde51baf4063011081f4111a64f53726b51c87aca3e48214304e0119
|
data/README.md
CHANGED
|
@@ -273,6 +273,85 @@ See [STRUCTURED_GENERATION.md](docs/STRUCTURED_GENERATION.md) for detailed docum
|
|
|
273
273
|
|
|
274
274
|
**Note on Reliability**: Structured generation constrains the model's output tokens, but success rates vary by model size and schema complexity. Smaller models (< 7B parameters) may occasionally produce incomplete or invalid JSON, especially with complex schemas. Consider implementing retry logic or fallback strategies in production applications. Larger models generally perform much better with structured generation.
|
|
275
275
|
|
|
276
|
+
## Tool Calling
|
|
277
|
+
|
|
278
|
+
Red-candle supports tool/function calling, enabling models to invoke external functions during generation. This works best with models fine-tuned for tool calling, such as Qwen3.
|
|
279
|
+
|
|
280
|
+
### Defining Tools
|
|
281
|
+
|
|
282
|
+
```ruby
|
|
283
|
+
get_weather = Candle::Tool.new(
|
|
284
|
+
name: "get_weather",
|
|
285
|
+
description: "Get the current weather for a city",
|
|
286
|
+
parameters: {
|
|
287
|
+
type: "object",
|
|
288
|
+
properties: { city: { type: "string", description: "City name" } },
|
|
289
|
+
required: ["city"]
|
|
290
|
+
}
|
|
291
|
+
) { |args| { city: args["city"], temperature: 72, condition: "sunny" } }
|
|
292
|
+
```
|
|
293
|
+
|
|
294
|
+
### Extracting Tool Calls
|
|
295
|
+
|
|
296
|
+
`chat_with_tools` injects tool definitions into the system prompt, generates a response, and parses any `<tool_call>` tags from the output. It does **not** feed results back to the model — it just tells you what the model wants to call. You decide what to do with it:
|
|
297
|
+
|
|
298
|
+
```ruby
|
|
299
|
+
llm = Candle::LLM.from_pretrained("Qwen/Qwen3-0.6B")
|
|
300
|
+
|
|
301
|
+
messages = [{ role: "user", content: "What's the weather in San Francisco?" }]
|
|
302
|
+
result = llm.chat_with_tools(messages, tools: [get_weather],
|
|
303
|
+
config: Candle::GenerationConfig.deterministic(max_length: 500))
|
|
304
|
+
|
|
305
|
+
if result.has_tool_calls?
|
|
306
|
+
result.tool_calls.each do |tc|
|
|
307
|
+
puts "#{tc.name}(#{tc.arguments})"
|
|
308
|
+
output = get_weather.call(tc.arguments)
|
|
309
|
+
puts "=> #{output}"
|
|
310
|
+
end
|
|
311
|
+
else
|
|
312
|
+
puts result.text_response
|
|
313
|
+
end
|
|
314
|
+
```
|
|
315
|
+
|
|
316
|
+
Pass `execute: true` to automatically run the tools (but still no round-trip back to the model):
|
|
317
|
+
|
|
318
|
+
```ruby
|
|
319
|
+
result = llm.chat_with_tools(messages, tools: [get_weather], execute: true,
|
|
320
|
+
config: Candle::GenerationConfig.deterministic(max_length: 500))
|
|
321
|
+
|
|
322
|
+
result.tool_results.each do |tr|
|
|
323
|
+
puts "#{tr[:tool_call].name} => #{tr[:result]}"
|
|
324
|
+
end
|
|
325
|
+
```
|
|
326
|
+
|
|
327
|
+
### Agent (Multi-Turn Tool Loop)
|
|
328
|
+
|
|
329
|
+
`Candle::Agent` completes the round-trip: generate → parse tool calls → execute → feed results back to the model → repeat until the model produces a final text answer or hits `max_iterations`. This is a convenience wrapper for quick prototyping — for production use, frameworks like [RubyLLM](https://github.com/crmne/ruby_llm) manage this loop for you via the [ruby_llm-red_candle](https://github.com/scientist-labs/ruby_llm-red_candle) plugin:
|
|
330
|
+
|
|
331
|
+
```ruby
|
|
332
|
+
agent = Candle::Agent.new(llm, tools: [get_weather, lookup_price], max_iterations: 5)
|
|
333
|
+
result = agent.run("What's the weather in Paris, and how much does a widget cost?",
|
|
334
|
+
config: Candle::GenerationConfig.deterministic(max_length: 1000))
|
|
335
|
+
|
|
336
|
+
puts result.response # Final text answer from the model
|
|
337
|
+
puts result.iterations # Number of generate cycles
|
|
338
|
+
puts result.tool_calls_made # Number of tools invoked
|
|
339
|
+
```
|
|
340
|
+
|
|
341
|
+
### Model Recommendations
|
|
342
|
+
|
|
343
|
+
Tool calling quality depends heavily on model size:
|
|
344
|
+
|
|
345
|
+
| Model | Tool Calling Quality |
|
|
346
|
+
|-------|---------------------|
|
|
347
|
+
| **Qwen3-8B GGUF** (~5 GB) | Calls correct tools, self-corrects errors, but may hallucinate values from tool results |
|
|
348
|
+
| **Qwen3-4B GGUF** (~2.5 GB) | Calls correct tools, occasional reasoning errors |
|
|
349
|
+
| **Qwen3-0.6B** (~1.2 GB) | Single-turn works, needs `max_length: 500+` for thinking |
|
|
350
|
+
| SmolLM2-360M | Does not work |
|
|
351
|
+
| TinyLlama-1.1B | Does not work (not fine-tuned for tool calling) |
|
|
352
|
+
|
|
353
|
+
**Tip:** Qwen3 models use a `<think>` reasoning block before producing tool calls. Set `max_length` high enough (500+ for 0.6B, 1000+ for larger models) to allow room for both thinking and the tool call.
|
|
354
|
+
|
|
276
355
|
## ⚠️ Model Format Requirements
|
|
277
356
|
|
|
278
357
|
### EmbeddingModels and Rerankers: Safetensors Only
|
|
@@ -355,12 +434,26 @@ embedding = model.embedding("Hi there!")
|
|
|
355
434
|
|
|
356
435
|
Red-Candle includes support for cross-encoder reranking models, which can be used to reorder documents by relevance to a query. This is particularly useful for improving search results or implementing retrieval-augmented generation (RAG) systems.
|
|
357
436
|
|
|
437
|
+
### Supported Models
|
|
438
|
+
|
|
439
|
+
Red-Candle supports both **BERT** and **XLM-RoBERTa** reranker architectures. The model type is auto-detected from `config.json`.
|
|
440
|
+
|
|
441
|
+
| Model | Architecture | Params | Notes |
|
|
442
|
+
|-------|-------------|--------|-------|
|
|
443
|
+
| `BAAI/bge-reranker-base` | XLM-RoBERTa | 278M | Recommended — strong quality, multilingual |
|
|
444
|
+
| `BAAI/bge-reranker-large` | XLM-RoBERTa | 560M | Best quality, higher resource usage |
|
|
445
|
+
| `BAAI/bge-reranker-v2-m3` | XLM-RoBERTa | 278M | Multilingual, very strong |
|
|
446
|
+
| `cross-encoder/ms-marco-MiniLM-L-12-v2` | BERT | 33M | Lightweight, English only |
|
|
447
|
+
|
|
358
448
|
### Basic Usage
|
|
359
449
|
|
|
360
450
|
```ruby
|
|
361
451
|
require 'candle'
|
|
362
452
|
|
|
363
|
-
# Initialize the reranker
|
|
453
|
+
# Initialize the reranker (BGE reranker recommended for quality)
|
|
454
|
+
reranker = Candle::Reranker.from_pretrained("BAAI/bge-reranker-base")
|
|
455
|
+
|
|
456
|
+
# Or use the lighter BERT-based model
|
|
364
457
|
reranker = Candle::Reranker.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
|
|
365
458
|
|
|
366
459
|
# Or with custom max_length for truncation (default is 512)
|
|
@@ -449,7 +542,7 @@ results = reranker.rerank_with_pooling(query, documents, "cls")
|
|
|
449
542
|
results = reranker.rerank_sigmoid_with_pooling(query, documents, "cls")
|
|
450
543
|
```
|
|
451
544
|
|
|
452
|
-
Note:
|
|
545
|
+
Note: Pooling methods only apply to BERT-based models. XLM-RoBERTa models (e.g., BGE rerankers) have a built-in classification head and ignore the `pooling_method` parameter. For BERT models, the default "pooler" method is recommended as it matches how cross-encoder models are trained.
|
|
453
546
|
|
|
454
547
|
### CUDA Support
|
|
455
548
|
|
|
@@ -467,11 +560,11 @@ Cross-encoder reranking models differ from bi-encoder embedding models:
|
|
|
467
560
|
- **Bi-encoders** (like the embedding models above) encode queries and documents separately into dense vectors
|
|
468
561
|
- **Cross-encoders** process the query and document together, allowing for more nuanced relevance scoring
|
|
469
562
|
|
|
470
|
-
The reranker
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
563
|
+
The reranker concatenates the query and document with special tokens and processes them jointly through transformer layers to produce a single relevance score.
|
|
564
|
+
|
|
565
|
+
**BERT models** (e.g., MiniLM): Use a pooler layer (dense + tanh) on the [CLS] token, then a classifier layer. Pooling method is configurable (`pooler`, `cls`, `mean`).
|
|
566
|
+
|
|
567
|
+
**XLM-RoBERTa models** (e.g., BGE rerankers): Use a built-in classification head that returns logits directly. The `pooling_method` parameter is ignored — the model handles its own pooling internally.
|
|
475
568
|
|
|
476
569
|
This joint processing allows cross-encoders to capture subtle semantic relationships between queries and documents, making them more accurate for reranking tasks, though at the cost of higher computational requirements.
|
|
477
570
|
|
data/ext/candle/src/llm/llama.rs
CHANGED
|
@@ -319,9 +319,9 @@ impl Llama {
|
|
|
319
319
|
if i == 1 || (i == 0 && system_message.is_empty()) {
|
|
320
320
|
// First user message
|
|
321
321
|
if !system_message.is_empty() {
|
|
322
|
-
prompt.push_str(&format!("
|
|
322
|
+
prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
|
|
323
323
|
} else {
|
|
324
|
-
prompt.push_str(&format!("
|
|
324
|
+
prompt.push_str(&format!("[INST] {} [/INST]", content));
|
|
325
325
|
}
|
|
326
326
|
} else {
|
|
327
327
|
prompt.push_str(&format!(" [INST] {} [/INST]", content));
|
|
@@ -340,7 +340,7 @@ impl Llama {
|
|
|
340
340
|
fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
341
341
|
let mut prompt = String::new();
|
|
342
342
|
|
|
343
|
-
prompt
|
|
343
|
+
// BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
|
|
344
344
|
|
|
345
345
|
for message in messages {
|
|
346
346
|
let role = message["role"].as_str().unwrap_or("");
|
|
@@ -386,9 +386,9 @@ impl QuantizedGGUF {
|
|
|
386
386
|
"user" => {
|
|
387
387
|
if i == 1 || (i == 0 && system_message.is_empty()) {
|
|
388
388
|
if !system_message.is_empty() {
|
|
389
|
-
prompt.push_str(&format!("
|
|
389
|
+
prompt.push_str(&format!("[INST] <<SYS>>\n{}\n<</SYS>>\n\n{} [/INST]", system_message, content));
|
|
390
390
|
} else {
|
|
391
|
-
prompt.push_str(&format!("
|
|
391
|
+
prompt.push_str(&format!("[INST] {} [/INST]", content));
|
|
392
392
|
}
|
|
393
393
|
} else {
|
|
394
394
|
prompt.push_str(&format!(" [INST] {} [/INST]", content));
|
|
@@ -406,7 +406,7 @@ impl QuantizedGGUF {
|
|
|
406
406
|
|
|
407
407
|
fn apply_llama3_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
408
408
|
let mut prompt = String::new();
|
|
409
|
-
prompt
|
|
409
|
+
// BOS token is added by the tokenizer's encode(prompt, add_special_tokens=true)
|
|
410
410
|
|
|
411
411
|
for message in messages {
|
|
412
412
|
let role = message["role"].as_str().unwrap_or("");
|
|
@@ -481,15 +481,18 @@ impl QuantizedGGUF {
|
|
|
481
481
|
"assistant" => {
|
|
482
482
|
prompt.push_str(&format!("<|im_start|>assistant\n{}<|im_end|>\n", content));
|
|
483
483
|
}
|
|
484
|
+
"tool" => {
|
|
485
|
+
prompt.push_str(&format!("<|im_start|>tool\n{}<|im_end|>\n", content));
|
|
486
|
+
}
|
|
484
487
|
_ => {}
|
|
485
488
|
}
|
|
486
489
|
}
|
|
487
|
-
|
|
490
|
+
|
|
488
491
|
// Add generation prompt
|
|
489
492
|
prompt.push_str("<|im_start|>assistant\n");
|
|
490
493
|
Ok(prompt)
|
|
491
494
|
}
|
|
492
|
-
|
|
495
|
+
|
|
493
496
|
fn apply_phi_template(&self, messages: &[serde_json::Value]) -> CandleResult<String> {
|
|
494
497
|
let mut prompt = String::new();
|
|
495
498
|
|
data/ext/candle/src/llm/qwen.rs
CHANGED
|
@@ -1,18 +1,59 @@
|
|
|
1
1
|
use magnus::{function, method, prelude::*, Error, RModule, RArray, RHash, Ruby};
|
|
2
|
-
use candle_transformers::models::bert::{BertModel, Config};
|
|
2
|
+
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
|
|
3
|
+
use candle_transformers::models::xlm_roberta::{
|
|
4
|
+
XLMRobertaForSequenceClassification, Config as XLMRobertaConfig,
|
|
5
|
+
};
|
|
6
|
+
use candle_transformers::models::debertav2::{
|
|
7
|
+
DebertaV2Model, DebertaV2ContextPooler, Config as DebertaV2Config,
|
|
8
|
+
};
|
|
9
|
+
use candle_transformers::models::modernbert::{
|
|
10
|
+
ModernBert, Config as ModernBertConfig,
|
|
11
|
+
};
|
|
12
|
+
use candle_transformers::models::qwen3::{
|
|
13
|
+
ModelForCausalLM as Qwen3Model, Config as Qwen3Config,
|
|
14
|
+
};
|
|
3
15
|
use candle_core::{Device as CoreDevice, Tensor, IndexOp, DType};
|
|
4
16
|
use candle_nn::{VarBuilder, Linear, Module, ops::sigmoid};
|
|
5
17
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
6
18
|
use tokenizers::{EncodeInput, Tokenizer};
|
|
19
|
+
use std::cell::RefCell;
|
|
7
20
|
use crate::ruby::{Device, Result};
|
|
8
21
|
use crate::tokenizer::{TokenizerWrapper, loader::TokenizerLoader};
|
|
9
22
|
|
|
23
|
+
enum RerankerModel {
|
|
24
|
+
Bert {
|
|
25
|
+
model: BertModel,
|
|
26
|
+
pooler: Linear,
|
|
27
|
+
classifier: Linear,
|
|
28
|
+
},
|
|
29
|
+
XLMRoberta {
|
|
30
|
+
model: XLMRobertaForSequenceClassification,
|
|
31
|
+
pad_token_id: u32,
|
|
32
|
+
},
|
|
33
|
+
DeBERTa {
|
|
34
|
+
model: DebertaV2Model,
|
|
35
|
+
pooler: DebertaV2ContextPooler,
|
|
36
|
+
classifier: Linear,
|
|
37
|
+
pad_token_id: u32,
|
|
38
|
+
},
|
|
39
|
+
ModernBert {
|
|
40
|
+
model: ModernBert,
|
|
41
|
+
head_dense: Linear,
|
|
42
|
+
head_norm: candle_nn::LayerNorm,
|
|
43
|
+
classifier: Linear,
|
|
44
|
+
pad_token_id: u32,
|
|
45
|
+
},
|
|
46
|
+
Qwen3 {
|
|
47
|
+
model: RefCell<Qwen3Model>,
|
|
48
|
+
yes_token_id: u32,
|
|
49
|
+
no_token_id: u32,
|
|
50
|
+
},
|
|
51
|
+
}
|
|
52
|
+
|
|
10
53
|
#[magnus::wrap(class = "Candle::Reranker", free_immediately, size)]
|
|
11
54
|
pub struct Reranker {
|
|
12
|
-
model:
|
|
55
|
+
model: RerankerModel,
|
|
13
56
|
tokenizer: TokenizerWrapper,
|
|
14
|
-
pooler: Linear,
|
|
15
|
-
classifier: Linear,
|
|
16
57
|
device: CoreDevice,
|
|
17
58
|
model_id: String,
|
|
18
59
|
}
|
|
@@ -28,7 +69,7 @@ impl Reranker {
|
|
|
28
69
|
let ruby = Ruby::get().unwrap();
|
|
29
70
|
let runtime_error = ruby.exception_runtime_error();
|
|
30
71
|
|
|
31
|
-
let result = (|| -> std::result::Result<(
|
|
72
|
+
let result = (|| -> std::result::Result<(RerankerModel, TokenizerWrapper), Box<dyn std::error::Error + Send + Sync>> {
|
|
32
73
|
let api = Api::new()?;
|
|
33
74
|
let repo = api.repo(Repo::new(model_id.clone(), RepoType::Model));
|
|
34
75
|
|
|
@@ -37,9 +78,10 @@ impl Reranker {
|
|
|
37
78
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
|
38
79
|
let weights_filename = repo.get("model.safetensors")?;
|
|
39
80
|
|
|
40
|
-
//
|
|
41
|
-
let
|
|
42
|
-
let
|
|
81
|
+
// Read raw config to detect model type
|
|
82
|
+
let config_str = std::fs::read_to_string(&config_filename)?;
|
|
83
|
+
let raw_config: serde_json::Value = serde_json::from_str(&config_str)?;
|
|
84
|
+
let model_type = raw_config["model_type"].as_str().unwrap_or("bert");
|
|
43
85
|
|
|
44
86
|
// Setup tokenizer with padding AND truncation
|
|
45
87
|
let tokenizer = Tokenizer::from_file(tokenizer_filename)?;
|
|
@@ -51,21 +93,71 @@ impl Reranker {
|
|
|
51
93
|
VarBuilder::from_mmaped_safetensors(&[weights_filename], DType::F32, &device)?
|
|
52
94
|
};
|
|
53
95
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
96
|
+
let model = match model_type {
|
|
97
|
+
"xlm-roberta" => {
|
|
98
|
+
let config: XLMRobertaConfig = serde_json::from_str(&config_str)?;
|
|
99
|
+
let pad_token_id = config.pad_token_id;
|
|
100
|
+
let model = XLMRobertaForSequenceClassification::new(1, &config, vb)?;
|
|
101
|
+
RerankerModel::XLMRoberta { model, pad_token_id }
|
|
102
|
+
}
|
|
103
|
+
"deberta-v2" => {
|
|
104
|
+
let config: DebertaV2Config = serde_json::from_str(&config_str)?;
|
|
105
|
+
let pad_token_id = config.pad_token_id.unwrap_or(0) as u32;
|
|
106
|
+
let model = DebertaV2Model::load(vb.pp("deberta"), &config)?;
|
|
107
|
+
let pooler = DebertaV2ContextPooler::load(vb.clone(), &config)?;
|
|
108
|
+
let pooler_hidden_size = config.pooler_hidden_size.unwrap_or(config.hidden_size);
|
|
109
|
+
let num_labels = config.id2label.as_ref().map_or(1, |m| m.len());
|
|
110
|
+
let classifier = candle_nn::linear(pooler_hidden_size, num_labels, vb.pp("classifier"))?;
|
|
111
|
+
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id }
|
|
112
|
+
}
|
|
113
|
+
"qwen3" => {
|
|
114
|
+
let config: Qwen3Config = serde_json::from_str(&config_str)?;
|
|
115
|
+
let model = Qwen3Model::new(&config, vb)?;
|
|
116
|
+
|
|
117
|
+
// Look up "yes" and "no" token IDs from the tokenizer
|
|
118
|
+
let yes_token_id: u32 = tokenizer
|
|
119
|
+
.encode("yes", false)
|
|
120
|
+
.ok()
|
|
121
|
+
.and_then(|enc| enc.get_ids().first().copied())
|
|
122
|
+
.unwrap_or(9693);
|
|
123
|
+
let no_token_id: u32 = tokenizer
|
|
124
|
+
.encode("no", false)
|
|
125
|
+
.ok()
|
|
126
|
+
.and_then(|enc| enc.get_ids().first().copied())
|
|
127
|
+
.unwrap_or(2152);
|
|
128
|
+
|
|
129
|
+
RerankerModel::Qwen3 {
|
|
130
|
+
model: RefCell::new(model),
|
|
131
|
+
yes_token_id,
|
|
132
|
+
no_token_id,
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
"modernbert" => {
|
|
136
|
+
let config: ModernBertConfig = serde_json::from_str(&config_str)?;
|
|
137
|
+
let pad_token_id = config.pad_token_id;
|
|
138
|
+
let model = ModernBert::load(vb.clone(), &config)?;
|
|
139
|
+
// ModernBertHead::load is private, so load the head layers manually
|
|
140
|
+
let head_vb = vb.pp("head");
|
|
141
|
+
let head_dense = candle_nn::linear_no_bias(config.hidden_size, config.hidden_size, head_vb.pp("dense"))?;
|
|
142
|
+
let head_norm = candle_nn::layer_norm_no_bias(config.hidden_size, config.layer_norm_eps, head_vb.pp("norm"))?;
|
|
143
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
|
144
|
+
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id }
|
|
145
|
+
}
|
|
146
|
+
_ => {
|
|
147
|
+
let config: BertConfig = serde_json::from_str(&config_str)?;
|
|
148
|
+
let model = BertModel::load(vb.pp("bert"), &config)?;
|
|
149
|
+
let pooler = candle_nn::linear(config.hidden_size, config.hidden_size, vb.pp("bert.pooler.dense"))?;
|
|
150
|
+
let classifier = candle_nn::linear(config.hidden_size, 1, vb.pp("classifier"))?;
|
|
151
|
+
RerankerModel::Bert { model, pooler, classifier }
|
|
152
|
+
}
|
|
153
|
+
};
|
|
62
154
|
|
|
63
|
-
Ok((model, TokenizerWrapper::new(tokenizer)
|
|
155
|
+
Ok((model, TokenizerWrapper::new(tokenizer)))
|
|
64
156
|
})();
|
|
65
157
|
|
|
66
158
|
match result {
|
|
67
|
-
Ok((model, tokenizer
|
|
68
|
-
Ok(Self { model, tokenizer,
|
|
159
|
+
Ok((model, tokenizer)) => {
|
|
160
|
+
Ok(Self { model, tokenizer, device, model_id })
|
|
69
161
|
}
|
|
70
162
|
Err(e) => Err(Error::new(runtime_error, format!("Failed to load model: {}", e))),
|
|
71
163
|
}
|
|
@@ -164,49 +256,161 @@ impl Reranker {
|
|
|
164
256
|
.map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?;
|
|
165
257
|
let token_type_ids = Tensor::new(token_type_ids, &self.device)
|
|
166
258
|
.map_err(|e| Error::new(runtime_error, format!("Failed to create token type ids tensor: {}", e)))?;
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
let
|
|
259
|
+
|
|
260
|
+
// Compute scores based on model type
|
|
261
|
+
let scores = match &self.model {
|
|
262
|
+
RerankerModel::Bert { model, pooler, classifier } => {
|
|
263
|
+
let attention_mask = token_ids.ne(0u32)
|
|
264
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
265
|
+
|
|
266
|
+
// Forward pass through BERT
|
|
267
|
+
let embeddings = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))
|
|
268
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
269
|
+
|
|
270
|
+
// Apply pooling based on the specified method
|
|
271
|
+
let pooled_embeddings = match pooling_method.as_str() {
|
|
272
|
+
"pooler" => {
|
|
273
|
+
let cls_embeddings = self.extract_cls_embeddings(&embeddings)?;
|
|
274
|
+
let pooled = pooler.forward(&cls_embeddings)
|
|
275
|
+
.map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
|
|
276
|
+
pooled.tanh()
|
|
277
|
+
.map_err(|e| Error::new(runtime_error, format!("Tanh activation failed: {}", e)))?
|
|
278
|
+
},
|
|
279
|
+
"cls" => {
|
|
280
|
+
self.extract_cls_embeddings(&embeddings)?
|
|
281
|
+
},
|
|
282
|
+
"mean" => {
|
|
283
|
+
let (_batch, seq_len, _hidden) = embeddings.dims3()
|
|
284
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get tensor dimensions: {}", e)))?;
|
|
285
|
+
let sum = embeddings.sum(1)
|
|
286
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to sum embeddings: {}", e)))?;
|
|
287
|
+
(sum / (seq_len as f64))
|
|
288
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to compute mean: {}", e)))?
|
|
289
|
+
},
|
|
290
|
+
_ => return Err(Error::new(runtime_error,
|
|
291
|
+
format!("Unknown pooling method: {}. Use 'pooler', 'cls', or 'mean'", pooling_method)))
|
|
292
|
+
};
|
|
293
|
+
|
|
294
|
+
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
295
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
|
|
296
|
+
let logits = classifier.forward(&pooled_embeddings)
|
|
297
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
298
|
+
logits.squeeze(1)
|
|
299
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
300
|
+
}
|
|
301
|
+
RerankerModel::XLMRoberta { model, pad_token_id } => {
|
|
302
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
303
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
304
|
+
|
|
305
|
+
// XLMRobertaForSequenceClassification returns logits directly
|
|
306
|
+
let logits = model.forward(&token_ids, &attention_mask, &token_type_ids)
|
|
307
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
308
|
+
logits.squeeze(1)
|
|
309
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
310
|
+
}
|
|
311
|
+
RerankerModel::DeBERTa { model, pooler, classifier, pad_token_id } => {
|
|
312
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
313
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
314
|
+
|
|
315
|
+
// Forward through DeBERTa encoder
|
|
316
|
+
let encoder_output = model.forward(&token_ids, Some(token_type_ids.clone()), Some(attention_mask))
|
|
317
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
318
|
+
|
|
319
|
+
// Pool and classify
|
|
320
|
+
let pooled = pooler.forward(&encoder_output)
|
|
180
321
|
.map_err(|e| Error::new(runtime_error, format!("Pooler forward failed: {}", e)))?;
|
|
181
|
-
|
|
182
|
-
.map_err(|e| Error::new(runtime_error, format!("
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
322
|
+
let logits = classifier.forward(&pooled)
|
|
323
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
324
|
+
logits.squeeze(1)
|
|
325
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
326
|
+
}
|
|
327
|
+
RerankerModel::ModernBert { model, head_dense, head_norm, classifier, pad_token_id } => {
|
|
328
|
+
let attention_mask = token_ids.ne(*pad_token_id)
|
|
329
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create attention mask: {}", e)))?;
|
|
330
|
+
let attention_mask_f32 = attention_mask.to_dtype(DType::F32)
|
|
331
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert attention mask: {}", e)))?;
|
|
332
|
+
|
|
333
|
+
// Forward through ModernBERT encoder
|
|
334
|
+
let encoder_output = model.forward(&token_ids, &attention_mask_f32)
|
|
335
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
336
|
+
|
|
337
|
+
// CLS pooling, then head (dense + GELU + norm) + classifier
|
|
338
|
+
let cls = encoder_output.i((.., 0, ..))
|
|
339
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to extract CLS: {}", e)))?
|
|
340
|
+
.contiguous()
|
|
341
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to make contiguous: {}", e)))?;
|
|
342
|
+
let hidden = head_dense.forward(&cls)
|
|
343
|
+
.map_err(|e| Error::new(runtime_error, format!("Head dense failed: {}", e)))?;
|
|
344
|
+
let hidden = hidden.gelu_erf()
|
|
345
|
+
.map_err(|e| Error::new(runtime_error, format!("GELU activation failed: {}", e)))?;
|
|
346
|
+
let hidden = head_norm.forward(&hidden)
|
|
347
|
+
.map_err(|e| Error::new(runtime_error, format!("Head norm failed: {}", e)))?;
|
|
348
|
+
let logits = classifier.forward(&hidden)
|
|
349
|
+
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
350
|
+
logits.squeeze(1)
|
|
351
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?
|
|
352
|
+
}
|
|
353
|
+
RerankerModel::Qwen3 { model, yes_token_id, no_token_id } => {
|
|
354
|
+
// Qwen3 reranker: decoder-based yes/no scoring
|
|
355
|
+
// Process each document individually (causal LM, not batch encoder)
|
|
356
|
+
let mut scores_vec: Vec<f32> = Vec::with_capacity(documents.len());
|
|
357
|
+
let mut model = model.borrow_mut();
|
|
358
|
+
|
|
359
|
+
for doc in &documents {
|
|
360
|
+
// Build the Qwen3 reranker prompt
|
|
361
|
+
let prompt = format!(
|
|
362
|
+
"<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n<Instruct>: Given a web search query, retrieve relevant passages that answer the query\n<Query>: {}\n<Document>: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n",
|
|
363
|
+
query, doc
|
|
364
|
+
);
|
|
365
|
+
|
|
366
|
+
// Tokenize the prompt
|
|
367
|
+
let encoding = self.tokenizer.inner().encode(prompt.as_str(), false)
|
|
368
|
+
.map_err(|e| Error::new(runtime_error, format!("Tokenization failed: {}", e)))?;
|
|
369
|
+
let input_ids: Vec<u32> = encoding.get_ids().to_vec();
|
|
370
|
+
|
|
371
|
+
// Clear KV cache for each document
|
|
372
|
+
model.clear_kv_cache();
|
|
373
|
+
|
|
374
|
+
// Forward pass — get logits for the last token position
|
|
375
|
+
let input_tensor = Tensor::new(&input_ids[..], &self.device)
|
|
376
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create tensor: {}", e)))?
|
|
377
|
+
.unsqueeze(0)
|
|
378
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to unsqueeze: {}", e)))?;
|
|
379
|
+
|
|
380
|
+
let logits = model.forward(&input_tensor, 0)
|
|
381
|
+
.map_err(|e| Error::new(runtime_error, format!("Model forward pass failed: {}", e)))?;
|
|
382
|
+
|
|
383
|
+
// logits shape: [1, 1, vocab_size] → flatten to [vocab_size]
|
|
384
|
+
let logits = logits.flatten_all()
|
|
385
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to flatten: {}", e)))?
|
|
386
|
+
.to_dtype(DType::F32)
|
|
387
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert dtype: {}", e)))?;
|
|
388
|
+
|
|
389
|
+
// Extract yes/no logits and compute score
|
|
390
|
+
let yes_logit: f32 = logits.i(*yes_token_id as usize)
|
|
391
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get yes logit: {}", e)))?
|
|
392
|
+
.to_scalar()
|
|
393
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert yes logit: {}", e)))?;
|
|
394
|
+
let no_logit: f32 = logits.i(*no_token_id as usize)
|
|
395
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to get no logit: {}", e)))?
|
|
396
|
+
.to_scalar()
|
|
397
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to convert no logit: {}", e)))?;
|
|
398
|
+
|
|
399
|
+
// softmax over [yes, no] → P(yes)
|
|
400
|
+
let max_logit = yes_logit.max(no_logit);
|
|
401
|
+
let yes_exp = (yes_logit - max_logit).exp();
|
|
402
|
+
let no_exp = (no_logit - max_logit).exp();
|
|
403
|
+
let score = yes_exp / (yes_exp + no_exp);
|
|
404
|
+
|
|
405
|
+
scores_vec.push(score);
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
// Build scores tensor for uniform handling below
|
|
409
|
+
Tensor::new(scores_vec.as_slice(), &self.device)
|
|
410
|
+
.map_err(|e| Error::new(runtime_error, format!("Failed to create scores tensor: {}", e)))?
|
|
411
|
+
}
|
|
199
412
|
};
|
|
200
413
|
|
|
201
|
-
// Apply classifier to get relevance scores (raw logits)
|
|
202
|
-
// Ensure tensor is contiguous before linear layer
|
|
203
|
-
let pooled_embeddings = pooled_embeddings.contiguous()
|
|
204
|
-
.map_err(|e| Error::new(runtime_error, format!("Failed to make pooled_embeddings contiguous: {}", e)))?;
|
|
205
|
-
let logits = self.classifier.forward(&pooled_embeddings)
|
|
206
|
-
.map_err(|e| Error::new(runtime_error, format!("Classifier forward failed: {}", e)))?;
|
|
207
|
-
let scores = logits.squeeze(1)
|
|
208
|
-
.map_err(|e| Error::new(runtime_error, format!("Failed to squeeze tensor: {}", e)))?;
|
|
209
|
-
|
|
210
414
|
// Optionally apply sigmoid activation
|
|
211
415
|
let scores = if apply_sigmoid {
|
|
212
416
|
sigmoid(&scores)
|
data/lib/candle/agent.rb
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module Candle
|
|
6
|
+
class Agent
|
|
7
|
+
MAX_ITERATIONS = 10
|
|
8
|
+
|
|
9
|
+
attr_reader :llm, :tools, :system_prompt, :max_iterations
|
|
10
|
+
|
|
11
|
+
def initialize(llm, tools:, system_prompt: nil, max_iterations: MAX_ITERATIONS)
|
|
12
|
+
@llm = llm
|
|
13
|
+
@tools = tools
|
|
14
|
+
@system_prompt = system_prompt
|
|
15
|
+
@max_iterations = max_iterations
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def run(user_message, **options)
|
|
19
|
+
messages = []
|
|
20
|
+
messages << { role: "system", content: @system_prompt } if @system_prompt
|
|
21
|
+
messages << { role: "user", content: user_message }
|
|
22
|
+
|
|
23
|
+
iterations = 0
|
|
24
|
+
loop do
|
|
25
|
+
iterations += 1
|
|
26
|
+
if iterations > @max_iterations
|
|
27
|
+
raise AgentMaxIterationsError,
|
|
28
|
+
"Agent exceeded maximum iterations (#{@max_iterations})"
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
result = @llm.chat_with_tools(messages, tools: @tools, execute: true, **options)
|
|
32
|
+
|
|
33
|
+
if result.has_tool_calls?
|
|
34
|
+
# If the model produced a substantial text answer alongside tool calls,
|
|
35
|
+
# treat it as a final response (model is done, trailing tool calls are noise).
|
|
36
|
+
# Strip <think> blocks so they don't count toward the length check.
|
|
37
|
+
text_without_thinking = result.text_response&.gsub(/<think>.*?<\/think>/m, "")&.strip
|
|
38
|
+
if text_without_thinking && text_without_thinking.length > 50
|
|
39
|
+
return AgentResult.new(
|
|
40
|
+
response: result.text_response,
|
|
41
|
+
messages: messages,
|
|
42
|
+
iterations: iterations,
|
|
43
|
+
tool_calls_made: messages.count { |m| m[:role] == "tool" }
|
|
44
|
+
)
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
messages << { role: "assistant", content: result.raw_response }
|
|
48
|
+
|
|
49
|
+
result.tool_results.each do |tr|
|
|
50
|
+
tool_name = tr[:tool_call]&.name || "unknown"
|
|
51
|
+
tool_output = tr[:error] ? "Error: #{tr[:error]}" : JSON.generate(tr[:result])
|
|
52
|
+
messages << { role: "tool", content: "[#{tool_name}] #{tool_output}" }
|
|
53
|
+
end
|
|
54
|
+
else
|
|
55
|
+
return AgentResult.new(
|
|
56
|
+
response: result.text_response || result.raw_response,
|
|
57
|
+
messages: messages,
|
|
58
|
+
iterations: iterations,
|
|
59
|
+
tool_calls_made: messages.count { |m| m[:role] == "tool" }
|
|
60
|
+
)
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
AgentResult = Struct.new(:response, :messages, :iterations, :tool_calls_made, keyword_init: true)
|
|
67
|
+
AgentMaxIterationsError = Class.new(StandardError)
|
|
68
|
+
end
|
data/lib/candle/llm.rb
CHANGED
|
@@ -252,7 +252,7 @@ module Candle
|
|
|
252
252
|
base_model
|
|
253
253
|
end
|
|
254
254
|
|
|
255
|
-
#
|
|
255
|
+
# Chat interface — always returns a String
|
|
256
256
|
def chat(messages, **options)
|
|
257
257
|
prompt = apply_chat_template(messages)
|
|
258
258
|
generate(prompt, **options)
|
|
@@ -263,7 +263,48 @@ module Candle
|
|
|
263
263
|
prompt = apply_chat_template(messages)
|
|
264
264
|
generate_stream(prompt, **options, &block)
|
|
265
265
|
end
|
|
266
|
-
|
|
266
|
+
|
|
267
|
+
# Chat with tool calling — always returns a ToolCallResult
|
|
268
|
+
# Set execute: true to automatically run the tools (default: false)
|
|
269
|
+
def chat_with_tools(messages, tools:, execute: false, **options)
|
|
270
|
+
tool_prompt = build_tool_system_prompt(tools)
|
|
271
|
+
augmented = inject_tool_instructions(messages, tool_prompt)
|
|
272
|
+
|
|
273
|
+
raw_response = chat(augmented, **options)
|
|
274
|
+
|
|
275
|
+
result = ToolCallParser.parse(raw_response, available_tools: tools)
|
|
276
|
+
|
|
277
|
+
if result.has_tool_calls? && execute
|
|
278
|
+
tool_results = result.tool_calls.map do |tool_call|
|
|
279
|
+
tool = tools.find { |t| t.name == tool_call.name }
|
|
280
|
+
unless tool
|
|
281
|
+
next { tool_call: tool_call, result: nil, error: "Unknown tool: #{tool_call.name}" }
|
|
282
|
+
end
|
|
283
|
+
|
|
284
|
+
begin
|
|
285
|
+
output = tool.call(tool_call.arguments)
|
|
286
|
+
{ tool_call: tool_call, result: output, error: nil }
|
|
287
|
+
rescue Exception => e
|
|
288
|
+
{ tool_call: tool_call, result: nil, error: e.message }
|
|
289
|
+
end
|
|
290
|
+
end
|
|
291
|
+
|
|
292
|
+
ToolCallResult.new(
|
|
293
|
+
tool_calls: result.tool_calls,
|
|
294
|
+
tool_results: tool_results,
|
|
295
|
+
text_response: result.text_response,
|
|
296
|
+
raw_response: raw_response
|
|
297
|
+
)
|
|
298
|
+
else
|
|
299
|
+
ToolCallResult.new(
|
|
300
|
+
tool_calls: result.tool_calls,
|
|
301
|
+
tool_results: [],
|
|
302
|
+
text_response: result.has_tool_calls? ? result.text_response : raw_response,
|
|
303
|
+
raw_response: raw_response
|
|
304
|
+
)
|
|
305
|
+
end
|
|
306
|
+
end
|
|
307
|
+
|
|
267
308
|
# Inspect method for debugging and exploration
|
|
268
309
|
def inspect
|
|
269
310
|
opts = options rescue {}
|
|
@@ -354,6 +395,27 @@ module Candle
|
|
|
354
395
|
|
|
355
396
|
private
|
|
356
397
|
|
|
398
|
+
def build_tool_system_prompt(tools)
|
|
399
|
+
tool_defs = tools.map { |t| JSON.generate(t.to_tool_definition) }.join("\n\n")
|
|
400
|
+
"You are a helpful assistant with access to the following tools:\n\n" \
|
|
401
|
+
"#{tool_defs}\n\n" \
|
|
402
|
+
"When you need to use a tool, respond with a tool call in the following format:\n" \
|
|
403
|
+
"<tool_call>\n" \
|
|
404
|
+
"{\"name\": \"tool_name\", \"arguments\": {\"arg1\": \"value1\"}}\n" \
|
|
405
|
+
"</tool_call>\n\n" \
|
|
406
|
+
"If you don't need to use a tool, respond normally with text."
|
|
407
|
+
end
|
|
408
|
+
|
|
409
|
+
def inject_tool_instructions(messages, tool_prompt)
|
|
410
|
+
msgs = messages.map { |m| m.dup }
|
|
411
|
+
if msgs.first && msgs.first[:role] == "system"
|
|
412
|
+
msgs.first[:content] = "#{tool_prompt}\n\n#{msgs.first[:content]}"
|
|
413
|
+
else
|
|
414
|
+
msgs.unshift({ role: "system", content: tool_prompt })
|
|
415
|
+
end
|
|
416
|
+
msgs
|
|
417
|
+
end
|
|
418
|
+
|
|
357
419
|
# Extract JSON content from generated text, handling stop tokens and extra content
|
|
358
420
|
def extract_json_content(text)
|
|
359
421
|
# Remove any content after common stop tokens
|
data/lib/candle/tool.rb
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Candle
|
|
4
|
+
class Tool
|
|
5
|
+
attr_reader :name, :description, :parameters
|
|
6
|
+
|
|
7
|
+
def initialize(name:, description:, parameters: {}, &block)
|
|
8
|
+
@name = name
|
|
9
|
+
@description = description
|
|
10
|
+
@parameters = parameters
|
|
11
|
+
@callable = block
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def call(arguments)
|
|
15
|
+
@callable.call(arguments)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def to_tool_definition
|
|
19
|
+
{
|
|
20
|
+
"type" => "function",
|
|
21
|
+
"function" => {
|
|
22
|
+
"name" => @name,
|
|
23
|
+
"description" => @description,
|
|
24
|
+
"parameters" => @parameters
|
|
25
|
+
}
|
|
26
|
+
}
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
ToolCall = Struct.new(:name, :arguments, keyword_init: true)
|
|
31
|
+
|
|
32
|
+
ToolCallResult = Struct.new(
|
|
33
|
+
:tool_calls,
|
|
34
|
+
:tool_results,
|
|
35
|
+
:text_response,
|
|
36
|
+
:raw_response,
|
|
37
|
+
keyword_init: true
|
|
38
|
+
) do
|
|
39
|
+
def has_tool_calls?
|
|
40
|
+
tool_calls && !tool_calls.empty?
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
def success?
|
|
44
|
+
tool_results.all? { |r| r[:error].nil? }
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module Candle
|
|
6
|
+
class ToolCallParser
|
|
7
|
+
DEFAULT_PATTERN = /<tool_call>\s*(.*?)\s*<\/tool_call>/m
|
|
8
|
+
|
|
9
|
+
attr_reader :pattern
|
|
10
|
+
|
|
11
|
+
def initialize(pattern: DEFAULT_PATTERN)
|
|
12
|
+
@pattern = pattern
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
ParseResult = Struct.new(:text_response, :tool_calls, keyword_init: true) do
|
|
16
|
+
def has_tool_calls?
|
|
17
|
+
tool_calls && !tool_calls.empty?
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def parse(text, available_tools: [])
|
|
22
|
+
tool_calls = []
|
|
23
|
+
|
|
24
|
+
text.scan(@pattern) do |match|
|
|
25
|
+
json_str = match[0].strip
|
|
26
|
+
begin
|
|
27
|
+
parsed = JSON.parse(json_str)
|
|
28
|
+
name = parsed["name"]
|
|
29
|
+
arguments = parsed["arguments"] || parsed["parameters"] || {}
|
|
30
|
+
|
|
31
|
+
next unless name
|
|
32
|
+
if available_tools.empty? || available_tools.any? { |t| t.name == name }
|
|
33
|
+
tool_calls << ToolCall.new(name: name, arguments: arguments)
|
|
34
|
+
end
|
|
35
|
+
rescue JSON::ParserError
|
|
36
|
+
# Skip malformed tool calls
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Deduplicate identical tool calls (models sometimes repeat the same call)
|
|
41
|
+
tool_calls.uniq! { |tc| [tc.name, tc.arguments] }
|
|
42
|
+
|
|
43
|
+
remaining_text = text.gsub(@pattern, "").strip
|
|
44
|
+
remaining_text = nil if remaining_text.empty?
|
|
45
|
+
|
|
46
|
+
ParseResult.new(
|
|
47
|
+
text_response: remaining_text,
|
|
48
|
+
tool_calls: tool_calls
|
|
49
|
+
)
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Convenience class method using the default pattern
|
|
53
|
+
def self.parse(text, available_tools: [])
|
|
54
|
+
new.parse(text, available_tools: available_tools)
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
end
|
data/lib/candle/version.rb
CHANGED
data/lib/candle.rb
CHANGED
|
@@ -5,7 +5,10 @@ require_relative "candle/device_utils"
|
|
|
5
5
|
require_relative "candle/embedding_model_type"
|
|
6
6
|
require_relative "candle/embedding_model"
|
|
7
7
|
require_relative "candle/reranker"
|
|
8
|
+
require_relative "candle/tool"
|
|
9
|
+
require_relative "candle/tool_call_parser"
|
|
8
10
|
require_relative "candle/llm"
|
|
11
|
+
require_relative "candle/agent"
|
|
9
12
|
require_relative "candle/tokenizer"
|
|
10
13
|
require_relative "candle/ner"
|
|
11
14
|
require_relative "candle/build_info"
|
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: red-candle
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 1.
|
|
4
|
+
version: 1.6.1
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Christopher Petersen
|
|
@@ -254,6 +254,7 @@ files:
|
|
|
254
254
|
- ext/candle/tests/device_tests.rs
|
|
255
255
|
- ext/candle/tests/tensor_tests.rs
|
|
256
256
|
- lib/candle.rb
|
|
257
|
+
- lib/candle/agent.rb
|
|
257
258
|
- lib/candle/build_info.rb
|
|
258
259
|
- lib/candle/device_utils.rb
|
|
259
260
|
- lib/candle/embedding_model.rb
|
|
@@ -264,6 +265,8 @@ files:
|
|
|
264
265
|
- lib/candle/reranker.rb
|
|
265
266
|
- lib/candle/tensor.rb
|
|
266
267
|
- lib/candle/tokenizer.rb
|
|
268
|
+
- lib/candle/tool.rb
|
|
269
|
+
- lib/candle/tool_call_parser.rb
|
|
267
270
|
- lib/candle/version.rb
|
|
268
271
|
- lib/red-candle.rb
|
|
269
272
|
homepage: https://github.com/scientist-labs/red-candle
|