red-arrow 0.12.0 → 0.13.0

Sign up to get free protection for your applications and to get access to all the features.

Potentially problematic release.


This version of red-arrow might be problematic. Click here for more details.

Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/Rakefile +49 -4
  3. data/ext/arrow/arrow.cpp +43 -0
  4. data/ext/arrow/extconf.rb +52 -0
  5. data/ext/arrow/record-batch.cpp +756 -0
  6. data/ext/arrow/red-arrow.hpp +60 -0
  7. data/lib/arrow.rb +2 -1
  8. data/lib/arrow/array-builder.rb +4 -0
  9. data/lib/arrow/array.rb +11 -1
  10. data/lib/arrow/bigdecimal-extension.rb +24 -0
  11. data/lib/arrow/binary-array-builder.rb +36 -0
  12. data/lib/arrow/block-closable.rb +5 -1
  13. data/lib/arrow/csv-loader.rb +28 -6
  14. data/lib/arrow/data-type.rb +8 -4
  15. data/lib/arrow/decimal128-array-builder.rb +2 -2
  16. data/lib/arrow/decimal128.rb +42 -0
  17. data/lib/arrow/list-array-builder.rb +1 -1
  18. data/lib/arrow/loader.rb +8 -0
  19. data/lib/arrow/null-array-builder.rb +26 -0
  20. data/lib/arrow/record-batch-builder.rb +8 -9
  21. data/lib/arrow/struct-array-builder.rb +3 -3
  22. data/lib/arrow/struct-array.rb +15 -7
  23. data/lib/arrow/struct.rb +11 -0
  24. data/lib/arrow/table-loader.rb +14 -14
  25. data/lib/arrow/version.rb +1 -1
  26. data/red-arrow.gemspec +8 -4
  27. data/test/raw-records/record-batch/test-basic-arrays.rb +349 -0
  28. data/test/raw-records/record-batch/test-dense-union-array.rb +486 -0
  29. data/test/raw-records/record-batch/test-list-array.rb +498 -0
  30. data/test/raw-records/record-batch/test-multiple-columns.rb +49 -0
  31. data/test/raw-records/record-batch/test-sparse-union-array.rb +474 -0
  32. data/test/raw-records/record-batch/test-struct-array.rb +426 -0
  33. data/test/run-test.rb +25 -2
  34. data/test/test-array.rb +38 -9
  35. data/test/test-bigdecimal.rb +23 -0
  36. data/{dependency-check/Rakefile → test/test-buffer.rb} +15 -20
  37. data/test/test-chunked-array.rb +22 -0
  38. data/test/test-column.rb +24 -0
  39. data/test/test-csv-loader.rb +30 -0
  40. data/test/test-data-type.rb +25 -0
  41. data/test/test-decimal128.rb +64 -0
  42. data/test/test-field.rb +20 -0
  43. data/test/test-group.rb +2 -2
  44. data/test/test-record-batch-builder.rb +9 -0
  45. data/test/test-record-batch.rb +14 -0
  46. data/test/test-schema.rb +14 -0
  47. data/test/test-struct-array.rb +16 -3
  48. data/test/test-table.rb +14 -0
  49. data/test/test-tensor.rb +56 -0
  50. metadata +117 -47
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 847a4994cc15fb50df335c7231c942d49392f1ec528b85647b2bbe6fb4e82f7b
4
- data.tar.gz: 4337680e47dea67107a1fef863d66e936b9c5d20b2cb6698e879b774eada3c88
3
+ metadata.gz: a362c35e8eb6b5b93cb69e1ee077acf572af51d89b54e0e9a41282cbdb546cbb
4
+ data.tar.gz: db53bb7327021bbb8e45e1b8dbf741becf4ad98d85366fe6c3a475da3650ef4e
5
5
  SHA512:
6
- metadata.gz: c72b26d9b4f488c4d00184ea3243056c69a601fe59fba80a7423075b19034b0db7ad12dbbf5e210ea4b66607acd2ad7123a85057a27266b21665e143961bccb7
7
- data.tar.gz: 1fb8c82007a25cac5e99d41f3bb879ed86ad1f9deab320f45470252513fada6c6f972f7ddb0ed3f38ca1dcdf7de3f5662a3fa28536a4981b9e5993dd5bc666a8
6
+ metadata.gz: 153a2ab1e7ccefe6fbe3d3481d8d86cf90f6d583825543b4ae541d3e3704b1aab4debcc2d5aa7148ded9eefbed828354697514e22664ba63aa83e3ad74e41180
7
+ data.tar.gz: a8ab623a40a47a073865a13c0eaa1e7c380574d57307068cb5258884cbf04113030b137b00a83347c0c5a3ccc6ca597db25c89297ccf38b20e03f71e9345b1aa
data/Rakefile CHANGED
@@ -17,27 +17,72 @@
17
17
  # specific language governing permissions and limitations
18
18
  # under the License.
19
19
 
20
- require "rubygems"
21
20
  require "bundler/gem_helper"
21
+ require "rake/clean"
22
22
  require "yard"
23
23
 
24
24
  base_dir = File.join(__dir__)
25
25
 
26
26
  helper = Bundler::GemHelper.new(base_dir)
27
27
  helper.install
28
+ spec = helper.gemspec
28
29
 
29
30
  release_task = Rake::Task["release"]
30
31
  release_task.prerequisites.replace(["build", "release:rubygem_push"])
31
32
 
33
+ def run_extconf(extension_dir, *arguments)
34
+ cd(extension_dir) do
35
+ ruby("extconf.rb", *arguments)
36
+ end
37
+ end
38
+
39
+ spec.extensions.each do |extension|
40
+ extension_dir = File.dirname(extension)
41
+ CLOBBER << File.join(extension_dir, "Makefile")
42
+ CLOBBER << File.join(extension_dir, "mkmf.log")
43
+
44
+ makefile = File.join(extension_dir, "Makefile")
45
+ file makefile do
46
+ run_extconf(extension_dir)
47
+ end
48
+
49
+ desc "Configure"
50
+ task :configure do
51
+ run_extconf(extension_dir)
52
+ end
53
+
54
+ desc "Compile"
55
+ task :compile => makefile do
56
+ cd(extension_dir) do
57
+ sh("make")
58
+ end
59
+ end
60
+
61
+ task :clean do
62
+ cd(extension_dir) do
63
+ sh("make", "clean") if File.exist?("Makefile")
64
+ end
65
+ end
66
+ end
67
+
32
68
  desc "Run tests"
33
69
  task :test do
34
- cd("dependency-check") do
35
- ruby("-S", "rake")
36
- end
37
70
  ruby("test/run-test.rb")
38
71
  end
39
72
 
40
73
  task default: :test
41
74
 
75
+ desc "Run benchmarks"
76
+ task :benchmark do
77
+ benchmarks = if ENV["BENCHMARKS"]
78
+ ENV["BENCHMARKS"].split
79
+ else
80
+ FileList["benchmark/{,*/**/}*.yml"]
81
+ end
82
+ benchmarks.each do |benchmark|
83
+ sh("benchmark-driver", benchmark)
84
+ end
85
+ end
86
+
42
87
  YARD::Rake::YardocTask.new do |task|
43
88
  end
@@ -0,0 +1,43 @@
1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one
3
+ * or more contributor license agreements. See the NOTICE file
4
+ * distributed with this work for additional information
5
+ * regarding copyright ownership. The ASF licenses this file
6
+ * to you under the Apache License, Version 2.0 (the
7
+ * "License"); you may not use this file except in compliance
8
+ * with the License. You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing,
13
+ * software distributed under the License is distributed on an
14
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ * KIND, either express or implied. See the License for the
16
+ * specific language governing permissions and limitations
17
+ * under the License.
18
+ */
19
+
20
+ #include "red-arrow.hpp"
21
+
22
+ #include <ruby.hpp>
23
+
24
+ namespace red_arrow {
25
+ VALUE cDate;
26
+ ID id_BigDecimal;
27
+ ID id_jd;
28
+ ID id_to_datetime;
29
+ }
30
+
31
+ extern "C" void Init_arrow() {
32
+ auto mArrow = rb_const_get_at(rb_cObject, rb_intern("Arrow"));
33
+ auto cArrowRecordBatch = rb_const_get_at(mArrow, rb_intern("RecordBatch"));
34
+ rb_define_method(cArrowRecordBatch, "raw_records",
35
+ reinterpret_cast<rb::RawMethod>(red_arrow::record_batch_raw_records),
36
+ 0);
37
+
38
+ red_arrow::cDate = rb_const_get(rb_cObject, rb_intern("Date"));
39
+
40
+ red_arrow::id_BigDecimal = rb_intern("BigDecimal");
41
+ red_arrow::id_jd = rb_intern("jd");
42
+ red_arrow::id_to_datetime = rb_intern("to_datetime");
43
+ }
@@ -0,0 +1,52 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ require "extpp"
19
+ require "mkmf-gnome2"
20
+
21
+ arrow_pkg_config_path = ENV["ARROW_PKG_CONFIG_PATH"]
22
+ if arrow_pkg_config_path
23
+ pkg_config_paths = [arrow_pkg_config_path, ENV["PKG_CONFIG_PATH"]].compact
24
+ ENV["PKG_CONFIG_PATH"] = pkg_config_paths.join(File::PATH_SEPARATOR)
25
+ end
26
+
27
+ unless required_pkg_config_package("arrow",
28
+ debian: "libarrow-dev",
29
+ redhat: "arrow-devel",
30
+ homebrew: "apache-arrow",
31
+ msys2: "arrow")
32
+ exit(false)
33
+ end
34
+
35
+ unless required_pkg_config_package("arrow-glib",
36
+ debian: "libarrow-glib-dev",
37
+ redhat: "arrow-glib-devel",
38
+ homebrew: "apache-arrow-glib",
39
+ msys2: "arrow")
40
+ exit(false)
41
+ end
42
+
43
+ [
44
+ ["glib2", "ext/glib2"],
45
+ ].each do |name, relative_source_dir|
46
+ spec = find_gem_spec(name)
47
+ source_dir = File.join(spec.full_gem_path, relative_source_dir)
48
+ build_dir = source_dir
49
+ add_depend_package_path(name, source_dir, build_dir)
50
+ end
51
+
52
+ create_makefile("arrow")
@@ -0,0 +1,756 @@
1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one
3
+ * or more contributor license agreements. See the NOTICE file
4
+ * distributed with this work for additional information
5
+ * regarding copyright ownership. The ASF licenses this file
6
+ * to you under the Apache License, Version 2.0 (the
7
+ * "License"); you may not use this file except in compliance
8
+ * with the License. You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing,
13
+ * software distributed under the License is distributed on an
14
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ * KIND, either express or implied. See the License for the
16
+ * specific language governing permissions and limitations
17
+ * under the License.
18
+ */
19
+
20
+ #include "red-arrow.hpp"
21
+
22
+ #include <ruby.hpp>
23
+ #include <ruby/encoding.h>
24
+
25
+ #include <arrow-glib/error.hpp>
26
+
27
+ #include <arrow/util/logging.h>
28
+
29
+ namespace red_arrow {
30
+ namespace {
31
+ using Status = arrow::Status;
32
+
33
+ void check_status(const Status&& status, const char* context) {
34
+ GError* error = nullptr;
35
+ if (!garrow_error_check(&error, status, context)) {
36
+ RG_RAISE_ERROR(error);
37
+ }
38
+ }
39
+
40
+ class ListArrayValueConverter;
41
+ class StructArrayValueConverter;
42
+ class UnionArrayValueConverter;
43
+ class DictionaryArrayValueConverter;
44
+
45
+ class ArrayValueConverter {
46
+ public:
47
+ ArrayValueConverter()
48
+ : decimal_buffer_(),
49
+ list_array_value_converter_(nullptr),
50
+ struct_array_value_converter_(nullptr),
51
+ union_array_value_converter_(nullptr),
52
+ dictionary_array_value_converter_(nullptr) {
53
+ }
54
+
55
+ void set_sub_value_converters(ListArrayValueConverter* list_array_value_converter,
56
+ StructArrayValueConverter* struct_array_value_converter,
57
+ UnionArrayValueConverter* union_array_value_converter,
58
+ DictionaryArrayValueConverter* dictionary_array_value_converter) {
59
+ list_array_value_converter_ = list_array_value_converter;
60
+ struct_array_value_converter_ = struct_array_value_converter;
61
+ union_array_value_converter_ = union_array_value_converter;
62
+ dictionary_array_value_converter_ = dictionary_array_value_converter;
63
+ }
64
+
65
+ inline VALUE convert(const arrow::NullArray& array,
66
+ const int64_t i) {
67
+ return Qnil;
68
+ }
69
+
70
+ inline VALUE convert(const arrow::BooleanArray& array,
71
+ const int64_t i) {
72
+ return array.Value(i) ? Qtrue : Qfalse;
73
+ }
74
+
75
+ inline VALUE convert(const arrow::Int8Array& array,
76
+ const int64_t i) {
77
+ return INT2NUM(array.Value(i));
78
+ }
79
+
80
+ inline VALUE convert(const arrow::Int16Array& array,
81
+ const int64_t i) {
82
+ return INT2NUM(array.Value(i));
83
+ }
84
+
85
+ inline VALUE convert(const arrow::Int32Array& array,
86
+ const int64_t i) {
87
+ return INT2NUM(array.Value(i));
88
+ }
89
+
90
+ inline VALUE convert(const arrow::Int64Array& array,
91
+ const int64_t i) {
92
+ return LL2NUM(array.Value(i));
93
+ }
94
+
95
+ inline VALUE convert(const arrow::UInt8Array& array,
96
+ const int64_t i) {
97
+ return UINT2NUM(array.Value(i));
98
+ }
99
+
100
+ inline VALUE convert(const arrow::UInt16Array& array,
101
+ const int64_t i) {
102
+ return UINT2NUM(array.Value(i));
103
+ }
104
+
105
+ inline VALUE convert(const arrow::UInt32Array& array,
106
+ const int64_t i) {
107
+ return UINT2NUM(array.Value(i));
108
+ }
109
+
110
+ inline VALUE convert(const arrow::UInt64Array& array,
111
+ const int64_t i) {
112
+ return ULL2NUM(array.Value(i));
113
+ }
114
+
115
+ // TODO
116
+ // inline VALUE convert(const arrow::HalfFloatArray& array,
117
+ // const int64_t i) {
118
+ // }
119
+
120
+ inline VALUE convert(const arrow::FloatArray& array,
121
+ const int64_t i) {
122
+ return DBL2NUM(array.Value(i));
123
+ }
124
+
125
+ inline VALUE convert(const arrow::DoubleArray& array,
126
+ const int64_t i) {
127
+ return DBL2NUM(array.Value(i));
128
+ }
129
+
130
+ inline VALUE convert(const arrow::BinaryArray& array,
131
+ const int64_t i) {
132
+ int32_t length;
133
+ const auto value = array.GetValue(i, &length);
134
+ // TODO: encoding support
135
+ return rb_enc_str_new(reinterpret_cast<const char*>(value),
136
+ length,
137
+ rb_ascii8bit_encoding());
138
+ }
139
+
140
+ inline VALUE convert(const arrow::StringArray& array,
141
+ const int64_t i) {
142
+ int32_t length;
143
+ const auto value = array.GetValue(i, &length);
144
+ return rb_utf8_str_new(reinterpret_cast<const char*>(value),
145
+ length);
146
+ }
147
+
148
+ inline VALUE convert(const arrow::FixedSizeBinaryArray& array,
149
+ const int64_t i) {
150
+ return rb_enc_str_new(reinterpret_cast<const char*>(array.Value(i)),
151
+ array.byte_width(),
152
+ rb_ascii8bit_encoding());
153
+ }
154
+
155
+ constexpr static int32_t JULIAN_DATE_UNIX_EPOCH = 2440588;
156
+ inline VALUE convert(const arrow::Date32Array& array,
157
+ const int64_t i) {
158
+ const auto value = array.Value(i);
159
+ const auto days_in_julian = value + JULIAN_DATE_UNIX_EPOCH;
160
+ return rb_funcall(cDate, id_jd, 1, LONG2NUM(days_in_julian));
161
+ }
162
+
163
+ inline VALUE convert(const arrow::Date64Array& array,
164
+ const int64_t i) {
165
+ const auto value = array.Value(i);
166
+ auto msec = LL2NUM(value);
167
+ auto sec = rb_rational_new(msec, INT2NUM(1000));
168
+ auto time_value = rb_time_num_new(sec, Qnil);
169
+ return rb_funcall(time_value, id_to_datetime, 0, 0);
170
+ }
171
+
172
+ inline VALUE convert(const arrow::Time32Array& array,
173
+ const int64_t i) {
174
+ // TODO: unit treatment
175
+ const auto value = array.Value(i);
176
+ return INT2NUM(value);
177
+ }
178
+
179
+ inline VALUE convert(const arrow::Time64Array& array,
180
+ const int64_t i) {
181
+ // TODO: unit treatment
182
+ const auto value = array.Value(i);
183
+ return LL2NUM(value);
184
+ }
185
+
186
+ inline VALUE convert(const arrow::TimestampArray& array,
187
+ const int64_t i) {
188
+ const auto type =
189
+ arrow::internal::checked_cast<const arrow::TimestampType*>(array.type().get());
190
+ auto scale = time_unit_to_scale(type->unit());
191
+ if (NIL_P(scale)) {
192
+ rb_raise(rb_eArgError, "Invalid TimeUnit");
193
+ }
194
+ auto value = array.Value(i);
195
+ auto sec = rb_rational_new(LL2NUM(value), scale);
196
+ return rb_time_num_new(sec, Qnil);
197
+ }
198
+
199
+ // TODO
200
+ // inline VALUE convert(const arrow::IntervalArray& array,
201
+ // const int64_t i) {
202
+ // };
203
+
204
+ VALUE convert(const arrow::ListArray& array,
205
+ const int64_t i);
206
+
207
+ VALUE convert(const arrow::StructArray& array,
208
+ const int64_t i);
209
+
210
+ VALUE convert(const arrow::UnionArray& array,
211
+ const int64_t i);
212
+
213
+ VALUE convert(const arrow::DictionaryArray& array,
214
+ const int64_t i);
215
+
216
+ inline VALUE convert(const arrow::Decimal128Array& array,
217
+ const int64_t i) {
218
+ decimal_buffer_ = array.FormatValue(i);
219
+ return rb_funcall(rb_cObject,
220
+ id_BigDecimal,
221
+ 1,
222
+ rb_enc_str_new(decimal_buffer_.data(),
223
+ decimal_buffer_.length(),
224
+ rb_ascii8bit_encoding()));
225
+ }
226
+
227
+ private:
228
+ std::string decimal_buffer_;
229
+ ListArrayValueConverter* list_array_value_converter_;
230
+ StructArrayValueConverter* struct_array_value_converter_;
231
+ UnionArrayValueConverter* union_array_value_converter_;
232
+ DictionaryArrayValueConverter* dictionary_array_value_converter_;
233
+ };
234
+
235
+ class ListArrayValueConverter : public arrow::ArrayVisitor {
236
+ public:
237
+ explicit ListArrayValueConverter(ArrayValueConverter* converter)
238
+ : array_value_converter_(converter),
239
+ offset_(0),
240
+ length_(0),
241
+ result_(Qnil) {}
242
+
243
+ VALUE convert(const arrow::ListArray& array, const int64_t index) {
244
+ auto values = array.values().get();
245
+ auto offset_keep = offset_;
246
+ auto length_keep = length_;
247
+ offset_ = array.value_offset(index);
248
+ length_ = array.value_length(index);
249
+ auto result_keep = result_;
250
+ result_ = rb_ary_new_capa(length_);
251
+ check_status(values->Accept(this),
252
+ "[raw-records][list-array]");
253
+ offset_ = offset_keep;
254
+ length_ = length_keep;
255
+ auto result_return = result_;
256
+ result_ = result_keep;
257
+ return result_return;
258
+ }
259
+
260
+ #define VISIT(TYPE) \
261
+ Status Visit(const arrow::TYPE ## Array& array) override { \
262
+ return visit_value(array); \
263
+ }
264
+
265
+ VISIT(Null)
266
+ VISIT(Boolean)
267
+ VISIT(Int8)
268
+ VISIT(Int16)
269
+ VISIT(Int32)
270
+ VISIT(Int64)
271
+ VISIT(UInt8)
272
+ VISIT(UInt16)
273
+ VISIT(UInt32)
274
+ VISIT(UInt64)
275
+ // TODO
276
+ // VISIT(HalfFloat)
277
+ VISIT(Float)
278
+ VISIT(Double)
279
+ VISIT(Binary)
280
+ VISIT(String)
281
+ VISIT(FixedSizeBinary)
282
+ VISIT(Date32)
283
+ VISIT(Date64)
284
+ VISIT(Time32)
285
+ VISIT(Time64)
286
+ VISIT(Timestamp)
287
+ // TODO
288
+ // VISIT(Interval)
289
+ VISIT(List)
290
+ VISIT(Struct)
291
+ VISIT(Union)
292
+ VISIT(Dictionary)
293
+ VISIT(Decimal128)
294
+ // TODO
295
+ // VISIT(Extension)
296
+
297
+ #undef VISIT
298
+
299
+ private:
300
+ template <typename ArrayType>
301
+ inline VALUE convert_value(const ArrayType& array,
302
+ const int64_t i) {
303
+ return array_value_converter_->convert(array, i);
304
+ }
305
+
306
+ template <typename ArrayType>
307
+ Status visit_value(const ArrayType& array) {
308
+ if (array.null_count() > 0) {
309
+ for (int64_t i = 0; i < length_; ++i) {
310
+ auto value = Qnil;
311
+ if (!array.IsNull(i + offset_)) {
312
+ value = convert_value(array, i + offset_);
313
+ }
314
+ rb_ary_push(result_, value);
315
+ }
316
+ } else {
317
+ for (int64_t i = 0; i < length_; ++i) {
318
+ rb_ary_push(result_, convert_value(array, i + offset_));
319
+ }
320
+ }
321
+ return Status::OK();
322
+ }
323
+
324
+ ArrayValueConverter* array_value_converter_;
325
+ int32_t offset_;
326
+ int32_t length_;
327
+ VALUE result_;
328
+ };
329
+
330
+ class StructArrayValueConverter : public arrow::ArrayVisitor {
331
+ public:
332
+ explicit StructArrayValueConverter(ArrayValueConverter* converter)
333
+ : array_value_converter_(converter),
334
+ key_(Qnil),
335
+ index_(0),
336
+ result_(Qnil) {}
337
+
338
+ VALUE convert(const arrow::StructArray& array,
339
+ const int64_t index) {
340
+ auto index_keep = index_;
341
+ auto result_keep = result_;
342
+ index_ = index;
343
+ result_ = rb_hash_new();
344
+ const auto struct_type = array.struct_type();
345
+ const auto n = struct_type->num_children();
346
+ for (int i = 0; i < n; ++i) {
347
+ const auto field_type = struct_type->child(i).get();
348
+ const auto& field_name = field_type->name();
349
+ auto key_keep = key_;
350
+ key_ = rb_utf8_str_new(field_name.data(), field_name.length());
351
+ const auto field_array = array.field(i).get();
352
+ check_status(field_array->Accept(this),
353
+ "[raw-records][struct-array]");
354
+ key_ = key_keep;
355
+ }
356
+ auto result_return = result_;
357
+ result_ = result_keep;
358
+ index_ = index_keep;
359
+ return result_return;
360
+ }
361
+
362
+ #define VISIT(TYPE) \
363
+ Status Visit(const arrow::TYPE ## Array& array) override { \
364
+ fill_field(array); \
365
+ return Status::OK(); \
366
+ }
367
+
368
+ VISIT(Null)
369
+ VISIT(Boolean)
370
+ VISIT(Int8)
371
+ VISIT(Int16)
372
+ VISIT(Int32)
373
+ VISIT(Int64)
374
+ VISIT(UInt8)
375
+ VISIT(UInt16)
376
+ VISIT(UInt32)
377
+ VISIT(UInt64)
378
+ // TODO
379
+ // VISIT(HalfFloat)
380
+ VISIT(Float)
381
+ VISIT(Double)
382
+ VISIT(Binary)
383
+ VISIT(String)
384
+ VISIT(FixedSizeBinary)
385
+ VISIT(Date32)
386
+ VISIT(Date64)
387
+ VISIT(Time32)
388
+ VISIT(Time64)
389
+ VISIT(Timestamp)
390
+ // TODO
391
+ // VISIT(Interval)
392
+ VISIT(List)
393
+ VISIT(Struct)
394
+ VISIT(Union)
395
+ VISIT(Dictionary)
396
+ VISIT(Decimal128)
397
+ // TODO
398
+ // VISIT(Extension)
399
+
400
+ #undef VISIT
401
+
402
+ private:
403
+ template <typename ArrayType>
404
+ inline VALUE convert_value(const ArrayType& array,
405
+ const int64_t i) {
406
+ return array_value_converter_->convert(array, i);
407
+ }
408
+
409
+ template <typename ArrayType>
410
+ void fill_field(const ArrayType& array) {
411
+ if (array.IsNull(index_)) {
412
+ rb_hash_aset(result_, key_, Qnil);
413
+ } else {
414
+ rb_hash_aset(result_, key_, convert_value(array, index_));
415
+ }
416
+ }
417
+
418
+ ArrayValueConverter* array_value_converter_;
419
+ VALUE key_;
420
+ int64_t index_;
421
+ VALUE result_;
422
+ };
423
+
424
+ class UnionArrayValueConverter : public arrow::ArrayVisitor {
425
+ public:
426
+ explicit UnionArrayValueConverter(ArrayValueConverter* converter)
427
+ : array_value_converter_(converter),
428
+ index_(0),
429
+ result_(Qnil) {}
430
+
431
+ VALUE convert(const arrow::UnionArray& array,
432
+ const int64_t index) {
433
+ const auto index_keep = index_;
434
+ const auto result_keep = result_;
435
+ index_ = index;
436
+ switch (array.mode()) {
437
+ case arrow::UnionMode::SPARSE:
438
+ convert_sparse(array);
439
+ break;
440
+ case arrow::UnionMode::DENSE:
441
+ convert_dense(array);
442
+ break;
443
+ default:
444
+ rb_raise(rb_eArgError, "Invalid union mode");
445
+ break;
446
+ }
447
+ auto result_return = result_;
448
+ index_ = index_keep;
449
+ result_ = result_keep;
450
+ return result_return;
451
+ }
452
+
453
+ #define VISIT(TYPE) \
454
+ Status Visit(const arrow::TYPE ## Array& array) override { \
455
+ convert_value(array); \
456
+ return Status::OK(); \
457
+ }
458
+
459
+ VISIT(Null)
460
+ VISIT(Boolean)
461
+ VISIT(Int8)
462
+ VISIT(Int16)
463
+ VISIT(Int32)
464
+ VISIT(Int64)
465
+ VISIT(UInt8)
466
+ VISIT(UInt16)
467
+ VISIT(UInt32)
468
+ VISIT(UInt64)
469
+ // TODO
470
+ // VISIT(HalfFloat)
471
+ VISIT(Float)
472
+ VISIT(Double)
473
+ VISIT(Binary)
474
+ VISIT(String)
475
+ VISIT(FixedSizeBinary)
476
+ VISIT(Date32)
477
+ VISIT(Date64)
478
+ VISIT(Time32)
479
+ VISIT(Time64)
480
+ VISIT(Timestamp)
481
+ // TODO
482
+ // VISIT(Interval)
483
+ VISIT(List)
484
+ VISIT(Struct)
485
+ VISIT(Union)
486
+ VISIT(Dictionary)
487
+ VISIT(Decimal128)
488
+ // TODO
489
+ // VISIT(Extension)
490
+
491
+ #undef VISIT
492
+ private:
493
+ template <typename ArrayType>
494
+ inline void convert_value(const ArrayType& array) {
495
+ auto result = rb_hash_new();
496
+ if (array.IsNull(index_)) {
497
+ rb_hash_aset(result, field_name_, Qnil);
498
+ } else {
499
+ rb_hash_aset(result,
500
+ field_name_,
501
+ array_value_converter_->convert(array, index_));
502
+ }
503
+ result_ = result;
504
+ }
505
+
506
+ uint8_t compute_child_index(const arrow::UnionArray& array,
507
+ arrow::UnionType* type,
508
+ const char* tag) {
509
+ const auto type_id = array.raw_type_ids()[index_];
510
+ const auto& type_codes = type->type_codes();
511
+ for (uint8_t i = 0; i < type_codes.size(); ++i) {
512
+ if (type_codes[i] == type_id) {
513
+ return i;
514
+ }
515
+ }
516
+ check_status(Status::Invalid("Unknown type ID: ", type_id),
517
+ tag);
518
+ return 0;
519
+ }
520
+
521
+ void convert_sparse(const arrow::UnionArray& array) {
522
+ const auto type =
523
+ std::static_pointer_cast<arrow::UnionType>(array.type()).get();
524
+ const auto tag = "[raw-records][union-sparse-array]";
525
+ const auto child_index = compute_child_index(array, type, tag);
526
+ const auto child_field = type->child(child_index).get();
527
+ const auto& field_name = child_field->name();
528
+ const auto field_name_keep = field_name_;
529
+ field_name_ = rb_utf8_str_new(field_name.data(), field_name.length());
530
+ const auto child_array = array.child(child_index).get();
531
+ check_status(child_array->Accept(this), tag);
532
+ field_name_ = field_name_keep;
533
+ }
534
+
535
+ void convert_dense(const arrow::UnionArray& array) {
536
+ const auto type =
537
+ std::static_pointer_cast<arrow::UnionType>(array.type()).get();
538
+ const auto tag = "[raw-records][union-dense-array]";
539
+ const auto child_index = compute_child_index(array, type, tag);
540
+ const auto child_field = type->child(child_index).get();
541
+ const auto& field_name = child_field->name();
542
+ const auto field_name_keep = field_name_;
543
+ field_name_ = rb_utf8_str_new(field_name.data(), field_name.length());
544
+ const auto child_array = array.child(child_index);
545
+ const auto index_keep = index_;
546
+ index_ = array.value_offset(index_);
547
+ check_status(child_array->Accept(this), tag);
548
+ index_ = index_keep;
549
+ field_name_ = field_name_keep;
550
+ }
551
+
552
+ ArrayValueConverter* array_value_converter_;
553
+ int64_t index_;
554
+ VALUE field_name_;
555
+ VALUE result_;
556
+ };
557
+
558
+ class DictionaryArrayValueConverter : public arrow::ArrayVisitor {
559
+ public:
560
+ explicit DictionaryArrayValueConverter(ArrayValueConverter* converter)
561
+ : array_value_converter_(converter),
562
+ index_(0),
563
+ result_(Qnil) {
564
+ }
565
+
566
+ VALUE convert(const arrow::DictionaryArray& array,
567
+ const int64_t index) {
568
+ index_ = index;
569
+ auto indices = array.indices().get();
570
+ check_status(indices->Accept(this),
571
+ "[raw-records][dictionary-array]");
572
+ return result_;
573
+ }
574
+
575
+ // TODO: Convert to real value.
576
+ #define VISIT(TYPE) \
577
+ Status Visit(const arrow::TYPE ## Array& array) override { \
578
+ result_ = convert_value(array, index_); \
579
+ return Status::OK(); \
580
+ }
581
+
582
+ VISIT(Int8)
583
+ VISIT(Int16)
584
+ VISIT(Int32)
585
+ VISIT(Int64)
586
+
587
+ #undef VISIT
588
+
589
+ private:
590
+ template <typename ArrayType>
591
+ inline VALUE convert_value(const ArrayType& array,
592
+ const int64_t i) {
593
+ return array_value_converter_->convert(array, i);
594
+ }
595
+
596
+ ArrayValueConverter* array_value_converter_;
597
+ int64_t index_;
598
+ VALUE result_;
599
+ };
600
+
601
+ VALUE ArrayValueConverter::convert(const arrow::ListArray& array,
602
+ const int64_t i) {
603
+ return list_array_value_converter_->convert(array, i);
604
+ }
605
+
606
+ VALUE ArrayValueConverter::convert(const arrow::StructArray& array,
607
+ const int64_t i) {
608
+ return struct_array_value_converter_->convert(array, i);
609
+ }
610
+
611
+ VALUE ArrayValueConverter::convert(const arrow::UnionArray& array,
612
+ const int64_t i) {
613
+ return union_array_value_converter_->convert(array, i);
614
+ }
615
+
616
+ VALUE ArrayValueConverter::convert(const arrow::DictionaryArray& array,
617
+ const int64_t i) {
618
+ return dictionary_array_value_converter_->convert(array, i);
619
+ }
620
+
621
+ class RawRecordsBuilder : public arrow::ArrayVisitor {
622
+ public:
623
+ explicit RawRecordsBuilder(VALUE records, int n_columns)
624
+ : array_value_converter_(),
625
+ list_array_value_converter_(&array_value_converter_),
626
+ struct_array_value_converter_(&array_value_converter_),
627
+ union_array_value_converter_(&array_value_converter_),
628
+ dictionary_array_value_converter_(&array_value_converter_),
629
+ records_(records),
630
+ n_columns_(n_columns) {
631
+ array_value_converter_.
632
+ set_sub_value_converters(&list_array_value_converter_,
633
+ &struct_array_value_converter_,
634
+ &union_array_value_converter_,
635
+ &dictionary_array_value_converter_);
636
+ }
637
+
638
+ void build(const arrow::RecordBatch& record_batch) {
639
+ rb::protect([&] {
640
+ const auto n_rows = record_batch.num_rows();
641
+ for (int64_t i = 0; i < n_rows; ++i) {
642
+ auto record = rb_ary_new_capa(n_columns_);
643
+ rb_ary_push(records_, record);
644
+ }
645
+ for (int i = 0; i < n_columns_; ++i) {
646
+ const auto array = record_batch.column(i).get();
647
+ column_index_ = i;
648
+ check_status(array->Accept(this),
649
+ "[raw-records]");
650
+ }
651
+ return Qnil;
652
+ });
653
+ }
654
+
655
+ #define VISIT(TYPE) \
656
+ Status Visit(const arrow::TYPE ## Array& array) override { \
657
+ convert(array); \
658
+ return Status::OK(); \
659
+ }
660
+
661
+ VISIT(Null)
662
+ VISIT(Boolean)
663
+ VISIT(Int8)
664
+ VISIT(Int16)
665
+ VISIT(Int32)
666
+ VISIT(Int64)
667
+ VISIT(UInt8)
668
+ VISIT(UInt16)
669
+ VISIT(UInt32)
670
+ VISIT(UInt64)
671
+ // TODO
672
+ // VISIT(HalfFloat)
673
+ VISIT(Float)
674
+ VISIT(Double)
675
+ VISIT(Binary)
676
+ VISIT(String)
677
+ VISIT(FixedSizeBinary)
678
+ VISIT(Date32)
679
+ VISIT(Date64)
680
+ VISIT(Time32)
681
+ VISIT(Time64)
682
+ VISIT(Timestamp)
683
+ // TODO
684
+ // VISIT(Interval)
685
+ VISIT(List)
686
+ VISIT(Struct)
687
+ VISIT(Union)
688
+ VISIT(Dictionary)
689
+ VISIT(Decimal128)
690
+ // TODO
691
+ // VISIT(Extension)
692
+
693
+ #undef VISIT
694
+
695
+ private:
696
+ template <typename ArrayType>
697
+ inline VALUE convert_value(const ArrayType& array,
698
+ const int64_t i) {
699
+ return array_value_converter_.convert(array, i);
700
+ }
701
+
702
+ template <typename ArrayType>
703
+ void convert(const ArrayType& array) {
704
+ const auto n = array.length();
705
+ if (array.null_count() > 0) {
706
+ for (int64_t i = 0; i < n; ++i) {
707
+ auto value = Qnil;
708
+ if (!array.IsNull(i)) {
709
+ value = convert_value(array, i);
710
+ }
711
+ auto record = rb_ary_entry(records_, i);
712
+ rb_ary_store(record, column_index_, value);
713
+ }
714
+ } else {
715
+ for (int64_t i = 0; i < n; ++i) {
716
+ auto record = rb_ary_entry(records_, i);
717
+ rb_ary_store(record, column_index_, convert_value(array, i));
718
+ }
719
+ }
720
+ }
721
+
722
+ ArrayValueConverter array_value_converter_;
723
+ ListArrayValueConverter list_array_value_converter_;
724
+ StructArrayValueConverter struct_array_value_converter_;
725
+ UnionArrayValueConverter union_array_value_converter_;
726
+ DictionaryArrayValueConverter dictionary_array_value_converter_;
727
+
728
+ // Destination for converted records.
729
+ VALUE records_;
730
+
731
+ // The current column index.
732
+ int column_index_;
733
+
734
+ // The number of columns.
735
+ const int n_columns_;
736
+ };
737
+ }
738
+
739
+ VALUE
740
+ record_batch_raw_records(VALUE rb_record_batch) {
741
+ auto garrow_record_batch = GARROW_RECORD_BATCH(RVAL2GOBJ(rb_record_batch));
742
+ auto record_batch = garrow_record_batch_get_raw(garrow_record_batch).get();
743
+ const auto n_rows = record_batch->num_rows();
744
+ const auto n_columns = record_batch->num_columns();
745
+ auto records = rb_ary_new_capa(n_rows);
746
+
747
+ try {
748
+ RawRecordsBuilder builder(records, n_columns);
749
+ builder.build(*record_batch);
750
+ } catch (rb::State& state) {
751
+ state.jump();
752
+ }
753
+
754
+ return records;
755
+ }
756
+ }