execsql2 2.1.2__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. execsql/cli/__init__.py +436 -0
  2. execsql/cli/dsn.py +86 -0
  3. execsql/cli/help.py +140 -0
  4. execsql/{cli.py → cli/run.py} +14 -589
  5. execsql/config.py +65 -1
  6. execsql/db/access.py +27 -15
  7. execsql/db/base.py +328 -215
  8. execsql/db/dsn.py +10 -5
  9. execsql/db/duckdb.py +6 -2
  10. execsql/db/factory.py +21 -0
  11. execsql/db/firebird.py +27 -19
  12. execsql/db/mysql.py +12 -7
  13. execsql/db/oracle.py +15 -11
  14. execsql/db/postgres.py +31 -16
  15. execsql/db/sqlite.py +15 -11
  16. execsql/db/sqlserver.py +16 -5
  17. execsql/exceptions.py +25 -7
  18. execsql/exporters/base.py +12 -1
  19. execsql/exporters/delimited.py +80 -35
  20. execsql/exporters/duckdb.py +6 -2
  21. execsql/exporters/feather.py +10 -6
  22. execsql/exporters/html.py +89 -69
  23. execsql/exporters/json.py +52 -45
  24. execsql/exporters/latex.py +37 -27
  25. execsql/exporters/ods.py +32 -11
  26. execsql/exporters/parquet.py +5 -2
  27. execsql/exporters/pretty.py +16 -9
  28. execsql/exporters/raw.py +22 -16
  29. execsql/exporters/sqlite.py +6 -2
  30. execsql/exporters/templates.py +39 -21
  31. execsql/exporters/values.py +26 -20
  32. execsql/exporters/xls.py +30 -11
  33. execsql/exporters/xml.py +31 -13
  34. execsql/exporters/zip.py +15 -0
  35. execsql/importers/base.py +6 -4
  36. execsql/importers/csv.py +8 -6
  37. execsql/importers/feather.py +6 -4
  38. execsql/importers/ods.py +6 -4
  39. execsql/importers/xls.py +6 -4
  40. execsql/metacommands/__init__.py +208 -1548
  41. execsql/metacommands/conditions.py +101 -27
  42. execsql/metacommands/control.py +8 -4
  43. execsql/metacommands/data.py +6 -6
  44. execsql/metacommands/debug.py +6 -2
  45. execsql/metacommands/dispatch.py +2011 -0
  46. execsql/metacommands/io.py +67 -1310
  47. execsql/metacommands/io_export.py +442 -0
  48. execsql/metacommands/io_fileops.py +287 -0
  49. execsql/metacommands/io_import.py +398 -0
  50. execsql/metacommands/io_write.py +248 -0
  51. execsql/metacommands/prompt.py +22 -66
  52. execsql/metacommands/system.py +7 -2
  53. execsql/models.py +7 -0
  54. execsql/parser.py +10 -0
  55. execsql/py.typed +0 -0
  56. execsql/script/__init__.py +95 -0
  57. execsql/script/control.py +162 -0
  58. execsql/{script.py → script/engine.py} +184 -402
  59. execsql/script/variables.py +281 -0
  60. execsql/types.py +49 -20
  61. execsql/utils/auth.py +2 -0
  62. execsql/utils/crypto.py +4 -6
  63. execsql/utils/datetime.py +1 -0
  64. execsql/utils/errors.py +11 -0
  65. execsql/utils/fileio.py +33 -8
  66. execsql/utils/gui.py +46 -0
  67. execsql/utils/mail.py +7 -17
  68. execsql/utils/numeric.py +2 -0
  69. execsql/utils/regex.py +9 -0
  70. execsql/utils/strings.py +16 -0
  71. execsql/utils/timer.py +2 -0
  72. execsql2-2.4.0.data/data/execsql2_extras/README.md +65 -0
  73. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/execsql.conf +1 -1
  74. {execsql2-2.1.2.dist-info → execsql2-2.4.0.dist-info}/METADATA +13 -6
  75. execsql2-2.4.0.dist-info/RECORD +108 -0
  76. execsql2-2.1.2.data/data/execsql2_extras/READ_ME.rst +0 -127
  77. execsql2-2.1.2.dist-info/RECORD +0 -96
  78. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/config_settings.sqlite +0 -0
  79. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/example_config_prompt.sql +0 -0
  80. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/make_config_db.sql +0 -0
  81. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/md_compare.sql +0 -0
  82. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/md_glossary.sql +0 -0
  83. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/md_upsert.sql +0 -0
  84. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/pg_compare.sql +0 -0
  85. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/pg_glossary.sql +0 -0
  86. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/pg_upsert.sql +0 -0
  87. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/script_template.sql +0 -0
  88. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/ss_compare.sql +0 -0
  89. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/ss_glossary.sql +0 -0
  90. {execsql2-2.1.2.data → execsql2-2.4.0.data}/data/execsql2_extras/ss_upsert.sql +0 -0
  91. {execsql2-2.1.2.dist-info → execsql2-2.4.0.dist-info}/WHEEL +0 -0
  92. {execsql2-2.1.2.dist-info → execsql2-2.4.0.dist-info}/entry_points.txt +0 -0
  93. {execsql2-2.1.2.dist-info → execsql2-2.4.0.dist-info}/licenses/LICENSE.txt +0 -0
  94. {execsql2-2.1.2.dist-info → execsql2-2.4.0.dist-info}/licenses/NOTICE +0 -0
execsql/db/base.py CHANGED
@@ -14,7 +14,11 @@ open :class:`Database` instances and tracks which connection is currently
14
14
  active. It is the canonical ``_state.dbs`` object.
15
15
  """
16
16
 
17
+ import contextlib
18
+ import datetime
17
19
  import re
20
+ from abc import ABC, abstractmethod
21
+ from decimal import Decimal
18
22
  from typing import Any
19
23
  from collections.abc import Callable, Generator, Iterator
20
24
 
@@ -22,11 +26,40 @@ from execsql.exceptions import ErrInfo
22
26
  from execsql.utils.errors import exception_desc
23
27
  import execsql.state as _state
24
28
 
29
+ __all__ = ["Database", "DatabasePool"]
25
30
 
26
- class Database:
31
+
32
+ def _default_dt_cast() -> dict[type, Callable]:
33
+ """Build the default type-cast mapping used by all database backends."""
34
+ from execsql.types import DT_Boolean, DT_Timestamp, DT_Date, DT_Decimal
35
+
36
+ return {
37
+ int: int,
38
+ float: float,
39
+ str: str,
40
+ bool: DT_Boolean().from_data,
41
+ datetime.datetime: DT_Timestamp().from_data,
42
+ datetime.date: DT_Date().from_data,
43
+ Decimal: DT_Decimal().from_data,
44
+ bytearray: bytearray,
45
+ }
46
+
47
+
48
+ class Database(ABC):
27
49
  """Abstract base class for all database connections."""
28
50
 
29
- dt_cast: dict[type, Callable] = {} # populated per-subclass or in __init__
51
+ _dt_cast: dict[type, Callable] | None = None
52
+
53
+ @property
54
+ def dt_cast(self) -> dict[type, Callable]:
55
+ """Return the type-cast mapping, initialising it lazily on first access."""
56
+ if self._dt_cast is None:
57
+ self._dt_cast = _default_dt_cast()
58
+ return self._dt_cast
59
+
60
+ @dt_cast.setter
61
+ def dt_cast(self, value: dict[type, Callable]) -> None:
62
+ self._dt_cast = value
30
63
 
31
64
  def __init__(
32
65
  self,
@@ -57,22 +90,38 @@ class Database:
57
90
  )
58
91
 
59
92
  def name(self) -> str:
93
+ """Return a human-readable description of this connection (DBMS + server/file)."""
60
94
  if self.server_name:
61
95
  return f"{self.type.dbms_id}(server {self.server_name}; database {self.db_name})"
62
96
  else:
63
97
  return f"{self.type.dbms_id}(file {self.db_name})"
64
98
 
99
+ @abstractmethod
65
100
  def open_db(self) -> None:
66
- from execsql.exceptions import DatabaseNotImplementedError
67
-
68
- raise DatabaseNotImplementedError(self.name(), "open_db")
101
+ """Open the underlying database connection."""
102
+ ...
69
103
 
70
104
  def cursor(self):
105
+ """Return a new cursor, opening the connection first if it has not been opened yet."""
71
106
  if self.conn is None:
72
107
  self.open_db()
73
108
  return self.conn.cursor()
74
109
 
110
+ @contextlib.contextmanager
111
+ def _cursor(self):
112
+ """Context manager that yields a cursor and closes it on exit.
113
+
114
+ Works with any DB-API 2.0 cursor regardless of whether the driver
115
+ natively supports the context manager protocol.
116
+ """
117
+ curs = self.cursor()
118
+ try:
119
+ yield curs
120
+ finally:
121
+ curs.close()
122
+
75
123
  def close(self) -> None:
124
+ """Close the database connection, logging a warning if autocommit is off."""
76
125
  if self.conn:
77
126
  if not self.autocommit:
78
127
  _state.exec_log.log_status_info(
@@ -87,24 +136,26 @@ class Database:
87
136
  return '"' + identifier.replace('"', '""') + '"'
88
137
 
89
138
  def paramsubs(self, paramcount: int) -> str:
139
+ """Return a comma-separated string of *paramcount* parameter placeholders."""
90
140
  return ",".join((self.paramstr,) * paramcount)
91
141
 
92
142
  def execute(self, sql: Any, paramlist: list | None = None) -> None:
93
- # A shortcut to self.cursor().execute() that handles encoding.
94
- # Whether or not encoding is needed depends on the DBMS.
143
+ """Execute *sql* (optionally with *paramlist*), updating ``$LAST_ROWCOUNT``.
144
+
145
+ Rolls back the current transaction and re-raises on any driver error.
146
+ """
95
147
  if type(sql) in (tuple, list):
96
148
  sql = " ".join(sql)
97
149
  try:
98
- curs = self.cursor()
99
- if paramlist is None:
100
- curs.execute(sql)
101
- else:
102
- curs.execute(sql, paramlist)
103
- try:
104
- # DuckDB does not support the 'rowcount' attribute.
105
- _state.subvars.add_substitution("$LAST_ROWCOUNT", curs.rowcount)
106
- except Exception:
107
- pass # Non-critical: some drivers lack rowcount support.
150
+ with self._cursor() as curs:
151
+ if paramlist is None:
152
+ curs.execute(sql)
153
+ else:
154
+ curs.execute(sql, paramlist)
155
+ try:
156
+ _state.subvars.add_substitution("$LAST_ROWCOUNT", curs.rowcount)
157
+ except Exception:
158
+ pass # Non-critical: some drivers lack rowcount support.
108
159
  except Exception:
109
160
  try:
110
161
  self.rollback()
@@ -112,22 +163,26 @@ class Database:
112
163
  pass # Rollback is best-effort after a failed execute.
113
164
  raise
114
165
 
166
+ @abstractmethod
115
167
  def exec_cmd(self, querycommand: str) -> None:
116
- from execsql.exceptions import DatabaseNotImplementedError
117
-
118
- raise DatabaseNotImplementedError(self.name(), "exec_cmd")
168
+ """Execute a stored procedure or function by name."""
169
+ ...
119
170
 
120
171
  def autocommit_on(self) -> None:
172
+ """Enable autocommit mode so each statement is committed immediately."""
121
173
  self.autocommit = True
122
174
 
123
175
  def autocommit_off(self) -> None:
176
+ """Disable autocommit mode, grouping subsequent statements into a transaction."""
124
177
  self.autocommit = False
125
178
 
126
179
  def commit(self) -> None:
180
+ """Commit the current transaction if autocommit is enabled."""
127
181
  if self.conn and self.autocommit:
128
182
  self.conn.commit()
129
183
 
130
184
  def rollback(self) -> None:
185
+ """Roll back the current transaction; swallows errors (best-effort)."""
131
186
  if self.conn:
132
187
  try:
133
188
  self.conn.rollback()
@@ -135,6 +190,7 @@ class Database:
135
190
  pass # Best-effort; connection may already be closed.
136
191
 
137
192
  def schema_qualified_table_name(self, schema_name: str | None, table_name: str) -> str:
193
+ """Return the quoted, optionally schema-qualified form of *table_name*."""
138
194
  table_name = self.type.quoted(table_name)
139
195
  if schema_name:
140
196
  schema_name = self.type.quoted(schema_name)
@@ -142,21 +198,22 @@ class Database:
142
198
  return table_name
143
199
 
144
200
  def select_data(self, sql: str) -> tuple[list[str], list]:
145
- # Returns the results of the sql select statement.
146
- curs = self.cursor()
147
- try:
148
- curs.execute(sql)
149
- except Exception:
150
- self.rollback()
151
- raise
152
- try:
153
- _state.subvars.add_substitution("$LAST_ROWCOUNT", curs.rowcount)
154
- except Exception:
155
- pass # Non-critical: some drivers lack rowcount support.
156
- rows = curs.fetchall()
157
- return [d[0] for d in curs.description], rows
201
+ """Execute *sql* and return ``(column_names, rows)`` with all rows fetched into memory."""
202
+ with self._cursor() as curs:
203
+ try:
204
+ curs.execute(sql)
205
+ except Exception:
206
+ self.rollback()
207
+ raise
208
+ try:
209
+ _state.subvars.add_substitution("$LAST_ROWCOUNT", curs.rowcount)
210
+ except Exception:
211
+ pass # Non-critical: some drivers lack rowcount support.
212
+ rows = curs.fetchall()
213
+ return [d[0] for d in curs.description], rows
158
214
 
159
215
  def select_rowsource(self, sql: str) -> tuple[list[str], Generator]:
216
+ """Execute *sql* and return ``(column_names, row_generator)`` for streaming large result sets."""
160
217
  # Return 1) a list of column names, and 2) an iterable that yields rows.
161
218
  curs = self.cursor()
162
219
  try:
@@ -191,6 +248,7 @@ class Database:
191
248
  return [d[0] for d in curs.description], decode_row()
192
249
 
193
250
  def select_rowdict(self, sql: str) -> tuple[list[str], Iterator]:
251
+ """Execute *sql* and return ``(column_names, row_iterator)`` where each row is a ``dict``."""
194
252
  # Return an iterable that yields dictionaries of row data
195
253
  curs = self.cursor()
196
254
  try:
@@ -218,35 +276,35 @@ class Database:
218
276
  return hdrs, iter(dict_row, None)
219
277
 
220
278
  def schema_exists(self, schema_name: str) -> bool:
221
- curs = self.cursor()
222
- sql = f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = {self.paramstr};"
223
- curs.execute(sql, (schema_name,))
224
- rows = curs.fetchall()
225
- curs.close()
279
+ """Return ``True`` if *schema_name* exists in this database."""
280
+ with self._cursor() as curs:
281
+ sql = f"SELECT schema_name FROM information_schema.schemata WHERE schema_name = {self.paramstr};"
282
+ curs.execute(sql, (schema_name,))
283
+ rows = curs.fetchall()
226
284
  return len(rows) > 0
227
285
 
228
286
  def table_exists(self, table_name: str, schema_name: str | None = None) -> bool:
229
- curs = self.cursor()
230
- params: list = [table_name]
231
- schema_clause = ""
232
- if schema_name:
233
- schema_clause = f" and table_schema={self.paramstr}"
234
- params.append(schema_name)
235
- sql = f"select table_name from information_schema.tables where table_name = {self.paramstr}{schema_clause};"
236
- try:
237
- curs.execute(sql, params)
238
- except ErrInfo:
239
- raise
240
- except Exception:
241
- self.rollback()
242
- raise ErrInfo(
243
- type="db",
244
- command_text=sql,
245
- exception_msg=exception_desc(),
246
- other_msg=f"Failed test for existence of table {table_name} in {self.name()}",
247
- )
248
- rows = curs.fetchall()
249
- curs.close()
287
+ """Return ``True`` if *table_name* (optionally in *schema_name*) exists."""
288
+ with self._cursor() as curs:
289
+ params: list = [table_name]
290
+ schema_clause = ""
291
+ if schema_name:
292
+ schema_clause = f" and table_schema={self.paramstr}"
293
+ params.append(schema_name)
294
+ sql = f"select table_name from information_schema.tables where table_name = {self.paramstr}{schema_clause};"
295
+ try:
296
+ curs.execute(sql, params)
297
+ except ErrInfo:
298
+ raise
299
+ except Exception as e:
300
+ self.rollback()
301
+ raise ErrInfo(
302
+ type="db",
303
+ command_text=sql,
304
+ exception_msg=exception_desc(),
305
+ other_msg=f"Failed test for existence of table {table_name} in {self.name()}",
306
+ ) from e
307
+ rows = curs.fetchall()
250
308
  return len(rows) > 0
251
309
 
252
310
  def column_exists(
@@ -255,92 +313,94 @@ class Database:
255
313
  column_name: str,
256
314
  schema_name: str | None = None,
257
315
  ) -> bool:
258
- curs = self.cursor()
259
- params: list = [table_name]
260
- schema_clause = ""
261
- if schema_name:
262
- schema_clause = f" and table_schema={self.paramstr}"
263
- params.append(schema_name)
264
- params.append(column_name)
265
- sql = (
266
- f"select column_name from information_schema.columns "
267
- f"where table_name={self.paramstr}{schema_clause} "
268
- f"and column_name={self.paramstr};"
269
- )
270
- try:
271
- curs.execute(sql, params)
272
- except ErrInfo:
273
- raise
274
- except Exception:
275
- self.rollback()
276
- raise ErrInfo(
277
- type="db",
278
- command_text=sql,
279
- exception_msg=exception_desc(),
280
- other_msg=f"Failed test for existence of column {column_name} in table {table_name} of {self.name()}",
316
+ """Return ``True`` if *column_name* exists in *table_name* (optionally in *schema_name*)."""
317
+ with self._cursor() as curs:
318
+ params: list = [table_name]
319
+ schema_clause = ""
320
+ if schema_name:
321
+ schema_clause = f" and table_schema={self.paramstr}"
322
+ params.append(schema_name)
323
+ params.append(column_name)
324
+ sql = (
325
+ f"select column_name from information_schema.columns "
326
+ f"where table_name={self.paramstr}{schema_clause} "
327
+ f"and column_name={self.paramstr};"
281
328
  )
282
- rows = curs.fetchall()
283
- curs.close()
329
+ try:
330
+ curs.execute(sql, params)
331
+ except ErrInfo:
332
+ raise
333
+ except Exception as e:
334
+ self.rollback()
335
+ raise ErrInfo(
336
+ type="db",
337
+ command_text=sql,
338
+ exception_msg=exception_desc(),
339
+ other_msg=f"Failed test for existence of column {column_name} in table {table_name} of {self.name()}",
340
+ ) from e
341
+ rows = curs.fetchall()
284
342
  return len(rows) > 0
285
343
 
286
344
  def table_columns(self, table_name: str, schema_name: str | None = None) -> list[str]:
287
- curs = self.cursor()
288
- params: list = [table_name]
289
- schema_clause = ""
290
- if schema_name:
291
- schema_clause = f" and table_schema={self.paramstr}"
292
- params.append(schema_name)
293
- sql = (
294
- f"select column_name from information_schema.columns "
295
- f"where table_name={self.paramstr}{schema_clause} "
296
- f"order by ordinal_position;"
297
- )
298
- try:
299
- curs.execute(sql, params)
300
- except ErrInfo:
301
- raise
302
- except Exception:
303
- self.rollback()
304
- raise ErrInfo(
305
- type="db",
306
- command_text=sql,
307
- exception_msg=exception_desc(),
308
- other_msg=f"Failed to get column names for table {table_name} of {self.name()}",
345
+ """Return the ordered list of column names for *table_name*."""
346
+ with self._cursor() as curs:
347
+ params: list = [table_name]
348
+ schema_clause = ""
349
+ if schema_name:
350
+ schema_clause = f" and table_schema={self.paramstr}"
351
+ params.append(schema_name)
352
+ sql = (
353
+ f"select column_name from information_schema.columns "
354
+ f"where table_name={self.paramstr}{schema_clause} "
355
+ f"order by ordinal_position;"
309
356
  )
310
- rows = curs.fetchall()
311
- curs.close()
357
+ try:
358
+ curs.execute(sql, params)
359
+ except ErrInfo:
360
+ raise
361
+ except Exception as e:
362
+ self.rollback()
363
+ raise ErrInfo(
364
+ type="db",
365
+ command_text=sql,
366
+ exception_msg=exception_desc(),
367
+ other_msg=f"Failed to get column names for table {table_name} of {self.name()}",
368
+ ) from e
369
+ rows = curs.fetchall()
312
370
  return [row[0] for row in rows]
313
371
 
314
372
  def view_exists(self, view_name: str, schema_name: str | None = None) -> bool:
315
- curs = self.cursor()
316
- params: list = [view_name]
317
- schema_clause = ""
318
- if schema_name:
319
- schema_clause = f" and table_schema={self.paramstr}"
320
- params.append(schema_name)
321
- sql = f"select table_name from information_schema.views where table_name = {self.paramstr}{schema_clause};"
322
- try:
323
- curs.execute(sql, params)
324
- except ErrInfo:
325
- raise
326
- except Exception:
327
- self.rollback()
328
- raise ErrInfo(
329
- type="db",
330
- command_text=sql,
331
- exception_msg=exception_desc(),
332
- other_msg=f"Failed test for existence of view {view_name} in {self.name()}",
333
- )
334
- rows = curs.fetchall()
335
- curs.close()
373
+ """Return ``True`` if *view_name* (optionally in *schema_name*) exists."""
374
+ with self._cursor() as curs:
375
+ params: list = [view_name]
376
+ schema_clause = ""
377
+ if schema_name:
378
+ schema_clause = f" and table_schema={self.paramstr}"
379
+ params.append(schema_name)
380
+ sql = f"select table_name from information_schema.views where table_name = {self.paramstr}{schema_clause};"
381
+ try:
382
+ curs.execute(sql, params)
383
+ except ErrInfo:
384
+ raise
385
+ except Exception as e:
386
+ self.rollback()
387
+ raise ErrInfo(
388
+ type="db",
389
+ command_text=sql,
390
+ exception_msg=exception_desc(),
391
+ other_msg=f"Failed test for existence of view {view_name} in {self.name()}",
392
+ ) from e
393
+ rows = curs.fetchall()
336
394
  return len(rows) > 0
337
395
 
338
396
  def role_exists(self, rolename: str) -> bool:
397
+ """Return ``True`` if *rolename* exists; subclasses must override this."""
339
398
  from execsql.exceptions import DatabaseNotImplementedError
340
399
 
341
400
  raise DatabaseNotImplementedError(self.name(), "role_exists")
342
401
 
343
402
  def drop_table(self, tablename: str) -> None:
403
+ """Drop *tablename* if it exists; *tablename* must already be schema-qualified and quoted."""
344
404
  # The 'tablename' argument should be schema-qualified and quoted as necessary.
345
405
  self.execute(f"drop table if exists {tablename} cascade;")
346
406
  self.commit()
@@ -353,6 +413,11 @@ class Database:
353
413
  column_list: list[str],
354
414
  tablespec_src: Callable,
355
415
  ) -> None:
416
+ """Bulk-insert rows from *rowsource* into *table_name* using the columns in *column_list*.
417
+
418
+ *rowsource* must be a generator yielding lists of values in column order.
419
+ *tablespec_src* is a zero-argument callable that returns the table's type specification.
420
+ """
356
421
  # The rowsource argument must be a generator yielding a list of values for the columns of the table.
357
422
  # The column_list argument must an iterable containing column names. This may be a subset of
358
423
  # the names of columns in the rowsource.
@@ -382,87 +447,126 @@ class Database:
382
447
  curs = self.cursor()
383
448
  eof = False
384
449
  total_rows = 0
385
- while True:
386
- b = []
387
- for _j in range(_state.conf.import_row_buffer):
388
- try:
389
- line = next(rows)
390
- except StopIteration:
391
- eof = True
392
- else:
393
- if len(line) > len(ts_colnames):
394
- extra_err = True
395
- if _state.conf.del_empty_cols:
396
- any_non_empty = False
397
- for cno in range(len(ts_colnames), len(line)):
398
- if not (
399
- line[cno] is None
400
- or (
401
- not _state.conf.empty_strings
402
- and isinstance(line[cno], _state.stringtypes)
403
- and len(line[cno].strip()) == 0
404
- )
405
- and _state.conf.del_empty_cols
406
- ):
407
- any_non_empty = True
408
- break
409
- extra_err = any_non_empty
410
- if extra_err:
411
- raise ErrInfo(
412
- type="error",
413
- other_msg=f"Too many data columns on line {{{line}}}",
414
- )
415
- else:
416
- line = line[: len(ts_colnames)]
417
- if not (len(line) == 1 and line[0] is None):
418
- if _state.conf.trim_strings or _state.conf.replace_newlines or not _state.conf.empty_strings:
419
- for i in range(len(line)):
420
- if line[i] is not None and isinstance(
421
- line[i],
422
- _state.stringtypes,
423
- ):
424
- if _state.conf.trim_strings:
425
- line[i] = line[i].strip()
426
- if _state.conf.replace_newlines:
427
- line[i] = re.sub(
428
- r"[\s\t]*[\r\n]+[\s\t]*",
429
- " ",
430
- line[i],
450
+
451
+ # Optional rich progress bar for long-running imports.
452
+ use_progress = getattr(_state.conf, "show_progress", False)
453
+ progress_ctx = None
454
+ task_id = None
455
+ if use_progress:
456
+ try:
457
+ from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn
458
+ from rich.console import Console
459
+
460
+ progress_ctx = Progress(
461
+ SpinnerColumn(),
462
+ TextColumn("[bold blue]IMPORT[/bold blue] {task.description}"),
463
+ TextColumn("{task.completed:,} rows"),
464
+ TimeElapsedColumn(),
465
+ console=Console(stderr=True),
466
+ )
467
+ except ImportError:
468
+ use_progress = False
469
+
470
+ def _import_loop() -> int:
471
+ nonlocal eof, total_rows, task_id
472
+ while True:
473
+ b = []
474
+ for _j in range(_state.conf.import_row_buffer):
475
+ try:
476
+ line = next(rows)
477
+ except StopIteration:
478
+ eof = True
479
+ else:
480
+ if len(line) > len(ts_colnames):
481
+ extra_err = True
482
+ if _state.conf.del_empty_cols:
483
+ any_non_empty = False
484
+ for cno in range(len(ts_colnames), len(line)):
485
+ if not (
486
+ line[cno] is None
487
+ or (
488
+ not _state.conf.empty_strings
489
+ and isinstance(line[cno], _state.stringtypes)
490
+ and len(line[cno].strip()) == 0
431
491
  )
432
- if not _state.conf.empty_strings and line[i].strip() == "":
433
- line[i] = None
434
- lt = [type_objs[i].from_data(val) if val is not None else None for i, val in enumerate(line)]
435
- lt = [type_mod_fn[i](v) if type_mod_fn[i] else v for i, v in enumerate(lt)]
436
- row = []
437
- for i, v in enumerate(lt):
438
- if incl_col[i]:
439
- row.append(v)
440
- add_line = True
441
- if not _state.conf.empty_rows:
442
- add_line = not all(c is None for c in row)
443
- if add_line:
444
- b.append(row)
445
- if len(b) > 0:
446
- try:
447
- curs.executemany(sql, b)
448
- except ErrInfo:
449
- raise
450
- except Exception:
451
- self.rollback()
452
- raise ErrInfo(
453
- type="db",
454
- command_text=sql,
455
- exception_msg=exception_desc(),
456
- other_msg=f"Can't load data into table {sq_name} of {self.name()} from line {{{line}}}",
457
- )
458
- total_rows += len(b)
459
- interval = _state.conf.import_progress_interval
460
- if _state.exec_log and interval > 0 and total_rows % interval == 0:
461
- _state.exec_log.log_status_info(
462
- f"IMPORT into {sq_name}: {total_rows} rows imported so far.",
463
- )
464
- if eof:
465
- break
492
+ and _state.conf.del_empty_cols
493
+ ):
494
+ any_non_empty = True
495
+ break
496
+ extra_err = any_non_empty
497
+ if extra_err:
498
+ raise ErrInfo(
499
+ type="error",
500
+ other_msg=f"Too many data columns on line {{{line}}}",
501
+ )
502
+ else:
503
+ line = line[: len(ts_colnames)]
504
+ if not (len(line) == 1 and line[0] is None):
505
+ if (
506
+ _state.conf.trim_strings
507
+ or _state.conf.replace_newlines
508
+ or not _state.conf.empty_strings
509
+ ):
510
+ for i in range(len(line)):
511
+ if line[i] is not None and isinstance(
512
+ line[i],
513
+ _state.stringtypes,
514
+ ):
515
+ if _state.conf.trim_strings:
516
+ line[i] = line[i].strip()
517
+ if _state.conf.replace_newlines:
518
+ line[i] = re.sub(
519
+ r"[\s\t]*[\r\n]+[\s\t]*",
520
+ " ",
521
+ line[i],
522
+ )
523
+ if not _state.conf.empty_strings and line[i].strip() == "":
524
+ line[i] = None
525
+ lt = [
526
+ type_objs[i].from_data(val) if val is not None else None for i, val in enumerate(line)
527
+ ]
528
+ lt = [type_mod_fn[i](v) if type_mod_fn[i] else v for i, v in enumerate(lt)]
529
+ row = []
530
+ for i, v in enumerate(lt):
531
+ if incl_col[i]:
532
+ row.append(v)
533
+ add_line = True
534
+ if not _state.conf.empty_rows:
535
+ add_line = not all(c is None for c in row)
536
+ if add_line:
537
+ b.append(row)
538
+ if len(b) > 0:
539
+ try:
540
+ curs.executemany(sql, b)
541
+ except ErrInfo:
542
+ raise
543
+ except Exception as e:
544
+ self.rollback()
545
+ raise ErrInfo(
546
+ type="db",
547
+ command_text=sql,
548
+ exception_msg=exception_desc(),
549
+ other_msg=f"Can't load data into table {sq_name} of {self.name()} from line {{{line}}}",
550
+ ) from e
551
+ total_rows += len(b)
552
+ if use_progress and progress_ctx is not None and task_id is not None:
553
+ progress_ctx.update(task_id, completed=total_rows)
554
+ interval = _state.conf.import_progress_interval
555
+ if _state.exec_log and interval > 0 and total_rows % interval == 0:
556
+ _state.exec_log.log_status_info(
557
+ f"IMPORT into {sq_name}: {total_rows} rows imported so far.",
558
+ )
559
+ if eof:
560
+ break
561
+ return total_rows
562
+
563
+ if use_progress and progress_ctx is not None:
564
+ with progress_ctx:
565
+ task_id = progress_ctx.add_task(sq_name, total=None)
566
+ _import_loop()
567
+ else:
568
+ _import_loop()
569
+
466
570
  if _state.exec_log:
467
571
  _state.exec_log.log_status_info(
468
572
  f"IMPORT into {sq_name} complete: {total_rows} rows imported.",
@@ -475,6 +579,7 @@ class Database:
475
579
  csv_file_obj: Any,
476
580
  skipheader: bool,
477
581
  ) -> None:
582
+ """Import a CSV/tabular file into *table_name*; column names must be compatible."""
478
583
  # Import a text (CSV) file containing tabular data to a table. Columns must be compatible.
479
584
  if not self.table_exists(table_name, schema_name):
480
585
  raise ErrInfo(
@@ -517,11 +622,14 @@ class Database:
517
622
  column_name: str,
518
623
  file_name: str,
519
624
  ) -> None:
625
+ """Insert the raw binary content of *file_name* as a single row into *column_name* of *table_name*."""
520
626
  with open(file_name, "rb") as f:
521
627
  filedata = f.read()
522
628
  sq_name = self.schema_qualified_table_name(schema_name, table_name)
523
- sql = f"insert into {sq_name} ({column_name}) values ({self.paramsubs(1)});"
524
- self.cursor().execute(sql, (filedata,))
629
+ quoted_col = self.quote_identifier(column_name)
630
+ sql = f"insert into {sq_name} ({quoted_col}) values ({self.paramsubs(1)});"
631
+ with self._cursor() as curs:
632
+ curs.execute(sql, (filedata,))
525
633
 
526
634
 
527
635
  class DatabasePool:
@@ -538,6 +646,7 @@ class DatabasePool:
538
646
  return "DatabasePool()"
539
647
 
540
648
  def add(self, db_alias: str, db_obj: Database) -> None:
649
+ """Register *db_obj* under *db_alias*, setting it as initial/current if this is the first connection."""
541
650
  db_alias = db_alias.lower()
542
651
  if db_alias == "initial" and len(self.pool) > 0:
543
652
  raise ErrInfo(
@@ -561,25 +670,27 @@ class DatabasePool:
561
670
  self.pool[db_alias] = db_obj
562
671
 
563
672
  def aliases(self) -> list[str]:
564
- # Return a list of the currently defined aliases
673
+ """Return a list of all currently registered database aliases."""
565
674
  return list(self.pool)
566
675
 
567
676
  def current(self) -> Database:
568
- # Return the current db object.
677
+ """Return the currently active ``Database`` object."""
569
678
  return self.pool[self.current_db]
570
679
 
571
680
  def current_alias(self) -> str:
572
- # Return the alias of the current db object.
681
+ """Return the alias string for the currently active database."""
573
682
  return self.current_db
574
683
 
575
684
  def initial(self) -> Database:
685
+ """Return the first ``Database`` that was added to the pool."""
576
686
  return self.pool[self.initial_db]
577
687
 
578
688
  def aliased_as(self, db_alias: str) -> Database:
689
+ """Return the ``Database`` registered under *db_alias*."""
579
690
  return self.pool[db_alias]
580
691
 
581
692
  def make_current(self, db_alias: str) -> None:
582
- # Change the current database in use.
693
+ """Set the active database to *db_alias*; raises ``ErrInfo`` if the alias is unknown."""
583
694
  db_alias = db_alias.lower()
584
695
  if db_alias not in self.pool:
585
696
  raise ErrInfo(
@@ -589,6 +700,7 @@ class DatabasePool:
589
700
  self.current_db = db_alias
590
701
 
591
702
  def disconnect(self, alias: str) -> None:
703
+ """Close and remove the connection registered under *alias* from the pool."""
592
704
  if alias == self.current_db or (alias == "initial" and "initial" in self.pool):
593
705
  raise ErrInfo(
594
706
  type="error",
@@ -599,6 +711,7 @@ class DatabasePool:
599
711
  del self.pool[alias]
600
712
 
601
713
  def closeall(self) -> None:
714
+ """Roll back and close every connection in the pool, then reset the pool to empty."""
602
715
  for alias, db in self.pool.items():
603
716
  nm = db.name()
604
717
  try: