onnx-ruby 0.1.1 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/MILESTONES.md +12 -0
- data/ext/onnx_ruby/onnx_ruby_ext.cpp +59 -2
- data/lib/onnx_ruby/classifier.rb +7 -1
- data/lib/onnx_ruby/session.rb +13 -1
- data/lib/onnx_ruby/session_pool.rb +4 -1
- data/lib/onnx_ruby/tensor.rb +8 -0
- data/lib/onnx_ruby/version.rb +1 -1
- metadata +2 -2
- data/CLAUDE.md +0 -334
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 16caafd6c79615589d18c10f16be77816875f823f0024b78d32357b025645d30
|
|
4
|
+
data.tar.gz: 915a8bbf683808063b752e81da2a5629157baf73959f99bee7a61d1c1871fad3
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: b3fcc3eb5eb0605ff0866af424b0436f2bd294675a5eb32c0348182ad65bea9b64997d19c3bb026c0b7b8290c8fe59038693f3737ca7984ffb4712cd0e88c2d3
|
|
7
|
+
data.tar.gz: d22068ed40991670d3dee5780339cea9232bc7627a07f796d9b0a7bc870b648250622e66b8eb3337a186eb9ff7b092e7e6bc6dd1da7c1ac4cfc99840b2b4a9ae
|
data/MILESTONES.md
ADDED
|
@@ -94,12 +94,63 @@ static Rice::Object tensor_to_ruby(const Ort::Value& tensor) {
|
|
|
94
94
|
break;
|
|
95
95
|
}
|
|
96
96
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
|
|
97
|
-
const
|
|
97
|
+
const uint8_t* data = reinterpret_cast<const uint8_t*>(tensor.GetTensorData<bool>());
|
|
98
98
|
for (size_t i = 0; i < total; i++) {
|
|
99
99
|
flat.push(Rice::Object(data[i] ? Qtrue : Qfalse));
|
|
100
100
|
}
|
|
101
101
|
break;
|
|
102
102
|
}
|
|
103
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
|
|
104
|
+
const uint8_t* data = tensor.GetTensorData<uint8_t>();
|
|
105
|
+
for (size_t i = 0; i < total; i++) {
|
|
106
|
+
flat.push(Rice::Object(INT2NUM(data[i])));
|
|
107
|
+
}
|
|
108
|
+
break;
|
|
109
|
+
}
|
|
110
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
|
|
111
|
+
const int8_t* data = tensor.GetTensorData<int8_t>();
|
|
112
|
+
for (size_t i = 0; i < total; i++) {
|
|
113
|
+
flat.push(Rice::Object(INT2NUM(data[i])));
|
|
114
|
+
}
|
|
115
|
+
break;
|
|
116
|
+
}
|
|
117
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
|
|
118
|
+
const uint16_t* data = tensor.GetTensorData<uint16_t>();
|
|
119
|
+
for (size_t i = 0; i < total; i++) {
|
|
120
|
+
flat.push(Rice::Object(INT2NUM(data[i])));
|
|
121
|
+
}
|
|
122
|
+
break;
|
|
123
|
+
}
|
|
124
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
|
|
125
|
+
const int16_t* data = tensor.GetTensorData<int16_t>();
|
|
126
|
+
for (size_t i = 0; i < total; i++) {
|
|
127
|
+
flat.push(Rice::Object(INT2NUM(data[i])));
|
|
128
|
+
}
|
|
129
|
+
break;
|
|
130
|
+
}
|
|
131
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
|
|
132
|
+
// float16 is stored as uint16_t; convert to float for Ruby
|
|
133
|
+
const uint16_t* data = tensor.GetTensorData<uint16_t>();
|
|
134
|
+
for (size_t i = 0; i < total; i++) {
|
|
135
|
+
// IEEE 754 half-precision to single-precision conversion
|
|
136
|
+
uint16_t h = data[i];
|
|
137
|
+
uint32_t sign = (h & 0x8000u) << 16;
|
|
138
|
+
uint32_t exponent = (h >> 10) & 0x1F;
|
|
139
|
+
uint32_t mantissa = h & 0x03FF;
|
|
140
|
+
uint32_t f;
|
|
141
|
+
if (exponent == 0) {
|
|
142
|
+
f = sign; // zero or subnormal (treat as zero for simplicity)
|
|
143
|
+
} else if (exponent == 31) {
|
|
144
|
+
f = sign | 0x7F800000u | (mantissa << 13); // inf or nan
|
|
145
|
+
} else {
|
|
146
|
+
f = sign | ((exponent + 112) << 23) | (mantissa << 13);
|
|
147
|
+
}
|
|
148
|
+
float val;
|
|
149
|
+
memcpy(&val, &f, sizeof(float));
|
|
150
|
+
flat.push(Rice::Object(rb_float_new(static_cast<double>(val))));
|
|
151
|
+
}
|
|
152
|
+
break;
|
|
153
|
+
}
|
|
103
154
|
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
|
|
104
155
|
size_t count = total;
|
|
105
156
|
std::vector<std::string> strings(count);
|
|
@@ -317,7 +368,13 @@ public:
|
|
|
317
368
|
}
|
|
318
369
|
|
|
319
370
|
size_t total_elements = 1;
|
|
320
|
-
for (auto dim : shape)
|
|
371
|
+
for (auto dim : shape) {
|
|
372
|
+
if (dim < 0) throw std::runtime_error("Negative shape dimension: " + std::to_string(dim));
|
|
373
|
+
if (dim > 0 && total_elements > SIZE_MAX / static_cast<size_t>(dim)) {
|
|
374
|
+
throw std::runtime_error("Shape dimension overflow");
|
|
375
|
+
}
|
|
376
|
+
total_elements *= static_cast<size_t>(dim);
|
|
377
|
+
}
|
|
321
378
|
|
|
322
379
|
if (dtype == "float") {
|
|
323
380
|
float_buffers.emplace_back(total_elements);
|
data/lib/onnx_ruby/classifier.rb
CHANGED
|
@@ -86,9 +86,15 @@ module OnnxRuby
|
|
|
86
86
|
end
|
|
87
87
|
|
|
88
88
|
def softmax(logits)
|
|
89
|
+
# Clamp extreme values to prevent overflow
|
|
89
90
|
max_val = logits.max
|
|
90
|
-
exps = logits.map
|
|
91
|
+
exps = logits.map do |v|
|
|
92
|
+
clamped = v - max_val
|
|
93
|
+
clamped = -500.0 if clamped < -500.0
|
|
94
|
+
Math.exp(clamped)
|
|
95
|
+
end
|
|
91
96
|
sum = exps.sum
|
|
97
|
+
sum = Float::MIN if sum.zero?
|
|
92
98
|
exps.map { |v| v / sum }
|
|
93
99
|
end
|
|
94
100
|
end
|
data/lib/onnx_ruby/session.rb
CHANGED
|
@@ -42,8 +42,13 @@ module OnnxRuby
|
|
|
42
42
|
if data.is_a?(Tensor)
|
|
43
43
|
{ name: name, data: data.flat_data, shape: data.shape, dtype: data.dtype.to_s }
|
|
44
44
|
else
|
|
45
|
-
flat = data.flatten
|
|
46
45
|
shape = infer_shape(data)
|
|
46
|
+
flat = data.flatten
|
|
47
|
+
expected_size = shape.reduce(1, :*)
|
|
48
|
+
if flat.length != expected_size
|
|
49
|
+
raise TensorError,
|
|
50
|
+
"input '#{name}' data size #{flat.length} does not match shape #{shape} (expected #{expected_size})"
|
|
51
|
+
end
|
|
47
52
|
dtype = infer_dtype(flat)
|
|
48
53
|
{ name: name, data: flat, shape: shape, dtype: dtype }
|
|
49
54
|
end
|
|
@@ -70,6 +75,13 @@ module OnnxRuby
|
|
|
70
75
|
current = data
|
|
71
76
|
while current.is_a?(Array)
|
|
72
77
|
shape << current.length
|
|
78
|
+
if current.length > 1 && current.all? { |el| el.is_a?(Array) }
|
|
79
|
+
lengths = current.map(&:length).uniq
|
|
80
|
+
if lengths.size > 1
|
|
81
|
+
raise TensorError,
|
|
82
|
+
"jagged array detected: sub-arrays have lengths #{lengths.sort.join(', ')} at dimension #{shape.size - 1}"
|
|
83
|
+
end
|
|
84
|
+
end
|
|
73
85
|
current = current.first
|
|
74
86
|
end
|
|
75
87
|
shape
|
|
@@ -64,12 +64,15 @@ module OnnxRuby
|
|
|
64
64
|
def checkin(session)
|
|
65
65
|
@mutex.synchronize do
|
|
66
66
|
@pool.push(session)
|
|
67
|
-
@condition.
|
|
67
|
+
@condition.broadcast
|
|
68
68
|
end
|
|
69
69
|
end
|
|
70
70
|
|
|
71
71
|
def create_session
|
|
72
72
|
Session.new(@model_path, **@session_opts)
|
|
73
|
+
rescue => e
|
|
74
|
+
@created -= 1
|
|
75
|
+
raise
|
|
73
76
|
end
|
|
74
77
|
end
|
|
75
78
|
end
|
data/lib/onnx_ruby/tensor.rb
CHANGED
|
@@ -59,6 +59,14 @@ module OnnxRuby
|
|
|
59
59
|
current = data
|
|
60
60
|
while current.is_a?(Array)
|
|
61
61
|
shape << current.length
|
|
62
|
+
# Check for jagged arrays: all sub-arrays at this level must have the same length
|
|
63
|
+
if current.length > 1 && current.all? { |el| el.is_a?(Array) }
|
|
64
|
+
lengths = current.map(&:length).uniq
|
|
65
|
+
if lengths.size > 1
|
|
66
|
+
raise TensorError,
|
|
67
|
+
"jagged array detected: sub-arrays have lengths #{lengths.sort.join(', ')} at dimension #{shape.size - 1}"
|
|
68
|
+
end
|
|
69
|
+
end
|
|
62
70
|
current = current.first
|
|
63
71
|
end
|
|
64
72
|
shape
|
data/lib/onnx_ruby/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: onnx-ruby
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.2.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Johannes Dwi Cahyo
|
|
@@ -74,9 +74,9 @@ extensions:
|
|
|
74
74
|
- ext/onnx_ruby/extconf.rb
|
|
75
75
|
extra_rdoc_files: []
|
|
76
76
|
files:
|
|
77
|
-
- CLAUDE.md
|
|
78
77
|
- Gemfile
|
|
79
78
|
- LICENSE
|
|
79
|
+
- MILESTONES.md
|
|
80
80
|
- README.md
|
|
81
81
|
- Rakefile
|
|
82
82
|
- examples/classification.rb
|
data/CLAUDE.md
DELETED
|
@@ -1,334 +0,0 @@
|
|
|
1
|
-
# onnx-ruby
|
|
2
|
-
|
|
3
|
-
## Project Overview
|
|
4
|
-
|
|
5
|
-
Ruby bindings for [ONNX Runtime](https://github.com/microsoft/onnxruntime), Microsoft's high-performance inference engine for ONNX models. This gem wraps the ONNX Runtime C++ API using **Rice** (same approach as zvec-ruby) to give Ruby developers fast local model inference.
|
|
6
|
-
|
|
7
|
-
This unlocks: local embeddings, text classification, named entity recognition, sentiment analysis, reranking, and any other ML model exported to ONNX format — all without Python or API calls.
|
|
8
|
-
|
|
9
|
-
## Author
|
|
10
|
-
|
|
11
|
-
- Name: Johannes Dwi Cahyo
|
|
12
|
-
- GitHub: johannesdwicahyo
|
|
13
|
-
- Repo: git@github.com:johannesdwicahyo/onnx-ruby.git
|
|
14
|
-
|
|
15
|
-
## Technical Approach
|
|
16
|
-
|
|
17
|
-
### Binding Strategy: Rice 4.x (C++ → Ruby)
|
|
18
|
-
|
|
19
|
-
ONNX Runtime has a C++ API (`onnxruntime_cxx_api.h`). We wrap it using Rice, exactly like zvec-ruby.
|
|
20
|
-
|
|
21
|
-
**Important lessons from zvec-ruby to apply here:**
|
|
22
|
-
- Use `require "mkmf-rice"` (not `require "rice/extconf"`) for Rice 4.x
|
|
23
|
-
- Use `define_module_under()` and `define_enum_under()` for Rice 4.x
|
|
24
|
-
- Wrap all raw `VALUE` returns in `Rice::Object()` when pushing to Arrays
|
|
25
|
-
- Use `std::make_shared` when C++ API expects shared_ptr
|
|
26
|
-
- Extract results to Ruby Hashes/Arrays in C++ before returning (avoid dangling pointers)
|
|
27
|
-
- Ship precompiled gems — ONNX Runtime is a large C++ library, nobody wants to build it
|
|
28
|
-
- On macOS use `-force_load` for static archives with static initializers
|
|
29
|
-
- Default to safe options (like mmap=true in zvec-ruby)
|
|
30
|
-
|
|
31
|
-
### ONNX Runtime Linking
|
|
32
|
-
|
|
33
|
-
ONNX Runtime provides **prebuilt shared libraries** (`.so`/`.dylib`) for all platforms. Unlike zvec (which required building from source), we can download the official release and link against it. This is much simpler.
|
|
34
|
-
|
|
35
|
-
Download from: https://github.com/microsoft/onnxruntime/releases
|
|
36
|
-
|
|
37
|
-
The `extconf.rb` should:
|
|
38
|
-
1. Check for `ONNX_RUNTIME_DIR` env var
|
|
39
|
-
2. Check for system-installed onnxruntime via pkg-config
|
|
40
|
-
3. Auto-download the correct prebuilt release if neither found
|
|
41
|
-
|
|
42
|
-
### Precompiled Gems
|
|
43
|
-
|
|
44
|
-
For precompiled gems, statically link or bundle the ONNX Runtime `.dylib`/`.so` inside the gem. The gem will be ~50-80MB but users get zero-install experience.
|
|
45
|
-
|
|
46
|
-
## Core API Design
|
|
47
|
-
|
|
48
|
-
```ruby
|
|
49
|
-
require "onnx_ruby"
|
|
50
|
-
|
|
51
|
-
# --- Session (model loading) ---
|
|
52
|
-
|
|
53
|
-
# Load a model
|
|
54
|
-
session = OnnxRuby::Session.new("model.onnx")
|
|
55
|
-
|
|
56
|
-
# With options
|
|
57
|
-
session = OnnxRuby::Session.new("model.onnx",
|
|
58
|
-
providers: [:cpu], # :cpu, :cuda, :coreml, :tensorrt
|
|
59
|
-
inter_threads: 4,
|
|
60
|
-
intra_threads: 2,
|
|
61
|
-
log_level: :warning
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
# Model info
|
|
65
|
-
session.inputs # => [{ name: "input_ids", type: :int64, shape: [-1, 512] }]
|
|
66
|
-
session.outputs # => [{ name: "embeddings", type: :float, shape: [-1, 384] }]
|
|
67
|
-
|
|
68
|
-
# --- Inference ---
|
|
69
|
-
|
|
70
|
-
# Run inference
|
|
71
|
-
result = session.run(
|
|
72
|
-
{ "input_ids" => [[101, 2023, 2003, 1037, 3231, 102]] },
|
|
73
|
-
)
|
|
74
|
-
result["embeddings"] # => [[0.0123, -0.0456, ...]]
|
|
75
|
-
|
|
76
|
-
# With output names
|
|
77
|
-
result = session.run(inputs, output_names: ["embeddings"])
|
|
78
|
-
|
|
79
|
-
# --- Tensor ---
|
|
80
|
-
|
|
81
|
-
# Create tensors explicitly
|
|
82
|
-
tensor = OnnxRuby::Tensor.new([1, 2, 3, 4], shape: [2, 2], dtype: :int64)
|
|
83
|
-
tensor.to_a # => [[1, 2], [3, 4]]
|
|
84
|
-
tensor.shape # => [2, 2]
|
|
85
|
-
tensor.dtype # => :int64
|
|
86
|
-
|
|
87
|
-
# From flat array
|
|
88
|
-
tensor = OnnxRuby::Tensor.float([0.1, 0.2, 0.3], shape: [1, 3])
|
|
89
|
-
|
|
90
|
-
# --- High-Level Helpers ---
|
|
91
|
-
|
|
92
|
-
# Embedding model (wraps session with pre/post processing)
|
|
93
|
-
embedder = OnnxRuby::Embedder.new("all-MiniLM-L6-v2.onnx",
|
|
94
|
-
tokenizer: "sentence-transformers/all-MiniLM-L6-v2" # requires tokenizer-ruby
|
|
95
|
-
)
|
|
96
|
-
embeddings = embedder.embed("Hello world") # => [0.0123, ...]
|
|
97
|
-
embeddings = embedder.embed_batch(["Hello", "World"]) # => [[...], [...]]
|
|
98
|
-
|
|
99
|
-
# Classifier
|
|
100
|
-
classifier = OnnxRuby::Classifier.new("intent_model.onnx",
|
|
101
|
-
tokenizer: "bert-base-uncased",
|
|
102
|
-
labels: ["greeting", "farewell", "question", "command"]
|
|
103
|
-
)
|
|
104
|
-
classifier.predict("Hello there!") # => { label: "greeting", score: 0.95 }
|
|
105
|
-
```
|
|
106
|
-
|
|
107
|
-
## Features to Implement
|
|
108
|
-
|
|
109
|
-
### Phase 1 — Core (MVP)
|
|
110
|
-
- [ ] `Session.new(path, options)` — load ONNX model
|
|
111
|
-
- [ ] `session.run(inputs)` — run inference, return outputs
|
|
112
|
-
- [ ] `session.inputs` / `session.outputs` — model metadata
|
|
113
|
-
- [ ] `Tensor` class — create and manipulate tensors
|
|
114
|
-
- [ ] Support dtypes: float32, float64, int32, int64, string, bool
|
|
115
|
-
- [ ] Support shapes: 1D, 2D, 3D, 4D tensors
|
|
116
|
-
- [ ] CPU execution provider
|
|
117
|
-
|
|
118
|
-
### Phase 2 — Providers & Options
|
|
119
|
-
- [ ] CoreML provider (macOS acceleration)
|
|
120
|
-
- [ ] CUDA provider (NVIDIA GPU)
|
|
121
|
-
- [ ] Session options: threading, memory, optimization level
|
|
122
|
-
- [ ] Model optimization: `OnnxRuby.optimize("model.onnx", "optimized.onnx")`
|
|
123
|
-
- [ ] Dynamic shapes (batching)
|
|
124
|
-
|
|
125
|
-
### Phase 3 — High-Level API
|
|
126
|
-
- [ ] `Embedder` — embedding model wrapper (tokenize → infer → normalize)
|
|
127
|
-
- [ ] `Classifier` — text classification wrapper
|
|
128
|
-
- [ ] `Reranker` — cross-encoder reranking wrapper
|
|
129
|
-
- [ ] Integration with tokenizer-ruby for text preprocessing
|
|
130
|
-
- [ ] Model hub: `OnnxRuby::Hub.download("sentence-transformers/all-MiniLM-L6-v2")`
|
|
131
|
-
|
|
132
|
-
### Phase 4 — Rails Integration
|
|
133
|
-
- [ ] `OnnxRuby.configure { |c| c.models_path = "app/models/onnx" }`
|
|
134
|
-
- [ ] Lazy model loading (load on first inference)
|
|
135
|
-
- [ ] Connection pool for thread-safe concurrent inference
|
|
136
|
-
- [ ] ActiveModel integration for embedding generation
|
|
137
|
-
|
|
138
|
-
## Project Structure
|
|
139
|
-
|
|
140
|
-
```
|
|
141
|
-
onnx-ruby/
|
|
142
|
-
├── CLAUDE.md
|
|
143
|
-
├── Gemfile
|
|
144
|
-
├── Rakefile
|
|
145
|
-
├── LICENSE # MIT
|
|
146
|
-
├── README.md
|
|
147
|
-
├── onnx-ruby.gemspec
|
|
148
|
-
├── lib/
|
|
149
|
-
│ ├── onnx_ruby.rb
|
|
150
|
-
│ └── onnx_ruby/
|
|
151
|
-
│ ├── version.rb
|
|
152
|
-
│ ├── session.rb
|
|
153
|
-
│ ├── tensor.rb
|
|
154
|
-
│ ├── embedder.rb
|
|
155
|
-
│ ├── classifier.rb
|
|
156
|
-
│ └── reranker.rb
|
|
157
|
-
├── ext/
|
|
158
|
-
│ └── onnx_ruby/
|
|
159
|
-
│ ├── extconf.rb
|
|
160
|
-
│ └── onnx_ruby_ext.cpp
|
|
161
|
-
├── test/
|
|
162
|
-
│ ├── test_helper.rb
|
|
163
|
-
│ ├── test_session.rb
|
|
164
|
-
│ ├── test_tensor.rb
|
|
165
|
-
│ ├── test_inference.rb
|
|
166
|
-
│ └── models/ # small test ONNX models
|
|
167
|
-
│ └── .gitkeep
|
|
168
|
-
├── script/
|
|
169
|
-
│ ├── download_onnxruntime.sh
|
|
170
|
-
│ └── package_native_gem.rb
|
|
171
|
-
└── examples/
|
|
172
|
-
├── embedding.rb
|
|
173
|
-
├── classification.rb
|
|
174
|
-
└── with_zvec.rb # full RAG example with zvec-ruby
|
|
175
|
-
```
|
|
176
|
-
|
|
177
|
-
## Dependencies
|
|
178
|
-
|
|
179
|
-
### Runtime
|
|
180
|
-
- `rice` (>= 4.0) — C++ to Ruby bindings
|
|
181
|
-
- ONNX Runtime shared library (bundled in precompiled gems)
|
|
182
|
-
|
|
183
|
-
### Optional
|
|
184
|
-
- `tokenizer-ruby` — for Embedder/Classifier text preprocessing
|
|
185
|
-
|
|
186
|
-
### Development
|
|
187
|
-
- `rake-compiler` for building native extensions
|
|
188
|
-
- `rake-compiler-dock` for cross-compilation
|
|
189
|
-
- `minitest` for testing
|
|
190
|
-
- `rake` for tasks
|
|
191
|
-
|
|
192
|
-
## Key C++ Binding Details
|
|
193
|
-
|
|
194
|
-
### ONNX Runtime C++ API Structure
|
|
195
|
-
|
|
196
|
-
```cpp
|
|
197
|
-
#include <onnxruntime_cxx_api.h>
|
|
198
|
-
|
|
199
|
-
// Key classes to wrap:
|
|
200
|
-
Ort::Env // Runtime environment (singleton)
|
|
201
|
-
Ort::Session // Model session
|
|
202
|
-
Ort::SessionOptions
|
|
203
|
-
Ort::Value // Tensor (input/output)
|
|
204
|
-
Ort::MemoryInfo // Memory allocation info
|
|
205
|
-
Ort::TypeInfo // Model input/output type info
|
|
206
|
-
Ort::TensorTypeAndShapeInfo
|
|
207
|
-
```
|
|
208
|
-
|
|
209
|
-
### extconf.rb approach
|
|
210
|
-
|
|
211
|
-
```ruby
|
|
212
|
-
require "mkmf-rice"
|
|
213
|
-
|
|
214
|
-
# Try to find ONNX Runtime
|
|
215
|
-
ort_dir = ENV["ONNX_RUNTIME_DIR"]
|
|
216
|
-
|
|
217
|
-
unless ort_dir
|
|
218
|
-
# Auto-download prebuilt ONNX Runtime for the current platform
|
|
219
|
-
# from https://github.com/microsoft/onnxruntime/releases
|
|
220
|
-
ort_dir = download_onnxruntime() # helper function
|
|
221
|
-
end
|
|
222
|
-
|
|
223
|
-
dir_config("onnxruntime", "#{ort_dir}/include", "#{ort_dir}/lib")
|
|
224
|
-
$INCFLAGS << " -I#{ort_dir}/include"
|
|
225
|
-
$LDFLAGS << " -L#{ort_dir}/lib"
|
|
226
|
-
$libs << " -lonnxruntime"
|
|
227
|
-
|
|
228
|
-
have_header("onnxruntime_cxx_api.h") or
|
|
229
|
-
abort "Cannot find ONNX Runtime headers"
|
|
230
|
-
|
|
231
|
-
create_makefile("onnx_ruby/onnx_ruby_ext")
|
|
232
|
-
```
|
|
233
|
-
|
|
234
|
-
### C++ Extension Skeleton
|
|
235
|
-
|
|
236
|
-
```cpp
|
|
237
|
-
#include <rice/rice.hpp>
|
|
238
|
-
#include <onnxruntime_cxx_api.h>
|
|
239
|
-
|
|
240
|
-
using namespace Rice;
|
|
241
|
-
|
|
242
|
-
// Global ORT environment (initialized once)
|
|
243
|
-
static Ort::Env& get_env() {
|
|
244
|
-
static Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "onnx_ruby");
|
|
245
|
-
return env;
|
|
246
|
-
}
|
|
247
|
-
|
|
248
|
-
// Wrap Ort::Session
|
|
249
|
-
// Wrap Ort::Value (Tensor)
|
|
250
|
-
// Handle type conversion: Ruby Array ↔ ORT Tensor
|
|
251
|
-
// Map ORT errors to Ruby exceptions
|
|
252
|
-
|
|
253
|
-
void Init_onnx_ruby_ext() {
|
|
254
|
-
Module rb_mOnnxRuby = define_module("OnnxRuby");
|
|
255
|
-
Module rb_mExt = define_module_under(rb_mOnnxRuby, "Ext");
|
|
256
|
-
|
|
257
|
-
// Define Session, Tensor, etc.
|
|
258
|
-
}
|
|
259
|
-
```
|
|
260
|
-
|
|
261
|
-
### Critical: Tensor ↔ Ruby Array Conversion
|
|
262
|
-
|
|
263
|
-
The most complex part. Need to handle:
|
|
264
|
-
- Ruby Array of floats → ORT float32 tensor (most common for embeddings)
|
|
265
|
-
- Ruby Array of integers → ORT int64 tensor (for token IDs)
|
|
266
|
-
- Nested Ruby Arrays → multi-dimensional tensors
|
|
267
|
-
- ORT output tensors → Ruby Arrays (with proper Float wrapping via `Rice::Object(rb_float_new())`)
|
|
268
|
-
|
|
269
|
-
```cpp
|
|
270
|
-
// Ruby Array → ORT Tensor
|
|
271
|
-
Ort::Value array_to_tensor(Rice::Array arr, const std::vector<int64_t>& shape) {
|
|
272
|
-
// Flatten nested arrays
|
|
273
|
-
// Detect dtype from Ruby values
|
|
274
|
-
// Create ORT tensor with proper memory allocation
|
|
275
|
-
}
|
|
276
|
-
|
|
277
|
-
// ORT Tensor → Ruby Array
|
|
278
|
-
Rice::Object tensor_to_array(const Ort::Value& tensor) {
|
|
279
|
-
// Read shape
|
|
280
|
-
// Read dtype
|
|
281
|
-
// Copy data to Ruby Array (with Rice::Object wrapping!)
|
|
282
|
-
}
|
|
283
|
-
```
|
|
284
|
-
|
|
285
|
-
## Testing Strategy
|
|
286
|
-
|
|
287
|
-
- Test with small ONNX models (generate test models with Python: `torch.onnx.export`)
|
|
288
|
-
- Test model loading and metadata inspection
|
|
289
|
-
- Test inference with known inputs/outputs
|
|
290
|
-
- Test all supported dtypes
|
|
291
|
-
- Test batch inference
|
|
292
|
-
- Test error handling (invalid model, wrong input shape, etc.)
|
|
293
|
-
- Benchmark against Python's onnxruntime for correctness
|
|
294
|
-
|
|
295
|
-
### Create test models (Python script to include):
|
|
296
|
-
```python
|
|
297
|
-
# script/create_test_models.py
|
|
298
|
-
import torch
|
|
299
|
-
import torch.nn as nn
|
|
300
|
-
|
|
301
|
-
# Simple linear model for testing
|
|
302
|
-
class SimpleModel(nn.Module):
|
|
303
|
-
def __init__(self):
|
|
304
|
-
super().__init__()
|
|
305
|
-
self.linear = nn.Linear(4, 3)
|
|
306
|
-
def forward(self, x):
|
|
307
|
-
return self.linear(x)
|
|
308
|
-
|
|
309
|
-
model = SimpleModel()
|
|
310
|
-
dummy = torch.randn(1, 4)
|
|
311
|
-
torch.onnx.export(model, dummy, "test/models/simple.onnx",
|
|
312
|
-
input_names=["input"], output_names=["output"])
|
|
313
|
-
```
|
|
314
|
-
|
|
315
|
-
## Publishing
|
|
316
|
-
|
|
317
|
-
- RubyGems.org: `GEM_HOST_API_KEY=rubygems_5d46e91ceb51fb455e98a7f491a2321bb6879f9be35d6842 gem push onnx-ruby-*.gem`
|
|
318
|
-
- gem.coop: `GEM_HOST_API_KEY=hjncPswY8PbGDfLPw4RMj928 gem push onnx-ruby-*.gem --host https://beta.gem.coop/@johannesdwicahyo`
|
|
319
|
-
|
|
320
|
-
## Notes from zvec-ruby Experience
|
|
321
|
-
|
|
322
|
-
- **Rice 4.x API**: `define_module_under()`, `define_enum_under()`, not the 3.x syntax
|
|
323
|
-
- **Rice::Object wrapping**: ALWAYS wrap `rb_float_new()`, `Qtrue`, `Qnil` in `Rice::Object()` when pushing to Arrays
|
|
324
|
-
- **shared_ptr**: Use `make_shared` when C++ expects `shared_ptr`, accept by `const T&` in bindings
|
|
325
|
-
- **Extract results in C++**: Don't try to push C++ objects directly into Ruby arrays. Extract to Hashes/Arrays first.
|
|
326
|
-
- **Precompiled gems**: Essential. Use `script/package_native_gem.rb` for macOS, `rake-compiler-dock` for Linux.
|
|
327
|
-
- **ONNX Runtime ships prebuilt binaries**: Much easier than zvec. Download from GitHub releases, link against `.dylib`/`.so`.
|
|
328
|
-
- **Static initializers**: May need `-force_load` if ONNX Runtime uses static registration patterns.
|
|
329
|
-
- **mmap/memory**: ONNX Runtime manages its own memory via allocators. Let it handle memory, don't fight it.
|
|
330
|
-
|
|
331
|
-
## Existing Ruby ONNX Solutions
|
|
332
|
-
|
|
333
|
-
- `onnxruntime` gem by ankane — exists but is FFI-based and limited. We can provide better performance and API with Rice + additional high-level features (Embedder, Classifier, Reranker).
|
|
334
|
-
- Differentiate by: better API, precompiled gems, high-level wrappers, tokenizer-ruby integration
|