parquet 0.4.2 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,53 +1,57 @@
1
- use std::{
2
- fs::File,
3
- io::{self, BufReader, BufWriter},
4
- mem,
5
- sync::Arc,
6
- };
1
+ mod write_columns;
2
+ mod write_rows;
7
3
 
8
- use arrow_array::{Array, RecordBatch};
9
- use arrow_schema::{DataType, Field, Schema, TimeUnit};
4
+ use arrow_schema::{DataType, Schema, TimeUnit};
5
+ use itertools::Itertools;
10
6
  use magnus::{
11
7
  scan_args::{get_kwargs, scan_args},
12
8
  value::ReprValue,
13
- Error as MagnusError, RArray, Ruby, TryConvert, Value,
9
+ Error as MagnusError, RArray, RHash, Ruby, Symbol, Value,
14
10
  };
15
11
  use parquet::{
16
12
  arrow::ArrowWriter,
17
13
  basic::{Compression, GzipLevel, ZstdLevel},
18
14
  file::properties::WriterProperties,
19
15
  };
20
- use rand::Rng;
16
+ use std::{
17
+ fs::File,
18
+ io::{self, BufReader, BufWriter},
19
+ sync::Arc,
20
+ };
21
21
  use tempfile::NamedTempFile;
22
+ pub use write_columns::write_columns;
23
+ pub use write_rows::write_rows;
22
24
 
25
+ use crate::{types::PrimitiveType, SchemaNode};
23
26
  use crate::{
24
- convert_ruby_array_to_arrow,
25
- types::{ColumnCollector, ParquetErrorWrapper, WriterOutput},
26
- IoLikeValue, ParquetSchemaType, ParquetWriteArgs, SchemaField, SendableWrite,
27
+ types::{ColumnCollector, ParquetGemError, ParquetSchemaType, WriterOutput},
28
+ utils::parse_string_or_symbol,
29
+ IoLikeValue, ParquetSchemaType as PST, ParquetWriteArgs, SchemaField, SendableWrite,
27
30
  };
28
31
 
29
- const MIN_SAMPLES_FOR_ESTIMATE: usize = 10; // Minimum samples needed for estimation
30
- const SAMPLE_SIZE: usize = 100; // Number of rows to sample for size estimation
31
- const MIN_BATCH_SIZE: usize = 10; // Minimum batch size to maintain efficiency
32
- const INITIAL_BATCH_SIZE: usize = 100; // Initial batch size while sampling
33
-
34
- // Maximum memory usage per batch (64MB by default)
32
+ const MIN_SAMPLES_FOR_ESTIMATE: usize = 10;
33
+ const SAMPLE_SIZE: usize = 100;
34
+ const MIN_BATCH_SIZE: usize = 10;
35
+ const INITIAL_BATCH_SIZE: usize = 100;
35
36
  const DEFAULT_MEMORY_THRESHOLD: usize = 64 * 1024 * 1024;
36
37
 
37
38
  /// Parse arguments for Parquet writing
38
- pub fn parse_parquet_write_args(args: &[Value]) -> Result<ParquetWriteArgs, MagnusError> {
39
- let ruby = unsafe { Ruby::get_unchecked() };
39
+ pub fn parse_parquet_write_args(
40
+ ruby: &Ruby,
41
+ args: &[Value],
42
+ ) -> Result<ParquetWriteArgs, MagnusError> {
40
43
  let parsed_args = scan_args::<(Value,), (), (), (), _, ()>(args)?;
41
44
  let (read_from,) = parsed_args.required;
42
45
 
43
46
  let kwargs = get_kwargs::<
44
47
  _,
45
- (Option<RArray>, Value),
48
+ (Value, Value),
46
49
  (
47
50
  Option<Option<usize>>,
48
51
  Option<Option<usize>>,
49
52
  Option<Option<String>>,
50
53
  Option<Option<usize>>,
54
+ Option<Option<Value>>,
51
55
  ),
52
56
  (),
53
57
  >(
@@ -58,451 +62,358 @@ pub fn parse_parquet_write_args(args: &[Value]) -> Result<ParquetWriteArgs, Magn
58
62
  "flush_threshold",
59
63
  "compression",
60
64
  "sample_size",
65
+ "logger",
61
66
  ],
62
67
  )?;
63
68
 
64
- let schema = if kwargs.required.0.is_none() || kwargs.required.0.unwrap().is_empty() {
65
- // If schema is nil, we need to peek at the first value to determine column count
66
- let first_value = read_from.funcall::<_, _, Value>("peek", ())?;
67
- let array = RArray::from_value(first_value).ok_or_else(|| {
68
- MagnusError::new(
69
- magnus::exception::type_error(),
70
- "First value must be an array when schema is not provided",
71
- )
72
- })?;
73
-
74
- // Generate field names f0, f1, f2, etc.
75
- (0..array.len())
76
- .map(|i| SchemaField {
77
- name: format!("f{}", i),
78
- type_: ParquetSchemaType::String,
79
- format: None,
80
- })
81
- .collect()
69
+ // The schema value could be one of:
70
+ // 1. An array of hashes (legacy format)
71
+ // 2. A hash with type: :struct (new DSL format)
72
+ // 3. nil (infer from data)
73
+ let schema_value = kwargs.required.0;
74
+
75
+ // Check if it's the new DSL format (a hash with type: :struct)
76
+ // We need to handle both direct hash objects and objects created via Parquet::Schema.define
77
+
78
+ // First, try to convert it to a Hash if it's not already a Hash
79
+ // This handles the case where schema_value is a Schema object from Parquet::Schema.define
80
+ let schema_hash = if schema_value.is_kind_of(ruby.class_hash()) {
81
+ RHash::from_value(schema_value).ok_or_else(|| {
82
+ MagnusError::new(magnus::exception::type_error(), "Schema must be a hash")
83
+ })?
82
84
  } else {
83
- let schema_array = kwargs.required.0.unwrap();
84
-
85
- let mut schema = Vec::with_capacity(schema_array.len());
86
-
87
- for (idx, field_hash) in schema_array.into_iter().enumerate() {
88
- if !field_hash.is_kind_of(ruby.class_hash()) {
89
- return Err(MagnusError::new(
90
- magnus::exception::type_error(),
91
- format!("schema[{}] must be a hash", idx),
92
- ));
93
- }
94
-
95
- let entries: Vec<(Value, Value)> = field_hash.funcall("to_a", ())?;
96
- if entries.len() != 1 {
97
- return Err(MagnusError::new(
98
- magnus::exception::type_error(),
99
- format!("schema[{}] must contain exactly one key-value pair", idx),
100
- ));
101
- }
102
-
103
- let (name, type_value) = &entries[0];
104
- let name = String::try_convert(name.clone())?;
105
-
106
- let (type_, format) = if type_value.is_kind_of(ruby.class_hash()) {
107
- let type_hash: Vec<(Value, Value)> = type_value.funcall("to_a", ())?;
108
- let mut type_str = None;
109
- let mut format_str = None;
110
-
111
- for (key, value) in type_hash {
112
- let key = String::try_convert(key)?;
113
- match key.as_str() {
114
- "type" => type_str = Some(value),
115
- "format" => format_str = Some(String::try_convert(value)?),
116
- _ => {
117
- return Err(MagnusError::new(
118
- magnus::exception::type_error(),
119
- format!("Unknown key '{}' in type definition", key),
120
- ))
85
+ // Try to convert the object to a hash with to_h
86
+ match schema_value.respond_to("to_h", false) {
87
+ Ok(true) => {
88
+ match schema_value.funcall::<_, _, Value>("to_h", ()) {
89
+ Ok(hash_val) => match RHash::from_value(hash_val) {
90
+ Some(hash) => hash,
91
+ None => {
92
+ // Not a hash, continue to normal handling
93
+ RHash::new()
121
94
  }
95
+ },
96
+ Err(_) => {
97
+ // couldn't call to_h, continue to normal handling
98
+ RHash::new()
122
99
  }
123
100
  }
101
+ }
102
+ _ => {
103
+ // Doesn't respond to to_h, continue to normal handling
104
+ RHash::new()
105
+ }
106
+ }
107
+ };
108
+
109
+ // Now check if it's a schema hash with a type: :struct field
110
+ let type_val = schema_hash.get(Symbol::new("type"));
111
+
112
+ if let Some(type_val) = type_val {
113
+ // If it has a type: :struct, it's the new DSL format
114
+ // Use parse_string_or_symbol to handle both String and Symbol values
115
+ let ttype = parse_string_or_symbol(&ruby, type_val)?;
116
+ if let Some(ref type_str) = ttype {
117
+ if type_str == "struct" {
118
+ // Parse using the new schema approach
119
+ let schema_node = crate::parse_schema_node(&ruby, schema_value)?;
120
+
121
+ validate_schema_node(&ruby, &schema_node)?;
122
+
123
+ return Ok(ParquetWriteArgs {
124
+ read_from,
125
+ write_to: kwargs.required.1,
126
+ schema: schema_node,
127
+ batch_size: kwargs.optional.0.flatten(),
128
+ flush_threshold: kwargs.optional.1.flatten(),
129
+ compression: kwargs.optional.2.flatten(),
130
+ sample_size: kwargs.optional.3.flatten(),
131
+ logger: kwargs.optional.4.flatten(),
132
+ });
133
+ }
134
+ }
135
+ }
124
136
 
125
- let type_str = type_str.ok_or_else(|| {
137
+ // If it's not a hash with type: :struct, handle as legacy format
138
+ let schema_fields = if schema_value.is_nil()
139
+ || (schema_value.is_kind_of(ruby.class_array())
140
+ && RArray::from_value(schema_value)
141
+ .ok_or_else(|| {
126
142
  MagnusError::new(
127
143
  magnus::exception::type_error(),
128
- "Missing 'type' in type definition",
144
+ "Schema fields must be an array",
129
145
  )
130
- })?;
131
-
132
- (ParquetSchemaType::try_convert(type_str)?, format_str)
133
- } else {
134
- (ParquetSchemaType::try_convert(type_value.clone())?, None)
135
- };
146
+ })?
147
+ .len()
148
+ == 0)
149
+ {
150
+ // If schema is nil or an empty array, we need to peek at the first value to determine column count
151
+ let first_value = read_from.funcall::<_, _, Value>("peek", ())?;
152
+ // Default to nullable:true for auto-inferred fields
153
+ crate::infer_schema_from_first_row(&ruby, first_value, true)?
154
+ } else {
155
+ // Legacy array format - use our centralized parser
156
+ crate::parse_legacy_schema(&ruby, schema_value)?
157
+ };
136
158
 
137
- schema.push(SchemaField {
138
- name,
139
- type_,
140
- format,
141
- });
142
- }
159
+ // Convert the legacy schema fields to SchemaNode (DSL format)
160
+ let schema_node = crate::legacy_schema_to_dsl(&ruby, schema_fields)?;
143
161
 
144
- schema
145
- };
162
+ validate_schema_node(&ruby, &schema_node)?;
146
163
 
147
164
  Ok(ParquetWriteArgs {
148
165
  read_from,
149
166
  write_to: kwargs.required.1,
150
- schema,
167
+ schema: schema_node,
151
168
  batch_size: kwargs.optional.0.flatten(),
152
169
  flush_threshold: kwargs.optional.1.flatten(),
153
170
  compression: kwargs.optional.2.flatten(),
154
171
  sample_size: kwargs.optional.3.flatten(),
172
+ logger: kwargs.optional.4.flatten(),
155
173
  })
156
174
  }
157
175
 
158
- /// Estimate the size of a row
159
- fn estimate_single_row_size(row: &RArray, schema: &[SchemaField]) -> Result<usize, MagnusError> {
160
- let mut row_size = 0;
161
- for (field, value) in schema.iter().zip(row.into_iter()) {
162
- // Estimate size based on type and value
163
- row_size += match &field.type_ {
164
- // Use reference to avoid moving
165
- ParquetSchemaType::Int8 | ParquetSchemaType::UInt8 => 1,
166
- ParquetSchemaType::Int16 | ParquetSchemaType::UInt16 => 2,
167
- ParquetSchemaType::Int32
168
- | ParquetSchemaType::UInt32
169
- | ParquetSchemaType::Float
170
- | ParquetSchemaType::Date32 => 4,
171
- ParquetSchemaType::Int64
172
- | ParquetSchemaType::UInt64
173
- | ParquetSchemaType::Double
174
- | ParquetSchemaType::TimestampMillis
175
- | ParquetSchemaType::TimestampMicros => 8,
176
- ParquetSchemaType::String => {
177
- if let Ok(s) = String::try_convert(value) {
178
- s.len() + mem::size_of::<usize>() // account for length prefix
179
- } else {
180
- 16 // default estimate for string
181
- }
182
- }
183
- ParquetSchemaType::Binary => {
184
- if let Ok(bytes) = Vec::<u8>::try_convert(value) {
185
- bytes.len() + mem::size_of::<usize>() // account for length prefix
176
+ // -----------------------------------------------------------------------------
177
+ // HELPER to invert arrow DataType back to our ParquetSchemaType
178
+ // Converts Arrow DataType to our internal ParquetSchemaType representation.
179
+ // This is essential for mapping Arrow types back to our schema representation
180
+ // when working with column collections and schema validation.
181
+ // -----------------------------------------------------------------------------
182
+ fn arrow_data_type_to_parquet_schema_type(dt: &DataType) -> Result<ParquetSchemaType, MagnusError> {
183
+ match dt {
184
+ DataType::Boolean => Ok(PST::Primitive(PrimitiveType::Boolean)),
185
+ DataType::Int8 => Ok(PST::Primitive(PrimitiveType::Int8)),
186
+ DataType::Int16 => Ok(PST::Primitive(PrimitiveType::Int16)),
187
+ DataType::Int32 => Ok(PST::Primitive(PrimitiveType::Int32)),
188
+ DataType::Int64 => Ok(PST::Primitive(PrimitiveType::Int64)),
189
+ DataType::UInt8 => Ok(PST::Primitive(PrimitiveType::UInt8)),
190
+ DataType::UInt16 => Ok(PST::Primitive(PrimitiveType::UInt16)),
191
+ DataType::UInt32 => Ok(PST::Primitive(PrimitiveType::UInt32)),
192
+ DataType::UInt64 => Ok(PST::Primitive(PrimitiveType::UInt64)),
193
+ DataType::Float16 => {
194
+ // We do not have a direct ParquetSchemaType::Float16, we treat it as Float
195
+ Ok(PST::Primitive(PrimitiveType::Float32))
196
+ }
197
+ DataType::Float32 => Ok(PST::Primitive(PrimitiveType::Float32)),
198
+ DataType::Float64 => Ok(PST::Primitive(PrimitiveType::Float64)),
199
+ DataType::Date32 => Ok(PST::Primitive(PrimitiveType::Date32)),
200
+ DataType::Date64 => {
201
+ // Our code typically uses Date32 or Timestamp for 64. But Arrow has Date64
202
+ // We can store it as PST::Date64 if we want. If we don't have that, consider PST::Date32 or an error.
203
+ // If your existing code only handles Date32, you can error. But let's do PST::Date32 as fallback:
204
+ // Or define a new variant if you have one in your code. We'll show a fallback approach:
205
+ Err(MagnusError::new(
206
+ magnus::exception::runtime_error(),
207
+ "Arrow Date64 not directly supported in current ParquetSchemaType (use date32?).",
208
+ ))
209
+ }
210
+ DataType::Timestamp(TimeUnit::Second, _tz) => {
211
+ // We'll treat this as PST::TimestampMillis, or define PST::TimestampSecond
212
+ // For simplicity, let's map "second" to PST::TimestampMillis with a note:
213
+ Ok(PST::Primitive(PrimitiveType::TimestampMillis))
214
+ }
215
+ DataType::Timestamp(TimeUnit::Millisecond, _tz) => {
216
+ Ok(PST::Primitive(PrimitiveType::TimestampMillis))
217
+ }
218
+ DataType::Timestamp(TimeUnit::Microsecond, _tz) => {
219
+ Ok(PST::Primitive(PrimitiveType::TimestampMicros))
220
+ }
221
+ DataType::Timestamp(TimeUnit::Nanosecond, _tz) => {
222
+ // If you have a PST::TimestampNanos variant, use it. Otherwise, degrade to micros
223
+ // for demonstration:
224
+ Err(MagnusError::new(
225
+ magnus::exception::runtime_error(),
226
+ "TimestampNanos not supported, please adjust your schema or code.",
227
+ ))
228
+ }
229
+ DataType::Utf8 => Ok(PST::Primitive(PrimitiveType::String)),
230
+ DataType::Binary => Ok(PST::Primitive(PrimitiveType::Binary)),
231
+ DataType::LargeUtf8 => {
232
+ // If not supported, degrade or error. We'll degrade to PST::String
233
+ Ok(PST::Primitive(PrimitiveType::String))
234
+ }
235
+ DataType::LargeBinary => Ok(PST::Primitive(PrimitiveType::Binary)),
236
+ DataType::List(child_field) => {
237
+ // Recursively handle the item type
238
+ let child_type = arrow_data_type_to_parquet_schema_type(child_field.data_type())?;
239
+ Ok(PST::List(Box::new(crate::types::ListField {
240
+ item_type: child_type,
241
+ format: None,
242
+ nullable: true,
243
+ })))
244
+ }
245
+ DataType::Map(entry_field, _keys_sorted) => {
246
+ // Arrow's Map -> a struct<key, value> inside
247
+ let entry_type = entry_field.data_type();
248
+ if let DataType::Struct(fields) = entry_type {
249
+ if fields.len() == 2 {
250
+ let key_type = arrow_data_type_to_parquet_schema_type(fields[0].data_type())?;
251
+ let value_type = arrow_data_type_to_parquet_schema_type(fields[1].data_type())?;
252
+ Ok(PST::Map(Box::new(crate::types::MapField {
253
+ key_type,
254
+ value_type,
255
+ key_format: None,
256
+ value_format: None,
257
+ value_nullable: true,
258
+ })))
186
259
  } else {
187
- 16 // default estimate for binary
260
+ Err(MagnusError::new(
261
+ magnus::exception::type_error(),
262
+ "Map field must have exactly 2 child fields (key, value)",
263
+ ))
188
264
  }
265
+ } else {
266
+ Err(MagnusError::new(
267
+ magnus::exception::type_error(),
268
+ "Map field is not a struct? Unexpected Arrow schema layout",
269
+ ))
189
270
  }
190
- ParquetSchemaType::Boolean => 1,
191
- ParquetSchemaType::List(_) | ParquetSchemaType::Map(_) => {
192
- 32 // rough estimate for complex types
271
+ }
272
+ DataType::Struct(arrow_fields) => {
273
+ // We treat this as PST::Struct. We'll recursively handle subfields
274
+ // but for top-level collecting we only store them as one column
275
+ // so the user data must pass a Ruby Hash or something for that field.
276
+ let mut schema_fields = vec![];
277
+ for f in arrow_fields {
278
+ let sub_type = arrow_data_type_to_parquet_schema_type(f.data_type())?;
279
+ schema_fields.push(SchemaField {
280
+ name: f.name().clone(),
281
+ type_: sub_type,
282
+ format: None, // We can't see the 'format' from Arrow
283
+ nullable: f.is_nullable(),
284
+ });
193
285
  }
194
- };
286
+ Ok(PST::Struct(Box::new(crate::types::StructField {
287
+ fields: schema_fields,
288
+ })))
289
+ }
290
+ _ => Err(MagnusError::new(
291
+ magnus::exception::runtime_error(),
292
+ format!("Unsupported or unhandled Arrow DataType: {:?}", dt),
293
+ )),
195
294
  }
196
- Ok(row_size)
197
295
  }
198
296
 
199
- #[inline]
200
- pub fn write_rows(args: &[Value]) -> Result<(), MagnusError> {
201
- let ruby = unsafe { Ruby::get_unchecked() };
202
-
203
- let ParquetWriteArgs {
204
- read_from,
205
- write_to,
206
- schema,
207
- batch_size: user_batch_size,
208
- compression,
209
- flush_threshold,
210
- sample_size: user_sample_size,
211
- } = parse_parquet_write_args(args)?;
212
-
213
- let flush_threshold = flush_threshold.unwrap_or(DEFAULT_MEMORY_THRESHOLD);
214
-
215
- // Convert schema to Arrow schema
216
- let arrow_fields: Vec<Field> = schema
217
- .iter()
218
- .map(|field| {
219
- Field::new(
220
- &field.name,
221
- match field.type_ {
222
- ParquetSchemaType::Int8 => DataType::Int8,
223
- ParquetSchemaType::Int16 => DataType::Int16,
224
- ParquetSchemaType::Int32 => DataType::Int32,
225
- ParquetSchemaType::Int64 => DataType::Int64,
226
- ParquetSchemaType::UInt8 => DataType::UInt8,
227
- ParquetSchemaType::UInt16 => DataType::UInt16,
228
- ParquetSchemaType::UInt32 => DataType::UInt32,
229
- ParquetSchemaType::UInt64 => DataType::UInt64,
230
- ParquetSchemaType::Float => DataType::Float32,
231
- ParquetSchemaType::Double => DataType::Float64,
232
- ParquetSchemaType::String => DataType::Utf8,
233
- ParquetSchemaType::Binary => DataType::Binary,
234
- ParquetSchemaType::Boolean => DataType::Boolean,
235
- ParquetSchemaType::Date32 => DataType::Date32,
236
- ParquetSchemaType::TimestampMillis => {
237
- DataType::Timestamp(TimeUnit::Millisecond, None)
238
- }
239
- ParquetSchemaType::TimestampMicros => {
240
- DataType::Timestamp(TimeUnit::Microsecond, None)
241
- }
242
- ParquetSchemaType::List(_) => unimplemented!("List type not yet supported"),
243
- ParquetSchemaType::Map(_) => unimplemented!("Map type not yet supported"),
244
- },
245
- true,
246
- )
247
- })
248
- .collect();
249
- let arrow_schema = Arc::new(Schema::new(arrow_fields));
250
-
251
- // Create the writer
252
- let mut writer = create_writer(&ruby, &write_to, arrow_schema.clone(), compression)?;
253
-
254
- if read_from.is_kind_of(ruby.class_enumerator()) {
255
- // Create collectors for each column
256
- let mut column_collectors: Vec<ColumnCollector> = schema
257
- .iter()
258
- .map(|field| {
259
- // Clone the type to avoid moving from a reference
260
- let type_clone = field.type_.clone();
261
- ColumnCollector::new(field.name.clone(), type_clone, field.format.clone())
262
- })
263
- .collect();
264
-
265
- let mut rows_in_batch = 0;
266
- let mut total_rows = 0;
267
- let mut rng = rand::rng();
268
- let sample_size = user_sample_size.unwrap_or(SAMPLE_SIZE);
269
- let mut size_samples = Vec::with_capacity(sample_size);
270
- let mut current_batch_size = user_batch_size.unwrap_or(INITIAL_BATCH_SIZE);
271
-
272
- loop {
273
- match read_from.funcall::<_, _, Value>("next", ()) {
274
- Ok(row) => {
275
- let row_array = RArray::from_value(row).ok_or_else(|| {
276
- MagnusError::new(ruby.exception_type_error(), "Row must be an array")
277
- })?;
278
-
279
- // Validate row length matches schema
280
- if row_array.len() != column_collectors.len() {
281
- return Err(MagnusError::new(
282
- magnus::exception::type_error(),
283
- format!(
284
- "Row length ({}) does not match schema length ({}). Schema expects columns: {:?}",
285
- row_array.len(),
286
- column_collectors.len(),
287
- column_collectors.iter().map(|c| c.name.as_str()).collect::<Vec<_>>()
288
- ),
289
- ));
290
- }
291
-
292
- // Sample row sizes using reservoir sampling
293
- if size_samples.len() < sample_size {
294
- size_samples.push(estimate_single_row_size(&row_array, &schema)?);
295
- } else if rng.random_range(0..=total_rows) < sample_size {
296
- let idx = rng.random_range(0..sample_size);
297
- size_samples[idx] = estimate_single_row_size(&row_array, &schema)?;
298
- }
299
-
300
- // Process each value in the row
301
- for (collector, value) in column_collectors.iter_mut().zip(row_array) {
302
- collector.push_value(value)?;
303
- }
304
-
305
- rows_in_batch += 1;
306
- total_rows += 1;
307
-
308
- // Calculate batch size progressively once we have minimum samples
309
- if size_samples.len() >= MIN_SAMPLES_FOR_ESTIMATE && user_batch_size.is_none() {
310
- let total_size = size_samples.iter().sum::<usize>();
311
- // Safe because we know we have at least MIN_SAMPLES_FOR_ESTIMATE samples
312
- let avg_row_size = total_size as f64 / size_samples.len() as f64;
313
- let avg_row_size = avg_row_size.max(1.0); // Ensure we don't divide by zero
314
- let suggested_batch_size =
315
- (flush_threshold as f64 / avg_row_size).floor() as usize;
316
- current_batch_size = suggested_batch_size.max(MIN_BATCH_SIZE);
317
- }
318
-
319
- // When we reach batch size, write the batch
320
- if rows_in_batch >= current_batch_size {
321
- write_batch(&mut writer, &mut column_collectors, flush_threshold)?;
322
- rows_in_batch = 0;
323
- }
324
- }
325
- Err(e) => {
326
- if e.is_kind_of(ruby.exception_stop_iteration()) {
327
- // Write any remaining rows
328
- if rows_in_batch > 0 {
329
- write_batch(&mut writer, &mut column_collectors, flush_threshold)?;
330
- }
331
- break;
332
- }
333
- return Err(e);
334
- }
335
- }
297
+ // -----------------------------------------------------------------------------
298
+ // HELPER to build ColumnCollectors for the DSL variant
299
+ // This function converts a SchemaNode (from our DSL) into a collection of ColumnCollectors
300
+ // that can accumulate values for each column in the schema.
301
+ // - arrow_schema: The Arrow schema corresponding to our DSL schema
302
+ // - root_node: The root SchemaNode (expected to be a Struct node) from which to build collectors
303
+ // -----------------------------------------------------------------------------
304
+ fn build_column_collectors_from_dsl<'a>(
305
+ ruby: &'a Ruby,
306
+ arrow_schema: &'a Arc<Schema>,
307
+ root_node: &'a SchemaNode,
308
+ ) -> Result<Vec<ColumnCollector<'a>>, MagnusError> {
309
+ // We expect the top-level schema node to be a Struct so that arrow_schema
310
+ // lines up with root_node.fields. If the user gave a top-level primitive, it would be 1 field, but
311
+ // our code calls build_arrow_schema under the assumption "top-level must be Struct."
312
+ let fields = match root_node {
313
+ SchemaNode::Struct { fields, .. } => fields,
314
+ _ => {
315
+ return Err(MagnusError::new(
316
+ ruby.exception_runtime_error(),
317
+ "Top-level schema for DSL must be a struct",
318
+ ))
336
319
  }
337
- } else {
320
+ };
321
+
322
+ if fields.len() != arrow_schema.fields().len() {
338
323
  return Err(MagnusError::new(
339
- magnus::exception::type_error(),
340
- "read_from must be an Enumerator",
324
+ ruby.exception_runtime_error(),
325
+ format!(
326
+ "Mismatch between DSL field count ({}) and Arrow fields ({})",
327
+ fields.len(),
328
+ arrow_schema.fields().len()
329
+ ),
341
330
  ));
342
331
  }
343
332
 
344
- // Ensure everything is written and get the temp file if it exists
345
- if let Some(temp_file) = writer.close().map_err(|e| ParquetErrorWrapper(e))? {
346
- // If we got a temp file back, we need to copy its contents to the IO-like object
347
- copy_temp_file_to_io_like(temp_file, IoLikeValue(write_to))?;
333
+ let mut collectors = Vec::with_capacity(fields.len());
334
+ for (arrow_field, schema_field_node) in arrow_schema.fields().iter().zip(fields) {
335
+ let name = arrow_field.name().clone();
336
+ let parquet_type = arrow_data_type_to_parquet_schema_type(arrow_field.data_type())?;
337
+
338
+ // Extract the optional format from the schema node
339
+ let format = extract_format_from_schema_node(schema_field_node);
340
+
341
+ // Build the ColumnCollector
342
+ collectors.push(ColumnCollector::new(
343
+ ruby,
344
+ name,
345
+ parquet_type,
346
+ format,
347
+ arrow_field.is_nullable(),
348
+ ));
348
349
  }
349
-
350
- Ok(())
350
+ Ok(collectors)
351
351
  }
352
352
 
353
- #[inline]
354
- pub fn write_columns(args: &[Value]) -> Result<(), MagnusError> {
355
- let ruby = unsafe { Ruby::get_unchecked() };
356
-
357
- let ParquetWriteArgs {
358
- read_from,
359
- write_to,
360
- schema,
361
- batch_size: _,
362
- compression,
363
- flush_threshold,
364
- sample_size: _,
365
- } = parse_parquet_write_args(args)?;
366
-
367
- let flush_threshold = flush_threshold.unwrap_or(DEFAULT_MEMORY_THRESHOLD);
368
-
369
- // Convert schema to Arrow schema
370
- let arrow_fields: Vec<Field> = schema
371
- .iter()
372
- .map(|field| {
373
- Field::new(
374
- &field.name,
375
- match field.type_ {
376
- ParquetSchemaType::Int8 => DataType::Int8,
377
- ParquetSchemaType::Int16 => DataType::Int16,
378
- ParquetSchemaType::Int32 => DataType::Int32,
379
- ParquetSchemaType::Int64 => DataType::Int64,
380
- ParquetSchemaType::UInt8 => DataType::UInt8,
381
- ParquetSchemaType::UInt16 => DataType::UInt16,
382
- ParquetSchemaType::UInt32 => DataType::UInt32,
383
- ParquetSchemaType::UInt64 => DataType::UInt64,
384
- ParquetSchemaType::Float => DataType::Float32,
385
- ParquetSchemaType::Double => DataType::Float64,
386
- ParquetSchemaType::String => DataType::Utf8,
387
- ParquetSchemaType::Binary => DataType::Binary,
388
- ParquetSchemaType::Boolean => DataType::Boolean,
389
- ParquetSchemaType::Date32 => DataType::Date32,
390
- ParquetSchemaType::TimestampMillis => {
391
- DataType::Timestamp(TimeUnit::Millisecond, None)
392
- }
393
- ParquetSchemaType::TimestampMicros => {
394
- DataType::Timestamp(TimeUnit::Microsecond, None)
395
- }
396
- ParquetSchemaType::List(_) => unimplemented!("List type not yet supported"),
397
- ParquetSchemaType::Map(_) => unimplemented!("Map type not yet supported"),
398
- },
399
- true,
400
- )
401
- })
402
- .collect();
403
- let arrow_schema = Arc::new(Schema::new(arrow_fields));
404
-
405
- // Create the writer
406
- let mut writer = create_writer(&ruby, &write_to, arrow_schema.clone(), compression)?;
407
-
408
- if read_from.is_kind_of(ruby.class_enumerator()) {
409
- loop {
410
- match read_from.funcall::<_, _, Value>("next", ()) {
411
- Ok(batch) => {
412
- let batch_array = RArray::from_value(batch).ok_or_else(|| {
413
- MagnusError::new(ruby.exception_type_error(), "Batch must be an array")
414
- })?;
415
-
416
- // Batch array must be an array of arrays. Check that the first value in `batch_array` is an array.
417
- batch_array.entry::<RArray>(0).map_err(|_| {
418
- MagnusError::new(
419
- ruby.exception_type_error(),
420
- "When writing columns, data must be formatted as batches of columns: [[batch1_col1, batch1_col2], [batch2_col1, batch2_col2]].",
421
- )
422
- })?;
423
-
424
- // Validate batch length matches schema
425
- if batch_array.len() != schema.len() {
426
- return Err(MagnusError::new(
427
- magnus::exception::type_error(),
428
- format!(
429
- "Batch column count ({}) does not match schema length ({}). Schema expects columns: {:?}",
430
- batch_array.len(),
431
- schema.len(),
432
- schema.iter().map(|f| f.name.as_str()).collect::<Vec<_>>()
433
- ),
434
- ));
435
- }
436
-
437
- // Convert each column in the batch to Arrow arrays
438
- let arrow_arrays: Vec<(String, Arc<dyn Array>)> = schema
439
- .iter()
440
- .zip(batch_array)
441
- .map(|(field, column)| {
442
- let column_array = RArray::from_value(column).ok_or_else(|| {
443
- MagnusError::new(
444
- magnus::exception::type_error(),
445
- format!("Column '{}' must be an array", field.name),
446
- )
447
- })?;
448
-
449
- Ok((
450
- field.name.clone(),
451
- convert_ruby_array_to_arrow(column_array, &field.type_)?,
452
- ))
453
- })
454
- .collect::<Result<_, MagnusError>>()?;
455
-
456
- // Create and write record batch
457
- let record_batch = RecordBatch::try_from_iter(arrow_arrays).map_err(|e| {
458
- MagnusError::new(
459
- magnus::exception::runtime_error(),
460
- format!("Failed to create record batch: {}", e),
461
- )
462
- })?;
463
-
464
- writer
465
- .write(&record_batch)
466
- .map_err(|e| ParquetErrorWrapper(e))?;
467
-
468
- match &mut writer {
469
- WriterOutput::File(w) | WriterOutput::TempFile(w, _) => {
470
- if w.in_progress_size() >= flush_threshold {
471
- w.flush().map_err(|e| ParquetErrorWrapper(e))?;
472
- }
473
- }
474
- }
475
- }
476
- Err(e) => {
477
- if e.is_kind_of(ruby.exception_stop_iteration()) {
478
- break;
479
- }
480
- return Err(e);
481
- }
482
- }
483
- }
484
- } else {
485
- return Err(MagnusError::new(
486
- magnus::exception::type_error(),
487
- "read_from must be an Enumerator",
488
- ));
353
+ // Helper to extract the format from a SchemaNode if available
354
+ fn extract_format_from_schema_node(node: &SchemaNode) -> Option<String> {
355
+ match node {
356
+ SchemaNode::Primitive {
357
+ format: f,
358
+ parquet_type: _,
359
+ ..
360
+ } => f.clone(),
361
+ // For struct, list, map, etc. there's no single "format." We ignore it.
362
+ _ => None,
489
363
  }
364
+ }
490
365
 
491
- // Ensure everything is written and get the temp file if it exists
492
- if let Some(temp_file) = writer.close().map_err(|e| ParquetErrorWrapper(e))? {
493
- // If we got a temp file back, we need to copy its contents to the IO-like object
494
- copy_temp_file_to_io_like(temp_file, IoLikeValue(write_to))?;
366
+ // Validates a SchemaNode to ensure it meets Parquet schema requirements
367
+ // Currently checks for duplicate field names at the root level, which would
368
+ // cause problems when writing Parquet files. Additional validation rules
369
+ // could be added here in the future.
370
+ //
371
+ // This validation is important because schema errors are difficult to debug
372
+ // once they reach the Parquet/Arrow layer, so we check proactively before
373
+ // any data processing begins.
374
+ fn validate_schema_node(ruby: &Ruby, schema_node: &SchemaNode) -> Result<(), MagnusError> {
375
+ if let SchemaNode::Struct { fields, .. } = &schema_node {
376
+ // if any root level schema fields have the same name, we raise an error
377
+ let field_names = fields
378
+ .iter()
379
+ .map(|f| match f {
380
+ SchemaNode::Struct { name, .. } => name.as_str(),
381
+ SchemaNode::List { name, .. } => name.as_str(),
382
+ SchemaNode::Map { name, .. } => name.as_str(),
383
+ SchemaNode::Primitive { name, .. } => name.as_str(),
384
+ })
385
+ .collect::<Vec<_>>();
386
+ let unique_field_names = field_names.iter().unique().collect::<Vec<_>>();
387
+ if field_names.len() != unique_field_names.len() {
388
+ return Err(MagnusError::new(
389
+ ruby.exception_arg_error(),
390
+ format!(
391
+ "Duplicate field names in root level schema: {:?}",
392
+ field_names
393
+ ),
394
+ ));
395
+ }
495
396
  }
496
-
497
397
  Ok(())
498
398
  }
499
399
 
400
+ // Creates an appropriate Parquet writer based on the output target and compression settings
401
+ // This function handles two main output scenarios:
402
+ // 1. Writing directly to a file path (string)
403
+ // 2. Writing to a Ruby IO-like object (using a temporary file as an intermediate buffer)
404
+ //
405
+ // For IO-like objects, the function creates a temporary file that is later copied to the
406
+ // IO object when writing is complete. This approach is necessary because Parquet requires
407
+ // random file access to write its footer after the data.
408
+ //
409
+ // The function also configures compression based on the user's preferences, with
410
+ // several options available (none, snappy, gzip, lz4, zstd).
500
411
  fn create_writer(
501
412
  ruby: &Ruby,
502
413
  write_to: &Value,
503
414
  schema: Arc<Schema>,
504
415
  compression: Option<String>,
505
- ) -> Result<WriterOutput, MagnusError> {
416
+ ) -> Result<WriterOutput, ParquetGemError> {
506
417
  // Create writer properties with compression based on the option
507
418
  let props = WriterProperties::builder()
508
419
  .set_compression(match compression.as_deref() {
@@ -517,9 +428,8 @@ fn create_writer(
517
428
 
518
429
  if write_to.is_kind_of(ruby.class_string()) {
519
430
  let path = write_to.to_r_string()?.to_string()?;
520
- let file: Box<dyn SendableWrite> = Box::new(File::create(path).unwrap());
521
- let writer =
522
- ArrowWriter::try_new(file, schema, Some(props)).map_err(|e| ParquetErrorWrapper(e))?;
431
+ let file: Box<dyn SendableWrite> = Box::new(File::create(path)?);
432
+ let writer = ArrowWriter::try_new(file, schema, Some(props))?;
523
433
  Ok(WriterOutput::File(writer))
524
434
  } else {
525
435
  // Create a temporary file to write to instead of directly to the IoLikeValue
@@ -535,13 +445,22 @@ fn create_writer(
535
445
  format!("Failed to reopen temporary file: {}", e),
536
446
  )
537
447
  })?);
538
- let writer =
539
- ArrowWriter::try_new(file, schema, Some(props)).map_err(|e| ParquetErrorWrapper(e))?;
448
+ let writer = ArrowWriter::try_new(file, schema, Some(props))?;
540
449
  Ok(WriterOutput::TempFile(writer, temp_file))
541
450
  }
542
451
  }
543
452
 
544
- // Helper function to copy temp file contents to IoLikeValue
453
+ // Copies the contents of a temporary file to a Ruby IO-like object
454
+ // This function is necessary because Parquet writing requires random file access
455
+ // (especially for writing the footer after all data), but Ruby IO objects may not
456
+ // support seeking. The solution is to:
457
+ //
458
+ // 1. Write the entire Parquet file to a temporary file first
459
+ // 2. Once writing is complete, copy the entire contents to the Ruby IO object
460
+ //
461
+ // This approach enables support for a wide range of Ruby IO objects like StringIO,
462
+ // network streams, etc., but does require enough disk space for the temporary file
463
+ // and involves a second full-file read/write operation at the end.
545
464
  fn copy_temp_file_to_io_like(
546
465
  temp_file: NamedTempFile,
547
466
  io_like: IoLikeValue,
@@ -564,37 +483,3 @@ fn copy_temp_file_to_io_like(
564
483
 
565
484
  Ok(())
566
485
  }
567
-
568
- fn write_batch(
569
- writer: &mut WriterOutput,
570
- collectors: &mut [ColumnCollector],
571
- flush_threshold: usize,
572
- ) -> Result<(), MagnusError> {
573
- // Convert columns to Arrow arrays
574
- let arrow_arrays: Vec<(String, Arc<dyn Array>)> = collectors
575
- .iter_mut()
576
- .map(|collector| Ok((collector.name.clone(), collector.take_array()?)))
577
- .collect::<Result<_, MagnusError>>()?;
578
-
579
- // Create and write record batch
580
- let record_batch = RecordBatch::try_from_iter(arrow_arrays).map_err(|e| {
581
- MagnusError::new(
582
- magnus::exception::runtime_error(),
583
- format!("Failed to create record batch: {}", e),
584
- )
585
- })?;
586
-
587
- writer
588
- .write(&record_batch)
589
- .map_err(|e| ParquetErrorWrapper(e))?;
590
-
591
- match writer {
592
- WriterOutput::File(w) | WriterOutput::TempFile(w, _) => {
593
- if w.in_progress_size() >= flush_threshold || w.memory_size() >= flush_threshold {
594
- w.flush().map_err(|e| ParquetErrorWrapper(e))?;
595
- }
596
- }
597
- }
598
-
599
- Ok(())
600
- }