arthur-common 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of arthur-common might be problematic. Click here for more details.
- arthur_common/__init__.py +0 -0
- arthur_common/__version__.py +1 -0
- arthur_common/aggregations/__init__.py +2 -0
- arthur_common/aggregations/aggregator.py +214 -0
- arthur_common/aggregations/functions/README.md +26 -0
- arthur_common/aggregations/functions/__init__.py +25 -0
- arthur_common/aggregations/functions/categorical_count.py +89 -0
- arthur_common/aggregations/functions/confusion_matrix.py +412 -0
- arthur_common/aggregations/functions/inference_count.py +69 -0
- arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
- arthur_common/aggregations/functions/inference_null_count.py +82 -0
- arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
- arthur_common/aggregations/functions/mean_squared_error.py +110 -0
- arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
- arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
- arthur_common/aggregations/functions/numeric_stats.py +90 -0
- arthur_common/aggregations/functions/numeric_sum.py +87 -0
- arthur_common/aggregations/functions/py.typed +0 -0
- arthur_common/aggregations/functions/shield_aggregations.py +752 -0
- arthur_common/aggregations/py.typed +0 -0
- arthur_common/models/__init__.py +0 -0
- arthur_common/models/connectors.py +41 -0
- arthur_common/models/datasets.py +22 -0
- arthur_common/models/metrics.py +227 -0
- arthur_common/models/py.typed +0 -0
- arthur_common/models/schema_definitions.py +420 -0
- arthur_common/models/shield.py +504 -0
- arthur_common/models/task_job_specs.py +78 -0
- arthur_common/py.typed +0 -0
- arthur_common/tools/__init__.py +0 -0
- arthur_common/tools/aggregation_analyzer.py +243 -0
- arthur_common/tools/aggregation_loader.py +59 -0
- arthur_common/tools/duckdb_data_loader.py +329 -0
- arthur_common/tools/functions.py +46 -0
- arthur_common/tools/py.typed +0 -0
- arthur_common/tools/schema_inferer.py +104 -0
- arthur_common/tools/time_utils.py +33 -0
- arthur_common-1.0.1.dist-info/METADATA +74 -0
- arthur_common-1.0.1.dist-info/RECORD +40 -0
- arthur_common-1.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
import typing
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import Any, Callable, get_type_hints
|
|
6
|
+
|
|
7
|
+
from arthur_common.aggregations import (
|
|
8
|
+
AggregationFunction,
|
|
9
|
+
NumericAggregationFunction,
|
|
10
|
+
SketchAggregationFunction,
|
|
11
|
+
)
|
|
12
|
+
from arthur_common.models.metrics import (
|
|
13
|
+
AggregationMetricType,
|
|
14
|
+
AggregationSpecSchema,
|
|
15
|
+
DatasetReference,
|
|
16
|
+
MetricsColumnParameterSchema,
|
|
17
|
+
MetricsDatasetParameterSchema,
|
|
18
|
+
MetricsLiteralParameterSchema,
|
|
19
|
+
MetricsParameterSchemaUnion,
|
|
20
|
+
)
|
|
21
|
+
from arthur_common.models.schema_definitions import (
|
|
22
|
+
DType,
|
|
23
|
+
MetricColumnParameterAnnotation,
|
|
24
|
+
MetricDatasetParameterAnnotation,
|
|
25
|
+
MetricLiteralParameterAnnotation,
|
|
26
|
+
MetricsParameterAnnotationUnion,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class FunctionAnalyzer:
|
|
33
|
+
@staticmethod
|
|
34
|
+
def _python_type_to_scope_dtype(t: Any) -> DType:
|
|
35
|
+
if t is int:
|
|
36
|
+
return DType.INT
|
|
37
|
+
elif t is str:
|
|
38
|
+
return DType.STRING
|
|
39
|
+
elif t is bool:
|
|
40
|
+
return DType.BOOL
|
|
41
|
+
elif t is float:
|
|
42
|
+
return DType.FLOAT
|
|
43
|
+
elif t is uuid.UUID:
|
|
44
|
+
return DType.UUID
|
|
45
|
+
elif t is DatasetReference:
|
|
46
|
+
return DType.UUID
|
|
47
|
+
else:
|
|
48
|
+
raise ValueError(f"Parameter type {t} is not supported.")
|
|
49
|
+
|
|
50
|
+
@staticmethod
|
|
51
|
+
def _get_metric_annotation_from_annotated(
|
|
52
|
+
param_name: str,
|
|
53
|
+
annotation: typing.Annotated, # type: ignore
|
|
54
|
+
) -> MetricsParameterAnnotationUnion | None:
|
|
55
|
+
arthur_metric_annotations = [
|
|
56
|
+
m
|
|
57
|
+
for m in annotation.__metadata__
|
|
58
|
+
if isinstance(m, MetricsParameterAnnotationUnion)
|
|
59
|
+
]
|
|
60
|
+
metric_annotation: MetricsParameterAnnotationUnion | None = None
|
|
61
|
+
if len(arthur_metric_annotations) == 1:
|
|
62
|
+
metric_annotation = arthur_metric_annotations[0]
|
|
63
|
+
if len(arthur_metric_annotations) > 1:
|
|
64
|
+
raise ValueError(
|
|
65
|
+
f"Parameter {param_name} defines more than one metric annotation.",
|
|
66
|
+
)
|
|
67
|
+
return metric_annotation
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def _get_scope_metric_parameter_from_annotation(
|
|
71
|
+
param_name: str,
|
|
72
|
+
param_dtype: DType,
|
|
73
|
+
optional: bool,
|
|
74
|
+
annotation: typing.Annotated, # type: ignore
|
|
75
|
+
) -> MetricsParameterSchemaUnion:
|
|
76
|
+
if annotation is None:
|
|
77
|
+
return MetricsLiteralParameterSchema(
|
|
78
|
+
parameter_key=param_name,
|
|
79
|
+
optional=optional,
|
|
80
|
+
parameter_dtype=param_dtype,
|
|
81
|
+
friendly_name=param_name,
|
|
82
|
+
description=f"A {param_dtype.value} value.",
|
|
83
|
+
)
|
|
84
|
+
elif isinstance(annotation, MetricLiteralParameterAnnotation):
|
|
85
|
+
return MetricsLiteralParameterSchema(
|
|
86
|
+
parameter_key=param_name,
|
|
87
|
+
optional=optional,
|
|
88
|
+
parameter_dtype=param_dtype,
|
|
89
|
+
friendly_name=annotation.friendly_name,
|
|
90
|
+
description=annotation.description,
|
|
91
|
+
)
|
|
92
|
+
elif isinstance(annotation, MetricDatasetParameterAnnotation):
|
|
93
|
+
if param_dtype != DType.UUID:
|
|
94
|
+
raise ValueError(
|
|
95
|
+
f"Dataset parameter {param_name} has type {param_dtype}, but should be a UUID.",
|
|
96
|
+
)
|
|
97
|
+
return MetricsDatasetParameterSchema(
|
|
98
|
+
parameter_key=param_name,
|
|
99
|
+
optional=optional,
|
|
100
|
+
friendly_name=annotation.friendly_name,
|
|
101
|
+
description=annotation.description,
|
|
102
|
+
model_problem_type=annotation.model_problem_type,
|
|
103
|
+
)
|
|
104
|
+
elif isinstance(annotation, MetricColumnParameterAnnotation):
|
|
105
|
+
if param_dtype != DType.STRING:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"Column parameter {param_name} has type {param_dtype}, but should be a string.",
|
|
108
|
+
)
|
|
109
|
+
return MetricsColumnParameterSchema(
|
|
110
|
+
parameter_key=param_name,
|
|
111
|
+
tag_hints=annotation.tag_hints,
|
|
112
|
+
optional=optional,
|
|
113
|
+
source_dataset_parameter_key=annotation.source_dataset_parameter_key,
|
|
114
|
+
allowed_column_types=annotation.allowed_column_types,
|
|
115
|
+
allow_any_column_type=annotation.allow_any_column_type,
|
|
116
|
+
friendly_name=annotation.friendly_name,
|
|
117
|
+
description=annotation.description,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Parameter {param_name} has an unsupported annotation {annotation}.",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
"""
|
|
125
|
+
Returns a list of parameter names, parameter types, scope-specific annotations.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def _extract_parameter_metadata(func: Callable) -> list[MetricsParameterSchemaUnion]: # type: ignore
|
|
130
|
+
parameter_schemas: list[MetricsParameterSchemaUnion] = []
|
|
131
|
+
args = inspect.signature(func).parameters
|
|
132
|
+
for name, param in args.items():
|
|
133
|
+
if name == "self":
|
|
134
|
+
continue
|
|
135
|
+
if name == "ddb_conn":
|
|
136
|
+
continue
|
|
137
|
+
|
|
138
|
+
if param.annotation == inspect.Parameter.empty:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
f"{func.__name__} must provide type annotation for parameter {name}.",
|
|
141
|
+
)
|
|
142
|
+
parameter_schemas.append(
|
|
143
|
+
FunctionAnalyzer._get_scope_metric_parameter_from_annotation(
|
|
144
|
+
name,
|
|
145
|
+
FunctionAnalyzer._python_type_to_scope_dtype(
|
|
146
|
+
get_type_hints(func)[name],
|
|
147
|
+
),
|
|
148
|
+
param.default != inspect.Parameter.empty,
|
|
149
|
+
(
|
|
150
|
+
FunctionAnalyzer._get_metric_annotation_from_annotated(
|
|
151
|
+
name,
|
|
152
|
+
param.annotation,
|
|
153
|
+
)
|
|
154
|
+
if typing.get_origin(param.annotation) is typing.Annotated
|
|
155
|
+
else None
|
|
156
|
+
),
|
|
157
|
+
),
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
return parameter_schemas
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def analyze_aggregation_function(agg_func: type) -> AggregationSpecSchema:
|
|
164
|
+
# Check if X is a subclass of AggregationFunction
|
|
165
|
+
if not issubclass(agg_func, AggregationFunction):
|
|
166
|
+
raise TypeError(
|
|
167
|
+
f"Class {agg_func.__name__} is not a subclass of AggregationFunction.",
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
if issubclass(agg_func, NumericAggregationFunction):
|
|
171
|
+
metric_type = AggregationMetricType.NUMERIC
|
|
172
|
+
elif issubclass(agg_func, SketchAggregationFunction):
|
|
173
|
+
metric_type = AggregationMetricType.SKETCH
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(
|
|
176
|
+
f"Class {agg_func.__name__} is not a subclass of SketchAggregationFunction, NumericAggregationFunction.",
|
|
177
|
+
)
|
|
178
|
+
# Check if X implements the required methods
|
|
179
|
+
required_methods = ["aggregate", "id", "description", "display_name"]
|
|
180
|
+
static_methods = ["description", "id", "display_name"]
|
|
181
|
+
for method in required_methods:
|
|
182
|
+
if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
|
|
183
|
+
raise AttributeError(
|
|
184
|
+
f"Class {agg_func.__name__} does not implement {method} method.",
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
for method in static_methods:
|
|
188
|
+
if not is_static_method(getattr(agg_func, method)):
|
|
189
|
+
raise AttributeError(f"Method {method} should be a staticmethod.")
|
|
190
|
+
# Check if X passes the ABC implementation:
|
|
191
|
+
try:
|
|
192
|
+
agg_func()
|
|
193
|
+
except TypeError as e:
|
|
194
|
+
if "Can't instantiate abstract class" in str(e):
|
|
195
|
+
logger.error(str(e))
|
|
196
|
+
raise TypeError(
|
|
197
|
+
f"Class {agg_func.__name__} does not implement all the base class functions.",
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
# This is okay, it just means we didn't supply proper args to the __init__ function. The ABC mismatch would throw before this, so it must have passed
|
|
201
|
+
pass
|
|
202
|
+
|
|
203
|
+
aggregation_init_args: list[MetricsParameterSchemaUnion] = []
|
|
204
|
+
|
|
205
|
+
# This is necessary because all the way down in the ABC class, some __init__ function is defined which we don't care about. Users should be able to exclude an init function and this allows them to do that.
|
|
206
|
+
if has_custom_init(agg_func):
|
|
207
|
+
aggregation_init_args = FunctionAnalyzer._extract_parameter_metadata(
|
|
208
|
+
agg_func.__init__,
|
|
209
|
+
)
|
|
210
|
+
aggregate_args = FunctionAnalyzer._extract_parameter_metadata(
|
|
211
|
+
agg_func.aggregate,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
aggregation_id = agg_func.id()
|
|
215
|
+
aggregation_description = agg_func.description()
|
|
216
|
+
|
|
217
|
+
return AggregationSpecSchema(
|
|
218
|
+
name=agg_func.display_name(),
|
|
219
|
+
id=aggregation_id,
|
|
220
|
+
# TODO: Require description, version
|
|
221
|
+
description=aggregation_description,
|
|
222
|
+
# version=0,
|
|
223
|
+
metric_type=metric_type,
|
|
224
|
+
init_args=aggregation_init_args,
|
|
225
|
+
aggregate_args=aggregate_args,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def has_custom_init(cls: type) -> bool:
|
|
230
|
+
init_method = getattr(cls, "__init__", None)
|
|
231
|
+
base_init_method = (
|
|
232
|
+
getattr(cls.__base__, "__init__", None) if hasattr(cls, "__base__") else None
|
|
233
|
+
)
|
|
234
|
+
return init_method is not base_init_method
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def is_static_method(method: type) -> bool:
|
|
238
|
+
if inspect.isfunction(method):
|
|
239
|
+
# Check if the method accepts no arguments or only default arguments
|
|
240
|
+
argspec = inspect.getfullargspec(method)
|
|
241
|
+
if len(argspec.args) == 0 and not argspec.varargs and not argspec.varkw:
|
|
242
|
+
return True
|
|
243
|
+
return False
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import logging
|
|
3
|
+
from types import ModuleType
|
|
4
|
+
from typing import Type
|
|
5
|
+
|
|
6
|
+
import arthur_common.aggregations as agg_module
|
|
7
|
+
from arthur_common.aggregations.aggregator import (
|
|
8
|
+
AggregationFunction,
|
|
9
|
+
NumericAggregationFunction,
|
|
10
|
+
SketchAggregationFunction,
|
|
11
|
+
)
|
|
12
|
+
from arthur_common.models.metrics import AggregationSpecSchema
|
|
13
|
+
from arthur_common.tools.aggregation_analyzer import FunctionAnalyzer
|
|
14
|
+
|
|
15
|
+
logging.basicConfig(level=logging.INFO)
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class AggregationLoader:
|
|
20
|
+
@staticmethod
|
|
21
|
+
def load_aggregations() -> (
|
|
22
|
+
list[tuple[AggregationSpecSchema, Type[AggregationFunction]]]
|
|
23
|
+
):
|
|
24
|
+
def find_subclasses(
|
|
25
|
+
module: ModuleType,
|
|
26
|
+
base_classes: tuple[type, ...],
|
|
27
|
+
visited: set[ModuleType] = set(),
|
|
28
|
+
) -> set[type]:
|
|
29
|
+
subclasses = set()
|
|
30
|
+
visited.add(module)
|
|
31
|
+
for name, obj in inspect.getmembers(module):
|
|
32
|
+
if inspect.isclass(obj) and issubclass(obj, base_classes):
|
|
33
|
+
subclasses.add(obj)
|
|
34
|
+
elif inspect.ismodule(obj) and obj not in visited:
|
|
35
|
+
subclasses.update(find_subclasses(obj, base_classes, visited))
|
|
36
|
+
return subclasses
|
|
37
|
+
|
|
38
|
+
base_classes = (SketchAggregationFunction, NumericAggregationFunction)
|
|
39
|
+
agg_functions = find_subclasses(agg_module, base_classes)
|
|
40
|
+
aggregation_specs = []
|
|
41
|
+
"""
|
|
42
|
+
This seems to pick up duplicate functions somehow in different namespaces, ie:
|
|
43
|
+
<class 'categorical_count.CategoricalCountAggregationFunction'>
|
|
44
|
+
and
|
|
45
|
+
<class 'arthur_common.aggregations.functions.categorical_count.CategoricalCountAggregationFunction'>
|
|
46
|
+
so dedupe by id
|
|
47
|
+
"""
|
|
48
|
+
aggregation_ids = set()
|
|
49
|
+
for agg_function in agg_functions:
|
|
50
|
+
try:
|
|
51
|
+
func_spec = FunctionAnalyzer.analyze_aggregation_function(agg_function)
|
|
52
|
+
logger.info(f"Found agg function {agg_function}")
|
|
53
|
+
if func_spec.id in aggregation_ids:
|
|
54
|
+
continue
|
|
55
|
+
aggregation_specs.append((func_spec, agg_function))
|
|
56
|
+
aggregation_ids.add(func_spec.id)
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.error(f"Failed to load aggregation function {agg_function}: {e}")
|
|
59
|
+
return aggregation_specs
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
import duckdb
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from arthur_common.models.datasets import DatasetJoinKind
|
|
7
|
+
from arthur_common.models.schema_definitions import (
|
|
8
|
+
DatasetListType,
|
|
9
|
+
DatasetObjectType,
|
|
10
|
+
DatasetScalarType,
|
|
11
|
+
DatasetSchema,
|
|
12
|
+
DType,
|
|
13
|
+
)
|
|
14
|
+
from dateutil.parser import parse
|
|
15
|
+
from fsspec import filesystem
|
|
16
|
+
from pydantic import BaseModel
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ColumnFormat(BaseModel):
|
|
20
|
+
source_name: str
|
|
21
|
+
alias: str
|
|
22
|
+
format: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DuckDBOperator:
|
|
26
|
+
"""
|
|
27
|
+
Loads data into a DuckDB table.
|
|
28
|
+
|
|
29
|
+
If no schema is supplied, the output table will contain columns with names equal to the source names in the data.
|
|
30
|
+
If a schema is applied, the column names in the output table will be aliases equal to the column id from the schema.
|
|
31
|
+
This allows for consistent column naming across different data sources.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def load_data_to_duckdb(
|
|
36
|
+
data: list[dict[str, Any]] | pd.DataFrame,
|
|
37
|
+
preprocess_schema: bool = False,
|
|
38
|
+
table_name: str = "inferences",
|
|
39
|
+
conn: duckdb.DuckDBPyConnection | None = None,
|
|
40
|
+
schema: DatasetSchema | None = None,
|
|
41
|
+
) -> duckdb.DuckDBPyConnection:
|
|
42
|
+
if not conn:
|
|
43
|
+
conn = duckdb.connect()
|
|
44
|
+
|
|
45
|
+
if type(data) == list:
|
|
46
|
+
DuckDBOperator._load_unstructured_data(data, table_name, conn, schema)
|
|
47
|
+
elif type(data) == pd.DataFrame:
|
|
48
|
+
DuckDBOperator._load_structured_data(
|
|
49
|
+
data,
|
|
50
|
+
preprocess_schema,
|
|
51
|
+
table_name,
|
|
52
|
+
conn,
|
|
53
|
+
schema,
|
|
54
|
+
)
|
|
55
|
+
else:
|
|
56
|
+
raise ValueError(f"Unsupported data type: {type(data)}")
|
|
57
|
+
return conn
|
|
58
|
+
|
|
59
|
+
"""
|
|
60
|
+
Rename columns from ids to their friendly names based on schema.column_names
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def apply_alias_mask(
|
|
65
|
+
table_name: str,
|
|
66
|
+
conn: duckdb.DuckDBPyConnection,
|
|
67
|
+
schema: DatasetSchema,
|
|
68
|
+
) -> None:
|
|
69
|
+
old_new_mask = {
|
|
70
|
+
str(col_id): schema.column_names[col_id] for col_id in schema.column_names
|
|
71
|
+
}
|
|
72
|
+
DuckDBOperator._apply_alias_mask(table_name, conn, old_new_mask)
|
|
73
|
+
|
|
74
|
+
@staticmethod
|
|
75
|
+
def _apply_alias_mask(
|
|
76
|
+
table_name: str,
|
|
77
|
+
conn: duckdb.DuckDBPyConnection,
|
|
78
|
+
old_new_mask: dict[str, str],
|
|
79
|
+
) -> None:
|
|
80
|
+
for old, new in old_new_mask.items():
|
|
81
|
+
# Don't quote the join column names, since they're already quoted as part of escape_identifier's output
|
|
82
|
+
alter_query = f"ALTER TABLE {table_name} RENAME COLUMN {escape_identifier(old)} TO {escape_identifier(new)}"
|
|
83
|
+
conn.sql(alter_query)
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _load_unstructured_data(
|
|
87
|
+
data: list[dict[str, Any]],
|
|
88
|
+
table_name: str,
|
|
89
|
+
conn: duckdb.DuckDBPyConnection,
|
|
90
|
+
schema: DatasetSchema | None,
|
|
91
|
+
) -> None:
|
|
92
|
+
with filesystem("memory").open(f"inferences.json", "w") as file:
|
|
93
|
+
file.write(json.dumps(data))
|
|
94
|
+
conn.register_filesystem(filesystem("memory"))
|
|
95
|
+
|
|
96
|
+
if schema:
|
|
97
|
+
column_formats = make_duckdb_dataset_schema(schema)
|
|
98
|
+
|
|
99
|
+
key_value_pairs = [
|
|
100
|
+
f"{escape_identifier(col.source_name)}: '{col.format}'"
|
|
101
|
+
for col in column_formats
|
|
102
|
+
]
|
|
103
|
+
stringified_schema = ", ".join([f"{kv}" for kv in key_value_pairs])
|
|
104
|
+
stringified_schema = f"{{ {stringified_schema} }}"
|
|
105
|
+
|
|
106
|
+
read_stmt = f"read_json('memory://inferences.json', format='array', columns={stringified_schema})"
|
|
107
|
+
else:
|
|
108
|
+
read_stmt = "read_json_auto('memory://inferences.json')"
|
|
109
|
+
|
|
110
|
+
conn.sql(
|
|
111
|
+
f"CREATE OR REPLACE TEMP TABLE {table_name} AS SELECT * FROM {read_stmt}",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if schema:
|
|
115
|
+
old_new_mask = {}
|
|
116
|
+
for col in column_formats:
|
|
117
|
+
old_new_mask[col.source_name] = col.alias
|
|
118
|
+
DuckDBOperator._apply_alias_mask(table_name, conn, old_new_mask)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _load_structured_data(
|
|
122
|
+
data: pd.DataFrame,
|
|
123
|
+
preprocess_schema: bool,
|
|
124
|
+
table_name: str,
|
|
125
|
+
conn: duckdb.DuckDBPyConnection,
|
|
126
|
+
schema: DatasetSchema | None,
|
|
127
|
+
) -> None:
|
|
128
|
+
if preprocess_schema:
|
|
129
|
+
data = DuckDBOperator._preprocess_dataframe_schema_inference(data)
|
|
130
|
+
|
|
131
|
+
if schema:
|
|
132
|
+
column_formats = make_duckdb_dataset_schema(schema)
|
|
133
|
+
|
|
134
|
+
key_value_pairs = [
|
|
135
|
+
f"{escape_identifier(col.source_name)} {col.format}"
|
|
136
|
+
for col in column_formats
|
|
137
|
+
]
|
|
138
|
+
stringified_schema = ", ".join([f"{kv}" for kv in key_value_pairs])
|
|
139
|
+
create_table_stmt = (
|
|
140
|
+
f"CREATE OR REPLACE TEMP TABLE {table_name} ({stringified_schema});"
|
|
141
|
+
)
|
|
142
|
+
conn.sql(create_table_stmt)
|
|
143
|
+
conn.sql(f"INSERT INTO {table_name} SELECT * FROM data")
|
|
144
|
+
|
|
145
|
+
old_new_mask = {}
|
|
146
|
+
for col in column_formats:
|
|
147
|
+
old_new_mask[col.source_name] = col.alias
|
|
148
|
+
DuckDBOperator._apply_alias_mask(table_name, conn, old_new_mask)
|
|
149
|
+
|
|
150
|
+
else:
|
|
151
|
+
conn.sql(f"CREATE OR REPLACE TEMP TABLE {table_name} AS SELECT * FROM data")
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
Preprocess to make smarter type inferences. Pandas and json recognize very little beyond primitives out of the box. We can support a little more with a little effort like:
|
|
155
|
+
1. Datetimes
|
|
156
|
+
|
|
157
|
+
Modifies the input data in place to have smarter types than pandas will natively infer
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def _preprocess_dataframe_schema_inference(data: pd.DataFrame) -> pd.DataFrame:
|
|
162
|
+
datetime_columns = _infer_dataframe_datetime_columns(data)
|
|
163
|
+
for column in datetime_columns:
|
|
164
|
+
try:
|
|
165
|
+
data[column] = pd.to_datetime(data[column])
|
|
166
|
+
except Exception:
|
|
167
|
+
# we're using best-effort to infer datetime columns, but just in case we got it wrong, move on
|
|
168
|
+
continue
|
|
169
|
+
|
|
170
|
+
return data
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def join_tables(
|
|
174
|
+
conn: duckdb.DuckDBPyConnection,
|
|
175
|
+
table_name: str,
|
|
176
|
+
table_1: str,
|
|
177
|
+
table_2: str,
|
|
178
|
+
table_1_join_key: str,
|
|
179
|
+
table_2_join_key: str,
|
|
180
|
+
join_kind: DatasetJoinKind = DatasetJoinKind.INNER,
|
|
181
|
+
) -> None:
|
|
182
|
+
match join_kind:
|
|
183
|
+
case DatasetJoinKind.INNER:
|
|
184
|
+
join = "INNER"
|
|
185
|
+
case DatasetJoinKind.LEFT_OUTER:
|
|
186
|
+
join = "LEFT"
|
|
187
|
+
case DatasetJoinKind.RIGHT_OUTER:
|
|
188
|
+
join = "RIGHT"
|
|
189
|
+
case DatasetJoinKind.OUTER:
|
|
190
|
+
join = "FULL OUTER"
|
|
191
|
+
case _:
|
|
192
|
+
raise NotImplementedError(f"Join kind {join_kind} is not supported.")
|
|
193
|
+
|
|
194
|
+
# Don't quote the join column names, since they're already quoted as part of escape_identifier's output
|
|
195
|
+
join_query = f"""
|
|
196
|
+
CREATE TABLE {table_name} AS
|
|
197
|
+
SELECT *
|
|
198
|
+
FROM {table_1} a
|
|
199
|
+
{join} JOIN {table_2} b
|
|
200
|
+
ON a.{escape_identifier(table_1_join_key)} = b.{escape_identifier(table_2_join_key)}
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
conn.sql(join_query)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def _infer_dataframe_datetime_columns(df: pd.DataFrame, n: int = 100) -> list[str]:
|
|
207
|
+
"""
|
|
208
|
+
Infer datetime columns in a pandas DataFrame by parsing non-null values in the first n rows. Return the column names believed to be datetime type
|
|
209
|
+
|
|
210
|
+
Parameters:
|
|
211
|
+
df (pandas.DataFrame): Input DataFrame.
|
|
212
|
+
n (int): Number of non-null rows to consider for each column. Default is 100.
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
datetime_columns (list): List of column names inferred to be datetime.
|
|
216
|
+
"""
|
|
217
|
+
datetime_columns = []
|
|
218
|
+
|
|
219
|
+
for column in df.columns:
|
|
220
|
+
non_null_values = df[column].dropna().head(n)
|
|
221
|
+
if non_null_values.empty:
|
|
222
|
+
continue
|
|
223
|
+
|
|
224
|
+
# Try parsing each non-null value in the column
|
|
225
|
+
try:
|
|
226
|
+
parsed_values = non_null_values.apply(lambda x: parse(x))
|
|
227
|
+
|
|
228
|
+
# If parsing succeeds for all values, consider the column as datetime
|
|
229
|
+
if parsed_values.notnull().all():
|
|
230
|
+
datetime_columns.append(column)
|
|
231
|
+
except:
|
|
232
|
+
# If parsing fails for any value, move to the next column
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
return datetime_columns
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
"""
|
|
239
|
+
Returns a list of ColumnFormat. Depending on structure / unstructured data, we need to format the root columns differently, so return the raw forms.
|
|
240
|
+
|
|
241
|
+
See the subtle differences between
|
|
242
|
+
|
|
243
|
+
CREATE TABLE users (
|
|
244
|
+
userID BIGINT,
|
|
245
|
+
userName VARCHAR,
|
|
246
|
+
hobbies ARRAY<VARCHAR>
|
|
247
|
+
);
|
|
248
|
+
|
|
249
|
+
and
|
|
250
|
+
|
|
251
|
+
SELECT *
|
|
252
|
+
FROM read_json('todos.json',
|
|
253
|
+
format = 'array',
|
|
254
|
+
columns = {userId: 'UBIGINT',
|
|
255
|
+
userName: 'VARCHAR',
|
|
256
|
+
hobbies: 'ARRAY<VARCHAR>'});
|
|
257
|
+
|
|
258
|
+
"""
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def make_duckdb_dataset_schema(schema: DatasetSchema) -> list[ColumnFormat]:
|
|
262
|
+
details = []
|
|
263
|
+
for col in schema.columns:
|
|
264
|
+
format = _make_schema(col.definition)
|
|
265
|
+
details.append(
|
|
266
|
+
ColumnFormat(source_name=col.source_name, alias=str(col.id), format=format),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
return details
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def _make_schema(
|
|
273
|
+
schema_node: DatasetObjectType | DatasetListType | DatasetScalarType,
|
|
274
|
+
) -> str:
|
|
275
|
+
if isinstance(schema_node, DatasetObjectType):
|
|
276
|
+
details = {}
|
|
277
|
+
for col, value in schema_node.object.items():
|
|
278
|
+
details[col] = _make_schema(value)
|
|
279
|
+
key_value_pairs = [
|
|
280
|
+
f"{escape_identifier(col)} {value}" for col, value in details.items()
|
|
281
|
+
]
|
|
282
|
+
return f"STRUCT({', '.join(key_value_pairs)})"
|
|
283
|
+
|
|
284
|
+
elif isinstance(schema_node, DatasetListType):
|
|
285
|
+
return f"{_make_schema(schema_node.items)}[]"
|
|
286
|
+
elif isinstance(schema_node, DatasetScalarType):
|
|
287
|
+
match schema_node.dtype:
|
|
288
|
+
case DType.INT:
|
|
289
|
+
return "BIGINT"
|
|
290
|
+
case DType.FLOAT:
|
|
291
|
+
return "DOUBLE"
|
|
292
|
+
case DType.BOOL:
|
|
293
|
+
return "BOOLEAN"
|
|
294
|
+
case DType.STRING:
|
|
295
|
+
return "VARCHAR"
|
|
296
|
+
case DType.UUID:
|
|
297
|
+
return "UUID"
|
|
298
|
+
case DType.TIMESTAMP:
|
|
299
|
+
return "TIMESTAMP"
|
|
300
|
+
case DType.JSON:
|
|
301
|
+
return "JSON"
|
|
302
|
+
case _:
|
|
303
|
+
raise ValueError(f"Unknown mapping for DType {schema_node.dtype}")
|
|
304
|
+
else:
|
|
305
|
+
raise NotImplementedError(
|
|
306
|
+
f"Schema conversion not implemented for node type {type(schema_node)}",
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def escape_identifier(identifier: str) -> str:
|
|
311
|
+
"""
|
|
312
|
+
Escape an identifier (e.g., column name) for use in a SQL query.
|
|
313
|
+
This method handles special characters and ensures proper quoting.
|
|
314
|
+
"""
|
|
315
|
+
# Replace any double quotes with two double quotes
|
|
316
|
+
escaped = identifier.replace('"', '""')
|
|
317
|
+
# Wrap the entire identifier in double quotes and return
|
|
318
|
+
return f'"{escaped}"'
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def escape_str_literal(literal: str) -> str:
|
|
322
|
+
"""
|
|
323
|
+
Escape a duckDB string literal for use in a SQL query.
|
|
324
|
+
https://duckdb.org/docs/stable/sql/data_types/literal_types.html#escape-string-literals
|
|
325
|
+
"""
|
|
326
|
+
# replace any single quotes with two single quotes
|
|
327
|
+
escaped = literal.replace("'", "''")
|
|
328
|
+
# Wrap the entire identifier in single quotes and return
|
|
329
|
+
return f"'{escaped}'"
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
from uuid import UUID
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Convert a uuid to a 12 character, all lowercase string deterministically
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def uuid_to_base26(uuid: str | UUID) -> str:
|
|
10
|
+
# Remove hyphens from the UUID
|
|
11
|
+
if isinstance(uuid, UUID):
|
|
12
|
+
uuid_str = str(uuid)
|
|
13
|
+
else:
|
|
14
|
+
uuid_str = uuid
|
|
15
|
+
|
|
16
|
+
no_hyphens = uuid_str.replace("-", "")
|
|
17
|
+
|
|
18
|
+
# Convert the hex string to an integer
|
|
19
|
+
num = int(no_hyphens, 16)
|
|
20
|
+
|
|
21
|
+
# Define the alphabet for base-26
|
|
22
|
+
alphabet = "abcdefghijklmnopqrstuvwxyz"
|
|
23
|
+
base = len(alphabet)
|
|
24
|
+
|
|
25
|
+
# Encode the integer into a base-26 string
|
|
26
|
+
base26_str = ""
|
|
27
|
+
while num > 0:
|
|
28
|
+
num, rem = divmod(num, base)
|
|
29
|
+
base26_str = alphabet[rem] + base26_str
|
|
30
|
+
|
|
31
|
+
# Ensure the string is 12 characters long (pad with 'a' if necessary)
|
|
32
|
+
base26_str = base26_str.rjust(12, "a")
|
|
33
|
+
|
|
34
|
+
# If the resulting string is longer than 12 characters, take the last 12 characters
|
|
35
|
+
return base26_str[-12:]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
Hash a string
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def hash_nonce(nonce: str) -> str:
|
|
44
|
+
md5_hash = hashlib.md5()
|
|
45
|
+
md5_hash.update(nonce.encode("utf-8"))
|
|
46
|
+
return md5_hash.hexdigest()
|
|
File without changes
|