thds.tabularasa 0.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. thds/tabularasa/__init__.py +6 -0
  2. thds/tabularasa/__main__.py +1122 -0
  3. thds/tabularasa/compat.py +33 -0
  4. thds/tabularasa/data_dependencies/__init__.py +0 -0
  5. thds/tabularasa/data_dependencies/adls.py +97 -0
  6. thds/tabularasa/data_dependencies/build.py +573 -0
  7. thds/tabularasa/data_dependencies/sqlite.py +286 -0
  8. thds/tabularasa/data_dependencies/tabular.py +167 -0
  9. thds/tabularasa/data_dependencies/util.py +209 -0
  10. thds/tabularasa/diff/__init__.py +0 -0
  11. thds/tabularasa/diff/data.py +346 -0
  12. thds/tabularasa/diff/schema.py +254 -0
  13. thds/tabularasa/diff/summary.py +249 -0
  14. thds/tabularasa/git_util.py +37 -0
  15. thds/tabularasa/loaders/__init__.py +0 -0
  16. thds/tabularasa/loaders/lazy_adls.py +44 -0
  17. thds/tabularasa/loaders/parquet_util.py +385 -0
  18. thds/tabularasa/loaders/sqlite_util.py +346 -0
  19. thds/tabularasa/loaders/util.py +532 -0
  20. thds/tabularasa/py.typed +0 -0
  21. thds/tabularasa/schema/__init__.py +7 -0
  22. thds/tabularasa/schema/compilation/__init__.py +20 -0
  23. thds/tabularasa/schema/compilation/_format.py +50 -0
  24. thds/tabularasa/schema/compilation/attrs.py +257 -0
  25. thds/tabularasa/schema/compilation/attrs_sqlite.py +278 -0
  26. thds/tabularasa/schema/compilation/io.py +96 -0
  27. thds/tabularasa/schema/compilation/pandas.py +252 -0
  28. thds/tabularasa/schema/compilation/pyarrow.py +93 -0
  29. thds/tabularasa/schema/compilation/sphinx.py +550 -0
  30. thds/tabularasa/schema/compilation/sqlite.py +69 -0
  31. thds/tabularasa/schema/compilation/util.py +117 -0
  32. thds/tabularasa/schema/constraints.py +327 -0
  33. thds/tabularasa/schema/dtypes.py +153 -0
  34. thds/tabularasa/schema/extract_from_parquet.py +132 -0
  35. thds/tabularasa/schema/files.py +215 -0
  36. thds/tabularasa/schema/metaschema.py +1007 -0
  37. thds/tabularasa/schema/util.py +123 -0
  38. thds/tabularasa/schema/validation.py +878 -0
  39. thds/tabularasa/sqlite3_compat.py +41 -0
  40. thds/tabularasa/sqlite_from_parquet.py +34 -0
  41. thds/tabularasa/to_sqlite.py +56 -0
  42. thds_tabularasa-0.13.0.dist-info/METADATA +530 -0
  43. thds_tabularasa-0.13.0.dist-info/RECORD +46 -0
  44. thds_tabularasa-0.13.0.dist-info/WHEEL +5 -0
  45. thds_tabularasa-0.13.0.dist-info/entry_points.txt +2 -0
  46. thds_tabularasa-0.13.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,286 @@
1
+ import datetime
2
+ import io
3
+ import warnings
4
+ from logging import getLogger
5
+ from pathlib import Path
6
+ from typing import Callable, Mapping, Optional, Union
7
+
8
+ import pandas as pd
9
+
10
+ from thds.tabularasa.loaders.parquet_util import TypeCheckLevel, pandas_maybe
11
+ from thds.tabularasa.loaders.sqlite_util import bulk_write_connection, sqlite_preprocessor_for_type
12
+ from thds.tabularasa.loaders.util import PandasParquetLoader
13
+ from thds.tabularasa.schema.compilation.sqlite import render_sql_index_schema, render_sql_table_schema
14
+ from thds.tabularasa.schema.metaschema import Schema, Table, is_build_time_package_table
15
+ from thds.tabularasa.sqlite3_compat import sqlite3
16
+
17
+ from .util import hash_file
18
+
19
+ TABLE_METADATA_TABLE_NAME = "_table_metadata"
20
+ TABLE_METADATA_DDL = f"""CREATE TABLE IF NOT EXISTS {TABLE_METADATA_TABLE_NAME}(
21
+ table_name TEXT PRIMARY KEY,
22
+ n_rows INTEGER NOT NULL,
23
+ data_hash TEXT NOT NULL,
24
+ table_ddl_hash TEXT NOT NULL ,
25
+ index_ddl_hash TEXT
26
+ );"""
27
+
28
+
29
+ def insert_table(
30
+ con: sqlite3.Connection,
31
+ table: Table,
32
+ package: Optional[str],
33
+ data_dir: str,
34
+ filename: Optional[str] = None,
35
+ validate: bool = False,
36
+ check_hash: bool = True,
37
+ type_check: Optional[TypeCheckLevel] = None,
38
+ cast: bool = False,
39
+ ):
40
+ """Insert data for a schema table but only if the table doesn't exist in the database OR it exists
41
+ and the hashes of DDL scripts and data recorded there do not match those derived from auto-generated
42
+ code and data from the schema
43
+ """
44
+ _LOGGER = getLogger(__name__)
45
+ table_ddl = render_sql_table_schema(table)
46
+ index_ddl = render_sql_index_schema(table)
47
+
48
+ table_ddl_hash_expected = hash_file(io.BytesIO(table_ddl.strip().encode()))
49
+ if index_ddl is None:
50
+ index_ddl_hash_expected = None
51
+ else:
52
+ index_ddl_hash_expected = hash_file(io.BytesIO(index_ddl.strip().encode()))
53
+
54
+ loader = PandasParquetLoader.from_schema_table(
55
+ table,
56
+ package=package or None,
57
+ data_dir=data_dir,
58
+ filename=filename,
59
+ derive_schema=validate,
60
+ )
61
+ data_hash_expected = loader.file_hash()
62
+
63
+ _ensure_metadata_table(con)
64
+
65
+ if (not check_hash) or (not table_populated(con, table)):
66
+ insert_data = True
67
+ else:
68
+ row = con.execute(
69
+ f"SELECT data_hash, table_ddl_hash, index_ddl_hash FROM {TABLE_METADATA_TABLE_NAME} "
70
+ "WHERE table_name = ?",
71
+ (table.snake_case_name,),
72
+ ).fetchone()
73
+ if row is None:
74
+ insert_data = True
75
+ else:
76
+ data_hash, table_ddl_hash, index_ddl_hash = row
77
+ if table_ddl_hash != table_ddl_hash_expected:
78
+ _LOGGER.info(
79
+ f"Hash of table DDL doesn't match that recorded in the database; reinserting data "
80
+ f"for {table.name}"
81
+ )
82
+ insert_data = True
83
+ elif index_ddl_hash != index_ddl_hash_expected:
84
+ _LOGGER.info(
85
+ f"Hash of index DDL doesn't match that recorded in the database; reinserting data "
86
+ f"for {table.name}"
87
+ )
88
+ insert_data = True
89
+ elif data_hash != data_hash_expected:
90
+ _LOGGER.info(
91
+ f"Hash of source data doesn't match that recorded in the database; reinserting data "
92
+ f"for {table.name}"
93
+ )
94
+ insert_data = True
95
+ else:
96
+ _LOGGER.info(
97
+ f"Hashes of table DDL, index DDL, and source data match those recorded in the "
98
+ f"database; skipping data insert for {table.name}"
99
+ )
100
+ insert_data = False
101
+
102
+ if insert_data:
103
+ _LOGGER.info(f"Inserting data for {table.name}")
104
+ with con:
105
+ _LOGGER.debug(f"Dropping existing table {table.name}")
106
+ con.execute(f"DROP TABLE IF EXISTS {table.snake_case_name}")
107
+ con.execute(
108
+ f"DELETE FROM {TABLE_METADATA_TABLE_NAME} WHERE table_name = ?",
109
+ (table.snake_case_name,),
110
+ )
111
+
112
+ with con:
113
+ _LOGGER.debug(f"Executing table DDL for {table.name}:\n{table_ddl}")
114
+ con.execute(table_ddl)
115
+ # `postprocess=True` implies that all collection-valued columns will have values that are
116
+ # instances of builtin python types (dicts, lists), and thus that they are handleable by
117
+ # `cattrs` and thus ultimately JSON-serializable
118
+ try:
119
+ batches = loader.load_batched(
120
+ validate=validate, postprocess=True, type_check=type_check, cast=cast
121
+ )
122
+ except KeyError as key_error:
123
+ _LOGGER.error(
124
+ "Column names in parquet file may not match "
125
+ "schema. Try deleting derived files and building again."
126
+ )
127
+ raise Exception(key_error)
128
+
129
+ _LOGGER.debug(f"Inserting data for table {table.name}")
130
+ for df in batches:
131
+ df = _prepare_for_sqlite(df, table)
132
+
133
+ with warnings.catch_warnings():
134
+ warnings.filterwarnings("ignore", "pandas only supports SQLAlchemy connectable")
135
+ df.to_sql(table.snake_case_name, con, if_exists="append", index=False)
136
+
137
+ if index_ddl:
138
+ with con:
139
+ _LOGGER.debug(f"Executing index DDL for table {table.name}:\n{index_ddl}")
140
+ con.executescript(index_ddl)
141
+
142
+ with con:
143
+ # only insert hashes once all table and index creation is complete
144
+ _LOGGER.debug(f"Inserting metadata for {table.name}")
145
+ num_rows = loader.num_rows()
146
+ con.execute(
147
+ f"INSERT INTO {TABLE_METADATA_TABLE_NAME}"
148
+ "(table_name, n_rows, data_hash, table_ddl_hash, index_ddl_hash) VALUES (?, ?, ?, ?, ?)",
149
+ (
150
+ table.snake_case_name,
151
+ num_rows,
152
+ data_hash_expected,
153
+ table_ddl_hash_expected,
154
+ index_ddl_hash_expected,
155
+ ),
156
+ )
157
+
158
+
159
+ def _prepare_for_sqlite(df: pd.DataFrame, table: Table) -> pd.DataFrame:
160
+ """Only meant for use right before `df` is inserted into a sqlite table. Mutates `df` by converting
161
+ collection-valued columns to JSON literals"""
162
+ if table.primary_key:
163
+ df.reset_index(inplace=True, drop=False)
164
+
165
+ for column in table.columns:
166
+ name = column.snake_case_name
167
+ pytype = column.type.python
168
+ if pytype is datetime.date:
169
+ # pandas has only one type for dates/datetimes and uses a datetime ISO format for it, whereas
170
+ # our sqlite adapters use ISO formats specific to dates and datetimes; just cast to regular
171
+ # `datetime.date` here and the sqlite API will handle the conversion to ISO format since
172
+ # we're using `detect_types=sqlite3.PARSE_DECLTYPES`
173
+ df[name] = df[name].dt.date
174
+ else:
175
+ preproc = sqlite_preprocessor_for_type(pytype) # type: ignore
176
+ if preproc is not None:
177
+ preproc_ = pandas_maybe(preproc) if column.nullable else preproc
178
+ df[name] = df[name].apply(preproc_)
179
+
180
+ return df
181
+
182
+
183
+ def _ensure_metadata_table(con: sqlite3.Connection):
184
+ if not table_exists(con, TABLE_METADATA_TABLE_NAME):
185
+ con.execute(TABLE_METADATA_DDL)
186
+
187
+
188
+ def table_populated(con: sqlite3.Connection, table: Table) -> bool:
189
+ """Return True if a table named `table` exists in the database represented by `con` and has any rows,
190
+ False otherwise"""
191
+ try:
192
+ cur = con.execute(
193
+ f"SELECT {', '.join(c.snake_case_name for c in table.columns)} FROM {table.snake_case_name} limit 1"
194
+ )
195
+ except sqlite3.Error:
196
+ return False
197
+ else:
198
+ results = list(cur)
199
+ return bool(results)
200
+
201
+
202
+ def table_exists(con: sqlite3.Connection, table: Union[str, Table]) -> bool:
203
+ table_name = table if isinstance(table, str) else table.snake_case_name
204
+ try:
205
+ con.execute(f"SELECT * FROM {table_name} limit 1")
206
+ except sqlite3.Error:
207
+ return False
208
+ return True
209
+
210
+
211
+ def populate_sqlite_db(
212
+ schema: Schema,
213
+ db_package: Optional[str],
214
+ db_path: str,
215
+ data_package: Optional[str],
216
+ data_dir: str,
217
+ transient_data_dir: str,
218
+ validate: bool = False,
219
+ check_hash: bool = True,
220
+ type_check: Optional[TypeCheckLevel] = None,
221
+ cast: bool = False,
222
+ table_predicate: Callable[[Table], bool] = is_build_time_package_table,
223
+ data_path_overrides: Optional[Mapping[str, Path]] = None,
224
+ ):
225
+ """Populate a sqlite database with data for a set of tables from a `reference_data.schema.Schema`
226
+
227
+ :param schema: the `reference_data.schema.Schema` object defining the data to be inserted
228
+ :param db_package: name of the package where the database file is stored, if any. In case `None` is
229
+ passed, `db_path` refers to an ordinary file.
230
+ :param db_path: path to the sqlite database archive file in which the data will be inserted
231
+ :param data_package: optional package name, to be used if the data files are distributed as package
232
+ data
233
+ :param data_dir: path to the directory where the table parquet files are stored (relative to
234
+ `data_package` if it is passed). Ignored for any tables specified in `data_path_overrides` - see
235
+ below
236
+ :param transient_data_dir: path to the directory where transient table parquet files are stored
237
+ (relative to `data_package` if it is passed). Ignored for any tables specified in
238
+ `data_path_overrides` - see below
239
+ :param validate: if True, validate tables against their pandera schemas on load before inserting the
240
+ data into the database
241
+ :param check_hash: if True, skip tables whose data has already been inserted into the database, as
242
+ indicated by a hash of file contents and DDL statements in the _table_metadata table in the
243
+ database
244
+ :param type_check: optional `reference_data.loaders.parquet_util.TypeCheckLevel`. If given, the
245
+ arrow schemas of the files to be loaded will be checked against their expected arrow schemas as
246
+ derived from `schema` before read. This is a very efficient check as it requires no data to be
247
+ read. Useful for loading tables at run time as a quick validity check. This is passed to the same
248
+ keyword argument of `reference_data.loaders.util.PandasParquetLoader.__call__`.
249
+ :param cast: indicates that a safe cast of the parquet data should be performed on load using
250
+ `pyarrow`, in case the file arrow schema doesn't match the expected one exactly. This is passed to
251
+ the same keyword argument of `reference_data.loaders.util.PandasParquetLoader.__call__`.
252
+ :param table_predicate: Optional predicate indicating which tables from `schema.tables` should be
253
+ inserted into the database. If not given, all tables in the schema will be inserted.
254
+ :param data_path_overrides: Optional mapping from table name to the file path where the parquet
255
+ data for the table is to be loaded from. Any table whose name is a key in this mapping will be
256
+ loaded from the associated file path as a normal file (`data_package` and `data_dir` will be
257
+ ignored). This is useful for specifying dynamic run-time-installed tables.
258
+ """
259
+ # gather all tables before executing any I/O
260
+ insert_tables = [table for table in schema.filter_tables(table_predicate) if table.has_indexes]
261
+
262
+ with bulk_write_connection(db_path, db_package, close=True) as con:
263
+ for table in insert_tables:
264
+ table_filename: Optional[str]
265
+ table_package: Optional[str]
266
+ if data_path_overrides and table.name in data_path_overrides:
267
+ data_path = Path(data_path_overrides[table.name]).absolute()
268
+ table_package = None
269
+ table_data_dir = str(data_path.parent)
270
+ table_filename = data_path.name
271
+ else:
272
+ table_package = data_package
273
+ table_data_dir = transient_data_dir if table.transient else data_dir
274
+ table_filename = None
275
+
276
+ insert_table(
277
+ con=con,
278
+ table=table,
279
+ package=table_package,
280
+ data_dir=table_data_dir,
281
+ filename=table_filename,
282
+ validate=validate,
283
+ check_hash=check_hash,
284
+ type_check=type_check,
285
+ cast=cast,
286
+ )
@@ -0,0 +1,167 @@
1
+ from typing import AbstractSet, Any, Callable, Dict, List, Optional, TypeVar, cast
2
+
3
+ import pandas as pd
4
+ import pandera as pa
5
+
6
+ from thds.tabularasa.loaders.sqlite_util import sqlite_postprocessor_for_type
7
+ from thds.tabularasa.schema import metaschema
8
+ from thds.tabularasa.schema.dtypes import DType, PyType
9
+ from thds.tabularasa.schema.files import TabularFileSource
10
+
11
+ from .util import check_categorical_values
12
+
13
+ T = TypeVar("T")
14
+ K = TypeVar("K", bound=PyType)
15
+ V = TypeVar("V", bound=PyType)
16
+
17
+ BOOL_CONSTANTS = {
18
+ "true": True,
19
+ "false": False,
20
+ "t": True,
21
+ "f": False,
22
+ "yes": True,
23
+ "no": False,
24
+ "y": True,
25
+ "n": False,
26
+ "1": True,
27
+ "0": False,
28
+ }
29
+ JSON_NULL = "null"
30
+
31
+
32
+ class PandasCSVLoader:
33
+ """Base interface for loading package data CSV files as pandas.DataFrames
34
+ This is only for use at build time"""
35
+
36
+ def __init__(self, table: metaschema.Table, schema: Optional[pa.DataFrameSchema] = None):
37
+ if not isinstance(table.dependencies, TabularFileSource):
38
+ raise ValueError(
39
+ f"Table '{table.name}' has no single tablular text file source of truth; it depends on "
40
+ f"{table.dependencies}"
41
+ )
42
+
43
+ self.schema = schema
44
+ self.table = table
45
+ self.header = [column.header_name for column in self.table.columns]
46
+ self.rename = {column.header_name: column.name for column in self.table.columns}
47
+ # set the primary key as the index
48
+ self.index_cols: List[str] = list(self.table.primary_key) if self.table.primary_key else []
49
+ self.cols = [c.name for c in self.table.columns if c.name not in self.index_cols]
50
+ # pass these to the csv parser as parse_dates
51
+ self.parse_date_cols = [
52
+ column.header_name
53
+ for column in self.table.columns
54
+ if column.dtype in (DType.DATE, DType.DATETIME)
55
+ ]
56
+ self.na_values = table.csv_na_values
57
+ self.dtypes = {}
58
+ self.dtypes_for_csv_read = {}
59
+ self.converters: Dict[str, Callable[[str], Any]] = {}
60
+ for column in self.table.columns:
61
+ dtype = column.pandas(index=column.name in self.index_cols)
62
+ self.dtypes[column.name] = dtype
63
+ na_values_for_col = self.na_values.get(column.header_name)
64
+ if column.dtype == DType.BOOL:
65
+ # we have a custom converter (parser) for boolean values.
66
+ # converters override na_values in pandas.read_csv, so we have to specify them here.
67
+ self.converters[column.header_name] = (
68
+ parse_optional(parse_bool, na_values_for_col) if na_values_for_col else parse_bool
69
+ )
70
+ elif isinstance(column.dtype, (metaschema.ArrayType, metaschema.MappingType)):
71
+ # parse as json - again a custom converter which overrides na_values
72
+ converter = sqlite_postprocessor_for_type(column.dtype.python)
73
+ assert converter is not None # converter for a structured type will not be None
74
+ self.converters[column.header_name] = cast(
75
+ Callable[[str], Any],
76
+ parse_optional(converter, na_values_for_col) if na_values_for_col else converter,
77
+ )
78
+ elif column.header_name not in self.parse_date_cols:
79
+ # read_csv requires passing `parse_dates` for determining date-typed columns
80
+ # also do NOT tell pandas.read_csv you want an enum; it will mangle unknown values to null!
81
+ self.dtypes_for_csv_read[column.header_name] = (
82
+ dtype
83
+ if column.type.enum is None
84
+ else column.dtype.pandas(
85
+ nullable=column.nullable, index=column.name in self.index_cols
86
+ )
87
+ )
88
+
89
+ def __call__(self, validate: bool = False) -> pd.DataFrame:
90
+ if validate and self.schema is None:
91
+ raise ValueError(f"Can't validate table {self.table.name} with no schema")
92
+
93
+ df = self.read()
94
+ df = self.postprocess(df)
95
+ # schema not None only to make mypy happy - error thrown above in case it's required
96
+ return self.schema.validate(df) if (validate and self.schema is not None) else df
97
+
98
+ def read(self):
99
+ # make mypy happy; this is checked in __init__
100
+ deps = self.table.dependencies
101
+ assert isinstance(deps, TabularFileSource)
102
+ with deps.file_handle as f:
103
+ df = pd.read_csv( # type: ignore
104
+ f,
105
+ usecols=self.header,
106
+ dtype=self.dtypes_for_csv_read,
107
+ parse_dates=self.parse_date_cols,
108
+ converters=self.converters,
109
+ skiprows=deps.skiprows or 0,
110
+ encoding=deps.encoding,
111
+ dialect=deps.csv_dialect,
112
+ na_values={k: sorted(v) for k, v in self.na_values.items()},
113
+ # without this, pandas adds in its own extensive set of strings to interpret as null.
114
+ # we force the user to be explicit about the values they want to parse as null.
115
+ keep_default_na=False,
116
+ )
117
+
118
+ return df
119
+
120
+ def postprocess(
121
+ self,
122
+ df: pd.DataFrame,
123
+ ):
124
+ """Ensure correct column names, column order, dtypes and index. Mutates `df` in-place"""
125
+ df.rename(columns=self.rename, inplace=True)
126
+
127
+ # pandas silently nullifies values not matching a categorical dtype!
128
+ # so we have to do this ourselves before we coerce with .astype below
129
+ for col in self.table.columns:
130
+ name = col.name
131
+ dtype = self.dtypes[name]
132
+ if isinstance(dtype, pd.CategoricalDtype):
133
+ check_categorical_values(df[name], dtype)
134
+
135
+ df = df.astype(self.dtypes, copy=False)
136
+ if self.index_cols:
137
+ df.set_index(self.index_cols, inplace=True)
138
+
139
+ if list(df.columns) != self.cols:
140
+ df = df[self.cols]
141
+
142
+ return df
143
+
144
+
145
+ # CSV parsing for complex types
146
+
147
+
148
+ def identity(x):
149
+ return x
150
+
151
+
152
+ def parse_optional(
153
+ func: Callable[[str], V], null_values: AbstractSet[str] = frozenset([""])
154
+ ) -> Callable[[str], Optional[V]]:
155
+ """Turn a csv parser for a type V into a parser for Optional[V] by treating the empty string as a
156
+ null value"""
157
+
158
+ def parse(s: str) -> Optional[V]:
159
+ if s in null_values:
160
+ return None
161
+ return func(s)
162
+
163
+ return parse
164
+
165
+
166
+ def parse_bool(s: str) -> bool:
167
+ return BOOL_CONSTANTS[s.lower()]
@@ -0,0 +1,209 @@
1
+ import hashlib
2
+ import multiprocessing
3
+ import multiprocessing.connection
4
+ import os
5
+ import warnings
6
+ from functools import partial, wraps
7
+ from logging import getLogger
8
+ from pathlib import Path
9
+ from typing import IO, Callable, Dict, List, Optional, TypeVar, Union, cast
10
+
11
+ import pandas as pd
12
+ import pkg_resources
13
+ import pyarrow
14
+
15
+ from thds.tabularasa.data_dependencies.adls import ADLSDownloadResult
16
+ from thds.tabularasa.loaders.parquet_util import pandas_maybe, preprocessor_for_pyarrow_type
17
+ from thds.tabularasa.schema.files import LocalDataSpec
18
+ from thds.tabularasa.schema.metaschema import Table
19
+ from thds.tabularasa.schema.util import Identifier, import_func
20
+
21
+ PARQUET_FORMAT_VERSION = "2.4"
22
+ HASH_FILE_BUFFER_SIZE = 2**16
23
+ FILENAME_YEAR_REGEX = r"(?:[^\d])(20\d{2})(?:[^\d])"
24
+ FILENAME_QUARTER_REGEX = r"(?:[^\d])([Qq]\d{1})(?:[^\d])"
25
+
26
+ MaterializedPackageDataDeps = Dict[Identifier, pd.DataFrame]
27
+ SyncedADLSDeps = Dict[Identifier, List[ADLSDownloadResult]]
28
+ RawLocalDataDeps = Dict[Identifier, LocalDataSpec]
29
+ DataPreprocessor = Callable[
30
+ [MaterializedPackageDataDeps, SyncedADLSDeps, RawLocalDataDeps], pd.DataFrame
31
+ ]
32
+ F = TypeVar("F", bound=Callable)
33
+
34
+
35
+ def package_data_file_size(package: str, path: str) -> int:
36
+ os_path = pkg_resources.resource_filename(package, path)
37
+ return os.stat(os_path).st_size
38
+
39
+
40
+ def run_in_subprocess(func: F) -> F:
41
+ """Decorator to cause a side-effect-producing routine to run in a subprocess. Isolates memory
42
+ consumption of the function call to the subprocess to ensure that all memory resources are reclaimed
43
+ at the end of the call. For example, `pyarrow` is known to be quite aggressive with memory allocation
44
+ and reluctant to free consumed memory. For example, a routine that simply reads a parquet file using
45
+ `pyarrow` and then writes the result somewhere else would benefit from use of this decorator.
46
+ """
47
+
48
+ @wraps(func)
49
+ def subprocess_func(*args, _subprocess: bool = True, **kwargs):
50
+ # extra _subprocess arg required to avoid trying to send the wrapped `func` to a subprocess,
51
+ # which in python 3.8+ results in a pickling error since it isn't the same object as the
52
+ # function importable at its own module/name (it's been replaced by `subprocess_func`)
53
+ if _subprocess:
54
+ recv_con, send_con = multiprocessing.Pipe(duplex=False)
55
+ proc = multiprocessing.Process(
56
+ target=SubprocessFunc(subprocess_func),
57
+ args=(send_con, *args),
58
+ kwargs=dict(_subprocess=False, **kwargs),
59
+ )
60
+ proc.start()
61
+ try:
62
+ result, exc = recv_con.recv()
63
+ proc.join()
64
+ except Exception as e:
65
+ # communication error, e.g. unpicklable return value
66
+ raise e
67
+ else:
68
+ if exc is not None:
69
+ raise exc
70
+ finally:
71
+ proc.close()
72
+ else:
73
+ result = func(*args, **kwargs)
74
+
75
+ return result
76
+
77
+ return cast(F, subprocess_func)
78
+
79
+
80
+ class SubprocessFunc:
81
+ def __init__(self, func):
82
+ self.func = func
83
+
84
+ def __call__(self, con: multiprocessing.connection.Connection, *args, **kwargs):
85
+ exc: Optional[Exception]
86
+ try:
87
+ result = self.func(*args, **kwargs)
88
+ except Exception as e:
89
+ exc = e
90
+ result = None
91
+ else:
92
+ exc = None
93
+ con.send((result, exc))
94
+
95
+
96
+ def hash_file(file: Union[Path, str, IO[bytes]]) -> str:
97
+ """MD5 hash of the contents of a file (specified by path or passed directly as a handle)"""
98
+ io: IO[bytes]
99
+ if isinstance(file, (str, Path)):
100
+ io = open(file, "rb")
101
+ close = True
102
+ else:
103
+ io = file
104
+ close = False
105
+
106
+ hash_ = hashlib.md5()
107
+ for bytes_ in iter(partial(io.read, HASH_FILE_BUFFER_SIZE), b""):
108
+ hash_.update(bytes_)
109
+
110
+ if close:
111
+ io.close()
112
+
113
+ return hash_.hexdigest()
114
+
115
+
116
+ def import_data_preprocessor(path: str) -> DataPreprocessor:
117
+ return import_func(path)
118
+
119
+
120
+ def arrow_table_for_parquet_write(df: pd.DataFrame, table: Table) -> pyarrow.Table:
121
+ """Preprocess a dataframe with possibly complex object types in preparation to write to a parquet
122
+ file. Casts dicts to lists since pyarrow expects lists or arrays of key-value tuples as the
123
+ represenation of mapping types. Also casts any types with different kinds - e.g. if a float column is
124
+ expected but an int column is passed. Possibly mutates input as this should only be called on a table
125
+ which is about to be saved as a parquet file and then garbage-collected."""
126
+ logger = getLogger(__name__)
127
+
128
+ if any(df.index.names):
129
+ df.reset_index(inplace=True)
130
+
131
+ table_columns = {c.name for c in table.columns}
132
+ extra_columns = [c for c in df.columns if c not in table_columns]
133
+ if extra_columns:
134
+ logger.warning(
135
+ f"Extra columns {extra_columns!r} in dataframe but not in schema of table {table.name!r} "
136
+ "will be dropped on parquet write"
137
+ )
138
+
139
+ for column in table.columns:
140
+ field = column.parquet_field
141
+ name = column.name
142
+ pproc = preprocessor_for_pyarrow_type(field.type)
143
+ if pproc is not None:
144
+ if field.nullable:
145
+ pproc = pandas_maybe(pproc)
146
+ df[name] = df[name].apply(pproc)
147
+
148
+ if (enum_constraint := column.dtype.enum) is not None:
149
+ try:
150
+ check_categorical_values(df[name], pd.CategoricalDtype(enum_constraint.enum))
151
+ except ValueError as e:
152
+ # only warn on write since the data may in fact be correct while only the schema needs
153
+ # updating, potentially saving the developer an expensive derivation
154
+ warnings.warn(str(e))
155
+
156
+ if table.primary_key:
157
+ df.sort_values(list(table.primary_key), inplace=True)
158
+
159
+ arrow = pyarrow.Table.from_pandas(df, table.parquet_schema, safe=True)
160
+ # we remove the pandas-related metadata to ensure that insignificant changes e.g. to pandas/pyarrow
161
+ # versions do not effect file hashes. We can safely do this since we don't rely on pandas to infer
162
+ # data types on load, instead using the parquet/arrow schemas directly on load (pyarrow uses the
163
+ # 'ARROW:schema' metadata key to document arrow schemas in a serialized binary format so we don't
164
+ # lose that information by discarding the pandas information)
165
+ meta = arrow.schema.metadata
166
+ meta.pop(b"pandas")
167
+ return arrow.replace_schema_metadata(meta)
168
+
169
+
170
+ def check_categorical_values(col: pd.Series, dtype: pd.CategoricalDtype):
171
+ """Check that values in a column match an expected categorical dtype prior to a write or cast operation.
172
+ This exists to preempt the unfortunate behavior of pandas wherein a cast silently nullifies any values
173
+ which are not in the categories of the target `CategoricalDtype`, resulting in confusing errors (or
174
+ worse - no errors in case null values are tolerated) downstream.
175
+
176
+ :raises TypeError: when the underlying data type of the `series` has a different kind than the categories of the `dtype`
177
+ :raises ValueError: when any values in the `series` are outside the expected set of categories of the `dtype`
178
+ """
179
+ current_dtype = col.dtype
180
+ expected_dtype = dtype.categories.dtype
181
+
182
+ if isinstance(current_dtype, pd.CategoricalDtype):
183
+ current_dtype_kind = current_dtype.categories.dtype.kind
184
+ else:
185
+ current_dtype_kind = current_dtype.kind
186
+
187
+ int_kinds = {"i", "u"}
188
+ if current_dtype_kind != expected_dtype.kind and not (
189
+ current_dtype_kind in int_kinds and expected_dtype.kind in int_kinds
190
+ ):
191
+ raise TypeError(
192
+ f"Column {col.name} is expected to be categorical with underlying data type "
193
+ f"{expected_dtype}, but has incompatible type {current_dtype}"
194
+ )
195
+
196
+ expected_values = dtype.categories
197
+ actual_values = pd.Series(col.dropna().unique())
198
+ bad_values = actual_values[~actual_values.isin(expected_values)]
199
+ if len(bad_values):
200
+ display_max_values = 20
201
+ addendum = (
202
+ f"(truncated to {display_max_values} unique values)"
203
+ if len(bad_values) > display_max_values
204
+ else ""
205
+ )
206
+ raise ValueError(
207
+ f"Column {col.name} is expected to have values in the set {expected_values.tolist()}, "
208
+ f"but also contains values {bad_values[:display_max_values].tolist()}{addendum}"
209
+ )
File without changes