colnade-dask 0.3.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ *.egg
8
+ .eggs/
9
+ *.so
10
+ .mypy_cache/
11
+ .ruff_cache/
12
+ .pytest_cache/
13
+ .coverage
14
+ htmlcov/
15
+ *.parquet
16
+ *.csv
17
+ !tests/fixtures/*.csv
18
+ !tests/fixtures/*.parquet
19
+ .nox/
20
+ .venv/
21
+ venv/
22
+ env/
23
+ .env
24
+ site/
@@ -0,0 +1,61 @@
1
+ Metadata-Version: 2.4
2
+ Name: colnade-dask
3
+ Version: 0.3.1
4
+ Summary: Dask backend adapter for Colnade
5
+ Project-URL: Homepage, https://colnade.com
6
+ Project-URL: Documentation, https://colnade.com
7
+ Project-URL: Repository, https://github.com/jwde/colnade
8
+ Author: jwde
9
+ License-Expression: MIT
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Software Development :: Libraries
19
+ Classifier: Typing :: Typed
20
+ Requires-Python: >=3.10
21
+ Requires-Dist: colnade-pandas>=0.3.0
22
+ Requires-Dist: colnade>=0.5.0
23
+ Requires-Dist: dask[dataframe]>=2024.1
24
+ Requires-Dist: pyarrow>=12.0
25
+ Description-Content-Type: text/markdown
26
+
27
+ # colnade-dask
28
+
29
+ Dask backend adapter for [Colnade](https://github.com/jwde/colnade). Supports lazy evaluation and distributed computation.
30
+
31
+ ## Installation
32
+
33
+ ```bash
34
+ pip install colnade-dask
35
+ ```
36
+
37
+ ## Usage
38
+
39
+ ```python
40
+ from colnade import Column, Schema, UInt64, Float64, Utf8
41
+ from colnade_dask import scan_parquet
42
+
43
+ class Users(Schema):
44
+ id: Column[UInt64]
45
+ name: Column[Utf8]
46
+ age: Column[UInt64]
47
+ score: Column[Float64]
48
+
49
+ lf = scan_parquet("users.parquet", Users)
50
+ result = lf.filter(Users.age > 25).sort(Users.score.desc()).collect()
51
+ ```
52
+
53
+ ## I/O Functions
54
+
55
+ - `read_parquet` / `write_parquet` (eager)
56
+ - `scan_parquet` / `scan_csv` (lazy)
57
+ - `read_csv` / `write_csv`
58
+
59
+ ## Documentation
60
+
61
+ Full documentation at [colnade.com](https://colnade.com/).
@@ -0,0 +1,35 @@
1
+ # colnade-dask
2
+
3
+ Dask backend adapter for [Colnade](https://github.com/jwde/colnade). Supports lazy evaluation and distributed computation.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install colnade-dask
9
+ ```
10
+
11
+ ## Usage
12
+
13
+ ```python
14
+ from colnade import Column, Schema, UInt64, Float64, Utf8
15
+ from colnade_dask import scan_parquet
16
+
17
+ class Users(Schema):
18
+ id: Column[UInt64]
19
+ name: Column[Utf8]
20
+ age: Column[UInt64]
21
+ score: Column[Float64]
22
+
23
+ lf = scan_parquet("users.parquet", Users)
24
+ result = lf.filter(Users.age > 25).sort(Users.score.desc()).collect()
25
+ ```
26
+
27
+ ## I/O Functions
28
+
29
+ - `read_parquet` / `write_parquet` (eager)
30
+ - `scan_parquet` / `scan_csv` (lazy)
31
+ - `read_csv` / `write_csv`
32
+
33
+ ## Documentation
34
+
35
+ Full documentation at [colnade.com](https://colnade.com/).
@@ -0,0 +1,40 @@
1
+ [project]
2
+ name = "colnade-dask"
3
+ version = "0.3.1"
4
+ description = "Dask backend adapter for Colnade"
5
+ requires-python = ">=3.10"
6
+ license = "MIT"
7
+ readme = "README.md"
8
+ authors = [
9
+ { name = "jwde" },
10
+ ]
11
+ classifiers = [
12
+ "Development Status :: 3 - Alpha",
13
+ "Intended Audience :: Developers",
14
+ "License :: OSI Approved :: MIT License",
15
+ "Programming Language :: Python :: 3",
16
+ "Programming Language :: Python :: 3.10",
17
+ "Programming Language :: Python :: 3.11",
18
+ "Programming Language :: Python :: 3.12",
19
+ "Programming Language :: Python :: 3.13",
20
+ "Typing :: Typed",
21
+ "Topic :: Software Development :: Libraries",
22
+ ]
23
+ dependencies = [
24
+ "colnade>=0.5.0",
25
+ "colnade-pandas>=0.3.0",
26
+ "dask[dataframe]>=2024.1",
27
+ "pyarrow>=12.0",
28
+ ]
29
+
30
+ [project.urls]
31
+ Homepage = "https://colnade.com"
32
+ Documentation = "https://colnade.com"
33
+ Repository = "https://github.com/jwde/colnade"
34
+
35
+ [build-system]
36
+ requires = ["hatchling"]
37
+ build-backend = "hatchling.build"
38
+
39
+ [tool.hatch.build.targets.wheel]
40
+ packages = ["src/colnade_dask"]
@@ -0,0 +1,25 @@
1
+ """Colnade Dask backend adapter."""
2
+
3
+ from colnade_dask.adapter import DaskBackend
4
+ from colnade_dask.io import (
5
+ from_dict,
6
+ from_rows,
7
+ read_csv,
8
+ read_parquet,
9
+ scan_csv,
10
+ scan_parquet,
11
+ write_csv,
12
+ write_parquet,
13
+ )
14
+
15
+ __all__ = [
16
+ "DaskBackend",
17
+ "from_dict",
18
+ "from_rows",
19
+ "read_csv",
20
+ "read_parquet",
21
+ "scan_csv",
22
+ "scan_parquet",
23
+ "write_csv",
24
+ "write_parquet",
25
+ ]
@@ -0,0 +1,662 @@
1
+ """DaskBackend — translates Colnade expression trees and executes operations on Dask DataFrames."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import types as _types
6
+ from collections.abc import Iterator, Sequence
7
+ from typing import Any
8
+
9
+ import dask.dataframe as dd
10
+ import pandas as pd
11
+
12
+ from colnade.expr import (
13
+ Agg,
14
+ AliasedExpr,
15
+ BinOp,
16
+ ColumnRef,
17
+ Expr,
18
+ FunctionCall,
19
+ ListOp,
20
+ Literal,
21
+ SortExpr,
22
+ StructFieldAccess,
23
+ UnaryOp,
24
+ )
25
+ from colnade.schema import Column, Schema, SchemaError
26
+ from colnade_pandas.conversion import map_colnade_dtype, map_pandas_dtype
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # BinOp operator dispatch
30
+ # ---------------------------------------------------------------------------
31
+
32
+ _BINOP_MAP: dict[str, str] = {
33
+ "+": "__add__",
34
+ "-": "__sub__",
35
+ "*": "__mul__",
36
+ "/": "__truediv__",
37
+ "%": "__mod__",
38
+ ">": "__gt__",
39
+ "<": "__lt__",
40
+ ">=": "__ge__",
41
+ "<=": "__le__",
42
+ "==": "__eq__",
43
+ "!=": "__ne__",
44
+ "&": "__and__",
45
+ "|": "__or__",
46
+ }
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Agg function name mapping (Colnade → Pandas/Dask)
50
+ # ---------------------------------------------------------------------------
51
+
52
+ _AGG_MAP: dict[str, str] = {
53
+ "sum": "sum",
54
+ "mean": "mean",
55
+ "min": "min",
56
+ "max": "max",
57
+ "count": "count",
58
+ "first": "first",
59
+ "last": "last",
60
+ }
61
+
62
+ # ---------------------------------------------------------------------------
63
+ # DaskBackend
64
+ # ---------------------------------------------------------------------------
65
+
66
+
67
+ class DaskBackend:
68
+ """Colnade backend adapter for Dask.
69
+
70
+ Expression translation produces callables ``(df) -> Series | scalar``
71
+ since Dask, like Pandas, has no standalone lazy expression API. The
72
+ callables build lazy Dask task graphs instead of executing immediately.
73
+ """
74
+
75
+ # --- Expression translation ---
76
+
77
+ def translate_expr(self, expr: Expr[Any]) -> Any:
78
+ """Recursively translate a Colnade AST node to a callable (df -> result)."""
79
+ if isinstance(expr, AliasedExpr):
80
+ inner_fn = self.translate_expr(expr.expr)
81
+ target_name = expr.target.name
82
+ return (inner_fn, target_name)
83
+
84
+ if isinstance(expr, ColumnRef):
85
+ col_name = expr.column.name
86
+ return lambda df, _cn=col_name: df[_cn]
87
+
88
+ if isinstance(expr, Literal):
89
+ val = expr.value
90
+ return lambda df, _v=val: _v
91
+
92
+ if isinstance(expr, BinOp):
93
+ left_fn = self._ensure_callable(self.translate_expr(expr.left))
94
+ right_fn = self._ensure_callable(self.translate_expr(expr.right))
95
+ method = _BINOP_MAP.get(expr.op)
96
+ if method is None:
97
+ msg = f"Unsupported BinOp operator: {expr.op}"
98
+ raise ValueError(msg)
99
+ return lambda df, _l=left_fn, _r=right_fn, _m=method: getattr(_l(df), _m)(_r(df))
100
+
101
+ if isinstance(expr, UnaryOp):
102
+ operand_fn = self._ensure_callable(self.translate_expr(expr.operand))
103
+ if expr.op == "-":
104
+ return lambda df, _o=operand_fn: -_o(df)
105
+ if expr.op == "~":
106
+ return lambda df, _o=operand_fn: ~_o(df)
107
+ if expr.op == "is_null":
108
+ return lambda df, _o=operand_fn: _o(df).isna()
109
+ if expr.op == "is_not_null":
110
+ return lambda df, _o=operand_fn: _o(df).notnull()
111
+ if expr.op == "is_nan":
112
+ return lambda df, _o=operand_fn: _o(df).isna()
113
+ msg = f"Unsupported UnaryOp: {expr.op}"
114
+ raise ValueError(msg)
115
+
116
+ if isinstance(expr, Agg):
117
+ source_fn = self._ensure_callable(self.translate_expr(expr.source))
118
+ agg_name = _AGG_MAP.get(expr.agg_type)
119
+ if agg_name is None:
120
+ msg = f"Unsupported aggregation: {expr.agg_type}"
121
+ raise ValueError(msg)
122
+ return (source_fn, agg_name)
123
+
124
+ if isinstance(expr, FunctionCall):
125
+ return self._translate_function_call(expr)
126
+
127
+ if isinstance(expr, StructFieldAccess):
128
+ struct_fn = self._ensure_callable(self.translate_expr(expr.struct_expr))
129
+ field_name = expr.field.name
130
+ return lambda df, _s=struct_fn, _f=field_name: _s(df).apply(lambda x: x.get(_f))
131
+
132
+ if isinstance(expr, ListOp):
133
+ return self._translate_list_op(expr)
134
+
135
+ msg = f"Unsupported expression type: {type(expr).__name__}"
136
+ raise TypeError(msg)
137
+
138
+ def _ensure_callable(self, translated: Any) -> Any:
139
+ """Unwrap (fn, alias) tuples from AliasedExpr to get the callable."""
140
+ if isinstance(translated, tuple):
141
+ return translated[0]
142
+ return translated
143
+
144
+ def _translate_function_call(self, expr: FunctionCall[Any]) -> Any:
145
+ """Translate a FunctionCall node to a Dask callable."""
146
+ name = expr.name
147
+
148
+ # String methods
149
+ if name == "str_contains":
150
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
151
+ pattern = expr.args[1]
152
+ return lambda df, _s=source_fn, _p=pattern: _s(df).str.contains(_p, regex=False)
153
+ if name == "str_starts_with":
154
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
155
+ pattern = expr.args[1]
156
+ return lambda df, _s=source_fn, _p=pattern: _s(df).str.startswith(_p)
157
+ if name == "str_ends_with":
158
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
159
+ pattern = expr.args[1]
160
+ return lambda df, _s=source_fn, _p=pattern: _s(df).str.endswith(_p)
161
+ if name == "str_len":
162
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
163
+ return lambda df, _s=source_fn: _s(df).str.len()
164
+ if name == "str_to_lowercase":
165
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
166
+ return lambda df, _s=source_fn: _s(df).str.lower()
167
+ if name == "str_to_uppercase":
168
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
169
+ return lambda df, _s=source_fn: _s(df).str.upper()
170
+ if name == "str_strip":
171
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
172
+ return lambda df, _s=source_fn: _s(df).str.strip()
173
+ if name == "str_replace":
174
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
175
+ old, new = expr.args[1], expr.args[2]
176
+ return lambda df, _s=source_fn, _o=old, _n=new: _s(df).str.replace(_o, _n, regex=False)
177
+
178
+ # Temporal methods
179
+ if name == "dt_year":
180
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
181
+ return lambda df, _s=source_fn: _s(df).dt.year
182
+ if name == "dt_month":
183
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
184
+ return lambda df, _s=source_fn: _s(df).dt.month
185
+ if name == "dt_day":
186
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
187
+ return lambda df, _s=source_fn: _s(df).dt.day
188
+ if name == "dt_hour":
189
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
190
+ return lambda df, _s=source_fn: _s(df).dt.hour
191
+ if name == "dt_minute":
192
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
193
+ return lambda df, _s=source_fn: _s(df).dt.minute
194
+ if name == "dt_second":
195
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
196
+ return lambda df, _s=source_fn: _s(df).dt.second
197
+ if name == "dt_truncate":
198
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
199
+ unit = expr.args[1]
200
+ return lambda df, _s=source_fn, _u=unit: _s(df).dt.floor(_u)
201
+
202
+ # Null/NaN handling
203
+ if name == "fill_null":
204
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
205
+ fill_fn = self._ensure_callable(self.translate_expr(expr.args[1]))
206
+ return lambda df, _s=source_fn, _f=fill_fn: _s(df).fillna(_f(df))
207
+ if name == "fill_nan":
208
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
209
+ fill_fn = self._ensure_callable(self.translate_expr(expr.args[1]))
210
+ return lambda df, _s=source_fn, _f=fill_fn: _s(df).fillna(_f(df))
211
+ if name == "assert_non_null":
212
+ return self._ensure_callable(self.translate_expr(expr.args[0]))
213
+
214
+ # Cast
215
+ if name == "cast":
216
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
217
+ target_dtype = map_colnade_dtype(expr.kwargs["dtype"])
218
+ return lambda df, _s=source_fn, _t=target_dtype: _s(df).astype(_t)
219
+
220
+ # Window function (over)
221
+ if name == "over":
222
+ source_fn = self._ensure_callable(self.translate_expr(expr.args[0]))
223
+ partition_names = [self._ensure_callable(self.translate_expr(a)) for a in expr.args[1:]]
224
+ return lambda df, _s=source_fn, _p=partition_names: df.groupby(
225
+ [p(df).name for p in _p]
226
+ )[_s(df).name].transform(lambda x: x)
227
+
228
+ msg = f"Unsupported FunctionCall: {name}"
229
+ raise ValueError(msg)
230
+
231
+ def _translate_list_op(self, expr: ListOp[Any]) -> Any:
232
+ """Translate a ListOp node to a Dask callable."""
233
+ list_fn = self._ensure_callable(self.translate_expr(expr.list_expr))
234
+ op = expr.op
235
+
236
+ if op == "len":
237
+ return lambda df, _l=list_fn: _l(df).apply(len, meta=(_l(df).name, "int64"))
238
+ if op == "get":
239
+ idx = expr.args[0]
240
+ return lambda df, _l=list_fn, _i=idx: _l(df).apply(
241
+ lambda x: x[_i], meta=(_l(df).name, "object")
242
+ )
243
+ if op == "contains":
244
+ val = expr.args[0]
245
+ return lambda df, _l=list_fn, _v=val: _l(df).apply(
246
+ lambda x: _v in x, meta=(_l(df).name, "bool")
247
+ )
248
+ if op == "sum":
249
+ return lambda df, _l=list_fn: _l(df).apply(sum, meta=(_l(df).name, "float64"))
250
+ if op == "mean":
251
+ return lambda df, _l=list_fn: _l(df).apply(
252
+ lambda x: sum(x) / len(x) if len(x) > 0 else None,
253
+ meta=(_l(df).name, "float64"),
254
+ )
255
+ if op == "min":
256
+ return lambda df, _l=list_fn: _l(df).apply(min, meta=(_l(df).name, "object"))
257
+ if op == "max":
258
+ return lambda df, _l=list_fn: _l(df).apply(max, meta=(_l(df).name, "object"))
259
+
260
+ msg = f"Unsupported ListOp: {op}"
261
+ raise ValueError(msg)
262
+
263
+ # --- Execution methods ---
264
+
265
+ def filter(self, source: Any, predicate: Expr[Any]) -> Any:
266
+ pred_fn = self._ensure_callable(self.translate_expr(predicate))
267
+ mask = pred_fn(source)
268
+ return source.loc[mask].reset_index(drop=True)
269
+
270
+ def sort(
271
+ self,
272
+ source: Any,
273
+ by: Sequence[Column[Any] | SortExpr],
274
+ descending: bool,
275
+ ) -> Any:
276
+ col_names: list[str] = []
277
+ ascending: list[bool] = []
278
+ for item in by:
279
+ if isinstance(item, SortExpr):
280
+ col_names.append(item.expr.column.name)
281
+ ascending.append(not item.descending)
282
+ else:
283
+ col_names.append(item.name)
284
+ ascending.append(not descending)
285
+ return source.sort_values(by=col_names, ascending=ascending).reset_index(drop=True)
286
+
287
+ def limit(self, source: Any, n: int) -> Any:
288
+ return source.head(n, npartitions=-1, compute=False)
289
+
290
+ def head(self, source: Any, n: int) -> Any:
291
+ return source.head(n, npartitions=-1, compute=False)
292
+
293
+ def tail(self, source: Any, n: int) -> Any:
294
+ # Dask tail() returns a Pandas DF — re-wrap in Dask
295
+ return dd.from_pandas(source.tail(n), npartitions=1)
296
+
297
+ def sample(self, source: Any, n: int) -> Any:
298
+ # Dask doesn't support sample(n=...), only sample(frac=...)
299
+ # Compute to Pandas, sample, and re-wrap
300
+ computed = source.compute()
301
+ sampled = computed.sample(n).reset_index(drop=True)
302
+ return dd.from_pandas(sampled, npartitions=1)
303
+
304
+ def unique(self, source: Any, columns: Sequence[Column[Any]]) -> Any:
305
+ return source.drop_duplicates(subset=[c.name for c in columns]).reset_index(drop=True)
306
+
307
+ def drop_nulls(self, source: Any, columns: Sequence[Column[Any]]) -> Any:
308
+ return source.dropna(subset=[c.name for c in columns]).reset_index(drop=True)
309
+
310
+ def with_columns(self, source: Any, exprs: Sequence[AliasedExpr[Any] | Expr[Any]]) -> Any:
311
+ result = source
312
+ for expr in exprs:
313
+ translated = self.translate_expr(expr)
314
+ if isinstance(translated, tuple):
315
+ fn, alias = translated
316
+ result = result.assign(**{alias: fn(result)})
317
+ else:
318
+ msg = "with_columns requires aliased expressions"
319
+ raise ValueError(msg)
320
+ return result
321
+
322
+ def select(self, source: Any, columns: Sequence[Column[Any]]) -> Any:
323
+ return source[[c.name for c in columns]]
324
+
325
+ def group_by_agg(
326
+ self,
327
+ source: Any,
328
+ keys: Sequence[Column[Any]],
329
+ aggs: Sequence[AliasedExpr[Any]],
330
+ ) -> Any:
331
+ key_names = [k.name for k in keys]
332
+ agg_dict: dict[str, str] = {}
333
+ rename_map: dict[str, str] = {}
334
+
335
+ for agg_expr in aggs:
336
+ translated = self.translate_expr(agg_expr)
337
+ if not isinstance(translated, tuple):
338
+ msg = "group_by_agg requires aliased aggregation expressions"
339
+ raise ValueError(msg)
340
+
341
+ inner, alias = translated
342
+ if isinstance(inner, tuple):
343
+ source_fn, agg_name = inner
344
+ col_name = self._extract_col_name(source_fn, source)
345
+ agg_dict[col_name] = agg_name
346
+ rename_map[col_name] = alias
347
+ else:
348
+ msg = "group_by_agg requires aggregation expressions (e.g., .sum(), .mean())"
349
+ raise ValueError(msg)
350
+
351
+ grouped = source.groupby(key_names).agg(agg_dict).reset_index()
352
+ for old_name, new_name in rename_map.items():
353
+ if old_name != new_name and old_name not in key_names:
354
+ grouped = grouped.rename(columns={old_name: new_name})
355
+ return grouped
356
+
357
+ def agg(self, source: Any, aggs: Sequence[AliasedExpr[Any]]) -> Any:
358
+ result: dict[str, Any] = {}
359
+ for agg_expr in aggs:
360
+ translated = self.translate_expr(agg_expr)
361
+ inner, alias = translated
362
+ source_fn, agg_name = inner
363
+ col_name = self._extract_col_name(source_fn, source)
364
+ val = getattr(source[col_name], agg_name)()
365
+ result[alias] = val.compute() if hasattr(val, "compute") else val
366
+ return dd.from_pandas(pd.DataFrame([result]), npartitions=1)
367
+
368
+ def _extract_col_name(self, fn: Any, df: Any) -> str:
369
+ """Extract column name from a translated ColumnRef function."""
370
+ series = fn(df)
371
+ if hasattr(series, "name"):
372
+ return series.name
373
+ msg = "Cannot extract column name from expression"
374
+ raise ValueError(msg)
375
+
376
+ @staticmethod
377
+ def _dtypes_compatible(actual: Any, expected: Any) -> bool:
378
+ """Compare dtypes accounting for storage backend variations.
379
+
380
+ Dask may use different storage backends (e.g., pyarrow vs python)
381
+ for the same logical dtype, so strict equality can fail.
382
+ """
383
+ if actual == expected:
384
+ return True
385
+ # StringDtype with different storage backends (python vs pyarrow)
386
+ return isinstance(actual, pd.StringDtype) and isinstance(expected, pd.StringDtype)
387
+
388
+ def join(self, left: Any, right: Any, on: Any, how: str) -> Any:
389
+ return left.merge(right, left_on=on.left.name, right_on=on.right.name, how=how)
390
+
391
+ def cast_schema(self, source: Any, column_mapping: dict[str, str]) -> Any:
392
+ rename_map = {src: tgt for tgt, src in column_mapping.items()}
393
+ result = source.rename(columns=rename_map)
394
+ return result[list(column_mapping.keys())]
395
+
396
+ def lazy(self, source: Any) -> Any:
397
+ # Dask is inherently lazy — passthrough
398
+ return source
399
+
400
+ def collect(self, source: Any) -> Any:
401
+ # Materialize the Dask task graph and re-wrap so DaskBackend
402
+ # can continue operating on the result.
403
+ return dd.from_pandas(source.compute(), npartitions=1)
404
+
405
+ def validate_schema(self, source: Any, schema: type[Schema]) -> None:
406
+ """Validate that a Dask DataFrame matches the schema."""
407
+ expected_columns = schema._columns
408
+ actual_names = set(source.columns)
409
+
410
+ missing = [n for n in expected_columns if n not in actual_names]
411
+ type_mismatches: dict[str, tuple[str, str]] = {}
412
+
413
+ for col_name, col in expected_columns.items():
414
+ if col_name not in actual_names:
415
+ continue
416
+ actual_pd_dtype = source[col_name].dtype
417
+ expected_pd_dtype = map_colnade_dtype(col.dtype)
418
+ if not self._dtypes_compatible(actual_pd_dtype, expected_pd_dtype):
419
+ try:
420
+ actual_colnade = map_pandas_dtype(actual_pd_dtype).__name__
421
+ except TypeError:
422
+ actual_colnade = str(actual_pd_dtype)
423
+ expected_name = (
424
+ col.dtype.__name__ if hasattr(col.dtype, "__name__") else str(col.dtype)
425
+ )
426
+ type_mismatches[col_name] = (expected_name, actual_colnade)
427
+
428
+ # Null checks require computation in Dask
429
+ null_violations: list[str] = []
430
+ for col_name, col in expected_columns.items():
431
+ if col_name not in actual_names:
432
+ continue
433
+ if isinstance(col.dtype, _types.UnionType):
434
+ continue
435
+ if source[col_name].isna().any().compute():
436
+ null_violations.append(col_name)
437
+
438
+ if missing or type_mismatches or null_violations:
439
+ raise SchemaError(
440
+ missing_columns=missing if missing else None,
441
+ type_mismatches=type_mismatches if type_mismatches else None,
442
+ null_violations=null_violations if null_violations else None,
443
+ )
444
+
445
+ def validate_field_constraints(self, source: Any, schema: type[Schema]) -> None:
446
+ """Validate value-level constraints (Field(), @schema_check) on data."""
447
+ from colnade.constraints import ValueViolation, get_column_constraints, get_schema_checks
448
+
449
+ constraints = get_column_constraints(schema)
450
+ checks = get_schema_checks(schema)
451
+ if not constraints and not checks:
452
+ return
453
+
454
+ # Materialize Dask DataFrame for value checks
455
+ pdf = source.compute() if hasattr(source, "compute") else source
456
+ violations: list[ValueViolation] = []
457
+
458
+ for col_name, field_info in constraints.items():
459
+ if col_name not in pdf.columns:
460
+ continue
461
+ series = pdf[col_name].dropna()
462
+
463
+ if field_info.ge is not None:
464
+ mask = series < field_info.ge
465
+ count = int(mask.sum())
466
+ if count > 0:
467
+ samples = series[mask].head(5).tolist()
468
+ violations.append(
469
+ ValueViolation(
470
+ column=col_name,
471
+ constraint=f"ge={field_info.ge!r}",
472
+ got_count=count,
473
+ sample_values=samples,
474
+ )
475
+ )
476
+
477
+ if field_info.gt is not None:
478
+ mask = series <= field_info.gt
479
+ count = int(mask.sum())
480
+ if count > 0:
481
+ samples = series[mask].head(5).tolist()
482
+ violations.append(
483
+ ValueViolation(
484
+ column=col_name,
485
+ constraint=f"gt={field_info.gt!r}",
486
+ got_count=count,
487
+ sample_values=samples,
488
+ )
489
+ )
490
+
491
+ if field_info.le is not None:
492
+ mask = series > field_info.le
493
+ count = int(mask.sum())
494
+ if count > 0:
495
+ samples = series[mask].head(5).tolist()
496
+ violations.append(
497
+ ValueViolation(
498
+ column=col_name,
499
+ constraint=f"le={field_info.le!r}",
500
+ got_count=count,
501
+ sample_values=samples,
502
+ )
503
+ )
504
+
505
+ if field_info.lt is not None:
506
+ mask = series >= field_info.lt
507
+ count = int(mask.sum())
508
+ if count > 0:
509
+ samples = series[mask].head(5).tolist()
510
+ violations.append(
511
+ ValueViolation(
512
+ column=col_name,
513
+ constraint=f"lt={field_info.lt!r}",
514
+ got_count=count,
515
+ sample_values=samples,
516
+ )
517
+ )
518
+
519
+ if field_info.min_length is not None:
520
+ lengths = series.str.len()
521
+ mask = lengths < field_info.min_length
522
+ count = int(mask.sum())
523
+ if count > 0:
524
+ samples = series[mask].head(5).tolist()
525
+ violations.append(
526
+ ValueViolation(
527
+ column=col_name,
528
+ constraint=f"min_length={field_info.min_length}",
529
+ got_count=count,
530
+ sample_values=samples,
531
+ )
532
+ )
533
+
534
+ if field_info.max_length is not None:
535
+ lengths = series.str.len()
536
+ mask = lengths > field_info.max_length
537
+ count = int(mask.sum())
538
+ if count > 0:
539
+ samples = series[mask].head(5).tolist()
540
+ violations.append(
541
+ ValueViolation(
542
+ column=col_name,
543
+ constraint=f"max_length={field_info.max_length}",
544
+ got_count=count,
545
+ sample_values=samples,
546
+ )
547
+ )
548
+
549
+ if field_info.pattern is not None:
550
+ matches = series.str.contains(field_info.pattern, regex=True, na=False)
551
+ mask = ~matches
552
+ count = int(mask.sum())
553
+ if count > 0:
554
+ samples = series[mask].head(5).tolist()
555
+ violations.append(
556
+ ValueViolation(
557
+ column=col_name,
558
+ constraint=f"pattern={field_info.pattern!r}",
559
+ got_count=count,
560
+ sample_values=samples,
561
+ )
562
+ )
563
+
564
+ if field_info.unique:
565
+ dup_mask = series.duplicated(keep=False)
566
+ count = int(dup_mask.sum())
567
+ if count > 0:
568
+ samples = series[dup_mask].unique().tolist()[:5]
569
+ violations.append(
570
+ ValueViolation(
571
+ column=col_name,
572
+ constraint="unique",
573
+ got_count=count,
574
+ sample_values=samples,
575
+ )
576
+ )
577
+
578
+ if field_info.isin is not None:
579
+ mask = ~series.isin(field_info.isin)
580
+ count = int(mask.sum())
581
+ if count > 0:
582
+ samples = series[mask].unique().tolist()[:5]
583
+ violations.append(
584
+ ValueViolation(
585
+ column=col_name,
586
+ constraint=f"isin={list(field_info.isin)!r}",
587
+ got_count=count,
588
+ sample_values=samples,
589
+ )
590
+ )
591
+
592
+ # Cross-column schema checks
593
+ for check in checks:
594
+ expr = check.fn(schema)
595
+ pd_mask = self.translate_expr(expr)
596
+ violation_count = int((~pd_mask(pdf)).sum())
597
+ if violation_count > 0:
598
+ violations.append(
599
+ ValueViolation(
600
+ column="(cross-column)",
601
+ constraint=f"schema_check:{check.name}",
602
+ got_count=violation_count,
603
+ sample_values=[],
604
+ )
605
+ )
606
+
607
+ if violations:
608
+ raise SchemaError(value_violations=violations)
609
+
610
+ # --- Introspection ---
611
+
612
+ def row_count(self, source: Any) -> int:
613
+ return len(source)
614
+
615
+ def iter_row_dicts(self, source: Any) -> Iterator[dict[str, Any]]:
616
+ return source.compute().to_dict(orient="records")
617
+
618
+ # --- Arrow boundary ---
619
+
620
+ def to_arrow_batches(
621
+ self,
622
+ source: Any,
623
+ batch_size: int | None,
624
+ ) -> Iterator[Any]:
625
+ """Convert a Dask DataFrame to an iterator of Arrow RecordBatches."""
626
+ import pyarrow as pa
627
+
628
+ # Iterate partitions for memory-efficient conversion
629
+ for partition in source.partitions:
630
+ pdf: pd.DataFrame = partition.compute()
631
+ table: pa.Table = pa.Table.from_pandas(pdf)
632
+ if batch_size is not None:
633
+ yield from table.to_batches(max_chunksize=batch_size)
634
+ else:
635
+ yield from table.to_batches()
636
+
637
+ def from_arrow_batches(
638
+ self,
639
+ batches: Iterator[Any],
640
+ schema: type[Any],
641
+ ) -> Any:
642
+ """Reconstruct a Dask DataFrame from Arrow RecordBatches."""
643
+ import pyarrow as pa
644
+
645
+ batch_list = list(batches)
646
+ if not batch_list:
647
+ return dd.from_pandas(pd.DataFrame(), npartitions=1)
648
+ table = pa.Table.from_batches(batch_list)
649
+ pdf = table.to_pandas()
650
+ return dd.from_pandas(pdf, npartitions=1)
651
+
652
+ def from_dict(
653
+ self,
654
+ data: dict[str, Sequence[Any]],
655
+ schema: type[Schema],
656
+ ) -> Any:
657
+ """Create a Dask DataFrame from a columnar dict with schema-driven dtypes."""
658
+ from colnade_pandas.conversion import map_colnade_dtype
659
+
660
+ pd_schema = {name: map_colnade_dtype(col.dtype) for name, col in schema._columns.items()}
661
+ pdf = pd.DataFrame(data).astype(pd_schema)
662
+ return dd.from_pandas(pdf, npartitions=1)
@@ -0,0 +1,101 @@
1
+ """Read/write operations for Dask backend."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Sequence
6
+ from typing import Any, TypeVar
7
+
8
+ import dask.dataframe as dd
9
+
10
+ from colnade import DataFrame, LazyFrame, Row, Schema
11
+ from colnade.dataframe import rows_to_dict
12
+ from colnade.validation import ValidationLevel, get_validation_level, is_validation_enabled
13
+ from colnade_dask.adapter import DaskBackend
14
+ from colnade_pandas.conversion import map_colnade_dtype
15
+
16
+ S = TypeVar("S", bound=Schema)
17
+
18
+
19
+ def _build_pandas_schema(schema: type[S]) -> dict[str, Any]:
20
+ """Build a Pandas dtype dict from a Colnade schema."""
21
+ return {name: map_colnade_dtype(col.dtype) for name, col in schema._columns.items()}
22
+
23
+
24
+ def read_parquet(path: str, schema: type[S], **kwargs: Any) -> DataFrame[S]:
25
+ """Read a Parquet file into a typed DataFrame backed by Dask."""
26
+ backend = DaskBackend()
27
+ data = dd.read_parquet(path, **kwargs)
28
+ if is_validation_enabled():
29
+ backend.validate_schema(data, schema)
30
+ if get_validation_level() is ValidationLevel.FULL:
31
+ backend.validate_field_constraints(data, schema)
32
+ return DataFrame(_data=data, _schema=schema, _backend=backend)
33
+
34
+
35
+ def read_csv(path: str, schema: type[S], **kwargs: Any) -> DataFrame[S]:
36
+ """Read a CSV file into a typed DataFrame backed by Dask.
37
+
38
+ Applies the schema's dtype mapping to ensure correct column types.
39
+ """
40
+ backend = DaskBackend()
41
+ pd_schema = _build_pandas_schema(schema)
42
+ data = dd.read_csv(path, dtype=pd_schema, **kwargs)
43
+ if is_validation_enabled():
44
+ backend.validate_schema(data, schema)
45
+ if get_validation_level() is ValidationLevel.FULL:
46
+ backend.validate_field_constraints(data, schema)
47
+ return DataFrame(_data=data, _schema=schema, _backend=backend)
48
+
49
+
50
+ def scan_parquet(path: str, schema: type[S], **kwargs: Any) -> LazyFrame[S]:
51
+ """Lazily scan a Parquet file into a typed LazyFrame backed by Dask."""
52
+ backend = DaskBackend()
53
+ data = dd.read_parquet(path, **kwargs)
54
+ return LazyFrame(_data=data, _schema=schema, _backend=backend)
55
+
56
+
57
+ def scan_csv(path: str, schema: type[S], **kwargs: Any) -> LazyFrame[S]:
58
+ """Lazily scan a CSV file into a typed LazyFrame backed by Dask.
59
+
60
+ Applies the schema's dtype mapping to ensure correct column types.
61
+ """
62
+ backend = DaskBackend()
63
+ pd_schema = _build_pandas_schema(schema)
64
+ data = dd.read_csv(path, dtype=pd_schema, **kwargs)
65
+ return LazyFrame(_data=data, _schema=schema, _backend=backend)
66
+
67
+
68
+ def from_dict(
69
+ schema: type[S],
70
+ data: dict[str, Sequence[Any]],
71
+ ) -> DataFrame[S]:
72
+ """Create a typed DataFrame from a columnar dict.
73
+
74
+ The schema drives dtype coercion — plain Python values (``[1, 2, 3]``)
75
+ are cast to the correct native types (e.g. ``UInt64``).
76
+ """
77
+ backend = DaskBackend()
78
+ return DataFrame.from_dict(data, schema, backend)
79
+
80
+
81
+ def from_rows(
82
+ schema: type[S],
83
+ rows: Sequence[Row[S]],
84
+ ) -> DataFrame[S]:
85
+ """Create a typed DataFrame from ``Row[S]`` instances.
86
+
87
+ The type checker verifies that rows match the schema — passing
88
+ ``Orders.Row`` where ``Users.Row`` is expected is a static error.
89
+ """
90
+ data = rows_to_dict(rows, schema)
91
+ return from_dict(schema, data)
92
+
93
+
94
+ def write_parquet(df: DataFrame[Any], path: str, **kwargs: Any) -> None:
95
+ """Write a DataFrame to a Parquet file."""
96
+ df._data.to_parquet(path, **kwargs)
97
+
98
+
99
+ def write_csv(df: DataFrame[Any], path: str, **kwargs: Any) -> None:
100
+ """Write a DataFrame to a CSV file."""
101
+ df._data.to_csv(path, index=False, single_file=True, **kwargs)