red-arrow 0.12.0 → 0.13.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of red-arrow might be problematic. Click here for more details.

Files changed (50) hide show
  1. checksums.yaml +4 -4
  2. data/Rakefile +49 -4
  3. data/ext/arrow/arrow.cpp +43 -0
  4. data/ext/arrow/extconf.rb +52 -0
  5. data/ext/arrow/record-batch.cpp +756 -0
  6. data/ext/arrow/red-arrow.hpp +60 -0
  7. data/lib/arrow.rb +2 -1
  8. data/lib/arrow/array-builder.rb +4 -0
  9. data/lib/arrow/array.rb +11 -1
  10. data/lib/arrow/bigdecimal-extension.rb +24 -0
  11. data/lib/arrow/binary-array-builder.rb +36 -0
  12. data/lib/arrow/block-closable.rb +5 -1
  13. data/lib/arrow/csv-loader.rb +28 -6
  14. data/lib/arrow/data-type.rb +8 -4
  15. data/lib/arrow/decimal128-array-builder.rb +2 -2
  16. data/lib/arrow/decimal128.rb +42 -0
  17. data/lib/arrow/list-array-builder.rb +1 -1
  18. data/lib/arrow/loader.rb +8 -0
  19. data/lib/arrow/null-array-builder.rb +26 -0
  20. data/lib/arrow/record-batch-builder.rb +8 -9
  21. data/lib/arrow/struct-array-builder.rb +3 -3
  22. data/lib/arrow/struct-array.rb +15 -7
  23. data/lib/arrow/struct.rb +11 -0
  24. data/lib/arrow/table-loader.rb +14 -14
  25. data/lib/arrow/version.rb +1 -1
  26. data/red-arrow.gemspec +8 -4
  27. data/test/raw-records/record-batch/test-basic-arrays.rb +349 -0
  28. data/test/raw-records/record-batch/test-dense-union-array.rb +486 -0
  29. data/test/raw-records/record-batch/test-list-array.rb +498 -0
  30. data/test/raw-records/record-batch/test-multiple-columns.rb +49 -0
  31. data/test/raw-records/record-batch/test-sparse-union-array.rb +474 -0
  32. data/test/raw-records/record-batch/test-struct-array.rb +426 -0
  33. data/test/run-test.rb +25 -2
  34. data/test/test-array.rb +38 -9
  35. data/test/test-bigdecimal.rb +23 -0
  36. data/{dependency-check/Rakefile → test/test-buffer.rb} +15 -20
  37. data/test/test-chunked-array.rb +22 -0
  38. data/test/test-column.rb +24 -0
  39. data/test/test-csv-loader.rb +30 -0
  40. data/test/test-data-type.rb +25 -0
  41. data/test/test-decimal128.rb +64 -0
  42. data/test/test-field.rb +20 -0
  43. data/test/test-group.rb +2 -2
  44. data/test/test-record-batch-builder.rb +9 -0
  45. data/test/test-record-batch.rb +14 -0
  46. data/test/test-schema.rb +14 -0
  47. data/test/test-struct-array.rb +16 -3
  48. data/test/test-table.rb +14 -0
  49. data/test/test-tensor.rb +56 -0
  50. metadata +117 -47
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 847a4994cc15fb50df335c7231c942d49392f1ec528b85647b2bbe6fb4e82f7b
4
- data.tar.gz: 4337680e47dea67107a1fef863d66e936b9c5d20b2cb6698e879b774eada3c88
3
+ metadata.gz: a362c35e8eb6b5b93cb69e1ee077acf572af51d89b54e0e9a41282cbdb546cbb
4
+ data.tar.gz: db53bb7327021bbb8e45e1b8dbf741becf4ad98d85366fe6c3a475da3650ef4e
5
5
  SHA512:
6
- metadata.gz: c72b26d9b4f488c4d00184ea3243056c69a601fe59fba80a7423075b19034b0db7ad12dbbf5e210ea4b66607acd2ad7123a85057a27266b21665e143961bccb7
7
- data.tar.gz: 1fb8c82007a25cac5e99d41f3bb879ed86ad1f9deab320f45470252513fada6c6f972f7ddb0ed3f38ca1dcdf7de3f5662a3fa28536a4981b9e5993dd5bc666a8
6
+ metadata.gz: 153a2ab1e7ccefe6fbe3d3481d8d86cf90f6d583825543b4ae541d3e3704b1aab4debcc2d5aa7148ded9eefbed828354697514e22664ba63aa83e3ad74e41180
7
+ data.tar.gz: a8ab623a40a47a073865a13c0eaa1e7c380574d57307068cb5258884cbf04113030b137b00a83347c0c5a3ccc6ca597db25c89297ccf38b20e03f71e9345b1aa
data/Rakefile CHANGED
@@ -17,27 +17,72 @@
17
17
  # specific language governing permissions and limitations
18
18
  # under the License.
19
19
 
20
- require "rubygems"
21
20
  require "bundler/gem_helper"
21
+ require "rake/clean"
22
22
  require "yard"
23
23
 
24
24
  base_dir = File.join(__dir__)
25
25
 
26
26
  helper = Bundler::GemHelper.new(base_dir)
27
27
  helper.install
28
+ spec = helper.gemspec
28
29
 
29
30
  release_task = Rake::Task["release"]
30
31
  release_task.prerequisites.replace(["build", "release:rubygem_push"])
31
32
 
33
+ def run_extconf(extension_dir, *arguments)
34
+ cd(extension_dir) do
35
+ ruby("extconf.rb", *arguments)
36
+ end
37
+ end
38
+
39
+ spec.extensions.each do |extension|
40
+ extension_dir = File.dirname(extension)
41
+ CLOBBER << File.join(extension_dir, "Makefile")
42
+ CLOBBER << File.join(extension_dir, "mkmf.log")
43
+
44
+ makefile = File.join(extension_dir, "Makefile")
45
+ file makefile do
46
+ run_extconf(extension_dir)
47
+ end
48
+
49
+ desc "Configure"
50
+ task :configure do
51
+ run_extconf(extension_dir)
52
+ end
53
+
54
+ desc "Compile"
55
+ task :compile => makefile do
56
+ cd(extension_dir) do
57
+ sh("make")
58
+ end
59
+ end
60
+
61
+ task :clean do
62
+ cd(extension_dir) do
63
+ sh("make", "clean") if File.exist?("Makefile")
64
+ end
65
+ end
66
+ end
67
+
32
68
  desc "Run tests"
33
69
  task :test do
34
- cd("dependency-check") do
35
- ruby("-S", "rake")
36
- end
37
70
  ruby("test/run-test.rb")
38
71
  end
39
72
 
40
73
  task default: :test
41
74
 
75
+ desc "Run benchmarks"
76
+ task :benchmark do
77
+ benchmarks = if ENV["BENCHMARKS"]
78
+ ENV["BENCHMARKS"].split
79
+ else
80
+ FileList["benchmark/{,*/**/}*.yml"]
81
+ end
82
+ benchmarks.each do |benchmark|
83
+ sh("benchmark-driver", benchmark)
84
+ end
85
+ end
86
+
42
87
  YARD::Rake::YardocTask.new do |task|
43
88
  end
@@ -0,0 +1,43 @@
1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one
3
+ * or more contributor license agreements. See the NOTICE file
4
+ * distributed with this work for additional information
5
+ * regarding copyright ownership. The ASF licenses this file
6
+ * to you under the Apache License, Version 2.0 (the
7
+ * "License"); you may not use this file except in compliance
8
+ * with the License. You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing,
13
+ * software distributed under the License is distributed on an
14
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ * KIND, either express or implied. See the License for the
16
+ * specific language governing permissions and limitations
17
+ * under the License.
18
+ */
19
+
20
+ #include "red-arrow.hpp"
21
+
22
+ #include <ruby.hpp>
23
+
24
+ namespace red_arrow {
25
+ VALUE cDate;
26
+ ID id_BigDecimal;
27
+ ID id_jd;
28
+ ID id_to_datetime;
29
+ }
30
+
31
+ extern "C" void Init_arrow() {
32
+ auto mArrow = rb_const_get_at(rb_cObject, rb_intern("Arrow"));
33
+ auto cArrowRecordBatch = rb_const_get_at(mArrow, rb_intern("RecordBatch"));
34
+ rb_define_method(cArrowRecordBatch, "raw_records",
35
+ reinterpret_cast<rb::RawMethod>(red_arrow::record_batch_raw_records),
36
+ 0);
37
+
38
+ red_arrow::cDate = rb_const_get(rb_cObject, rb_intern("Date"));
39
+
40
+ red_arrow::id_BigDecimal = rb_intern("BigDecimal");
41
+ red_arrow::id_jd = rb_intern("jd");
42
+ red_arrow::id_to_datetime = rb_intern("to_datetime");
43
+ }
@@ -0,0 +1,52 @@
1
+ # Licensed to the Apache Software Foundation (ASF) under one
2
+ # or more contributor license agreements. See the NOTICE file
3
+ # distributed with this work for additional information
4
+ # regarding copyright ownership. The ASF licenses this file
5
+ # to you under the Apache License, Version 2.0 (the
6
+ # "License"); you may not use this file except in compliance
7
+ # with the License. You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
18
+ require "extpp"
19
+ require "mkmf-gnome2"
20
+
21
+ arrow_pkg_config_path = ENV["ARROW_PKG_CONFIG_PATH"]
22
+ if arrow_pkg_config_path
23
+ pkg_config_paths = [arrow_pkg_config_path, ENV["PKG_CONFIG_PATH"]].compact
24
+ ENV["PKG_CONFIG_PATH"] = pkg_config_paths.join(File::PATH_SEPARATOR)
25
+ end
26
+
27
+ unless required_pkg_config_package("arrow",
28
+ debian: "libarrow-dev",
29
+ redhat: "arrow-devel",
30
+ homebrew: "apache-arrow",
31
+ msys2: "arrow")
32
+ exit(false)
33
+ end
34
+
35
+ unless required_pkg_config_package("arrow-glib",
36
+ debian: "libarrow-glib-dev",
37
+ redhat: "arrow-glib-devel",
38
+ homebrew: "apache-arrow-glib",
39
+ msys2: "arrow")
40
+ exit(false)
41
+ end
42
+
43
+ [
44
+ ["glib2", "ext/glib2"],
45
+ ].each do |name, relative_source_dir|
46
+ spec = find_gem_spec(name)
47
+ source_dir = File.join(spec.full_gem_path, relative_source_dir)
48
+ build_dir = source_dir
49
+ add_depend_package_path(name, source_dir, build_dir)
50
+ end
51
+
52
+ create_makefile("arrow")
@@ -0,0 +1,756 @@
1
+ /*
2
+ * Licensed to the Apache Software Foundation (ASF) under one
3
+ * or more contributor license agreements. See the NOTICE file
4
+ * distributed with this work for additional information
5
+ * regarding copyright ownership. The ASF licenses this file
6
+ * to you under the Apache License, Version 2.0 (the
7
+ * "License"); you may not use this file except in compliance
8
+ * with the License. You may obtain a copy of the License at
9
+ *
10
+ * http://www.apache.org/licenses/LICENSE-2.0
11
+ *
12
+ * Unless required by applicable law or agreed to in writing,
13
+ * software distributed under the License is distributed on an
14
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15
+ * KIND, either express or implied. See the License for the
16
+ * specific language governing permissions and limitations
17
+ * under the License.
18
+ */
19
+
20
+ #include "red-arrow.hpp"
21
+
22
+ #include <ruby.hpp>
23
+ #include <ruby/encoding.h>
24
+
25
+ #include <arrow-glib/error.hpp>
26
+
27
+ #include <arrow/util/logging.h>
28
+
29
+ namespace red_arrow {
30
+ namespace {
31
+ using Status = arrow::Status;
32
+
33
+ void check_status(const Status&& status, const char* context) {
34
+ GError* error = nullptr;
35
+ if (!garrow_error_check(&error, status, context)) {
36
+ RG_RAISE_ERROR(error);
37
+ }
38
+ }
39
+
40
+ class ListArrayValueConverter;
41
+ class StructArrayValueConverter;
42
+ class UnionArrayValueConverter;
43
+ class DictionaryArrayValueConverter;
44
+
45
+ class ArrayValueConverter {
46
+ public:
47
+ ArrayValueConverter()
48
+ : decimal_buffer_(),
49
+ list_array_value_converter_(nullptr),
50
+ struct_array_value_converter_(nullptr),
51
+ union_array_value_converter_(nullptr),
52
+ dictionary_array_value_converter_(nullptr) {
53
+ }
54
+
55
+ void set_sub_value_converters(ListArrayValueConverter* list_array_value_converter,
56
+ StructArrayValueConverter* struct_array_value_converter,
57
+ UnionArrayValueConverter* union_array_value_converter,
58
+ DictionaryArrayValueConverter* dictionary_array_value_converter) {
59
+ list_array_value_converter_ = list_array_value_converter;
60
+ struct_array_value_converter_ = struct_array_value_converter;
61
+ union_array_value_converter_ = union_array_value_converter;
62
+ dictionary_array_value_converter_ = dictionary_array_value_converter;
63
+ }
64
+
65
+ inline VALUE convert(const arrow::NullArray& array,
66
+ const int64_t i) {
67
+ return Qnil;
68
+ }
69
+
70
+ inline VALUE convert(const arrow::BooleanArray& array,
71
+ const int64_t i) {
72
+ return array.Value(i) ? Qtrue : Qfalse;
73
+ }
74
+
75
+ inline VALUE convert(const arrow::Int8Array& array,
76
+ const int64_t i) {
77
+ return INT2NUM(array.Value(i));
78
+ }
79
+
80
+ inline VALUE convert(const arrow::Int16Array& array,
81
+ const int64_t i) {
82
+ return INT2NUM(array.Value(i));
83
+ }
84
+
85
+ inline VALUE convert(const arrow::Int32Array& array,
86
+ const int64_t i) {
87
+ return INT2NUM(array.Value(i));
88
+ }
89
+
90
+ inline VALUE convert(const arrow::Int64Array& array,
91
+ const int64_t i) {
92
+ return LL2NUM(array.Value(i));
93
+ }
94
+
95
+ inline VALUE convert(const arrow::UInt8Array& array,
96
+ const int64_t i) {
97
+ return UINT2NUM(array.Value(i));
98
+ }
99
+
100
+ inline VALUE convert(const arrow::UInt16Array& array,
101
+ const int64_t i) {
102
+ return UINT2NUM(array.Value(i));
103
+ }
104
+
105
+ inline VALUE convert(const arrow::UInt32Array& array,
106
+ const int64_t i) {
107
+ return UINT2NUM(array.Value(i));
108
+ }
109
+
110
+ inline VALUE convert(const arrow::UInt64Array& array,
111
+ const int64_t i) {
112
+ return ULL2NUM(array.Value(i));
113
+ }
114
+
115
+ // TODO
116
+ // inline VALUE convert(const arrow::HalfFloatArray& array,
117
+ // const int64_t i) {
118
+ // }
119
+
120
+ inline VALUE convert(const arrow::FloatArray& array,
121
+ const int64_t i) {
122
+ return DBL2NUM(array.Value(i));
123
+ }
124
+
125
+ inline VALUE convert(const arrow::DoubleArray& array,
126
+ const int64_t i) {
127
+ return DBL2NUM(array.Value(i));
128
+ }
129
+
130
+ inline VALUE convert(const arrow::BinaryArray& array,
131
+ const int64_t i) {
132
+ int32_t length;
133
+ const auto value = array.GetValue(i, &length);
134
+ // TODO: encoding support
135
+ return rb_enc_str_new(reinterpret_cast<const char*>(value),
136
+ length,
137
+ rb_ascii8bit_encoding());
138
+ }
139
+
140
+ inline VALUE convert(const arrow::StringArray& array,
141
+ const int64_t i) {
142
+ int32_t length;
143
+ const auto value = array.GetValue(i, &length);
144
+ return rb_utf8_str_new(reinterpret_cast<const char*>(value),
145
+ length);
146
+ }
147
+
148
+ inline VALUE convert(const arrow::FixedSizeBinaryArray& array,
149
+ const int64_t i) {
150
+ return rb_enc_str_new(reinterpret_cast<const char*>(array.Value(i)),
151
+ array.byte_width(),
152
+ rb_ascii8bit_encoding());
153
+ }
154
+
155
+ constexpr static int32_t JULIAN_DATE_UNIX_EPOCH = 2440588;
156
+ inline VALUE convert(const arrow::Date32Array& array,
157
+ const int64_t i) {
158
+ const auto value = array.Value(i);
159
+ const auto days_in_julian = value + JULIAN_DATE_UNIX_EPOCH;
160
+ return rb_funcall(cDate, id_jd, 1, LONG2NUM(days_in_julian));
161
+ }
162
+
163
+ inline VALUE convert(const arrow::Date64Array& array,
164
+ const int64_t i) {
165
+ const auto value = array.Value(i);
166
+ auto msec = LL2NUM(value);
167
+ auto sec = rb_rational_new(msec, INT2NUM(1000));
168
+ auto time_value = rb_time_num_new(sec, Qnil);
169
+ return rb_funcall(time_value, id_to_datetime, 0, 0);
170
+ }
171
+
172
+ inline VALUE convert(const arrow::Time32Array& array,
173
+ const int64_t i) {
174
+ // TODO: unit treatment
175
+ const auto value = array.Value(i);
176
+ return INT2NUM(value);
177
+ }
178
+
179
+ inline VALUE convert(const arrow::Time64Array& array,
180
+ const int64_t i) {
181
+ // TODO: unit treatment
182
+ const auto value = array.Value(i);
183
+ return LL2NUM(value);
184
+ }
185
+
186
+ inline VALUE convert(const arrow::TimestampArray& array,
187
+ const int64_t i) {
188
+ const auto type =
189
+ arrow::internal::checked_cast<const arrow::TimestampType*>(array.type().get());
190
+ auto scale = time_unit_to_scale(type->unit());
191
+ if (NIL_P(scale)) {
192
+ rb_raise(rb_eArgError, "Invalid TimeUnit");
193
+ }
194
+ auto value = array.Value(i);
195
+ auto sec = rb_rational_new(LL2NUM(value), scale);
196
+ return rb_time_num_new(sec, Qnil);
197
+ }
198
+
199
+ // TODO
200
+ // inline VALUE convert(const arrow::IntervalArray& array,
201
+ // const int64_t i) {
202
+ // };
203
+
204
+ VALUE convert(const arrow::ListArray& array,
205
+ const int64_t i);
206
+
207
+ VALUE convert(const arrow::StructArray& array,
208
+ const int64_t i);
209
+
210
+ VALUE convert(const arrow::UnionArray& array,
211
+ const int64_t i);
212
+
213
+ VALUE convert(const arrow::DictionaryArray& array,
214
+ const int64_t i);
215
+
216
+ inline VALUE convert(const arrow::Decimal128Array& array,
217
+ const int64_t i) {
218
+ decimal_buffer_ = array.FormatValue(i);
219
+ return rb_funcall(rb_cObject,
220
+ id_BigDecimal,
221
+ 1,
222
+ rb_enc_str_new(decimal_buffer_.data(),
223
+ decimal_buffer_.length(),
224
+ rb_ascii8bit_encoding()));
225
+ }
226
+
227
+ private:
228
+ std::string decimal_buffer_;
229
+ ListArrayValueConverter* list_array_value_converter_;
230
+ StructArrayValueConverter* struct_array_value_converter_;
231
+ UnionArrayValueConverter* union_array_value_converter_;
232
+ DictionaryArrayValueConverter* dictionary_array_value_converter_;
233
+ };
234
+
235
+ class ListArrayValueConverter : public arrow::ArrayVisitor {
236
+ public:
237
+ explicit ListArrayValueConverter(ArrayValueConverter* converter)
238
+ : array_value_converter_(converter),
239
+ offset_(0),
240
+ length_(0),
241
+ result_(Qnil) {}
242
+
243
+ VALUE convert(const arrow::ListArray& array, const int64_t index) {
244
+ auto values = array.values().get();
245
+ auto offset_keep = offset_;
246
+ auto length_keep = length_;
247
+ offset_ = array.value_offset(index);
248
+ length_ = array.value_length(index);
249
+ auto result_keep = result_;
250
+ result_ = rb_ary_new_capa(length_);
251
+ check_status(values->Accept(this),
252
+ "[raw-records][list-array]");
253
+ offset_ = offset_keep;
254
+ length_ = length_keep;
255
+ auto result_return = result_;
256
+ result_ = result_keep;
257
+ return result_return;
258
+ }
259
+
260
+ #define VISIT(TYPE) \
261
+ Status Visit(const arrow::TYPE ## Array& array) override { \
262
+ return visit_value(array); \
263
+ }
264
+
265
+ VISIT(Null)
266
+ VISIT(Boolean)
267
+ VISIT(Int8)
268
+ VISIT(Int16)
269
+ VISIT(Int32)
270
+ VISIT(Int64)
271
+ VISIT(UInt8)
272
+ VISIT(UInt16)
273
+ VISIT(UInt32)
274
+ VISIT(UInt64)
275
+ // TODO
276
+ // VISIT(HalfFloat)
277
+ VISIT(Float)
278
+ VISIT(Double)
279
+ VISIT(Binary)
280
+ VISIT(String)
281
+ VISIT(FixedSizeBinary)
282
+ VISIT(Date32)
283
+ VISIT(Date64)
284
+ VISIT(Time32)
285
+ VISIT(Time64)
286
+ VISIT(Timestamp)
287
+ // TODO
288
+ // VISIT(Interval)
289
+ VISIT(List)
290
+ VISIT(Struct)
291
+ VISIT(Union)
292
+ VISIT(Dictionary)
293
+ VISIT(Decimal128)
294
+ // TODO
295
+ // VISIT(Extension)
296
+
297
+ #undef VISIT
298
+
299
+ private:
300
+ template <typename ArrayType>
301
+ inline VALUE convert_value(const ArrayType& array,
302
+ const int64_t i) {
303
+ return array_value_converter_->convert(array, i);
304
+ }
305
+
306
+ template <typename ArrayType>
307
+ Status visit_value(const ArrayType& array) {
308
+ if (array.null_count() > 0) {
309
+ for (int64_t i = 0; i < length_; ++i) {
310
+ auto value = Qnil;
311
+ if (!array.IsNull(i + offset_)) {
312
+ value = convert_value(array, i + offset_);
313
+ }
314
+ rb_ary_push(result_, value);
315
+ }
316
+ } else {
317
+ for (int64_t i = 0; i < length_; ++i) {
318
+ rb_ary_push(result_, convert_value(array, i + offset_));
319
+ }
320
+ }
321
+ return Status::OK();
322
+ }
323
+
324
+ ArrayValueConverter* array_value_converter_;
325
+ int32_t offset_;
326
+ int32_t length_;
327
+ VALUE result_;
328
+ };
329
+
330
+ class StructArrayValueConverter : public arrow::ArrayVisitor {
331
+ public:
332
+ explicit StructArrayValueConverter(ArrayValueConverter* converter)
333
+ : array_value_converter_(converter),
334
+ key_(Qnil),
335
+ index_(0),
336
+ result_(Qnil) {}
337
+
338
+ VALUE convert(const arrow::StructArray& array,
339
+ const int64_t index) {
340
+ auto index_keep = index_;
341
+ auto result_keep = result_;
342
+ index_ = index;
343
+ result_ = rb_hash_new();
344
+ const auto struct_type = array.struct_type();
345
+ const auto n = struct_type->num_children();
346
+ for (int i = 0; i < n; ++i) {
347
+ const auto field_type = struct_type->child(i).get();
348
+ const auto& field_name = field_type->name();
349
+ auto key_keep = key_;
350
+ key_ = rb_utf8_str_new(field_name.data(), field_name.length());
351
+ const auto field_array = array.field(i).get();
352
+ check_status(field_array->Accept(this),
353
+ "[raw-records][struct-array]");
354
+ key_ = key_keep;
355
+ }
356
+ auto result_return = result_;
357
+ result_ = result_keep;
358
+ index_ = index_keep;
359
+ return result_return;
360
+ }
361
+
362
+ #define VISIT(TYPE) \
363
+ Status Visit(const arrow::TYPE ## Array& array) override { \
364
+ fill_field(array); \
365
+ return Status::OK(); \
366
+ }
367
+
368
+ VISIT(Null)
369
+ VISIT(Boolean)
370
+ VISIT(Int8)
371
+ VISIT(Int16)
372
+ VISIT(Int32)
373
+ VISIT(Int64)
374
+ VISIT(UInt8)
375
+ VISIT(UInt16)
376
+ VISIT(UInt32)
377
+ VISIT(UInt64)
378
+ // TODO
379
+ // VISIT(HalfFloat)
380
+ VISIT(Float)
381
+ VISIT(Double)
382
+ VISIT(Binary)
383
+ VISIT(String)
384
+ VISIT(FixedSizeBinary)
385
+ VISIT(Date32)
386
+ VISIT(Date64)
387
+ VISIT(Time32)
388
+ VISIT(Time64)
389
+ VISIT(Timestamp)
390
+ // TODO
391
+ // VISIT(Interval)
392
+ VISIT(List)
393
+ VISIT(Struct)
394
+ VISIT(Union)
395
+ VISIT(Dictionary)
396
+ VISIT(Decimal128)
397
+ // TODO
398
+ // VISIT(Extension)
399
+
400
+ #undef VISIT
401
+
402
+ private:
403
+ template <typename ArrayType>
404
+ inline VALUE convert_value(const ArrayType& array,
405
+ const int64_t i) {
406
+ return array_value_converter_->convert(array, i);
407
+ }
408
+
409
+ template <typename ArrayType>
410
+ void fill_field(const ArrayType& array) {
411
+ if (array.IsNull(index_)) {
412
+ rb_hash_aset(result_, key_, Qnil);
413
+ } else {
414
+ rb_hash_aset(result_, key_, convert_value(array, index_));
415
+ }
416
+ }
417
+
418
+ ArrayValueConverter* array_value_converter_;
419
+ VALUE key_;
420
+ int64_t index_;
421
+ VALUE result_;
422
+ };
423
+
424
+ class UnionArrayValueConverter : public arrow::ArrayVisitor {
425
+ public:
426
+ explicit UnionArrayValueConverter(ArrayValueConverter* converter)
427
+ : array_value_converter_(converter),
428
+ index_(0),
429
+ result_(Qnil) {}
430
+
431
+ VALUE convert(const arrow::UnionArray& array,
432
+ const int64_t index) {
433
+ const auto index_keep = index_;
434
+ const auto result_keep = result_;
435
+ index_ = index;
436
+ switch (array.mode()) {
437
+ case arrow::UnionMode::SPARSE:
438
+ convert_sparse(array);
439
+ break;
440
+ case arrow::UnionMode::DENSE:
441
+ convert_dense(array);
442
+ break;
443
+ default:
444
+ rb_raise(rb_eArgError, "Invalid union mode");
445
+ break;
446
+ }
447
+ auto result_return = result_;
448
+ index_ = index_keep;
449
+ result_ = result_keep;
450
+ return result_return;
451
+ }
452
+
453
+ #define VISIT(TYPE) \
454
+ Status Visit(const arrow::TYPE ## Array& array) override { \
455
+ convert_value(array); \
456
+ return Status::OK(); \
457
+ }
458
+
459
+ VISIT(Null)
460
+ VISIT(Boolean)
461
+ VISIT(Int8)
462
+ VISIT(Int16)
463
+ VISIT(Int32)
464
+ VISIT(Int64)
465
+ VISIT(UInt8)
466
+ VISIT(UInt16)
467
+ VISIT(UInt32)
468
+ VISIT(UInt64)
469
+ // TODO
470
+ // VISIT(HalfFloat)
471
+ VISIT(Float)
472
+ VISIT(Double)
473
+ VISIT(Binary)
474
+ VISIT(String)
475
+ VISIT(FixedSizeBinary)
476
+ VISIT(Date32)
477
+ VISIT(Date64)
478
+ VISIT(Time32)
479
+ VISIT(Time64)
480
+ VISIT(Timestamp)
481
+ // TODO
482
+ // VISIT(Interval)
483
+ VISIT(List)
484
+ VISIT(Struct)
485
+ VISIT(Union)
486
+ VISIT(Dictionary)
487
+ VISIT(Decimal128)
488
+ // TODO
489
+ // VISIT(Extension)
490
+
491
+ #undef VISIT
492
+ private:
493
+ template <typename ArrayType>
494
+ inline void convert_value(const ArrayType& array) {
495
+ auto result = rb_hash_new();
496
+ if (array.IsNull(index_)) {
497
+ rb_hash_aset(result, field_name_, Qnil);
498
+ } else {
499
+ rb_hash_aset(result,
500
+ field_name_,
501
+ array_value_converter_->convert(array, index_));
502
+ }
503
+ result_ = result;
504
+ }
505
+
506
+ uint8_t compute_child_index(const arrow::UnionArray& array,
507
+ arrow::UnionType* type,
508
+ const char* tag) {
509
+ const auto type_id = array.raw_type_ids()[index_];
510
+ const auto& type_codes = type->type_codes();
511
+ for (uint8_t i = 0; i < type_codes.size(); ++i) {
512
+ if (type_codes[i] == type_id) {
513
+ return i;
514
+ }
515
+ }
516
+ check_status(Status::Invalid("Unknown type ID: ", type_id),
517
+ tag);
518
+ return 0;
519
+ }
520
+
521
+ void convert_sparse(const arrow::UnionArray& array) {
522
+ const auto type =
523
+ std::static_pointer_cast<arrow::UnionType>(array.type()).get();
524
+ const auto tag = "[raw-records][union-sparse-array]";
525
+ const auto child_index = compute_child_index(array, type, tag);
526
+ const auto child_field = type->child(child_index).get();
527
+ const auto& field_name = child_field->name();
528
+ const auto field_name_keep = field_name_;
529
+ field_name_ = rb_utf8_str_new(field_name.data(), field_name.length());
530
+ const auto child_array = array.child(child_index).get();
531
+ check_status(child_array->Accept(this), tag);
532
+ field_name_ = field_name_keep;
533
+ }
534
+
535
+ void convert_dense(const arrow::UnionArray& array) {
536
+ const auto type =
537
+ std::static_pointer_cast<arrow::UnionType>(array.type()).get();
538
+ const auto tag = "[raw-records][union-dense-array]";
539
+ const auto child_index = compute_child_index(array, type, tag);
540
+ const auto child_field = type->child(child_index).get();
541
+ const auto& field_name = child_field->name();
542
+ const auto field_name_keep = field_name_;
543
+ field_name_ = rb_utf8_str_new(field_name.data(), field_name.length());
544
+ const auto child_array = array.child(child_index);
545
+ const auto index_keep = index_;
546
+ index_ = array.value_offset(index_);
547
+ check_status(child_array->Accept(this), tag);
548
+ index_ = index_keep;
549
+ field_name_ = field_name_keep;
550
+ }
551
+
552
+ ArrayValueConverter* array_value_converter_;
553
+ int64_t index_;
554
+ VALUE field_name_;
555
+ VALUE result_;
556
+ };
557
+
558
+ class DictionaryArrayValueConverter : public arrow::ArrayVisitor {
559
+ public:
560
+ explicit DictionaryArrayValueConverter(ArrayValueConverter* converter)
561
+ : array_value_converter_(converter),
562
+ index_(0),
563
+ result_(Qnil) {
564
+ }
565
+
566
+ VALUE convert(const arrow::DictionaryArray& array,
567
+ const int64_t index) {
568
+ index_ = index;
569
+ auto indices = array.indices().get();
570
+ check_status(indices->Accept(this),
571
+ "[raw-records][dictionary-array]");
572
+ return result_;
573
+ }
574
+
575
+ // TODO: Convert to real value.
576
+ #define VISIT(TYPE) \
577
+ Status Visit(const arrow::TYPE ## Array& array) override { \
578
+ result_ = convert_value(array, index_); \
579
+ return Status::OK(); \
580
+ }
581
+
582
+ VISIT(Int8)
583
+ VISIT(Int16)
584
+ VISIT(Int32)
585
+ VISIT(Int64)
586
+
587
+ #undef VISIT
588
+
589
+ private:
590
+ template <typename ArrayType>
591
+ inline VALUE convert_value(const ArrayType& array,
592
+ const int64_t i) {
593
+ return array_value_converter_->convert(array, i);
594
+ }
595
+
596
+ ArrayValueConverter* array_value_converter_;
597
+ int64_t index_;
598
+ VALUE result_;
599
+ };
600
+
601
+ VALUE ArrayValueConverter::convert(const arrow::ListArray& array,
602
+ const int64_t i) {
603
+ return list_array_value_converter_->convert(array, i);
604
+ }
605
+
606
+ VALUE ArrayValueConverter::convert(const arrow::StructArray& array,
607
+ const int64_t i) {
608
+ return struct_array_value_converter_->convert(array, i);
609
+ }
610
+
611
+ VALUE ArrayValueConverter::convert(const arrow::UnionArray& array,
612
+ const int64_t i) {
613
+ return union_array_value_converter_->convert(array, i);
614
+ }
615
+
616
+ VALUE ArrayValueConverter::convert(const arrow::DictionaryArray& array,
617
+ const int64_t i) {
618
+ return dictionary_array_value_converter_->convert(array, i);
619
+ }
620
+
621
+ class RawRecordsBuilder : public arrow::ArrayVisitor {
622
+ public:
623
+ explicit RawRecordsBuilder(VALUE records, int n_columns)
624
+ : array_value_converter_(),
625
+ list_array_value_converter_(&array_value_converter_),
626
+ struct_array_value_converter_(&array_value_converter_),
627
+ union_array_value_converter_(&array_value_converter_),
628
+ dictionary_array_value_converter_(&array_value_converter_),
629
+ records_(records),
630
+ n_columns_(n_columns) {
631
+ array_value_converter_.
632
+ set_sub_value_converters(&list_array_value_converter_,
633
+ &struct_array_value_converter_,
634
+ &union_array_value_converter_,
635
+ &dictionary_array_value_converter_);
636
+ }
637
+
638
+ void build(const arrow::RecordBatch& record_batch) {
639
+ rb::protect([&] {
640
+ const auto n_rows = record_batch.num_rows();
641
+ for (int64_t i = 0; i < n_rows; ++i) {
642
+ auto record = rb_ary_new_capa(n_columns_);
643
+ rb_ary_push(records_, record);
644
+ }
645
+ for (int i = 0; i < n_columns_; ++i) {
646
+ const auto array = record_batch.column(i).get();
647
+ column_index_ = i;
648
+ check_status(array->Accept(this),
649
+ "[raw-records]");
650
+ }
651
+ return Qnil;
652
+ });
653
+ }
654
+
655
+ #define VISIT(TYPE) \
656
+ Status Visit(const arrow::TYPE ## Array& array) override { \
657
+ convert(array); \
658
+ return Status::OK(); \
659
+ }
660
+
661
+ VISIT(Null)
662
+ VISIT(Boolean)
663
+ VISIT(Int8)
664
+ VISIT(Int16)
665
+ VISIT(Int32)
666
+ VISIT(Int64)
667
+ VISIT(UInt8)
668
+ VISIT(UInt16)
669
+ VISIT(UInt32)
670
+ VISIT(UInt64)
671
+ // TODO
672
+ // VISIT(HalfFloat)
673
+ VISIT(Float)
674
+ VISIT(Double)
675
+ VISIT(Binary)
676
+ VISIT(String)
677
+ VISIT(FixedSizeBinary)
678
+ VISIT(Date32)
679
+ VISIT(Date64)
680
+ VISIT(Time32)
681
+ VISIT(Time64)
682
+ VISIT(Timestamp)
683
+ // TODO
684
+ // VISIT(Interval)
685
+ VISIT(List)
686
+ VISIT(Struct)
687
+ VISIT(Union)
688
+ VISIT(Dictionary)
689
+ VISIT(Decimal128)
690
+ // TODO
691
+ // VISIT(Extension)
692
+
693
+ #undef VISIT
694
+
695
+ private:
696
+ template <typename ArrayType>
697
+ inline VALUE convert_value(const ArrayType& array,
698
+ const int64_t i) {
699
+ return array_value_converter_.convert(array, i);
700
+ }
701
+
702
+ template <typename ArrayType>
703
+ void convert(const ArrayType& array) {
704
+ const auto n = array.length();
705
+ if (array.null_count() > 0) {
706
+ for (int64_t i = 0; i < n; ++i) {
707
+ auto value = Qnil;
708
+ if (!array.IsNull(i)) {
709
+ value = convert_value(array, i);
710
+ }
711
+ auto record = rb_ary_entry(records_, i);
712
+ rb_ary_store(record, column_index_, value);
713
+ }
714
+ } else {
715
+ for (int64_t i = 0; i < n; ++i) {
716
+ auto record = rb_ary_entry(records_, i);
717
+ rb_ary_store(record, column_index_, convert_value(array, i));
718
+ }
719
+ }
720
+ }
721
+
722
+ ArrayValueConverter array_value_converter_;
723
+ ListArrayValueConverter list_array_value_converter_;
724
+ StructArrayValueConverter struct_array_value_converter_;
725
+ UnionArrayValueConverter union_array_value_converter_;
726
+ DictionaryArrayValueConverter dictionary_array_value_converter_;
727
+
728
+ // Destination for converted records.
729
+ VALUE records_;
730
+
731
+ // The current column index.
732
+ int column_index_;
733
+
734
+ // The number of columns.
735
+ const int n_columns_;
736
+ };
737
+ }
738
+
739
+ VALUE
740
+ record_batch_raw_records(VALUE rb_record_batch) {
741
+ auto garrow_record_batch = GARROW_RECORD_BATCH(RVAL2GOBJ(rb_record_batch));
742
+ auto record_batch = garrow_record_batch_get_raw(garrow_record_batch).get();
743
+ const auto n_rows = record_batch->num_rows();
744
+ const auto n_columns = record_batch->num_columns();
745
+ auto records = rb_ary_new_capa(n_rows);
746
+
747
+ try {
748
+ RawRecordsBuilder builder(records, n_columns);
749
+ builder.build(*record_batch);
750
+ } catch (rb::State& state) {
751
+ state.jump();
752
+ }
753
+
754
+ return records;
755
+ }
756
+ }