jupyter-duckdb 1.2.0.1__py3-none-any.whl → 1.4.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. duckdb_kernel/db/Connection.py +3 -0
  2. duckdb_kernel/db/Table.py +8 -0
  3. duckdb_kernel/db/implementation/duckdb/Connection.py +27 -13
  4. duckdb_kernel/db/implementation/postgres/Connection.py +27 -12
  5. duckdb_kernel/db/implementation/sqlite/Connection.py +9 -3
  6. duckdb_kernel/kernel.py +407 -200
  7. duckdb_kernel/magics/MagicCommand.py +34 -10
  8. duckdb_kernel/magics/MagicCommandCallback.py +11 -7
  9. duckdb_kernel/magics/MagicCommandHandler.py +58 -9
  10. duckdb_kernel/magics/MagicState.py +11 -0
  11. duckdb_kernel/magics/__init__.py +1 -0
  12. duckdb_kernel/parser/DCParser.py +17 -7
  13. duckdb_kernel/parser/LogicParser.py +6 -6
  14. duckdb_kernel/parser/ParserError.py +18 -0
  15. duckdb_kernel/parser/RAParser.py +29 -21
  16. duckdb_kernel/parser/__init__.py +1 -0
  17. duckdb_kernel/parser/elements/DCOperand.py +7 -4
  18. duckdb_kernel/parser/elements/LogicElement.py +0 -2
  19. duckdb_kernel/parser/elements/RAElement.py +4 -1
  20. duckdb_kernel/parser/elements/RARelationReference.py +86 -0
  21. duckdb_kernel/parser/elements/RAUnaryOperator.py +6 -0
  22. duckdb_kernel/parser/elements/__init__.py +2 -1
  23. duckdb_kernel/parser/elements/binary/And.py +1 -1
  24. duckdb_kernel/parser/elements/binary/ConditionalSet.py +37 -10
  25. duckdb_kernel/parser/elements/binary/Cross.py +2 -2
  26. duckdb_kernel/parser/elements/binary/Difference.py +1 -1
  27. duckdb_kernel/parser/elements/binary/Divide.py +1 -1
  28. duckdb_kernel/parser/elements/binary/Division.py +0 -4
  29. duckdb_kernel/parser/elements/binary/FullOuterJoin.py +40 -0
  30. duckdb_kernel/parser/elements/binary/Join.py +4 -1
  31. duckdb_kernel/parser/elements/binary/LeftOuterJoin.py +27 -0
  32. duckdb_kernel/parser/elements/binary/LeftSemiJoin.py +27 -0
  33. duckdb_kernel/parser/elements/binary/RightOuterJoin.py +27 -0
  34. duckdb_kernel/parser/elements/binary/RightSemiJoin.py +27 -0
  35. duckdb_kernel/parser/elements/binary/__init__.py +21 -6
  36. duckdb_kernel/parser/elements/unary/AttributeRename.py +39 -0
  37. duckdb_kernel/parser/elements/unary/Projection.py +1 -1
  38. duckdb_kernel/parser/elements/unary/Rename.py +68 -14
  39. duckdb_kernel/parser/elements/unary/__init__.py +2 -0
  40. duckdb_kernel/parser/tokenizer/Token.py +24 -3
  41. duckdb_kernel/parser/util/QuerySplitter.py +87 -0
  42. duckdb_kernel/parser/util/RenamableColumnList.py +10 -2
  43. duckdb_kernel/tests/__init__.py +76 -0
  44. duckdb_kernel/tests/test_dc.py +483 -0
  45. duckdb_kernel/tests/test_ra.py +1966 -0
  46. duckdb_kernel/tests/test_result_comparison.py +173 -0
  47. duckdb_kernel/tests/test_sql.py +48 -0
  48. duckdb_kernel/util/ResultSetComparator.py +22 -4
  49. duckdb_kernel/util/SQL.py +6 -0
  50. duckdb_kernel/util/TestError.py +4 -0
  51. duckdb_kernel/visualization/Plotly.py +144 -0
  52. duckdb_kernel/visualization/RATreeDrawer.py +34 -2
  53. duckdb_kernel/visualization/__init__.py +1 -0
  54. duckdb_kernel/visualization/lib/__init__.py +53 -0
  55. duckdb_kernel/visualization/lib/plotly-3.0.1.min.js +3879 -0
  56. duckdb_kernel/visualization/lib/ra.css +3 -0
  57. duckdb_kernel/visualization/lib/ra.js +55 -0
  58. {jupyter_duckdb-1.2.0.1.dist-info → jupyter_duckdb-1.4.111.dist-info}/METADATA +53 -19
  59. jupyter_duckdb-1.4.111.dist-info/RECORD +104 -0
  60. {jupyter_duckdb-1.2.0.1.dist-info → jupyter_duckdb-1.4.111.dist-info}/WHEEL +1 -1
  61. jupyter_duckdb-1.2.0.1.dist-info/RECORD +0 -82
  62. {jupyter_duckdb-1.2.0.1.dist-info → jupyter_duckdb-1.4.111.dist-info}/top_level.txt +0 -0
duckdb_kernel/kernel.py CHANGED
@@ -9,11 +9,14 @@ from typing import Optional, Dict, List, Tuple
9
9
 
10
10
  from ipykernel.kernelbase import Kernel
11
11
 
12
- from .db import Connection, DatabaseError
12
+ from .db import Connection, DatabaseError, Table
13
13
  from .db.error import *
14
14
  from .magics import *
15
- from .parser import RAParser, DCParser
15
+ from .parser import RAParser, DCParser, ParserError
16
+ from .parser.util.QuerySplitter import split_queries, get_last_query
16
17
  from .util.ResultSetComparator import ResultSetComparator
18
+ from .util.SQL import SQL_KEYWORDS
19
+ from .util.TestError import TestError
17
20
  from .util.formatting import row_count, rows_table, wrap_image
18
21
  from .visualization import *
19
22
 
@@ -25,9 +28,11 @@ class DuckDBKernel(Kernel):
25
28
  implementation_version = '1.0'
26
29
  banner = 'DuckDB Kernel'
27
30
  language_info = {
28
- 'name': 'duckdb',
29
- 'mimetype': 'application/sql',
31
+ 'name': 'sql',
30
32
  'file_extension': '.sql',
33
+ 'mimetype': 'text/x-sql',
34
+ 'codemirror_mode': 'sql',
35
+ 'pygments_lexer': 'sql',
31
36
  }
32
37
 
33
38
  def __init__(self, **kwargs):
@@ -37,21 +42,31 @@ class DuckDBKernel(Kernel):
37
42
  self._magics: MagicCommandHandler = MagicCommandHandler()
38
43
 
39
44
  self._magics.add(
40
- MagicCommand('create').arg('database').opt('of').opt('with_tests').on(self._create_magic),
41
- MagicCommand('load').arg('database').opt('with_tests').on(self._load_magic),
45
+ MagicCommand('create').arg('database').opt('of').opt('name').on(self._create_magic),
46
+ MagicCommand('load').arg('database').opt('name').on(self._load_magic),
47
+ MagicCommand('copy').arg('source').arg('target').on(self._copy_magic),
48
+ MagicCommand('use').arg('name').on(self._use_magic),
49
+ MagicCommand('load_tests').arg('tests').on(self._load_tests_magic),
42
50
  MagicCommand('test').arg('name').result(True).on(self._test_magic),
43
51
  MagicCommand('all', 'all_rows').on(self._all_magic),
44
52
  MagicCommand('max_rows').arg('count').on(self._max_rows_magic),
45
53
  MagicCommand('query_max_rows').arg('count').on(self._query_max_rows_magic),
46
54
  MagicCommand('schema').flag('td').opt('only').on(self._schema_magic),
47
55
  MagicCommand('store').arg('file').flag('noheader').result(True).on(self._store_magic),
48
- MagicCommand('ra').flag('analyze').code(True).on(self._ra_magic),
49
- MagicCommand('dc').code(True).on(self._dc_magic)
56
+ MagicCommand('sql').disable('ra', 'dc', 'auto_parser'),
57
+ MagicCommand('ra').disable('sql', 'dc', 'auto_parser').flag('analyze').code(True).on(self._ra_magic),
58
+ MagicCommand('all_ra').arg('value', '1').on(self._all_ra_magic),
59
+ MagicCommand('dc').disable('sql', 'ra', 'auto_parser').code(True).on(self._dc_magic),
60
+ MagicCommand('all_dc').arg('value', '1').on(self._all_dc_magic),
61
+ MagicCommand('auto_parser').disable('sql', 'ra', 'dc').code(True).on(self._auto_parser_magic),
62
+ MagicCommand('guess_parser').arg('value', '1').on(self._guess_parser_magic),
63
+ MagicCommand('plotly').arg('type').arg('mapping').opt('title').result(True).on(self._plotly_magic),
64
+ MagicCommand('plotly_raw').opt('title').result(True).on(self._plotly_raw_magic)
50
65
  )
51
66
 
52
67
  # create placeholders for database and tests
53
- self._db: Optional[Connection] = None
54
- self._tests: Optional[Dict] = None
68
+ self._db: Dict[str, Connection] = {}
69
+ self._tests: Dict = {}
55
70
 
56
71
  # output related functions
57
72
  def print(self, text: str, name: str = 'stdout'):
@@ -88,129 +103,145 @@ class DuckDBKernel(Kernel):
88
103
  })
89
104
 
90
105
  # database related functions
91
- def _load_database(self, path: str):
92
- if self._db is None:
93
- # If the provided path looks like a postgres url,
94
- # we want to use the postgres driver.
95
- if path.startswith(('postgresql://', 'postgres://', 'pgsql://', 'psql://', 'pg://')):
96
- # pull data from connection string
97
- re_expr = r'(postgresql|postgres|pgsql|psql|pg)://((.*?)(:(.*?))?@)?(.*?)(:(\d+))?(/(.*))?'
98
- match = re.fullmatch(re_expr, path)
99
-
100
- host = match.group(6)
101
- port = int(match.group(8)) if match.group(8) is not None else 5432
102
- username = match.group(3)
103
- password = match.group(5)
104
- database_name = match.group(10)
105
-
106
- # load and create instance
107
- try:
108
- from .db.implementation.postgres import Connection as Postgres
109
- self._db = Postgres(host, port, username, password, database_name)
110
- except ImportError:
111
- self.print('psycopg could not be found', name='stderr')
112
-
113
- # Otherwise the provided path is used to create an
114
- # in-process instance.
115
- else:
116
- # By default, we try to load DuckDB.
117
- try:
118
- from .db.implementation.duckdb import Connection as DuckDB
119
- self._db = DuckDB(path)
106
+ def _load_database(self, path: str, name: str) -> Connection:
107
+ if name in self._db:
108
+ raise ValueError(f'duplicate database name {name}')
109
+
110
+ # If the provided path looks like a postgres url,
111
+ # we want to use the postgres driver.
112
+ if path.startswith(('postgresql://', 'postgres://', 'pgsql://', 'psql://', 'pg://')):
113
+ # pull data from connection string
114
+ re_expr = r'(postgresql|postgres|pgsql|psql|pg)://((.*?)(:(.*?))?@)?(.*?)(:(\d+))?(/(.*))?'
115
+ match = re.fullmatch(re_expr, path)
116
+
117
+ host = match.group(6)
118
+ port = int(match.group(8)) if match.group(8) is not None else 5432
119
+ username = match.group(3)
120
+ password = match.group(5)
121
+ database_name = match.group(10)
122
+
123
+ # load and create instance
124
+ from .db.implementation.postgres import Connection as Postgres
125
+ self._db[name] = Postgres(host, port, username, password, database_name)
126
+
127
+ # Otherwise the provided path is used to create an
128
+ # in-process instance.
129
+ else:
130
+ # By default, we try to load DuckDB.
131
+ try:
132
+ from .db.implementation.duckdb import Connection as DuckDB
133
+ self._db[name] = DuckDB(path)
120
134
 
121
- # If DuckDB is not installed or fails to load,
122
- # we use SQLite instead.
123
- except ImportError:
124
- self.print('DuckDB is not available\n')
135
+ # If DuckDB is not installed or fails to load,
136
+ # we use SQLite instead.
137
+ except ImportError:
138
+ self.print('DuckDB is not available\n')
125
139
 
126
- from .db.implementation.sqlite import Connection as SQLite
127
- self._db = SQLite(path)
140
+ from .db.implementation.sqlite import Connection as SQLite
141
+ self._db[name] = SQLite(path)
128
142
 
129
- return True
130
- else:
131
- return False
143
+ return self._db[name]
144
+
145
+ def _unload_database(self, name: str):
146
+ if name in self._db:
147
+ self._db[name].close()
148
+ del self._db[name]
132
149
 
133
- def _unload_database(self):
134
- if self._db is not None:
135
- self._db.close()
136
- self._db = None
137
150
  return True
138
151
  else:
139
152
  return False
140
153
 
141
- def _execute_stmt(self, query: str, silent: bool,
142
- max_rows: Optional[int]) -> Tuple[Optional[List[str]], Optional[List[List]]]:
143
- if self._db is None:
154
+ def _execute_stmt(self, silent: bool, state: MagicState, name: str, query: str) \
155
+ -> Tuple[Optional[List[str]], Optional[List[List]]]:
156
+ if state.db is None:
144
157
  raise AssertionError('load a database first')
145
158
 
146
159
  # execute query and store start and end timestamp
147
160
  st = time.time()
148
161
 
149
162
  try:
150
- columns, rows = self._db.execute(query)
163
+ columns, rows = state.db.execute(query)
151
164
  except EmptyResultError:
152
165
  columns, rows = None, None
153
166
 
154
167
  et = time.time()
155
168
 
156
- # return result if silent
157
- if silent:
158
- return columns, rows
159
-
160
- # print EXPLAIN queries as raw text if using DuckDB
161
- if query.strip().startswith('EXPLAIN') and self._db.plain_explain():
162
- for ekey, evalue in rows:
163
- self.print_data(f'<b>{ekey}</b><br><pre>{evalue}</pre>')
169
+ # print result if not silent
170
+ if not silent:
171
+ # print EXPLAIN queries as raw text if using DuckDB
172
+ last_query = get_last_query(query, remove_comments=True).strip()
164
173
 
165
- return None, None
166
-
167
- # print every other query as a table
168
- else:
169
- if columns is not None:
170
- # table header
171
- table_header = ''.join(f'<th>{c}</th>' for c in columns)
172
-
173
- # table data
174
- if max_rows is not None and len(rows) > max_rows:
175
- table_data = f'''
176
- {rows_table(rows[:math.ceil(max_rows / 2)])}
177
- <tr>
178
- <td colspan="{len(columns)}"
179
- style="text-align: center"
180
- title="{row_count(len(rows) - max_rows)} omitted">
181
- ...
182
- </td>
183
- </tr>
184
- {rows_table(rows[-math.floor(max_rows // 2):])}
185
- '''
174
+ if last_query.startswith('EXPLAIN') and state.db.plain_explain():
175
+ for ekey, evalue in rows:
176
+ html = f'<b>{ekey}</b><br><pre>{evalue}</pre>'
177
+ break
186
178
  else:
187
- table_data = rows_table(rows)
179
+ html = ''
188
180
 
189
- # send to client
190
- self.print_data(f'''
191
- <table class="duckdb-query-result">
192
- {table_header}
193
- {table_data}
194
- </table>
195
- ''')
181
+ # print every other query as a table
182
+ else:
183
+ if columns is not None:
184
+ # table header
185
+ mapped_columns = (state.column_name_mapping.get(c, c) for c in columns)
186
+ table_header = ''.join(f'<th>{c}</th>' for c in mapped_columns)
187
+
188
+ # table data
189
+ if state.max_rows is not None and len(rows) > state.max_rows:
190
+ table_data = f'''
191
+ {rows_table(rows[:math.ceil(state.max_rows / 2)])}
192
+ <tr>
193
+ <td colspan="{len(columns)}"
194
+ style="text-align: center"
195
+ title="{row_count(len(rows) - state.max_rows)} omitted">
196
+ ...
197
+ </td>
198
+ </tr>
199
+ {rows_table(rows[-math.floor(state.max_rows // 2):])}
200
+ '''
201
+ else:
202
+ table_data = rows_table(rows)
203
+
204
+ # send to client
205
+ html = (f'''
206
+ <table class="duckdb-query-result-table">
207
+ {table_header}
208
+ {table_data}
209
+ </table>
210
+
211
+ {row_count(len(rows))} in {et - st:.3f}s
212
+ ''')
196
213
 
197
- self.print_data(f'{row_count(len(rows))} in {et - st:.3f}s')
214
+ else:
215
+ html = f'statement executed without result in {et - st:.3f}s'
198
216
 
199
- else:
200
- self.print_data(f'statement executed without result in {et - st:.3f}s')
217
+ self.print_data(f'''
218
+ <div class="duckdb-query-result {name}">
219
+ {html}
220
+ </div>
221
+ ''')
201
222
 
202
223
  return columns, rows
203
224
 
204
225
  # magic command related functions
205
- def _create_magic(self, silent: bool, path: str, of: Optional[str], with_tests: Optional[str]):
206
- self._load(silent, path, True, of, with_tests)
226
+ def _create_magic(self, silent: bool, state: MagicState,
227
+ path: str, of: Optional[str], name: Optional[str]):
228
+ self._load(silent, state, path, True, of, name)
229
+
230
+ def _load_magic(self, silent: bool, state: MagicState,
231
+ path: str, name: Optional[str]):
232
+ self._load(silent, state, path, False, None, name)
207
233
 
208
- def _load_magic(self, silent: bool, path: str, with_tests: Optional[str]):
209
- self._load(silent, path, False, None, with_tests)
234
+ def _load(self, silent: bool, state: MagicState,
235
+ path: str, create: bool, of: Optional[str], name: Optional[str]):
236
+ # use default name if non provided
237
+ if name is None:
238
+ name = 'default'
239
+
240
+ if not silent:
241
+ self.print(f'--- connection {name} ---\n')
210
242
 
211
- def _load(self, silent: bool, path: str, create: bool, of: Optional[str], with_tests: Optional[str]):
212
243
  # unload current database if necessary
213
- if self._unload_database():
244
+ if self._unload_database(name):
214
245
  if not silent:
215
246
  self.print('unloaded database\n')
216
247
 
@@ -227,10 +258,10 @@ class DuckDBKernel(Kernel):
227
258
  if create and os.path.exists(path):
228
259
  os.remove(path)
229
260
 
230
- if self._load_database(path):
231
- if not silent:
232
- # self.print(f'loaded database "{path}"\n')
233
- self.print(str(self._db) + '\n')
261
+ state.db = self._load_database(path, name)
262
+ if not silent:
263
+ # self.print(f'loaded database "{path}"\n')
264
+ self.print(str(state.db) + '\n')
234
265
 
235
266
  # copy data from source database
236
267
  if of is not None:
@@ -246,18 +277,17 @@ class DuckDBKernel(Kernel):
246
277
  content = file.read()
247
278
 
248
279
  # You can only execute one statement at a time using SQLite.
249
- if not self._db.multiple_statements_per_query():
250
- statements = re.split(r';\r?\n', content)
251
- for statement in statements:
280
+ if not state.db.multiple_statements_per_query():
281
+ for statement in split_queries(content):
252
282
  try:
253
- self._db.execute(statement)
283
+ state.db.execute(statement)
254
284
  except EmptyResultError:
255
285
  pass
256
286
 
257
287
  # Other DBMS can execute multiple statements at a time.
258
288
  else:
259
289
  try:
260
- self._db.execute(content)
290
+ state.db.execute(content)
261
291
  except EmptyResultError:
262
292
  pass
263
293
 
@@ -271,29 +301,56 @@ class DuckDBKernel(Kernel):
271
301
  of_db.execute('SHOW TABLES')
272
302
  for table, in of_db.fetchall():
273
303
  transfer_df = of_db.query(f'SELECT * FROM {table}').to_df()
274
- self._db.execute(f'CREATE TABLE {table} AS SELECT * FROM transfer_df')
304
+ state.db.execute(f'CREATE TABLE {table} AS SELECT * FROM transfer_df')
275
305
 
276
306
  if not silent:
277
307
  self.print(f'transferred table {table}\n')
278
308
 
279
- # load tests
280
- if with_tests is None:
281
- self._tests = {}
282
- else:
283
- with open(with_tests, 'r', encoding='utf-8') as tests_file:
284
- self._tests = json.load(tests_file)
285
- for test in self._tests.values():
286
- if 'attributes' in test:
287
- rows = {k: [] for k in test['attributes']}
288
- for row in test['equals']:
289
- for k, v in zip(test['attributes'], row):
290
- rows[k].append(v)
309
+ def _copy_magic(self, silent: bool, state: MagicState, source: str, target: str):
310
+ if source not in self._db:
311
+ raise ValueError(f'unknown connection {source}')
291
312
 
292
- test['equals'] = rows
313
+ if not silent:
314
+ self.print(f'--- connection {target} ---\n')
293
315
 
294
- self.print(f'loaded tests from {with_tests}\n')
316
+ # unload current database if necessary
317
+ if self._unload_database(target):
318
+ if not silent:
319
+ self.print('unloaded database\n')
320
+
321
+ # copy connection
322
+ self._db[target] = self._db[source].copy()
323
+ state.db = self._db[target]
324
+
325
+ if not silent:
326
+ self.print(str(state.db) + '\n')
327
+
328
+ def _use_magic(self, silent: bool, state: MagicState, name: str):
329
+ if name not in self._db:
330
+ raise ValueError(f'unknown connection {name}')
331
+
332
+ state.db = self._db[name]
333
+
334
+ def _load_tests_magic(self, silent: bool, state: MagicState, tests: str):
335
+ with open(tests, 'r', encoding='utf-8') as tests_file:
336
+ self._tests = json.load(tests_file)
337
+ for test in self._tests.values():
338
+ if 'attributes' in test:
339
+ rows = {k: [] for k in test['attributes']}
340
+ for row in test['equals']:
341
+ for k, v in zip(test['attributes'], row):
342
+ rows[k].append(v)
343
+
344
+ test['equals'] = rows
345
+
346
+ self.print(f'loaded tests from {tests}\n')
347
+
348
+ def _test_magic(self, silent: bool, state: MagicState, result_columns: List[str], result: List[List], name: str):
349
+ # If the query was empty, result_columns and result may be None.
350
+ if result_columns is None or result is None:
351
+ self.print_data(wrap_image(False, 'Statement did not return data.'))
352
+ return
295
353
 
296
- def _test_magic(self, silent: bool, result_columns: List[str], result: List[List], name: str):
297
354
  # Testing makes no sense if there is no output.
298
355
  if silent:
299
356
  return
@@ -302,55 +359,64 @@ class DuckDBKernel(Kernel):
302
359
  result_columns = [col.rsplit('.', 1)[-1] for col in result_columns]
303
360
 
304
361
  # extract data for test
305
- data = self._tests[name]
362
+ test_data = self._tests[name]
306
363
 
364
+ # execute test
365
+ try:
366
+ self._execute_test(test_data, result_columns, result)
367
+ self.print_data(wrap_image(True))
368
+ except TestError as e:
369
+ self.print_data(wrap_image(False, e.message))
370
+ if os.environ.get('DUCKDB_TESTS_RAISE_EXCEPTION', 'false').lower() in ('true', '1'):
371
+ raise e
372
+
373
+ @staticmethod
374
+ def _execute_test(test_data: Dict, result_columns: List[str], result: List[List]):
307
375
  # check columns if required
308
- if isinstance(data['equals'], dict):
376
+ if isinstance(test_data['equals'], dict):
309
377
  # get column order
310
- data_columns = list(data['equals'].keys())
378
+ data_columns = list(test_data['equals'].keys())
311
379
  column_order = []
312
380
 
313
381
  for dc in data_columns:
314
382
  found = 0
315
383
  for i, rc in enumerate(result_columns):
316
- if dc == rc:
384
+ if dc.lower() == rc.lower():
317
385
  column_order.append(i)
318
386
  found += 1
319
387
 
320
388
  if found == 0:
321
- return self.print_data(wrap_image(False, f'attribute {dc} missing'))
389
+ raise TestError(f'attribute {dc} missing')
322
390
  if found >= 2:
323
- return self.print_data(wrap_image(False, f'ambiguous attribute {dc}'))
391
+ raise TestError(f'ambiguous attribute {dc}')
324
392
 
325
393
  # abort if columns from result are unnecessary
326
394
  for i, rc in enumerate(result_columns):
327
395
  if i not in column_order:
328
- return self.print_data(wrap_image(False, f'unnecessary attribute {rc}'))
396
+ raise TestError(f'unnecessary attribute {rc}')
329
397
 
330
398
  # reorder columns and transform to list of lists
331
399
  sorted_columns = [x for _, x in sorted(zip(column_order, data_columns))]
332
400
  rows = []
333
401
 
334
- for row in zip(*(data['equals'][col] for col in sorted_columns)):
402
+ for row in zip(*(test_data['equals'][col] for col in sorted_columns)):
335
403
  rows.append(row)
336
404
 
337
405
  else:
338
- rows = data['equals']
406
+ rows = test_data['equals']
339
407
 
340
408
  # ordered test
341
- if data['ordered']:
409
+ if test_data['ordered']:
342
410
  # calculate diff
343
411
  rsc = ResultSetComparator(result, rows)
344
412
 
345
413
  missing = len(rsc.ordered_right_only)
346
414
  if missing > 0:
347
- return self.print_data(wrap_image(False, f'{row_count(missing)} missing'))
415
+ raise TestError(f'{row_count(missing)} missing')
348
416
 
349
417
  missing = len(rsc.ordered_left_only)
350
418
  if missing > 0:
351
- return self.print_data(wrap_image(False, f'{row_count(missing)} more than required'))
352
-
353
- return self.print_data(wrap_image(True))
419
+ raise TestError(f'{row_count(missing)} more than required')
354
420
 
355
421
  # unordered test
356
422
  else:
@@ -362,39 +428,35 @@ class DuckDBKernel(Kernel):
362
428
 
363
429
  # print result
364
430
  if below > 0 and above > 0:
365
- self.print_data(wrap_image(False, f'{row_count(below)} missing, {row_count(above)} unnecessary'))
431
+ raise TestError(f'{row_count(below)} missing, {row_count(above)} unnecessary')
366
432
  elif below > 0:
367
- self.print_data(wrap_image(False, f'{row_count(below)} missing'))
433
+ raise TestError(f'{row_count(below)} missing')
368
434
  elif above > 0:
369
- self.print_data(wrap_image(False, f'{row_count(above)} unnecessary'))
370
- else:
371
- self.print_data(wrap_image(True))
435
+ raise TestError(f'{row_count(above)} unnecessary')
372
436
 
373
- def _all_magic(self, silent: bool):
374
- return {
375
- 'max_rows': None
376
- }
437
+ def _all_magic(self, silent: bool, state: MagicState):
438
+ state.max_rows = None
377
439
 
378
- def _max_rows_magic(self, silent: bool, count: str):
440
+ def _max_rows_magic(self, silent: bool, state: MagicState, count: str):
379
441
  if count.lower() != 'none':
380
442
  DuckDBKernel.DEFAULT_MAX_ROWS = int(count)
381
443
  else:
382
444
  DuckDBKernel.DEFAULT_MAX_ROWS = None
383
445
 
384
- def _query_max_rows_magic(self, silent: bool, count: str):
385
- return {
386
- 'max_rows': int(count) if count.lower() != 'none' else None
387
- }
446
+ state.max_rows = DuckDBKernel.DEFAULT_MAX_ROWS
388
447
 
389
- def _schema_magic(self, silent: bool, td: bool, only: Optional[str]):
390
- if self._db is None:
391
- raise AssertionError('load a database first')
448
+ def _query_max_rows_magic(self, silent: bool, state: MagicState, count: str):
449
+ state.max_rows = int(count) if count.lower() != 'none' else None
392
450
 
451
+ def _schema_magic(self, silent: bool, state: MagicState, td: bool, only: Optional[str]):
393
452
  if silent:
394
453
  return
395
454
 
455
+ if state.db is None:
456
+ raise AssertionError('load a database first')
457
+
396
458
  # analyze tables
397
- tables = self._db.analyze()
459
+ tables = state.db.analyze()
398
460
 
399
461
  # apply filter
400
462
  if only is None:
@@ -404,7 +466,7 @@ class DuckDBKernel(Kernel):
404
466
  whitelist = set()
405
467
 
406
468
  # split and strip names
407
- names = [n.strip() for n in re.split(r'[, \t]', only)]
469
+ names = [Table.normalize_name(n.strip()) for n in re.split(r'[, \t]', only)]
408
470
 
409
471
  # add initial tables to result set
410
472
  for name in names:
@@ -436,7 +498,9 @@ class DuckDBKernel(Kernel):
436
498
 
437
499
  self.print_data(svg)
438
500
 
439
- def _store_magic(self, silent: bool, result_columns: List[str], result: List[List], file: str, noheader: bool):
501
+ def _store_magic(self, silent: bool, state: MagicState,
502
+ result_columns: List[str], result: List[List],
503
+ file: str, noheader: bool):
440
504
  _, ext = file.rsplit('.', 1)
441
505
 
442
506
  # csv
@@ -456,58 +520,198 @@ class DuckDBKernel(Kernel):
456
520
  else:
457
521
  raise ValueError(f'extension {ext} not supported')
458
522
 
459
- def _ra_magic(self, silent: bool, code: str, analyze: bool):
460
- if self._db is None:
461
- raise AssertionError('load a database first')
462
-
523
+ def _ra_magic(self, silent: bool, state: MagicState, analyze: bool):
463
524
  if silent:
464
525
  return
465
526
 
466
- if not code.strip():
527
+ if not state.code.strip():
467
528
  return
468
529
 
530
+ if state.db is None:
531
+ raise AssertionError('load a database first')
532
+
469
533
  # analyze tables
470
- tables = self._db.analyze()
534
+ tables = state.db.analyze()
471
535
 
472
536
  # parse ra input
473
- root_node = RAParser.parse_query(code)
537
+ root_node = RAParser.parse_query(state.code)
538
+ if root_node is None:
539
+ return
474
540
 
475
541
  # create and show visualization
476
542
  if analyze:
477
- vd = RATreeDrawer(self._db, root_node, tables)
478
- svg = vd.to_svg(True)
543
+ vd = RATreeDrawer(state.db, root_node, tables)
479
544
 
545
+ svg = vd.to_interactive_svg()
480
546
  self.print_data(svg)
481
547
 
482
- # generate sql
483
- sql = root_node.to_sql_with_renamed_columns(tables)
548
+ state.code = {
549
+ node_id: node.to_sql_with_renamed_columns(tables)
550
+ for node_id, node in vd.nodes.items()
551
+ }
484
552
 
485
- return {
486
- 'generated_code': sql
487
- }
553
+ else:
554
+ state.code = root_node.to_sql_with_renamed_columns(tables)
488
555
 
489
- def _dc_magic(self, silent: bool, code: str):
490
- if self._db is None:
491
- raise AssertionError('load a database first')
556
+ def _all_ra_magic(self, silent: bool, state: MagicState, value: str):
557
+ if value.lower() in ('1', 'on', 'true'):
558
+ self._magics['ra'].default(True)
559
+ self._magics['dc'].default(False)
560
+
561
+ self.print('All further cells are interpreted as %RA.\n')
562
+ else:
563
+ self._magics['ra'].default(False)
492
564
 
565
+ def _dc_magic(self, silent: bool, state: MagicState):
493
566
  if silent:
494
567
  return
495
568
 
496
- if not code.strip():
569
+ if not state.code.strip():
497
570
  return
498
571
 
572
+ if state.db is None:
573
+ raise AssertionError('load a database first')
574
+
499
575
  # analyze tables
500
- tables = self._db.analyze()
576
+ tables = state.db.analyze()
501
577
 
502
578
  # parse dc input
503
- root_node = DCParser.parse_query(code)
579
+ root_node = DCParser.parse_query(state.code)
580
+ if root_node is None:
581
+ return
504
582
 
505
583
  # generate sql
506
- sql = root_node.to_sql(tables)
584
+ sql, cnm = root_node.to_sql_with_renamed_columns(tables)
585
+
586
+ state.code = sql
587
+ state.column_name_mapping.update(cnm)
588
+
589
+ def _all_dc_magic(self, silent: bool, state: MagicState, value: str):
590
+ if value.lower() in ('1', 'on', 'true'):
591
+ self._magics['dc'].default(True)
592
+ self._magics['ra'].default(False)
593
+
594
+ self.print('All further cells are interpreted as %DC.\n')
595
+ else:
596
+ self._magics['dc'].default(False)
597
+
598
+ def _guess_parser_magic(self, silent: bool, state: MagicState, value: str):
599
+ if value.lower() in ('1', 'on', 'true'):
600
+ self._magics['auto_parser'].default(True)
601
+ self.print('The correct parser is guessed for each subsequently executed cell.\n')
602
+ else:
603
+ self._magics['auto_parser'].default(False)
507
604
 
508
- return {
509
- 'generated_code': sql
605
+ def _auto_parser_magic(self, silent: bool, state: MagicState):
606
+ # do not handle statements starting with SQL keywords
607
+ first_word = state.code.strip().split(maxsplit=1)
608
+ if len(first_word) > 0:
609
+ if first_word[0].upper() in SQL_KEYWORDS:
610
+ return
611
+
612
+ # try to parse DC
613
+ try:
614
+ self._dc_magic(silent, state)
615
+ return
616
+ except ParserError as e:
617
+ if e.depth > 0:
618
+ raise e
619
+
620
+ # try to parse RA
621
+ try:
622
+ self._ra_magic(silent, state, analyze=False)
623
+ return
624
+ except ParserError as e:
625
+ if e.depth > 0:
626
+ raise e
627
+
628
+ def _plotly_magic(self, silent: bool, state: MagicState,
629
+ cols: List, rows: List[Tuple],
630
+ type: str, mapping: str, title: str = None):
631
+ # split mapping and handle asterisks
632
+ mapping = [m.strip() for m in mapping.split(',')]
633
+
634
+ for i in range(len(mapping)):
635
+ if mapping[i] == '*':
636
+ mapping = mapping[:i] + cols + mapping[i + 1:]
637
+
638
+ # convert all column names to lower case
639
+ lower_cols = [c.lower() for c in cols]
640
+ lower_mapping = [m.lower() for m in mapping]
641
+
642
+ # map desired columns to indices
643
+ mapped_indices = {}
644
+ for ok, lk in zip(mapping, lower_mapping):
645
+ for i in range(len(lower_cols)):
646
+ if lk == lower_cols[i]:
647
+ mapped_indices[ok] = i
648
+ break
649
+ else:
650
+ raise ValueError(f'unknown column {ok}')
651
+
652
+ # map desired columns to value lists
653
+ mapped_values = {
654
+ m: [r[i] for r in rows]
655
+ for m, i in mapped_indices.items()
510
656
  }
657
+ mapped_keys = iter(mapped_values.keys())
658
+
659
+ # get required chart type
660
+ match type.lower():
661
+ case 'scatter':
662
+ if len(lower_mapping) < 2: raise ValueError('scatter requires at least x and y values')
663
+ html = draw_scatter_chart(title,
664
+ mapped_values[next(mapped_keys)],
665
+ **{k: mapped_values[k] for k in mapped_keys})
666
+ case 'line':
667
+ if len(lower_mapping) < 2: raise ValueError('lines requires at least x and y values')
668
+ html = draw_line_chart(title,
669
+ mapped_values[next(mapped_keys)],
670
+ **{k: mapped_values[k] for k in mapped_keys})
671
+
672
+ case 'bar':
673
+ if len(lower_mapping) < 2: raise ValueError('bar requires at least x and y values')
674
+ html = draw_bar_chart(title,
675
+ mapped_values[next(mapped_keys)],
676
+ **{k: mapped_values[k] for k in mapped_keys})
677
+
678
+ case 'pie':
679
+ if len(lower_mapping) != 2: raise ValueError('pie requires labels and values')
680
+ html = draw_pie_chart(title,
681
+ mapped_values[next(mapped_keys)],
682
+ mapped_values[next(mapped_keys)])
683
+
684
+ case 'bubble':
685
+ if len(lower_mapping) != 4: raise ValueError('bubble requires x, y, size and color')
686
+ html = draw_bubble_chart(title,
687
+ mapped_values[next(mapped_keys)],
688
+ mapped_values[next(mapped_keys)],
689
+ mapped_values[next(mapped_keys)],
690
+ mapped_values[next(mapped_keys)])
691
+
692
+ case 'heatmap':
693
+ if len(lower_mapping) != 3: raise ValueError('heatmap requires x, y and z values')
694
+ html = draw_heatmap_chart(title,
695
+ mapped_values[next(mapped_keys)],
696
+ mapped_values[next(mapped_keys)],
697
+ mapped_values[next(mapped_keys)])
698
+
699
+ case _:
700
+ raise ValueError(f'unknown type: {type}')
701
+
702
+ # finally print the code
703
+ self.print_data(html, mime='text/html')
704
+
705
+ def _plotly_raw_magic(self, silent: bool, state: MagicState,
706
+ cols: List, rows: List[Tuple],
707
+ title: str = None):
708
+ if len(cols) != 1 and len(rows) != 1:
709
+ raise ValueError(f'expected exactly one column and one row')
710
+
711
+ self.print_data(
712
+ draw_chart(title, rows[0][0]),
713
+ mime='text/html'
714
+ )
511
715
 
512
716
  # jupyter related functions
513
717
  def do_execute(self, code: str, silent: bool,
@@ -515,26 +719,27 @@ class DuckDBKernel(Kernel):
515
719
  **kwargs):
516
720
  try:
517
721
  # get magic command
518
- clean_code, pre_query_callbacks, post_query_callbacks = self._magics(silent, code)
722
+ if len(self._db) > 0:
723
+ init_db = self._db[list(self._db.keys())[0]]
724
+ else:
725
+ init_db = None
519
726
 
520
- # execute magic commands here if it does not depend on query results
521
- execution_args = {
522
- 'max_rows': DuckDBKernel.DEFAULT_MAX_ROWS
523
- }
727
+ magic_state = MagicState(init_db, code, DuckDBKernel.DEFAULT_MAX_ROWS)
728
+ pre_query_callbacks, post_query_callbacks = self._magics(silent, magic_state)
524
729
 
730
+ # execute magic commands here if it does not depend on query results
525
731
  for callback in pre_query_callbacks:
526
- execution_args.update(callback())
527
-
528
- # overwrite clean_code with generated code
529
- if 'generated_code' in execution_args:
530
- clean_code = execution_args['generated_code']
531
- del execution_args['generated_code']
732
+ callback()
532
733
 
533
734
  # execute statement if needed
534
- if clean_code.strip():
535
- cols, rows = self._execute_stmt(clean_code, silent, **execution_args)
536
- else:
537
- cols, rows = None, None
735
+ cols, rows = None, None
736
+
737
+ if not isinstance(magic_state.code, dict):
738
+ magic_state.code = {'default': magic_state.code}
739
+
740
+ for name, code in reversed(magic_state.code.items()):
741
+ if code.strip():
742
+ cols, rows = self._execute_stmt(silent, magic_state, name, code)
538
743
 
539
744
  # execute magic command here if it does depend on query results
540
745
  for callback in post_query_callbacks:
@@ -558,5 +763,7 @@ class DuckDBKernel(Kernel):
558
763
  }
559
764
 
560
765
  def do_shutdown(self, restart):
561
- self._unload_database()
766
+ for name in list(self._db.keys()):
767
+ self._unload_database(name)
768
+
562
769
  return super().do_shutdown(restart)