tracdap-runtime 0.6.5__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +272 -105
- tracdap/rt/_exec/dev_mode.py +231 -138
- tracdap/rt/_exec/engine.py +217 -59
- tracdap/rt/_exec/functions.py +25 -1
- tracdap/rt/_exec/graph.py +9 -0
- tracdap/rt/_exec/graph_builder.py +295 -198
- tracdap/rt/_exec/runtime.py +7 -5
- tracdap/rt/_impl/config_parser.py +11 -4
- tracdap/rt/_impl/data.py +278 -167
- tracdap/rt/_impl/ext/__init__.py +13 -0
- tracdap/rt/_impl/ext/sql.py +116 -0
- tracdap/rt/_impl/ext/storage.py +57 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +62 -54
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +37 -2
- tracdap/rt/_impl/static_api.py +24 -11
- tracdap/rt/_impl/storage.py +2 -2
- tracdap/rt/_impl/util.py +10 -0
- tracdap/rt/_impl/validation.py +66 -13
- tracdap/rt/_plugins/storage_sql.py +417 -0
- tracdap/rt/_plugins/storage_sql_dialects.py +117 -0
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/experimental.py +79 -32
- tracdap/rt/api/hook.py +10 -0
- tracdap/rt/metadata/__init__.py +4 -0
- tracdap/rt/metadata/job.py +45 -0
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/METADATA +3 -1
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/RECORD +30 -25
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/WHEEL +1 -1
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/top_level.txt +0 -0
tracdap/rt/_impl/validation.py
CHANGED
@@ -44,10 +44,25 @@ def check_type(expected_type: tp.Type, value: tp.Any) -> bool:
|
|
44
44
|
return _TypeValidator.check_type(expected_type, value)
|
45
45
|
|
46
46
|
|
47
|
+
def type_name(type_: tp.Type, qualified: bool) -> str:
|
48
|
+
return _TypeValidator._type_name(type_, qualified) # noqa
|
49
|
+
|
50
|
+
|
47
51
|
def quick_validate_model_def(model_def: meta.ModelDefinition):
|
48
52
|
StaticValidator.quick_validate_model_def(model_def)
|
49
53
|
|
50
54
|
|
55
|
+
def is_primitive_type(basic_type: meta.BasicType) -> bool:
|
56
|
+
return StaticValidator.is_primitive_type(basic_type)
|
57
|
+
|
58
|
+
|
59
|
+
T_SKIP_VAL = tp.TypeVar("T_SKIP_VAL")
|
60
|
+
|
61
|
+
class SkipValidation(tp.Generic[T_SKIP_VAL]):
|
62
|
+
def __init__(self, skip_type: tp.Type[T_SKIP_VAL]):
|
63
|
+
self.skip_type = skip_type
|
64
|
+
|
65
|
+
|
51
66
|
class _TypeValidator:
|
52
67
|
|
53
68
|
# The metaclass for generic types varies between versions of the typing library
|
@@ -56,28 +71,28 @@ class _TypeValidator:
|
|
56
71
|
|
57
72
|
# Cache method signatures to avoid inspection on every call
|
58
73
|
# Inspecting a function signature can take ~ half a second in Python 3.7
|
59
|
-
__method_cache: tp.Dict[str, inspect.Signature] = dict()
|
74
|
+
__method_cache: tp.Dict[str, tp.Tuple[inspect.Signature, tp.Any]] = dict()
|
60
75
|
|
61
76
|
_log: logging.Logger = util.logger_for_namespace(__name__)
|
62
77
|
|
63
78
|
@classmethod
|
64
79
|
def validate_signature(cls, method: tp.Callable, *args, **kwargs):
|
65
80
|
|
66
|
-
if method.
|
67
|
-
signature = cls.__method_cache[method.
|
81
|
+
if method.__qualname__ in cls.__method_cache:
|
82
|
+
signature, hints = cls.__method_cache[method.__qualname__]
|
68
83
|
else:
|
69
84
|
signature = inspect.signature(method)
|
70
|
-
|
71
|
-
|
72
|
-
hints = tp.get_type_hints(method)
|
85
|
+
hints = tp.get_type_hints(method)
|
86
|
+
cls.__method_cache[method.__qualname__] = signature, hints
|
73
87
|
|
88
|
+
named_params = list(signature.parameters.keys())
|
74
89
|
positional_index = 0
|
75
90
|
|
76
91
|
for param_name, param in signature.parameters.items():
|
77
92
|
|
78
93
|
param_type = hints.get(param_name)
|
79
94
|
|
80
|
-
values = cls._select_arg(method.__name__, param, positional_index, *args, **kwargs)
|
95
|
+
values = cls._select_arg(method.__name__, param, positional_index, named_params, *args, **kwargs)
|
81
96
|
positional_index += len(values)
|
82
97
|
|
83
98
|
for value in values:
|
@@ -86,11 +101,12 @@ class _TypeValidator:
|
|
86
101
|
@classmethod
|
87
102
|
def validate_return_type(cls, method: tp.Callable, value: tp.Any):
|
88
103
|
|
89
|
-
if method.
|
90
|
-
signature = cls.__method_cache[method.
|
104
|
+
if method.__qualname__ in cls.__method_cache:
|
105
|
+
signature, hints = cls.__method_cache[method.__qualname__]
|
91
106
|
else:
|
92
107
|
signature = inspect.signature(method)
|
93
|
-
|
108
|
+
hints = tp.get_type_hints(method)
|
109
|
+
cls.__method_cache[method.__qualname__] = signature, hints
|
94
110
|
|
95
111
|
correct_type = cls._validate_type(signature.return_annotation, value)
|
96
112
|
|
@@ -107,7 +123,7 @@ class _TypeValidator:
|
|
107
123
|
|
108
124
|
@classmethod
|
109
125
|
def _select_arg(
|
110
|
-
cls, method_name: str, parameter: inspect.Parameter, positional_index,
|
126
|
+
cls, method_name: str, parameter: inspect.Parameter, positional_index, named_params,
|
111
127
|
*args, **kwargs) -> tp.List[tp.Any]:
|
112
128
|
|
113
129
|
if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
|
@@ -152,7 +168,7 @@ class _TypeValidator:
|
|
152
168
|
|
153
169
|
if parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
154
170
|
|
155
|
-
|
171
|
+
return [arg for kw, arg in kwargs.items() if kw not in named_params]
|
156
172
|
|
157
173
|
raise ex.EUnexpected("Invalid method signature in runtime API (this is a bug)")
|
158
174
|
|
@@ -180,6 +196,12 @@ class _TypeValidator:
|
|
180
196
|
if expected_type == tp.Any:
|
181
197
|
return True
|
182
198
|
|
199
|
+
# Sometimes we need to validate a partial set of arguments
|
200
|
+
# Explicitly passing a SkipValidation value allows for this
|
201
|
+
if isinstance(value, SkipValidation):
|
202
|
+
if value.skip_type == expected_type:
|
203
|
+
return True
|
204
|
+
|
183
205
|
if isinstance(expected_type, cls.__generic_metaclass):
|
184
206
|
|
185
207
|
origin = util.get_origin(expected_type)
|
@@ -216,8 +238,31 @@ class _TypeValidator:
|
|
216
238
|
all(map(lambda k: cls._validate_type(key_type, k), value.keys())) and \
|
217
239
|
all(map(lambda v: cls._validate_type(value_type, v), value.values()))
|
218
240
|
|
241
|
+
if origin.__module__.startswith("tracdap.rt.api."):
|
242
|
+
return isinstance(value, origin)
|
243
|
+
|
219
244
|
raise ex.ETracInternal(f"Validation of [{origin.__name__}] generic parameters is not supported yet")
|
220
245
|
|
246
|
+
# Support for generic type variables
|
247
|
+
if isinstance(expected_type, tp.TypeVar):
|
248
|
+
|
249
|
+
# If there are any constraints or a bound, those must be honoured
|
250
|
+
|
251
|
+
constraints = util.get_constraints(expected_type)
|
252
|
+
bound = util.get_bound(expected_type)
|
253
|
+
|
254
|
+
if constraints:
|
255
|
+
if not any(map(lambda c: type(value) == c, constraints)):
|
256
|
+
return False
|
257
|
+
|
258
|
+
if bound:
|
259
|
+
if not isinstance(value, bound):
|
260
|
+
return False
|
261
|
+
|
262
|
+
# So long as constraints / bound are ok, any type matches a generic type var
|
263
|
+
return True
|
264
|
+
|
265
|
+
|
221
266
|
# Validate everything else as a concrete type
|
222
267
|
|
223
268
|
# TODO: Recursive validation of types for class members using field annotations
|
@@ -237,7 +282,10 @@ class _TypeValidator:
|
|
237
282
|
return f"Named[{named_type}]"
|
238
283
|
|
239
284
|
if origin is tp.Union:
|
240
|
-
|
285
|
+
if len(args) == 2 and args[1] == type(None):
|
286
|
+
return f"Optional[{cls._type_name(args[0])}]"
|
287
|
+
else:
|
288
|
+
return "|".join(map(cls._type_name, args))
|
241
289
|
|
242
290
|
if origin is list:
|
243
291
|
list_type = cls._type_name(args[0])
|
@@ -274,6 +322,11 @@ class StaticValidator:
|
|
274
322
|
|
275
323
|
_log: logging.Logger = util.logger_for_namespace(__name__)
|
276
324
|
|
325
|
+
@classmethod
|
326
|
+
def is_primitive_type(cls, basic_type: meta.BasicType) -> bool:
|
327
|
+
|
328
|
+
return basic_type in cls.__PRIMITIVE_TYPES
|
329
|
+
|
277
330
|
@classmethod
|
278
331
|
def quick_validate_model_def(cls, model_def: meta.ModelDefinition):
|
279
332
|
|
@@ -0,0 +1,417 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import contextlib
|
16
|
+
import typing as tp
|
17
|
+
import urllib.parse as urlp
|
18
|
+
|
19
|
+
import pyarrow as pa
|
20
|
+
|
21
|
+
import tracdap.rt.config as cfg
|
22
|
+
import tracdap.rt.exceptions as ex
|
23
|
+
import tracdap.rt.ext.plugins as plugins
|
24
|
+
|
25
|
+
# Import storage interfaces (private extension API)
|
26
|
+
from tracdap.rt._impl.ext.storage import * # noqa
|
27
|
+
from tracdap.rt._impl.ext.sql import * # noqa
|
28
|
+
|
29
|
+
import tracdap.rt._plugins._helpers as _helpers
|
30
|
+
|
31
|
+
# TODO: Remove internal references
|
32
|
+
import tracdap.rt._impl.data as _data
|
33
|
+
|
34
|
+
|
35
|
+
class SqlDataStorage(IDataStorageBase[pa.Table, pa.Schema]):
|
36
|
+
|
37
|
+
DIALECT_PROPERTY = "dialect"
|
38
|
+
DRIVER_PROPERTY = "driver.python"
|
39
|
+
|
40
|
+
__DQL_KEYWORDS = ["select"]
|
41
|
+
__DML_KEYWORDS = ["insert", "update", "delete", "merge"]
|
42
|
+
__DDL_KEYWORDS = ["create", "alter", "drop", "grant"]
|
43
|
+
|
44
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
45
|
+
|
46
|
+
self._log = _helpers.logger_for_object(self)
|
47
|
+
self._properties = properties
|
48
|
+
|
49
|
+
dialect_name = _helpers.get_plugin_property(self._properties, self.DIALECT_PROPERTY)
|
50
|
+
|
51
|
+
if dialect_name is None:
|
52
|
+
raise ex.EConfigLoad(f"Missing required property [{self.DIALECT_PROPERTY}]")
|
53
|
+
|
54
|
+
if not plugins.PluginManager.is_plugin_available(ISqlDialect, dialect_name.lower()):
|
55
|
+
raise ex.EPluginNotAvailable(f"SQL dialect [{dialect_name}] is not supported")
|
56
|
+
|
57
|
+
driver_name = _helpers.get_plugin_property(self._properties, self.DRIVER_PROPERTY)
|
58
|
+
if driver_name is None:
|
59
|
+
driver_name = dialect_name.lower()
|
60
|
+
|
61
|
+
if not plugins.PluginManager.is_plugin_available(ISqlDriver, driver_name):
|
62
|
+
raise ex.EPluginNotAvailable(f"SQL driver [{driver_name}] is not available")
|
63
|
+
|
64
|
+
driver_props = self._driver_props(driver_name)
|
65
|
+
driver_cfg = cfg.PluginConfig(protocol=driver_name.lower(), properties=driver_props)
|
66
|
+
dialect_cfg = cfg.PluginConfig(protocol=dialect_name.lower(), properties={})
|
67
|
+
|
68
|
+
self._log.info(f"Loading SQL driver [{driver_name}] for dialect [{dialect_name}]")
|
69
|
+
|
70
|
+
self._driver = plugins.PluginManager.load_plugin(ISqlDriver, driver_cfg)
|
71
|
+
self._dialect = plugins.PluginManager.load_plugin(ISqlDialect, dialect_cfg)
|
72
|
+
|
73
|
+
# Test connectivity
|
74
|
+
with self._connection():
|
75
|
+
pass
|
76
|
+
|
77
|
+
def _driver_props(self, driver_name: str) -> tp.Dict[str, str]:
|
78
|
+
|
79
|
+
driver_props = dict()
|
80
|
+
driver_filter = f"{driver_name}."
|
81
|
+
|
82
|
+
for key, value in self._properties.items():
|
83
|
+
if key.startswith(driver_filter):
|
84
|
+
dialect_key = key[len(driver_filter):]
|
85
|
+
driver_props[dialect_key] = value
|
86
|
+
|
87
|
+
return driver_props
|
88
|
+
|
89
|
+
def _connection(self) -> DbApiWrapper.Connection:
|
90
|
+
|
91
|
+
return contextlib.closing(self._driver.connect()) # noqa
|
92
|
+
|
93
|
+
def _cursor(self, conn: DbApiWrapper.Connection) -> DbApiWrapper.Cursor:
|
94
|
+
|
95
|
+
return contextlib.closing(conn.cursor()) # noqa
|
96
|
+
|
97
|
+
def data_type(self) -> tp.Type[pa.Table]:
|
98
|
+
return pa.Table
|
99
|
+
|
100
|
+
def schema_type(self) -> tp.Type[pa.Schema]:
|
101
|
+
return pa.Schema
|
102
|
+
|
103
|
+
def has_table(self, table_name: str):
|
104
|
+
|
105
|
+
with self._driver.error_handling():
|
106
|
+
return self._driver.has_table(table_name)
|
107
|
+
|
108
|
+
def list_tables(self):
|
109
|
+
|
110
|
+
with self._driver.error_handling():
|
111
|
+
return self._driver.list_tables()
|
112
|
+
|
113
|
+
def create_table(self, table_name: str, schema: pa.Schema):
|
114
|
+
|
115
|
+
with self._driver.error_handling():
|
116
|
+
|
117
|
+
def type_decl(field: pa.Field):
|
118
|
+
sql_type = self._dialect.arrow_to_sql_type(field.type)
|
119
|
+
null_qualifier = " NULL" if field.nullable else " NOT NULL"
|
120
|
+
return f"{field.name} {sql_type}{null_qualifier}"
|
121
|
+
|
122
|
+
create_fields = map(lambda i: type_decl(schema.field(i)), range(len(schema.names)))
|
123
|
+
create_stmt = f"create table {table_name} (" + ", ".join(create_fields) + ")"
|
124
|
+
|
125
|
+
with self._connection() as conn, self._cursor(conn) as cur:
|
126
|
+
cur.execute(create_stmt, [])
|
127
|
+
conn.commit() # Some drivers / dialects (Postgres) require commit for create table
|
128
|
+
|
129
|
+
def read_table(self, table_name: str) -> pa.Table:
|
130
|
+
|
131
|
+
select_stmt = f"select * from {table_name}" # noqa
|
132
|
+
|
133
|
+
return self.native_read_query(select_stmt)
|
134
|
+
|
135
|
+
def native_read_query(self, query: str, **parameters) -> pa.Table:
|
136
|
+
|
137
|
+
# Real restrictions are enforced in deployment, by permissions granted to service accounts
|
138
|
+
# This is a sanity check to catch common errors before sending a query to the backend
|
139
|
+
self._check_read_query(query)
|
140
|
+
|
141
|
+
with self._driver.error_handling():
|
142
|
+
|
143
|
+
with self._connection() as conn, self._cursor(conn) as cur:
|
144
|
+
|
145
|
+
cur.execute(query, parameters)
|
146
|
+
sql_batch = cur.fetchmany()
|
147
|
+
|
148
|
+
# Read queries should always return a result set, even if it is empty
|
149
|
+
if not cur.description:
|
150
|
+
raise ex.EStorage(f"Query did not return a result set: {query}")
|
151
|
+
|
152
|
+
arrow_schema = self._decode_sql_schema(cur.description)
|
153
|
+
arrow_batches: tp.List[pa.RecordBatch] = []
|
154
|
+
|
155
|
+
while len(sql_batch) > 0:
|
156
|
+
|
157
|
+
arrow_batch = self._decode_sql_batch(arrow_schema, sql_batch)
|
158
|
+
arrow_batches.append(arrow_batch)
|
159
|
+
|
160
|
+
# Sometimes the schema is not fully defined up front (because cur.description is not sufficient)
|
161
|
+
# If type information has been inferred from the batch, update the schema accordingly
|
162
|
+
arrow_schema = arrow_batch.schema
|
163
|
+
|
164
|
+
sql_batch = cur.fetchmany()
|
165
|
+
|
166
|
+
return pa.Table.from_batches(arrow_batches, arrow_schema) # noqa
|
167
|
+
|
168
|
+
def write_table(self, table_name: str, table: pa.Table):
|
169
|
+
|
170
|
+
with self._driver.error_handling():
|
171
|
+
|
172
|
+
insert_fields = ", ".join(table.schema.names)
|
173
|
+
insert_markers = ", ".join(f":{name}" for name in table.schema.names)
|
174
|
+
insert_stmt = f"insert into {table_name}({insert_fields}) values ({insert_markers})" # noqa
|
175
|
+
|
176
|
+
with self._connection() as conn:
|
177
|
+
|
178
|
+
# Use execute many to perform a batch write
|
179
|
+
with self._cursor(conn) as cur:
|
180
|
+
if table.num_rows > 0:
|
181
|
+
# Provider converts rows on demand, to optimize for memory
|
182
|
+
row_provider = self._encode_sql_rows_dict(table)
|
183
|
+
cur.executemany(insert_stmt, row_provider)
|
184
|
+
else:
|
185
|
+
# Do not try to insert if there are now rows to bind
|
186
|
+
pass
|
187
|
+
|
188
|
+
conn.commit()
|
189
|
+
|
190
|
+
def _check_read_query(self, query):
|
191
|
+
|
192
|
+
if not any(map(lambda keyword: keyword in query.lower(), self.__DQL_KEYWORDS)):
|
193
|
+
raise ex.EStorageRequest(f"Query is not a read query: {query}")
|
194
|
+
|
195
|
+
if any(map(lambda keyword: keyword in query.lower(), self.__DML_KEYWORDS)):
|
196
|
+
raise ex.EStorageRequest(f"Query is not a read query: {query}")
|
197
|
+
|
198
|
+
if any(map(lambda keyword: keyword in query.lower(), self.__DDL_KEYWORDS)):
|
199
|
+
raise ex.EStorageRequest(f"Query is not a read query: {query}")
|
200
|
+
|
201
|
+
@staticmethod
|
202
|
+
def _decode_sql_schema(description: tp.List[tp.Tuple]):
|
203
|
+
|
204
|
+
# TODO: Infer Python / Arrow type using DB API type code
|
205
|
+
# These codes are db-specific so decoding would probably be on a best effort basis
|
206
|
+
# However the information is public for many popular db engines
|
207
|
+
# The current logic can be kept as a fallback (set type info on reading first non-null value)
|
208
|
+
|
209
|
+
def _decode_sql_field(field_desc: tp.Tuple):
|
210
|
+
field_name, type_code, _, _, precision, scale, null_ok = field_desc
|
211
|
+
return pa.field(field_name, pa.null(), null_ok)
|
212
|
+
|
213
|
+
fields = map(_decode_sql_field, description)
|
214
|
+
|
215
|
+
return pa.schema(fields)
|
216
|
+
|
217
|
+
def _decode_sql_batch(self, schema: pa.Schema, sql_batch: tp.List[tp.Tuple]) -> pa.RecordBatch:
|
218
|
+
|
219
|
+
py_dict: tp.Dict[str, pa.Array] = {}
|
220
|
+
|
221
|
+
for i, col in enumerate(schema.names):
|
222
|
+
|
223
|
+
arrow_type = schema.types[i]
|
224
|
+
|
225
|
+
if pa.types.is_null(arrow_type):
|
226
|
+
values = list(map(lambda row: row[i], sql_batch))
|
227
|
+
concrete_value = next(v for v in values if v is not None)
|
228
|
+
if concrete_value is not None:
|
229
|
+
arrow_type = _data.DataMapping.python_to_arrow_type(type(concrete_value))
|
230
|
+
arrow_field = pa.field(schema.names[i], arrow_type, nullable=True)
|
231
|
+
schema = schema.remove(i).insert(i, arrow_field)
|
232
|
+
else:
|
233
|
+
python_type = _data.DataMapping.arrow_to_python_type(arrow_type)
|
234
|
+
values = map(lambda row: self._driver.decode_sql_value(row[i], python_type), sql_batch)
|
235
|
+
|
236
|
+
py_dict[col] = pa.array(values, type=arrow_type)
|
237
|
+
|
238
|
+
return pa.RecordBatch.from_pydict(py_dict, schema)
|
239
|
+
|
240
|
+
def _encode_sql_rows_tuple(self, table: pa.Table) -> tp.Iterator[tp.Tuple]:
|
241
|
+
|
242
|
+
for row in range(0, table.num_rows):
|
243
|
+
row_values = map(lambda col: self._driver.encode_sql_value(col[row].as_py()), table.columns)
|
244
|
+
yield tuple(row_values)
|
245
|
+
|
246
|
+
def _encode_sql_rows_dict(self, table: pa.Table) -> tp.Iterator[tp.Tuple]:
|
247
|
+
|
248
|
+
for row in range(0, table.num_rows):
|
249
|
+
row_values = map(lambda col: self._driver.encode_sql_value(col[row].as_py()), table.columns)
|
250
|
+
yield dict(zip(table.column_names, row_values))
|
251
|
+
|
252
|
+
|
253
|
+
class SqlStorageProvider(IStorageProvider):
|
254
|
+
|
255
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
256
|
+
self._properties = properties
|
257
|
+
|
258
|
+
def has_data_storage(self) -> bool:
|
259
|
+
return True
|
260
|
+
|
261
|
+
def get_data_storage(self) -> IDataStorageBase:
|
262
|
+
return SqlDataStorage(self._properties)
|
263
|
+
|
264
|
+
|
265
|
+
# Register with the plugin manager
|
266
|
+
plugins.PluginManager.register_plugin(IStorageProvider, SqlStorageProvider, ["SQL"])
|
267
|
+
|
268
|
+
|
269
|
+
try:
|
270
|
+
|
271
|
+
import sqlalchemy as sqla # noqa
|
272
|
+
import sqlalchemy.exc as sqla_exc # noqa
|
273
|
+
|
274
|
+
class SqlAlchemyDriver(ISqlDriver):
|
275
|
+
|
276
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
277
|
+
|
278
|
+
self._log = _helpers.logger_for_object(self)
|
279
|
+
|
280
|
+
raw_url = properties.get('url')
|
281
|
+
|
282
|
+
if raw_url is None or raw_url.strip() == '':
|
283
|
+
raise ex.EConfigLoad("Missing required property [url] for SQL driver [alchemy]")
|
284
|
+
|
285
|
+
url = urlp.urlparse(raw_url)
|
286
|
+
credentials = _helpers.get_http_credentials(url, properties)
|
287
|
+
url = _helpers.apply_http_credentials(url, credentials)
|
288
|
+
|
289
|
+
filtered_keys = ["url", "username", "password", "token"]
|
290
|
+
filtered_props = dict(kv for kv in properties.items() if kv[0] not in filtered_keys)
|
291
|
+
|
292
|
+
self._log.info("Connecting: %s", _helpers.log_safe_url(url))
|
293
|
+
|
294
|
+
try:
|
295
|
+
self.__engine = sqla.create_engine(url.geturl(), **filtered_props)
|
296
|
+
except ModuleNotFoundError as e:
|
297
|
+
raise ex.EPluginNotAvailable("SQL driver is not available: " + str(e)) from e
|
298
|
+
|
299
|
+
def param_style(self) -> "DbApiWrapper.ParamStyle":
|
300
|
+
return DbApiWrapper.ParamStyle.NAMED
|
301
|
+
|
302
|
+
def connect(self, **kwargs) -> "DbApiWrapper.Connection":
|
303
|
+
|
304
|
+
return SqlAlchemyDriver.ConnectionWrapper(self.__engine.connect())
|
305
|
+
|
306
|
+
def has_table(self, table_name: str):
|
307
|
+
|
308
|
+
with self.__engine.connect() as conn:
|
309
|
+
inspection = sqla.inspect(conn)
|
310
|
+
return inspection.has_table(table_name)
|
311
|
+
|
312
|
+
def list_tables(self):
|
313
|
+
|
314
|
+
with self.__engine.connect() as conn:
|
315
|
+
inspection = sqla.inspect(conn)
|
316
|
+
return inspection.get_table_names()
|
317
|
+
|
318
|
+
def encode_sql_value(self, py_value: tp.Any) -> tp.Any:
|
319
|
+
|
320
|
+
return py_value
|
321
|
+
|
322
|
+
def decode_sql_value(self, sql_value: tp.Any, python_type: tp.Type) -> tp.Any:
|
323
|
+
|
324
|
+
return sql_value
|
325
|
+
|
326
|
+
@contextlib.contextmanager
|
327
|
+
def error_handling(self) -> contextlib.contextmanager:
|
328
|
+
|
329
|
+
try:
|
330
|
+
yield
|
331
|
+
except (sqla_exc.OperationalError, sqla_exc.ProgrammingError, sqla_exc.StatementError) as e:
|
332
|
+
raise ex.EStorageRequest(*e.args) from e
|
333
|
+
except sqla_exc.SQLAlchemyError as e:
|
334
|
+
raise ex.EStorage() from e
|
335
|
+
|
336
|
+
class ConnectionWrapper(DbApiWrapper.Connection):
|
337
|
+
|
338
|
+
def __init__(self, conn: sqla.Connection):
|
339
|
+
self.__conn = conn
|
340
|
+
|
341
|
+
def close(self):
|
342
|
+
self.__conn.close()
|
343
|
+
|
344
|
+
def commit(self):
|
345
|
+
self.__conn.commit()
|
346
|
+
|
347
|
+
def rollback(self):
|
348
|
+
self.__conn.rollback()
|
349
|
+
|
350
|
+
def cursor(self) -> "DbApiWrapper.Cursor":
|
351
|
+
return SqlAlchemyDriver.CursorWrapper(self.__conn)
|
352
|
+
|
353
|
+
class CursorWrapper(DbApiWrapper.Cursor):
|
354
|
+
|
355
|
+
arraysize: int = 1000
|
356
|
+
|
357
|
+
def __init__(self, conn: sqla.Connection):
|
358
|
+
self.__conn = conn
|
359
|
+
self.__result: tp.Optional[sqla.CursorResult] = None
|
360
|
+
|
361
|
+
@property
|
362
|
+
def description(self):
|
363
|
+
|
364
|
+
# Prefer description from the underlying cursor if available
|
365
|
+
if self.__result.cursor is not None and self.__result.cursor.description:
|
366
|
+
return self.__result.cursor.description
|
367
|
+
|
368
|
+
if not self.__result.returns_rows:
|
369
|
+
return None
|
370
|
+
|
371
|
+
# SQL Alchemy sometimes closes the cursor and the description is lost
|
372
|
+
# Fall back on using the Result API to generate a description with field names only
|
373
|
+
|
374
|
+
def name_only_field_desc(field_name):
|
375
|
+
return field_name, None, None, None, None, None, None
|
376
|
+
|
377
|
+
return list(map(name_only_field_desc, self.__result.keys()))
|
378
|
+
|
379
|
+
@property
|
380
|
+
def rowcount(self) -> int:
|
381
|
+
|
382
|
+
# Prefer the value from the underlying cursor if it is available
|
383
|
+
if self.__result.cursor is not None:
|
384
|
+
return self.__result.cursor.rowcount
|
385
|
+
|
386
|
+
return self.__result.rowcount # noqa
|
387
|
+
|
388
|
+
def execute(self, statement: str, parameters: tp.Union[tp.Dict, tp.Sequence]):
|
389
|
+
|
390
|
+
self.__result = self.__conn.execute(sqla.text(statement), parameters)
|
391
|
+
|
392
|
+
def executemany(self, statement: str, parameters: tp.Iterable[tp.Union[tp.Dict, tp.Sequence]]):
|
393
|
+
|
394
|
+
if not isinstance(parameters, tp.List):
|
395
|
+
parameters = list(parameters)
|
396
|
+
|
397
|
+
self.__result = self.__conn.execute(sqla.text(statement), parameters)
|
398
|
+
|
399
|
+
def fetchone(self) -> tp.Tuple:
|
400
|
+
|
401
|
+
row = self.__result.fetchone()
|
402
|
+
return row.tuple() if row is not None else None
|
403
|
+
|
404
|
+
def fetchmany(self, size: int = arraysize) -> tp.Sequence[tp.Tuple]:
|
405
|
+
|
406
|
+
sqla_rows = self.__result.fetchmany(self.arraysize)
|
407
|
+
return list(map(sqla.Row.tuple, sqla_rows)) # noqa
|
408
|
+
|
409
|
+
def close(self):
|
410
|
+
|
411
|
+
if self.__result is not None:
|
412
|
+
self.__result.close()
|
413
|
+
|
414
|
+
plugins.PluginManager.register_plugin(ISqlDriver, SqlAlchemyDriver, ["alchemy"])
|
415
|
+
|
416
|
+
except ModuleNotFoundError:
|
417
|
+
pass
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import typing as tp
|
16
|
+
|
17
|
+
import pyarrow as pa
|
18
|
+
|
19
|
+
import tracdap.rt.exceptions as ex
|
20
|
+
import tracdap.rt.ext.plugins as plugins
|
21
|
+
|
22
|
+
from tracdap.rt._impl.ext.sql import * # noqa
|
23
|
+
|
24
|
+
|
25
|
+
|
26
|
+
class AnsiStandardDialect(ISqlDialect):
|
27
|
+
|
28
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
29
|
+
|
30
|
+
if pa.types.is_boolean(arrow_type):
|
31
|
+
return "boolean"
|
32
|
+
|
33
|
+
if pa.types.is_integer(arrow_type):
|
34
|
+
return "bigint"
|
35
|
+
|
36
|
+
if pa.types.is_floating(arrow_type):
|
37
|
+
return "double precision"
|
38
|
+
|
39
|
+
if pa.types.is_decimal(arrow_type):
|
40
|
+
return "decimal (31, 10)"
|
41
|
+
|
42
|
+
if pa.types.is_string(arrow_type):
|
43
|
+
return "varchar(4096)"
|
44
|
+
|
45
|
+
if pa.types.is_date(arrow_type):
|
46
|
+
return "date"
|
47
|
+
|
48
|
+
if pa.types.is_timestamp(arrow_type):
|
49
|
+
return "timestamp (6)"
|
50
|
+
|
51
|
+
raise ex.ETracInternal(f"Unsupported data type [{str(arrow_type)}] in SQL dialect [{self.__class__.__name__}]")
|
52
|
+
|
53
|
+
|
54
|
+
class MySqlDialect(AnsiStandardDialect):
|
55
|
+
|
56
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
57
|
+
self._properties = properties
|
58
|
+
|
59
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
60
|
+
|
61
|
+
if pa.types.is_floating(arrow_type):
|
62
|
+
return "double"
|
63
|
+
|
64
|
+
if pa.types.is_string(arrow_type):
|
65
|
+
return "varchar(8192)"
|
66
|
+
|
67
|
+
return super().arrow_to_sql_type(arrow_type)
|
68
|
+
|
69
|
+
|
70
|
+
class MariaDbDialect(MySqlDialect):
|
71
|
+
|
72
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
73
|
+
super().__init__(properties)
|
74
|
+
|
75
|
+
# Inherit MySQL implementation
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
class PostgresqlDialect(AnsiStandardDialect):
|
80
|
+
|
81
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
82
|
+
self._properties = properties
|
83
|
+
|
84
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
85
|
+
|
86
|
+
if pa.types.is_string(arrow_type):
|
87
|
+
return "varchar"
|
88
|
+
|
89
|
+
return super().arrow_to_sql_type(arrow_type)
|
90
|
+
|
91
|
+
|
92
|
+
class SqlServerDialect(AnsiStandardDialect):
|
93
|
+
|
94
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
95
|
+
self._properties = properties
|
96
|
+
|
97
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
98
|
+
|
99
|
+
if pa.types.is_boolean(arrow_type):
|
100
|
+
return "bit"
|
101
|
+
|
102
|
+
if pa.types.is_floating(arrow_type):
|
103
|
+
return "float(53)"
|
104
|
+
|
105
|
+
if pa.types.is_string(arrow_type):
|
106
|
+
return "varchar(8000)"
|
107
|
+
|
108
|
+
if pa.types.is_timestamp(arrow_type):
|
109
|
+
return "datetime2"
|
110
|
+
|
111
|
+
return super().arrow_to_sql_type(arrow_type)
|
112
|
+
|
113
|
+
|
114
|
+
plugins.PluginManager.register_plugin(ISqlDialect, MySqlDialect, ["mysql"])
|
115
|
+
plugins.PluginManager.register_plugin(ISqlDialect, MariaDbDialect, ["mariadb"])
|
116
|
+
plugins.PluginManager.register_plugin(ISqlDialect, PostgresqlDialect, ["postgresql"])
|
117
|
+
plugins.PluginManager.register_plugin(ISqlDialect, SqlServerDialect, ["sqlserver"])
|
tracdap/rt/_version.py
CHANGED