tracdap-runtime 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +556 -36
- tracdap/rt/_exec/dev_mode.py +320 -198
- tracdap/rt/_exec/engine.py +331 -62
- tracdap/rt/_exec/functions.py +151 -22
- tracdap/rt/_exec/graph.py +47 -13
- tracdap/rt/_exec/graph_builder.py +383 -175
- tracdap/rt/_exec/runtime.py +7 -5
- tracdap/rt/_impl/config_parser.py +11 -4
- tracdap/rt/_impl/data.py +329 -152
- tracdap/rt/_impl/ext/__init__.py +13 -0
- tracdap/rt/_impl/ext/sql.py +116 -0
- tracdap/rt/_impl/ext/storage.py +57 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +82 -30
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +155 -2
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +12 -10
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +14 -2
- tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
- tracdap/rt/_impl/models.py +8 -0
- tracdap/rt/_impl/static_api.py +29 -0
- tracdap/rt/_impl/storage.py +39 -27
- tracdap/rt/_impl/util.py +10 -0
- tracdap/rt/_impl/validation.py +140 -18
- tracdap/rt/_plugins/repo_git.py +1 -1
- tracdap/rt/_plugins/storage_sql.py +417 -0
- tracdap/rt/_plugins/storage_sql_dialects.py +117 -0
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/experimental.py +267 -0
- tracdap/rt/api/hook.py +14 -0
- tracdap/rt/api/model_api.py +48 -6
- tracdap/rt/config/__init__.py +2 -2
- tracdap/rt/config/common.py +6 -0
- tracdap/rt/metadata/__init__.py +29 -20
- tracdap/rt/metadata/job.py +99 -0
- tracdap/rt/metadata/model.py +18 -0
- tracdap/rt/metadata/resource.py +24 -0
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/METADATA +5 -1
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/RECORD +41 -32
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/WHEEL +1 -1
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,417 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import contextlib
|
16
|
+
import typing as tp
|
17
|
+
import urllib.parse as urlp
|
18
|
+
|
19
|
+
import pyarrow as pa
|
20
|
+
|
21
|
+
import tracdap.rt.config as cfg
|
22
|
+
import tracdap.rt.exceptions as ex
|
23
|
+
import tracdap.rt.ext.plugins as plugins
|
24
|
+
|
25
|
+
# Import storage interfaces (private extension API)
|
26
|
+
from tracdap.rt._impl.ext.storage import * # noqa
|
27
|
+
from tracdap.rt._impl.ext.sql import * # noqa
|
28
|
+
|
29
|
+
import tracdap.rt._plugins._helpers as _helpers
|
30
|
+
|
31
|
+
# TODO: Remove internal references
|
32
|
+
import tracdap.rt._impl.data as _data
|
33
|
+
|
34
|
+
|
35
|
+
class SqlDataStorage(IDataStorageBase[pa.Table, pa.Schema]):
|
36
|
+
|
37
|
+
DIALECT_PROPERTY = "dialect"
|
38
|
+
DRIVER_PROPERTY = "driver.python"
|
39
|
+
|
40
|
+
__DQL_KEYWORDS = ["select"]
|
41
|
+
__DML_KEYWORDS = ["insert", "update", "delete", "merge"]
|
42
|
+
__DDL_KEYWORDS = ["create", "alter", "drop", "grant"]
|
43
|
+
|
44
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
45
|
+
|
46
|
+
self._log = _helpers.logger_for_object(self)
|
47
|
+
self._properties = properties
|
48
|
+
|
49
|
+
dialect_name = _helpers.get_plugin_property(self._properties, self.DIALECT_PROPERTY)
|
50
|
+
|
51
|
+
if dialect_name is None:
|
52
|
+
raise ex.EConfigLoad(f"Missing required property [{self.DIALECT_PROPERTY}]")
|
53
|
+
|
54
|
+
if not plugins.PluginManager.is_plugin_available(ISqlDialect, dialect_name.lower()):
|
55
|
+
raise ex.EPluginNotAvailable(f"SQL dialect [{dialect_name}] is not supported")
|
56
|
+
|
57
|
+
driver_name = _helpers.get_plugin_property(self._properties, self.DRIVER_PROPERTY)
|
58
|
+
if driver_name is None:
|
59
|
+
driver_name = dialect_name.lower()
|
60
|
+
|
61
|
+
if not plugins.PluginManager.is_plugin_available(ISqlDriver, driver_name):
|
62
|
+
raise ex.EPluginNotAvailable(f"SQL driver [{driver_name}] is not available")
|
63
|
+
|
64
|
+
driver_props = self._driver_props(driver_name)
|
65
|
+
driver_cfg = cfg.PluginConfig(protocol=driver_name.lower(), properties=driver_props)
|
66
|
+
dialect_cfg = cfg.PluginConfig(protocol=dialect_name.lower(), properties={})
|
67
|
+
|
68
|
+
self._log.info(f"Loading SQL driver [{driver_name}] for dialect [{dialect_name}]")
|
69
|
+
|
70
|
+
self._driver = plugins.PluginManager.load_plugin(ISqlDriver, driver_cfg)
|
71
|
+
self._dialect = plugins.PluginManager.load_plugin(ISqlDialect, dialect_cfg)
|
72
|
+
|
73
|
+
# Test connectivity
|
74
|
+
with self._connection():
|
75
|
+
pass
|
76
|
+
|
77
|
+
def _driver_props(self, driver_name: str) -> tp.Dict[str, str]:
|
78
|
+
|
79
|
+
driver_props = dict()
|
80
|
+
driver_filter = f"{driver_name}."
|
81
|
+
|
82
|
+
for key, value in self._properties.items():
|
83
|
+
if key.startswith(driver_filter):
|
84
|
+
dialect_key = key[len(driver_filter):]
|
85
|
+
driver_props[dialect_key] = value
|
86
|
+
|
87
|
+
return driver_props
|
88
|
+
|
89
|
+
def _connection(self) -> DbApiWrapper.Connection:
|
90
|
+
|
91
|
+
return contextlib.closing(self._driver.connect()) # noqa
|
92
|
+
|
93
|
+
def _cursor(self, conn: DbApiWrapper.Connection) -> DbApiWrapper.Cursor:
|
94
|
+
|
95
|
+
return contextlib.closing(conn.cursor()) # noqa
|
96
|
+
|
97
|
+
def data_type(self) -> tp.Type[pa.Table]:
|
98
|
+
return pa.Table
|
99
|
+
|
100
|
+
def schema_type(self) -> tp.Type[pa.Schema]:
|
101
|
+
return pa.Schema
|
102
|
+
|
103
|
+
def has_table(self, table_name: str):
|
104
|
+
|
105
|
+
with self._driver.error_handling():
|
106
|
+
return self._driver.has_table(table_name)
|
107
|
+
|
108
|
+
def list_tables(self):
|
109
|
+
|
110
|
+
with self._driver.error_handling():
|
111
|
+
return self._driver.list_tables()
|
112
|
+
|
113
|
+
def create_table(self, table_name: str, schema: pa.Schema):
|
114
|
+
|
115
|
+
with self._driver.error_handling():
|
116
|
+
|
117
|
+
def type_decl(field: pa.Field):
|
118
|
+
sql_type = self._dialect.arrow_to_sql_type(field.type)
|
119
|
+
null_qualifier = " NULL" if field.nullable else " NOT NULL"
|
120
|
+
return f"{field.name} {sql_type}{null_qualifier}"
|
121
|
+
|
122
|
+
create_fields = map(lambda i: type_decl(schema.field(i)), range(len(schema.names)))
|
123
|
+
create_stmt = f"create table {table_name} (" + ", ".join(create_fields) + ")"
|
124
|
+
|
125
|
+
with self._connection() as conn, self._cursor(conn) as cur:
|
126
|
+
cur.execute(create_stmt, [])
|
127
|
+
conn.commit() # Some drivers / dialects (Postgres) require commit for create table
|
128
|
+
|
129
|
+
def read_table(self, table_name: str) -> pa.Table:
|
130
|
+
|
131
|
+
select_stmt = f"select * from {table_name}" # noqa
|
132
|
+
|
133
|
+
return self.native_read_query(select_stmt)
|
134
|
+
|
135
|
+
def native_read_query(self, query: str, **parameters) -> pa.Table:
|
136
|
+
|
137
|
+
# Real restrictions are enforced in deployment, by permissions granted to service accounts
|
138
|
+
# This is a sanity check to catch common errors before sending a query to the backend
|
139
|
+
self._check_read_query(query)
|
140
|
+
|
141
|
+
with self._driver.error_handling():
|
142
|
+
|
143
|
+
with self._connection() as conn, self._cursor(conn) as cur:
|
144
|
+
|
145
|
+
cur.execute(query, parameters)
|
146
|
+
sql_batch = cur.fetchmany()
|
147
|
+
|
148
|
+
# Read queries should always return a result set, even if it is empty
|
149
|
+
if not cur.description:
|
150
|
+
raise ex.EStorage(f"Query did not return a result set: {query}")
|
151
|
+
|
152
|
+
arrow_schema = self._decode_sql_schema(cur.description)
|
153
|
+
arrow_batches: tp.List[pa.RecordBatch] = []
|
154
|
+
|
155
|
+
while len(sql_batch) > 0:
|
156
|
+
|
157
|
+
arrow_batch = self._decode_sql_batch(arrow_schema, sql_batch)
|
158
|
+
arrow_batches.append(arrow_batch)
|
159
|
+
|
160
|
+
# Sometimes the schema is not fully defined up front (because cur.description is not sufficient)
|
161
|
+
# If type information has been inferred from the batch, update the schema accordingly
|
162
|
+
arrow_schema = arrow_batch.schema
|
163
|
+
|
164
|
+
sql_batch = cur.fetchmany()
|
165
|
+
|
166
|
+
return pa.Table.from_batches(arrow_batches, arrow_schema) # noqa
|
167
|
+
|
168
|
+
def write_table(self, table_name: str, table: pa.Table):
|
169
|
+
|
170
|
+
with self._driver.error_handling():
|
171
|
+
|
172
|
+
insert_fields = ", ".join(table.schema.names)
|
173
|
+
insert_markers = ", ".join(f":{name}" for name in table.schema.names)
|
174
|
+
insert_stmt = f"insert into {table_name}({insert_fields}) values ({insert_markers})" # noqa
|
175
|
+
|
176
|
+
with self._connection() as conn:
|
177
|
+
|
178
|
+
# Use execute many to perform a batch write
|
179
|
+
with self._cursor(conn) as cur:
|
180
|
+
if table.num_rows > 0:
|
181
|
+
# Provider converts rows on demand, to optimize for memory
|
182
|
+
row_provider = self._encode_sql_rows_dict(table)
|
183
|
+
cur.executemany(insert_stmt, row_provider)
|
184
|
+
else:
|
185
|
+
# Do not try to insert if there are now rows to bind
|
186
|
+
pass
|
187
|
+
|
188
|
+
conn.commit()
|
189
|
+
|
190
|
+
def _check_read_query(self, query):
|
191
|
+
|
192
|
+
if not any(map(lambda keyword: keyword in query.lower(), self.__DQL_KEYWORDS)):
|
193
|
+
raise ex.EStorageRequest(f"Query is not a read query: {query}")
|
194
|
+
|
195
|
+
if any(map(lambda keyword: keyword in query.lower(), self.__DML_KEYWORDS)):
|
196
|
+
raise ex.EStorageRequest(f"Query is not a read query: {query}")
|
197
|
+
|
198
|
+
if any(map(lambda keyword: keyword in query.lower(), self.__DDL_KEYWORDS)):
|
199
|
+
raise ex.EStorageRequest(f"Query is not a read query: {query}")
|
200
|
+
|
201
|
+
@staticmethod
|
202
|
+
def _decode_sql_schema(description: tp.List[tp.Tuple]):
|
203
|
+
|
204
|
+
# TODO: Infer Python / Arrow type using DB API type code
|
205
|
+
# These codes are db-specific so decoding would probably be on a best effort basis
|
206
|
+
# However the information is public for many popular db engines
|
207
|
+
# The current logic can be kept as a fallback (set type info on reading first non-null value)
|
208
|
+
|
209
|
+
def _decode_sql_field(field_desc: tp.Tuple):
|
210
|
+
field_name, type_code, _, _, precision, scale, null_ok = field_desc
|
211
|
+
return pa.field(field_name, pa.null(), null_ok)
|
212
|
+
|
213
|
+
fields = map(_decode_sql_field, description)
|
214
|
+
|
215
|
+
return pa.schema(fields)
|
216
|
+
|
217
|
+
def _decode_sql_batch(self, schema: pa.Schema, sql_batch: tp.List[tp.Tuple]) -> pa.RecordBatch:
|
218
|
+
|
219
|
+
py_dict: tp.Dict[str, pa.Array] = {}
|
220
|
+
|
221
|
+
for i, col in enumerate(schema.names):
|
222
|
+
|
223
|
+
arrow_type = schema.types[i]
|
224
|
+
|
225
|
+
if pa.types.is_null(arrow_type):
|
226
|
+
values = list(map(lambda row: row[i], sql_batch))
|
227
|
+
concrete_value = next(v for v in values if v is not None)
|
228
|
+
if concrete_value is not None:
|
229
|
+
arrow_type = _data.DataMapping.python_to_arrow_type(type(concrete_value))
|
230
|
+
arrow_field = pa.field(schema.names[i], arrow_type, nullable=True)
|
231
|
+
schema = schema.remove(i).insert(i, arrow_field)
|
232
|
+
else:
|
233
|
+
python_type = _data.DataMapping.arrow_to_python_type(arrow_type)
|
234
|
+
values = map(lambda row: self._driver.decode_sql_value(row[i], python_type), sql_batch)
|
235
|
+
|
236
|
+
py_dict[col] = pa.array(values, type=arrow_type)
|
237
|
+
|
238
|
+
return pa.RecordBatch.from_pydict(py_dict, schema)
|
239
|
+
|
240
|
+
def _encode_sql_rows_tuple(self, table: pa.Table) -> tp.Iterator[tp.Tuple]:
|
241
|
+
|
242
|
+
for row in range(0, table.num_rows):
|
243
|
+
row_values = map(lambda col: self._driver.encode_sql_value(col[row].as_py()), table.columns)
|
244
|
+
yield tuple(row_values)
|
245
|
+
|
246
|
+
def _encode_sql_rows_dict(self, table: pa.Table) -> tp.Iterator[tp.Tuple]:
|
247
|
+
|
248
|
+
for row in range(0, table.num_rows):
|
249
|
+
row_values = map(lambda col: self._driver.encode_sql_value(col[row].as_py()), table.columns)
|
250
|
+
yield dict(zip(table.column_names, row_values))
|
251
|
+
|
252
|
+
|
253
|
+
class SqlStorageProvider(IStorageProvider):
|
254
|
+
|
255
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
256
|
+
self._properties = properties
|
257
|
+
|
258
|
+
def has_data_storage(self) -> bool:
|
259
|
+
return True
|
260
|
+
|
261
|
+
def get_data_storage(self) -> IDataStorageBase:
|
262
|
+
return SqlDataStorage(self._properties)
|
263
|
+
|
264
|
+
|
265
|
+
# Register with the plugin manager
|
266
|
+
plugins.PluginManager.register_plugin(IStorageProvider, SqlStorageProvider, ["SQL"])
|
267
|
+
|
268
|
+
|
269
|
+
try:
|
270
|
+
|
271
|
+
import sqlalchemy as sqla # noqa
|
272
|
+
import sqlalchemy.exc as sqla_exc # noqa
|
273
|
+
|
274
|
+
class SqlAlchemyDriver(ISqlDriver):
|
275
|
+
|
276
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
277
|
+
|
278
|
+
self._log = _helpers.logger_for_object(self)
|
279
|
+
|
280
|
+
raw_url = properties.get('url')
|
281
|
+
|
282
|
+
if raw_url is None or raw_url.strip() == '':
|
283
|
+
raise ex.EConfigLoad("Missing required property [url] for SQL driver [alchemy]")
|
284
|
+
|
285
|
+
url = urlp.urlparse(raw_url)
|
286
|
+
credentials = _helpers.get_http_credentials(url, properties)
|
287
|
+
url = _helpers.apply_http_credentials(url, credentials)
|
288
|
+
|
289
|
+
filtered_keys = ["url", "username", "password", "token"]
|
290
|
+
filtered_props = dict(kv for kv in properties.items() if kv[0] not in filtered_keys)
|
291
|
+
|
292
|
+
self._log.info("Connecting: %s", _helpers.log_safe_url(url))
|
293
|
+
|
294
|
+
try:
|
295
|
+
self.__engine = sqla.create_engine(url.geturl(), **filtered_props)
|
296
|
+
except ModuleNotFoundError as e:
|
297
|
+
raise ex.EPluginNotAvailable("SQL driver is not available: " + str(e)) from e
|
298
|
+
|
299
|
+
def param_style(self) -> "DbApiWrapper.ParamStyle":
|
300
|
+
return DbApiWrapper.ParamStyle.NAMED
|
301
|
+
|
302
|
+
def connect(self, **kwargs) -> "DbApiWrapper.Connection":
|
303
|
+
|
304
|
+
return SqlAlchemyDriver.ConnectionWrapper(self.__engine.connect())
|
305
|
+
|
306
|
+
def has_table(self, table_name: str):
|
307
|
+
|
308
|
+
with self.__engine.connect() as conn:
|
309
|
+
inspection = sqla.inspect(conn)
|
310
|
+
return inspection.has_table(table_name)
|
311
|
+
|
312
|
+
def list_tables(self):
|
313
|
+
|
314
|
+
with self.__engine.connect() as conn:
|
315
|
+
inspection = sqla.inspect(conn)
|
316
|
+
return inspection.get_table_names()
|
317
|
+
|
318
|
+
def encode_sql_value(self, py_value: tp.Any) -> tp.Any:
|
319
|
+
|
320
|
+
return py_value
|
321
|
+
|
322
|
+
def decode_sql_value(self, sql_value: tp.Any, python_type: tp.Type) -> tp.Any:
|
323
|
+
|
324
|
+
return sql_value
|
325
|
+
|
326
|
+
@contextlib.contextmanager
|
327
|
+
def error_handling(self) -> contextlib.contextmanager:
|
328
|
+
|
329
|
+
try:
|
330
|
+
yield
|
331
|
+
except (sqla_exc.OperationalError, sqla_exc.ProgrammingError, sqla_exc.StatementError) as e:
|
332
|
+
raise ex.EStorageRequest(*e.args) from e
|
333
|
+
except sqla_exc.SQLAlchemyError as e:
|
334
|
+
raise ex.EStorage() from e
|
335
|
+
|
336
|
+
class ConnectionWrapper(DbApiWrapper.Connection):
|
337
|
+
|
338
|
+
def __init__(self, conn: sqla.Connection):
|
339
|
+
self.__conn = conn
|
340
|
+
|
341
|
+
def close(self):
|
342
|
+
self.__conn.close()
|
343
|
+
|
344
|
+
def commit(self):
|
345
|
+
self.__conn.commit()
|
346
|
+
|
347
|
+
def rollback(self):
|
348
|
+
self.__conn.rollback()
|
349
|
+
|
350
|
+
def cursor(self) -> "DbApiWrapper.Cursor":
|
351
|
+
return SqlAlchemyDriver.CursorWrapper(self.__conn)
|
352
|
+
|
353
|
+
class CursorWrapper(DbApiWrapper.Cursor):
|
354
|
+
|
355
|
+
arraysize: int = 1000
|
356
|
+
|
357
|
+
def __init__(self, conn: sqla.Connection):
|
358
|
+
self.__conn = conn
|
359
|
+
self.__result: tp.Optional[sqla.CursorResult] = None
|
360
|
+
|
361
|
+
@property
|
362
|
+
def description(self):
|
363
|
+
|
364
|
+
# Prefer description from the underlying cursor if available
|
365
|
+
if self.__result.cursor is not None and self.__result.cursor.description:
|
366
|
+
return self.__result.cursor.description
|
367
|
+
|
368
|
+
if not self.__result.returns_rows:
|
369
|
+
return None
|
370
|
+
|
371
|
+
# SQL Alchemy sometimes closes the cursor and the description is lost
|
372
|
+
# Fall back on using the Result API to generate a description with field names only
|
373
|
+
|
374
|
+
def name_only_field_desc(field_name):
|
375
|
+
return field_name, None, None, None, None, None, None
|
376
|
+
|
377
|
+
return list(map(name_only_field_desc, self.__result.keys()))
|
378
|
+
|
379
|
+
@property
|
380
|
+
def rowcount(self) -> int:
|
381
|
+
|
382
|
+
# Prefer the value from the underlying cursor if it is available
|
383
|
+
if self.__result.cursor is not None:
|
384
|
+
return self.__result.cursor.rowcount
|
385
|
+
|
386
|
+
return self.__result.rowcount # noqa
|
387
|
+
|
388
|
+
def execute(self, statement: str, parameters: tp.Union[tp.Dict, tp.Sequence]):
|
389
|
+
|
390
|
+
self.__result = self.__conn.execute(sqla.text(statement), parameters)
|
391
|
+
|
392
|
+
def executemany(self, statement: str, parameters: tp.Iterable[tp.Union[tp.Dict, tp.Sequence]]):
|
393
|
+
|
394
|
+
if not isinstance(parameters, tp.List):
|
395
|
+
parameters = list(parameters)
|
396
|
+
|
397
|
+
self.__result = self.__conn.execute(sqla.text(statement), parameters)
|
398
|
+
|
399
|
+
def fetchone(self) -> tp.Tuple:
|
400
|
+
|
401
|
+
row = self.__result.fetchone()
|
402
|
+
return row.tuple() if row is not None else None
|
403
|
+
|
404
|
+
def fetchmany(self, size: int = arraysize) -> tp.Sequence[tp.Tuple]:
|
405
|
+
|
406
|
+
sqla_rows = self.__result.fetchmany(self.arraysize)
|
407
|
+
return list(map(sqla.Row.tuple, sqla_rows)) # noqa
|
408
|
+
|
409
|
+
def close(self):
|
410
|
+
|
411
|
+
if self.__result is not None:
|
412
|
+
self.__result.close()
|
413
|
+
|
414
|
+
plugins.PluginManager.register_plugin(ISqlDriver, SqlAlchemyDriver, ["alchemy"])
|
415
|
+
|
416
|
+
except ModuleNotFoundError:
|
417
|
+
pass
|
@@ -0,0 +1,117 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import typing as tp
|
16
|
+
|
17
|
+
import pyarrow as pa
|
18
|
+
|
19
|
+
import tracdap.rt.exceptions as ex
|
20
|
+
import tracdap.rt.ext.plugins as plugins
|
21
|
+
|
22
|
+
from tracdap.rt._impl.ext.sql import * # noqa
|
23
|
+
|
24
|
+
|
25
|
+
|
26
|
+
class AnsiStandardDialect(ISqlDialect):
|
27
|
+
|
28
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
29
|
+
|
30
|
+
if pa.types.is_boolean(arrow_type):
|
31
|
+
return "boolean"
|
32
|
+
|
33
|
+
if pa.types.is_integer(arrow_type):
|
34
|
+
return "bigint"
|
35
|
+
|
36
|
+
if pa.types.is_floating(arrow_type):
|
37
|
+
return "double precision"
|
38
|
+
|
39
|
+
if pa.types.is_decimal(arrow_type):
|
40
|
+
return "decimal (31, 10)"
|
41
|
+
|
42
|
+
if pa.types.is_string(arrow_type):
|
43
|
+
return "varchar(4096)"
|
44
|
+
|
45
|
+
if pa.types.is_date(arrow_type):
|
46
|
+
return "date"
|
47
|
+
|
48
|
+
if pa.types.is_timestamp(arrow_type):
|
49
|
+
return "timestamp (6)"
|
50
|
+
|
51
|
+
raise ex.ETracInternal(f"Unsupported data type [{str(arrow_type)}] in SQL dialect [{self.__class__.__name__}]")
|
52
|
+
|
53
|
+
|
54
|
+
class MySqlDialect(AnsiStandardDialect):
|
55
|
+
|
56
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
57
|
+
self._properties = properties
|
58
|
+
|
59
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
60
|
+
|
61
|
+
if pa.types.is_floating(arrow_type):
|
62
|
+
return "double"
|
63
|
+
|
64
|
+
if pa.types.is_string(arrow_type):
|
65
|
+
return "varchar(8192)"
|
66
|
+
|
67
|
+
return super().arrow_to_sql_type(arrow_type)
|
68
|
+
|
69
|
+
|
70
|
+
class MariaDbDialect(MySqlDialect):
|
71
|
+
|
72
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
73
|
+
super().__init__(properties)
|
74
|
+
|
75
|
+
# Inherit MySQL implementation
|
76
|
+
pass
|
77
|
+
|
78
|
+
|
79
|
+
class PostgresqlDialect(AnsiStandardDialect):
|
80
|
+
|
81
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
82
|
+
self._properties = properties
|
83
|
+
|
84
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
85
|
+
|
86
|
+
if pa.types.is_string(arrow_type):
|
87
|
+
return "varchar"
|
88
|
+
|
89
|
+
return super().arrow_to_sql_type(arrow_type)
|
90
|
+
|
91
|
+
|
92
|
+
class SqlServerDialect(AnsiStandardDialect):
|
93
|
+
|
94
|
+
def __init__(self, properties: tp.Dict[str, str]):
|
95
|
+
self._properties = properties
|
96
|
+
|
97
|
+
def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
|
98
|
+
|
99
|
+
if pa.types.is_boolean(arrow_type):
|
100
|
+
return "bit"
|
101
|
+
|
102
|
+
if pa.types.is_floating(arrow_type):
|
103
|
+
return "float(53)"
|
104
|
+
|
105
|
+
if pa.types.is_string(arrow_type):
|
106
|
+
return "varchar(8000)"
|
107
|
+
|
108
|
+
if pa.types.is_timestamp(arrow_type):
|
109
|
+
return "datetime2"
|
110
|
+
|
111
|
+
return super().arrow_to_sql_type(arrow_type)
|
112
|
+
|
113
|
+
|
114
|
+
plugins.PluginManager.register_plugin(ISqlDialect, MySqlDialect, ["mysql"])
|
115
|
+
plugins.PluginManager.register_plugin(ISqlDialect, MariaDbDialect, ["mariadb"])
|
116
|
+
plugins.PluginManager.register_plugin(ISqlDialect, PostgresqlDialect, ["postgresql"])
|
117
|
+
plugins.PluginManager.register_plugin(ISqlDialect, SqlServerDialect, ["sqlserver"])
|
tracdap/rt/_version.py
CHANGED