arthur-common 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of arthur-common might be problematic. Click here for more details.

Files changed (40) hide show
  1. arthur_common/__init__.py +0 -0
  2. arthur_common/__version__.py +1 -0
  3. arthur_common/aggregations/__init__.py +2 -0
  4. arthur_common/aggregations/aggregator.py +214 -0
  5. arthur_common/aggregations/functions/README.md +26 -0
  6. arthur_common/aggregations/functions/__init__.py +25 -0
  7. arthur_common/aggregations/functions/categorical_count.py +89 -0
  8. arthur_common/aggregations/functions/confusion_matrix.py +412 -0
  9. arthur_common/aggregations/functions/inference_count.py +69 -0
  10. arthur_common/aggregations/functions/inference_count_by_class.py +206 -0
  11. arthur_common/aggregations/functions/inference_null_count.py +82 -0
  12. arthur_common/aggregations/functions/mean_absolute_error.py +110 -0
  13. arthur_common/aggregations/functions/mean_squared_error.py +110 -0
  14. arthur_common/aggregations/functions/multiclass_confusion_matrix.py +205 -0
  15. arthur_common/aggregations/functions/multiclass_inference_count_by_class.py +90 -0
  16. arthur_common/aggregations/functions/numeric_stats.py +90 -0
  17. arthur_common/aggregations/functions/numeric_sum.py +87 -0
  18. arthur_common/aggregations/functions/py.typed +0 -0
  19. arthur_common/aggregations/functions/shield_aggregations.py +752 -0
  20. arthur_common/aggregations/py.typed +0 -0
  21. arthur_common/models/__init__.py +0 -0
  22. arthur_common/models/connectors.py +41 -0
  23. arthur_common/models/datasets.py +22 -0
  24. arthur_common/models/metrics.py +227 -0
  25. arthur_common/models/py.typed +0 -0
  26. arthur_common/models/schema_definitions.py +420 -0
  27. arthur_common/models/shield.py +504 -0
  28. arthur_common/models/task_job_specs.py +78 -0
  29. arthur_common/py.typed +0 -0
  30. arthur_common/tools/__init__.py +0 -0
  31. arthur_common/tools/aggregation_analyzer.py +243 -0
  32. arthur_common/tools/aggregation_loader.py +59 -0
  33. arthur_common/tools/duckdb_data_loader.py +329 -0
  34. arthur_common/tools/functions.py +46 -0
  35. arthur_common/tools/py.typed +0 -0
  36. arthur_common/tools/schema_inferer.py +104 -0
  37. arthur_common/tools/time_utils.py +33 -0
  38. arthur_common-1.0.1.dist-info/METADATA +74 -0
  39. arthur_common-1.0.1.dist-info/RECORD +40 -0
  40. arthur_common-1.0.1.dist-info/WHEEL +4 -0
@@ -0,0 +1,243 @@
1
+ import inspect
2
+ import logging
3
+ import typing
4
+ import uuid
5
+ from typing import Any, Callable, get_type_hints
6
+
7
+ from arthur_common.aggregations import (
8
+ AggregationFunction,
9
+ NumericAggregationFunction,
10
+ SketchAggregationFunction,
11
+ )
12
+ from arthur_common.models.metrics import (
13
+ AggregationMetricType,
14
+ AggregationSpecSchema,
15
+ DatasetReference,
16
+ MetricsColumnParameterSchema,
17
+ MetricsDatasetParameterSchema,
18
+ MetricsLiteralParameterSchema,
19
+ MetricsParameterSchemaUnion,
20
+ )
21
+ from arthur_common.models.schema_definitions import (
22
+ DType,
23
+ MetricColumnParameterAnnotation,
24
+ MetricDatasetParameterAnnotation,
25
+ MetricLiteralParameterAnnotation,
26
+ MetricsParameterAnnotationUnion,
27
+ )
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class FunctionAnalyzer:
33
+ @staticmethod
34
+ def _python_type_to_scope_dtype(t: Any) -> DType:
35
+ if t is int:
36
+ return DType.INT
37
+ elif t is str:
38
+ return DType.STRING
39
+ elif t is bool:
40
+ return DType.BOOL
41
+ elif t is float:
42
+ return DType.FLOAT
43
+ elif t is uuid.UUID:
44
+ return DType.UUID
45
+ elif t is DatasetReference:
46
+ return DType.UUID
47
+ else:
48
+ raise ValueError(f"Parameter type {t} is not supported.")
49
+
50
+ @staticmethod
51
+ def _get_metric_annotation_from_annotated(
52
+ param_name: str,
53
+ annotation: typing.Annotated, # type: ignore
54
+ ) -> MetricsParameterAnnotationUnion | None:
55
+ arthur_metric_annotations = [
56
+ m
57
+ for m in annotation.__metadata__
58
+ if isinstance(m, MetricsParameterAnnotationUnion)
59
+ ]
60
+ metric_annotation: MetricsParameterAnnotationUnion | None = None
61
+ if len(arthur_metric_annotations) == 1:
62
+ metric_annotation = arthur_metric_annotations[0]
63
+ if len(arthur_metric_annotations) > 1:
64
+ raise ValueError(
65
+ f"Parameter {param_name} defines more than one metric annotation.",
66
+ )
67
+ return metric_annotation
68
+
69
+ @staticmethod
70
+ def _get_scope_metric_parameter_from_annotation(
71
+ param_name: str,
72
+ param_dtype: DType,
73
+ optional: bool,
74
+ annotation: typing.Annotated, # type: ignore
75
+ ) -> MetricsParameterSchemaUnion:
76
+ if annotation is None:
77
+ return MetricsLiteralParameterSchema(
78
+ parameter_key=param_name,
79
+ optional=optional,
80
+ parameter_dtype=param_dtype,
81
+ friendly_name=param_name,
82
+ description=f"A {param_dtype.value} value.",
83
+ )
84
+ elif isinstance(annotation, MetricLiteralParameterAnnotation):
85
+ return MetricsLiteralParameterSchema(
86
+ parameter_key=param_name,
87
+ optional=optional,
88
+ parameter_dtype=param_dtype,
89
+ friendly_name=annotation.friendly_name,
90
+ description=annotation.description,
91
+ )
92
+ elif isinstance(annotation, MetricDatasetParameterAnnotation):
93
+ if param_dtype != DType.UUID:
94
+ raise ValueError(
95
+ f"Dataset parameter {param_name} has type {param_dtype}, but should be a UUID.",
96
+ )
97
+ return MetricsDatasetParameterSchema(
98
+ parameter_key=param_name,
99
+ optional=optional,
100
+ friendly_name=annotation.friendly_name,
101
+ description=annotation.description,
102
+ model_problem_type=annotation.model_problem_type,
103
+ )
104
+ elif isinstance(annotation, MetricColumnParameterAnnotation):
105
+ if param_dtype != DType.STRING:
106
+ raise ValueError(
107
+ f"Column parameter {param_name} has type {param_dtype}, but should be a string.",
108
+ )
109
+ return MetricsColumnParameterSchema(
110
+ parameter_key=param_name,
111
+ tag_hints=annotation.tag_hints,
112
+ optional=optional,
113
+ source_dataset_parameter_key=annotation.source_dataset_parameter_key,
114
+ allowed_column_types=annotation.allowed_column_types,
115
+ allow_any_column_type=annotation.allow_any_column_type,
116
+ friendly_name=annotation.friendly_name,
117
+ description=annotation.description,
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ f"Parameter {param_name} has an unsupported annotation {annotation}.",
122
+ )
123
+
124
+ """
125
+ Returns a list of parameter names, parameter types, scope-specific annotations.
126
+ """
127
+
128
+ @staticmethod
129
+ def _extract_parameter_metadata(func: Callable) -> list[MetricsParameterSchemaUnion]: # type: ignore
130
+ parameter_schemas: list[MetricsParameterSchemaUnion] = []
131
+ args = inspect.signature(func).parameters
132
+ for name, param in args.items():
133
+ if name == "self":
134
+ continue
135
+ if name == "ddb_conn":
136
+ continue
137
+
138
+ if param.annotation == inspect.Parameter.empty:
139
+ raise ValueError(
140
+ f"{func.__name__} must provide type annotation for parameter {name}.",
141
+ )
142
+ parameter_schemas.append(
143
+ FunctionAnalyzer._get_scope_metric_parameter_from_annotation(
144
+ name,
145
+ FunctionAnalyzer._python_type_to_scope_dtype(
146
+ get_type_hints(func)[name],
147
+ ),
148
+ param.default != inspect.Parameter.empty,
149
+ (
150
+ FunctionAnalyzer._get_metric_annotation_from_annotated(
151
+ name,
152
+ param.annotation,
153
+ )
154
+ if typing.get_origin(param.annotation) is typing.Annotated
155
+ else None
156
+ ),
157
+ ),
158
+ )
159
+
160
+ return parameter_schemas
161
+
162
+ @staticmethod
163
+ def analyze_aggregation_function(agg_func: type) -> AggregationSpecSchema:
164
+ # Check if X is a subclass of AggregationFunction
165
+ if not issubclass(agg_func, AggregationFunction):
166
+ raise TypeError(
167
+ f"Class {agg_func.__name__} is not a subclass of AggregationFunction.",
168
+ )
169
+
170
+ if issubclass(agg_func, NumericAggregationFunction):
171
+ metric_type = AggregationMetricType.NUMERIC
172
+ elif issubclass(agg_func, SketchAggregationFunction):
173
+ metric_type = AggregationMetricType.SKETCH
174
+ else:
175
+ raise ValueError(
176
+ f"Class {agg_func.__name__} is not a subclass of SketchAggregationFunction, NumericAggregationFunction.",
177
+ )
178
+ # Check if X implements the required methods
179
+ required_methods = ["aggregate", "id", "description", "display_name"]
180
+ static_methods = ["description", "id", "display_name"]
181
+ for method in required_methods:
182
+ if not hasattr(agg_func, method) or not callable(getattr(agg_func, method)):
183
+ raise AttributeError(
184
+ f"Class {agg_func.__name__} does not implement {method} method.",
185
+ )
186
+
187
+ for method in static_methods:
188
+ if not is_static_method(getattr(agg_func, method)):
189
+ raise AttributeError(f"Method {method} should be a staticmethod.")
190
+ # Check if X passes the ABC implementation:
191
+ try:
192
+ agg_func()
193
+ except TypeError as e:
194
+ if "Can't instantiate abstract class" in str(e):
195
+ logger.error(str(e))
196
+ raise TypeError(
197
+ f"Class {agg_func.__name__} does not implement all the base class functions.",
198
+ )
199
+ else:
200
+ # This is okay, it just means we didn't supply proper args to the __init__ function. The ABC mismatch would throw before this, so it must have passed
201
+ pass
202
+
203
+ aggregation_init_args: list[MetricsParameterSchemaUnion] = []
204
+
205
+ # This is necessary because all the way down in the ABC class, some __init__ function is defined which we don't care about. Users should be able to exclude an init function and this allows them to do that.
206
+ if has_custom_init(agg_func):
207
+ aggregation_init_args = FunctionAnalyzer._extract_parameter_metadata(
208
+ agg_func.__init__,
209
+ )
210
+ aggregate_args = FunctionAnalyzer._extract_parameter_metadata(
211
+ agg_func.aggregate,
212
+ )
213
+
214
+ aggregation_id = agg_func.id()
215
+ aggregation_description = agg_func.description()
216
+
217
+ return AggregationSpecSchema(
218
+ name=agg_func.display_name(),
219
+ id=aggregation_id,
220
+ # TODO: Require description, version
221
+ description=aggregation_description,
222
+ # version=0,
223
+ metric_type=metric_type,
224
+ init_args=aggregation_init_args,
225
+ aggregate_args=aggregate_args,
226
+ )
227
+
228
+
229
+ def has_custom_init(cls: type) -> bool:
230
+ init_method = getattr(cls, "__init__", None)
231
+ base_init_method = (
232
+ getattr(cls.__base__, "__init__", None) if hasattr(cls, "__base__") else None
233
+ )
234
+ return init_method is not base_init_method
235
+
236
+
237
+ def is_static_method(method: type) -> bool:
238
+ if inspect.isfunction(method):
239
+ # Check if the method accepts no arguments or only default arguments
240
+ argspec = inspect.getfullargspec(method)
241
+ if len(argspec.args) == 0 and not argspec.varargs and not argspec.varkw:
242
+ return True
243
+ return False
@@ -0,0 +1,59 @@
1
+ import inspect
2
+ import logging
3
+ from types import ModuleType
4
+ from typing import Type
5
+
6
+ import arthur_common.aggregations as agg_module
7
+ from arthur_common.aggregations.aggregator import (
8
+ AggregationFunction,
9
+ NumericAggregationFunction,
10
+ SketchAggregationFunction,
11
+ )
12
+ from arthur_common.models.metrics import AggregationSpecSchema
13
+ from arthur_common.tools.aggregation_analyzer import FunctionAnalyzer
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class AggregationLoader:
20
+ @staticmethod
21
+ def load_aggregations() -> (
22
+ list[tuple[AggregationSpecSchema, Type[AggregationFunction]]]
23
+ ):
24
+ def find_subclasses(
25
+ module: ModuleType,
26
+ base_classes: tuple[type, ...],
27
+ visited: set[ModuleType] = set(),
28
+ ) -> set[type]:
29
+ subclasses = set()
30
+ visited.add(module)
31
+ for name, obj in inspect.getmembers(module):
32
+ if inspect.isclass(obj) and issubclass(obj, base_classes):
33
+ subclasses.add(obj)
34
+ elif inspect.ismodule(obj) and obj not in visited:
35
+ subclasses.update(find_subclasses(obj, base_classes, visited))
36
+ return subclasses
37
+
38
+ base_classes = (SketchAggregationFunction, NumericAggregationFunction)
39
+ agg_functions = find_subclasses(agg_module, base_classes)
40
+ aggregation_specs = []
41
+ """
42
+ This seems to pick up duplicate functions somehow in different namespaces, ie:
43
+ <class 'categorical_count.CategoricalCountAggregationFunction'>
44
+ and
45
+ <class 'arthur_common.aggregations.functions.categorical_count.CategoricalCountAggregationFunction'>
46
+ so dedupe by id
47
+ """
48
+ aggregation_ids = set()
49
+ for agg_function in agg_functions:
50
+ try:
51
+ func_spec = FunctionAnalyzer.analyze_aggregation_function(agg_function)
52
+ logger.info(f"Found agg function {agg_function}")
53
+ if func_spec.id in aggregation_ids:
54
+ continue
55
+ aggregation_specs.append((func_spec, agg_function))
56
+ aggregation_ids.add(func_spec.id)
57
+ except Exception as e:
58
+ logger.error(f"Failed to load aggregation function {agg_function}: {e}")
59
+ return aggregation_specs
@@ -0,0 +1,329 @@
1
+ import json
2
+ from typing import Any
3
+
4
+ import duckdb
5
+ import pandas as pd
6
+ from arthur_common.models.datasets import DatasetJoinKind
7
+ from arthur_common.models.schema_definitions import (
8
+ DatasetListType,
9
+ DatasetObjectType,
10
+ DatasetScalarType,
11
+ DatasetSchema,
12
+ DType,
13
+ )
14
+ from dateutil.parser import parse
15
+ from fsspec import filesystem
16
+ from pydantic import BaseModel
17
+
18
+
19
+ class ColumnFormat(BaseModel):
20
+ source_name: str
21
+ alias: str
22
+ format: str
23
+
24
+
25
+ class DuckDBOperator:
26
+ """
27
+ Loads data into a DuckDB table.
28
+
29
+ If no schema is supplied, the output table will contain columns with names equal to the source names in the data.
30
+ If a schema is applied, the column names in the output table will be aliases equal to the column id from the schema.
31
+ This allows for consistent column naming across different data sources.
32
+ """
33
+
34
+ @staticmethod
35
+ def load_data_to_duckdb(
36
+ data: list[dict[str, Any]] | pd.DataFrame,
37
+ preprocess_schema: bool = False,
38
+ table_name: str = "inferences",
39
+ conn: duckdb.DuckDBPyConnection | None = None,
40
+ schema: DatasetSchema | None = None,
41
+ ) -> duckdb.DuckDBPyConnection:
42
+ if not conn:
43
+ conn = duckdb.connect()
44
+
45
+ if type(data) == list:
46
+ DuckDBOperator._load_unstructured_data(data, table_name, conn, schema)
47
+ elif type(data) == pd.DataFrame:
48
+ DuckDBOperator._load_structured_data(
49
+ data,
50
+ preprocess_schema,
51
+ table_name,
52
+ conn,
53
+ schema,
54
+ )
55
+ else:
56
+ raise ValueError(f"Unsupported data type: {type(data)}")
57
+ return conn
58
+
59
+ """
60
+ Rename columns from ids to their friendly names based on schema.column_names
61
+ """
62
+
63
+ @staticmethod
64
+ def apply_alias_mask(
65
+ table_name: str,
66
+ conn: duckdb.DuckDBPyConnection,
67
+ schema: DatasetSchema,
68
+ ) -> None:
69
+ old_new_mask = {
70
+ str(col_id): schema.column_names[col_id] for col_id in schema.column_names
71
+ }
72
+ DuckDBOperator._apply_alias_mask(table_name, conn, old_new_mask)
73
+
74
+ @staticmethod
75
+ def _apply_alias_mask(
76
+ table_name: str,
77
+ conn: duckdb.DuckDBPyConnection,
78
+ old_new_mask: dict[str, str],
79
+ ) -> None:
80
+ for old, new in old_new_mask.items():
81
+ # Don't quote the join column names, since they're already quoted as part of escape_identifier's output
82
+ alter_query = f"ALTER TABLE {table_name} RENAME COLUMN {escape_identifier(old)} TO {escape_identifier(new)}"
83
+ conn.sql(alter_query)
84
+
85
+ @staticmethod
86
+ def _load_unstructured_data(
87
+ data: list[dict[str, Any]],
88
+ table_name: str,
89
+ conn: duckdb.DuckDBPyConnection,
90
+ schema: DatasetSchema | None,
91
+ ) -> None:
92
+ with filesystem("memory").open(f"inferences.json", "w") as file:
93
+ file.write(json.dumps(data))
94
+ conn.register_filesystem(filesystem("memory"))
95
+
96
+ if schema:
97
+ column_formats = make_duckdb_dataset_schema(schema)
98
+
99
+ key_value_pairs = [
100
+ f"{escape_identifier(col.source_name)}: '{col.format}'"
101
+ for col in column_formats
102
+ ]
103
+ stringified_schema = ", ".join([f"{kv}" for kv in key_value_pairs])
104
+ stringified_schema = f"{{ {stringified_schema} }}"
105
+
106
+ read_stmt = f"read_json('memory://inferences.json', format='array', columns={stringified_schema})"
107
+ else:
108
+ read_stmt = "read_json_auto('memory://inferences.json')"
109
+
110
+ conn.sql(
111
+ f"CREATE OR REPLACE TEMP TABLE {table_name} AS SELECT * FROM {read_stmt}",
112
+ )
113
+
114
+ if schema:
115
+ old_new_mask = {}
116
+ for col in column_formats:
117
+ old_new_mask[col.source_name] = col.alias
118
+ DuckDBOperator._apply_alias_mask(table_name, conn, old_new_mask)
119
+
120
+ @staticmethod
121
+ def _load_structured_data(
122
+ data: pd.DataFrame,
123
+ preprocess_schema: bool,
124
+ table_name: str,
125
+ conn: duckdb.DuckDBPyConnection,
126
+ schema: DatasetSchema | None,
127
+ ) -> None:
128
+ if preprocess_schema:
129
+ data = DuckDBOperator._preprocess_dataframe_schema_inference(data)
130
+
131
+ if schema:
132
+ column_formats = make_duckdb_dataset_schema(schema)
133
+
134
+ key_value_pairs = [
135
+ f"{escape_identifier(col.source_name)} {col.format}"
136
+ for col in column_formats
137
+ ]
138
+ stringified_schema = ", ".join([f"{kv}" for kv in key_value_pairs])
139
+ create_table_stmt = (
140
+ f"CREATE OR REPLACE TEMP TABLE {table_name} ({stringified_schema});"
141
+ )
142
+ conn.sql(create_table_stmt)
143
+ conn.sql(f"INSERT INTO {table_name} SELECT * FROM data")
144
+
145
+ old_new_mask = {}
146
+ for col in column_formats:
147
+ old_new_mask[col.source_name] = col.alias
148
+ DuckDBOperator._apply_alias_mask(table_name, conn, old_new_mask)
149
+
150
+ else:
151
+ conn.sql(f"CREATE OR REPLACE TEMP TABLE {table_name} AS SELECT * FROM data")
152
+
153
+ """
154
+ Preprocess to make smarter type inferences. Pandas and json recognize very little beyond primitives out of the box. We can support a little more with a little effort like:
155
+ 1. Datetimes
156
+
157
+ Modifies the input data in place to have smarter types than pandas will natively infer
158
+ """
159
+
160
+ @staticmethod
161
+ def _preprocess_dataframe_schema_inference(data: pd.DataFrame) -> pd.DataFrame:
162
+ datetime_columns = _infer_dataframe_datetime_columns(data)
163
+ for column in datetime_columns:
164
+ try:
165
+ data[column] = pd.to_datetime(data[column])
166
+ except Exception:
167
+ # we're using best-effort to infer datetime columns, but just in case we got it wrong, move on
168
+ continue
169
+
170
+ return data
171
+
172
+ @staticmethod
173
+ def join_tables(
174
+ conn: duckdb.DuckDBPyConnection,
175
+ table_name: str,
176
+ table_1: str,
177
+ table_2: str,
178
+ table_1_join_key: str,
179
+ table_2_join_key: str,
180
+ join_kind: DatasetJoinKind = DatasetJoinKind.INNER,
181
+ ) -> None:
182
+ match join_kind:
183
+ case DatasetJoinKind.INNER:
184
+ join = "INNER"
185
+ case DatasetJoinKind.LEFT_OUTER:
186
+ join = "LEFT"
187
+ case DatasetJoinKind.RIGHT_OUTER:
188
+ join = "RIGHT"
189
+ case DatasetJoinKind.OUTER:
190
+ join = "FULL OUTER"
191
+ case _:
192
+ raise NotImplementedError(f"Join kind {join_kind} is not supported.")
193
+
194
+ # Don't quote the join column names, since they're already quoted as part of escape_identifier's output
195
+ join_query = f"""
196
+ CREATE TABLE {table_name} AS
197
+ SELECT *
198
+ FROM {table_1} a
199
+ {join} JOIN {table_2} b
200
+ ON a.{escape_identifier(table_1_join_key)} = b.{escape_identifier(table_2_join_key)}
201
+ """
202
+
203
+ conn.sql(join_query)
204
+
205
+
206
+ def _infer_dataframe_datetime_columns(df: pd.DataFrame, n: int = 100) -> list[str]:
207
+ """
208
+ Infer datetime columns in a pandas DataFrame by parsing non-null values in the first n rows. Return the column names believed to be datetime type
209
+
210
+ Parameters:
211
+ df (pandas.DataFrame): Input DataFrame.
212
+ n (int): Number of non-null rows to consider for each column. Default is 100.
213
+
214
+ Returns:
215
+ datetime_columns (list): List of column names inferred to be datetime.
216
+ """
217
+ datetime_columns = []
218
+
219
+ for column in df.columns:
220
+ non_null_values = df[column].dropna().head(n)
221
+ if non_null_values.empty:
222
+ continue
223
+
224
+ # Try parsing each non-null value in the column
225
+ try:
226
+ parsed_values = non_null_values.apply(lambda x: parse(x))
227
+
228
+ # If parsing succeeds for all values, consider the column as datetime
229
+ if parsed_values.notnull().all():
230
+ datetime_columns.append(column)
231
+ except:
232
+ # If parsing fails for any value, move to the next column
233
+ continue
234
+
235
+ return datetime_columns
236
+
237
+
238
+ """
239
+ Returns a list of ColumnFormat. Depending on structure / unstructured data, we need to format the root columns differently, so return the raw forms.
240
+
241
+ See the subtle differences between
242
+
243
+ CREATE TABLE users (
244
+ userID BIGINT,
245
+ userName VARCHAR,
246
+ hobbies ARRAY<VARCHAR>
247
+ );
248
+
249
+ and
250
+
251
+ SELECT *
252
+ FROM read_json('todos.json',
253
+ format = 'array',
254
+ columns = {userId: 'UBIGINT',
255
+ userName: 'VARCHAR',
256
+ hobbies: 'ARRAY<VARCHAR>'});
257
+
258
+ """
259
+
260
+
261
+ def make_duckdb_dataset_schema(schema: DatasetSchema) -> list[ColumnFormat]:
262
+ details = []
263
+ for col in schema.columns:
264
+ format = _make_schema(col.definition)
265
+ details.append(
266
+ ColumnFormat(source_name=col.source_name, alias=str(col.id), format=format),
267
+ )
268
+
269
+ return details
270
+
271
+
272
+ def _make_schema(
273
+ schema_node: DatasetObjectType | DatasetListType | DatasetScalarType,
274
+ ) -> str:
275
+ if isinstance(schema_node, DatasetObjectType):
276
+ details = {}
277
+ for col, value in schema_node.object.items():
278
+ details[col] = _make_schema(value)
279
+ key_value_pairs = [
280
+ f"{escape_identifier(col)} {value}" for col, value in details.items()
281
+ ]
282
+ return f"STRUCT({', '.join(key_value_pairs)})"
283
+
284
+ elif isinstance(schema_node, DatasetListType):
285
+ return f"{_make_schema(schema_node.items)}[]"
286
+ elif isinstance(schema_node, DatasetScalarType):
287
+ match schema_node.dtype:
288
+ case DType.INT:
289
+ return "BIGINT"
290
+ case DType.FLOAT:
291
+ return "DOUBLE"
292
+ case DType.BOOL:
293
+ return "BOOLEAN"
294
+ case DType.STRING:
295
+ return "VARCHAR"
296
+ case DType.UUID:
297
+ return "UUID"
298
+ case DType.TIMESTAMP:
299
+ return "TIMESTAMP"
300
+ case DType.JSON:
301
+ return "JSON"
302
+ case _:
303
+ raise ValueError(f"Unknown mapping for DType {schema_node.dtype}")
304
+ else:
305
+ raise NotImplementedError(
306
+ f"Schema conversion not implemented for node type {type(schema_node)}",
307
+ )
308
+
309
+
310
+ def escape_identifier(identifier: str) -> str:
311
+ """
312
+ Escape an identifier (e.g., column name) for use in a SQL query.
313
+ This method handles special characters and ensures proper quoting.
314
+ """
315
+ # Replace any double quotes with two double quotes
316
+ escaped = identifier.replace('"', '""')
317
+ # Wrap the entire identifier in double quotes and return
318
+ return f'"{escaped}"'
319
+
320
+
321
+ def escape_str_literal(literal: str) -> str:
322
+ """
323
+ Escape a duckDB string literal for use in a SQL query.
324
+ https://duckdb.org/docs/stable/sql/data_types/literal_types.html#escape-string-literals
325
+ """
326
+ # replace any single quotes with two single quotes
327
+ escaped = literal.replace("'", "''")
328
+ # Wrap the entire identifier in single quotes and return
329
+ return f"'{escaped}'"
@@ -0,0 +1,46 @@
1
+ import hashlib
2
+ from uuid import UUID
3
+
4
+ """
5
+ Convert a uuid to a 12 character, all lowercase string deterministically
6
+ """
7
+
8
+
9
+ def uuid_to_base26(uuid: str | UUID) -> str:
10
+ # Remove hyphens from the UUID
11
+ if isinstance(uuid, UUID):
12
+ uuid_str = str(uuid)
13
+ else:
14
+ uuid_str = uuid
15
+
16
+ no_hyphens = uuid_str.replace("-", "")
17
+
18
+ # Convert the hex string to an integer
19
+ num = int(no_hyphens, 16)
20
+
21
+ # Define the alphabet for base-26
22
+ alphabet = "abcdefghijklmnopqrstuvwxyz"
23
+ base = len(alphabet)
24
+
25
+ # Encode the integer into a base-26 string
26
+ base26_str = ""
27
+ while num > 0:
28
+ num, rem = divmod(num, base)
29
+ base26_str = alphabet[rem] + base26_str
30
+
31
+ # Ensure the string is 12 characters long (pad with 'a' if necessary)
32
+ base26_str = base26_str.rjust(12, "a")
33
+
34
+ # If the resulting string is longer than 12 characters, take the last 12 characters
35
+ return base26_str[-12:]
36
+
37
+
38
+ """
39
+ Hash a string
40
+ """
41
+
42
+
43
+ def hash_nonce(nonce: str) -> str:
44
+ md5_hash = hashlib.md5()
45
+ md5_hash.update(nonce.encode("utf-8"))
46
+ return md5_hash.hexdigest()
File without changes