onnx-ruby 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CLAUDE.md +334 -0
- data/Gemfile +5 -0
- data/LICENSE +21 -0
- data/README.md +301 -0
- data/Rakefile +17 -0
- data/examples/classification.rb +35 -0
- data/examples/embedding.rb +35 -0
- data/examples/real_world_demo.rb +170 -0
- data/examples/with_zvec.rb +54 -0
- data/ext/onnx_ruby/extconf.rb +75 -0
- data/ext/onnx_ruby/onnx_ruby_ext.cpp +436 -0
- data/lib/onnx_ruby/classifier.rb +107 -0
- data/lib/onnx_ruby/configuration.rb +16 -0
- data/lib/onnx_ruby/embedder.rb +147 -0
- data/lib/onnx_ruby/hub.rb +73 -0
- data/lib/onnx_ruby/lazy_session.rb +38 -0
- data/lib/onnx_ruby/model.rb +71 -0
- data/lib/onnx_ruby/reranker.rb +91 -0
- data/lib/onnx_ruby/session.rb +89 -0
- data/lib/onnx_ruby/session_pool.rb +75 -0
- data/lib/onnx_ruby/tensor.rb +92 -0
- data/lib/onnx_ruby/version.rb +5 -0
- data/lib/onnx_ruby.rb +45 -0
- data/onnx-ruby.gemspec +37 -0
- metadata +125 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
#include <rice/rice.hpp>
|
|
2
|
+
#include <rice/stl.hpp>
|
|
3
|
+
#include <onnxruntime_cxx_api.h>
|
|
4
|
+
#ifdef __APPLE__
|
|
5
|
+
#include <coreml_provider_factory.h>
|
|
6
|
+
#endif
|
|
7
|
+
#include <vector>
|
|
8
|
+
#include <string>
|
|
9
|
+
#include <cstring>
|
|
10
|
+
#include <stdexcept>
|
|
11
|
+
#include <memory>
|
|
12
|
+
#include <unordered_map>
|
|
13
|
+
|
|
14
|
+
using namespace Rice;
|
|
15
|
+
|
|
16
|
+
// Global ORT environment (initialized once)
|
|
17
|
+
static Ort::Env& get_env(int log_level = ORT_LOGGING_LEVEL_WARNING) {
|
|
18
|
+
static Ort::Env env(static_cast<OrtLoggingLevel>(log_level), "onnx_ruby");
|
|
19
|
+
return env;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
// Map ORT element type to Ruby symbol name
|
|
23
|
+
static std::string ort_type_to_string(ONNXTensorElementDataType type) {
|
|
24
|
+
switch (type) {
|
|
25
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return "float32";
|
|
26
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return "float64";
|
|
27
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return "int32";
|
|
28
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return "int64";
|
|
29
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: return "string";
|
|
30
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: return "bool";
|
|
31
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return "uint8";
|
|
32
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return "int8";
|
|
33
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return "uint16";
|
|
34
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return "int16";
|
|
35
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: return "float16";
|
|
36
|
+
default: return "unknown";
|
|
37
|
+
}
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// Flatten a nested Ruby array into a flat vector and compute shape
|
|
41
|
+
static void flatten_ruby_array(VALUE arr, std::vector<VALUE>& flat, std::vector<int64_t>& shape, int depth) {
|
|
42
|
+
if (!RB_TYPE_P(arr, T_ARRAY)) {
|
|
43
|
+
flat.push_back(arr);
|
|
44
|
+
return;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
long len = RARRAY_LEN(arr);
|
|
48
|
+
if (depth == (int)shape.size()) {
|
|
49
|
+
shape.push_back(len);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
for (long i = 0; i < len; i++) {
|
|
53
|
+
flatten_ruby_array(rb_ary_entry(arr, i), flat, shape, depth + 1);
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
// Convert an ORT output tensor to a nested Ruby array
|
|
58
|
+
static Rice::Object tensor_to_ruby(const Ort::Value& tensor) {
|
|
59
|
+
auto type_info = tensor.GetTensorTypeAndShapeInfo();
|
|
60
|
+
auto shape = type_info.GetShape();
|
|
61
|
+
auto elem_type = type_info.GetElementType();
|
|
62
|
+
size_t total = type_info.GetElementCount();
|
|
63
|
+
|
|
64
|
+
// Build flat Ruby array first
|
|
65
|
+
Rice::Array flat;
|
|
66
|
+
|
|
67
|
+
switch (elem_type) {
|
|
68
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
|
|
69
|
+
const float* data = tensor.GetTensorData<float>();
|
|
70
|
+
for (size_t i = 0; i < total; i++) {
|
|
71
|
+
flat.push(Rice::Object(rb_float_new(static_cast<double>(data[i]))));
|
|
72
|
+
}
|
|
73
|
+
break;
|
|
74
|
+
}
|
|
75
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
|
|
76
|
+
const double* data = tensor.GetTensorData<double>();
|
|
77
|
+
for (size_t i = 0; i < total; i++) {
|
|
78
|
+
flat.push(Rice::Object(rb_float_new(data[i])));
|
|
79
|
+
}
|
|
80
|
+
break;
|
|
81
|
+
}
|
|
82
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
|
|
83
|
+
const int32_t* data = tensor.GetTensorData<int32_t>();
|
|
84
|
+
for (size_t i = 0; i < total; i++) {
|
|
85
|
+
flat.push(Rice::Object(INT2NUM(data[i])));
|
|
86
|
+
}
|
|
87
|
+
break;
|
|
88
|
+
}
|
|
89
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
|
|
90
|
+
const int64_t* data = tensor.GetTensorData<int64_t>();
|
|
91
|
+
for (size_t i = 0; i < total; i++) {
|
|
92
|
+
flat.push(Rice::Object(LONG2NUM(data[i])));
|
|
93
|
+
}
|
|
94
|
+
break;
|
|
95
|
+
}
|
|
96
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
|
|
97
|
+
const bool* data = tensor.GetTensorData<bool>();
|
|
98
|
+
for (size_t i = 0; i < total; i++) {
|
|
99
|
+
flat.push(Rice::Object(data[i] ? Qtrue : Qfalse));
|
|
100
|
+
}
|
|
101
|
+
break;
|
|
102
|
+
}
|
|
103
|
+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: {
|
|
104
|
+
size_t count = total;
|
|
105
|
+
std::vector<std::string> strings(count);
|
|
106
|
+
// GetStringTensorContent approach
|
|
107
|
+
size_t total_len = tensor.GetStringTensorDataLength();
|
|
108
|
+
std::vector<char> buffer(total_len);
|
|
109
|
+
std::vector<size_t> offsets(count);
|
|
110
|
+
tensor.GetStringTensorContent(buffer.data(), total_len, offsets.data(), count);
|
|
111
|
+
for (size_t i = 0; i < count; i++) {
|
|
112
|
+
size_t start = offsets[i];
|
|
113
|
+
size_t end = (i + 1 < count) ? offsets[i + 1] : total_len;
|
|
114
|
+
flat.push(Rice::Object(rb_str_new(buffer.data() + start, end - start)));
|
|
115
|
+
}
|
|
116
|
+
break;
|
|
117
|
+
}
|
|
118
|
+
default:
|
|
119
|
+
throw std::runtime_error("Unsupported output tensor element type: " + ort_type_to_string(elem_type));
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
// Reshape flat array into nested arrays according to shape
|
|
123
|
+
if (shape.empty() || shape.size() == 1) {
|
|
124
|
+
return flat;
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
// Reshape from innermost dimension outward
|
|
128
|
+
Rice::Array current = flat;
|
|
129
|
+
for (int d = (int)shape.size() - 1; d >= 1; d--) {
|
|
130
|
+
int64_t dim_size = shape[d];
|
|
131
|
+
Rice::Array reshaped;
|
|
132
|
+
long total_items = RARRAY_LEN(current.value());
|
|
133
|
+
for (long i = 0; i < total_items; i += dim_size) {
|
|
134
|
+
Rice::Array slice;
|
|
135
|
+
for (int64_t j = 0; j < dim_size; j++) {
|
|
136
|
+
slice.push(Rice::Object(rb_ary_entry(current.value(), i + j)));
|
|
137
|
+
}
|
|
138
|
+
reshaped.push(slice);
|
|
139
|
+
}
|
|
140
|
+
current = reshaped;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
return current;
|
|
144
|
+
}
|
|
145
|
+
|
|
146
|
+
// Map Ruby optimization level symbol to ORT enum
|
|
147
|
+
static GraphOptimizationLevel parse_opt_level(const std::string& level) {
|
|
148
|
+
if (level == "none" || level == "disabled") return GraphOptimizationLevel::ORT_DISABLE_ALL;
|
|
149
|
+
if (level == "basic") return GraphOptimizationLevel::ORT_ENABLE_BASIC;
|
|
150
|
+
if (level == "extended") return GraphOptimizationLevel::ORT_ENABLE_EXTENDED;
|
|
151
|
+
return GraphOptimizationLevel::ORT_ENABLE_ALL; // "all" or default
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// Optimize a model and save to disk
|
|
155
|
+
static Rice::Object optimize_model(const std::string& input_path, const std::string& output_path,
|
|
156
|
+
const std::string& opt_level) {
|
|
157
|
+
Ort::SessionOptions opts;
|
|
158
|
+
opts.SetGraphOptimizationLevel(parse_opt_level(opt_level));
|
|
159
|
+
opts.SetOptimizedModelFilePath(output_path.c_str());
|
|
160
|
+
|
|
161
|
+
// Creating the session triggers optimization and saves the optimized model
|
|
162
|
+
Ort::Session session(get_env(), input_path.c_str(), opts);
|
|
163
|
+
return Rice::Object(Qtrue);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
// Get available execution providers
|
|
167
|
+
static Rice::Array available_providers() {
|
|
168
|
+
Rice::Array result;
|
|
169
|
+
auto providers = Ort::GetAvailableProviders();
|
|
170
|
+
for (const auto& p : providers) {
|
|
171
|
+
result.push(Rice::String(p));
|
|
172
|
+
}
|
|
173
|
+
return result;
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
class SessionWrapper {
|
|
177
|
+
public:
|
|
178
|
+
SessionWrapper(const std::string& model_path, int log_level, int intra_threads, int inter_threads,
|
|
179
|
+
const std::string& opt_level, bool memory_pattern, bool cpu_mem_arena,
|
|
180
|
+
const std::string& execution_mode, Rice::Array providers) {
|
|
181
|
+
Ort::SessionOptions opts;
|
|
182
|
+
|
|
183
|
+
if (intra_threads > 0) opts.SetIntraOpNumThreads(intra_threads);
|
|
184
|
+
if (inter_threads > 0) opts.SetInterOpNumThreads(inter_threads);
|
|
185
|
+
opts.SetGraphOptimizationLevel(parse_opt_level(opt_level));
|
|
186
|
+
|
|
187
|
+
if (!memory_pattern) opts.DisableMemPattern();
|
|
188
|
+
if (!cpu_mem_arena) opts.DisableCpuMemArena();
|
|
189
|
+
|
|
190
|
+
if (execution_mode == "parallel") {
|
|
191
|
+
opts.SetExecutionMode(ExecutionMode::ORT_PARALLEL);
|
|
192
|
+
} else {
|
|
193
|
+
opts.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
// Append execution providers
|
|
197
|
+
for (long i = 0; i < RARRAY_LEN(providers.value()); i++) {
|
|
198
|
+
std::string provider = Rice::detail::From_Ruby<std::string>().convert(
|
|
199
|
+
rb_ary_entry(providers.value(), i));
|
|
200
|
+
|
|
201
|
+
if (provider == "coreml") {
|
|
202
|
+
#ifdef __APPLE__
|
|
203
|
+
uint32_t coreml_flags = COREML_FLAG_USE_NONE;
|
|
204
|
+
auto status = OrtSessionOptionsAppendExecutionProvider_CoreML(opts, coreml_flags);
|
|
205
|
+
if (status) {
|
|
206
|
+
std::string msg = Ort::GetApi().GetErrorMessage(status);
|
|
207
|
+
Ort::GetApi().ReleaseStatus(status);
|
|
208
|
+
throw std::runtime_error("CoreML provider error: " + msg);
|
|
209
|
+
}
|
|
210
|
+
#else
|
|
211
|
+
throw std::runtime_error("CoreML provider is only available on macOS/iOS");
|
|
212
|
+
#endif
|
|
213
|
+
} else if (provider == "cuda") {
|
|
214
|
+
OrtCUDAProviderOptions cuda_opts;
|
|
215
|
+
memset(&cuda_opts, 0, sizeof(cuda_opts));
|
|
216
|
+
opts.AppendExecutionProvider_CUDA(cuda_opts);
|
|
217
|
+
} else if (provider == "tensorrt") {
|
|
218
|
+
OrtTensorRTProviderOptions trt_opts;
|
|
219
|
+
memset(&trt_opts, 0, sizeof(trt_opts));
|
|
220
|
+
opts.AppendExecutionProvider_TensorRT(trt_opts);
|
|
221
|
+
} else if (provider == "cpu") {
|
|
222
|
+
// CPU is always available as fallback, no-op
|
|
223
|
+
} else {
|
|
224
|
+
throw std::runtime_error("Unknown execution provider: " + provider);
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
session_ = std::make_unique<Ort::Session>(get_env(log_level), model_path.c_str(), opts);
|
|
229
|
+
allocator_ = Ort::AllocatorWithDefaultOptions();
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
// Get input metadata
|
|
233
|
+
Rice::Array input_info() {
|
|
234
|
+
Rice::Array result;
|
|
235
|
+
size_t count = session_->GetInputCount();
|
|
236
|
+
|
|
237
|
+
for (size_t i = 0; i < count; i++) {
|
|
238
|
+
auto name = session_->GetInputNameAllocated(i, allocator_);
|
|
239
|
+
auto type_info = session_->GetInputTypeInfo(i);
|
|
240
|
+
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
|
241
|
+
auto shape = tensor_info.GetShape();
|
|
242
|
+
auto elem_type = tensor_info.GetElementType();
|
|
243
|
+
|
|
244
|
+
Rice::Hash info;
|
|
245
|
+
info[Rice::Symbol("name")] = Rice::String(name.get());
|
|
246
|
+
info[Rice::Symbol("type")] = Rice::Symbol(ort_type_to_string(elem_type));
|
|
247
|
+
|
|
248
|
+
Rice::Array rb_shape;
|
|
249
|
+
for (auto dim : shape) {
|
|
250
|
+
rb_shape.push(Rice::Object(LONG2NUM(dim)));
|
|
251
|
+
}
|
|
252
|
+
info[Rice::Symbol("shape")] = rb_shape;
|
|
253
|
+
|
|
254
|
+
result.push(Rice::Object(info.value()));
|
|
255
|
+
}
|
|
256
|
+
return result;
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
// Get output metadata
|
|
260
|
+
Rice::Array output_info() {
|
|
261
|
+
Rice::Array result;
|
|
262
|
+
size_t count = session_->GetOutputCount();
|
|
263
|
+
|
|
264
|
+
for (size_t i = 0; i < count; i++) {
|
|
265
|
+
auto name = session_->GetOutputNameAllocated(i, allocator_);
|
|
266
|
+
auto type_info = session_->GetOutputTypeInfo(i);
|
|
267
|
+
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
|
|
268
|
+
auto shape = tensor_info.GetShape();
|
|
269
|
+
auto elem_type = tensor_info.GetElementType();
|
|
270
|
+
|
|
271
|
+
Rice::Hash info;
|
|
272
|
+
info[Rice::Symbol("name")] = Rice::String(name.get());
|
|
273
|
+
info[Rice::Symbol("type")] = Rice::Symbol(ort_type_to_string(elem_type));
|
|
274
|
+
|
|
275
|
+
Rice::Array rb_shape;
|
|
276
|
+
for (auto dim : shape) {
|
|
277
|
+
rb_shape.push(Rice::Object(LONG2NUM(dim)));
|
|
278
|
+
}
|
|
279
|
+
info[Rice::Symbol("shape")] = rb_shape;
|
|
280
|
+
|
|
281
|
+
result.push(Rice::Object(info.value()));
|
|
282
|
+
}
|
|
283
|
+
return result;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
// Run inference
|
|
287
|
+
Rice::Object run(Rice::Array input_specs, Rice::Array output_names_filter) {
|
|
288
|
+
std::vector<const char*> input_names;
|
|
289
|
+
std::vector<Ort::Value> input_tensors;
|
|
290
|
+
std::vector<std::string> input_name_strs;
|
|
291
|
+
|
|
292
|
+
// Storage for tensor data (must outlive the Run call)
|
|
293
|
+
std::vector<std::vector<float>> float_buffers;
|
|
294
|
+
std::vector<std::vector<double>> double_buffers;
|
|
295
|
+
std::vector<std::vector<int32_t>> int32_buffers;
|
|
296
|
+
std::vector<std::vector<int64_t>> int64_buffers;
|
|
297
|
+
// Use uint8_t instead of bool because std::vector<bool> is bit-packed and has no .data()
|
|
298
|
+
std::vector<std::vector<uint8_t>> bool_buffers;
|
|
299
|
+
|
|
300
|
+
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
|
|
301
|
+
|
|
302
|
+
for (long idx = 0; idx < RARRAY_LEN(input_specs.value()); idx++) {
|
|
303
|
+
Rice::Hash spec(rb_ary_entry(input_specs.value(), idx));
|
|
304
|
+
|
|
305
|
+
std::string name = Rice::detail::From_Ruby<std::string>().convert(
|
|
306
|
+
spec[Rice::Symbol("name")].value());
|
|
307
|
+
input_name_strs.push_back(name);
|
|
308
|
+
|
|
309
|
+
Rice::Array data(spec[Rice::Symbol("data")].value());
|
|
310
|
+
Rice::Array rb_shape(spec[Rice::Symbol("shape")].value());
|
|
311
|
+
std::string dtype = Rice::detail::From_Ruby<std::string>().convert(
|
|
312
|
+
spec[Rice::Symbol("dtype")].value());
|
|
313
|
+
|
|
314
|
+
std::vector<int64_t> shape;
|
|
315
|
+
for (long s = 0; s < RARRAY_LEN(rb_shape.value()); s++) {
|
|
316
|
+
shape.push_back(NUM2LONG(rb_ary_entry(rb_shape.value(), s)));
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
size_t total_elements = 1;
|
|
320
|
+
for (auto dim : shape) total_elements *= dim;
|
|
321
|
+
|
|
322
|
+
if (dtype == "float") {
|
|
323
|
+
float_buffers.emplace_back(total_elements);
|
|
324
|
+
auto& buf = float_buffers.back();
|
|
325
|
+
for (size_t i = 0; i < total_elements; i++) {
|
|
326
|
+
buf[i] = static_cast<float>(NUM2DBL(rb_ary_entry(data.value(), i)));
|
|
327
|
+
}
|
|
328
|
+
input_tensors.push_back(
|
|
329
|
+
Ort::Value::CreateTensor<float>(memory_info, buf.data(), total_elements,
|
|
330
|
+
shape.data(), shape.size()));
|
|
331
|
+
} else if (dtype == "double") {
|
|
332
|
+
double_buffers.emplace_back(total_elements);
|
|
333
|
+
auto& buf = double_buffers.back();
|
|
334
|
+
for (size_t i = 0; i < total_elements; i++) {
|
|
335
|
+
buf[i] = NUM2DBL(rb_ary_entry(data.value(), i));
|
|
336
|
+
}
|
|
337
|
+
input_tensors.push_back(
|
|
338
|
+
Ort::Value::CreateTensor<double>(memory_info, buf.data(), total_elements,
|
|
339
|
+
shape.data(), shape.size()));
|
|
340
|
+
} else if (dtype == "int32") {
|
|
341
|
+
int32_buffers.emplace_back(total_elements);
|
|
342
|
+
auto& buf = int32_buffers.back();
|
|
343
|
+
for (size_t i = 0; i < total_elements; i++) {
|
|
344
|
+
buf[i] = static_cast<int32_t>(NUM2INT(rb_ary_entry(data.value(), i)));
|
|
345
|
+
}
|
|
346
|
+
input_tensors.push_back(
|
|
347
|
+
Ort::Value::CreateTensor<int32_t>(memory_info, buf.data(), total_elements,
|
|
348
|
+
shape.data(), shape.size()));
|
|
349
|
+
} else if (dtype == "int64") {
|
|
350
|
+
int64_buffers.emplace_back(total_elements);
|
|
351
|
+
auto& buf = int64_buffers.back();
|
|
352
|
+
for (size_t i = 0; i < total_elements; i++) {
|
|
353
|
+
buf[i] = NUM2LONG(rb_ary_entry(data.value(), i));
|
|
354
|
+
}
|
|
355
|
+
input_tensors.push_back(
|
|
356
|
+
Ort::Value::CreateTensor<int64_t>(memory_info, buf.data(), total_elements,
|
|
357
|
+
shape.data(), shape.size()));
|
|
358
|
+
} else if (dtype == "bool") {
|
|
359
|
+
bool_buffers.emplace_back(total_elements);
|
|
360
|
+
auto& buf = bool_buffers.back();
|
|
361
|
+
for (size_t i = 0; i < total_elements; i++) {
|
|
362
|
+
buf[i] = RTEST(rb_ary_entry(data.value(), i)) ? 1 : 0;
|
|
363
|
+
}
|
|
364
|
+
input_tensors.push_back(
|
|
365
|
+
Ort::Value::CreateTensor(memory_info, reinterpret_cast<bool*>(buf.data()), total_elements,
|
|
366
|
+
shape.data(), shape.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL));
|
|
367
|
+
} else {
|
|
368
|
+
throw std::runtime_error("Unsupported input dtype: " + dtype);
|
|
369
|
+
}
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
for (auto& s : input_name_strs) {
|
|
373
|
+
input_names.push_back(s.c_str());
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
// Determine output names
|
|
377
|
+
std::vector<std::string> output_name_strs;
|
|
378
|
+
std::vector<const char*> output_names;
|
|
379
|
+
|
|
380
|
+
if (RARRAY_LEN(output_names_filter.value()) > 0) {
|
|
381
|
+
for (long i = 0; i < RARRAY_LEN(output_names_filter.value()); i++) {
|
|
382
|
+
output_name_strs.push_back(
|
|
383
|
+
Rice::detail::From_Ruby<std::string>().convert(
|
|
384
|
+
rb_ary_entry(output_names_filter.value(), i)));
|
|
385
|
+
}
|
|
386
|
+
} else {
|
|
387
|
+
size_t output_count = session_->GetOutputCount();
|
|
388
|
+
for (size_t i = 0; i < output_count; i++) {
|
|
389
|
+
auto name = session_->GetOutputNameAllocated(i, allocator_);
|
|
390
|
+
output_name_strs.push_back(name.get());
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
for (auto& s : output_name_strs) {
|
|
395
|
+
output_names.push_back(s.c_str());
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
// Run inference
|
|
399
|
+
auto results = session_->Run(
|
|
400
|
+
Ort::RunOptions{nullptr},
|
|
401
|
+
input_names.data(), input_tensors.data(), input_names.size(),
|
|
402
|
+
output_names.data(), output_names.size());
|
|
403
|
+
|
|
404
|
+
// Convert results to Ruby Hash
|
|
405
|
+
Rice::Hash output;
|
|
406
|
+
for (size_t i = 0; i < results.size(); i++) {
|
|
407
|
+
Rice::Object rb_tensor = tensor_to_ruby(results[i]);
|
|
408
|
+
output[Rice::String(output_name_strs[i])] = rb_tensor;
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
return output;
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
private:
|
|
415
|
+
std::unique_ptr<Ort::Session> session_;
|
|
416
|
+
Ort::AllocatorWithDefaultOptions allocator_;
|
|
417
|
+
};
|
|
418
|
+
|
|
419
|
+
extern "C" void Init_onnx_ruby_ext() {
|
|
420
|
+
Module rb_mOnnxRuby = define_module("OnnxRuby");
|
|
421
|
+
Module rb_mExt = define_module_under(rb_mOnnxRuby, "Ext");
|
|
422
|
+
|
|
423
|
+
define_class_under<SessionWrapper>(rb_mExt, "SessionWrapper")
|
|
424
|
+
.define_constructor(Constructor<SessionWrapper, const std::string&, int, int, int,
|
|
425
|
+
const std::string&, bool, bool, const std::string&, Rice::Array>(),
|
|
426
|
+
Arg("model_path"), Arg("log_level"), Arg("intra_threads"), Arg("inter_threads"),
|
|
427
|
+
Arg("opt_level"), Arg("memory_pattern"), Arg("cpu_mem_arena"),
|
|
428
|
+
Arg("execution_mode"), Arg("providers"))
|
|
429
|
+
.define_method("input_info", &SessionWrapper::input_info)
|
|
430
|
+
.define_method("output_info", &SessionWrapper::output_info)
|
|
431
|
+
.define_method("run", &SessionWrapper::run);
|
|
432
|
+
|
|
433
|
+
rb_mExt.define_module_function("optimize_model", &optimize_model,
|
|
434
|
+
Arg("input_path"), Arg("output_path"), Arg("opt_level"));
|
|
435
|
+
rb_mExt.define_module_function("available_providers", &available_providers);
|
|
436
|
+
}
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class Classifier
|
|
5
|
+
attr_reader :session, :labels
|
|
6
|
+
|
|
7
|
+
def initialize(model_path, tokenizer: nil, labels: nil, **session_opts)
|
|
8
|
+
@session = Session.new(model_path, **session_opts)
|
|
9
|
+
@labels = labels
|
|
10
|
+
@tokenizer = resolve_tokenizer(tokenizer)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# Classify a single input
|
|
14
|
+
# @param input [String, Array<Float>] text (requires tokenizer) or feature vector
|
|
15
|
+
# @return [Hash] { label:, score:, scores: }
|
|
16
|
+
def predict(input)
|
|
17
|
+
predict_batch([input]).first
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Classify a batch of inputs
|
|
21
|
+
# @param inputs [Array<String>, Array<Array<Float>>] batch of texts or feature vectors
|
|
22
|
+
# @return [Array<Hash>] array of { label:, score:, scores: }
|
|
23
|
+
def predict_batch(inputs)
|
|
24
|
+
feed = prepare_inputs(inputs)
|
|
25
|
+
result = @session.run(feed)
|
|
26
|
+
|
|
27
|
+
logits = find_output(result, %w[logits output probabilities scores])
|
|
28
|
+
|
|
29
|
+
logits.map { |row| format_prediction(row) }
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
private
|
|
33
|
+
|
|
34
|
+
def resolve_tokenizer(tokenizer)
|
|
35
|
+
return nil if tokenizer.nil?
|
|
36
|
+
|
|
37
|
+
if tokenizer.respond_to?(:encode)
|
|
38
|
+
tokenizer
|
|
39
|
+
else
|
|
40
|
+
begin
|
|
41
|
+
require "tokenizers"
|
|
42
|
+
Tokenizers::Tokenizer.from_pretrained(tokenizer.to_s)
|
|
43
|
+
rescue LoadError
|
|
44
|
+
raise Error, "tokenizer-ruby gem is required for text tokenization. " \
|
|
45
|
+
"Install with: gem install tokenizers"
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def prepare_inputs(inputs)
|
|
51
|
+
if inputs.first.is_a?(String)
|
|
52
|
+
raise Error, "tokenizer is required for text inputs" unless @tokenizer
|
|
53
|
+
|
|
54
|
+
tokenize_batch(inputs)
|
|
55
|
+
elsif inputs.first.is_a?(Array)
|
|
56
|
+
# Raw feature vectors
|
|
57
|
+
input_name = @session.inputs.first[:name]
|
|
58
|
+
{ input_name => inputs }
|
|
59
|
+
else
|
|
60
|
+
raise Error, "inputs must be Strings or Arrays"
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def tokenize_batch(texts)
|
|
65
|
+
if @tokenizer.respond_to?(:encode_batch)
|
|
66
|
+
encodings = @tokenizer.encode_batch(texts)
|
|
67
|
+
ids = encodings.map(&:ids)
|
|
68
|
+
masks = encodings.map(&:attention_mask)
|
|
69
|
+
else
|
|
70
|
+
encodings = texts.map { |t| @tokenizer.encode(t) }
|
|
71
|
+
ids = encodings.map(&:ids)
|
|
72
|
+
masks = encodings.map(&:attention_mask)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
max_len = ids.map(&:length).max
|
|
76
|
+
ids = ids.map { |row| row + Array.new(max_len - row.length, 0) }
|
|
77
|
+
masks = masks.map { |row| row + Array.new(max_len - row.length, 0) }
|
|
78
|
+
|
|
79
|
+
input_names = @session.inputs.map { |i| i[:name] }
|
|
80
|
+
feed = {}
|
|
81
|
+
feed[input_names.find { |n| n.include?("input_id") } || input_names[0]] = ids
|
|
82
|
+
mask_name = input_names.find { |n| n.include?("mask") || n.include?("attention") }
|
|
83
|
+
feed[mask_name] = masks if mask_name
|
|
84
|
+
feed
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def find_output(result, candidate_names)
|
|
88
|
+
candidate_names.each { |name| return result[name] if result.key?(name) }
|
|
89
|
+
result.values.first
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def format_prediction(logits_row)
|
|
93
|
+
probs = softmax(logits_row)
|
|
94
|
+
max_idx = probs.each_with_index.max_by(&:first).last
|
|
95
|
+
label = @labels ? @labels[max_idx] : max_idx
|
|
96
|
+
|
|
97
|
+
{ label: label, score: probs[max_idx], scores: probs }
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def softmax(logits)
|
|
101
|
+
max_val = logits.max
|
|
102
|
+
exps = logits.map { |v| Math.exp(v - max_val) }
|
|
103
|
+
sum = exps.sum
|
|
104
|
+
exps.map { |v| v / sum }
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
end
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class Configuration
|
|
5
|
+
attr_accessor :models_path, :default_providers, :default_log_level,
|
|
6
|
+
:pool_size, :pool_timeout
|
|
7
|
+
|
|
8
|
+
def initialize
|
|
9
|
+
@models_path = "app/models/onnx"
|
|
10
|
+
@default_providers = [:cpu]
|
|
11
|
+
@default_log_level = :warning
|
|
12
|
+
@pool_size = 5
|
|
13
|
+
@pool_timeout = 5
|
|
14
|
+
end
|
|
15
|
+
end
|
|
16
|
+
end
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module OnnxRuby
|
|
4
|
+
class Embedder
|
|
5
|
+
attr_reader :session
|
|
6
|
+
|
|
7
|
+
def initialize(model_path, tokenizer: nil, normalize: true, **session_opts)
|
|
8
|
+
@session = Session.new(model_path, **session_opts)
|
|
9
|
+
@normalize = normalize
|
|
10
|
+
@tokenizer = resolve_tokenizer(tokenizer)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
# Embed a single text or pre-tokenized input
|
|
14
|
+
# @param input [String, Hash] text string (requires tokenizer) or hash of input tensors
|
|
15
|
+
# @return [Array<Float>] embedding vector
|
|
16
|
+
def embed(input)
|
|
17
|
+
embed_batch([input]).first
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Embed a batch of texts or pre-tokenized inputs
|
|
21
|
+
# @param inputs [Array<String>, Array<Hash>] batch of texts or tensor hashes
|
|
22
|
+
# @return [Array<Array<Float>>] array of embedding vectors
|
|
23
|
+
def embed_batch(inputs)
|
|
24
|
+
@_masks = nil
|
|
25
|
+
feed = prepare_inputs(inputs)
|
|
26
|
+
result = @session.run(feed)
|
|
27
|
+
|
|
28
|
+
raw = find_output(result, %w[embeddings sentence_embedding output last_hidden_state])
|
|
29
|
+
|
|
30
|
+
# If output is 3D (batch, seq_len, dim) — do mean pooling
|
|
31
|
+
embeddings = if raw.first.is_a?(Array) && raw.first.first.is_a?(Array)
|
|
32
|
+
mean_pool(raw, @_masks)
|
|
33
|
+
else
|
|
34
|
+
raw
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
embeddings.map { |vec| @normalize ? l2_normalize(vec) : vec }
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
private
|
|
41
|
+
|
|
42
|
+
def resolve_tokenizer(tokenizer)
|
|
43
|
+
return nil if tokenizer.nil?
|
|
44
|
+
|
|
45
|
+
if tokenizer.respond_to?(:encode) || tokenizer.respond_to?(:encode_batch)
|
|
46
|
+
tokenizer
|
|
47
|
+
else
|
|
48
|
+
begin
|
|
49
|
+
require "tokenizers"
|
|
50
|
+
Tokenizers::Tokenizer.from_pretrained(tokenizer.to_s)
|
|
51
|
+
rescue LoadError
|
|
52
|
+
raise Error, "tokenizer-ruby gem is required for text tokenization. " \
|
|
53
|
+
"Install with: gem install tokenizers"
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def prepare_inputs(inputs)
|
|
59
|
+
if inputs.first.is_a?(String)
|
|
60
|
+
raise Error, "tokenizer is required for text inputs" unless @tokenizer
|
|
61
|
+
|
|
62
|
+
tokenize_batch(inputs)
|
|
63
|
+
elsif inputs.first.is_a?(Hash)
|
|
64
|
+
# Merge hash inputs into batched arrays
|
|
65
|
+
merge_input_hashes(inputs)
|
|
66
|
+
else
|
|
67
|
+
raise Error, "inputs must be Strings or Hashes"
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def tokenize_batch(texts)
|
|
72
|
+
if @tokenizer.respond_to?(:encode_batch)
|
|
73
|
+
encodings = @tokenizer.encode_batch(texts)
|
|
74
|
+
ids = encodings.map(&:ids)
|
|
75
|
+
masks = encodings.map(&:attention_mask)
|
|
76
|
+
else
|
|
77
|
+
encodings = texts.map { |t| @tokenizer.encode(t) }
|
|
78
|
+
ids = encodings.map(&:ids)
|
|
79
|
+
masks = encodings.map(&:attention_mask)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Pad to max length
|
|
83
|
+
max_len = ids.map(&:length).max
|
|
84
|
+
ids = ids.map { |row| row + Array.new(max_len - row.length, 0) }
|
|
85
|
+
masks = masks.map { |row| row + Array.new(max_len - row.length, 0) }
|
|
86
|
+
|
|
87
|
+
build_feed(ids, masks)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def merge_input_hashes(hashes)
|
|
91
|
+
result = {}
|
|
92
|
+
hashes.first.each_key do |key|
|
|
93
|
+
result[key] = hashes.map { |h| h[key] }
|
|
94
|
+
end
|
|
95
|
+
# Stash masks for mean pooling
|
|
96
|
+
mask_key = result.keys.find { |k| k.to_s.include?("mask") || k.to_s.include?("attention") }
|
|
97
|
+
@_masks = result[mask_key] if mask_key
|
|
98
|
+
result
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
def build_feed(ids, masks)
|
|
102
|
+
input_names = @session.inputs.map { |i| i[:name] }
|
|
103
|
+
feed = {}
|
|
104
|
+
feed[input_names.find { |n| n.include?("input_id") } || input_names[0]] = ids
|
|
105
|
+
mask_name = input_names.find { |n| n.include?("mask") || n.include?("attention") }
|
|
106
|
+
feed[mask_name] = masks if mask_name
|
|
107
|
+
# Supply token_type_ids (zeros) if the model expects it
|
|
108
|
+
tti_name = input_names.find { |n| n.include?("token_type") }
|
|
109
|
+
feed[tti_name] = ids.map { |row| Array.new(row.length, 0) } if tti_name
|
|
110
|
+
@_masks = masks # stash for mean pooling
|
|
111
|
+
feed
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def find_output(result, candidate_names)
|
|
115
|
+
candidate_names.each { |name| return result[name] if result.key?(name) }
|
|
116
|
+
result.values.first
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Mean pooling over token embeddings, masked by attention_mask
|
|
120
|
+
def mean_pool(hidden_states, masks)
|
|
121
|
+
hidden_states.each_with_index.map do |tokens, batch_idx|
|
|
122
|
+
mask = masks && masks[batch_idx]
|
|
123
|
+
dim = tokens.first.length
|
|
124
|
+
sum = Array.new(dim, 0.0)
|
|
125
|
+
count = 0.0
|
|
126
|
+
|
|
127
|
+
tokens.each_with_index do |token_vec, tok_idx|
|
|
128
|
+
w = (mask && mask[tok_idx]) ? mask[tok_idx].to_f : 1.0
|
|
129
|
+
next if w.zero?
|
|
130
|
+
|
|
131
|
+
count += w
|
|
132
|
+
token_vec.each_with_index { |v, d| sum[d] += v * w }
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
count = 1.0 if count.zero?
|
|
136
|
+
sum.map { |v| v / count }
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
def l2_normalize(vec)
|
|
141
|
+
norm = Math.sqrt(vec.sum { |v| v * v })
|
|
142
|
+
return vec if norm.zero?
|
|
143
|
+
|
|
144
|
+
vec.map { |v| v / norm }
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
end
|