tracdap-runtime 0.6.5__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,10 +44,25 @@ def check_type(expected_type: tp.Type, value: tp.Any) -> bool:
44
44
  return _TypeValidator.check_type(expected_type, value)
45
45
 
46
46
 
47
+ def type_name(type_: tp.Type, qualified: bool) -> str:
48
+ return _TypeValidator._type_name(type_, qualified) # noqa
49
+
50
+
47
51
  def quick_validate_model_def(model_def: meta.ModelDefinition):
48
52
  StaticValidator.quick_validate_model_def(model_def)
49
53
 
50
54
 
55
+ def is_primitive_type(basic_type: meta.BasicType) -> bool:
56
+ return StaticValidator.is_primitive_type(basic_type)
57
+
58
+
59
+ T_SKIP_VAL = tp.TypeVar("T_SKIP_VAL")
60
+
61
+ class SkipValidation(tp.Generic[T_SKIP_VAL]):
62
+ def __init__(self, skip_type: tp.Type[T_SKIP_VAL]):
63
+ self.skip_type = skip_type
64
+
65
+
51
66
  class _TypeValidator:
52
67
 
53
68
  # The metaclass for generic types varies between versions of the typing library
@@ -56,28 +71,28 @@ class _TypeValidator:
56
71
 
57
72
  # Cache method signatures to avoid inspection on every call
58
73
  # Inspecting a function signature can take ~ half a second in Python 3.7
59
- __method_cache: tp.Dict[str, inspect.Signature] = dict()
74
+ __method_cache: tp.Dict[str, tp.Tuple[inspect.Signature, tp.Any]] = dict()
60
75
 
61
76
  _log: logging.Logger = util.logger_for_namespace(__name__)
62
77
 
63
78
  @classmethod
64
79
  def validate_signature(cls, method: tp.Callable, *args, **kwargs):
65
80
 
66
- if method.__name__ in cls.__method_cache:
67
- signature = cls.__method_cache[method.__name__]
81
+ if method.__qualname__ in cls.__method_cache:
82
+ signature, hints = cls.__method_cache[method.__qualname__]
68
83
  else:
69
84
  signature = inspect.signature(method)
70
- cls.__method_cache[method.__name__] = signature
71
-
72
- hints = tp.get_type_hints(method)
85
+ hints = tp.get_type_hints(method)
86
+ cls.__method_cache[method.__qualname__] = signature, hints
73
87
 
88
+ named_params = list(signature.parameters.keys())
74
89
  positional_index = 0
75
90
 
76
91
  for param_name, param in signature.parameters.items():
77
92
 
78
93
  param_type = hints.get(param_name)
79
94
 
80
- values = cls._select_arg(method.__name__, param, positional_index, *args, **kwargs)
95
+ values = cls._select_arg(method.__name__, param, positional_index, named_params, *args, **kwargs)
81
96
  positional_index += len(values)
82
97
 
83
98
  for value in values:
@@ -86,11 +101,12 @@ class _TypeValidator:
86
101
  @classmethod
87
102
  def validate_return_type(cls, method: tp.Callable, value: tp.Any):
88
103
 
89
- if method.__name__ in cls.__method_cache:
90
- signature = cls.__method_cache[method.__name__]
104
+ if method.__qualname__ in cls.__method_cache:
105
+ signature, hints = cls.__method_cache[method.__qualname__]
91
106
  else:
92
107
  signature = inspect.signature(method)
93
- cls.__method_cache[method.__name__] = signature
108
+ hints = tp.get_type_hints(method)
109
+ cls.__method_cache[method.__qualname__] = signature, hints
94
110
 
95
111
  correct_type = cls._validate_type(signature.return_annotation, value)
96
112
 
@@ -107,7 +123,7 @@ class _TypeValidator:
107
123
 
108
124
  @classmethod
109
125
  def _select_arg(
110
- cls, method_name: str, parameter: inspect.Parameter, positional_index,
126
+ cls, method_name: str, parameter: inspect.Parameter, positional_index, named_params,
111
127
  *args, **kwargs) -> tp.List[tp.Any]:
112
128
 
113
129
  if parameter.kind == inspect.Parameter.POSITIONAL_ONLY:
@@ -152,7 +168,7 @@ class _TypeValidator:
152
168
 
153
169
  if parameter.kind == inspect.Parameter.VAR_KEYWORD:
154
170
 
155
- raise ex.ETracInternal("Validation of VAR_KEYWORD params is not supported yet")
171
+ return [arg for kw, arg in kwargs.items() if kw not in named_params]
156
172
 
157
173
  raise ex.EUnexpected("Invalid method signature in runtime API (this is a bug)")
158
174
 
@@ -180,6 +196,12 @@ class _TypeValidator:
180
196
  if expected_type == tp.Any:
181
197
  return True
182
198
 
199
+ # Sometimes we need to validate a partial set of arguments
200
+ # Explicitly passing a SkipValidation value allows for this
201
+ if isinstance(value, SkipValidation):
202
+ if value.skip_type == expected_type:
203
+ return True
204
+
183
205
  if isinstance(expected_type, cls.__generic_metaclass):
184
206
 
185
207
  origin = util.get_origin(expected_type)
@@ -216,8 +238,31 @@ class _TypeValidator:
216
238
  all(map(lambda k: cls._validate_type(key_type, k), value.keys())) and \
217
239
  all(map(lambda v: cls._validate_type(value_type, v), value.values()))
218
240
 
241
+ if origin.__module__.startswith("tracdap.rt.api."):
242
+ return isinstance(value, origin)
243
+
219
244
  raise ex.ETracInternal(f"Validation of [{origin.__name__}] generic parameters is not supported yet")
220
245
 
246
+ # Support for generic type variables
247
+ if isinstance(expected_type, tp.TypeVar):
248
+
249
+ # If there are any constraints or a bound, those must be honoured
250
+
251
+ constraints = util.get_constraints(expected_type)
252
+ bound = util.get_bound(expected_type)
253
+
254
+ if constraints:
255
+ if not any(map(lambda c: type(value) == c, constraints)):
256
+ return False
257
+
258
+ if bound:
259
+ if not isinstance(value, bound):
260
+ return False
261
+
262
+ # So long as constraints / bound are ok, any type matches a generic type var
263
+ return True
264
+
265
+
221
266
  # Validate everything else as a concrete type
222
267
 
223
268
  # TODO: Recursive validation of types for class members using field annotations
@@ -237,7 +282,10 @@ class _TypeValidator:
237
282
  return f"Named[{named_type}]"
238
283
 
239
284
  if origin is tp.Union:
240
- return "|".join(map(cls._type_name, args))
285
+ if len(args) == 2 and args[1] == type(None):
286
+ return f"Optional[{cls._type_name(args[0])}]"
287
+ else:
288
+ return "|".join(map(cls._type_name, args))
241
289
 
242
290
  if origin is list:
243
291
  list_type = cls._type_name(args[0])
@@ -274,6 +322,11 @@ class StaticValidator:
274
322
 
275
323
  _log: logging.Logger = util.logger_for_namespace(__name__)
276
324
 
325
+ @classmethod
326
+ def is_primitive_type(cls, basic_type: meta.BasicType) -> bool:
327
+
328
+ return basic_type in cls.__PRIMITIVE_TYPES
329
+
277
330
  @classmethod
278
331
  def quick_validate_model_def(cls, model_def: meta.ModelDefinition):
279
332
 
@@ -0,0 +1,417 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import typing as tp
17
+ import urllib.parse as urlp
18
+
19
+ import pyarrow as pa
20
+
21
+ import tracdap.rt.config as cfg
22
+ import tracdap.rt.exceptions as ex
23
+ import tracdap.rt.ext.plugins as plugins
24
+
25
+ # Import storage interfaces (private extension API)
26
+ from tracdap.rt._impl.ext.storage import * # noqa
27
+ from tracdap.rt._impl.ext.sql import * # noqa
28
+
29
+ import tracdap.rt._plugins._helpers as _helpers
30
+
31
+ # TODO: Remove internal references
32
+ import tracdap.rt._impl.data as _data
33
+
34
+
35
+ class SqlDataStorage(IDataStorageBase[pa.Table, pa.Schema]):
36
+
37
+ DIALECT_PROPERTY = "dialect"
38
+ DRIVER_PROPERTY = "driver.python"
39
+
40
+ __DQL_KEYWORDS = ["select"]
41
+ __DML_KEYWORDS = ["insert", "update", "delete", "merge"]
42
+ __DDL_KEYWORDS = ["create", "alter", "drop", "grant"]
43
+
44
+ def __init__(self, properties: tp.Dict[str, str]):
45
+
46
+ self._log = _helpers.logger_for_object(self)
47
+ self._properties = properties
48
+
49
+ dialect_name = _helpers.get_plugin_property(self._properties, self.DIALECT_PROPERTY)
50
+
51
+ if dialect_name is None:
52
+ raise ex.EConfigLoad(f"Missing required property [{self.DIALECT_PROPERTY}]")
53
+
54
+ if not plugins.PluginManager.is_plugin_available(ISqlDialect, dialect_name.lower()):
55
+ raise ex.EPluginNotAvailable(f"SQL dialect [{dialect_name}] is not supported")
56
+
57
+ driver_name = _helpers.get_plugin_property(self._properties, self.DRIVER_PROPERTY)
58
+ if driver_name is None:
59
+ driver_name = dialect_name.lower()
60
+
61
+ if not plugins.PluginManager.is_plugin_available(ISqlDriver, driver_name):
62
+ raise ex.EPluginNotAvailable(f"SQL driver [{driver_name}] is not available")
63
+
64
+ driver_props = self._driver_props(driver_name)
65
+ driver_cfg = cfg.PluginConfig(protocol=driver_name.lower(), properties=driver_props)
66
+ dialect_cfg = cfg.PluginConfig(protocol=dialect_name.lower(), properties={})
67
+
68
+ self._log.info(f"Loading SQL driver [{driver_name}] for dialect [{dialect_name}]")
69
+
70
+ self._driver = plugins.PluginManager.load_plugin(ISqlDriver, driver_cfg)
71
+ self._dialect = plugins.PluginManager.load_plugin(ISqlDialect, dialect_cfg)
72
+
73
+ # Test connectivity
74
+ with self._connection():
75
+ pass
76
+
77
+ def _driver_props(self, driver_name: str) -> tp.Dict[str, str]:
78
+
79
+ driver_props = dict()
80
+ driver_filter = f"{driver_name}."
81
+
82
+ for key, value in self._properties.items():
83
+ if key.startswith(driver_filter):
84
+ dialect_key = key[len(driver_filter):]
85
+ driver_props[dialect_key] = value
86
+
87
+ return driver_props
88
+
89
+ def _connection(self) -> DbApiWrapper.Connection:
90
+
91
+ return contextlib.closing(self._driver.connect()) # noqa
92
+
93
+ def _cursor(self, conn: DbApiWrapper.Connection) -> DbApiWrapper.Cursor:
94
+
95
+ return contextlib.closing(conn.cursor()) # noqa
96
+
97
+ def data_type(self) -> tp.Type[pa.Table]:
98
+ return pa.Table
99
+
100
+ def schema_type(self) -> tp.Type[pa.Schema]:
101
+ return pa.Schema
102
+
103
+ def has_table(self, table_name: str):
104
+
105
+ with self._driver.error_handling():
106
+ return self._driver.has_table(table_name)
107
+
108
+ def list_tables(self):
109
+
110
+ with self._driver.error_handling():
111
+ return self._driver.list_tables()
112
+
113
+ def create_table(self, table_name: str, schema: pa.Schema):
114
+
115
+ with self._driver.error_handling():
116
+
117
+ def type_decl(field: pa.Field):
118
+ sql_type = self._dialect.arrow_to_sql_type(field.type)
119
+ null_qualifier = " NULL" if field.nullable else " NOT NULL"
120
+ return f"{field.name} {sql_type}{null_qualifier}"
121
+
122
+ create_fields = map(lambda i: type_decl(schema.field(i)), range(len(schema.names)))
123
+ create_stmt = f"create table {table_name} (" + ", ".join(create_fields) + ")"
124
+
125
+ with self._connection() as conn, self._cursor(conn) as cur:
126
+ cur.execute(create_stmt, [])
127
+ conn.commit() # Some drivers / dialects (Postgres) require commit for create table
128
+
129
+ def read_table(self, table_name: str) -> pa.Table:
130
+
131
+ select_stmt = f"select * from {table_name}" # noqa
132
+
133
+ return self.native_read_query(select_stmt)
134
+
135
+ def native_read_query(self, query: str, **parameters) -> pa.Table:
136
+
137
+ # Real restrictions are enforced in deployment, by permissions granted to service accounts
138
+ # This is a sanity check to catch common errors before sending a query to the backend
139
+ self._check_read_query(query)
140
+
141
+ with self._driver.error_handling():
142
+
143
+ with self._connection() as conn, self._cursor(conn) as cur:
144
+
145
+ cur.execute(query, parameters)
146
+ sql_batch = cur.fetchmany()
147
+
148
+ # Read queries should always return a result set, even if it is empty
149
+ if not cur.description:
150
+ raise ex.EStorage(f"Query did not return a result set: {query}")
151
+
152
+ arrow_schema = self._decode_sql_schema(cur.description)
153
+ arrow_batches: tp.List[pa.RecordBatch] = []
154
+
155
+ while len(sql_batch) > 0:
156
+
157
+ arrow_batch = self._decode_sql_batch(arrow_schema, sql_batch)
158
+ arrow_batches.append(arrow_batch)
159
+
160
+ # Sometimes the schema is not fully defined up front (because cur.description is not sufficient)
161
+ # If type information has been inferred from the batch, update the schema accordingly
162
+ arrow_schema = arrow_batch.schema
163
+
164
+ sql_batch = cur.fetchmany()
165
+
166
+ return pa.Table.from_batches(arrow_batches, arrow_schema) # noqa
167
+
168
+ def write_table(self, table_name: str, table: pa.Table):
169
+
170
+ with self._driver.error_handling():
171
+
172
+ insert_fields = ", ".join(table.schema.names)
173
+ insert_markers = ", ".join(f":{name}" for name in table.schema.names)
174
+ insert_stmt = f"insert into {table_name}({insert_fields}) values ({insert_markers})" # noqa
175
+
176
+ with self._connection() as conn:
177
+
178
+ # Use execute many to perform a batch write
179
+ with self._cursor(conn) as cur:
180
+ if table.num_rows > 0:
181
+ # Provider converts rows on demand, to optimize for memory
182
+ row_provider = self._encode_sql_rows_dict(table)
183
+ cur.executemany(insert_stmt, row_provider)
184
+ else:
185
+ # Do not try to insert if there are now rows to bind
186
+ pass
187
+
188
+ conn.commit()
189
+
190
+ def _check_read_query(self, query):
191
+
192
+ if not any(map(lambda keyword: keyword in query.lower(), self.__DQL_KEYWORDS)):
193
+ raise ex.EStorageRequest(f"Query is not a read query: {query}")
194
+
195
+ if any(map(lambda keyword: keyword in query.lower(), self.__DML_KEYWORDS)):
196
+ raise ex.EStorageRequest(f"Query is not a read query: {query}")
197
+
198
+ if any(map(lambda keyword: keyword in query.lower(), self.__DDL_KEYWORDS)):
199
+ raise ex.EStorageRequest(f"Query is not a read query: {query}")
200
+
201
+ @staticmethod
202
+ def _decode_sql_schema(description: tp.List[tp.Tuple]):
203
+
204
+ # TODO: Infer Python / Arrow type using DB API type code
205
+ # These codes are db-specific so decoding would probably be on a best effort basis
206
+ # However the information is public for many popular db engines
207
+ # The current logic can be kept as a fallback (set type info on reading first non-null value)
208
+
209
+ def _decode_sql_field(field_desc: tp.Tuple):
210
+ field_name, type_code, _, _, precision, scale, null_ok = field_desc
211
+ return pa.field(field_name, pa.null(), null_ok)
212
+
213
+ fields = map(_decode_sql_field, description)
214
+
215
+ return pa.schema(fields)
216
+
217
+ def _decode_sql_batch(self, schema: pa.Schema, sql_batch: tp.List[tp.Tuple]) -> pa.RecordBatch:
218
+
219
+ py_dict: tp.Dict[str, pa.Array] = {}
220
+
221
+ for i, col in enumerate(schema.names):
222
+
223
+ arrow_type = schema.types[i]
224
+
225
+ if pa.types.is_null(arrow_type):
226
+ values = list(map(lambda row: row[i], sql_batch))
227
+ concrete_value = next(v for v in values if v is not None)
228
+ if concrete_value is not None:
229
+ arrow_type = _data.DataMapping.python_to_arrow_type(type(concrete_value))
230
+ arrow_field = pa.field(schema.names[i], arrow_type, nullable=True)
231
+ schema = schema.remove(i).insert(i, arrow_field)
232
+ else:
233
+ python_type = _data.DataMapping.arrow_to_python_type(arrow_type)
234
+ values = map(lambda row: self._driver.decode_sql_value(row[i], python_type), sql_batch)
235
+
236
+ py_dict[col] = pa.array(values, type=arrow_type)
237
+
238
+ return pa.RecordBatch.from_pydict(py_dict, schema)
239
+
240
+ def _encode_sql_rows_tuple(self, table: pa.Table) -> tp.Iterator[tp.Tuple]:
241
+
242
+ for row in range(0, table.num_rows):
243
+ row_values = map(lambda col: self._driver.encode_sql_value(col[row].as_py()), table.columns)
244
+ yield tuple(row_values)
245
+
246
+ def _encode_sql_rows_dict(self, table: pa.Table) -> tp.Iterator[tp.Tuple]:
247
+
248
+ for row in range(0, table.num_rows):
249
+ row_values = map(lambda col: self._driver.encode_sql_value(col[row].as_py()), table.columns)
250
+ yield dict(zip(table.column_names, row_values))
251
+
252
+
253
+ class SqlStorageProvider(IStorageProvider):
254
+
255
+ def __init__(self, properties: tp.Dict[str, str]):
256
+ self._properties = properties
257
+
258
+ def has_data_storage(self) -> bool:
259
+ return True
260
+
261
+ def get_data_storage(self) -> IDataStorageBase:
262
+ return SqlDataStorage(self._properties)
263
+
264
+
265
+ # Register with the plugin manager
266
+ plugins.PluginManager.register_plugin(IStorageProvider, SqlStorageProvider, ["SQL"])
267
+
268
+
269
+ try:
270
+
271
+ import sqlalchemy as sqla # noqa
272
+ import sqlalchemy.exc as sqla_exc # noqa
273
+
274
+ class SqlAlchemyDriver(ISqlDriver):
275
+
276
+ def __init__(self, properties: tp.Dict[str, str]):
277
+
278
+ self._log = _helpers.logger_for_object(self)
279
+
280
+ raw_url = properties.get('url')
281
+
282
+ if raw_url is None or raw_url.strip() == '':
283
+ raise ex.EConfigLoad("Missing required property [url] for SQL driver [alchemy]")
284
+
285
+ url = urlp.urlparse(raw_url)
286
+ credentials = _helpers.get_http_credentials(url, properties)
287
+ url = _helpers.apply_http_credentials(url, credentials)
288
+
289
+ filtered_keys = ["url", "username", "password", "token"]
290
+ filtered_props = dict(kv for kv in properties.items() if kv[0] not in filtered_keys)
291
+
292
+ self._log.info("Connecting: %s", _helpers.log_safe_url(url))
293
+
294
+ try:
295
+ self.__engine = sqla.create_engine(url.geturl(), **filtered_props)
296
+ except ModuleNotFoundError as e:
297
+ raise ex.EPluginNotAvailable("SQL driver is not available: " + str(e)) from e
298
+
299
+ def param_style(self) -> "DbApiWrapper.ParamStyle":
300
+ return DbApiWrapper.ParamStyle.NAMED
301
+
302
+ def connect(self, **kwargs) -> "DbApiWrapper.Connection":
303
+
304
+ return SqlAlchemyDriver.ConnectionWrapper(self.__engine.connect())
305
+
306
+ def has_table(self, table_name: str):
307
+
308
+ with self.__engine.connect() as conn:
309
+ inspection = sqla.inspect(conn)
310
+ return inspection.has_table(table_name)
311
+
312
+ def list_tables(self):
313
+
314
+ with self.__engine.connect() as conn:
315
+ inspection = sqla.inspect(conn)
316
+ return inspection.get_table_names()
317
+
318
+ def encode_sql_value(self, py_value: tp.Any) -> tp.Any:
319
+
320
+ return py_value
321
+
322
+ def decode_sql_value(self, sql_value: tp.Any, python_type: tp.Type) -> tp.Any:
323
+
324
+ return sql_value
325
+
326
+ @contextlib.contextmanager
327
+ def error_handling(self) -> contextlib.contextmanager:
328
+
329
+ try:
330
+ yield
331
+ except (sqla_exc.OperationalError, sqla_exc.ProgrammingError, sqla_exc.StatementError) as e:
332
+ raise ex.EStorageRequest(*e.args) from e
333
+ except sqla_exc.SQLAlchemyError as e:
334
+ raise ex.EStorage() from e
335
+
336
+ class ConnectionWrapper(DbApiWrapper.Connection):
337
+
338
+ def __init__(self, conn: sqla.Connection):
339
+ self.__conn = conn
340
+
341
+ def close(self):
342
+ self.__conn.close()
343
+
344
+ def commit(self):
345
+ self.__conn.commit()
346
+
347
+ def rollback(self):
348
+ self.__conn.rollback()
349
+
350
+ def cursor(self) -> "DbApiWrapper.Cursor":
351
+ return SqlAlchemyDriver.CursorWrapper(self.__conn)
352
+
353
+ class CursorWrapper(DbApiWrapper.Cursor):
354
+
355
+ arraysize: int = 1000
356
+
357
+ def __init__(self, conn: sqla.Connection):
358
+ self.__conn = conn
359
+ self.__result: tp.Optional[sqla.CursorResult] = None
360
+
361
+ @property
362
+ def description(self):
363
+
364
+ # Prefer description from the underlying cursor if available
365
+ if self.__result.cursor is not None and self.__result.cursor.description:
366
+ return self.__result.cursor.description
367
+
368
+ if not self.__result.returns_rows:
369
+ return None
370
+
371
+ # SQL Alchemy sometimes closes the cursor and the description is lost
372
+ # Fall back on using the Result API to generate a description with field names only
373
+
374
+ def name_only_field_desc(field_name):
375
+ return field_name, None, None, None, None, None, None
376
+
377
+ return list(map(name_only_field_desc, self.__result.keys()))
378
+
379
+ @property
380
+ def rowcount(self) -> int:
381
+
382
+ # Prefer the value from the underlying cursor if it is available
383
+ if self.__result.cursor is not None:
384
+ return self.__result.cursor.rowcount
385
+
386
+ return self.__result.rowcount # noqa
387
+
388
+ def execute(self, statement: str, parameters: tp.Union[tp.Dict, tp.Sequence]):
389
+
390
+ self.__result = self.__conn.execute(sqla.text(statement), parameters)
391
+
392
+ def executemany(self, statement: str, parameters: tp.Iterable[tp.Union[tp.Dict, tp.Sequence]]):
393
+
394
+ if not isinstance(parameters, tp.List):
395
+ parameters = list(parameters)
396
+
397
+ self.__result = self.__conn.execute(sqla.text(statement), parameters)
398
+
399
+ def fetchone(self) -> tp.Tuple:
400
+
401
+ row = self.__result.fetchone()
402
+ return row.tuple() if row is not None else None
403
+
404
+ def fetchmany(self, size: int = arraysize) -> tp.Sequence[tp.Tuple]:
405
+
406
+ sqla_rows = self.__result.fetchmany(self.arraysize)
407
+ return list(map(sqla.Row.tuple, sqla_rows)) # noqa
408
+
409
+ def close(self):
410
+
411
+ if self.__result is not None:
412
+ self.__result.close()
413
+
414
+ plugins.PluginManager.register_plugin(ISqlDriver, SqlAlchemyDriver, ["alchemy"])
415
+
416
+ except ModuleNotFoundError:
417
+ pass
@@ -0,0 +1,117 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import typing as tp
16
+
17
+ import pyarrow as pa
18
+
19
+ import tracdap.rt.exceptions as ex
20
+ import tracdap.rt.ext.plugins as plugins
21
+
22
+ from tracdap.rt._impl.ext.sql import * # noqa
23
+
24
+
25
+
26
+ class AnsiStandardDialect(ISqlDialect):
27
+
28
+ def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
29
+
30
+ if pa.types.is_boolean(arrow_type):
31
+ return "boolean"
32
+
33
+ if pa.types.is_integer(arrow_type):
34
+ return "bigint"
35
+
36
+ if pa.types.is_floating(arrow_type):
37
+ return "double precision"
38
+
39
+ if pa.types.is_decimal(arrow_type):
40
+ return "decimal (31, 10)"
41
+
42
+ if pa.types.is_string(arrow_type):
43
+ return "varchar(4096)"
44
+
45
+ if pa.types.is_date(arrow_type):
46
+ return "date"
47
+
48
+ if pa.types.is_timestamp(arrow_type):
49
+ return "timestamp (6)"
50
+
51
+ raise ex.ETracInternal(f"Unsupported data type [{str(arrow_type)}] in SQL dialect [{self.__class__.__name__}]")
52
+
53
+
54
+ class MySqlDialect(AnsiStandardDialect):
55
+
56
+ def __init__(self, properties: tp.Dict[str, str]):
57
+ self._properties = properties
58
+
59
+ def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
60
+
61
+ if pa.types.is_floating(arrow_type):
62
+ return "double"
63
+
64
+ if pa.types.is_string(arrow_type):
65
+ return "varchar(8192)"
66
+
67
+ return super().arrow_to_sql_type(arrow_type)
68
+
69
+
70
+ class MariaDbDialect(MySqlDialect):
71
+
72
+ def __init__(self, properties: tp.Dict[str, str]):
73
+ super().__init__(properties)
74
+
75
+ # Inherit MySQL implementation
76
+ pass
77
+
78
+
79
+ class PostgresqlDialect(AnsiStandardDialect):
80
+
81
+ def __init__(self, properties: tp.Dict[str, str]):
82
+ self._properties = properties
83
+
84
+ def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
85
+
86
+ if pa.types.is_string(arrow_type):
87
+ return "varchar"
88
+
89
+ return super().arrow_to_sql_type(arrow_type)
90
+
91
+
92
+ class SqlServerDialect(AnsiStandardDialect):
93
+
94
+ def __init__(self, properties: tp.Dict[str, str]):
95
+ self._properties = properties
96
+
97
+ def arrow_to_sql_type(self, arrow_type: pa.DataType) -> str:
98
+
99
+ if pa.types.is_boolean(arrow_type):
100
+ return "bit"
101
+
102
+ if pa.types.is_floating(arrow_type):
103
+ return "float(53)"
104
+
105
+ if pa.types.is_string(arrow_type):
106
+ return "varchar(8000)"
107
+
108
+ if pa.types.is_timestamp(arrow_type):
109
+ return "datetime2"
110
+
111
+ return super().arrow_to_sql_type(arrow_type)
112
+
113
+
114
+ plugins.PluginManager.register_plugin(ISqlDialect, MySqlDialect, ["mysql"])
115
+ plugins.PluginManager.register_plugin(ISqlDialect, MariaDbDialect, ["mariadb"])
116
+ plugins.PluginManager.register_plugin(ISqlDialect, PostgresqlDialect, ["postgresql"])
117
+ plugins.PluginManager.register_plugin(ISqlDialect, SqlServerDialect, ["sqlserver"])
tracdap/rt/_version.py CHANGED
@@ -12,4 +12,4 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- __version__ = "0.6.5"
15
+ __version__ = "0.6.6"