dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,222 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from functools import partial
16
+ from typing import Any, ClassVar, Type
17
+
18
+ import attrs
19
+
20
+ from data_diff.abcs.database_types import (
21
+ Boolean,
22
+ ColType,
23
+ ColType_UUID,
24
+ DbPath,
25
+ DbTime,
26
+ Decimal,
27
+ Float,
28
+ FractionalType,
29
+ Integer,
30
+ Native_UUID,
31
+ TemporalType,
32
+ Text,
33
+ Timestamp,
34
+ TimestampTZ,
35
+ )
36
+ from data_diff.databases.base import (
37
+ CHECKSUM_HEXDIGITS,
38
+ CHECKSUM_OFFSET,
39
+ MD5_HEXDIGITS,
40
+ TIMESTAMP_PRECISION_POS,
41
+ BaseDialect,
42
+ Database,
43
+ QueryResult,
44
+ ThreadLocalInterpreter,
45
+ import_helper,
46
+ )
47
+ from data_diff.schema import RawColumnInfo
48
+ from data_diff.utils import match_regexps
49
+
50
+
51
+ def query_cursor(c, sql_code):
52
+ try:
53
+ c.execute(sql_code)
54
+ if sql_code.lower().startswith(("select", "explain", "show")):
55
+ columns = c.description and [col[0] for col in c.description]
56
+ return QueryResult(c.fetchall(), columns)
57
+ elif sql_code.lower().startswith(("create", "drop")):
58
+ try:
59
+ c.connection.commit()
60
+ except AttributeError:
61
+ ...
62
+ except Exception as _e:
63
+ try:
64
+ c.connection.rollback()
65
+ except Exception as rollback_error:
66
+ print("Rollback failed:", rollback_error)
67
+ raise
68
+
69
+
70
+ @import_helper("presto")
71
+ def import_presto():
72
+ import prestodb
73
+
74
+ return prestodb
75
+
76
+
77
+ class Dialect(BaseDialect):
78
+ name = "Presto"
79
+ ROUNDS_ON_PREC_LOSS = True
80
+ TYPE_CLASSES = {
81
+ # Timestamps
82
+ "timestamp with time zone": TimestampTZ,
83
+ "timestamp without time zone": Timestamp,
84
+ "timestamp": Timestamp,
85
+ # Numbers
86
+ "integer": Integer,
87
+ "bigint": Integer,
88
+ "real": Float,
89
+ "double": Float,
90
+ # Text
91
+ "varchar": Text,
92
+ # Boolean
93
+ "boolean": Boolean,
94
+ # UUID
95
+ "uuid": Native_UUID,
96
+ }
97
+
98
+ def explain_as_text(self, query: str) -> str:
99
+ return f"EXPLAIN (FORMAT TEXT) {query}"
100
+
101
+ def type_repr(self, t) -> str:
102
+ if isinstance(t, TimestampTZ):
103
+ return f"timestamp with time zone"
104
+
105
+ try:
106
+ return {float: "REAL"}[t]
107
+ except KeyError:
108
+ return super().type_repr(t)
109
+
110
+ def timestamp_value(self, t: DbTime) -> str:
111
+ return f"timestamp '{t.isoformat(' ')}'"
112
+
113
+ def quote(self, s: str, is_table: bool = False):
114
+ return f'"{s}"'
115
+
116
+ def to_string(self, s: str):
117
+ return f"cast({s} as varchar)"
118
+
119
+ def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
120
+ timestamp_regexps = {
121
+ r"timestamp\((\d)\)": Timestamp,
122
+ r"timestamp\((\d)\) with time zone": TimestampTZ,
123
+ }
124
+ for m, t_cls in match_regexps(timestamp_regexps, info.data_type):
125
+ precision = int(m.group(1))
126
+ return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)
127
+
128
+ number_regexps = {r"decimal\((\d+),(\d+)\)": Decimal}
129
+ for m, n_cls in match_regexps(number_regexps, info.data_type):
130
+ _prec, scale = map(int, m.groups())
131
+ return n_cls(scale)
132
+
133
+ string_regexps = {r"varchar\((\d+)\)": Text, r"char\((\d+)\)": Text}
134
+ for m, n_cls in match_regexps(string_regexps, info.data_type):
135
+ return n_cls()
136
+
137
+ return super().parse_type(table_path, info)
138
+
139
+ def set_timezone_to_utc(self) -> str:
140
+ raise NotImplementedError()
141
+
142
+ def current_timestamp(self) -> str:
143
+ return "current_timestamp"
144
+
145
+ def md5_as_int(self, s: str) -> str:
146
+ return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0)) - {CHECKSUM_OFFSET}"
147
+
148
+ def md5_as_hex(self, s: str) -> str:
149
+ return f"to_hex(md5(to_utf8({s})))"
150
+
151
+ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
152
+ # Trim doesn't work on CHAR type
153
+ return f"TRIM(CAST({value} AS VARCHAR))"
154
+
155
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
156
+ # TODO rounds
157
+ if coltype.rounds:
158
+ s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
159
+ else:
160
+ s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')"
161
+
162
+ return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')"
163
+
164
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
165
+ return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))")
166
+
167
+ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
168
+ return self.to_string(f"cast ({value} as int)")
169
+
170
+
171
+ @attrs.define(frozen=False, init=False, kw_only=True)
172
+ class Presto(Database):
173
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
174
+ CONNECT_URI_HELP = "presto://<user>@<host>/<catalog>/<schema>"
175
+ CONNECT_URI_PARAMS = ["catalog", "schema"]
176
+
177
+ _conn: Any
178
+
179
+ def __init__(self, **kw) -> None:
180
+ super().__init__()
181
+ self.default_schema = "public"
182
+ prestodb = import_presto()
183
+
184
+ if kw.get("schema"):
185
+ self.default_schema = kw.get("schema")
186
+
187
+ if kw.get("auth") == "basic": # if auth=basic, add basic authenticator for Presto
188
+ kw["auth"] = prestodb.auth.BasicAuthentication(kw["user"], kw.pop("password"))
189
+
190
+ if "cert" in kw: # if a certificate was specified in URI, verify session with cert
191
+ cert = kw.pop("cert")
192
+ self._conn = prestodb.dbapi.connect(**kw)
193
+ self._conn._http_session.verify = cert
194
+ else:
195
+ self._conn = prestodb.dbapi.connect(**kw)
196
+
197
+ def _query(self, sql_code: str) -> list:
198
+ "Uses the standard SQL cursor interface"
199
+ c = self._conn.cursor()
200
+
201
+ if isinstance(sql_code, ThreadLocalInterpreter):
202
+ return sql_code.apply_queries(partial(query_cursor, c))
203
+
204
+ return query_cursor(c, sql_code)
205
+
206
+ def close(self):
207
+ super().close()
208
+ self._conn.close()
209
+
210
+ def select_table_schema(self, path: DbPath) -> str:
211
+ schema, table = self._normalize_table_path(path)
212
+
213
+ return (
214
+ "SELECT column_name, data_type, 3 as datetime_precision, 3 as numeric_precision, NULL as numeric_scale,"
215
+ "NULL as collation_name, NULL as character_maximum_length "
216
+ "FROM INFORMATION_SCHEMA.COLUMNS "
217
+ f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
218
+ )
219
+
220
+ @property
221
+ def is_autocommit(self) -> bool:
222
+ return False
@@ -0,0 +1,93 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import threading
17
+ import time
18
+
19
+ from loguru import logger
20
+ from redis import ConnectionPool, Redis, RedisError
21
+
22
+
23
+ class RedisBackend:
24
+ _INSTANCE = None
25
+ _lock = threading.Lock()
26
+
27
+ @classmethod
28
+ def get_instance(cls, force_recreate: bool = False):
29
+ with cls._lock:
30
+ if cls._INSTANCE is None or force_recreate:
31
+ logger.info("Creating a new RedisBackend instance (force_recreate=%s)", force_recreate)
32
+ cls._INSTANCE = cls._create_instance()
33
+ return cls._INSTANCE
34
+
35
+ @classmethod
36
+ def _create_instance(cls):
37
+ redis_url = (
38
+ os.getenv("DCS_REDIS_URL", None)
39
+ or os.getenv("DCS_RABBIT_URL", None)
40
+ or os.getenv("REDIS_URL", None)
41
+ or os.getenv("RABBIT_URL", None)
42
+ )
43
+ if not redis_url:
44
+ logger.warning("environment variable is not configured for redis")
45
+ return None
46
+ try:
47
+ pool = ConnectionPool.from_url(
48
+ redis_url,
49
+ max_connections=int(os.getenv("DCS_REDIS_MAX_CONNECTIONS", "100")),
50
+ health_check_interval=int(os.getenv("DCS_REDIS_HEALTH_CHECK_INTERVAL", "30")),
51
+ socket_connect_timeout=float(os.getenv("DCS_REDIS_SOCKET_CONNECT_TIMEOUT", "5")),
52
+ socket_timeout=float(os.getenv("DCS_REDIS_SOCKET_TIMEOUT", "5")),
53
+ socket_keepalive=True,
54
+ )
55
+ client = Redis(connection_pool=pool, decode_responses=False, retry_on_timeout=True)
56
+ # ping to ensure connection and raise early if config wrong
57
+ client.ping()
58
+ return cls(client)
59
+ except RedisError as exc:
60
+ logger.exception("Failed to create Redis client: %s", exc)
61
+ raise
62
+
63
+ def __init__(self, client: Redis):
64
+ self._client = client
65
+
66
+ @property
67
+ def client(self) -> Redis:
68
+ return self._client
69
+
70
+ def ensure_connected(self) -> bool:
71
+ try:
72
+ self._client.ping()
73
+ return True
74
+ except RedisError:
75
+ logger.warning("Redis ping failed; trying to recreate connection/pool")
76
+ try:
77
+ time.sleep(0.2)
78
+ new_instance = self.__class__._create_instance()
79
+ self.__class__._INSTANCE = new_instance
80
+ self._client = new_instance.client
81
+ self._client.ping()
82
+ logger.info("Recreated Redis client/pool successfully")
83
+ return True
84
+ except Exception as exc:
85
+ logger.exception("Failed to recreate Redis client: %s", exc)
86
+ return False
87
+
88
+ def close(self):
89
+ try:
90
+ self._client.connection_pool.disconnect()
91
+ logger.info("Closed Redis connection pool")
92
+ except Exception as e:
93
+ logger.exception("Error closing Redis connection pool: %s", e)
@@ -0,0 +1,233 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, ClassVar, Dict, Iterable, List, Tuple, Type
16
+
17
+ import attrs
18
+
19
+ from data_diff.abcs.database_types import (
20
+ JSON,
21
+ ColType,
22
+ DbPath,
23
+ Float,
24
+ FractionalType,
25
+ Integer,
26
+ TemporalType,
27
+ TimestampTZ,
28
+ )
29
+ from data_diff.databases.postgresql import (
30
+ CHECKSUM_HEXDIGITS,
31
+ CHECKSUM_OFFSET,
32
+ MD5_HEXDIGITS,
33
+ TIMESTAMP_PRECISION_POS,
34
+ BaseDialect,
35
+ PostgreSQL,
36
+ PostgresqlDialect,
37
+ )
38
+ from data_diff.schema import RawColumnInfo
39
+
40
+
41
+ @attrs.define(frozen=False)
42
+ class Dialect(PostgresqlDialect):
43
+ name = "Redshift"
44
+ TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
45
+ **PostgresqlDialect.TYPE_CLASSES,
46
+ "double": Float,
47
+ "real": Float,
48
+ "super": JSON,
49
+ "int": Integer, # Redshift Spectrum
50
+ "float": Float, # Redshift Spectrum
51
+ }
52
+ SUPPORTS_INDEXES = False
53
+
54
+ def concat(self, items: List[str]) -> str:
55
+ joined_exprs = " || ".join(items)
56
+ return f"({joined_exprs})"
57
+
58
+ def is_distinct_from(self, a: str, b: str) -> str:
59
+ return f"({a} IS NULL != {b} IS NULL) OR ({a}!={b})"
60
+
61
+ def type_repr(self, t) -> str:
62
+ if isinstance(t, TimestampTZ):
63
+ return f"timestamptz"
64
+ return super().type_repr(t)
65
+
66
+ def md5_as_int(self, s: str) -> str:
67
+ return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38) - {CHECKSUM_OFFSET}"
68
+
69
+ def md5_as_hex(self, s: str) -> str:
70
+ return f"md5({s})"
71
+
72
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
73
+ return self.to_string(f"{value}::decimal(38,{coltype.precision})")
74
+
75
+ def normalize_json(self, value: str, _coltype: JSON) -> str:
76
+ return f"nvl2({value}, json_serialize({value}), NULL)"
77
+
78
+
79
+ @attrs.define(frozen=False, init=False, kw_only=True)
80
+ class Redshift(PostgreSQL):
81
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
82
+ CONNECT_URI_HELP = "redshift://<user>:<password>@<host>/<database>"
83
+ CONNECT_URI_PARAMS = ["database?"]
84
+
85
+ def select_table_schema(self, path: DbPath) -> str:
86
+ database, schema, table = self._normalize_table_path(path)
87
+
88
+ info_schema_path = ["information_schema", "columns"]
89
+ if database:
90
+ info_schema_path.insert(0, database)
91
+
92
+ return (
93
+ f"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM {'.'.join(info_schema_path)} "
94
+ f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'"
95
+ )
96
+
97
+ def select_external_table_schema(self, path: DbPath) -> str:
98
+ database, schema, table = self._normalize_table_path(path)
99
+
100
+ db_clause = ""
101
+ if database:
102
+ db_clause = f" AND redshift_database_name = '{database.lower()}'"
103
+
104
+ return (
105
+ f"""SELECT
106
+ columnname AS column_name
107
+ , CASE WHEN external_type = 'string' THEN 'varchar' ELSE external_type END AS data_type
108
+ , NULL AS datetime_precision
109
+ , NULL AS numeric_precision
110
+ , NULL AS numeric_scale
111
+ FROM svv_external_columns
112
+ WHERE tablename = '{table.lower()}' AND schemaname = '{schema.lower()}'
113
+ """
114
+ + db_clause
115
+ )
116
+
117
+ def query_external_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
118
+ rows = self.query(self.select_external_table_schema(path), list)
119
+ if not rows:
120
+ raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
121
+
122
+ schema_dict = self._normalize_schema_info(rows)
123
+ return schema_dict
124
+
125
+ def select_view_columns(self, path: DbPath) -> str:
126
+ _, schema, table = self._normalize_table_path(path)
127
+
128
+ return """select * from pg_get_cols('{}.{}')
129
+ cols(col_name name, col_type varchar)
130
+ """.format(
131
+ schema, table
132
+ )
133
+
134
+ def query_pg_get_cols(self, path: DbPath) -> Dict[str, RawColumnInfo]:
135
+ rows = self.query(self.select_view_columns(path), list)
136
+ if not rows:
137
+ raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns")
138
+
139
+ schema_dict = self._normalize_schema_info(rows)
140
+ return schema_dict
141
+
142
+ def select_svv_columns_schema(self, path: DbPath) -> Dict[str, tuple]:
143
+ database, schema, table = self._normalize_table_path(path)
144
+
145
+ db_clause = ""
146
+ if database:
147
+ db_clause = f" AND table_catalog = '{database.lower()}'"
148
+
149
+ return (
150
+ f"""
151
+ select
152
+ distinct
153
+ column_name,
154
+ data_type,
155
+ datetime_precision,
156
+ numeric_precision,
157
+ numeric_scale
158
+ from
159
+ svv_columns
160
+ where table_name = '{table.lower()}' and table_schema = '{schema.lower()}'
161
+ """
162
+ + db_clause
163
+ )
164
+
165
+ def query_svv_columns(self, path: DbPath) -> Dict[str, RawColumnInfo]:
166
+ rows = self.query(self.select_svv_columns_schema(path), list)
167
+ if not rows:
168
+ raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
169
+
170
+ d = {
171
+ r[0]: RawColumnInfo(
172
+ column_name=r[0],
173
+ data_type=r[1],
174
+ datetime_precision=r[2],
175
+ numeric_precision=r[3],
176
+ numeric_scale=r[4],
177
+ collation_name=r[5] if len(r) > 5 else None,
178
+ )
179
+ for r in rows
180
+ }
181
+ assert len(d) == len(rows)
182
+ return d
183
+
184
+ # when using a non-information_schema source, strip (N) from type(N) etc. to match
185
+ # typical information_schema output
186
+ def _normalize_schema_info(self, rows: Iterable[Tuple[Any]]) -> Dict[str, RawColumnInfo]:
187
+ schema_dict: Dict[str, RawColumnInfo] = {}
188
+ for r in rows:
189
+ col_name = r[0]
190
+ type_info = r[1].split("(")
191
+ base_type = type_info[0]
192
+ precision = None
193
+ scale = None
194
+
195
+ if len(type_info) > 1:
196
+ if base_type == "numeric":
197
+ precision, scale = type_info[1][:-1].split(",")
198
+ precision = int(precision)
199
+ scale = int(scale)
200
+
201
+ schema_dict[col_name] = RawColumnInfo(
202
+ column_name=col_name,
203
+ data_type=base_type,
204
+ datetime_precision=None,
205
+ numeric_precision=precision,
206
+ numeric_scale=scale,
207
+ collation_name=None,
208
+ )
209
+ return schema_dict
210
+
211
+ def query_table_schema(self, path: DbPath) -> Dict[str, RawColumnInfo]:
212
+ try:
213
+ return super().query_table_schema(path)
214
+ except RuntimeError:
215
+ try:
216
+ return self.query_external_table_schema(path)
217
+ except RuntimeError:
218
+ try:
219
+ return self.query_pg_get_cols(path)
220
+ except Exception:
221
+ return self.query_svv_columns(path)
222
+
223
+ def _normalize_table_path(self, path: DbPath) -> DbPath:
224
+ if len(path) == 1:
225
+ return None, self.default_schema, path[0]
226
+ elif len(path) == 2:
227
+ return None, path[0], path[1]
228
+ elif len(path) == 3:
229
+ return path
230
+
231
+ raise ValueError(
232
+ f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
233
+ )