meerschaum 2.1.0rc2__py3-none-any.whl → 2.1.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. meerschaum/actions/bootstrap.py +1 -2
  2. meerschaum/actions/delete.py +15 -1
  3. meerschaum/actions/sync.py +4 -4
  4. meerschaum/api/routes/_pipes.py +7 -11
  5. meerschaum/config/__init__.py +0 -2
  6. meerschaum/config/_default.py +3 -0
  7. meerschaum/config/_version.py +1 -1
  8. meerschaum/config/static/__init__.py +4 -0
  9. meerschaum/connectors/sql/SQLConnector.py +43 -3
  10. meerschaum/connectors/sql/_cli.py +27 -3
  11. meerschaum/connectors/sql/_instance.py +164 -0
  12. meerschaum/connectors/sql/_pipes.py +344 -304
  13. meerschaum/connectors/sql/_sql.py +52 -14
  14. meerschaum/connectors/sql/tables/__init__.py +65 -13
  15. meerschaum/connectors/sql/tables/pipes.py +9 -0
  16. meerschaum/core/Pipe/__init__.py +1 -1
  17. meerschaum/core/Pipe/_data.py +3 -4
  18. meerschaum/core/Pipe/_delete.py +12 -2
  19. meerschaum/core/Pipe/_sync.py +2 -5
  20. meerschaum/utils/dataframe.py +20 -4
  21. meerschaum/utils/dtypes/__init__.py +15 -1
  22. meerschaum/utils/dtypes/sql.py +1 -0
  23. meerschaum/utils/sql.py +485 -64
  24. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/METADATA +1 -1
  25. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/RECORD +31 -29
  26. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/LICENSE +0 -0
  27. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/NOTICE +0 -0
  28. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/WHEEL +0 -0
  29. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/entry_points.txt +0 -0
  30. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/top_level.txt +0 -0
  31. {meerschaum-2.1.0rc2.dist-info → meerschaum-2.1.1rc1.dist-info}/zip-safe +0 -0
meerschaum/utils/sql.py CHANGED
@@ -8,7 +8,7 @@ Flavor-specific SQL tools.
8
8
 
9
9
  from __future__ import annotations
10
10
  import meerschaum as mrsm
11
- from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable
11
+ from meerschaum.utils.typing import Optional, Dict, Any, Union, List, Iterable, Tuple
12
12
  ### Preserve legacy imports.
13
13
  from meerschaum.utils.dtypes.sql import (
14
14
  DB_TO_PD_DTYPES,
@@ -16,6 +16,8 @@ from meerschaum.utils.dtypes.sql import (
16
16
  get_pd_type_from_db_type as get_pd_type,
17
17
  get_db_type_from_pd_type as get_db_type,
18
18
  )
19
+ from meerschaum.utils.warnings import warn
20
+ from meerschaum.utils.debug import dprint
19
21
 
20
22
  test_queries = {
21
23
  'default' : 'SELECT 1',
@@ -26,7 +28,7 @@ test_queries = {
26
28
  ### `table_name` is the escaped name of the table.
27
29
  ### `table` is the unescaped name of the table.
28
30
  exists_queries = {
29
- 'default' : "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0",
31
+ 'default': "SELECT COUNT(*) FROM {table_name} WHERE 1 = 0",
30
32
  }
31
33
  version_queries = {
32
34
  'default': "SELECT VERSION() AS {version_name}",
@@ -34,49 +36,94 @@ version_queries = {
34
36
  'mssql': "SELECT @@version",
35
37
  'oracle': "SELECT version from PRODUCT_COMPONENT_VERSION WHERE rownum = 1",
36
38
  }
39
+ SKIP_IF_EXISTS_FLAVORS = {'mssql'}
37
40
  update_queries = {
38
41
  'default': """
39
42
  UPDATE {target_table_name} AS f
40
43
  {sets_subquery_none}
41
44
  FROM {target_table_name} AS t
42
- INNER JOIN (SELECT DISTINCT * FROM {patch_table_name}) AS p
45
+ INNER JOIN (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p
43
46
  ON {and_subquery_t}
44
47
  WHERE
45
48
  {and_subquery_f}
49
+ {date_bounds_subquery}
50
+ """,
51
+ 'timescaledb-upsert': """
52
+ INSERT INTO {target_table_name}
53
+ SELECT {patch_cols_str}
54
+ FROM {patch_table_name}
55
+ ON CONFLICT ({join_cols_str}) DO UPDATE {sets_subquery_none_excluded}
56
+ """,
57
+ 'postgresql-upsert': """
58
+ INSERT INTO {target_table_name}
59
+ SELECT {patch_cols_str}
60
+ FROM {patch_table_name}
61
+ ON CONFLICT ({join_cols_str}) DO UPDATE {sets_subquery_none_excluded}
62
+ """,
63
+ 'citus-upsert': """
64
+ INSERT INTO {target_table_name}
65
+ SELECT {patch_cols_str}
66
+ FROM {patch_table_name}
67
+ ON CONFLICT ({join_cols_str}) DO UPDATE {sets_subquery_none_excluded}
68
+ """,
69
+ 'cockroachdb-upsert': """
70
+ INSERT INTO {target_table_name}
71
+ SELECT {patch_cols_str}
72
+ FROM {patch_table_name}
73
+ ON CONFLICT ({join_cols_str}) DO UPDATE {sets_subquery_none_excluded}
74
+ """,
75
+ 'duckdb-upsert': """
76
+ INSERT INTO {target_table_name}
77
+ SELECT {patch_cols_str}
78
+ FROM {patch_table_name}
79
+ ON CONFLICT ({join_cols_str}) DO UPDATE {sets_subquery_none_excluded}
46
80
  """,
47
81
  'mysql': """
48
82
  UPDATE {target_table_name} AS f,
49
- (SELECT DISTINCT * FROM {patch_table_name}) AS p
83
+ (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p
50
84
  {sets_subquery_f}
51
85
  WHERE
52
86
  {and_subquery_f}
87
+ {date_bounds_subquery}
88
+ """,
89
+ 'mysql-upsert': """
90
+ REPLACE INTO {target_table_name}
91
+ SELECT {patch_cols_str}
92
+ FROM {patch_table_name}
53
93
  """,
54
94
  'mariadb': """
55
95
  UPDATE {target_table_name} AS f,
56
- (SELECT DISTINCT * FROM {patch_table_name}) AS p
96
+ (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) AS p
57
97
  {sets_subquery_f}
58
98
  WHERE
59
99
  {and_subquery_f}
100
+ {date_bounds_subquery}
101
+ """,
102
+ 'mariadb-upsert': """
103
+ REPLACE INTO {target_table_name}
104
+ SELECT {patch_cols_str}
105
+ FROM {patch_table_name}
60
106
  """,
61
107
  'mssql': """
62
- MERGE {target_table_name} t
63
- USING (SELECT DISTINCT * FROM {patch_table_name}) p
64
- ON {and_subquery_t}
108
+ MERGE {target_table_name} f
109
+ USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p
110
+ ON {and_subquery_f}
111
+ {date_bounds_subquery}
65
112
  WHEN MATCHED THEN
66
113
  UPDATE
67
114
  {sets_subquery_none};
68
115
  """,
69
116
  'oracle': """
70
- MERGE INTO {target_table_name} t
71
- USING (SELECT DISTINCT * FROM {patch_table_name}) p
117
+ MERGE INTO {target_table_name} f
118
+ USING (SELECT DISTINCT {patch_cols_str} FROM {patch_table_name}) p
72
119
  ON (
73
- {and_subquery_t}
120
+ {and_subquery_f}
74
121
  )
75
122
  WHEN MATCHED THEN
76
123
  UPDATE
77
124
  {sets_subquery_none}
78
125
  WHERE (
79
- {and_subquery_t}
126
+ {and_subquery_f}
80
127
  )
81
128
  """,
82
129
  'sqlite_delete_insert': [
@@ -91,19 +138,74 @@ update_queries = {
91
138
  """,
92
139
  """
93
140
  INSERT INTO {target_table_name} AS f
94
- SELECT DISTINCT * FROM {patch_table_name} AS p
141
+ SELECT DISTINCT {patch_cols_str} FROM {patch_table_name} AS p
95
142
  """,
96
143
  ],
144
+ }
145
+ columns_types_queries = {
97
146
  'default': """
98
- UPDATE {target_table_name} AS f
99
- {sets_subquery_none}
100
- FROM {target_table_name} AS t
101
- INNER JOIN (SELECT DISTINCT * FROM {patch_table_name}) AS p
102
- ON {and_subquery_t}
103
- WHERE
104
- {and_subquery_f}
147
+ SELECT
148
+ table_catalog AS database,
149
+ table_schema AS schema,
150
+ table_name AS table,
151
+ column_name AS column,
152
+ data_type AS type
153
+ FROM information_schema.columns
154
+ WHERE table_name = '{table}'
155
+ """,
156
+ 'sqlite': """
157
+ SELECT
158
+ '' "database",
159
+ '' "schema",
160
+ m.name "table",
161
+ p.name "column",
162
+ p.type "type"
163
+ FROM sqlite_master m
164
+ LEFT OUTER JOIN pragma_table_info((m.name)) p
165
+ ON m.name <> p.name
166
+ WHERE m.type = 'table'
167
+ AND m.name = '{table}'
168
+ """,
169
+ 'mssql': """
170
+ SELECT
171
+ TABLE_CATALOG AS [database],
172
+ TABLE_SCHEMA AS [schema],
173
+ TABLE_NAME AS [table],
174
+ COLUMN_NAME AS [column],
175
+ DATA_TYPE AS [type]
176
+ FROM INFORMATION_SCHEMA.COLUMNS
177
+ WHERE TABLE_NAME = '{table}'
178
+ """,
179
+ 'mysql': """
180
+ SELECT
181
+ TABLE_SCHEMA `database`,
182
+ TABLE_SCHEMA `schema`,
183
+ TABLE_NAME `table`,
184
+ COLUMN_NAME `column`,
185
+ DATA_TYPE `type`
186
+ FROM INFORMATION_SCHEMA.COLUMNS
187
+ WHERE TABLE_NAME = '{table}'
188
+ """,
189
+ 'mariadb': """
190
+ SELECT
191
+ TABLE_SCHEMA `database`,
192
+ TABLE_SCHEMA `schema`,
193
+ TABLE_NAME `table`,
194
+ COLUMN_NAME `column`,
195
+ DATA_TYPE `type`
196
+ FROM INFORMATION_SCHEMA.COLUMNS
197
+ WHERE TABLE_NAME = '{table}'
198
+ """,
199
+ 'oracle': """
200
+ SELECT
201
+ NULL AS "database",
202
+ NULL AS "schema",
203
+ TABLE_NAME AS "table",
204
+ COLUMN_NAME AS "column",
205
+ DATA_TYPE AS "type"
206
+ FROM all_tab_columns
207
+ WHERE TABLE_NAME = '{table}'
105
208
  """,
106
-
107
209
  }
108
210
  hypertable_queries = {
109
211
  'timescaledb': 'SELECT hypertable_size(\'{table_name}\')',
@@ -135,6 +237,16 @@ max_name_lens = {
135
237
  'mariadb' : 64,
136
238
  }
137
239
  json_flavors = {'postgresql', 'timescaledb', 'citus', 'cockroachdb'}
240
+ NO_SCHEMA_FLAVORS = {'oracle', 'sqlite', 'mysql', 'mariadb'}
241
+ DEFAULT_SCHEMA_FLAVORS = {
242
+ 'postgresql': 'public',
243
+ 'timescaledb': 'public',
244
+ 'citus': 'public',
245
+ 'cockroachdb': 'public',
246
+ 'mysql': 'mysql',
247
+ 'mariadb': 'mysql',
248
+ 'mssql': 'dbo',
249
+ }
138
250
  OMIT_NULLSFIRST_FLAVORS = {'mariadb', 'mysql', 'mssql'}
139
251
 
140
252
  SINGLE_ALTER_TABLE_FLAVORS = {'duckdb', 'sqlite', 'mssql', 'oracle'}
@@ -557,12 +669,13 @@ def build_where(
557
669
  import json
558
670
  from meerschaum.config.static import STATIC_CONFIG
559
671
  from meerschaum.utils.warnings import warn
672
+ from meerschaum.utils.dtypes import value_is_null, none_if_null
560
673
  negation_prefix = STATIC_CONFIG['system']['fetch_pipes_keys']['negation_prefix']
561
674
  try:
562
675
  params_json = json.dumps(params)
563
676
  except Exception as e:
564
677
  params_json = str(params)
565
- bad_words = ['drop', '--', ';']
678
+ bad_words = ['drop ', '--', ';']
566
679
  for word in bad_words:
567
680
  if word in params_json.lower():
568
681
  warn(f"Aborting build_where() due to possible SQL injection.")
@@ -577,21 +690,47 @@ def build_where(
577
690
  _key = sql_item_name(key, connector.flavor, None)
578
691
  ### search across a list (i.e. IN syntax)
579
692
  if isinstance(value, Iterable) and not isinstance(value, (dict, str)):
580
- includes = [item for item in value if not str(item).startswith(negation_prefix)]
581
- excludes = [item for item in value if str(item).startswith(negation_prefix)]
693
+ includes = [
694
+ none_if_null(item)
695
+ for item in value
696
+ if not str(item).startswith(negation_prefix)
697
+ ]
698
+ null_includes = [item for item in includes if item is None]
699
+ not_null_includes = [item for item in includes if item is not None]
700
+ excludes = [
701
+ none_if_null(str(item)[len(negation_prefix):])
702
+ for item in value
703
+ if str(item).startswith(negation_prefix)
704
+ ]
705
+ null_excludes = [item for item in excludes if item is None]
706
+ not_null_excludes = [item for item in excludes if item is not None]
707
+
582
708
  if includes:
583
- where += f"{leading_and}{_key} IN ("
584
- for item in includes:
709
+ where += f"{leading_and}("
710
+ if not_null_includes:
711
+ where += f"{_key} IN ("
712
+ for item in not_null_includes:
585
713
  quoted_item = str(item).replace("'", "''")
586
714
  where += f"'{quoted_item}', "
587
715
  where = where[:-2] + ")"
716
+ if null_includes:
717
+ where += ("\n OR " if not_null_includes else "") + f"{_key} IS NULL"
718
+ if includes:
719
+ where += ")"
720
+
588
721
  if excludes:
589
- where += f"{leading_and}{_key} NOT IN ("
590
- for item in excludes:
591
- item = str(item)[len(negation_prefix):]
722
+ where += f"{leading_and}("
723
+ if not_null_excludes:
724
+ where += f"{_key} NOT IN ("
725
+ for item in not_null_excludes:
592
726
  quoted_item = str(item).replace("'", "''")
593
727
  where += f"'{quoted_item}', "
594
728
  where = where[:-2] + ")"
729
+ if null_excludes:
730
+ where += ("\n AND " if not_null_excludes else "") + f"{_key} IS NOT NULL"
731
+ if excludes:
732
+ where += ")"
733
+
595
734
  continue
596
735
 
597
736
  ### search a dictionary
@@ -602,10 +741,16 @@ def build_where(
602
741
 
603
742
  eq_sign = '='
604
743
  is_null = 'IS NULL'
744
+ if value_is_null(str(value).lstrip(negation_prefix)):
745
+ value = (
746
+ (negation_prefix + 'None')
747
+ if str(value).startswith(negation_prefix)
748
+ else None
749
+ )
605
750
  if str(value).startswith(negation_prefix):
606
751
  value = str(value)[len(negation_prefix):]
607
752
  eq_sign = '!='
608
- if value == 'None':
753
+ if value_is_null(value):
609
754
  value = None
610
755
  is_null = 'IS NOT NULL'
611
756
  quoted_value = str(value).replace("'", "''")
@@ -725,12 +870,148 @@ def get_sqlalchemy_table(
725
870
  return tables[truncated_table_name]
726
871
 
727
872
 
873
+ def get_table_cols_types(
874
+ table: str,
875
+ connectable: Union[
876
+ 'mrsm.connectors.sql.SQLConnector',
877
+ 'sqlalchemy.orm.session.Session',
878
+ 'sqlalchemy.engine.base.Engine'
879
+ ],
880
+ flavor: Optional[str] = None,
881
+ schema: Optional[str] = None,
882
+ database: Optional[str] = None,
883
+ debug: bool = False,
884
+ ) -> Dict[str, str]:
885
+ """
886
+ Return a dictionary mapping a table's columns to data types.
887
+ This is useful for inspecting tables creating during a not-yet-committed session.
888
+
889
+ NOTE: This may return incorrect columns if the schema is not explicitly stated.
890
+ Use this function if you are confident the table name is unique or if you have
891
+ and explicit schema.
892
+ To use the configured schema, get the columns from `get_sqlalchemy_table()` instead.
893
+
894
+ Parameters
895
+ ----------
896
+ table: str
897
+ The name of the table (unquoted).
898
+
899
+ connectable: Union[
900
+ 'mrsm.connectors.sql.SQLConnector',
901
+ 'sqlalchemy.orm.session.Session',
902
+ ]
903
+ The connection object used to fetch the columns and types.
904
+
905
+ flavor: Optional[str], default None
906
+ The database dialect flavor to use for the query.
907
+ If omitted, default to `connectable.flavor`.
908
+
909
+ schema: Optional[str], default None
910
+ If provided, restrict the query to this schema.
911
+
912
+ database: Optional[str]. default None
913
+ If provided, restrict the query to this database.
914
+
915
+ Returns
916
+ -------
917
+ A dictionary mapping column names to data types.
918
+ """
919
+ from meerschaum.connectors import SQLConnector
920
+ from meerschaum.utils.misc import filter_keywords
921
+ sqlalchemy = mrsm.attempt_import('sqlalchemy')
922
+ flavor = flavor or getattr(connectable, 'flavor', None)
923
+ if not flavor:
924
+ raise ValueError(f"Please provide a database flavor.")
925
+ if flavor == 'duckdb' and not isinstance(connectable, SQLConnector):
926
+ raise ValueError(f"You must provide a SQLConnector when using DuckDB.")
927
+ if flavor in NO_SCHEMA_FLAVORS:
928
+ schema = None
929
+ if schema is None:
930
+ schema = DEFAULT_SCHEMA_FLAVORS.get(flavor, None)
931
+ if flavor in ('sqlite', 'duckdb', 'oracle'):
932
+ database = None
933
+ if flavor == 'oracle':
934
+ table = table.upper() if table.islower() else table
935
+
936
+ cols_types_query = sqlalchemy.text(
937
+ columns_types_queries.get(
938
+ flavor,
939
+ columns_types_queries['default']
940
+ ).format(table=table)
941
+ )
942
+
943
+ cols = ['database', 'schema', 'table', 'column', 'type']
944
+ result_cols_ix = dict(enumerate(cols))
945
+
946
+ debug_kwargs = {'debug': debug} if isinstance(connectable, SQLConnector) else {}
947
+ if not debug_kwargs and debug:
948
+ dprint(cols_types_query)
949
+
950
+ try:
951
+ result_rows = (
952
+ [
953
+ row
954
+ for row in connectable.execute(cols_types_query, **debug_kwargs).fetchall()
955
+ ]
956
+ if flavor != 'duckdb'
957
+ else [
958
+ (doc[col] for col in cols)
959
+ for doc in connectable.read(cols_types_query, debug=debug).to_dict(orient='records')
960
+ ]
961
+ )
962
+ cols_types_docs = [
963
+ {
964
+ result_cols_ix[i]: val
965
+ for i, val in enumerate(row)
966
+ }
967
+ for row in result_rows
968
+ ]
969
+ cols_types_docs_filtered = [
970
+ doc
971
+ for doc in cols_types_docs
972
+ if (
973
+ (
974
+ not schema
975
+ or doc['schema'] == schema
976
+ )
977
+ and
978
+ (
979
+ not database
980
+ or doc['database'] == database
981
+ )
982
+ )
983
+ ]
984
+ if debug:
985
+ dprint(f"schema={schema}, database={database}")
986
+ for doc in cols_types_docs:
987
+ print(doc)
988
+
989
+ ### NOTE: This may return incorrect columns if the schema is not explicitly stated.
990
+ if cols_types_docs and not cols_types_docs_filtered:
991
+ cols_types_docs_filtered = cols_types_docs
992
+
993
+ return {
994
+ doc['column']: doc['type'].upper()
995
+ for doc in cols_types_docs_filtered
996
+ }
997
+ except Exception as e:
998
+ warn(f"Failed to fetch columns for table '{table}':\n{e}")
999
+ return {}
1000
+
1001
+
728
1002
  def get_update_queries(
729
1003
  target: str,
730
1004
  patch: str,
731
- connector: mrsm.connectors.sql.SQLConnector,
1005
+ connectable: Union[
1006
+ mrsm.connectors.sql.SQLConnector,
1007
+ 'sqlalchemy.orm.session.Session'
1008
+ ],
732
1009
  join_cols: Iterable[str],
1010
+ flavor: Optional[str] = None,
1011
+ upsert: bool = False,
1012
+ datetime_col: Optional[str] = None,
733
1013
  schema: Optional[str] = None,
1014
+ patch_schema: Optional[str] = None,
734
1015
  debug: bool = False,
735
1016
  ) -> List[str]:
736
1017
  """
@@ -744,16 +1025,30 @@ def get_update_queries(
744
1025
  patch: str
745
1026
  The name of the patch table. This should have the same shape as the target.
746
1027
 
747
- connector: meerschaum.connectors.sql.SQLConnector
748
- The Meerschaum `SQLConnector` which will later execute the queries.
1028
+ connectable: Union[meerschaum.connectors.sql.SQLConnector, sqlalchemy.orm.session.Session]
1029
+ The `SQLConnector` or SQLAlchemy session which will later execute the queries.
749
1030
 
750
1031
  join_cols: List[str]
751
1032
  The columns to use to join the patch to the target.
752
1033
 
1034
+ flavor: Optional[str], default None
1035
+ If using a SQLAlchemy session, provide the expected database flavor.
1036
+
1037
+ upsert: bool, default False
1038
+ If `True`, return an upsert query rather than an update.
1039
+
1040
+ datetime_col: Optional[str], default None
1041
+ If provided, bound the join query using this column as the datetime index.
1042
+ This must be present on both tables.
1043
+
753
1044
  schema: Optional[str], default None
754
1045
  If provided, use this schema when quoting the target table.
755
1046
  Defaults to `connector.schema`.
756
1047
 
1048
+ patch_schema: Optional[str], default None
1049
+ If provided, use this schema when quoting the patch table.
1050
+ Defaults to `schema`.
1051
+
757
1052
  debug: bool, default False
758
1053
  Verbosity toggle.
759
1054
 
@@ -761,61 +1056,133 @@ def get_update_queries(
761
1056
  -------
762
1057
  A list of query strings to perform the update operation.
763
1058
  """
1059
+ from meerschaum.connectors import SQLConnector
764
1060
  from meerschaum.utils.debug import dprint
765
- flavor = connector.flavor
766
- if connector.flavor == 'sqlite' and connector.db_version < '3.33.0':
1061
+ flavor = flavor or (connectable.flavor if isinstance(connectable, SQLConnector) else None)
1062
+ if not flavor:
1063
+ raise ValueError("Provide a flavor if using a SQLAlchemy session.")
1064
+ if (
1065
+ flavor == 'sqlite'
1066
+ and isinstance(connectable, SQLConnector)
1067
+ and connectable.db_version < '3.33.0'
1068
+ ):
767
1069
  flavor = 'sqlite_delete_insert'
768
- base_queries = update_queries.get(flavor, update_queries['default'])
1070
+ flavor_key = (f'{flavor}-upsert' if upsert else flavor)
1071
+ base_queries = update_queries.get(
1072
+ flavor_key,
1073
+ update_queries['default']
1074
+ )
769
1075
  if not isinstance(base_queries, list):
770
1076
  base_queries = [base_queries]
771
- schema = schema or connector.schema
772
- target_table = get_sqlalchemy_table(target, connector, schema=schema)
1077
+ schema = schema or (connectable.schema if isinstance(connectable, SQLConnector) else None)
1078
+ patch_schema = patch_schema or schema
1079
+ target_table_columns = get_table_cols_types(
1080
+ target,
1081
+ connectable,
1082
+ flavor = flavor,
1083
+ schema = schema,
1084
+ debug = debug,
1085
+ )
1086
+ patch_table_columns = get_table_cols_types(
1087
+ patch,
1088
+ connectable,
1089
+ flavor = flavor,
1090
+ schema = patch_schema,
1091
+ debug = debug,
1092
+ )
1093
+
1094
+ patch_cols_str = ', '.join(
1095
+ [
1096
+ sql_item_name(col, flavor)
1097
+ for col in target_table_columns
1098
+ ]
1099
+ )
1100
+ join_cols_str = ','.join(
1101
+ [
1102
+ sql_item_name(col, flavor)
1103
+ for col in join_cols
1104
+ ]
1105
+ )
1106
+
773
1107
  value_cols = []
1108
+ join_cols_types = []
774
1109
  if debug:
775
- dprint(f"target_table.columns: {dict(target_table.columns)}")
776
- for c in target_table.columns:
777
- c_name, c_type = c.name, str(c.type)
778
- if c_name in join_cols:
779
- continue
780
- if connector.flavor in DB_FLAVORS_CAST_DTYPES:
781
- c_type = DB_FLAVORS_CAST_DTYPES[connector.flavor].get(c_type, c_type)
782
- value_cols.append((c_name, c_type))
1110
+ dprint(f"target_table_columns:")
1111
+ mrsm.pprint(target_table_columns)
1112
+ for c_name, c_type in target_table_columns.items():
1113
+ if flavor in DB_FLAVORS_CAST_DTYPES:
1114
+ c_type = DB_FLAVORS_CAST_DTYPES[flavor].get(c_type.upper(), c_type)
1115
+ (
1116
+ join_cols_types
1117
+ if c_name in join_cols
1118
+ else value_cols
1119
+ ).append((c_name, c_type))
783
1120
  if debug:
784
1121
  dprint(f"value_cols: {value_cols}")
785
1122
 
1123
+ if not value_cols or not join_cols_types:
1124
+ return []
1125
+
786
1126
  def sets_subquery(l_prefix: str, r_prefix: str):
787
1127
  return 'SET ' + ',\n'.join([
788
1128
  (
789
- l_prefix + sql_item_name(c_name, connector.flavor, None)
1129
+ l_prefix + sql_item_name(c_name, flavor, None)
790
1130
  + ' = '
791
- + ('CAST(' if connector.flavor != 'sqlite' else '')
1131
+ + ('CAST(' if flavor != 'sqlite' else '')
792
1132
  + r_prefix
793
- + sql_item_name(c_name, connector.flavor, None)
794
- + (' AS ' if connector.flavor != 'sqlite' else '')
795
- + (c_type.replace('_', ' ') if connector.flavor != 'sqlite' else '')
796
- + (')' if connector.flavor != 'sqlite' else '')
1133
+ + sql_item_name(c_name, flavor, None)
1134
+ + (' AS ' if flavor != 'sqlite' else '')
1135
+ + (c_type.replace('_', ' ') if flavor != 'sqlite' else '')
1136
+ + (')' if flavor != 'sqlite' else '')
797
1137
  ) for c_name, c_type in value_cols
798
1138
  ])
799
1139
 
800
1140
  def and_subquery(l_prefix: str, r_prefix: str):
801
1141
  return '\nAND\n'.join([
802
1142
  (
803
- l_prefix + sql_item_name(c, connector.flavor, None)
1143
+ "COALESCE("
1144
+ + l_prefix
1145
+ + sql_item_name(c_name, flavor, None)
1146
+ + ", "
1147
+ + get_null_replacement(c_type, flavor)
1148
+ + ")"
804
1149
  + ' = '
805
- + r_prefix + sql_item_name(c, connector.flavor, None)
806
- ) for c in join_cols
1150
+ + "COALESCE("
1151
+ + r_prefix
1152
+ + sql_item_name(c_name, flavor, None)
1153
+ + ", "
1154
+ + get_null_replacement(c_type, flavor)
1155
+ + ")"
1156
+ ) for c_name, c_type in join_cols_types
807
1157
  ])
808
1158
 
809
- return [base_query.format(
810
- sets_subquery_none = sets_subquery('', 'p.'),
811
- sets_subquery_f = sets_subquery('f.', 'p.'),
812
- and_subquery_f = and_subquery('p.', 'f.'),
813
- and_subquery_t = and_subquery('p.', 't.'),
814
- target_table_name = sql_item_name(target, connector.flavor, schema),
815
- patch_table_name = sql_item_name(patch, connector.flavor, schema),
816
- ) for base_query in base_queries]
1159
+ target_table_name = sql_item_name(target, flavor, schema)
1160
+ patch_table_name = sql_item_name(patch, flavor, patch_schema)
1161
+ dt_col_name = sql_item_name(datetime_col, flavor, None) if datetime_col else None
1162
+ date_bounds_subquery = (
1163
+ f"""
1164
+ AND f.{dt_col_name} >= (SELECT MIN({dt_col_name}) FROM {patch_table_name})
1165
+ AND f.{dt_col_name} <= (SELECT MAX({dt_col_name}) FROM {patch_table_name})
1166
+ """
1167
+ if datetime_col else ""
1168
+ )
1169
+
1170
+ return [
1171
+ base_query.format(
1172
+ sets_subquery_none = sets_subquery('', 'p.'),
1173
+ sets_subquery_none_excluded = sets_subquery('', 'EXCLUDED.'),
1174
+ sets_subquery_f = sets_subquery('f.', 'p.'),
1175
+ and_subquery_f = and_subquery('p.', 'f.'),
1176
+ and_subquery_t = and_subquery('p.', 't.'),
1177
+ target_table_name = target_table_name,
1178
+ patch_table_name = patch_table_name,
1179
+ patch_cols_str = patch_cols_str,
1180
+ date_bounds_subquery = date_bounds_subquery,
1181
+ join_cols_str = join_cols_str,
1182
+ )
1183
+ for base_query in base_queries
1184
+ ]
817
1185
 
818
-
819
1186
 
820
1187
  def get_null_replacement(typ: str, flavor: str) -> str:
821
1188
  """
@@ -1030,3 +1397,57 @@ def format_cte_subquery(
1030
1397
  + (f' AS {quoted_sub_name}' if flavor != 'oracle' else '') + """
1031
1398
  """
1032
1399
  )
1400
+
1401
+
1402
+ def session_execute(
1403
+ session: 'sqlalchemy.orm.session.Session',
1404
+ queries: Union[List[str], str],
1405
+ with_results: bool = False,
1406
+ debug: bool = False,
1407
+ ) -> Union[mrsm.SuccessTuple, Tuple[mrsm.SuccessTuple, List['sqlalchemy.sql.ResultProxy']]]:
1408
+ """
1409
+ Similar to `SQLConnector.exec_queries()`, execute a list of queries
1410
+ and roll back when one fails.
1411
+
1412
+ Parameters
1413
+ ----------
1414
+ session: sqlalchemy.orm.session.Session
1415
+ A SQLAlchemy session representing a transaction.
1416
+
1417
+ queries: Union[List[str], str]
1418
+ A query or list of queries to be executed.
1419
+ If a query fails, roll back the session.
1420
+
1421
+ with_results: bool, default False
1422
+ If `True`, return a list of result objects.
1423
+
1424
+ Returns
1425
+ -------
1426
+ A `SuccessTuple` indicating the queries were successfully executed.
1427
+ If `with_results`, return the `SuccessTuple` and a list of results.
1428
+ """
1429
+ sqlalchemy = mrsm.attempt_import('sqlalchemy')
1430
+ if not isinstance(queries, list):
1431
+ queries = [queries]
1432
+ successes, msgs, results = [], [], []
1433
+ for query in queries:
1434
+ query_text = sqlalchemy.text(query)
1435
+ fail_msg = f"Failed to execute queries."
1436
+ try:
1437
+ result = session.execute(query_text)
1438
+ query_success = result is not None
1439
+ query_msg = "Success" if query_success else fail_msg
1440
+ except Exception as e:
1441
+ query_success = False
1442
+ query_msg = f"{fail_msg}\n{e}"
1443
+ result = None
1444
+ successes.append(query_success)
1445
+ msgs.append(query_msg)
1446
+ results.append(result)
1447
+ if not query_success:
1448
+ session.rollback()
1449
+ break
1450
+ success, msg = all(successes), '\n'.join(msgs)
1451
+ if with_results:
1452
+ return (success, msg), results
1453
+ return success, msg