executorch 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +47 -0
- data/LICENSE.txt +176 -0
- data/README.md +198 -0
- data/ext/executorch/executorch.cpp +582 -0
- data/ext/executorch/extconf.rb +208 -0
- data/ext/executorch/utils.h +140 -0
- data/lib/executorch/version.rb +3 -0
- data/lib/executorch.rb +212 -0
- metadata +66 -0
|
@@ -0,0 +1,582 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* ExecuTorch Ruby Bindings
|
|
3
|
+
*
|
|
4
|
+
* This file provides Ruby bindings for Meta's ExecuTorch library using the Rice gem.
|
|
5
|
+
* It wraps the high-level Module API for loading and executing PyTorch models.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <rice/rice.hpp>
|
|
9
|
+
#include <rice/stl.hpp>
|
|
10
|
+
|
|
11
|
+
#include <executorch/extension/module/module.h>
|
|
12
|
+
#include <executorch/extension/tensor/tensor_ptr.h>
|
|
13
|
+
#include <executorch/runtime/core/evalue.h>
|
|
14
|
+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
|
|
15
|
+
|
|
16
|
+
#include <memory>
|
|
17
|
+
#include <vector>
|
|
18
|
+
#include <string>
|
|
19
|
+
#include <unordered_set>
|
|
20
|
+
|
|
21
|
+
#include "utils.h"
|
|
22
|
+
|
|
23
|
+
using namespace Rice;
|
|
24
|
+
using namespace executorch::runtime;
|
|
25
|
+
using namespace executorch::extension;
|
|
26
|
+
namespace et = executorch;
|
|
27
|
+
|
|
28
|
+
// Forward declarations
|
|
29
|
+
class RubyTensor;
|
|
30
|
+
class RubyEValue;
|
|
31
|
+
|
|
32
|
+
/**
|
|
33
|
+
* Ruby wrapper for executorch::aten::Tensor via TensorPtr
|
|
34
|
+
*
|
|
35
|
+
* This class manages tensor data and provides methods for creating tensors
|
|
36
|
+
* from Ruby arrays and extracting data back to Ruby.
|
|
37
|
+
*/
|
|
38
|
+
class RubyTensor {
|
|
39
|
+
public:
|
|
40
|
+
// Create a tensor from Ruby data array and shape
|
|
41
|
+
static RubyTensor create(Array data, Array shape, Symbol dtype) {
|
|
42
|
+
HANDLE_ET_ERRORS
|
|
43
|
+
|
|
44
|
+
// Convert shape to vector using the proper SizesType
|
|
45
|
+
std::vector<et::aten::SizesType> sizes;
|
|
46
|
+
for (size_t i = 0; i < shape.size(); i++) {
|
|
47
|
+
sizes.push_back(static_cast<et::aten::SizesType>(
|
|
48
|
+
detail::From_Ruby<int64_t>().convert(shape[i].value())));
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// Get scalar type from symbol
|
|
52
|
+
et::aten::ScalarType scalar_type = executorch_ruby::symbol_to_scalar_type(dtype.value());
|
|
53
|
+
|
|
54
|
+
// Convert data based on dtype and create tensor using templated make_tensor_ptr
|
|
55
|
+
// The templated version takes ownership of the data vector and handles memory management
|
|
56
|
+
if (scalar_type == et::aten::ScalarType::Float) {
|
|
57
|
+
std::vector<float> float_data;
|
|
58
|
+
float_data.reserve(data.size());
|
|
59
|
+
for (size_t i = 0; i < data.size(); i++) {
|
|
60
|
+
float_data.push_back(static_cast<float>(detail::From_Ruby<double>().convert(data[i].value())));
|
|
61
|
+
}
|
|
62
|
+
// Use templated make_tensor_ptr which manages data ownership
|
|
63
|
+
TensorPtr tensor_ptr = make_tensor_ptr<float>(
|
|
64
|
+
std::move(sizes),
|
|
65
|
+
std::move(float_data)
|
|
66
|
+
);
|
|
67
|
+
return RubyTensor(std::move(tensor_ptr));
|
|
68
|
+
} else if (scalar_type == et::aten::ScalarType::Double) {
|
|
69
|
+
std::vector<double> double_data;
|
|
70
|
+
double_data.reserve(data.size());
|
|
71
|
+
for (size_t i = 0; i < data.size(); i++) {
|
|
72
|
+
double_data.push_back(detail::From_Ruby<double>().convert(data[i].value()));
|
|
73
|
+
}
|
|
74
|
+
TensorPtr tensor_ptr = make_tensor_ptr<double>(
|
|
75
|
+
std::move(sizes),
|
|
76
|
+
std::move(double_data)
|
|
77
|
+
);
|
|
78
|
+
return RubyTensor(std::move(tensor_ptr));
|
|
79
|
+
} else if (scalar_type == et::aten::ScalarType::Long) {
|
|
80
|
+
std::vector<int64_t> int_data;
|
|
81
|
+
int_data.reserve(data.size());
|
|
82
|
+
for (size_t i = 0; i < data.size(); i++) {
|
|
83
|
+
int_data.push_back(detail::From_Ruby<int64_t>().convert(data[i].value()));
|
|
84
|
+
}
|
|
85
|
+
TensorPtr tensor_ptr = make_tensor_ptr<int64_t>(
|
|
86
|
+
std::move(sizes),
|
|
87
|
+
std::move(int_data)
|
|
88
|
+
);
|
|
89
|
+
return RubyTensor(std::move(tensor_ptr));
|
|
90
|
+
} else if (scalar_type == et::aten::ScalarType::Int) {
|
|
91
|
+
std::vector<int32_t> int_data;
|
|
92
|
+
int_data.reserve(data.size());
|
|
93
|
+
for (size_t i = 0; i < data.size(); i++) {
|
|
94
|
+
int_data.push_back(static_cast<int32_t>(detail::From_Ruby<int64_t>().convert(data[i].value())));
|
|
95
|
+
}
|
|
96
|
+
TensorPtr tensor_ptr = make_tensor_ptr<int32_t>(
|
|
97
|
+
std::move(sizes),
|
|
98
|
+
std::move(int_data)
|
|
99
|
+
);
|
|
100
|
+
return RubyTensor(std::move(tensor_ptr));
|
|
101
|
+
} else {
|
|
102
|
+
rb_raise(rb_eArgError, "Unsupported dtype. Use :float, :double, :long, or :int");
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
// Should never reach here but compiler needs it
|
|
106
|
+
rb_raise(rb_eRuntimeError, "Unexpected code path in Tensor.create");
|
|
107
|
+
END_HANDLE_ET_ERRORS
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Create a float tensor (convenience method)
|
|
111
|
+
static RubyTensor from_array(Array data, Array shape) {
|
|
112
|
+
return create(data, shape, Symbol("float"));
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// Get shape as Ruby array
|
|
116
|
+
Array shape() const {
|
|
117
|
+
Array result;
|
|
118
|
+
for (int i = 0; i < tensor_ptr_->dim(); i++) {
|
|
119
|
+
result.push(tensor_ptr_->size(i));
|
|
120
|
+
}
|
|
121
|
+
return result;
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
// Get number of dimensions
|
|
125
|
+
int64_t dim() const {
|
|
126
|
+
return tensor_ptr_->dim();
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
// Get total number of elements
|
|
130
|
+
int64_t numel() const {
|
|
131
|
+
return tensor_ptr_->numel();
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
// Get dtype as symbol
|
|
135
|
+
Object dtype() const {
|
|
136
|
+
return Object(executorch_ruby::scalar_type_to_symbol(tensor_ptr_->scalar_type()));
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
// Convert tensor data to Ruby array (flattened)
|
|
140
|
+
Array to_a() const {
|
|
141
|
+
Array result;
|
|
142
|
+
auto scalar_type = tensor_ptr_->scalar_type();
|
|
143
|
+
|
|
144
|
+
if (scalar_type == et::aten::ScalarType::Float) {
|
|
145
|
+
const float* data = tensor_ptr_->const_data_ptr<float>();
|
|
146
|
+
for (int64_t i = 0; i < tensor_ptr_->numel(); i++) {
|
|
147
|
+
result.push(data[i]);
|
|
148
|
+
}
|
|
149
|
+
} else if (scalar_type == et::aten::ScalarType::Double) {
|
|
150
|
+
const double* data = tensor_ptr_->const_data_ptr<double>();
|
|
151
|
+
for (int64_t i = 0; i < tensor_ptr_->numel(); i++) {
|
|
152
|
+
result.push(data[i]);
|
|
153
|
+
}
|
|
154
|
+
} else if (scalar_type == et::aten::ScalarType::Long) {
|
|
155
|
+
const int64_t* data = tensor_ptr_->const_data_ptr<int64_t>();
|
|
156
|
+
for (int64_t i = 0; i < tensor_ptr_->numel(); i++) {
|
|
157
|
+
result.push(data[i]);
|
|
158
|
+
}
|
|
159
|
+
} else if (scalar_type == et::aten::ScalarType::Int) {
|
|
160
|
+
const int32_t* data = tensor_ptr_->const_data_ptr<int32_t>();
|
|
161
|
+
for (int64_t i = 0; i < tensor_ptr_->numel(); i++) {
|
|
162
|
+
result.push(static_cast<int64_t>(data[i]));
|
|
163
|
+
}
|
|
164
|
+
} else {
|
|
165
|
+
rb_raise(rb_eRuntimeError, "Unsupported tensor dtype for to_a");
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
return result;
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// Get string representation
|
|
172
|
+
std::string to_s() const {
|
|
173
|
+
std::string result = "Tensor(shape=[";
|
|
174
|
+
for (int i = 0; i < tensor_ptr_->dim(); i++) {
|
|
175
|
+
if (i > 0) result += ", ";
|
|
176
|
+
result += std::to_string(tensor_ptr_->size(i));
|
|
177
|
+
}
|
|
178
|
+
result += "], dtype=";
|
|
179
|
+
|
|
180
|
+
auto scalar_type = tensor_ptr_->scalar_type();
|
|
181
|
+
if (scalar_type == et::aten::ScalarType::Float) result += "float";
|
|
182
|
+
else if (scalar_type == et::aten::ScalarType::Double) result += "double";
|
|
183
|
+
else if (scalar_type == et::aten::ScalarType::Long) result += "long";
|
|
184
|
+
else if (scalar_type == et::aten::ScalarType::Int) result += "int";
|
|
185
|
+
else result += "unknown";
|
|
186
|
+
|
|
187
|
+
result += ")";
|
|
188
|
+
return result;
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
// Access the underlying tensor
|
|
192
|
+
et::aten::Tensor& get() { return *tensor_ptr_; }
|
|
193
|
+
const et::aten::Tensor& get() const { return *tensor_ptr_; }
|
|
194
|
+
|
|
195
|
+
// Get the TensorPtr for ownership transfer
|
|
196
|
+
TensorPtr& get_ptr() { return tensor_ptr_; }
|
|
197
|
+
|
|
198
|
+
// Create from existing TensorPtr (used internally)
|
|
199
|
+
explicit RubyTensor(TensorPtr ptr) : tensor_ptr_(std::move(ptr)) {}
|
|
200
|
+
|
|
201
|
+
// Create from Tensor reference (clones the tensor)
|
|
202
|
+
static RubyTensor from_tensor(const et::aten::Tensor& tensor) {
|
|
203
|
+
return RubyTensor(clone_tensor_ptr(tensor));
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
private:
|
|
207
|
+
TensorPtr tensor_ptr_;
|
|
208
|
+
};
|
|
209
|
+
|
|
210
|
+
/**
|
|
211
|
+
* Ruby wrapper for executorch::runtime::EValue
|
|
212
|
+
*
|
|
213
|
+
* EValue is a tagged union that can hold different value types:
|
|
214
|
+
* - None
|
|
215
|
+
* - Int (int64_t)
|
|
216
|
+
* - Double
|
|
217
|
+
* - Bool
|
|
218
|
+
* - String
|
|
219
|
+
* - Tensor
|
|
220
|
+
* - Lists (IntList, DoubleList, BoolList, TensorList)
|
|
221
|
+
*/
|
|
222
|
+
class RubyEValue {
|
|
223
|
+
public:
|
|
224
|
+
// Create from different types
|
|
225
|
+
static RubyEValue from_none() {
|
|
226
|
+
return RubyEValue(EValue());
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
static RubyEValue from_int(int64_t value) {
|
|
230
|
+
return RubyEValue(EValue(value));
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
static RubyEValue from_double(double value) {
|
|
234
|
+
return RubyEValue(EValue(value));
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
static RubyEValue from_bool(bool value) {
|
|
238
|
+
return RubyEValue(EValue(value));
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
static RubyEValue from_tensor(RubyTensor& tensor) {
|
|
242
|
+
return RubyEValue(EValue(tensor.get()), tensor.get_ptr());
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
// Type checking
|
|
246
|
+
bool is_none() const { return evalue_.isNone(); }
|
|
247
|
+
bool is_int() const { return evalue_.isInt(); }
|
|
248
|
+
bool is_double() const { return evalue_.isDouble(); }
|
|
249
|
+
bool is_bool() const { return evalue_.isBool(); }
|
|
250
|
+
bool is_string() const { return evalue_.isString(); }
|
|
251
|
+
bool is_tensor() const { return evalue_.isTensor(); }
|
|
252
|
+
|
|
253
|
+
// Value extraction
|
|
254
|
+
int64_t to_int() const {
|
|
255
|
+
if (!is_int()) {
|
|
256
|
+
rb_raise(rb_eTypeError, "EValue is not an Int");
|
|
257
|
+
}
|
|
258
|
+
return evalue_.toInt();
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
double to_double() const {
|
|
262
|
+
if (!is_double()) {
|
|
263
|
+
rb_raise(rb_eTypeError, "EValue is not a Double");
|
|
264
|
+
}
|
|
265
|
+
return evalue_.toDouble();
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
bool to_bool() const {
|
|
269
|
+
if (!is_bool()) {
|
|
270
|
+
rb_raise(rb_eTypeError, "EValue is not a Bool");
|
|
271
|
+
}
|
|
272
|
+
return evalue_.toBool();
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
RubyTensor to_tensor() const {
|
|
276
|
+
if (!is_tensor()) {
|
|
277
|
+
rb_raise(rb_eTypeError, "EValue is not a Tensor");
|
|
278
|
+
}
|
|
279
|
+
return RubyTensor::from_tensor(evalue_.toTensor());
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// Convert to Ruby object based on type
|
|
283
|
+
Object to_ruby() const {
|
|
284
|
+
if (is_none()) {
|
|
285
|
+
return Object(Qnil);
|
|
286
|
+
} else if (is_int()) {
|
|
287
|
+
return Object(LONG2NUM(evalue_.toInt()));
|
|
288
|
+
} else if (is_double()) {
|
|
289
|
+
return Object(DBL2NUM(evalue_.toDouble()));
|
|
290
|
+
} else if (is_bool()) {
|
|
291
|
+
return Object(evalue_.toBool() ? Qtrue : Qfalse);
|
|
292
|
+
} else if (is_tensor()) {
|
|
293
|
+
// Return tensor info as hash
|
|
294
|
+
auto tensor = evalue_.toTensor();
|
|
295
|
+
VALUE hash = rb_hash_new();
|
|
296
|
+
rb_hash_aset(hash, ID2SYM(rb_intern("type")), rb_str_new_cstr("tensor"));
|
|
297
|
+
rb_hash_aset(hash, ID2SYM(rb_intern("dtype")),
|
|
298
|
+
executorch_ruby::scalar_type_to_symbol(tensor.scalar_type()));
|
|
299
|
+
|
|
300
|
+
// Build shape array
|
|
301
|
+
VALUE shape = rb_ary_new();
|
|
302
|
+
for (int i = 0; i < tensor.dim(); i++) {
|
|
303
|
+
rb_ary_push(shape, LONG2NUM(tensor.size(i)));
|
|
304
|
+
}
|
|
305
|
+
rb_hash_aset(hash, ID2SYM(rb_intern("shape")), shape);
|
|
306
|
+
|
|
307
|
+
return Object(hash);
|
|
308
|
+
} else {
|
|
309
|
+
return Object(Qnil);
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
// Get type name
|
|
314
|
+
std::string type_name() const {
|
|
315
|
+
if (is_none()) return "None";
|
|
316
|
+
if (is_int()) return "Int";
|
|
317
|
+
if (is_double()) return "Double";
|
|
318
|
+
if (is_bool()) return "Bool";
|
|
319
|
+
if (is_string()) return "String";
|
|
320
|
+
if (is_tensor()) return "Tensor";
|
|
321
|
+
return "Unknown";
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
// Access underlying EValue
|
|
325
|
+
const EValue& get() const { return evalue_; }
|
|
326
|
+
EValue& get() { return evalue_; }
|
|
327
|
+
|
|
328
|
+
// For internal use - create from EValue with tensor ownership
|
|
329
|
+
explicit RubyEValue(EValue evalue, TensorPtr tensor_ptr = nullptr)
|
|
330
|
+
: evalue_(std::move(evalue)), tensor_ptr_(std::move(tensor_ptr)) {}
|
|
331
|
+
|
|
332
|
+
private:
|
|
333
|
+
EValue evalue_;
|
|
334
|
+
TensorPtr tensor_ptr_; // Keep tensor alive if EValue references it
|
|
335
|
+
};
|
|
336
|
+
|
|
337
|
+
/**
|
|
338
|
+
* Ruby wrapper for executorch::extension::Module
|
|
339
|
+
*
|
|
340
|
+
* This class manages the lifecycle of an ExecuTorch model and provides
|
|
341
|
+
* methods for loading and executing inference.
|
|
342
|
+
*/
|
|
343
|
+
class RubyModel {
|
|
344
|
+
public:
|
|
345
|
+
RubyModel(const std::string& path)
|
|
346
|
+
: path_(path), module_(nullptr) {
|
|
347
|
+
HANDLE_ET_ERRORS
|
|
348
|
+
module_ = std::make_unique<et::extension::Module>(path);
|
|
349
|
+
// Auto-load on construction
|
|
350
|
+
auto err = module_->load();
|
|
351
|
+
executorch_ruby::check_error(err);
|
|
352
|
+
END_HANDLE_ET_ERRORS
|
|
353
|
+
}
|
|
354
|
+
|
|
355
|
+
bool is_loaded() const {
|
|
356
|
+
return module_ && module_->is_loaded();
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
Array method_names() {
|
|
360
|
+
HANDLE_ET_ERRORS
|
|
361
|
+
if (!is_loaded()) {
|
|
362
|
+
rb_raise(rb_eRuntimeError, "Module not loaded");
|
|
363
|
+
}
|
|
364
|
+
auto result = module_->method_names();
|
|
365
|
+
auto names_set = executorch_ruby::unwrap_result(std::move(result));
|
|
366
|
+
Array arr;
|
|
367
|
+
for (const auto& name : names_set) {
|
|
368
|
+
arr.push(String(name));
|
|
369
|
+
}
|
|
370
|
+
return arr;
|
|
371
|
+
END_HANDLE_ET_ERRORS
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
// Execute the forward method with tensor inputs
|
|
375
|
+
Array forward(Array inputs) {
|
|
376
|
+
HANDLE_ET_ERRORS
|
|
377
|
+
if (!is_loaded()) {
|
|
378
|
+
rb_raise(rb_eRuntimeError, "Module not loaded");
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
// Load forward method if not already loaded
|
|
382
|
+
if (!module_->is_method_loaded("forward")) {
|
|
383
|
+
auto load_err = module_->load_method("forward");
|
|
384
|
+
if (load_err != executorch::runtime::Error::Ok) {
|
|
385
|
+
const char* load_error_name = "Unknown";
|
|
386
|
+
switch (load_err) {
|
|
387
|
+
case executorch::runtime::Error::Ok: load_error_name = "Ok"; break;
|
|
388
|
+
case executorch::runtime::Error::Internal: load_error_name = "Internal"; break;
|
|
389
|
+
case executorch::runtime::Error::InvalidState: load_error_name = "InvalidState"; break;
|
|
390
|
+
case executorch::runtime::Error::InvalidArgument: load_error_name = "InvalidArgument"; break;
|
|
391
|
+
case executorch::runtime::Error::InvalidType: load_error_name = "InvalidType"; break;
|
|
392
|
+
case executorch::runtime::Error::NotFound: load_error_name = "NotFound"; break;
|
|
393
|
+
case executorch::runtime::Error::MemoryAllocationFailed: load_error_name = "MemoryAllocationFailed"; break;
|
|
394
|
+
case executorch::runtime::Error::AccessFailed: load_error_name = "AccessFailed"; break;
|
|
395
|
+
case executorch::runtime::Error::NotSupported: load_error_name = "NotSupported"; break;
|
|
396
|
+
default: load_error_name = "Unknown"; break;
|
|
397
|
+
}
|
|
398
|
+
rb_raise(rb_eRuntimeError, "Failed to load forward method: %s (%d)", load_error_name, static_cast<int>(load_err));
|
|
399
|
+
}
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
// Convert Ruby inputs to EValues
|
|
403
|
+
// Keep tensors alive during forward execution
|
|
404
|
+
std::vector<TensorPtr> input_tensors;
|
|
405
|
+
std::vector<EValue> input_evalues;
|
|
406
|
+
|
|
407
|
+
for (size_t i = 0; i < inputs.size(); i++) {
|
|
408
|
+
Object input = inputs[i];
|
|
409
|
+
|
|
410
|
+
// Check if it's a RubyTensor
|
|
411
|
+
if (input.is_a(rb_cObject)) {
|
|
412
|
+
try {
|
|
413
|
+
RubyTensor& tensor = detail::From_Ruby<RubyTensor&>().convert(input.value());
|
|
414
|
+
// Clone the tensor to ensure we own the data during forward
|
|
415
|
+
TensorPtr cloned = clone_tensor_ptr(tensor.get());
|
|
416
|
+
input_tensors.push_back(cloned);
|
|
417
|
+
input_evalues.push_back(EValue(*cloned));
|
|
418
|
+
} catch (...) {
|
|
419
|
+
// Try as RubyEValue
|
|
420
|
+
try {
|
|
421
|
+
RubyEValue& evalue = detail::From_Ruby<RubyEValue&>().convert(input.value());
|
|
422
|
+
input_evalues.push_back(evalue.get());
|
|
423
|
+
} catch (...) {
|
|
424
|
+
rb_raise(rb_eTypeError, "Input %zu must be a Tensor or EValue", i);
|
|
425
|
+
}
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
// Execute forward
|
|
431
|
+
auto result = module_->forward(input_evalues);
|
|
432
|
+
if (!result.ok()) {
|
|
433
|
+
auto error = result.error();
|
|
434
|
+
const char* error_name = "Unknown";
|
|
435
|
+
switch (error) {
|
|
436
|
+
case executorch::runtime::Error::Ok: error_name = "Ok"; break;
|
|
437
|
+
case executorch::runtime::Error::Internal: error_name = "Internal"; break;
|
|
438
|
+
case executorch::runtime::Error::InvalidState: error_name = "InvalidState"; break;
|
|
439
|
+
case executorch::runtime::Error::InvalidArgument: error_name = "InvalidArgument"; break;
|
|
440
|
+
case executorch::runtime::Error::InvalidType: error_name = "InvalidType"; break;
|
|
441
|
+
case executorch::runtime::Error::NotFound: error_name = "NotFound"; break;
|
|
442
|
+
case executorch::runtime::Error::MemoryAllocationFailed: error_name = "MemoryAllocationFailed"; break;
|
|
443
|
+
case executorch::runtime::Error::AccessFailed: error_name = "AccessFailed"; break;
|
|
444
|
+
case executorch::runtime::Error::NotSupported: error_name = "NotSupported"; break;
|
|
445
|
+
default: error_name = "Unknown"; break;
|
|
446
|
+
}
|
|
447
|
+
rb_raise(rb_eRuntimeError, "Forward execution failed: %s (%d)", error_name, static_cast<int>(error));
|
|
448
|
+
}
|
|
449
|
+
auto outputs = std::move(result.get());
|
|
450
|
+
|
|
451
|
+
// Convert outputs to Ruby array of RubyTensors
|
|
452
|
+
Array ruby_outputs;
|
|
453
|
+
for (auto& output : outputs) {
|
|
454
|
+
if (output.isTensor()) {
|
|
455
|
+
// Clone the tensor to own the data
|
|
456
|
+
ruby_outputs.push(RubyTensor::from_tensor(output.toTensor()));
|
|
457
|
+
} else if (output.isInt()) {
|
|
458
|
+
ruby_outputs.push(output.toInt());
|
|
459
|
+
} else if (output.isDouble()) {
|
|
460
|
+
ruby_outputs.push(output.toDouble());
|
|
461
|
+
} else if (output.isBool()) {
|
|
462
|
+
ruby_outputs.push(output.toBool() ? Qtrue : Qfalse);
|
|
463
|
+
} else {
|
|
464
|
+
ruby_outputs.push(Qnil);
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
return ruby_outputs;
|
|
469
|
+
END_HANDLE_ET_ERRORS
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// Execute a named method
|
|
473
|
+
Array execute(const std::string& method_name, Array inputs) {
|
|
474
|
+
HANDLE_ET_ERRORS
|
|
475
|
+
if (!is_loaded()) {
|
|
476
|
+
rb_raise(rb_eRuntimeError, "Module not loaded");
|
|
477
|
+
}
|
|
478
|
+
|
|
479
|
+
// Convert Ruby inputs to EValues
|
|
480
|
+
// Keep tensors alive during execution
|
|
481
|
+
std::vector<TensorPtr> input_tensors;
|
|
482
|
+
std::vector<EValue> input_evalues;
|
|
483
|
+
for (size_t i = 0; i < inputs.size(); i++) {
|
|
484
|
+
Object input = inputs[i];
|
|
485
|
+
|
|
486
|
+
try {
|
|
487
|
+
RubyTensor& tensor = detail::From_Ruby<RubyTensor&>().convert(input.value());
|
|
488
|
+
// Clone the tensor to ensure we own the data during execution
|
|
489
|
+
TensorPtr cloned = clone_tensor_ptr(tensor.get());
|
|
490
|
+
input_tensors.push_back(cloned);
|
|
491
|
+
input_evalues.push_back(EValue(*cloned));
|
|
492
|
+
} catch (...) {
|
|
493
|
+
try {
|
|
494
|
+
RubyEValue& evalue = detail::From_Ruby<RubyEValue&>().convert(input.value());
|
|
495
|
+
input_evalues.push_back(evalue.get());
|
|
496
|
+
} catch (...) {
|
|
497
|
+
rb_raise(rb_eTypeError, "Input %zu must be a Tensor or EValue", i);
|
|
498
|
+
}
|
|
499
|
+
}
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
// Execute method
|
|
503
|
+
auto result = module_->execute(method_name, input_evalues);
|
|
504
|
+
auto outputs = executorch_ruby::unwrap_result(std::move(result));
|
|
505
|
+
|
|
506
|
+
// Convert outputs to Ruby
|
|
507
|
+
Array ruby_outputs;
|
|
508
|
+
for (auto& output : outputs) {
|
|
509
|
+
if (output.isTensor()) {
|
|
510
|
+
ruby_outputs.push(RubyTensor::from_tensor(output.toTensor()));
|
|
511
|
+
} else if (output.isInt()) {
|
|
512
|
+
ruby_outputs.push(output.toInt());
|
|
513
|
+
} else if (output.isDouble()) {
|
|
514
|
+
ruby_outputs.push(output.toDouble());
|
|
515
|
+
} else if (output.isBool()) {
|
|
516
|
+
ruby_outputs.push(output.toBool() ? Qtrue : Qfalse);
|
|
517
|
+
} else {
|
|
518
|
+
ruby_outputs.push(Qnil);
|
|
519
|
+
}
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
return ruby_outputs;
|
|
523
|
+
END_HANDLE_ET_ERRORS
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
// Get the file path
|
|
527
|
+
std::string path() const {
|
|
528
|
+
return path_;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
// Access the underlying module for advanced use
|
|
532
|
+
et::extension::Module* get_module() {
|
|
533
|
+
return module_.get();
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
private:
|
|
537
|
+
std::string path_;
|
|
538
|
+
std::unique_ptr<et::extension::Module> module_;
|
|
539
|
+
};
|
|
540
|
+
|
|
541
|
+
/**
|
|
542
|
+
* Initialize the Executorch Ruby module
|
|
543
|
+
*/
|
|
544
|
+
extern "C"
|
|
545
|
+
void Init_executorch() {
|
|
546
|
+
Rice::Module m = define_module("Executorch");
|
|
547
|
+
|
|
548
|
+
// Define version constant
|
|
549
|
+
m.const_set("NATIVE_VERSION", String("0.1.0"));
|
|
550
|
+
|
|
551
|
+
// Define error class
|
|
552
|
+
define_class_under<std::runtime_error>(m, "NativeError")
|
|
553
|
+
.define_constructor(Constructor<std::runtime_error, const std::string&>());
|
|
554
|
+
|
|
555
|
+
// Define Tensor class
|
|
556
|
+
define_class_under<RubyTensor>(m, "Tensor")
|
|
557
|
+
.define_singleton_function("create", &RubyTensor::create,
|
|
558
|
+
Arg("data"), Arg("shape"), Arg("dtype"))
|
|
559
|
+
.define_singleton_function("from_array", &RubyTensor::from_array,
|
|
560
|
+
Arg("data"), Arg("shape"))
|
|
561
|
+
.define_method("shape", &RubyTensor::shape)
|
|
562
|
+
.define_method("dim", &RubyTensor::dim)
|
|
563
|
+
.define_method("numel", &RubyTensor::numel)
|
|
564
|
+
.define_method("dtype", &RubyTensor::dtype)
|
|
565
|
+
.define_method("_original_to_a", &RubyTensor::to_a)
|
|
566
|
+
.define_method("to_s", &RubyTensor::to_s)
|
|
567
|
+
.define_method("inspect", &RubyTensor::to_s);
|
|
568
|
+
|
|
569
|
+
// Define Model class
|
|
570
|
+
define_class_under<RubyModel>(m, "Model")
|
|
571
|
+
.define_constructor(Constructor<RubyModel, const std::string&>(),
|
|
572
|
+
Arg("path"))
|
|
573
|
+
.define_method("loaded?", &RubyModel::is_loaded)
|
|
574
|
+
.define_method("method_names", &RubyModel::method_names)
|
|
575
|
+
.define_method("path", &RubyModel::path)
|
|
576
|
+
.define_method("forward", &RubyModel::forward,
|
|
577
|
+
Arg("inputs"))
|
|
578
|
+
.define_method("execute", &RubyModel::execute,
|
|
579
|
+
Arg("method_name"), Arg("inputs"));
|
|
580
|
+
|
|
581
|
+
// Note: EValue is kept internal - users interact with Tensor and native Ruby types
|
|
582
|
+
}
|