SQLAlchemy 2.1.0b1__cp313-cp313-win_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlalchemy/__init__.py +295 -0
- sqlalchemy/connectors/__init__.py +18 -0
- sqlalchemy/connectors/aioodbc.py +161 -0
- sqlalchemy/connectors/asyncio.py +476 -0
- sqlalchemy/connectors/pyodbc.py +250 -0
- sqlalchemy/dialects/__init__.py +62 -0
- sqlalchemy/dialects/_typing.py +30 -0
- sqlalchemy/dialects/mssql/__init__.py +88 -0
- sqlalchemy/dialects/mssql/aioodbc.py +63 -0
- sqlalchemy/dialects/mssql/base.py +4110 -0
- sqlalchemy/dialects/mssql/information_schema.py +285 -0
- sqlalchemy/dialects/mssql/json.py +129 -0
- sqlalchemy/dialects/mssql/provision.py +185 -0
- sqlalchemy/dialects/mssql/pymssql.py +126 -0
- sqlalchemy/dialects/mssql/pyodbc.py +758 -0
- sqlalchemy/dialects/mysql/__init__.py +106 -0
- sqlalchemy/dialects/mysql/_mariadb_shim.py +312 -0
- sqlalchemy/dialects/mysql/aiomysql.py +226 -0
- sqlalchemy/dialects/mysql/asyncmy.py +214 -0
- sqlalchemy/dialects/mysql/base.py +3870 -0
- sqlalchemy/dialects/mysql/cymysql.py +106 -0
- sqlalchemy/dialects/mysql/dml.py +279 -0
- sqlalchemy/dialects/mysql/enumerated.py +277 -0
- sqlalchemy/dialects/mysql/expression.py +146 -0
- sqlalchemy/dialects/mysql/json.py +91 -0
- sqlalchemy/dialects/mysql/mariadb.py +67 -0
- sqlalchemy/dialects/mysql/mariadbconnector.py +330 -0
- sqlalchemy/dialects/mysql/mysqlconnector.py +296 -0
- sqlalchemy/dialects/mysql/mysqldb.py +312 -0
- sqlalchemy/dialects/mysql/provision.py +147 -0
- sqlalchemy/dialects/mysql/pymysql.py +157 -0
- sqlalchemy/dialects/mysql/pyodbc.py +156 -0
- sqlalchemy/dialects/mysql/reflection.py +724 -0
- sqlalchemy/dialects/mysql/reserved_words.py +570 -0
- sqlalchemy/dialects/mysql/types.py +845 -0
- sqlalchemy/dialects/oracle/__init__.py +83 -0
- sqlalchemy/dialects/oracle/base.py +3871 -0
- sqlalchemy/dialects/oracle/cx_oracle.py +1522 -0
- sqlalchemy/dialects/oracle/dictionary.py +507 -0
- sqlalchemy/dialects/oracle/oracledb.py +894 -0
- sqlalchemy/dialects/oracle/provision.py +288 -0
- sqlalchemy/dialects/oracle/types.py +350 -0
- sqlalchemy/dialects/oracle/vector.py +368 -0
- sqlalchemy/dialects/postgresql/__init__.py +171 -0
- sqlalchemy/dialects/postgresql/_psycopg_common.py +193 -0
- sqlalchemy/dialects/postgresql/array.py +534 -0
- sqlalchemy/dialects/postgresql/asyncpg.py +1331 -0
- sqlalchemy/dialects/postgresql/base.py +5729 -0
- sqlalchemy/dialects/postgresql/bitstring.py +327 -0
- sqlalchemy/dialects/postgresql/dml.py +360 -0
- sqlalchemy/dialects/postgresql/ext.py +593 -0
- sqlalchemy/dialects/postgresql/hstore.py +413 -0
- sqlalchemy/dialects/postgresql/json.py +407 -0
- sqlalchemy/dialects/postgresql/named_types.py +521 -0
- sqlalchemy/dialects/postgresql/operators.py +130 -0
- sqlalchemy/dialects/postgresql/pg8000.py +672 -0
- sqlalchemy/dialects/postgresql/pg_catalog.py +344 -0
- sqlalchemy/dialects/postgresql/provision.py +175 -0
- sqlalchemy/dialects/postgresql/psycopg.py +815 -0
- sqlalchemy/dialects/postgresql/psycopg2.py +887 -0
- sqlalchemy/dialects/postgresql/psycopg2cffi.py +61 -0
- sqlalchemy/dialects/postgresql/ranges.py +1002 -0
- sqlalchemy/dialects/postgresql/types.py +388 -0
- sqlalchemy/dialects/sqlite/__init__.py +57 -0
- sqlalchemy/dialects/sqlite/aiosqlite.py +321 -0
- sqlalchemy/dialects/sqlite/base.py +3050 -0
- sqlalchemy/dialects/sqlite/dml.py +279 -0
- sqlalchemy/dialects/sqlite/json.py +89 -0
- sqlalchemy/dialects/sqlite/provision.py +223 -0
- sqlalchemy/dialects/sqlite/pysqlcipher.py +157 -0
- sqlalchemy/dialects/sqlite/pysqlite.py +754 -0
- sqlalchemy/dialects/type_migration_guidelines.txt +145 -0
- sqlalchemy/engine/__init__.py +62 -0
- sqlalchemy/engine/_processors_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/engine/_processors_cy.py +92 -0
- sqlalchemy/engine/_result_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/engine/_result_cy.py +633 -0
- sqlalchemy/engine/_row_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/engine/_row_cy.py +232 -0
- sqlalchemy/engine/_util_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/engine/_util_cy.py +136 -0
- sqlalchemy/engine/base.py +3334 -0
- sqlalchemy/engine/characteristics.py +155 -0
- sqlalchemy/engine/create.py +869 -0
- sqlalchemy/engine/cursor.py +2416 -0
- sqlalchemy/engine/default.py +2393 -0
- sqlalchemy/engine/events.py +965 -0
- sqlalchemy/engine/interfaces.py +3465 -0
- sqlalchemy/engine/mock.py +134 -0
- sqlalchemy/engine/processors.py +82 -0
- sqlalchemy/engine/reflection.py +2100 -0
- sqlalchemy/engine/result.py +1932 -0
- sqlalchemy/engine/row.py +397 -0
- sqlalchemy/engine/strategies.py +16 -0
- sqlalchemy/engine/url.py +922 -0
- sqlalchemy/engine/util.py +156 -0
- sqlalchemy/event/__init__.py +26 -0
- sqlalchemy/event/api.py +220 -0
- sqlalchemy/event/attr.py +674 -0
- sqlalchemy/event/base.py +472 -0
- sqlalchemy/event/legacy.py +258 -0
- sqlalchemy/event/registry.py +390 -0
- sqlalchemy/events.py +17 -0
- sqlalchemy/exc.py +922 -0
- sqlalchemy/ext/__init__.py +11 -0
- sqlalchemy/ext/associationproxy.py +2072 -0
- sqlalchemy/ext/asyncio/__init__.py +29 -0
- sqlalchemy/ext/asyncio/base.py +281 -0
- sqlalchemy/ext/asyncio/engine.py +1475 -0
- sqlalchemy/ext/asyncio/exc.py +21 -0
- sqlalchemy/ext/asyncio/result.py +994 -0
- sqlalchemy/ext/asyncio/scoping.py +1667 -0
- sqlalchemy/ext/asyncio/session.py +1993 -0
- sqlalchemy/ext/automap.py +1701 -0
- sqlalchemy/ext/baked.py +559 -0
- sqlalchemy/ext/compiler.py +600 -0
- sqlalchemy/ext/declarative/__init__.py +65 -0
- sqlalchemy/ext/declarative/extensions.py +560 -0
- sqlalchemy/ext/horizontal_shard.py +481 -0
- sqlalchemy/ext/hybrid.py +1877 -0
- sqlalchemy/ext/indexable.py +364 -0
- sqlalchemy/ext/instrumentation.py +450 -0
- sqlalchemy/ext/mutable.py +1081 -0
- sqlalchemy/ext/orderinglist.py +439 -0
- sqlalchemy/ext/serializer.py +185 -0
- sqlalchemy/future/__init__.py +16 -0
- sqlalchemy/future/engine.py +15 -0
- sqlalchemy/inspection.py +174 -0
- sqlalchemy/log.py +283 -0
- sqlalchemy/orm/__init__.py +175 -0
- sqlalchemy/orm/_orm_constructors.py +2694 -0
- sqlalchemy/orm/_typing.py +179 -0
- sqlalchemy/orm/attributes.py +2868 -0
- sqlalchemy/orm/base.py +970 -0
- sqlalchemy/orm/bulk_persistence.py +2152 -0
- sqlalchemy/orm/clsregistry.py +582 -0
- sqlalchemy/orm/collections.py +1568 -0
- sqlalchemy/orm/context.py +3471 -0
- sqlalchemy/orm/decl_api.py +2257 -0
- sqlalchemy/orm/decl_base.py +2304 -0
- sqlalchemy/orm/dependency.py +1306 -0
- sqlalchemy/orm/descriptor_props.py +1183 -0
- sqlalchemy/orm/dynamic.py +300 -0
- sqlalchemy/orm/evaluator.py +379 -0
- sqlalchemy/orm/events.py +3386 -0
- sqlalchemy/orm/exc.py +237 -0
- sqlalchemy/orm/identity.py +302 -0
- sqlalchemy/orm/instrumentation.py +746 -0
- sqlalchemy/orm/interfaces.py +1589 -0
- sqlalchemy/orm/loading.py +1684 -0
- sqlalchemy/orm/mapped_collection.py +557 -0
- sqlalchemy/orm/mapper.py +4406 -0
- sqlalchemy/orm/path_registry.py +814 -0
- sqlalchemy/orm/persistence.py +1789 -0
- sqlalchemy/orm/properties.py +973 -0
- sqlalchemy/orm/query.py +3521 -0
- sqlalchemy/orm/relationships.py +3570 -0
- sqlalchemy/orm/scoping.py +2220 -0
- sqlalchemy/orm/session.py +5389 -0
- sqlalchemy/orm/state.py +1175 -0
- sqlalchemy/orm/state_changes.py +196 -0
- sqlalchemy/orm/strategies.py +3480 -0
- sqlalchemy/orm/strategy_options.py +2544 -0
- sqlalchemy/orm/sync.py +164 -0
- sqlalchemy/orm/unitofwork.py +798 -0
- sqlalchemy/orm/util.py +2435 -0
- sqlalchemy/orm/writeonly.py +694 -0
- sqlalchemy/pool/__init__.py +41 -0
- sqlalchemy/pool/base.py +1514 -0
- sqlalchemy/pool/events.py +372 -0
- sqlalchemy/pool/impl.py +582 -0
- sqlalchemy/py.typed +0 -0
- sqlalchemy/schema.py +72 -0
- sqlalchemy/sql/__init__.py +153 -0
- sqlalchemy/sql/_dml_constructors.py +132 -0
- sqlalchemy/sql/_elements_constructors.py +2147 -0
- sqlalchemy/sql/_orm_types.py +20 -0
- sqlalchemy/sql/_selectable_constructors.py +773 -0
- sqlalchemy/sql/_typing.py +486 -0
- sqlalchemy/sql/_util_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/sql/_util_cy.py +127 -0
- sqlalchemy/sql/annotation.py +590 -0
- sqlalchemy/sql/base.py +2602 -0
- sqlalchemy/sql/cache_key.py +1066 -0
- sqlalchemy/sql/coercions.py +1373 -0
- sqlalchemy/sql/compiler.py +8259 -0
- sqlalchemy/sql/crud.py +1807 -0
- sqlalchemy/sql/ddl.py +1928 -0
- sqlalchemy/sql/default_comparator.py +654 -0
- sqlalchemy/sql/dml.py +1974 -0
- sqlalchemy/sql/elements.py +6016 -0
- sqlalchemy/sql/events.py +458 -0
- sqlalchemy/sql/expression.py +170 -0
- sqlalchemy/sql/functions.py +2257 -0
- sqlalchemy/sql/lambdas.py +1443 -0
- sqlalchemy/sql/naming.py +209 -0
- sqlalchemy/sql/operators.py +2897 -0
- sqlalchemy/sql/roles.py +332 -0
- sqlalchemy/sql/schema.py +6560 -0
- sqlalchemy/sql/selectable.py +7497 -0
- sqlalchemy/sql/sqltypes.py +4050 -0
- sqlalchemy/sql/traversals.py +1042 -0
- sqlalchemy/sql/type_api.py +2425 -0
- sqlalchemy/sql/util.py +1495 -0
- sqlalchemy/sql/visitors.py +1157 -0
- sqlalchemy/testing/__init__.py +96 -0
- sqlalchemy/testing/assertions.py +1007 -0
- sqlalchemy/testing/assertsql.py +519 -0
- sqlalchemy/testing/asyncio.py +128 -0
- sqlalchemy/testing/config.py +440 -0
- sqlalchemy/testing/engines.py +478 -0
- sqlalchemy/testing/entities.py +117 -0
- sqlalchemy/testing/exclusions.py +476 -0
- sqlalchemy/testing/fixtures/__init__.py +30 -0
- sqlalchemy/testing/fixtures/base.py +366 -0
- sqlalchemy/testing/fixtures/mypy.py +247 -0
- sqlalchemy/testing/fixtures/orm.py +227 -0
- sqlalchemy/testing/fixtures/sql.py +538 -0
- sqlalchemy/testing/pickleable.py +155 -0
- sqlalchemy/testing/plugin/__init__.py +6 -0
- sqlalchemy/testing/plugin/bootstrap.py +51 -0
- sqlalchemy/testing/plugin/plugin_base.py +828 -0
- sqlalchemy/testing/plugin/pytestplugin.py +892 -0
- sqlalchemy/testing/profiling.py +329 -0
- sqlalchemy/testing/provision.py +596 -0
- sqlalchemy/testing/requirements.py +1973 -0
- sqlalchemy/testing/schema.py +198 -0
- sqlalchemy/testing/suite/__init__.py +19 -0
- sqlalchemy/testing/suite/test_cte.py +237 -0
- sqlalchemy/testing/suite/test_ddl.py +420 -0
- sqlalchemy/testing/suite/test_dialect.py +776 -0
- sqlalchemy/testing/suite/test_insert.py +630 -0
- sqlalchemy/testing/suite/test_reflection.py +3557 -0
- sqlalchemy/testing/suite/test_results.py +660 -0
- sqlalchemy/testing/suite/test_rowcount.py +258 -0
- sqlalchemy/testing/suite/test_select.py +2112 -0
- sqlalchemy/testing/suite/test_sequence.py +317 -0
- sqlalchemy/testing/suite/test_table_via_select.py +686 -0
- sqlalchemy/testing/suite/test_types.py +2253 -0
- sqlalchemy/testing/suite/test_unicode_ddl.py +189 -0
- sqlalchemy/testing/suite/test_update_delete.py +139 -0
- sqlalchemy/testing/util.py +535 -0
- sqlalchemy/testing/warnings.py +52 -0
- sqlalchemy/types.py +76 -0
- sqlalchemy/util/__init__.py +157 -0
- sqlalchemy/util/_collections.py +693 -0
- sqlalchemy/util/_collections_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/util/_collections_cy.pxd +8 -0
- sqlalchemy/util/_collections_cy.py +516 -0
- sqlalchemy/util/_has_cython.py +46 -0
- sqlalchemy/util/_immutabledict_cy.cp313-win_arm64.pyd +0 -0
- sqlalchemy/util/_immutabledict_cy.py +240 -0
- sqlalchemy/util/compat.py +287 -0
- sqlalchemy/util/concurrency.py +322 -0
- sqlalchemy/util/cython.py +79 -0
- sqlalchemy/util/deprecations.py +401 -0
- sqlalchemy/util/langhelpers.py +2256 -0
- sqlalchemy/util/preloaded.py +152 -0
- sqlalchemy/util/queue.py +304 -0
- sqlalchemy/util/tool_support.py +201 -0
- sqlalchemy/util/topological.py +120 -0
- sqlalchemy/util/typing.py +711 -0
- sqlalchemy-2.1.0b1.dist-info/METADATA +267 -0
- sqlalchemy-2.1.0b1.dist-info/RECORD +267 -0
- sqlalchemy-2.1.0b1.dist-info/WHEEL +5 -0
- sqlalchemy-2.1.0b1.dist-info/licenses/LICENSE +19 -0
- sqlalchemy-2.1.0b1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
# testing/assertsql.py
|
|
2
|
+
# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
|
|
3
|
+
# <see AUTHORS file>
|
|
4
|
+
#
|
|
5
|
+
# This module is part of SQLAlchemy and is released under
|
|
6
|
+
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
7
|
+
# mypy: ignore-errors
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import collections
|
|
13
|
+
import contextlib
|
|
14
|
+
import itertools
|
|
15
|
+
import re
|
|
16
|
+
|
|
17
|
+
from .. import event
|
|
18
|
+
from ..engine import url
|
|
19
|
+
from ..engine.default import DefaultDialect
|
|
20
|
+
from ..schema import BaseDDLElement
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AssertRule:
|
|
24
|
+
is_consumed = False
|
|
25
|
+
errormessage = None
|
|
26
|
+
consume_statement = True
|
|
27
|
+
|
|
28
|
+
def process_statement(self, execute_observed):
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
def no_more_statements(self):
|
|
32
|
+
assert False, (
|
|
33
|
+
"All statements are complete, but pending "
|
|
34
|
+
"assertion rules remain"
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SQLMatchRule(AssertRule):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CursorSQL(SQLMatchRule):
|
|
43
|
+
def __init__(self, statement, params=None, consume_statement=True):
|
|
44
|
+
self.statement = statement
|
|
45
|
+
self.params = params
|
|
46
|
+
self.consume_statement = consume_statement
|
|
47
|
+
|
|
48
|
+
def process_statement(self, execute_observed):
|
|
49
|
+
stmt = execute_observed.statements[0]
|
|
50
|
+
if self.statement != stmt.statement or (
|
|
51
|
+
self.params is not None and self.params != stmt.parameters
|
|
52
|
+
):
|
|
53
|
+
self.consume_statement = True
|
|
54
|
+
self.errormessage = (
|
|
55
|
+
"Testing for exact SQL %s parameters %s received %s %s"
|
|
56
|
+
% (
|
|
57
|
+
self.statement,
|
|
58
|
+
self.params,
|
|
59
|
+
stmt.statement,
|
|
60
|
+
stmt.parameters,
|
|
61
|
+
)
|
|
62
|
+
)
|
|
63
|
+
else:
|
|
64
|
+
execute_observed.statements.pop(0)
|
|
65
|
+
self.is_consumed = True
|
|
66
|
+
if not execute_observed.statements:
|
|
67
|
+
self.consume_statement = True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class CompiledSQL(SQLMatchRule):
|
|
71
|
+
def __init__(
|
|
72
|
+
self, statement, params=None, dialect="default", enable_returning=True
|
|
73
|
+
):
|
|
74
|
+
self.statement = statement
|
|
75
|
+
self.params = params
|
|
76
|
+
self.dialect = dialect
|
|
77
|
+
self.enable_returning = enable_returning
|
|
78
|
+
|
|
79
|
+
def _compare_sql(self, execute_observed, received_statement):
|
|
80
|
+
stmt = re.sub(r"[\n\t]", "", self.statement)
|
|
81
|
+
return received_statement == stmt
|
|
82
|
+
|
|
83
|
+
def _compile_dialect(self, execute_observed):
|
|
84
|
+
if self.dialect == "default":
|
|
85
|
+
dialect = DefaultDialect()
|
|
86
|
+
# this is currently what tests are expecting
|
|
87
|
+
# dialect.supports_default_values = True
|
|
88
|
+
dialect.supports_default_metavalue = True
|
|
89
|
+
|
|
90
|
+
if self.enable_returning:
|
|
91
|
+
dialect.insert_returning = dialect.update_returning = (
|
|
92
|
+
dialect.delete_returning
|
|
93
|
+
) = True
|
|
94
|
+
dialect.use_insertmanyvalues = True
|
|
95
|
+
dialect.supports_multivalues_insert = True
|
|
96
|
+
dialect.update_returning_multifrom = True
|
|
97
|
+
dialect.delete_returning_multifrom = True
|
|
98
|
+
# dialect.favor_returning_over_lastrowid = True
|
|
99
|
+
# dialect.insert_null_pk_still_autoincrements = True
|
|
100
|
+
|
|
101
|
+
# this is calculated but we need it to be True for this
|
|
102
|
+
# to look like all the current RETURNING dialects
|
|
103
|
+
assert dialect.insert_executemany_returning
|
|
104
|
+
|
|
105
|
+
return dialect
|
|
106
|
+
else:
|
|
107
|
+
return url.URL.create(self.dialect).get_dialect()()
|
|
108
|
+
|
|
109
|
+
def _received_statement(self, execute_observed):
|
|
110
|
+
"""reconstruct the statement and params in terms
|
|
111
|
+
of a target dialect, which for CompiledSQL is just DefaultDialect."""
|
|
112
|
+
|
|
113
|
+
context = execute_observed.context
|
|
114
|
+
compare_dialect = self._compile_dialect(execute_observed)
|
|
115
|
+
|
|
116
|
+
# received_statement runs a full compile(). we should not need to
|
|
117
|
+
# consider extracted_parameters; if we do this indicates some state
|
|
118
|
+
# is being sent from a previous cached query, which some misbehaviors
|
|
119
|
+
# in the ORM can cause, see #6881
|
|
120
|
+
cache_key = None # execute_observed.context.compiled.cache_key
|
|
121
|
+
extracted_parameters = (
|
|
122
|
+
None # execute_observed.context.extracted_parameters
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if "schema_translate_map" in context.execution_options:
|
|
126
|
+
map_ = context.execution_options["schema_translate_map"]
|
|
127
|
+
else:
|
|
128
|
+
map_ = None
|
|
129
|
+
|
|
130
|
+
if isinstance(execute_observed.clauseelement, BaseDDLElement):
|
|
131
|
+
compiled = execute_observed.clauseelement.compile(
|
|
132
|
+
dialect=compare_dialect,
|
|
133
|
+
schema_translate_map=map_,
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
compiled = execute_observed.clauseelement.compile(
|
|
137
|
+
cache_key=cache_key,
|
|
138
|
+
dialect=compare_dialect,
|
|
139
|
+
column_keys=context.compiled.column_keys,
|
|
140
|
+
for_executemany=context.compiled.for_executemany,
|
|
141
|
+
schema_translate_map=map_,
|
|
142
|
+
)
|
|
143
|
+
_received_statement = re.sub(r"[\n\t]", "", str(compiled))
|
|
144
|
+
parameters = execute_observed.parameters
|
|
145
|
+
|
|
146
|
+
if not parameters:
|
|
147
|
+
_received_parameters = [
|
|
148
|
+
compiled.construct_params(
|
|
149
|
+
extracted_parameters=extracted_parameters
|
|
150
|
+
)
|
|
151
|
+
]
|
|
152
|
+
else:
|
|
153
|
+
_received_parameters = [
|
|
154
|
+
compiled.construct_params(
|
|
155
|
+
m, extracted_parameters=extracted_parameters
|
|
156
|
+
)
|
|
157
|
+
for m in parameters
|
|
158
|
+
]
|
|
159
|
+
|
|
160
|
+
return _received_statement, _received_parameters
|
|
161
|
+
|
|
162
|
+
def process_statement(self, execute_observed):
|
|
163
|
+
context = execute_observed.context
|
|
164
|
+
|
|
165
|
+
_received_statement, _received_parameters = self._received_statement(
|
|
166
|
+
execute_observed
|
|
167
|
+
)
|
|
168
|
+
params = self._all_params(context)
|
|
169
|
+
|
|
170
|
+
equivalent = self._compare_sql(execute_observed, _received_statement)
|
|
171
|
+
|
|
172
|
+
if equivalent:
|
|
173
|
+
if params is not None:
|
|
174
|
+
all_params = list(params)
|
|
175
|
+
all_received = list(_received_parameters)
|
|
176
|
+
while all_params and all_received:
|
|
177
|
+
param = dict(all_params.pop(0))
|
|
178
|
+
|
|
179
|
+
for idx, received in enumerate(list(all_received)):
|
|
180
|
+
# do a positive compare only
|
|
181
|
+
for param_key in param:
|
|
182
|
+
# a key in param did not match current
|
|
183
|
+
# 'received'
|
|
184
|
+
if (
|
|
185
|
+
param_key not in received
|
|
186
|
+
or received[param_key] != param[param_key]
|
|
187
|
+
):
|
|
188
|
+
break
|
|
189
|
+
else:
|
|
190
|
+
# all keys in param matched 'received';
|
|
191
|
+
# onto next param
|
|
192
|
+
del all_received[idx]
|
|
193
|
+
break
|
|
194
|
+
else:
|
|
195
|
+
# param did not match any entry
|
|
196
|
+
# in all_received
|
|
197
|
+
equivalent = False
|
|
198
|
+
break
|
|
199
|
+
if all_params or all_received:
|
|
200
|
+
equivalent = False
|
|
201
|
+
|
|
202
|
+
if equivalent:
|
|
203
|
+
self.is_consumed = True
|
|
204
|
+
self.errormessage = None
|
|
205
|
+
else:
|
|
206
|
+
self.errormessage = self._failure_message(
|
|
207
|
+
execute_observed, params
|
|
208
|
+
) % {
|
|
209
|
+
"received_statement": _received_statement,
|
|
210
|
+
"received_parameters": _received_parameters,
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
def _all_params(self, context):
|
|
214
|
+
if self.params:
|
|
215
|
+
if callable(self.params):
|
|
216
|
+
params = self.params(context)
|
|
217
|
+
else:
|
|
218
|
+
params = self.params
|
|
219
|
+
if not isinstance(params, list):
|
|
220
|
+
params = [params]
|
|
221
|
+
return params
|
|
222
|
+
else:
|
|
223
|
+
return None
|
|
224
|
+
|
|
225
|
+
def _failure_message(self, execute_observed, expected_params):
|
|
226
|
+
return (
|
|
227
|
+
"Testing for compiled statement\n%r partial params %s, "
|
|
228
|
+
"received\n%%(received_statement)r with params "
|
|
229
|
+
"%%(received_parameters)r"
|
|
230
|
+
% (
|
|
231
|
+
self.statement.replace("%", "%%"),
|
|
232
|
+
repr(expected_params).replace("%", "%%"),
|
|
233
|
+
)
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
class RegexSQL(CompiledSQL):
|
|
238
|
+
def __init__(
|
|
239
|
+
self, regex, params=None, dialect="default", enable_returning=False
|
|
240
|
+
):
|
|
241
|
+
SQLMatchRule.__init__(self)
|
|
242
|
+
self.regex = re.compile(regex)
|
|
243
|
+
self.orig_regex = regex
|
|
244
|
+
self.params = params
|
|
245
|
+
self.dialect = dialect
|
|
246
|
+
self.enable_returning = enable_returning
|
|
247
|
+
|
|
248
|
+
def _failure_message(self, execute_observed, expected_params):
|
|
249
|
+
return (
|
|
250
|
+
"Testing for compiled statement ~%r partial params %s, "
|
|
251
|
+
"received %%(received_statement)r with params "
|
|
252
|
+
"%%(received_parameters)r"
|
|
253
|
+
% (
|
|
254
|
+
self.orig_regex.replace("%", "%%"),
|
|
255
|
+
repr(expected_params).replace("%", "%%"),
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def _compare_sql(self, execute_observed, received_statement):
|
|
260
|
+
return bool(self.regex.match(received_statement))
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class DialectSQL(CompiledSQL):
|
|
264
|
+
def _compile_dialect(self, execute_observed):
|
|
265
|
+
return execute_observed.context.dialect
|
|
266
|
+
|
|
267
|
+
def _compare_no_space(self, real_stmt, received_stmt):
|
|
268
|
+
stmt = re.sub(r"[\n\t]", "", real_stmt)
|
|
269
|
+
return received_stmt == stmt
|
|
270
|
+
|
|
271
|
+
def _received_statement(self, execute_observed):
|
|
272
|
+
received_stmt, received_params = super()._received_statement(
|
|
273
|
+
execute_observed
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# TODO: why do we need this part?
|
|
277
|
+
for real_stmt in execute_observed.statements:
|
|
278
|
+
if self._compare_no_space(
|
|
279
|
+
real_stmt.context.statement, received_stmt
|
|
280
|
+
):
|
|
281
|
+
break
|
|
282
|
+
else:
|
|
283
|
+
raise AssertionError(
|
|
284
|
+
"Can't locate compiled statement %r in list of "
|
|
285
|
+
"statements actually invoked" % received_stmt
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
return received_stmt, execute_observed.context.compiled_parameters
|
|
289
|
+
|
|
290
|
+
def _dialect_adjusted_statement(self, dialect):
|
|
291
|
+
paramstyle = dialect.paramstyle
|
|
292
|
+
stmt = re.sub(r"[\n\t]", "", self.statement)
|
|
293
|
+
|
|
294
|
+
# temporarily escape out PG double colons
|
|
295
|
+
stmt = stmt.replace("::", "!!")
|
|
296
|
+
|
|
297
|
+
if paramstyle == "pyformat":
|
|
298
|
+
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
|
|
299
|
+
else:
|
|
300
|
+
# positional params
|
|
301
|
+
repl = None
|
|
302
|
+
if paramstyle == "qmark":
|
|
303
|
+
repl = "?"
|
|
304
|
+
elif paramstyle == "format":
|
|
305
|
+
repl = r"%s"
|
|
306
|
+
elif paramstyle.startswith("numeric"):
|
|
307
|
+
counter = itertools.count(1)
|
|
308
|
+
|
|
309
|
+
num_identifier = "$" if paramstyle == "numeric_dollar" else ":"
|
|
310
|
+
|
|
311
|
+
def repl(m):
|
|
312
|
+
return f"{num_identifier}{next(counter)}"
|
|
313
|
+
|
|
314
|
+
stmt = re.sub(r":([\w_]+)", repl, stmt)
|
|
315
|
+
|
|
316
|
+
# put them back
|
|
317
|
+
stmt = stmt.replace("!!", "::")
|
|
318
|
+
|
|
319
|
+
return stmt
|
|
320
|
+
|
|
321
|
+
def _compare_sql(self, execute_observed, received_statement):
|
|
322
|
+
stmt = self._dialect_adjusted_statement(
|
|
323
|
+
execute_observed.context.dialect
|
|
324
|
+
)
|
|
325
|
+
return received_statement == stmt
|
|
326
|
+
|
|
327
|
+
def _failure_message(self, execute_observed, expected_params):
|
|
328
|
+
return (
|
|
329
|
+
"Testing for compiled statement\n%r partial params %s, "
|
|
330
|
+
"received\n%%(received_statement)r with params "
|
|
331
|
+
"%%(received_parameters)r"
|
|
332
|
+
% (
|
|
333
|
+
self._dialect_adjusted_statement(
|
|
334
|
+
execute_observed.context.dialect
|
|
335
|
+
).replace("%", "%%"),
|
|
336
|
+
repr(expected_params).replace("%", "%%"),
|
|
337
|
+
)
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class CountStatements(AssertRule):
|
|
342
|
+
def __init__(self, count):
|
|
343
|
+
self.count = count
|
|
344
|
+
self._statement_count = 0
|
|
345
|
+
|
|
346
|
+
def process_statement(self, execute_observed):
|
|
347
|
+
self._statement_count += 1
|
|
348
|
+
|
|
349
|
+
def no_more_statements(self):
|
|
350
|
+
if self.count != self._statement_count:
|
|
351
|
+
assert False, "desired statement count %d does not match %d" % (
|
|
352
|
+
self.count,
|
|
353
|
+
self._statement_count,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
class AllOf(AssertRule):
|
|
358
|
+
def __init__(self, *rules):
|
|
359
|
+
self.rules = set(rules)
|
|
360
|
+
|
|
361
|
+
def process_statement(self, execute_observed):
|
|
362
|
+
for rule in list(self.rules):
|
|
363
|
+
rule.errormessage = None
|
|
364
|
+
rule.process_statement(execute_observed)
|
|
365
|
+
if rule.is_consumed:
|
|
366
|
+
self.rules.discard(rule)
|
|
367
|
+
if not self.rules:
|
|
368
|
+
self.is_consumed = True
|
|
369
|
+
break
|
|
370
|
+
elif not rule.errormessage:
|
|
371
|
+
# rule is not done yet
|
|
372
|
+
self.errormessage = None
|
|
373
|
+
break
|
|
374
|
+
else:
|
|
375
|
+
self.errormessage = list(self.rules)[0].errormessage
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
class EachOf(AssertRule):
|
|
379
|
+
def __init__(self, *rules):
|
|
380
|
+
self.rules = list(rules)
|
|
381
|
+
|
|
382
|
+
def process_statement(self, execute_observed):
|
|
383
|
+
if not self.rules:
|
|
384
|
+
self.is_consumed = True
|
|
385
|
+
self.consume_statement = False
|
|
386
|
+
|
|
387
|
+
while self.rules:
|
|
388
|
+
rule = self.rules[0]
|
|
389
|
+
rule.process_statement(execute_observed)
|
|
390
|
+
if rule.is_consumed:
|
|
391
|
+
self.rules.pop(0)
|
|
392
|
+
elif rule.errormessage:
|
|
393
|
+
self.errormessage = rule.errormessage
|
|
394
|
+
if rule.consume_statement:
|
|
395
|
+
break
|
|
396
|
+
|
|
397
|
+
if not self.rules:
|
|
398
|
+
self.is_consumed = True
|
|
399
|
+
|
|
400
|
+
def no_more_statements(self):
|
|
401
|
+
if self.rules and not self.rules[0].is_consumed:
|
|
402
|
+
self.rules[0].no_more_statements()
|
|
403
|
+
elif self.rules:
|
|
404
|
+
super().no_more_statements()
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
class Conditional(EachOf):
|
|
408
|
+
def __init__(self, condition, rules, else_rules):
|
|
409
|
+
if condition:
|
|
410
|
+
super().__init__(*rules)
|
|
411
|
+
else:
|
|
412
|
+
super().__init__(*else_rules)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
class Or(AllOf):
|
|
416
|
+
def process_statement(self, execute_observed):
|
|
417
|
+
for rule in self.rules:
|
|
418
|
+
rule.process_statement(execute_observed)
|
|
419
|
+
if rule.is_consumed:
|
|
420
|
+
self.is_consumed = True
|
|
421
|
+
break
|
|
422
|
+
else:
|
|
423
|
+
self.errormessage = list(self.rules)[0].errormessage
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
class SQLExecuteObserved:
|
|
427
|
+
def __init__(self, context, clauseelement, multiparams, params):
|
|
428
|
+
self.context = context
|
|
429
|
+
self.clauseelement = clauseelement
|
|
430
|
+
|
|
431
|
+
if multiparams:
|
|
432
|
+
self.parameters = multiparams
|
|
433
|
+
elif params:
|
|
434
|
+
self.parameters = [params]
|
|
435
|
+
else:
|
|
436
|
+
self.parameters = []
|
|
437
|
+
self.statements = []
|
|
438
|
+
|
|
439
|
+
def __repr__(self):
|
|
440
|
+
return str(self.statements)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
class SQLCursorExecuteObserved(
|
|
444
|
+
collections.namedtuple(
|
|
445
|
+
"SQLCursorExecuteObserved",
|
|
446
|
+
["statement", "parameters", "context", "executemany"],
|
|
447
|
+
)
|
|
448
|
+
):
|
|
449
|
+
pass
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class SQLAsserter:
|
|
453
|
+
def __init__(self):
|
|
454
|
+
self.accumulated = []
|
|
455
|
+
|
|
456
|
+
def _close(self):
|
|
457
|
+
self._final = self.accumulated
|
|
458
|
+
del self.accumulated
|
|
459
|
+
|
|
460
|
+
def assert_(self, *rules):
|
|
461
|
+
rule = EachOf(*rules)
|
|
462
|
+
|
|
463
|
+
observed = list(self._final)
|
|
464
|
+
while observed:
|
|
465
|
+
statement = observed.pop(0)
|
|
466
|
+
rule.process_statement(statement)
|
|
467
|
+
if rule.is_consumed:
|
|
468
|
+
break
|
|
469
|
+
elif rule.errormessage:
|
|
470
|
+
assert False, rule.errormessage
|
|
471
|
+
if observed:
|
|
472
|
+
assert False, "Additional SQL statements remain:\n%s" % observed
|
|
473
|
+
elif not rule.is_consumed:
|
|
474
|
+
rule.no_more_statements()
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
@contextlib.contextmanager
|
|
478
|
+
def assert_engine(engine):
|
|
479
|
+
asserter = SQLAsserter()
|
|
480
|
+
|
|
481
|
+
orig = []
|
|
482
|
+
|
|
483
|
+
@event.listens_for(engine, "before_execute")
|
|
484
|
+
def connection_execute(
|
|
485
|
+
conn, clauseelement, multiparams, params, execution_options
|
|
486
|
+
):
|
|
487
|
+
# grab the original statement + params before any cursor
|
|
488
|
+
# execution
|
|
489
|
+
orig[:] = clauseelement, multiparams, params
|
|
490
|
+
|
|
491
|
+
@event.listens_for(engine, "after_cursor_execute")
|
|
492
|
+
def cursor_execute(
|
|
493
|
+
conn, cursor, statement, parameters, context, executemany
|
|
494
|
+
):
|
|
495
|
+
if not context:
|
|
496
|
+
return
|
|
497
|
+
# then grab real cursor statements and associate them all
|
|
498
|
+
# around a single context
|
|
499
|
+
if (
|
|
500
|
+
asserter.accumulated
|
|
501
|
+
and asserter.accumulated[-1].context is context
|
|
502
|
+
):
|
|
503
|
+
obs = asserter.accumulated[-1]
|
|
504
|
+
else:
|
|
505
|
+
obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
|
|
506
|
+
asserter.accumulated.append(obs)
|
|
507
|
+
|
|
508
|
+
obs.statements.append(
|
|
509
|
+
SQLCursorExecuteObserved(
|
|
510
|
+
statement, parameters, context, executemany
|
|
511
|
+
)
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
yield asserter
|
|
516
|
+
finally:
|
|
517
|
+
event.remove(engine, "after_cursor_execute", cursor_execute)
|
|
518
|
+
event.remove(engine, "before_execute", connection_execute)
|
|
519
|
+
asserter._close()
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# testing/asyncio.py
|
|
2
|
+
# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
|
|
3
|
+
# <see AUTHORS file>
|
|
4
|
+
#
|
|
5
|
+
# This module is part of SQLAlchemy and is released under
|
|
6
|
+
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
|
7
|
+
# mypy: ignore-errors
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# functions and wrappers to run tests, fixtures, provisioning and
|
|
11
|
+
# setup/teardown in an asyncio event loop, conditionally based on the
|
|
12
|
+
# current DB driver being used for a test.
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
from functools import wraps
|
|
17
|
+
import inspect
|
|
18
|
+
|
|
19
|
+
from . import config
|
|
20
|
+
from ..util.concurrency import _AsyncUtil
|
|
21
|
+
|
|
22
|
+
# may be set to False if the
|
|
23
|
+
# --disable-asyncio flag is passed to the test runner.
|
|
24
|
+
ENABLE_ASYNCIO = True
|
|
25
|
+
_async_util = _AsyncUtil() # it has lazy init so just always create one
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _shutdown():
|
|
29
|
+
"""called when the test finishes"""
|
|
30
|
+
_async_util.close()
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _run_coroutine_function(fn, *args, **kwargs):
|
|
34
|
+
return _async_util.run(fn, *args, **kwargs)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _assume_async(fn, *args, **kwargs):
|
|
38
|
+
"""Run a function in an asyncio loop unconditionally.
|
|
39
|
+
|
|
40
|
+
This function is used for provisioning features like
|
|
41
|
+
testing a database connection for server info.
|
|
42
|
+
|
|
43
|
+
Note that for blocking IO database drivers, this means they block the
|
|
44
|
+
event loop.
|
|
45
|
+
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
if not ENABLE_ASYNCIO:
|
|
49
|
+
return fn(*args, **kwargs)
|
|
50
|
+
|
|
51
|
+
return _async_util.run_in_greenlet(fn, *args, **kwargs)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _maybe_async_provisioning(fn, *args, **kwargs):
|
|
55
|
+
"""Run a function in an asyncio loop if any current drivers might need it.
|
|
56
|
+
|
|
57
|
+
This function is used for provisioning features that take
|
|
58
|
+
place outside of a specific database driver being selected, so if the
|
|
59
|
+
current driver that happens to be used for the provisioning operation
|
|
60
|
+
is an async driver, it will run in asyncio and not fail.
|
|
61
|
+
|
|
62
|
+
Note that for blocking IO database drivers, this means they block the
|
|
63
|
+
event loop.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
if not ENABLE_ASYNCIO:
|
|
67
|
+
return fn(*args, **kwargs)
|
|
68
|
+
|
|
69
|
+
if config.any_async:
|
|
70
|
+
return _async_util.run_in_greenlet(fn, *args, **kwargs)
|
|
71
|
+
else:
|
|
72
|
+
return fn(*args, **kwargs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _maybe_async(fn, *args, **kwargs):
|
|
76
|
+
"""Run a function in an asyncio loop if the current selected driver is
|
|
77
|
+
async.
|
|
78
|
+
|
|
79
|
+
This function is used for test setup/teardown and tests themselves
|
|
80
|
+
where the current DB driver is known.
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
"""
|
|
84
|
+
if not ENABLE_ASYNCIO:
|
|
85
|
+
return fn(*args, **kwargs)
|
|
86
|
+
|
|
87
|
+
is_async = config._current.is_async
|
|
88
|
+
|
|
89
|
+
if is_async:
|
|
90
|
+
return _async_util.run_in_greenlet(fn, *args, **kwargs)
|
|
91
|
+
else:
|
|
92
|
+
return fn(*args, **kwargs)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _maybe_async_wrapper(fn):
|
|
96
|
+
"""Apply the _maybe_async function to an existing function and return
|
|
97
|
+
as a wrapped callable, supporting generator functions as well.
|
|
98
|
+
|
|
99
|
+
This is currently used for pytest fixtures that support generator use.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
if inspect.isgeneratorfunction(fn):
|
|
104
|
+
_stop = object()
|
|
105
|
+
|
|
106
|
+
def call_next(gen):
|
|
107
|
+
try:
|
|
108
|
+
return next(gen)
|
|
109
|
+
# can't raise StopIteration in an awaitable.
|
|
110
|
+
except StopIteration:
|
|
111
|
+
return _stop
|
|
112
|
+
|
|
113
|
+
@wraps(fn)
|
|
114
|
+
def wrap_fixture(*args, **kwargs):
|
|
115
|
+
gen = fn(*args, **kwargs)
|
|
116
|
+
while True:
|
|
117
|
+
value = _maybe_async(call_next, gen)
|
|
118
|
+
if value is _stop:
|
|
119
|
+
break
|
|
120
|
+
yield value
|
|
121
|
+
|
|
122
|
+
else:
|
|
123
|
+
|
|
124
|
+
@wraps(fn)
|
|
125
|
+
def wrap_fixture(*args, **kwargs):
|
|
126
|
+
return _maybe_async(fn, *args, **kwargs)
|
|
127
|
+
|
|
128
|
+
return wrap_fixture
|