polars-runtime-compat 1.34.0b2__cp39-abi3-win_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of polars-runtime-compat might be problematic. Click here for more details.

Files changed (203) hide show
  1. _polars_runtime_compat/.gitkeep +0 -0
  2. _polars_runtime_compat/_polars_runtime_compat.pyd +0 -0
  3. polars/__init__.py +528 -0
  4. polars/_cpu_check.py +265 -0
  5. polars/_dependencies.py +355 -0
  6. polars/_plr.py +99 -0
  7. polars/_plr.pyi +2496 -0
  8. polars/_reexport.py +23 -0
  9. polars/_typing.py +478 -0
  10. polars/_utils/__init__.py +37 -0
  11. polars/_utils/async_.py +102 -0
  12. polars/_utils/cache.py +176 -0
  13. polars/_utils/cloud.py +40 -0
  14. polars/_utils/constants.py +29 -0
  15. polars/_utils/construction/__init__.py +46 -0
  16. polars/_utils/construction/dataframe.py +1397 -0
  17. polars/_utils/construction/other.py +72 -0
  18. polars/_utils/construction/series.py +560 -0
  19. polars/_utils/construction/utils.py +118 -0
  20. polars/_utils/convert.py +224 -0
  21. polars/_utils/deprecation.py +406 -0
  22. polars/_utils/getitem.py +457 -0
  23. polars/_utils/logging.py +11 -0
  24. polars/_utils/nest_asyncio.py +264 -0
  25. polars/_utils/parquet.py +15 -0
  26. polars/_utils/parse/__init__.py +12 -0
  27. polars/_utils/parse/expr.py +242 -0
  28. polars/_utils/polars_version.py +19 -0
  29. polars/_utils/pycapsule.py +53 -0
  30. polars/_utils/scan.py +27 -0
  31. polars/_utils/serde.py +63 -0
  32. polars/_utils/slice.py +215 -0
  33. polars/_utils/udfs.py +1251 -0
  34. polars/_utils/unstable.py +63 -0
  35. polars/_utils/various.py +782 -0
  36. polars/_utils/wrap.py +25 -0
  37. polars/api.py +370 -0
  38. polars/catalog/__init__.py +0 -0
  39. polars/catalog/unity/__init__.py +19 -0
  40. polars/catalog/unity/client.py +733 -0
  41. polars/catalog/unity/models.py +152 -0
  42. polars/config.py +1571 -0
  43. polars/convert/__init__.py +25 -0
  44. polars/convert/general.py +1046 -0
  45. polars/convert/normalize.py +261 -0
  46. polars/dataframe/__init__.py +5 -0
  47. polars/dataframe/_html.py +186 -0
  48. polars/dataframe/frame.py +12582 -0
  49. polars/dataframe/group_by.py +1067 -0
  50. polars/dataframe/plotting.py +257 -0
  51. polars/datatype_expr/__init__.py +5 -0
  52. polars/datatype_expr/array.py +56 -0
  53. polars/datatype_expr/datatype_expr.py +304 -0
  54. polars/datatype_expr/list.py +18 -0
  55. polars/datatype_expr/struct.py +69 -0
  56. polars/datatypes/__init__.py +122 -0
  57. polars/datatypes/_parse.py +195 -0
  58. polars/datatypes/_utils.py +48 -0
  59. polars/datatypes/classes.py +1213 -0
  60. polars/datatypes/constants.py +11 -0
  61. polars/datatypes/constructor.py +172 -0
  62. polars/datatypes/convert.py +366 -0
  63. polars/datatypes/group.py +130 -0
  64. polars/exceptions.py +230 -0
  65. polars/expr/__init__.py +7 -0
  66. polars/expr/array.py +964 -0
  67. polars/expr/binary.py +346 -0
  68. polars/expr/categorical.py +306 -0
  69. polars/expr/datetime.py +2620 -0
  70. polars/expr/expr.py +11272 -0
  71. polars/expr/list.py +1408 -0
  72. polars/expr/meta.py +444 -0
  73. polars/expr/name.py +321 -0
  74. polars/expr/string.py +3045 -0
  75. polars/expr/struct.py +357 -0
  76. polars/expr/whenthen.py +185 -0
  77. polars/functions/__init__.py +193 -0
  78. polars/functions/aggregation/__init__.py +33 -0
  79. polars/functions/aggregation/horizontal.py +298 -0
  80. polars/functions/aggregation/vertical.py +341 -0
  81. polars/functions/as_datatype.py +848 -0
  82. polars/functions/business.py +138 -0
  83. polars/functions/col.py +384 -0
  84. polars/functions/datatype.py +121 -0
  85. polars/functions/eager.py +524 -0
  86. polars/functions/escape_regex.py +29 -0
  87. polars/functions/lazy.py +2751 -0
  88. polars/functions/len.py +68 -0
  89. polars/functions/lit.py +210 -0
  90. polars/functions/random.py +22 -0
  91. polars/functions/range/__init__.py +19 -0
  92. polars/functions/range/_utils.py +15 -0
  93. polars/functions/range/date_range.py +303 -0
  94. polars/functions/range/datetime_range.py +370 -0
  95. polars/functions/range/int_range.py +348 -0
  96. polars/functions/range/linear_space.py +311 -0
  97. polars/functions/range/time_range.py +287 -0
  98. polars/functions/repeat.py +301 -0
  99. polars/functions/whenthen.py +353 -0
  100. polars/interchange/__init__.py +10 -0
  101. polars/interchange/buffer.py +77 -0
  102. polars/interchange/column.py +190 -0
  103. polars/interchange/dataframe.py +230 -0
  104. polars/interchange/from_dataframe.py +328 -0
  105. polars/interchange/protocol.py +303 -0
  106. polars/interchange/utils.py +170 -0
  107. polars/io/__init__.py +64 -0
  108. polars/io/_utils.py +317 -0
  109. polars/io/avro.py +49 -0
  110. polars/io/clipboard.py +36 -0
  111. polars/io/cloud/__init__.py +17 -0
  112. polars/io/cloud/_utils.py +80 -0
  113. polars/io/cloud/credential_provider/__init__.py +17 -0
  114. polars/io/cloud/credential_provider/_builder.py +520 -0
  115. polars/io/cloud/credential_provider/_providers.py +618 -0
  116. polars/io/csv/__init__.py +9 -0
  117. polars/io/csv/_utils.py +38 -0
  118. polars/io/csv/batched_reader.py +142 -0
  119. polars/io/csv/functions.py +1495 -0
  120. polars/io/database/__init__.py +6 -0
  121. polars/io/database/_arrow_registry.py +70 -0
  122. polars/io/database/_cursor_proxies.py +147 -0
  123. polars/io/database/_executor.py +578 -0
  124. polars/io/database/_inference.py +314 -0
  125. polars/io/database/_utils.py +144 -0
  126. polars/io/database/functions.py +516 -0
  127. polars/io/delta.py +499 -0
  128. polars/io/iceberg/__init__.py +3 -0
  129. polars/io/iceberg/_utils.py +697 -0
  130. polars/io/iceberg/dataset.py +556 -0
  131. polars/io/iceberg/functions.py +151 -0
  132. polars/io/ipc/__init__.py +8 -0
  133. polars/io/ipc/functions.py +514 -0
  134. polars/io/json/__init__.py +3 -0
  135. polars/io/json/read.py +101 -0
  136. polars/io/ndjson.py +332 -0
  137. polars/io/parquet/__init__.py +17 -0
  138. polars/io/parquet/field_overwrites.py +140 -0
  139. polars/io/parquet/functions.py +722 -0
  140. polars/io/partition.py +491 -0
  141. polars/io/plugins.py +187 -0
  142. polars/io/pyarrow_dataset/__init__.py +5 -0
  143. polars/io/pyarrow_dataset/anonymous_scan.py +109 -0
  144. polars/io/pyarrow_dataset/functions.py +79 -0
  145. polars/io/scan_options/__init__.py +5 -0
  146. polars/io/scan_options/_options.py +59 -0
  147. polars/io/scan_options/cast_options.py +126 -0
  148. polars/io/spreadsheet/__init__.py +6 -0
  149. polars/io/spreadsheet/_utils.py +52 -0
  150. polars/io/spreadsheet/_write_utils.py +647 -0
  151. polars/io/spreadsheet/functions.py +1323 -0
  152. polars/lazyframe/__init__.py +9 -0
  153. polars/lazyframe/engine_config.py +61 -0
  154. polars/lazyframe/frame.py +8564 -0
  155. polars/lazyframe/group_by.py +669 -0
  156. polars/lazyframe/in_process.py +42 -0
  157. polars/lazyframe/opt_flags.py +333 -0
  158. polars/meta/__init__.py +14 -0
  159. polars/meta/build.py +33 -0
  160. polars/meta/index_type.py +27 -0
  161. polars/meta/thread_pool.py +50 -0
  162. polars/meta/versions.py +120 -0
  163. polars/ml/__init__.py +0 -0
  164. polars/ml/torch.py +213 -0
  165. polars/ml/utilities.py +30 -0
  166. polars/plugins.py +155 -0
  167. polars/py.typed +0 -0
  168. polars/pyproject.toml +96 -0
  169. polars/schema.py +265 -0
  170. polars/selectors.py +3117 -0
  171. polars/series/__init__.py +5 -0
  172. polars/series/array.py +776 -0
  173. polars/series/binary.py +254 -0
  174. polars/series/categorical.py +246 -0
  175. polars/series/datetime.py +2275 -0
  176. polars/series/list.py +1087 -0
  177. polars/series/plotting.py +191 -0
  178. polars/series/series.py +9197 -0
  179. polars/series/string.py +2367 -0
  180. polars/series/struct.py +154 -0
  181. polars/series/utils.py +191 -0
  182. polars/sql/__init__.py +7 -0
  183. polars/sql/context.py +677 -0
  184. polars/sql/functions.py +139 -0
  185. polars/string_cache.py +185 -0
  186. polars/testing/__init__.py +13 -0
  187. polars/testing/asserts/__init__.py +9 -0
  188. polars/testing/asserts/frame.py +231 -0
  189. polars/testing/asserts/series.py +219 -0
  190. polars/testing/asserts/utils.py +12 -0
  191. polars/testing/parametric/__init__.py +33 -0
  192. polars/testing/parametric/profiles.py +107 -0
  193. polars/testing/parametric/strategies/__init__.py +22 -0
  194. polars/testing/parametric/strategies/_utils.py +14 -0
  195. polars/testing/parametric/strategies/core.py +615 -0
  196. polars/testing/parametric/strategies/data.py +452 -0
  197. polars/testing/parametric/strategies/dtype.py +436 -0
  198. polars/testing/parametric/strategies/legacy.py +169 -0
  199. polars/type_aliases.py +24 -0
  200. polars_runtime_compat-1.34.0b2.dist-info/METADATA +31 -0
  201. polars_runtime_compat-1.34.0b2.dist-info/RECORD +203 -0
  202. polars_runtime_compat-1.34.0b2.dist-info/WHEEL +4 -0
  203. polars_runtime_compat-1.34.0b2.dist-info/licenses/LICENSE +1 -0
@@ -0,0 +1,578 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+ from collections.abc import Coroutine, Sequence
5
+ from contextlib import suppress
6
+ from inspect import Parameter, signature
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from polars import functions as F
10
+ from polars._utils.various import parse_version
11
+ from polars.convert import from_arrow
12
+ from polars.datatypes import N_INFER_DEFAULT
13
+ from polars.exceptions import (
14
+ DuplicateError,
15
+ ModuleUpgradeRequiredError,
16
+ UnsuitableSQLError,
17
+ )
18
+ from polars.io.database._arrow_registry import ARROW_DRIVER_REGISTRY
19
+ from polars.io.database._cursor_proxies import ODBCCursorProxy, SurrealDBCursorProxy
20
+ from polars.io.database._inference import dtype_from_cursor_description
21
+ from polars.io.database._utils import _run_async
22
+
23
+ if TYPE_CHECKING:
24
+ import sys
25
+ from collections.abc import Iterable, Iterator
26
+ from types import TracebackType
27
+
28
+ import pyarrow as pa
29
+
30
+ from polars.io.database._arrow_registry import ArrowDriverProperties
31
+
32
+ if sys.version_info >= (3, 11):
33
+ from typing import Self
34
+ else:
35
+ from typing_extensions import Self
36
+
37
+ from sqlalchemy.sql.elements import TextClause
38
+ from sqlalchemy.sql.expression import Selectable
39
+
40
+ from polars import DataFrame
41
+ from polars._typing import ConnectionOrCursor, Cursor, SchemaDict
42
+
43
+ _INVALID_QUERY_TYPES = {
44
+ "ALTER",
45
+ "ANALYZE",
46
+ "CREATE",
47
+ "DELETE",
48
+ "DROP",
49
+ "INSERT",
50
+ "REPLACE",
51
+ "UPDATE",
52
+ "UPSERT",
53
+ "USE",
54
+ "VACUUM",
55
+ }
56
+
57
+
58
+ class CloseAfterFrameIter:
59
+ """Allows cursor close to be deferred until the last batch is returned."""
60
+
61
+ def __init__(self, frames: Any, *, cursor: Cursor) -> None:
62
+ self._iter_frames = frames
63
+ self._cursor = cursor
64
+
65
+ def __iter__(self) -> Iterator[DataFrame]:
66
+ yield from self._iter_frames
67
+
68
+ if hasattr(self._cursor, "close"):
69
+ self._cursor.close()
70
+
71
+
72
+ class ConnectionExecutor:
73
+ """Abstraction for querying databases with user-supplied connection objects."""
74
+
75
+ # indicate if we can/should close the cursor on scope exit. note that we
76
+ # should never close the underlying connection, or a user-supplied cursor.
77
+ can_close_cursor: bool = False
78
+
79
+ def __init__(self, connection: ConnectionOrCursor) -> None:
80
+ self.driver_name = (
81
+ "arrow_odbc_proxy"
82
+ if isinstance(connection, ODBCCursorProxy)
83
+ else type(connection).__module__.split(".", 1)[0].lower()
84
+ )
85
+ if self.driver_name == "surrealdb":
86
+ connection = SurrealDBCursorProxy(client=connection)
87
+
88
+ self.cursor = self._normalise_cursor(connection)
89
+ self.result: Any = None
90
+
91
+ def __enter__(self) -> Self:
92
+ return self
93
+
94
+ def __exit__(
95
+ self,
96
+ exc_type: type[BaseException] | None,
97
+ exc_val: BaseException | None,
98
+ exc_tb: TracebackType | None,
99
+ ) -> None:
100
+ # if we created it and are finished with it, we can
101
+ # close the cursor (but NOT the connection)
102
+ if self._is_alchemy_async(self.cursor):
103
+ from sqlalchemy.ext.asyncio import AsyncConnection
104
+
105
+ if isinstance(self.cursor, AsyncConnection):
106
+ _run_async(self._close_async_cursor())
107
+ elif self.can_close_cursor and hasattr(self.cursor, "close"):
108
+ self.cursor.close()
109
+
110
+ def __repr__(self) -> str:
111
+ return f"<{type(self).__name__} module={self.driver_name!r}>"
112
+
113
+ @staticmethod
114
+ def _apply_overrides(df: DataFrame, schema_overrides: SchemaDict) -> DataFrame:
115
+ """Apply schema overrides to a DataFrame."""
116
+ existing_schema = df.schema
117
+ if cast_cols := [
118
+ F.col(col).cast(dtype)
119
+ for col, dtype in schema_overrides.items()
120
+ if col in existing_schema and dtype != existing_schema[col]
121
+ ]:
122
+ df = df.with_columns(cast_cols)
123
+ return df
124
+
125
+ async def _close_async_cursor(self) -> None:
126
+ if self.can_close_cursor and hasattr(self.cursor, "close"):
127
+ from sqlalchemy.ext.asyncio.exc import AsyncContextNotStarted
128
+
129
+ with suppress(AsyncContextNotStarted):
130
+ await self.cursor.close()
131
+
132
+ @staticmethod
133
+ def _check_module_version(module_name: str, minimum_version: str) -> None:
134
+ """Check the module version against a minimum required version."""
135
+ mod = __import__(module_name)
136
+ with suppress(AttributeError):
137
+ module_version: tuple[int, ...] | None = None
138
+ for version_attr in ("__version__", "version"):
139
+ if isinstance(ver := getattr(mod, version_attr, None), str):
140
+ module_version = parse_version(ver)
141
+ break
142
+ if module_version and module_version < parse_version(minimum_version):
143
+ msg = f"`read_database` queries require at least {module_name} version {minimum_version}"
144
+ raise ModuleUpgradeRequiredError(msg)
145
+
146
+ def _fetch_arrow(
147
+ self,
148
+ driver_properties: ArrowDriverProperties,
149
+ *,
150
+ batch_size: int | None,
151
+ iter_batches: bool,
152
+ ) -> Iterable[pa.RecordBatch]:
153
+ """Yield Arrow data as a generator of one or more RecordBatches or Tables."""
154
+ fetch_batches = driver_properties["fetch_batches"]
155
+ if not iter_batches or fetch_batches is None:
156
+ fetch_method = driver_properties["fetch_all"]
157
+ yield getattr(self.result, fetch_method)()
158
+ else:
159
+ size = [batch_size] if driver_properties["exact_batch_size"] else []
160
+ repeat_batch_calls = driver_properties["repeat_batch_calls"]
161
+ fetchmany_arrow = getattr(self.result, fetch_batches)
162
+ if not repeat_batch_calls:
163
+ yield from fetchmany_arrow(*size)
164
+ else:
165
+ while True:
166
+ arrow = fetchmany_arrow(*size)
167
+ if not arrow:
168
+ break
169
+ yield arrow
170
+
171
+ @staticmethod
172
+ def _fetchall_rows(result: Cursor, *, is_alchemy: bool) -> Iterable[Sequence[Any]]:
173
+ """Fetch row data in a single call, returning the complete result set."""
174
+ rows = result.fetchall()
175
+ return (
176
+ rows
177
+ if rows and (is_alchemy or isinstance(rows[0], (list, tuple, dict)))
178
+ else [tuple(row) for row in rows]
179
+ )
180
+
181
+ def _fetchmany_rows(
182
+ self, result: Cursor, *, batch_size: int | None, is_alchemy: bool
183
+ ) -> Iterable[Sequence[Any]]:
184
+ """Fetch row data incrementally, yielding over the complete result set."""
185
+ while True:
186
+ rows = result.fetchmany(batch_size)
187
+ if not rows:
188
+ break
189
+ elif is_alchemy or isinstance(rows[0], (list, tuple, dict)):
190
+ yield rows
191
+ else:
192
+ yield [tuple(row) for row in rows]
193
+
194
+ def _from_arrow(
195
+ self,
196
+ *,
197
+ batch_size: int | None,
198
+ iter_batches: bool,
199
+ schema_overrides: SchemaDict | None,
200
+ infer_schema_length: int | None,
201
+ ) -> DataFrame | Iterator[DataFrame] | None:
202
+ """Return resultset data in Arrow format for frame init."""
203
+ from polars import DataFrame
204
+
205
+ try:
206
+ for driver, driver_properties in ARROW_DRIVER_REGISTRY.items():
207
+ if re.match(f"^{driver}$", self.driver_name):
208
+ if ver := driver_properties["minimum_version"]:
209
+ self._check_module_version(self.driver_name, ver)
210
+
211
+ if iter_batches and (
212
+ driver_properties["exact_batch_size"] and not batch_size
213
+ ):
214
+ msg = f"Cannot set `iter_batches` for {self.driver_name} without also setting a non-zero `batch_size`"
215
+ raise ValueError(msg) # noqa: TRY301
216
+
217
+ frames = (
218
+ self._apply_overrides(batch, (schema_overrides or {}))
219
+ if isinstance(batch, DataFrame)
220
+ else from_arrow(batch, schema_overrides=schema_overrides)
221
+ for batch in self._fetch_arrow(
222
+ driver_properties,
223
+ iter_batches=iter_batches,
224
+ batch_size=batch_size,
225
+ )
226
+ )
227
+ return frames if iter_batches else next(frames) # type: ignore[arg-type,return-value]
228
+ except Exception as err:
229
+ # eg: valid turbodbc/snowflake connection, but no arrow support
230
+ # compiled in to the underlying driver (or on this connection)
231
+ arrow_not_supported = (
232
+ "does not support Apache Arrow",
233
+ "Apache Arrow format is not supported",
234
+ )
235
+ if not any(e in str(err) for e in arrow_not_supported):
236
+ raise
237
+
238
+ return None
239
+
240
+ def _from_rows(
241
+ self,
242
+ *,
243
+ batch_size: int | None,
244
+ iter_batches: bool,
245
+ schema_overrides: SchemaDict | None,
246
+ infer_schema_length: int | None,
247
+ ) -> DataFrame | Iterator[DataFrame] | None:
248
+ """Return resultset data row-wise for frame init."""
249
+ from polars import DataFrame
250
+
251
+ if iter_batches and not batch_size:
252
+ msg = (
253
+ "Cannot set `iter_batches` without also setting a non-zero `batch_size`"
254
+ )
255
+ raise ValueError(msg)
256
+
257
+ if is_async := isinstance(original_result := self.result, Coroutine):
258
+ self.result = _run_async(self.result)
259
+ try:
260
+ if hasattr(self.result, "fetchall"):
261
+ if is_alchemy := (self.driver_name == "sqlalchemy"):
262
+ if hasattr(self.result, "cursor"):
263
+ cursor_desc = [
264
+ (d[0], d[1:]) for d in self.result.cursor.description
265
+ ]
266
+ elif hasattr(self.result, "_metadata"):
267
+ cursor_desc = [(k, None) for k in self.result._metadata.keys]
268
+ else:
269
+ msg = f"Unable to determine metadata from query result; {self.result!r}"
270
+ raise ValueError(msg)
271
+
272
+ elif hasattr(self.result, "description"):
273
+ cursor_desc = [(d[0], d[1:]) for d in self.result.description]
274
+ else:
275
+ cursor_desc = []
276
+
277
+ schema_overrides = self._inject_type_overrides(
278
+ description=cursor_desc,
279
+ schema_overrides=(schema_overrides or {}),
280
+ )
281
+ result_columns = [nm for nm, _ in cursor_desc]
282
+ frames = (
283
+ DataFrame(
284
+ data=rows,
285
+ schema=result_columns or None,
286
+ schema_overrides=schema_overrides,
287
+ infer_schema_length=infer_schema_length,
288
+ orient="row",
289
+ )
290
+ for rows in (
291
+ self._fetchmany_rows(
292
+ self.result,
293
+ batch_size=batch_size,
294
+ is_alchemy=is_alchemy,
295
+ )
296
+ if iter_batches
297
+ else [self._fetchall_rows(self.result, is_alchemy=is_alchemy)] # type: ignore[list-item]
298
+ )
299
+ )
300
+ return frames if iter_batches else next(frames) # type: ignore[arg-type]
301
+ return None
302
+ finally:
303
+ if is_async:
304
+ original_result.close()
305
+
306
+ def _inject_type_overrides(
307
+ self,
308
+ description: list[tuple[str, Any]],
309
+ schema_overrides: SchemaDict,
310
+ ) -> SchemaDict:
311
+ """
312
+ Attempt basic dtype inference from a cursor description.
313
+
314
+ Notes
315
+ -----
316
+ This is limited; the `type_code` description attr may contain almost anything,
317
+ from strings or python types to driver-specific codes, classes, enums, etc.
318
+ We currently only do the additional inference from string/python type values.
319
+ (Further refinement will require per-driver module knowledge and lookups).
320
+ """
321
+ dupe_check = set()
322
+ for nm, desc in description:
323
+ if nm in dupe_check:
324
+ msg = f"column {nm!r} appears more than once in the query/result cursor"
325
+ raise DuplicateError(msg)
326
+ elif desc is not None and nm not in schema_overrides:
327
+ dtype = dtype_from_cursor_description(self.cursor, desc)
328
+ if dtype is not None:
329
+ schema_overrides[nm] = dtype # type: ignore[index]
330
+ dupe_check.add(nm)
331
+
332
+ return schema_overrides
333
+
334
+ @staticmethod
335
+ def _is_alchemy_async(conn: Any) -> bool:
336
+ """Check if the given connection is SQLALchemy async."""
337
+ try:
338
+ from sqlalchemy.ext.asyncio import (
339
+ AsyncConnection,
340
+ AsyncSession,
341
+ async_sessionmaker,
342
+ )
343
+
344
+ return isinstance(conn, (AsyncConnection, AsyncSession, async_sessionmaker))
345
+ except ImportError:
346
+ return False
347
+
348
+ @staticmethod
349
+ def _is_alchemy_engine(conn: Any) -> bool:
350
+ """Check if the given connection is a SQLAlchemy Engine."""
351
+ from sqlalchemy.engine import Engine
352
+
353
+ if isinstance(conn, Engine):
354
+ return True
355
+ try:
356
+ from sqlalchemy.ext.asyncio import AsyncEngine
357
+
358
+ return isinstance(conn, AsyncEngine)
359
+ except ImportError:
360
+ return False
361
+
362
+ @staticmethod
363
+ def _is_alchemy_object(conn: Any) -> bool:
364
+ """Check if the given connection is a SQLAlchemy object (of any kind)."""
365
+ return type(conn).__module__.split(".", 1)[0] == "sqlalchemy"
366
+
367
+ @staticmethod
368
+ def _is_alchemy_session(conn: Any) -> bool:
369
+ """Check if the given connection is a SQLAlchemy Session object."""
370
+ from sqlalchemy.ext.asyncio import AsyncSession
371
+ from sqlalchemy.orm import Session, sessionmaker
372
+
373
+ if isinstance(conn, (AsyncSession, Session, sessionmaker)):
374
+ return True
375
+
376
+ try:
377
+ from sqlalchemy.ext.asyncio import async_sessionmaker
378
+
379
+ return isinstance(conn, async_sessionmaker)
380
+ except ImportError:
381
+ return False
382
+
383
+ @staticmethod
384
+ def _is_alchemy_result(result: Any) -> bool:
385
+ """Check if the given result is a SQLAlchemy Result object."""
386
+ try:
387
+ from sqlalchemy.engine import CursorResult
388
+
389
+ if isinstance(result, CursorResult):
390
+ return True
391
+
392
+ from sqlalchemy.ext.asyncio import AsyncResult
393
+
394
+ return isinstance(result, AsyncResult)
395
+ except ImportError:
396
+ return False
397
+
398
+ def _normalise_cursor(self, conn: Any) -> Cursor:
399
+ """Normalise a connection object such that we have the query executor."""
400
+ if self.driver_name == "sqlalchemy":
401
+ if self._is_alchemy_session(conn):
402
+ return conn
403
+ else:
404
+ # where possible, use the raw connection to access arrow integration
405
+ if conn.engine.driver == "databricks-sql-python":
406
+ self.driver_name = "databricks"
407
+ return conn.engine.raw_connection().cursor()
408
+ elif conn.engine.driver == "duckdb_engine":
409
+ self.driver_name = "duckdb"
410
+ return conn
411
+ elif self._is_alchemy_engine(conn):
412
+ # note: if we create it, we can close it
413
+ self.can_close_cursor = True
414
+ return conn.connect()
415
+ else:
416
+ return conn
417
+
418
+ elif hasattr(conn, "cursor"):
419
+ # connection has a dedicated cursor; prefer over direct execute
420
+ cursor = cursor() if callable(cursor := conn.cursor) else cursor
421
+ self.can_close_cursor = True
422
+ return cursor
423
+
424
+ elif hasattr(conn, "execute"):
425
+ # can execute directly (given cursor, sqlalchemy connection, etc)
426
+ return conn
427
+
428
+ msg = f"""Unrecognised connection type "{conn!r}"; no 'execute' or 'cursor' method"""
429
+ raise TypeError(msg)
430
+
431
+ async def _sqlalchemy_async_execute(self, query: TextClause, **options: Any) -> Any:
432
+ """Execute a query using an async SQLAlchemy connection."""
433
+ is_session = self._is_alchemy_session(self.cursor)
434
+ cursor = self.cursor.begin() if is_session else self.cursor # type: ignore[attr-defined]
435
+ async with cursor as conn: # type: ignore[union-attr]
436
+ if is_session and not hasattr(conn, "execute"):
437
+ conn = conn.session
438
+ result = await conn.execute(query, **options)
439
+ return result
440
+
441
+ def _sqlalchemy_setup(
442
+ self, query: str | TextClause | Selectable, options: dict[str, Any]
443
+ ) -> tuple[Any, dict[str, Any], str | TextClause | Selectable]:
444
+ """Prepare a query for execution using a SQLAlchemy connection."""
445
+ from sqlalchemy.orm import Session
446
+ from sqlalchemy.sql import text
447
+ from sqlalchemy.sql.elements import TextClause
448
+
449
+ param_key = "parameters"
450
+ cursor_execute = None
451
+ if (
452
+ isinstance(self.cursor, Session)
453
+ and "parameters" in options
454
+ and "params" not in options
455
+ ):
456
+ options = options.copy()
457
+ options["params"] = options.pop("parameters")
458
+ param_key = "params"
459
+
460
+ params = options.get(param_key)
461
+ is_async = self._is_alchemy_async(self.cursor)
462
+ if (
463
+ not is_async
464
+ and isinstance(params, Sequence)
465
+ and hasattr(self.cursor, "exec_driver_sql")
466
+ ):
467
+ cursor_execute = self.cursor.exec_driver_sql
468
+ if isinstance(query, TextClause):
469
+ query = str(query)
470
+ if isinstance(params, list) and not all(
471
+ isinstance(p, (dict, tuple)) for p in params
472
+ ):
473
+ options[param_key] = tuple(params)
474
+
475
+ elif isinstance(query, str):
476
+ query = text(query)
477
+
478
+ if cursor_execute is None:
479
+ cursor_execute = (
480
+ self._sqlalchemy_async_execute if is_async else self.cursor.execute
481
+ )
482
+ return cursor_execute, options, query
483
+
484
+ def execute(
485
+ self,
486
+ query: str | TextClause | Selectable,
487
+ *,
488
+ options: dict[str, Any] | None = None,
489
+ select_queries_only: bool = True,
490
+ ) -> Self:
491
+ """Execute a query and reference the result set."""
492
+ if select_queries_only and isinstance(query, str):
493
+ q = re.search(r"\w{3,}", re.sub(r"/\*(.|[\r\n])*?\*/", "", query))
494
+ if (query_type := "" if not q else q.group(0)) in _INVALID_QUERY_TYPES:
495
+ msg = f"{query_type} statements are not valid 'read' queries"
496
+ raise UnsuitableSQLError(msg)
497
+
498
+ options = options or {}
499
+
500
+ if self._is_alchemy_object(self.cursor):
501
+ cursor_execute, options, query = self._sqlalchemy_setup(query, options)
502
+ else:
503
+ cursor_execute = self.cursor.execute
504
+
505
+ # note: some cursor execute methods (eg: sqlite3) only take positional
506
+ # params, hence the slightly convoluted resolution of the 'options' dict
507
+ try:
508
+ params = signature(cursor_execute).parameters
509
+ except ValueError:
510
+ params = {} # type: ignore[assignment]
511
+
512
+ if not options or any(
513
+ p.kind in (Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD)
514
+ for p in params.values()
515
+ ):
516
+ result = cursor_execute(query, **options)
517
+ else:
518
+ positional_options = (
519
+ options[o] for o in (params or options) if (not options or o in options)
520
+ )
521
+ result = cursor_execute(query, *positional_options)
522
+
523
+ # note: some cursors execute in-place, some access results via a property
524
+ result = self.cursor if (result is None or result is True) else result
525
+ if self.driver_name == "duckdb" and self._is_alchemy_result(result):
526
+ result = result.cursor
527
+
528
+ self.result = result
529
+ return self
530
+
531
+ def to_polars(
532
+ self,
533
+ *,
534
+ iter_batches: bool = False,
535
+ batch_size: int | None = None,
536
+ schema_overrides: SchemaDict | None = None,
537
+ infer_schema_length: int | None = N_INFER_DEFAULT,
538
+ ) -> DataFrame | Iterator[DataFrame]:
539
+ """
540
+ Convert the result set to a DataFrame.
541
+
542
+ Wherever possible we try to return arrow-native data directly; only
543
+ fall back to initialising with row-level data if no other option.
544
+ """
545
+ if self.result is None:
546
+ msg = "cannot return a frame before executing a query"
547
+ raise RuntimeError(msg)
548
+
549
+ can_close = self.can_close_cursor
550
+
551
+ if defer_cursor_close := (iter_batches and can_close):
552
+ self.can_close_cursor = False
553
+
554
+ for frame_init in (
555
+ self._from_arrow, # init from arrow-native data (where support exists)
556
+ self._from_rows, # row-wise fallback (sqlalchemy, dbapi2, pyodbc, etc)
557
+ ):
558
+ frame = frame_init(
559
+ batch_size=batch_size,
560
+ iter_batches=iter_batches,
561
+ schema_overrides=schema_overrides,
562
+ infer_schema_length=infer_schema_length,
563
+ )
564
+ if frame is not None:
565
+ if defer_cursor_close:
566
+ frame = (
567
+ df
568
+ for df in CloseAfterFrameIter(
569
+ frame,
570
+ cursor=self.result,
571
+ )
572
+ )
573
+ return frame
574
+
575
+ msg = (
576
+ f"Currently no support for {self.driver_name!r} connection {self.cursor!r}"
577
+ )
578
+ raise NotImplementedError(msg)