dcs-sdk 1.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. data_diff/__init__.py +221 -0
  2. data_diff/__main__.py +517 -0
  3. data_diff/abcs/__init__.py +13 -0
  4. data_diff/abcs/compiler.py +27 -0
  5. data_diff/abcs/database_types.py +402 -0
  6. data_diff/config.py +141 -0
  7. data_diff/databases/__init__.py +38 -0
  8. data_diff/databases/_connect.py +323 -0
  9. data_diff/databases/base.py +1417 -0
  10. data_diff/databases/bigquery.py +376 -0
  11. data_diff/databases/clickhouse.py +217 -0
  12. data_diff/databases/databricks.py +262 -0
  13. data_diff/databases/duckdb.py +207 -0
  14. data_diff/databases/mssql.py +343 -0
  15. data_diff/databases/mysql.py +189 -0
  16. data_diff/databases/oracle.py +238 -0
  17. data_diff/databases/postgresql.py +293 -0
  18. data_diff/databases/presto.py +222 -0
  19. data_diff/databases/redis.py +93 -0
  20. data_diff/databases/redshift.py +233 -0
  21. data_diff/databases/snowflake.py +222 -0
  22. data_diff/databases/sybase.py +720 -0
  23. data_diff/databases/trino.py +73 -0
  24. data_diff/databases/vertica.py +174 -0
  25. data_diff/diff_tables.py +489 -0
  26. data_diff/errors.py +17 -0
  27. data_diff/format.py +369 -0
  28. data_diff/hashdiff_tables.py +1026 -0
  29. data_diff/info_tree.py +76 -0
  30. data_diff/joindiff_tables.py +434 -0
  31. data_diff/lexicographic_space.py +253 -0
  32. data_diff/parse_time.py +88 -0
  33. data_diff/py.typed +0 -0
  34. data_diff/queries/__init__.py +13 -0
  35. data_diff/queries/api.py +213 -0
  36. data_diff/queries/ast_classes.py +811 -0
  37. data_diff/queries/base.py +38 -0
  38. data_diff/queries/extras.py +43 -0
  39. data_diff/query_utils.py +70 -0
  40. data_diff/schema.py +67 -0
  41. data_diff/table_segment.py +583 -0
  42. data_diff/thread_utils.py +112 -0
  43. data_diff/utils.py +1022 -0
  44. data_diff/version.py +15 -0
  45. dcs_core/__init__.py +13 -0
  46. dcs_core/__main__.py +17 -0
  47. dcs_core/__version__.py +15 -0
  48. dcs_core/cli/__init__.py +13 -0
  49. dcs_core/cli/cli.py +165 -0
  50. dcs_core/core/__init__.py +19 -0
  51. dcs_core/core/common/__init__.py +13 -0
  52. dcs_core/core/common/errors.py +50 -0
  53. dcs_core/core/common/models/__init__.py +13 -0
  54. dcs_core/core/common/models/configuration.py +284 -0
  55. dcs_core/core/common/models/dashboard.py +24 -0
  56. dcs_core/core/common/models/data_source_resource.py +75 -0
  57. dcs_core/core/common/models/metric.py +160 -0
  58. dcs_core/core/common/models/profile.py +75 -0
  59. dcs_core/core/common/models/validation.py +216 -0
  60. dcs_core/core/common/models/widget.py +44 -0
  61. dcs_core/core/configuration/__init__.py +13 -0
  62. dcs_core/core/configuration/config_loader.py +139 -0
  63. dcs_core/core/configuration/configuration_parser.py +262 -0
  64. dcs_core/core/configuration/configuration_parser_arc.py +328 -0
  65. dcs_core/core/datasource/__init__.py +13 -0
  66. dcs_core/core/datasource/base.py +62 -0
  67. dcs_core/core/datasource/manager.py +112 -0
  68. dcs_core/core/datasource/search_datasource.py +421 -0
  69. dcs_core/core/datasource/sql_datasource.py +1094 -0
  70. dcs_core/core/inspect.py +163 -0
  71. dcs_core/core/logger/__init__.py +13 -0
  72. dcs_core/core/logger/base.py +32 -0
  73. dcs_core/core/logger/default_logger.py +94 -0
  74. dcs_core/core/metric/__init__.py +13 -0
  75. dcs_core/core/metric/base.py +220 -0
  76. dcs_core/core/metric/combined_metric.py +98 -0
  77. dcs_core/core/metric/custom_metric.py +34 -0
  78. dcs_core/core/metric/manager.py +137 -0
  79. dcs_core/core/metric/numeric_metric.py +403 -0
  80. dcs_core/core/metric/reliability_metric.py +90 -0
  81. dcs_core/core/profiling/__init__.py +13 -0
  82. dcs_core/core/profiling/datasource_profiling.py +136 -0
  83. dcs_core/core/profiling/numeric_field_profiling.py +72 -0
  84. dcs_core/core/profiling/text_field_profiling.py +67 -0
  85. dcs_core/core/repository/__init__.py +13 -0
  86. dcs_core/core/repository/metric_repository.py +77 -0
  87. dcs_core/core/utils/__init__.py +13 -0
  88. dcs_core/core/utils/log.py +29 -0
  89. dcs_core/core/utils/tracking.py +105 -0
  90. dcs_core/core/utils/utils.py +44 -0
  91. dcs_core/core/validation/__init__.py +13 -0
  92. dcs_core/core/validation/base.py +230 -0
  93. dcs_core/core/validation/completeness_validation.py +153 -0
  94. dcs_core/core/validation/custom_query_validation.py +24 -0
  95. dcs_core/core/validation/manager.py +282 -0
  96. dcs_core/core/validation/numeric_validation.py +276 -0
  97. dcs_core/core/validation/reliability_validation.py +91 -0
  98. dcs_core/core/validation/uniqueness_validation.py +61 -0
  99. dcs_core/core/validation/validity_validation.py +738 -0
  100. dcs_core/integrations/__init__.py +13 -0
  101. dcs_core/integrations/databases/__init__.py +13 -0
  102. dcs_core/integrations/databases/bigquery.py +187 -0
  103. dcs_core/integrations/databases/databricks.py +51 -0
  104. dcs_core/integrations/databases/db2.py +652 -0
  105. dcs_core/integrations/databases/elasticsearch.py +61 -0
  106. dcs_core/integrations/databases/mssql.py +829 -0
  107. dcs_core/integrations/databases/mysql.py +409 -0
  108. dcs_core/integrations/databases/opensearch.py +64 -0
  109. dcs_core/integrations/databases/oracle.py +719 -0
  110. dcs_core/integrations/databases/postgres.py +482 -0
  111. dcs_core/integrations/databases/redshift.py +53 -0
  112. dcs_core/integrations/databases/snowflake.py +48 -0
  113. dcs_core/integrations/databases/spark_df.py +111 -0
  114. dcs_core/integrations/databases/sybase.py +1069 -0
  115. dcs_core/integrations/storage/__init__.py +13 -0
  116. dcs_core/integrations/storage/local_file.py +149 -0
  117. dcs_core/integrations/utils/__init__.py +13 -0
  118. dcs_core/integrations/utils/utils.py +36 -0
  119. dcs_core/report/__init__.py +13 -0
  120. dcs_core/report/dashboard.py +211 -0
  121. dcs_core/report/models.py +88 -0
  122. dcs_core/report/static/assets/fonts/DMSans-Bold.ttf +0 -0
  123. dcs_core/report/static/assets/fonts/DMSans-Medium.ttf +0 -0
  124. dcs_core/report/static/assets/fonts/DMSans-Regular.ttf +0 -0
  125. dcs_core/report/static/assets/fonts/DMSans-SemiBold.ttf +0 -0
  126. dcs_core/report/static/assets/images/docs.svg +6 -0
  127. dcs_core/report/static/assets/images/github.svg +4 -0
  128. dcs_core/report/static/assets/images/logo.svg +7 -0
  129. dcs_core/report/static/assets/images/slack.svg +13 -0
  130. dcs_core/report/static/index.js +2 -0
  131. dcs_core/report/static/index.js.LICENSE.txt +3971 -0
  132. dcs_sdk/__init__.py +13 -0
  133. dcs_sdk/__main__.py +18 -0
  134. dcs_sdk/__version__.py +15 -0
  135. dcs_sdk/cli/__init__.py +13 -0
  136. dcs_sdk/cli/cli.py +163 -0
  137. dcs_sdk/sdk/__init__.py +58 -0
  138. dcs_sdk/sdk/config/__init__.py +13 -0
  139. dcs_sdk/sdk/config/config_loader.py +491 -0
  140. dcs_sdk/sdk/data_diff/__init__.py +13 -0
  141. dcs_sdk/sdk/data_diff/data_differ.py +821 -0
  142. dcs_sdk/sdk/rules/__init__.py +15 -0
  143. dcs_sdk/sdk/rules/rules_mappping.py +31 -0
  144. dcs_sdk/sdk/rules/rules_repository.py +214 -0
  145. dcs_sdk/sdk/rules/schema_rules.py +65 -0
  146. dcs_sdk/sdk/utils/__init__.py +13 -0
  147. dcs_sdk/sdk/utils/serializer.py +25 -0
  148. dcs_sdk/sdk/utils/similarity_score/__init__.py +13 -0
  149. dcs_sdk/sdk/utils/similarity_score/base_provider.py +153 -0
  150. dcs_sdk/sdk/utils/similarity_score/cosine_similarity_provider.py +39 -0
  151. dcs_sdk/sdk/utils/similarity_score/jaccard_provider.py +24 -0
  152. dcs_sdk/sdk/utils/similarity_score/levenshtein_distance_provider.py +31 -0
  153. dcs_sdk/sdk/utils/table.py +475 -0
  154. dcs_sdk/sdk/utils/themes.py +40 -0
  155. dcs_sdk/sdk/utils/utils.py +349 -0
  156. dcs_sdk-1.6.5.dist-info/METADATA +150 -0
  157. dcs_sdk-1.6.5.dist-info/RECORD +159 -0
  158. dcs_sdk-1.6.5.dist-info/WHEEL +4 -0
  159. dcs_sdk-1.6.5.dist-info/entry_points.txt +4 -0
@@ -0,0 +1,238 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import secrets
16
+ import string
17
+ import time
18
+ from typing import Any, ClassVar, Dict, List, Optional, Type
19
+
20
+ import attrs
21
+
22
+ from data_diff.abcs.database_types import (
23
+ ColType,
24
+ ColType_UUID,
25
+ DbPath,
26
+ DbTime,
27
+ Decimal,
28
+ Float,
29
+ FractionalType,
30
+ TemporalType,
31
+ Text,
32
+ Timestamp,
33
+ TimestampTZ,
34
+ )
35
+ from data_diff.databases.base import (
36
+ CHECKSUM_HEXDIGITS,
37
+ CHECKSUM_OFFSET,
38
+ MD5_HEXDIGITS,
39
+ TIMESTAMP_PRECISION_POS,
40
+ BaseDialect,
41
+ ConnectError,
42
+ QueryError,
43
+ ThreadedDatabase,
44
+ import_helper,
45
+ )
46
+ from data_diff.schema import RawColumnInfo
47
+ from data_diff.utils import match_regexps
48
+
49
+ SESSION_TIME_ZONE = None # Changed by the tests
50
+
51
+
52
+ @import_helper("oracle")
53
+ def import_oracle():
54
+ import oracledb
55
+
56
+ return oracledb
57
+
58
+
59
+ @attrs.define(frozen=False)
60
+ class Dialect(
61
+ BaseDialect,
62
+ ):
63
+ name = "Oracle"
64
+ SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
65
+ SUPPORTS_INDEXES = True
66
+ TYPE_CLASSES: Dict[str, type] = {
67
+ "NUMBER": Decimal,
68
+ "FLOAT": Float,
69
+ # Text
70
+ "CHAR": Text,
71
+ "NCHAR": Text,
72
+ "NVARCHAR2": Text,
73
+ "VARCHAR2": Text,
74
+ "DATE": Timestamp,
75
+ }
76
+ ROUNDS_ON_PREC_LOSS = True
77
+ PLACEHOLDER_TABLE = "DUAL"
78
+
79
+ def quote(self, s: str, is_table: bool = False) -> str:
80
+ if s in self.TABLE_NAMES and self.default_schema and is_table:
81
+ return f'"{self.default_schema}"."{s}"'
82
+ return f'"{s}"'
83
+
84
+ def to_string(self, s: str) -> str:
85
+ return f"cast({s} as varchar(1024))"
86
+
87
+ def limit_select(
88
+ self,
89
+ select_query: str,
90
+ offset: Optional[int] = None,
91
+ limit: Optional[int] = None,
92
+ has_order_by: Optional[bool] = None,
93
+ ) -> str:
94
+ if offset:
95
+ raise NotImplementedError("No support for OFFSET in query")
96
+
97
+ return f"SELECT * FROM ({select_query}) FETCH NEXT {limit} ROWS ONLY"
98
+
99
+ def concat(self, items: List[str]) -> str:
100
+ joined_exprs = " || ".join(items)
101
+ return f"({joined_exprs})"
102
+
103
+ def timestamp_value(self, t: DbTime) -> str:
104
+ return "timestamp '%s'" % t.isoformat(" ")
105
+
106
+ def random(self) -> str:
107
+ return "dbms_random.value"
108
+
109
+ def is_distinct_from(self, a: str, b: str) -> str:
110
+ return f"DECODE({a}, {b}, 1, 0) = 0"
111
+
112
+ def type_repr(self, t) -> str:
113
+ try:
114
+ return {
115
+ str: "VARCHAR(1024)",
116
+ }[t]
117
+ except KeyError:
118
+ return super().type_repr(t)
119
+
120
+ def constant_values(self, rows) -> str:
121
+ return " UNION ALL ".join(
122
+ "SELECT %s FROM DUAL" % ", ".join(self._constant_value(v) for v in row) for row in rows
123
+ )
124
+
125
+ def explain_as_text(self, query: str) -> str:
126
+ raise NotImplementedError("Explain not yet implemented in Oracle")
127
+
128
+ def parse_type(self, table_path: DbPath, info: RawColumnInfo) -> ColType:
129
+ regexps = {
130
+ r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp,
131
+ r"TIMESTAMP\((\d)\) WITH TIME ZONE": TimestampTZ,
132
+ r"TIMESTAMP\((\d)\)": Timestamp,
133
+ }
134
+
135
+ for m, t_cls in match_regexps(regexps, info.data_type):
136
+ precision = int(m.group(1))
137
+ return t_cls(precision=precision, rounds=self.ROUNDS_ON_PREC_LOSS)
138
+
139
+ return super().parse_type(table_path, info)
140
+
141
+ def set_timezone_to_utc(self) -> str:
142
+ return "ALTER SESSION SET TIME_ZONE = 'UTC'"
143
+
144
+ def current_timestamp(self) -> str:
145
+ return "LOCALTIMESTAMP"
146
+
147
+ def md5_as_int(self, s: str) -> str:
148
+ # standard_hash is faster than DBMS_CRYPTO.Hash
149
+ # TODO: Find a way to use UTL_RAW.CAST_TO_BINARY_INTEGER ?
150
+ return f"to_number(substr(standard_hash({s}, 'MD5'), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 'xxxxxxxxxxxxxxx') - {CHECKSUM_OFFSET}"
151
+
152
+ def md5_as_hex(self, s: str) -> str:
153
+ return f"standard_hash({s}, 'MD5')"
154
+
155
+ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str:
156
+ # Cast is necessary for correct MD5 (trimming not enough)
157
+ return f"CAST(TRIM({value}) AS VARCHAR(36))"
158
+
159
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
160
+ if coltype.rounds:
161
+ return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')"
162
+
163
+ if coltype.precision > 0:
164
+ truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.FF{coltype.precision}')"
165
+ else:
166
+ truncated = f"to_char({value}, 'YYYY-MM-DD HH24:MI:SS.')"
167
+ return f"RPAD({truncated}, {TIMESTAMP_PRECISION_POS+6}, '0')"
168
+
169
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
170
+ # FM999.9990
171
+ format_str = "FM" + "9" * (38 - coltype.precision)
172
+ if coltype.precision:
173
+ format_str += "0." + "9" * (coltype.precision - 1) + "0"
174
+ return f"to_char({value}, '{format_str}')"
175
+
176
+ def generate_view_name(self, view_name: str | None = None) -> str:
177
+ if view_name is not None:
178
+ return view_name.upper()
179
+ random_string = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
180
+ timestamp = int(time.time())
181
+ return f"view_{timestamp}_{random_string.lower()}".upper()
182
+
183
+ def parse_table_name(self, name: str) -> DbPath:
184
+ "Parse the given table name into a DbPath"
185
+ self.TABLE_NAMES.append(name.split(".")[-1])
186
+ return tuple(name.split("."))
187
+
188
+
189
+ @attrs.define(frozen=False, init=False, kw_only=True)
190
+ class Oracle(ThreadedDatabase):
191
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = Dialect
192
+ CONNECT_URI_HELP = "oracle://<user>:<password>@<host>/<database>"
193
+ CONNECT_URI_PARAMS = ["database?"]
194
+
195
+ kwargs: Dict[str, Any]
196
+ _oracle: Any
197
+ _conn: Any
198
+
199
+ def __init__(self, *, host, database, thread_count, **kw) -> None:
200
+ super().__init__(thread_count=thread_count)
201
+ self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw)
202
+ self.default_schema = kw.get("schema", None) or kw.get("user").upper()
203
+ self.dialect.default_schema = self.default_schema
204
+ self.kwargs = {k: v for k, v in self.kwargs.items() if v}
205
+ self._oracle = None
206
+ if "schema" in self.kwargs:
207
+ del self.kwargs["schema"]
208
+ self._conn = self.create_connection()
209
+
210
+ def create_connection(self):
211
+ self._oracle = import_oracle()
212
+ try:
213
+ self._conn = self._oracle.connect(**self.kwargs)
214
+ if SESSION_TIME_ZONE:
215
+ self._conn.cursor().execute(f"ALTER SESSION SET TIME_ZONE = '{SESSION_TIME_ZONE}'")
216
+ return self._conn
217
+ except Exception as e:
218
+ raise ConnectError(*e.args) from e
219
+
220
+ def _query_cursor(self, c, sql_code: str):
221
+ try:
222
+ return super()._query_cursor(c, sql_code)
223
+ except self._oracle.DatabaseError as e:
224
+ raise QueryError(e)
225
+
226
+ def select_table_schema(self, path: DbPath) -> str:
227
+ schema, name = self._normalize_table_path(path)
228
+
229
+ return (
230
+ f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, "
231
+ f"data_scale as numeric_scale, NULL as collation_name, char_length as character_maximum_length "
232
+ f"FROM ALL_TAB_COLUMNS WHERE table_name = '{name}' AND owner = '{schema}'"
233
+ )
234
+
235
+ def close(self):
236
+ super().close()
237
+ if self._conn is not None:
238
+ self._conn.close()
@@ -0,0 +1,293 @@
1
+ # Copyright 2022-present, the Waterdip Labs Pvt. Ltd.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, ClassVar, Dict, List, Tuple, Type
16
+ from urllib.parse import unquote
17
+
18
+ import attrs
19
+
20
+ from data_diff.abcs.database_types import (
21
+ JSON,
22
+ Boolean,
23
+ ColType,
24
+ Date,
25
+ DbPath,
26
+ Decimal,
27
+ Float,
28
+ FractionalType,
29
+ Integer,
30
+ Native_UUID,
31
+ TemporalType,
32
+ Text,
33
+ Time,
34
+ Timestamp,
35
+ TimestampTZ,
36
+ )
37
+ from data_diff.databases.base import (
38
+ _CHECKSUM_BITSIZE,
39
+ CHECKSUM_HEXDIGITS,
40
+ CHECKSUM_OFFSET,
41
+ MD5_HEXDIGITS,
42
+ TIMESTAMP_PRECISION_POS,
43
+ BaseDialect,
44
+ ConnectError,
45
+ QueryResult,
46
+ ThreadedDatabase,
47
+ import_helper,
48
+ )
49
+
50
+ SESSION_TIME_ZONE = None # Changed by the tests
51
+
52
+
53
+ @import_helper("postgresql")
54
+ def import_postgresql():
55
+ import psycopg2.extras
56
+
57
+ psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
58
+ return psycopg2
59
+
60
+
61
+ @attrs.define(frozen=False)
62
+ class PostgresqlDialect(BaseDialect):
63
+ name = "PostgreSQL"
64
+ ROUNDS_ON_PREC_LOSS = True
65
+ SUPPORTS_PRIMARY_KEY: ClassVar[bool] = True
66
+ SUPPORTS_INDEXES = True
67
+
68
+ # https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-DECIMAL
69
+ # without any precision or scale creates an “unconstrained numeric” column
70
+ # in which numeric values of any length can be stored, up to the implementation limits.
71
+ # https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-NUMERIC-TABLE
72
+ DEFAULT_NUMERIC_PRECISION = 16383
73
+
74
+ TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {
75
+ # Timestamps
76
+ "timestamp with time zone": TimestampTZ,
77
+ "timestamp without time zone": Timestamp,
78
+ "timestamp": Timestamp,
79
+ "date": Date,
80
+ "time with time zone": Time,
81
+ "time without time zone": Time,
82
+ # Numbers
83
+ "double precision": Float,
84
+ "real": Float,
85
+ "decimal": Decimal,
86
+ "smallint": Integer,
87
+ "integer": Integer,
88
+ "numeric": Decimal,
89
+ "bigint": Integer,
90
+ # Text
91
+ "character": Text,
92
+ "character varying": Text,
93
+ "varchar": Text,
94
+ "text": Text,
95
+ "json": JSON,
96
+ "jsonb": JSON,
97
+ "uuid": Native_UUID,
98
+ "boolean": Boolean,
99
+ }
100
+
101
+ def quote(self, s: str, is_table: bool = False) -> str:
102
+ if s in self.TABLE_NAMES and self.default_schema and is_table:
103
+ return f'"{self.default_schema}"."{s}"'
104
+ return f'"{s}"'
105
+
106
+ def to_string(self, s: str):
107
+ return f"{s}::varchar"
108
+
109
+ def concat(self, items: List[str]) -> str:
110
+ joined_exprs = " || ".join(items)
111
+ return f"({joined_exprs})"
112
+
113
+ def _convert_db_precision_to_digits(self, p: int) -> int:
114
+ # Subtracting 2 due to wierd precision issues in PostgreSQL
115
+ return super()._convert_db_precision_to_digits(p) - 2
116
+
117
+ def set_timezone_to_utc(self) -> str:
118
+ return "SET TIME ZONE 'UTC'"
119
+
120
+ def current_timestamp(self) -> str:
121
+ return "current_timestamp"
122
+
123
+ def type_repr(self, t) -> str:
124
+ if isinstance(t, TimestampTZ):
125
+ return f"timestamp ({t.precision}) with time zone"
126
+ return super().type_repr(t)
127
+
128
+ def md5_as_int(self, s: str) -> str:
129
+ return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint - {CHECKSUM_OFFSET}"
130
+
131
+ def md5_as_hex(self, s: str) -> str:
132
+ return f"md5({s})"
133
+
134
+ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
135
+ def _add_padding(coltype: TemporalType, timestamp6: str):
136
+ return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')"
137
+
138
+ try:
139
+ is_date = coltype.is_date
140
+ is_time = coltype.is_time
141
+ except:
142
+ is_date = False
143
+ is_time = False
144
+
145
+ if isinstance(coltype, Date) or is_date:
146
+ return f"cast({value} as varchar)"
147
+
148
+ if isinstance(coltype, Time) or is_time:
149
+ seconds = f"EXTRACT( epoch from {value})"
150
+ rounded = f"ROUND({seconds}, {coltype.precision})"
151
+ time_value = f"CAST('00:00:00' as time) + make_interval(0, 0, 0, 0, 0, 0, {rounded})" # 6th arg = seconds
152
+ converted = f"to_char({time_value}, 'hh24:mi:ss.ff6')"
153
+ return converted
154
+
155
+ if coltype.rounds:
156
+ # NULL value expected to return NULL after normalization
157
+ null_case_begin = f"CASE WHEN {value} IS NULL THEN NULL ELSE "
158
+ null_case_end = "END"
159
+
160
+ # 294277 or 4714 BC would be out of range, make sure we can't round to that
161
+ # TODO test timezones for overflow?
162
+ max_timestamp = "294276-12-31 23:59:59.0000"
163
+ min_timestamp = "4713-01-01 00:00:00.00 BC"
164
+ timestamp = f"least('{max_timestamp}'::timestamp(6), {value}::timestamp(6))"
165
+ timestamp = f"greatest('{min_timestamp}'::timestamp(6), {timestamp})"
166
+
167
+ interval = format((0.5 * (10 ** (-coltype.precision))), f".{coltype.precision+1}f")
168
+
169
+ rounded_timestamp = (
170
+ f"left(to_char(least('{max_timestamp}'::timestamp, {timestamp})"
171
+ f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US'),"
172
+ f"length(to_char(least('{max_timestamp}'::timestamp, {timestamp})"
173
+ f"+ interval '{interval}', 'YYYY-mm-dd HH24:MI:SS.US')) - (6-{coltype.precision}))"
174
+ )
175
+
176
+ padded = _add_padding(coltype, rounded_timestamp)
177
+ return f"{null_case_begin} {padded} {null_case_end}"
178
+
179
+ # TODO years with > 4 digits not padded correctly
180
+ # current w/ precision 6: 294276-12-31 23:59:59.0000
181
+ # should be 294276-12-31 23:59:59.000000
182
+ else:
183
+ rounded_timestamp = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')"
184
+ padded = _add_padding(coltype, rounded_timestamp)
185
+ return padded
186
+
187
+ def normalize_number(self, value: str, coltype: FractionalType) -> str:
188
+ precision = min(coltype.precision, 10)
189
+ return self.to_string(f"{value}::decimal(38, {precision})")
190
+
191
+ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
192
+ return self.to_string(f"{value}::int")
193
+
194
+ def normalize_json(self, value: str, _coltype: JSON) -> str:
195
+ return f"{value}::text"
196
+
197
+ def parse_table_name(self, name: str) -> DbPath:
198
+ "Parse the given table name into a DbPath"
199
+ self.TABLE_NAMES.append(name.split(".")[-1])
200
+ return tuple(name.split("."))
201
+
202
+
203
+ @attrs.define(frozen=False, init=False, kw_only=True)
204
+ class PostgreSQL(ThreadedDatabase):
205
+ DIALECT_CLASS: ClassVar[Type[BaseDialect]] = PostgresqlDialect
206
+ SUPPORTS_UNIQUE_CONSTAINT = True
207
+ CONNECT_URI_HELP = "postgresql://<user>:<password>@<host>/<database>"
208
+ CONNECT_URI_PARAMS = ["database?"]
209
+
210
+ _args: Dict[str, Any]
211
+ _conn: Any
212
+
213
+ def __init__(self, *, thread_count, **kw) -> None:
214
+ super().__init__(thread_count=thread_count)
215
+ self._args = kw
216
+ self.default_schema = self._args.get("schema", "public")
217
+ self.dialect.default_schema = self.default_schema
218
+
219
+ def create_connection(self):
220
+ if not self._args:
221
+ self._args["host"] = None # psycopg2 requires 1+ arguments
222
+
223
+ pg = import_postgresql()
224
+ try:
225
+ self._args["password"] = unquote(self._args["password"])
226
+ self._conn = pg.connect(
227
+ database=self._args.get("database"),
228
+ user=self._args.get("user"),
229
+ password=self._args.get("password"),
230
+ host=self._args.get("host"),
231
+ port=self._args.get("port"),
232
+ keepalives=1,
233
+ keepalives_idle=5,
234
+ keepalives_interval=2,
235
+ keepalives_count=2,
236
+ options="-c search_path={}".format(self.default_schema),
237
+ )
238
+ if SESSION_TIME_ZONE:
239
+ self._conn.cursor().execute(f"SET TIME ZONE '{SESSION_TIME_ZONE}'")
240
+ return self._conn
241
+ except pg.OperationalError as e:
242
+ raise ConnectError(*e.args) from e
243
+
244
+ def select_table_schema(self, path: DbPath) -> str:
245
+ database, schema, table = self._normalize_table_path(path)
246
+
247
+ info_schema_path = ["information_schema", "columns"]
248
+ if database:
249
+ info_schema_path.insert(0, database)
250
+ return (
251
+ f"SELECT column_name, data_type, datetime_precision, "
252
+ f"CASE WHEN data_type = 'numeric' "
253
+ f"THEN coalesce(numeric_precision, 131072 + {self.dialect.DEFAULT_NUMERIC_PRECISION}) "
254
+ f"ELSE numeric_precision END AS numeric_precision, "
255
+ f"CASE WHEN data_type = 'numeric' "
256
+ f"THEN coalesce(numeric_scale, {self.dialect.DEFAULT_NUMERIC_PRECISION}) "
257
+ f"ELSE numeric_scale END AS numeric_scale, "
258
+ f"COALESCE(collation_name, NULL) AS collation_name, "
259
+ f"CASE WHEN data_type = 'character varying' "
260
+ f"THEN character_maximum_length END AS character_maximum_length "
261
+ f"FROM {'.'.join(info_schema_path)} "
262
+ f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
263
+ )
264
+
265
+ def select_table_unique_columns(self, path: DbPath) -> str:
266
+ database, schema, table = self._normalize_table_path(path)
267
+
268
+ info_schema_path = ["information_schema", "key_column_usage"]
269
+ if database:
270
+ info_schema_path.insert(0, database)
271
+
272
+ return (
273
+ "SELECT column_name "
274
+ f"FROM {'.'.join(info_schema_path)} "
275
+ f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
276
+ )
277
+
278
+ def _normalize_table_path(self, path: DbPath) -> DbPath:
279
+ if len(path) == 1:
280
+ return None, self.default_schema, path[0]
281
+ elif len(path) == 2:
282
+ return None, path[0], path[1]
283
+ elif len(path) == 3:
284
+ return path
285
+
286
+ raise ValueError(
287
+ f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected format: table, schema.table, or database.schema.table"
288
+ )
289
+
290
+ def close(self):
291
+ super().close()
292
+ if self._conn is not None:
293
+ self._conn.close()