red-candle 0.0.5 → 1.0.0.pre.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +2627 -603
- data/Cargo.toml +4 -0
- data/README.md +224 -6
- data/ext/candle/Cargo.toml +20 -9
- data/ext/candle/extconf.rb +77 -1
- data/ext/candle/src/lib.rs +121 -82
- data/lib/candle/build_info.rb +66 -0
- data/lib/candle/device_utils.rb +20 -0
- data/lib/candle/embedding_model.rb +32 -0
- data/lib/candle/embedding_model_type.rb +31 -0
- data/lib/candle/llm.rb +107 -0
- data/lib/candle/reranker.rb +24 -0
- data/lib/candle/tensor.rb +68 -3
- data/lib/candle/version.rb +2 -2
- data/lib/candle.rb +6 -0
- metadata +9 -3
@@ -0,0 +1,66 @@
|
|
1
|
+
module Candle
|
2
|
+
module BuildInfo
|
3
|
+
def self.display_cuda_info
|
4
|
+
info = Candle.build_info
|
5
|
+
|
6
|
+
# Only display CUDA info if running in development or if CANDLE_VERBOSE is set
|
7
|
+
return unless ENV['CANDLE_VERBOSE'] || ENV['CANDLE_DEBUG'] || $DEBUG
|
8
|
+
|
9
|
+
if info["cuda_available"] == false
|
10
|
+
# Check if CUDA could be available on the system
|
11
|
+
cuda_potentially_available = ENV['CUDA_ROOT'] || ENV['CUDA_PATH'] ||
|
12
|
+
File.exist?('/usr/local/cuda') || File.exist?('/opt/cuda')
|
13
|
+
|
14
|
+
if cuda_potentially_available
|
15
|
+
warn "=" * 80
|
16
|
+
warn "Red Candle: CUDA detected on system but not enabled in build."
|
17
|
+
warn "To enable CUDA support (experimental), reinstall with:"
|
18
|
+
warn " CANDLE_ENABLE_CUDA=1 gem install red-candle"
|
19
|
+
warn "=" * 80
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def self.cuda_available?
|
25
|
+
Candle.build_info["cuda_available"]
|
26
|
+
end
|
27
|
+
|
28
|
+
def self.metal_available?
|
29
|
+
Candle.build_info["metal_available"]
|
30
|
+
end
|
31
|
+
|
32
|
+
def self.mkl_available?
|
33
|
+
Candle.build_info["mkl_available"]
|
34
|
+
end
|
35
|
+
|
36
|
+
def self.accelerate_available?
|
37
|
+
Candle.build_info["accelerate_available"]
|
38
|
+
end
|
39
|
+
|
40
|
+
def self.cudnn_available?
|
41
|
+
Candle.build_info["cudnn_available"]
|
42
|
+
end
|
43
|
+
|
44
|
+
def self.summary
|
45
|
+
info = Candle.build_info
|
46
|
+
|
47
|
+
available_backends = []
|
48
|
+
available_backends << "Metal" if info["metal_available"]
|
49
|
+
available_backends << "CUDA" if info["cuda_available"]
|
50
|
+
available_backends << "CPU"
|
51
|
+
|
52
|
+
{
|
53
|
+
default_device: info["default_device"],
|
54
|
+
available_backends: available_backends,
|
55
|
+
cuda_available: info["cuda_available"],
|
56
|
+
metal_available: info["metal_available"],
|
57
|
+
mkl_available: info["mkl_available"],
|
58
|
+
accelerate_available: info["accelerate_available"],
|
59
|
+
cudnn_available: info["cudnn_available"]
|
60
|
+
}
|
61
|
+
end
|
62
|
+
end
|
63
|
+
end
|
64
|
+
|
65
|
+
# Display CUDA info on load
|
66
|
+
Candle::BuildInfo.display_cuda_info
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Candle
|
2
|
+
module DeviceUtils
|
3
|
+
# Get the best available device (Metal > CUDA > CPU)
|
4
|
+
def self.best_device
|
5
|
+
# Try devices in order of preference
|
6
|
+
begin
|
7
|
+
# Try Metal first (for Mac users)
|
8
|
+
Device.metal
|
9
|
+
rescue
|
10
|
+
begin
|
11
|
+
# Try CUDA next (for NVIDIA GPU users)
|
12
|
+
Device.cuda
|
13
|
+
rescue
|
14
|
+
# Fall back to CPU
|
15
|
+
Device.cpu
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
@@ -0,0 +1,32 @@
|
|
1
|
+
module Candle
|
2
|
+
class EmbeddingModel
|
3
|
+
# Default model path for Jina BERT embedding model
|
4
|
+
DEFAULT_MODEL_PATH = "jinaai/jina-embeddings-v2-base-en"
|
5
|
+
|
6
|
+
# Default tokenizer path that works well with the default model
|
7
|
+
DEFAULT_TOKENIZER_PATH = "sentence-transformers/all-MiniLM-L6-v2"
|
8
|
+
|
9
|
+
# Default embedding model type
|
10
|
+
DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
|
11
|
+
|
12
|
+
# Constructor for creating a new EmbeddingModel with optional parameters
|
13
|
+
# @param model_path [String, nil] The path to the model on Hugging Face
|
14
|
+
# @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
|
15
|
+
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
16
|
+
# @param model_type [String, nil] The type of embedding model to use
|
17
|
+
# @param embedding_size [Integer, nil] Override for the embedding size (optional)
|
18
|
+
def self.new(model_path: DEFAULT_MODEL_PATH,
|
19
|
+
tokenizer_path: DEFAULT_TOKENIZER_PATH,
|
20
|
+
device: Candle::Device.cpu,
|
21
|
+
model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
|
22
|
+
embedding_size: nil)
|
23
|
+
_create(model_path, tokenizer_path, device, model_type, embedding_size)
|
24
|
+
end
|
25
|
+
# Returns the embedding for a string using the specified pooling method.
|
26
|
+
# @param str [String] The input text
|
27
|
+
# @param pooling_method [String] Pooling method: "pooled", "pooled_normalized", or "cls". Default: "pooled_normalized"
|
28
|
+
def embedding(str, pooling_method: "pooled_normalized")
|
29
|
+
_embedding(str, pooling_method)
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Candle
|
2
|
+
# Enum for the supported embedding model types
|
3
|
+
module EmbeddingModelType
|
4
|
+
# Jina Bert embedding models (e.g., jina-embeddings-v2-base-en)
|
5
|
+
JINA_BERT = "jina_bert"
|
6
|
+
|
7
|
+
# Standard BERT embedding models (e.g., bert-base-uncased)
|
8
|
+
STANDARD_BERT = "standard_bert"
|
9
|
+
|
10
|
+
# MiniLM embedding models (e.g., all-MiniLM-L6-v2)
|
11
|
+
MINILM = "minilm"
|
12
|
+
|
13
|
+
# DistilBERT models which can be used for embeddings
|
14
|
+
DISTILBERT = "distilbert"
|
15
|
+
|
16
|
+
# Returns a list of all supported model types
|
17
|
+
def self.all
|
18
|
+
[JINA_BERT, STANDARD_BERT, DISTILBERT, MINILM]
|
19
|
+
end
|
20
|
+
|
21
|
+
# Returns suggested model paths for each model type
|
22
|
+
def self.suggested_model_paths
|
23
|
+
{
|
24
|
+
JINA_BERT => "jinaai/jina-embeddings-v2-base-en",
|
25
|
+
STANDARD_BERT => "scientistcom/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
|
26
|
+
MINILM => "sentence-transformers/all-MiniLM-L6-v2",
|
27
|
+
DISTILBERT => "scientistcom/distilbert-base-uncased-finetuned-sst-2-english",
|
28
|
+
}
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/candle/llm.rb
ADDED
@@ -0,0 +1,107 @@
|
|
1
|
+
module Candle
|
2
|
+
class LLM
|
3
|
+
# Simple chat interface for instruction models
|
4
|
+
def chat(messages, **options)
|
5
|
+
prompt = format_messages(messages)
|
6
|
+
generate(prompt, **options)
|
7
|
+
end
|
8
|
+
|
9
|
+
# Streaming chat interface
|
10
|
+
def chat_stream(messages, **options, &block)
|
11
|
+
prompt = format_messages(messages)
|
12
|
+
generate_stream(prompt, **options, &block)
|
13
|
+
end
|
14
|
+
|
15
|
+
def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
|
16
|
+
begin
|
17
|
+
_generate(prompt, config)
|
18
|
+
ensure
|
19
|
+
clear_cache if reset_cache
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
def generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block)
|
24
|
+
begin
|
25
|
+
_generate_stream(prompt, config, &block)
|
26
|
+
ensure
|
27
|
+
clear_cache if reset_cache
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def self.from_pretrained(model_id, device: Candle::Device.cpu)
|
32
|
+
_from_pretrained(model_id, device)
|
33
|
+
end
|
34
|
+
|
35
|
+
private
|
36
|
+
|
37
|
+
# Format messages into a prompt string
|
38
|
+
# This is a simple implementation - model-specific formatting should be added
|
39
|
+
def format_messages(messages)
|
40
|
+
formatted = messages.map do |msg|
|
41
|
+
case msg[:role]
|
42
|
+
when "system"
|
43
|
+
"System: #{msg[:content]}"
|
44
|
+
when "user"
|
45
|
+
"User: #{msg[:content]}"
|
46
|
+
when "assistant"
|
47
|
+
"Assistant: #{msg[:content]}"
|
48
|
+
else
|
49
|
+
msg[:content]
|
50
|
+
end
|
51
|
+
end.join("\n\n")
|
52
|
+
|
53
|
+
# Add a prompt for the assistant to respond
|
54
|
+
formatted + "\n\nAssistant:"
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
class GenerationConfig
|
59
|
+
# Convenience method to create config with overrides
|
60
|
+
def with(**overrides)
|
61
|
+
current_config = {
|
62
|
+
max_length: max_length,
|
63
|
+
temperature: temperature,
|
64
|
+
top_p: top_p,
|
65
|
+
top_k: top_k,
|
66
|
+
repetition_penalty: repetition_penalty,
|
67
|
+
seed: seed,
|
68
|
+
stop_sequences: stop_sequences,
|
69
|
+
include_prompt: include_prompt
|
70
|
+
}.compact
|
71
|
+
|
72
|
+
self.class.new(current_config.merge(overrides))
|
73
|
+
end
|
74
|
+
|
75
|
+
# Create a deterministic configuration (temperature = 0, fixed seed)
|
76
|
+
def self.deterministic(**opts)
|
77
|
+
defaults = {
|
78
|
+
temperature: 0.0,
|
79
|
+
top_p: nil,
|
80
|
+
top_k: 1,
|
81
|
+
seed: 42
|
82
|
+
}
|
83
|
+
new(defaults.merge(opts))
|
84
|
+
end
|
85
|
+
|
86
|
+
# Create a creative configuration (higher temperature, random seed)
|
87
|
+
def self.creative(**opts)
|
88
|
+
defaults = {
|
89
|
+
temperature: 1.0,
|
90
|
+
top_p: 0.95,
|
91
|
+
top_k: 50,
|
92
|
+
repetition_penalty: 1.2
|
93
|
+
}
|
94
|
+
new(defaults.merge(opts))
|
95
|
+
end
|
96
|
+
|
97
|
+
# Create a balanced configuration (moderate temperature, random seed)
|
98
|
+
def self.balanced(**opts)
|
99
|
+
defaults = {
|
100
|
+
temperature: 0.7,
|
101
|
+
top_p: 0.9,
|
102
|
+
top_k: 40
|
103
|
+
}
|
104
|
+
new(defaults.merge(opts))
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
@@ -0,0 +1,24 @@
|
|
1
|
+
module Candle
|
2
|
+
class Reranker
|
3
|
+
# Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
|
4
|
+
DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
|
5
|
+
|
6
|
+
# Constructor for creating a new Reranker with optional parameters
|
7
|
+
# @param model_path [String, nil] The path to the model on Hugging Face
|
8
|
+
# @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
|
9
|
+
def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.cpu)
|
10
|
+
_create(model_path, device)
|
11
|
+
end
|
12
|
+
|
13
|
+
# Returns the embedding for a string using the specified pooling method.
|
14
|
+
# @param query [String] The input text
|
15
|
+
# @param documents [Array<String>] The list of documents to compare against
|
16
|
+
# @param pooling_method [String] Pooling method: "pooler", "cls", or "mean". Default: "pooler"
|
17
|
+
# @param apply_sigmoid [Boolean] Whether to apply sigmoid to the scores. Default: true
|
18
|
+
def rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
|
19
|
+
rerank_with_options(query, documents, pooling_method, apply_sigmoid).collect { |doc, score, doc_id|
|
20
|
+
{ doc_id: doc_id, score: score, text: doc }
|
21
|
+
}
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
data/lib/candle/tensor.rb
CHANGED
@@ -3,15 +3,80 @@ module Candle
|
|
3
3
|
include Enumerable
|
4
4
|
|
5
5
|
def each
|
6
|
-
|
7
|
-
|
8
|
-
|
6
|
+
case self.rank
|
7
|
+
when 0
|
8
|
+
# Scalar tensor - yield the single value
|
9
|
+
yield self.item
|
10
|
+
when 1
|
11
|
+
# 1D tensor - yield each value
|
12
|
+
# Check if we can use f32 values to avoid conversion
|
13
|
+
if dtype.to_s.downcase == "f32"
|
14
|
+
begin
|
15
|
+
values_f32.each { |value| yield value }
|
16
|
+
rescue NoMethodError
|
17
|
+
# If values_f32 isn't available yet (not recompiled), fall back
|
18
|
+
if device.to_s != "cpu"
|
19
|
+
# Move to CPU to avoid Metal F32->F64 conversion issue
|
20
|
+
to_device(Candle::Device.cpu).values.each { |value| yield value }
|
21
|
+
else
|
22
|
+
values.each { |value| yield value }
|
23
|
+
end
|
24
|
+
end
|
25
|
+
else
|
26
|
+
# For non-F32 dtypes, use regular values
|
27
|
+
values.each { |value| yield value }
|
9
28
|
end
|
10
29
|
else
|
30
|
+
# Multi-dimensional tensor - yield each sub-tensor
|
11
31
|
shape.first.times do |i|
|
12
32
|
yield self[i]
|
13
33
|
end
|
14
34
|
end
|
15
35
|
end
|
36
|
+
|
37
|
+
# Convert scalar tensor to float
|
38
|
+
def to_f
|
39
|
+
if rank == 0
|
40
|
+
# Use item method which handles dtype conversion properly
|
41
|
+
item
|
42
|
+
else
|
43
|
+
raise ArgumentError, "to_f can only be called on scalar tensors (rank 0), but this tensor has rank #{rank}"
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
# Convert scalar tensor to integer
|
48
|
+
def to_i
|
49
|
+
to_f.to_i
|
50
|
+
end
|
51
|
+
|
52
|
+
|
53
|
+
# Override class methods to support keyword arguments for device
|
54
|
+
class << self
|
55
|
+
alias_method :_original_new, :new
|
56
|
+
alias_method :_original_ones, :ones
|
57
|
+
alias_method :_original_zeros, :zeros
|
58
|
+
alias_method :_original_rand, :rand
|
59
|
+
alias_method :_original_randn, :randn
|
60
|
+
|
61
|
+
def new(data, dtype = nil, device: nil)
|
62
|
+
_original_new(data, dtype, device)
|
63
|
+
end
|
64
|
+
|
65
|
+
def ones(shape, device: nil)
|
66
|
+
_original_ones(shape, device)
|
67
|
+
end
|
68
|
+
|
69
|
+
def zeros(shape, device: nil)
|
70
|
+
_original_zeros(shape, device)
|
71
|
+
end
|
72
|
+
|
73
|
+
def rand(shape, device: nil)
|
74
|
+
_original_rand(shape, device)
|
75
|
+
end
|
76
|
+
|
77
|
+
def randn(shape, device: nil)
|
78
|
+
_original_randn(shape, device)
|
79
|
+
end
|
80
|
+
end
|
16
81
|
end
|
17
82
|
end
|
data/lib/candle/version.rb
CHANGED
@@ -1,3 +1,3 @@
|
|
1
1
|
module Candle
|
2
|
-
VERSION = "0.0.
|
3
|
-
end
|
2
|
+
VERSION = "1.0.0.pre.1"
|
3
|
+
end
|
data/lib/candle.rb
CHANGED
@@ -1,2 +1,8 @@
|
|
1
1
|
require_relative "candle/candle"
|
2
2
|
require_relative "candle/tensor"
|
3
|
+
require_relative "candle/device_utils"
|
4
|
+
require_relative "candle/embedding_model_type"
|
5
|
+
require_relative "candle/embedding_model"
|
6
|
+
require_relative "candle/reranker"
|
7
|
+
require_relative "candle/llm"
|
8
|
+
require_relative "candle/build_info"
|
metadata
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: red-candle
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.0.
|
4
|
+
version: 1.0.0.pre.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Christopher Petersen
|
@@ -9,7 +9,7 @@ authors:
|
|
9
9
|
autorequire:
|
10
10
|
bindir: bin
|
11
11
|
cert_chain: []
|
12
|
-
date: 2025-
|
12
|
+
date: 2025-07-09 00:00:00.000000000 Z
|
13
13
|
dependencies:
|
14
14
|
- !ruby/object:Gem::Dependency
|
15
15
|
name: rb_sys
|
@@ -27,7 +27,7 @@ dependencies:
|
|
27
27
|
version: '0'
|
28
28
|
description: huggingface/candle for Ruby
|
29
29
|
email:
|
30
|
-
-
|
30
|
+
- chris@petersen.io
|
31
31
|
- 2xijok@gmail.com
|
32
32
|
executables: []
|
33
33
|
extensions:
|
@@ -41,6 +41,12 @@ files:
|
|
41
41
|
- ext/candle/extconf.rb
|
42
42
|
- ext/candle/src/lib.rs
|
43
43
|
- lib/candle.rb
|
44
|
+
- lib/candle/build_info.rb
|
45
|
+
- lib/candle/device_utils.rb
|
46
|
+
- lib/candle/embedding_model.rb
|
47
|
+
- lib/candle/embedding_model_type.rb
|
48
|
+
- lib/candle/llm.rb
|
49
|
+
- lib/candle/reranker.rb
|
44
50
|
- lib/candle/tensor.rb
|
45
51
|
- lib/candle/version.rb
|
46
52
|
homepage: https://github.com/assaydepot/red-candle
|