red-candle 0.0.5 → 1.0.0.pre.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,66 @@
1
+ module Candle
2
+ module BuildInfo
3
+ def self.display_cuda_info
4
+ info = Candle.build_info
5
+
6
+ # Only display CUDA info if running in development or if CANDLE_VERBOSE is set
7
+ return unless ENV['CANDLE_VERBOSE'] || ENV['CANDLE_DEBUG'] || $DEBUG
8
+
9
+ if info["cuda_available"] == false
10
+ # Check if CUDA could be available on the system
11
+ cuda_potentially_available = ENV['CUDA_ROOT'] || ENV['CUDA_PATH'] ||
12
+ File.exist?('/usr/local/cuda') || File.exist?('/opt/cuda')
13
+
14
+ if cuda_potentially_available
15
+ warn "=" * 80
16
+ warn "Red Candle: CUDA detected on system but not enabled in build."
17
+ warn "To enable CUDA support (experimental), reinstall with:"
18
+ warn " CANDLE_ENABLE_CUDA=1 gem install red-candle"
19
+ warn "=" * 80
20
+ end
21
+ end
22
+ end
23
+
24
+ def self.cuda_available?
25
+ Candle.build_info["cuda_available"]
26
+ end
27
+
28
+ def self.metal_available?
29
+ Candle.build_info["metal_available"]
30
+ end
31
+
32
+ def self.mkl_available?
33
+ Candle.build_info["mkl_available"]
34
+ end
35
+
36
+ def self.accelerate_available?
37
+ Candle.build_info["accelerate_available"]
38
+ end
39
+
40
+ def self.cudnn_available?
41
+ Candle.build_info["cudnn_available"]
42
+ end
43
+
44
+ def self.summary
45
+ info = Candle.build_info
46
+
47
+ available_backends = []
48
+ available_backends << "Metal" if info["metal_available"]
49
+ available_backends << "CUDA" if info["cuda_available"]
50
+ available_backends << "CPU"
51
+
52
+ {
53
+ default_device: info["default_device"],
54
+ available_backends: available_backends,
55
+ cuda_available: info["cuda_available"],
56
+ metal_available: info["metal_available"],
57
+ mkl_available: info["mkl_available"],
58
+ accelerate_available: info["accelerate_available"],
59
+ cudnn_available: info["cudnn_available"]
60
+ }
61
+ end
62
+ end
63
+ end
64
+
65
+ # Display CUDA info on load
66
+ Candle::BuildInfo.display_cuda_info
@@ -0,0 +1,20 @@
1
+ module Candle
2
+ module DeviceUtils
3
+ # Get the best available device (Metal > CUDA > CPU)
4
+ def self.best_device
5
+ # Try devices in order of preference
6
+ begin
7
+ # Try Metal first (for Mac users)
8
+ Device.metal
9
+ rescue
10
+ begin
11
+ # Try CUDA next (for NVIDIA GPU users)
12
+ Device.cuda
13
+ rescue
14
+ # Fall back to CPU
15
+ Device.cpu
16
+ end
17
+ end
18
+ end
19
+ end
20
+ end
@@ -0,0 +1,32 @@
1
+ module Candle
2
+ class EmbeddingModel
3
+ # Default model path for Jina BERT embedding model
4
+ DEFAULT_MODEL_PATH = "jinaai/jina-embeddings-v2-base-en"
5
+
6
+ # Default tokenizer path that works well with the default model
7
+ DEFAULT_TOKENIZER_PATH = "sentence-transformers/all-MiniLM-L6-v2"
8
+
9
+ # Default embedding model type
10
+ DEFAULT_EMBEDDING_MODEL_TYPE = "jina_bert"
11
+
12
+ # Constructor for creating a new EmbeddingModel with optional parameters
13
+ # @param model_path [String, nil] The path to the model on Hugging Face
14
+ # @param tokenizer_path [String, nil] The path to the tokenizer on Hugging Face
15
+ # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
16
+ # @param model_type [String, nil] The type of embedding model to use
17
+ # @param embedding_size [Integer, nil] Override for the embedding size (optional)
18
+ def self.new(model_path: DEFAULT_MODEL_PATH,
19
+ tokenizer_path: DEFAULT_TOKENIZER_PATH,
20
+ device: Candle::Device.cpu,
21
+ model_type: DEFAULT_EMBEDDING_MODEL_TYPE,
22
+ embedding_size: nil)
23
+ _create(model_path, tokenizer_path, device, model_type, embedding_size)
24
+ end
25
+ # Returns the embedding for a string using the specified pooling method.
26
+ # @param str [String] The input text
27
+ # @param pooling_method [String] Pooling method: "pooled", "pooled_normalized", or "cls". Default: "pooled_normalized"
28
+ def embedding(str, pooling_method: "pooled_normalized")
29
+ _embedding(str, pooling_method)
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,31 @@
1
+ module Candle
2
+ # Enum for the supported embedding model types
3
+ module EmbeddingModelType
4
+ # Jina Bert embedding models (e.g., jina-embeddings-v2-base-en)
5
+ JINA_BERT = "jina_bert"
6
+
7
+ # Standard BERT embedding models (e.g., bert-base-uncased)
8
+ STANDARD_BERT = "standard_bert"
9
+
10
+ # MiniLM embedding models (e.g., all-MiniLM-L6-v2)
11
+ MINILM = "minilm"
12
+
13
+ # DistilBERT models which can be used for embeddings
14
+ DISTILBERT = "distilbert"
15
+
16
+ # Returns a list of all supported model types
17
+ def self.all
18
+ [JINA_BERT, STANDARD_BERT, DISTILBERT, MINILM]
19
+ end
20
+
21
+ # Returns suggested model paths for each model type
22
+ def self.suggested_model_paths
23
+ {
24
+ JINA_BERT => "jinaai/jina-embeddings-v2-base-en",
25
+ STANDARD_BERT => "scientistcom/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
26
+ MINILM => "sentence-transformers/all-MiniLM-L6-v2",
27
+ DISTILBERT => "scientistcom/distilbert-base-uncased-finetuned-sst-2-english",
28
+ }
29
+ end
30
+ end
31
+ end
data/lib/candle/llm.rb ADDED
@@ -0,0 +1,107 @@
1
+ module Candle
2
+ class LLM
3
+ # Simple chat interface for instruction models
4
+ def chat(messages, **options)
5
+ prompt = format_messages(messages)
6
+ generate(prompt, **options)
7
+ end
8
+
9
+ # Streaming chat interface
10
+ def chat_stream(messages, **options, &block)
11
+ prompt = format_messages(messages)
12
+ generate_stream(prompt, **options, &block)
13
+ end
14
+
15
+ def generate(prompt, config: GenerationConfig.balanced, reset_cache: true)
16
+ begin
17
+ _generate(prompt, config)
18
+ ensure
19
+ clear_cache if reset_cache
20
+ end
21
+ end
22
+
23
+ def generate_stream(prompt, config: GenerationConfig.balanced, reset_cache: true, &block)
24
+ begin
25
+ _generate_stream(prompt, config, &block)
26
+ ensure
27
+ clear_cache if reset_cache
28
+ end
29
+ end
30
+
31
+ def self.from_pretrained(model_id, device: Candle::Device.cpu)
32
+ _from_pretrained(model_id, device)
33
+ end
34
+
35
+ private
36
+
37
+ # Format messages into a prompt string
38
+ # This is a simple implementation - model-specific formatting should be added
39
+ def format_messages(messages)
40
+ formatted = messages.map do |msg|
41
+ case msg[:role]
42
+ when "system"
43
+ "System: #{msg[:content]}"
44
+ when "user"
45
+ "User: #{msg[:content]}"
46
+ when "assistant"
47
+ "Assistant: #{msg[:content]}"
48
+ else
49
+ msg[:content]
50
+ end
51
+ end.join("\n\n")
52
+
53
+ # Add a prompt for the assistant to respond
54
+ formatted + "\n\nAssistant:"
55
+ end
56
+ end
57
+
58
+ class GenerationConfig
59
+ # Convenience method to create config with overrides
60
+ def with(**overrides)
61
+ current_config = {
62
+ max_length: max_length,
63
+ temperature: temperature,
64
+ top_p: top_p,
65
+ top_k: top_k,
66
+ repetition_penalty: repetition_penalty,
67
+ seed: seed,
68
+ stop_sequences: stop_sequences,
69
+ include_prompt: include_prompt
70
+ }.compact
71
+
72
+ self.class.new(current_config.merge(overrides))
73
+ end
74
+
75
+ # Create a deterministic configuration (temperature = 0, fixed seed)
76
+ def self.deterministic(**opts)
77
+ defaults = {
78
+ temperature: 0.0,
79
+ top_p: nil,
80
+ top_k: 1,
81
+ seed: 42
82
+ }
83
+ new(defaults.merge(opts))
84
+ end
85
+
86
+ # Create a creative configuration (higher temperature, random seed)
87
+ def self.creative(**opts)
88
+ defaults = {
89
+ temperature: 1.0,
90
+ top_p: 0.95,
91
+ top_k: 50,
92
+ repetition_penalty: 1.2
93
+ }
94
+ new(defaults.merge(opts))
95
+ end
96
+
97
+ # Create a balanced configuration (moderate temperature, random seed)
98
+ def self.balanced(**opts)
99
+ defaults = {
100
+ temperature: 0.7,
101
+ top_p: 0.9,
102
+ top_k: 40
103
+ }
104
+ new(defaults.merge(opts))
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,24 @@
1
+ module Candle
2
+ class Reranker
3
+ # Default model path for cross-encoder/ms-marco-MiniLM-L-12-v2
4
+ DEFAULT_MODEL_PATH = "cross-encoder/ms-marco-MiniLM-L-12-v2"
5
+
6
+ # Constructor for creating a new Reranker with optional parameters
7
+ # @param model_path [String, nil] The path to the model on Hugging Face
8
+ # @param device [Candle::Device, Candle::Device.cpu] The device to use for computation
9
+ def self.new(model_path: DEFAULT_MODEL_PATH, device: Candle::Device.cpu)
10
+ _create(model_path, device)
11
+ end
12
+
13
+ # Returns the embedding for a string using the specified pooling method.
14
+ # @param query [String] The input text
15
+ # @param documents [Array<String>] The list of documents to compare against
16
+ # @param pooling_method [String] Pooling method: "pooler", "cls", or "mean". Default: "pooler"
17
+ # @param apply_sigmoid [Boolean] Whether to apply sigmoid to the scores. Default: true
18
+ def rerank(query, documents, pooling_method: "pooler", apply_sigmoid: true)
19
+ rerank_with_options(query, documents, pooling_method, apply_sigmoid).collect { |doc, score, doc_id|
20
+ { doc_id: doc_id, score: score, text: doc }
21
+ }
22
+ end
23
+ end
24
+ end
data/lib/candle/tensor.rb CHANGED
@@ -3,15 +3,80 @@ module Candle
3
3
  include Enumerable
4
4
 
5
5
  def each
6
- if self.rank == 1
7
- self.values.each do |value|
8
- yield value
6
+ case self.rank
7
+ when 0
8
+ # Scalar tensor - yield the single value
9
+ yield self.item
10
+ when 1
11
+ # 1D tensor - yield each value
12
+ # Check if we can use f32 values to avoid conversion
13
+ if dtype.to_s.downcase == "f32"
14
+ begin
15
+ values_f32.each { |value| yield value }
16
+ rescue NoMethodError
17
+ # If values_f32 isn't available yet (not recompiled), fall back
18
+ if device.to_s != "cpu"
19
+ # Move to CPU to avoid Metal F32->F64 conversion issue
20
+ to_device(Candle::Device.cpu).values.each { |value| yield value }
21
+ else
22
+ values.each { |value| yield value }
23
+ end
24
+ end
25
+ else
26
+ # For non-F32 dtypes, use regular values
27
+ values.each { |value| yield value }
9
28
  end
10
29
  else
30
+ # Multi-dimensional tensor - yield each sub-tensor
11
31
  shape.first.times do |i|
12
32
  yield self[i]
13
33
  end
14
34
  end
15
35
  end
36
+
37
+ # Convert scalar tensor to float
38
+ def to_f
39
+ if rank == 0
40
+ # Use item method which handles dtype conversion properly
41
+ item
42
+ else
43
+ raise ArgumentError, "to_f can only be called on scalar tensors (rank 0), but this tensor has rank #{rank}"
44
+ end
45
+ end
46
+
47
+ # Convert scalar tensor to integer
48
+ def to_i
49
+ to_f.to_i
50
+ end
51
+
52
+
53
+ # Override class methods to support keyword arguments for device
54
+ class << self
55
+ alias_method :_original_new, :new
56
+ alias_method :_original_ones, :ones
57
+ alias_method :_original_zeros, :zeros
58
+ alias_method :_original_rand, :rand
59
+ alias_method :_original_randn, :randn
60
+
61
+ def new(data, dtype = nil, device: nil)
62
+ _original_new(data, dtype, device)
63
+ end
64
+
65
+ def ones(shape, device: nil)
66
+ _original_ones(shape, device)
67
+ end
68
+
69
+ def zeros(shape, device: nil)
70
+ _original_zeros(shape, device)
71
+ end
72
+
73
+ def rand(shape, device: nil)
74
+ _original_rand(shape, device)
75
+ end
76
+
77
+ def randn(shape, device: nil)
78
+ _original_randn(shape, device)
79
+ end
80
+ end
16
81
  end
17
82
  end
@@ -1,3 +1,3 @@
1
1
  module Candle
2
- VERSION = "0.0.5"
3
- end
2
+ VERSION = "1.0.0.pre.1"
3
+ end
data/lib/candle.rb CHANGED
@@ -1,2 +1,8 @@
1
1
  require_relative "candle/candle"
2
2
  require_relative "candle/tensor"
3
+ require_relative "candle/device_utils"
4
+ require_relative "candle/embedding_model_type"
5
+ require_relative "candle/embedding_model"
6
+ require_relative "candle/reranker"
7
+ require_relative "candle/llm"
8
+ require_relative "candle/build_info"
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: red-candle
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.5
4
+ version: 1.0.0.pre.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Christopher Petersen
@@ -9,7 +9,7 @@ authors:
9
9
  autorequire:
10
10
  bindir: bin
11
11
  cert_chain: []
12
- date: 2025-03-18 00:00:00.000000000 Z
12
+ date: 2025-07-09 00:00:00.000000000 Z
13
13
  dependencies:
14
14
  - !ruby/object:Gem::Dependency
15
15
  name: rb_sys
@@ -27,7 +27,7 @@ dependencies:
27
27
  version: '0'
28
28
  description: huggingface/candle for Ruby
29
29
  email:
30
- - christopher.petersen@gmail.com
30
+ - chris@petersen.io
31
31
  - 2xijok@gmail.com
32
32
  executables: []
33
33
  extensions:
@@ -41,6 +41,12 @@ files:
41
41
  - ext/candle/extconf.rb
42
42
  - ext/candle/src/lib.rs
43
43
  - lib/candle.rb
44
+ - lib/candle/build_info.rb
45
+ - lib/candle/device_utils.rb
46
+ - lib/candle/embedding_model.rb
47
+ - lib/candle/embedding_model_type.rb
48
+ - lib/candle/llm.rb
49
+ - lib/candle/reranker.rb
44
50
  - lib/candle/tensor.rb
45
51
  - lib/candle/version.rb
46
52
  homepage: https://github.com/assaydepot/red-candle