datablade 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
datablade/sql/ddl.py CHANGED
@@ -1,13 +1,117 @@
1
- from typing import Any, List, Optional
1
+ """Pandas-driven DDL generation for multiple SQL dialects."""
2
+
3
+ from typing import Any, List, Mapping, Optional
2
4
 
3
5
  import pandas as pd
4
6
 
5
7
  from ..utils.messages import print_verbose
6
8
  from .dialects import Dialect
7
9
  from .quoting import quote_identifier
10
+ from .schema_spec import resolve_column_spec, resolve_string_policy
11
+
12
+ _VALID_PREFER_LENGTH = {"estimate", "minimum", "maximum"}
13
+
14
+
15
+ def _coerce_positive_int(value: Any, label: str) -> Optional[int]:
16
+ if value is None:
17
+ return None
18
+ if isinstance(value, bool) or not isinstance(value, int) or value <= 0:
19
+ raise ValueError(f"{label} must be a positive integer")
20
+ return int(value)
21
+
22
+
23
+ def _coerce_non_negative_int(value: Any, label: str) -> int:
24
+ if value is None:
25
+ return 0
26
+ if isinstance(value, bool) or not isinstance(value, int) or value < 0:
27
+ raise ValueError(f"{label} must be a non-negative integer")
28
+ return int(value)
29
+
30
+
31
+ def _coerce_optional_bool(value: Any, label: str) -> Optional[bool]:
32
+ if value is None:
33
+ return None
34
+ if not isinstance(value, bool):
35
+ raise TypeError(f"{label} must be a boolean")
36
+ return value
37
+
38
+
39
+ def _normalize_string_policy(policy: Optional[Mapping[str, Any]]) -> dict:
40
+ policy = {} if policy is None else dict(policy)
41
+ if "defined_pad" in policy and "pad" not in policy:
42
+ policy["pad"] = policy["defined_pad"]
43
+
44
+ prefer_length = policy.get("prefer_length", "estimate")
45
+ if prefer_length not in _VALID_PREFER_LENGTH:
46
+ raise ValueError(
47
+ "prefer_length must be one of 'estimate', 'minimum', or 'maximum'"
48
+ )
8
49
 
50
+ min_length = _coerce_positive_int(policy.get("min_length"), "min_length")
51
+ max_length = _coerce_positive_int(policy.get("max_length"), "max_length")
52
+ pad = _coerce_non_negative_int(policy.get("pad"), "pad")
53
+ empty_as_null = (
54
+ _coerce_optional_bool(policy.get("empty_as_null"), "empty_as_null") or False
55
+ )
56
+ allow_null = _coerce_optional_bool(policy.get("allow_null"), "allow_null")
9
57
 
10
- def _infer_sql_type(series: pd.Series, dialect: Dialect) -> str: # noqa: C901
58
+ return {
59
+ "prefer_length": prefer_length,
60
+ "min_length": min_length,
61
+ "max_length": max_length,
62
+ "pad": pad,
63
+ "empty_as_null": empty_as_null,
64
+ "allow_null": allow_null,
65
+ }
66
+
67
+
68
+ def _string_series_stats(series: pd.Series, empty_as_null: bool) -> tuple[int, bool]:
69
+ non_null = series.dropna()
70
+ if non_null.empty:
71
+ return 0, False
72
+
73
+ as_str = non_null.astype(str)
74
+ empty_mask = as_str == ""
75
+ any_empty = bool(empty_mask.any())
76
+ if empty_as_null:
77
+ as_str = as_str[~empty_mask]
78
+ if as_str.empty:
79
+ return 0, any_empty
80
+
81
+ lengths = as_str.map(len)
82
+ max_length = int(lengths.max()) if not lengths.empty else 0
83
+ return max_length, any_empty
84
+
85
+
86
+ def _select_string_length(
87
+ max_length: int,
88
+ *,
89
+ prefer_length: str,
90
+ pad: int,
91
+ min_length: Optional[int],
92
+ max_length_bound: Optional[int],
93
+ ) -> int:
94
+ if prefer_length == "minimum" and min_length is not None:
95
+ length = min_length
96
+ elif prefer_length == "maximum" and max_length_bound is not None:
97
+ length = max_length_bound
98
+ else:
99
+ length = max_length + pad
100
+
101
+ if min_length is not None:
102
+ length = max(length, min_length)
103
+ if max_length_bound is not None:
104
+ length = min(length, max_length_bound)
105
+
106
+ return max(1, int(length))
107
+
108
+
109
+ def _infer_sql_type( # noqa: C901
110
+ series: pd.Series,
111
+ dialect: Dialect,
112
+ *,
113
+ string_policy: Optional[Mapping[str, Any]] = None,
114
+ ) -> str:
11
115
  """Infer a SQL column type for a pandas Series given a dialect."""
12
116
  dtype = series.dtype
13
117
 
@@ -25,6 +129,7 @@ def _infer_sql_type(series: pd.Series, dialect: Dialect) -> str: # noqa: C901
25
129
  return bool(sample.map(_is_bytes_like).all())
26
130
 
27
131
  if dialect == Dialect.SQLSERVER:
132
+ # Use SQL Server's tiered integer sizes for best-fit types.
28
133
  if pd.api.types.is_integer_dtype(dtype):
29
134
  non_null = series.dropna()
30
135
  if non_null.empty:
@@ -47,19 +152,25 @@ def _infer_sql_type(series: pd.Series, dialect: Dialect) -> str: # noqa: C901
47
152
  if _is_bytes_like_series(series):
48
153
  return "varbinary(max)"
49
154
  # strings / objects
50
- non_null = series.dropna()
51
- if non_null.empty:
52
- max_length = 1
53
- else:
54
- lengths = non_null.astype(str).map(len)
55
- max_length = int(lengths.max()) if not lengths.empty else 1
56
- if pd.api.types.is_object_dtype(dtype) or isinstance(
57
- dtype, pd.CategoricalDtype
155
+ policy = _normalize_string_policy(string_policy)
156
+ max_length, _ = _string_series_stats(series, policy["empty_as_null"])
157
+ max_length = _select_string_length(
158
+ max_length,
159
+ prefer_length=policy["prefer_length"],
160
+ pad=policy["pad"],
161
+ min_length=policy["min_length"],
162
+ max_length_bound=policy["max_length"],
163
+ )
164
+ if (
165
+ pd.api.types.is_object_dtype(dtype)
166
+ or pd.api.types.is_string_dtype(dtype)
167
+ or isinstance(dtype, pd.CategoricalDtype)
58
168
  ):
59
169
  return f"nvarchar({max_length if max_length <= 4000 else 'max'})"
60
170
  return "nvarchar(max)"
61
171
 
62
172
  if dialect == Dialect.POSTGRES:
173
+ # PostgreSQL integer sizes are narrower than SQL Server's tinyint.
63
174
  if pd.api.types.is_integer_dtype(dtype):
64
175
  non_null = series.dropna()
65
176
  if non_null.empty:
@@ -79,15 +190,19 @@ def _infer_sql_type(series: pd.Series, dialect: Dialect) -> str: # noqa: C901
79
190
  return "timestamp"
80
191
  if _is_bytes_like_series(series):
81
192
  return "bytea"
82
- non_null = series.dropna()
83
- if non_null.empty:
84
- max_length = 1
85
- else:
86
- lengths = non_null.astype(str).map(len)
87
- max_length = int(lengths.max()) if not lengths.empty else 1
193
+ policy = _normalize_string_policy(string_policy)
194
+ max_length, _ = _string_series_stats(series, policy["empty_as_null"])
195
+ max_length = _select_string_length(
196
+ max_length,
197
+ prefer_length=policy["prefer_length"],
198
+ pad=policy["pad"],
199
+ min_length=policy["min_length"],
200
+ max_length_bound=policy["max_length"],
201
+ )
88
202
  return f"varchar({max_length})" if max_length <= 65535 else "text"
89
203
 
90
204
  if dialect == Dialect.MYSQL:
205
+ # Keep MySQL type names consistent with the existing DDL outputs.
91
206
  if pd.api.types.is_integer_dtype(dtype):
92
207
  non_null = series.dropna()
93
208
  if non_null.empty:
@@ -107,15 +222,19 @@ def _infer_sql_type(series: pd.Series, dialect: Dialect) -> str: # noqa: C901
107
222
  return "DATETIME"
108
223
  if _is_bytes_like_series(series):
109
224
  return "LONGBLOB"
110
- non_null = series.dropna()
111
- if non_null.empty:
112
- max_length = 1
113
- else:
114
- lengths = non_null.astype(str).map(len)
115
- max_length = int(lengths.max()) if not lengths.empty else 1
225
+ policy = _normalize_string_policy(string_policy)
226
+ max_length, _ = _string_series_stats(series, policy["empty_as_null"])
227
+ max_length = _select_string_length(
228
+ max_length,
229
+ prefer_length=policy["prefer_length"],
230
+ pad=policy["pad"],
231
+ min_length=policy["min_length"],
232
+ max_length_bound=policy["max_length"],
233
+ )
116
234
  return f"VARCHAR({max_length})" if max_length <= 65535 else "TEXT"
117
235
 
118
236
  if dialect == Dialect.DUCKDB:
237
+ # DuckDB has simplified type names and distinguishes signed/unsigned.
119
238
  if pd.api.types.is_integer_dtype(dtype):
120
239
  return (
121
240
  "BIGINT" if pd.api.types.is_signed_integer_dtype(dtype) else "UBIGINT"
@@ -136,6 +255,7 @@ def _infer_sql_type(series: pd.Series, dialect: Dialect) -> str: # noqa: C901
136
255
  def _qualify_name(
137
256
  catalog: Optional[str], schema: Optional[str], table: str, dialect: Dialect
138
257
  ) -> str:
258
+ """Build a fully-qualified table name for the selected dialect."""
139
259
  if dialect == Dialect.SQLSERVER:
140
260
  # catalog and schema are both used when provided
141
261
  if catalog:
@@ -161,6 +281,8 @@ def generate_create_table(
161
281
  table: str = "table",
162
282
  drop_existing: bool = True,
163
283
  dialect: Dialect = Dialect.SQLSERVER,
284
+ use_go: bool = False,
285
+ schema_spec: Optional[Mapping[str, Any]] = None,
164
286
  verbose: bool = False,
165
287
  ) -> str:
166
288
  """
@@ -173,6 +295,9 @@ def generate_create_table(
173
295
  table: Target table name.
174
296
  drop_existing: If True, include a DROP TABLE IF EXISTS stanza.
175
297
  dialect: SQL dialect.
298
+ use_go: If True and dialect is SQL Server, insert a GO batch separator
299
+ after USE when a catalog is provided.
300
+ schema_spec: Optional schema overrides for column types and string sizing.
176
301
  verbose: If True, prints progress messages.
177
302
 
178
303
  Returns:
@@ -193,17 +318,62 @@ def generate_create_table(
193
318
  raise ValueError("catalog, if provided, must be a non-empty string")
194
319
  if schema is not None and (not isinstance(schema, str) or not schema.strip()):
195
320
  raise ValueError("schema, if provided, must be a non-empty string")
321
+ if not isinstance(use_go, bool):
322
+ raise TypeError("use_go must be a boolean")
196
323
 
197
324
  qualified_name = _qualify_name(catalog, schema, table, dialect)
198
325
  lines: List[str] = []
199
326
 
200
327
  for column in df.columns:
201
328
  series = df[column]
329
+ column_name = str(column)
330
+ defaults, column_spec = resolve_column_spec(column_name, schema_spec)
331
+ string_policy = resolve_string_policy(column_name, defaults, column_spec)
332
+ normalized_policy = _normalize_string_policy(string_policy)
333
+
202
334
  nullable = series.isnull().any()
203
- sql_type = _infer_sql_type(series, dialect)
335
+ if normalized_policy["empty_as_null"] and (
336
+ pd.api.types.is_object_dtype(series.dtype)
337
+ or isinstance(series.dtype, pd.CategoricalDtype)
338
+ or pd.api.types.is_string_dtype(series.dtype)
339
+ ):
340
+ _, any_empty = _string_series_stats(series, True)
341
+ if any_empty:
342
+ nullable = True
343
+
344
+ nullable_override = _coerce_optional_bool(
345
+ column_spec.get("nullable"), "nullable"
346
+ )
347
+ if nullable_override is None:
348
+ nullable_override = _coerce_optional_bool(
349
+ column_spec.get("allow_null"), "allow_null"
350
+ )
351
+ if nullable_override is None:
352
+ nullable_override = normalized_policy["allow_null"]
353
+ if nullable_override is None:
354
+ nullable_override = _coerce_optional_bool(
355
+ defaults.get("nullable"), "defaults.nullable"
356
+ )
357
+ if nullable_override is None:
358
+ nullable_override = _coerce_optional_bool(
359
+ defaults.get("allow_null"), "defaults.allow_null"
360
+ )
361
+ if nullable_override is not None:
362
+ nullable = nullable_override
363
+
364
+ sql_type_override = column_spec.get("sql_type")
365
+ if sql_type_override is not None:
366
+ if not isinstance(sql_type_override, str) or not sql_type_override.strip():
367
+ raise ValueError(
368
+ f"schema_spec.columns['{column_name}'].sql_type must be a non-empty string"
369
+ )
370
+ sql_type = sql_type_override.strip()
371
+ else:
372
+ sql_type = _infer_sql_type(series, dialect, string_policy=normalized_policy)
373
+
204
374
  null_str = "NULL" if nullable else "NOT NULL"
205
375
  lines.append(
206
- f" {quote_identifier(str(column), dialect)} {sql_type} {null_str}"
376
+ f" {quote_identifier(column_name, dialect)} {sql_type} {null_str}"
207
377
  )
208
378
 
209
379
  body = ",\n".join(lines)
@@ -211,14 +381,19 @@ def generate_create_table(
211
381
  drop_clause = ""
212
382
  if drop_existing:
213
383
  if dialect == Dialect.SQLSERVER:
384
+ object_id_name = qualified_name.replace("'", "''")
214
385
  if catalog:
386
+ batch_sep = "GO\n" if use_go else ""
215
387
  drop_clause = (
216
388
  f"USE {quote_identifier(catalog, dialect)};\n"
217
- f"IF OBJECT_ID('{qualified_name}') IS NOT NULL "
389
+ f"{batch_sep}IF OBJECT_ID('{object_id_name}') IS NOT NULL "
218
390
  f"DROP TABLE {qualified_name};\n"
219
391
  )
220
392
  else:
221
- drop_clause = f"IF OBJECT_ID('{qualified_name}') IS NOT NULL DROP TABLE {qualified_name};\n"
393
+ drop_clause = (
394
+ f"IF OBJECT_ID('{object_id_name}') IS NOT NULL "
395
+ f"DROP TABLE {qualified_name};\n"
396
+ )
222
397
  else:
223
398
  drop_clause = f"DROP TABLE IF EXISTS {qualified_name};\n"
224
399
 
@@ -1,17 +1,50 @@
1
+ """Parquet schema-driven DDL generation using PyArrow."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import logging
4
- from typing import List, Optional
6
+ import pathlib
7
+ from dataclasses import dataclass
8
+ from typing import Any, List, Mapping, Optional, Union
5
9
 
6
10
  from ..utils.messages import print_verbose
11
+ from ..utils.strings import coerce_path
7
12
  from .ddl import _qualify_name
8
13
  from .dialects import Dialect
9
14
  from .quoting import quote_identifier
15
+ from .schema_spec import resolve_column_spec
10
16
 
11
17
  logger = logging.getLogger("datablade")
12
18
 
13
19
 
20
+ @dataclass(frozen=True)
21
+ class DroppedColumn:
22
+ """Metadata about a dropped column during Parquet DDL generation."""
23
+
24
+ name: str
25
+ arrow_type: str
26
+ reason: str
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class FallbackColumn:
31
+ """Metadata about a column handled via JSON fallback."""
32
+
33
+ name: str
34
+ arrow_type: str
35
+ sql_type: str
36
+
37
+
38
+ @dataclass(frozen=True)
39
+ class ParquetDDLMetadata:
40
+ """Details about columns dropped or handled via fallback."""
41
+
42
+ dropped_columns: List[DroppedColumn]
43
+ fallback_columns: List[FallbackColumn]
44
+
45
+
14
46
  def _require_pyarrow():
47
+ """Import pyarrow lazily to keep core dependencies light."""
15
48
  try:
16
49
  import pyarrow as pa # type: ignore
17
50
  import pyarrow.parquet as pq # type: ignore
@@ -23,6 +56,30 @@ def _require_pyarrow():
23
56
  return pa, pq
24
57
 
25
58
 
59
+ def _is_complex_arrow_type(data_type) -> bool:
60
+ pa, _ = _require_pyarrow()
61
+ return (
62
+ pa.types.is_struct(data_type)
63
+ or pa.types.is_list(data_type)
64
+ or pa.types.is_large_list(data_type)
65
+ or pa.types.is_fixed_size_list(data_type)
66
+ or pa.types.is_map(data_type)
67
+ or pa.types.is_union(data_type)
68
+ )
69
+
70
+
71
+ def _json_fallback_sql_type(dialect: Dialect) -> str:
72
+ if dialect == Dialect.SQLSERVER:
73
+ return "nvarchar(max)"
74
+ if dialect == Dialect.POSTGRES:
75
+ return "text"
76
+ if dialect == Dialect.MYSQL:
77
+ return "TEXT"
78
+ if dialect == Dialect.DUCKDB:
79
+ return "VARCHAR"
80
+ raise NotImplementedError(f"Dialect not supported: {dialect}")
81
+
82
+
26
83
  def _sql_type_from_arrow(data_type, dialect: Dialect) -> Optional[str]: # noqa: C901
27
84
  """Map a pyarrow.DataType to a SQL type string.
28
85
 
@@ -153,8 +210,8 @@ def _sql_type_from_arrow(data_type, dialect: Dialect) -> Optional[str]: # noqa:
153
210
  if pa.types.is_float64(data_type):
154
211
  return "DOUBLE"
155
212
  if pa.types.is_decimal(data_type):
156
- precision = int(data_type.precision)
157
- scale = int(data_type.scale)
213
+ precision = min(int(data_type.precision), 65)
214
+ scale = min(int(data_type.scale), 30, precision)
158
215
  return f"DECIMAL({precision}, {scale})"
159
216
  if pa.types.is_date(data_type):
160
217
  return "DATE"
@@ -204,56 +261,116 @@ def _sql_type_from_arrow(data_type, dialect: Dialect) -> Optional[str]: # noqa:
204
261
 
205
262
 
206
263
  def generate_create_table_from_parquet(
207
- parquet_path: str,
264
+ parquet_path: str | pathlib.Path,
208
265
  catalog: Optional[str] = None,
209
266
  schema: Optional[str] = None,
210
267
  table: str = "table",
211
268
  drop_existing: bool = True,
212
269
  dialect: Dialect = Dialect.SQLSERVER,
270
+ use_go: bool = False,
271
+ schema_spec: Optional[Mapping[str, Any]] = None,
213
272
  verbose: bool = False,
214
- ) -> str:
273
+ fallback_to_json: bool = False,
274
+ return_metadata: bool = False,
275
+ ) -> Union[str, tuple[str, ParquetDDLMetadata]]:
215
276
  """Generate a CREATE TABLE statement from a Parquet file schema.
216
277
 
217
278
  This reads the Parquet schema only (via PyArrow) and does not materialize data.
218
279
 
219
280
  Columns whose Parquet types have no clean mapping for the chosen dialect are
220
- dropped, and a warning is logged under logger name 'datablade'.
281
+ dropped, and a warning is logged under logger name 'datablade'. If
282
+ fallback_to_json is enabled, complex types are instead mapped to a text
283
+ column intended to store JSON-encoded values. Use return_metadata to receive
284
+ details about dropped and fallback-mapped columns.
285
+
286
+ When dialect is SQL Server and use_go is True, a GO batch separator is
287
+ inserted after a USE statement when a catalog is provided.
288
+
289
+ schema_spec may provide per-column sql_type/nullable overrides.
221
290
  """
222
291
 
223
- if (
224
- parquet_path is None
225
- or not isinstance(parquet_path, str)
226
- or not parquet_path.strip()
227
- ):
228
- raise ValueError("parquet_path must be a non-empty string")
292
+ path_obj = coerce_path(
293
+ parquet_path,
294
+ must_exist=True,
295
+ verbose=verbose,
296
+ label="parquet_path",
297
+ )
229
298
  if not isinstance(table, str) or not table.strip():
230
299
  raise ValueError("table must be a non-empty string")
231
300
  if catalog is not None and (not isinstance(catalog, str) or not catalog.strip()):
232
301
  raise ValueError("catalog, if provided, must be a non-empty string")
233
302
  if schema is not None and (not isinstance(schema, str) or not schema.strip()):
234
303
  raise ValueError("schema, if provided, must be a non-empty string")
304
+ if not isinstance(use_go, bool):
305
+ raise TypeError("use_go must be a boolean")
235
306
 
236
307
  _, pq = _require_pyarrow()
237
308
 
238
- arrow_schema = pq.ParquetFile(parquet_path).schema_arrow
309
+ # Read Parquet metadata only; this does not load row data.
310
+ arrow_schema = pq.ParquetFile(path_obj).schema_arrow
239
311
 
240
312
  qualified_name = _qualify_name(catalog, schema, table, dialect)
241
313
  lines: List[str] = []
314
+ dropped_columns: List[DroppedColumn] = []
315
+ fallback_columns: List[FallbackColumn] = []
242
316
 
243
317
  for field in arrow_schema:
244
- sql_type = _sql_type_from_arrow(field.type, dialect)
245
- if sql_type is None:
246
- logger.warning(
247
- "Dropping Parquet column %r (type=%s) for dialect=%s: unsupported type",
248
- field.name,
249
- str(field.type),
250
- dialect.value,
251
- )
252
- continue
318
+ column_name = str(field.name)
319
+ defaults, column_spec = resolve_column_spec(column_name, schema_spec)
320
+ sql_type_override = column_spec.get("sql_type")
321
+ if sql_type_override is not None:
322
+ if not isinstance(sql_type_override, str) or not sql_type_override.strip():
323
+ raise ValueError(
324
+ f"schema_spec.columns['{column_name}'].sql_type must be a non-empty string"
325
+ )
326
+ sql_type = sql_type_override.strip()
327
+ else:
328
+ sql_type = _sql_type_from_arrow(field.type, dialect)
253
329
 
254
- null_str = "NULL" if field.nullable else "NOT NULL"
330
+ if sql_type is None:
331
+ if fallback_to_json and _is_complex_arrow_type(field.type):
332
+ fallback_sql_type = _json_fallback_sql_type(dialect)
333
+ fallback_columns.append(
334
+ FallbackColumn(
335
+ name=str(field.name),
336
+ arrow_type=str(field.type),
337
+ sql_type=fallback_sql_type,
338
+ )
339
+ )
340
+ sql_type = fallback_sql_type
341
+ else:
342
+ dropped_columns.append(
343
+ DroppedColumn(
344
+ name=str(field.name),
345
+ arrow_type=str(field.type),
346
+ reason="unsupported type",
347
+ )
348
+ )
349
+ logger.warning(
350
+ "Dropping Parquet column %r (type=%s) for dialect=%s: unsupported type",
351
+ field.name,
352
+ str(field.type),
353
+ dialect.value,
354
+ )
355
+ continue
356
+
357
+ nullable = field.nullable
358
+ for label, value in (
359
+ ("nullable", column_spec.get("nullable")),
360
+ ("allow_null", column_spec.get("allow_null")),
361
+ ("defaults.nullable", defaults.get("nullable")),
362
+ ("defaults.allow_null", defaults.get("allow_null")),
363
+ ):
364
+ if value is None:
365
+ continue
366
+ if not isinstance(value, bool):
367
+ raise TypeError(f"{label} must be a boolean")
368
+ nullable = value
369
+ break
370
+
371
+ null_str = "NULL" if nullable else "NOT NULL"
255
372
  lines.append(
256
- f" {quote_identifier(str(field.name), dialect)} {sql_type} {null_str}"
373
+ f" {quote_identifier(column_name, dialect)} {sql_type} {null_str}"
257
374
  )
258
375
 
259
376
  if not lines:
@@ -266,15 +383,17 @@ def generate_create_table_from_parquet(
266
383
  drop_clause = ""
267
384
  if drop_existing:
268
385
  if dialect == Dialect.SQLSERVER:
386
+ object_id_name = qualified_name.replace("'", "''")
269
387
  if catalog:
388
+ batch_sep = "GO\n" if use_go else ""
270
389
  drop_clause = (
271
390
  f"USE {quote_identifier(catalog, dialect)};\n"
272
- f"IF OBJECT_ID('{qualified_name}') IS NOT NULL "
391
+ f"{batch_sep}IF OBJECT_ID('{object_id_name}') IS NOT NULL "
273
392
  f"DROP TABLE {qualified_name};\n"
274
393
  )
275
394
  else:
276
395
  drop_clause = (
277
- f"IF OBJECT_ID('{qualified_name}') IS NOT NULL "
396
+ f"IF OBJECT_ID('{object_id_name}') IS NOT NULL "
278
397
  f"DROP TABLE {qualified_name};\n"
279
398
  )
280
399
  else:
@@ -284,4 +403,9 @@ def generate_create_table_from_parquet(
284
403
  print_verbose(
285
404
  f"Generated CREATE TABLE from Parquet schema for {qualified_name}", verbose
286
405
  )
406
+ if return_metadata:
407
+ metadata = ParquetDDLMetadata(
408
+ dropped_columns=dropped_columns, fallback_columns=fallback_columns
409
+ )
410
+ return statement, metadata
287
411
  return statement
datablade/sql/dialects.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Enumeration of SQL dialects supported by datablade."""
2
+
1
3
  from enum import Enum
2
4
 
3
5
 
datablade/sql/quoting.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Identifier quoting for supported SQL dialects."""
2
+
1
3
  from typing import Optional
2
4
 
3
5
  from .dialects import Dialect
@@ -0,0 +1,65 @@
1
+ """Schema specification helpers for DDL generation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Mapping
6
+ from typing import Any, Optional, Tuple
7
+
8
+
9
+ def _as_mapping(value: Any, label: str) -> dict:
10
+ if value is None:
11
+ return {}
12
+ if not isinstance(value, Mapping):
13
+ raise TypeError(f"{label} must be a mapping")
14
+ return dict(value)
15
+
16
+
17
+ def resolve_schema_spec(
18
+ schema_spec: Optional[Mapping[str, Any]],
19
+ ) -> Tuple[dict, dict]:
20
+ """Return (defaults, columns) mappings for a schema spec."""
21
+ if schema_spec is None:
22
+ return {}, {}
23
+ if not isinstance(schema_spec, Mapping):
24
+ raise TypeError("schema_spec must be a mapping")
25
+
26
+ defaults = _as_mapping(schema_spec.get("defaults"), "schema_spec.defaults")
27
+ columns = _as_mapping(schema_spec.get("columns"), "schema_spec.columns")
28
+ return defaults, columns
29
+
30
+
31
+ def resolve_column_spec(
32
+ column_name: str,
33
+ schema_spec: Optional[Mapping[str, Any]],
34
+ ) -> Tuple[dict, dict]:
35
+ """Return (defaults, column_spec) for a column name."""
36
+ defaults, columns = resolve_schema_spec(schema_spec)
37
+ if not columns:
38
+ return defaults, {}
39
+
40
+ column_spec = columns.get(column_name)
41
+ if column_spec is None:
42
+ column_spec = columns.get(str(column_name))
43
+
44
+ if column_spec is None:
45
+ return defaults, {}
46
+ if not isinstance(column_spec, Mapping):
47
+ raise TypeError(f"schema_spec.columns['{column_name}'] must be a mapping")
48
+ return defaults, dict(column_spec)
49
+
50
+
51
+ def resolve_string_policy(
52
+ column_name: str,
53
+ defaults: dict,
54
+ column_spec: dict,
55
+ ) -> dict:
56
+ """Merge defaults + column string policy overrides."""
57
+ string_defaults = _as_mapping(defaults.get("string"), "schema_spec.defaults.string")
58
+ string_overrides = _as_mapping(
59
+ column_spec.get("string"),
60
+ f"schema_spec.columns['{column_name}'].string",
61
+ )
62
+ policy = {**string_defaults, **string_overrides}
63
+ if "defined_pad" in policy and "pad" not in policy:
64
+ policy["pad"] = policy["defined_pad"]
65
+ return policy