jupyter-duckdb 1.2.100__py3-none-any.whl → 1.2.102__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
duckdb_kernel/kernel.py CHANGED
@@ -10,6 +10,7 @@ from typing import Optional, Dict, List, Tuple
10
10
  from ipykernel.kernelbase import Kernel
11
11
 
12
12
  from .db import Connection, DatabaseError, Table
13
+ from .visualization.lib import *
13
14
  from .db.error import *
14
15
  from .magics import *
15
16
  from .parser import RAParser, DCParser, ParserError
@@ -54,6 +55,8 @@ class DuckDBKernel(Kernel):
54
55
  MagicCommand('all_dc').arg('value', '1').on(self._all_dc_magic),
55
56
  MagicCommand('auto_parser').disable('sql', 'ra', 'dc').code(True).on(self._auto_parser_magic),
56
57
  MagicCommand('guess_parser').arg('value', '1').on(self._guess_parser_magic),
58
+ MagicCommand('plotly').arg('type').arg('mapping').opt('title').result(True).on(self._plotly_magic),
59
+ MagicCommand('plotly_raw').opt('title').result(True).on(self._plotly_raw_magic)
57
60
  )
58
61
 
59
62
  # create placeholders for database and tests
@@ -116,6 +119,7 @@ class DuckDBKernel(Kernel):
116
119
  self._db = Postgres(host, port, username, password, database_name)
117
120
  except ImportError:
118
121
  self.print('psycopg could not be found', name='stderr')
122
+ return False
119
123
 
120
124
  # Otherwise the provided path is used to create an
121
125
  # in-process instance.
@@ -145,7 +149,7 @@ class DuckDBKernel(Kernel):
145
149
  else:
146
150
  return False
147
151
 
148
- def _execute_stmt(self, query: str, silent: bool,
152
+ def _execute_stmt(self, name: str, query: str, silent: bool,
149
153
  column_name_mapping: Dict[str, str],
150
154
  max_rows: Optional[int]) -> Tuple[Optional[List[str]], Optional[List[List]]]:
151
155
  if self._db is None:
@@ -161,52 +165,57 @@ class DuckDBKernel(Kernel):
161
165
 
162
166
  et = time.time()
163
167
 
164
- # return result if silent
165
- if silent:
166
- return columns, rows
168
+ # print result if not silent
169
+ if not silent:
170
+ # print EXPLAIN queries as raw text if using DuckDB
171
+ if query.strip().startswith('EXPLAIN') and self._db.plain_explain():
172
+ for ekey, evalue in rows:
173
+ html = f'<b>{ekey}</b><br><pre>{evalue}</pre>'
174
+ break
167
175
 
168
- # print EXPLAIN queries as raw text if using DuckDB
169
- if query.strip().startswith('EXPLAIN') and self._db.plain_explain():
170
- for ekey, evalue in rows:
171
- self.print_data(f'<b>{ekey}</b><br><pre>{evalue}</pre>')
176
+ return None, None
172
177
 
173
- return None, None
178
+ # print every other query as a table
179
+ else:
180
+ if columns is not None:
181
+ # table header
182
+ mapped_columns = (column_name_mapping.get(c, c) for c in columns)
183
+ table_header = ''.join(f'<th>{c}</th>' for c in mapped_columns)
184
+
185
+ # table data
186
+ if max_rows is not None and len(rows) > max_rows:
187
+ table_data = f'''
188
+ {rows_table(rows[:math.ceil(max_rows / 2)])}
189
+ <tr>
190
+ <td colspan="{len(columns)}"
191
+ style="text-align: center"
192
+ title="{row_count(len(rows) - max_rows)} omitted">
193
+ ...
194
+ </td>
195
+ </tr>
196
+ {rows_table(rows[-math.floor(max_rows // 2):])}
197
+ '''
198
+ else:
199
+ table_data = rows_table(rows)
200
+
201
+ # send to client
202
+ html = (f'''
203
+ <table class="duckdb-query-result-table">
204
+ {table_header}
205
+ {table_data}
206
+ </table>
207
+
208
+ {row_count(len(rows))} in {et - st:.3f}s
209
+ ''')
174
210
 
175
- # print every other query as a table
176
- else:
177
- if columns is not None:
178
- # table header
179
- mapped_columns = (column_name_mapping.get(c, c) for c in columns)
180
- table_header = ''.join(f'<th>{c}</th>' for c in mapped_columns)
181
-
182
- # table data
183
- if max_rows is not None and len(rows) > max_rows:
184
- table_data = f'''
185
- {rows_table(rows[:math.ceil(max_rows / 2)])}
186
- <tr>
187
- <td colspan="{len(columns)}"
188
- style="text-align: center"
189
- title="{row_count(len(rows) - max_rows)} omitted">
190
- ...
191
- </td>
192
- </tr>
193
- {rows_table(rows[-math.floor(max_rows // 2):])}
194
- '''
195
211
  else:
196
- table_data = rows_table(rows)
197
-
198
- # send to client
199
- self.print_data(f'''
200
- <table class="duckdb-query-result">
201
- {table_header}
202
- {table_data}
203
- </table>
204
- ''')
205
-
206
- self.print_data(f'{row_count(len(rows))} in {et - st:.3f}s')
212
+ html = f'statement executed without result in {et - st:.3f}s'
207
213
 
208
- else:
209
- self.print_data(f'statement executed without result in {et - st:.3f}s')
214
+ self.print_data(f'''
215
+ <div class="duckdb-query-result {name}">
216
+ {html}
217
+ </div>
218
+ ''')
210
219
 
211
220
  return columns, rows
212
221
 
@@ -496,16 +505,23 @@ class DuckDBKernel(Kernel):
496
505
  # create and show visualization
497
506
  if analyze:
498
507
  vd = RATreeDrawer(self._db, root_node, tables)
499
- svg = vd.to_svg(True)
500
-
501
- self.print_data(svg)
508
+ svg = vd.to_interactive_svg()
509
+ data = {
510
+ 'generated_code': {
511
+ node_id: node.to_sql_with_renamed_columns(tables)
512
+ for node_id, node in vd.nodes.items()
513
+ }
514
+ }
502
515
 
503
- # generate sql
504
- sql = root_node.to_sql_with_renamed_columns(tables)
516
+ else:
517
+ svg = ''
518
+ data = {
519
+ 'generated_code': root_node.to_sql_with_renamed_columns(tables)
520
+ }
505
521
 
506
- return {
507
- 'generated_code': sql
508
- }
522
+ # return data
523
+ self.print_data(svg)
524
+ return data
509
525
 
510
526
  def _all_ra_magic(self, silent: bool, value: str):
511
527
  if value.lower() in ('1', 'on', 'true'):
@@ -577,6 +593,90 @@ class DuckDBKernel(Kernel):
577
593
  if e.depth > 0:
578
594
  raise e
579
595
 
596
+ def _plotly_magic(self, silent: bool, cols: List, rows: List[Tuple], type: str, mapping: str, title: str = None):
597
+ # split mapping and handle asterisks
598
+ mapping = [m.strip() for m in mapping.split(',')]
599
+
600
+ for i in range(len(mapping)):
601
+ if mapping[i] == '*':
602
+ mapping = mapping[:i] + cols + mapping[i+1:]
603
+
604
+ # convert all column names to lower case
605
+ lower_cols = [c.lower() for c in cols]
606
+ lower_mapping = [m.lower() for m in mapping]
607
+
608
+ # map desired columns to indices
609
+ mapped_indices = {}
610
+ for ok, lk in zip(mapping, lower_mapping):
611
+ for i in range(len(lower_cols)):
612
+ if lk == lower_cols[i]:
613
+ mapped_indices[ok] = i
614
+ break
615
+ else:
616
+ raise ValueError(f'unknown column {ok}')
617
+
618
+ # map desired columns to value lists
619
+ mapped_values = {
620
+ m: [r[i] for r in rows]
621
+ for m, i in mapped_indices.items()
622
+ }
623
+ mapped_keys = iter(mapped_values.keys())
624
+
625
+ # get required chart type
626
+ match type.lower():
627
+ case 'scatter':
628
+ if len(lower_mapping) < 2: raise ValueError('scatter requires at least x and y values')
629
+ html = draw_scatter_chart(title,
630
+ mapped_values[next(mapped_keys)],
631
+ **{k: mapped_values[k] for k in mapped_keys})
632
+ case 'line':
633
+ if len(lower_mapping) < 2: raise ValueError('lines requires at least x and y values')
634
+ html = draw_line_chart(title,
635
+ mapped_values[next(mapped_keys)],
636
+ **{k: mapped_values[k] for k in mapped_keys})
637
+
638
+ case 'bar':
639
+ if len(lower_mapping) < 2: raise ValueError('bar requires at least x and y values')
640
+ html = draw_bar_chart(title,
641
+ mapped_values[next(mapped_keys)],
642
+ **{k: mapped_values[k] for k in mapped_keys})
643
+
644
+ case 'pie':
645
+ if len(lower_mapping) != 2: raise ValueError('pie requires labels and values')
646
+ html = draw_pie_chart(title,
647
+ mapped_values[next(mapped_keys)],
648
+ mapped_values[next(mapped_keys)])
649
+
650
+ case 'bubble':
651
+ if len(lower_mapping) != 4: raise ValueError('bubble requires x, y, size and color')
652
+ html = draw_bubble_chart(title,
653
+ mapped_values[next(mapped_keys)],
654
+ mapped_values[next(mapped_keys)],
655
+ mapped_values[next(mapped_keys)],
656
+ mapped_values[next(mapped_keys)])
657
+
658
+ case 'heatmap':
659
+ if len(lower_mapping) != 3: raise ValueError('heatmap requires x, y and z values')
660
+ html = draw_heatmap_chart(title,
661
+ mapped_values[next(mapped_keys)],
662
+ mapped_values[next(mapped_keys)],
663
+ mapped_values[next(mapped_keys)])
664
+
665
+ case _:
666
+ raise ValueError(f'unknown type: {type}')
667
+
668
+ # finally print the code
669
+ self.print_data(html, mime='text/html')
670
+
671
+ def _plotly_raw_magic(self, silent: bool, cols: List, rows: List[Tuple], title: str = None):
672
+ if len(cols) != 1 and len(rows) != 1:
673
+ raise ValueError(f'expected exactly one column and one row')
674
+
675
+ self.print_data(
676
+ draw_chart(title, rows[0][0]),
677
+ mime='text/html'
678
+ )
679
+
580
680
  # jupyter related functions
581
681
  def do_execute(self, code: str, silent: bool,
582
682
  store_history: bool = True, user_expressions: dict = None, allow_stdin: bool = False,
@@ -603,10 +703,14 @@ class DuckDBKernel(Kernel):
603
703
  execution_args['column_name_mapping'] = {}
604
704
 
605
705
  # execute statement if needed
606
- if clean_code.strip():
607
- cols, rows = self._execute_stmt(clean_code, silent, **execution_args)
608
- else:
609
- cols, rows = None, None
706
+ cols, rows = None, None
707
+
708
+ if not isinstance(clean_code, dict):
709
+ clean_code = {'default': clean_code}
710
+
711
+ for name, code in clean_code.items():
712
+ if code.strip():
713
+ cols, rows = self._execute_stmt(name, code, silent, **execution_args)
610
714
 
611
715
  # execute magic command here if it does depend on query results
612
716
  for callback in post_query_callbacks:
@@ -1,5 +1,5 @@
1
1
  import re
2
- from typing import Dict, Tuple, List, Optional
2
+ from typing import Dict, Tuple, List
3
3
 
4
4
  from . import MagicCommand, MagicCommandException, MagicCommandCallback
5
5
  from .StringWrapper import StringWrapper
@@ -62,6 +62,11 @@ class MagicCommandHandler:
62
62
  args = [group if group is not None else default
63
63
  for group, (_, default, _) in zip(match.groups(), magic.args)]
64
64
 
65
+ args = [arg[1:-1]
66
+ if arg is not None and (arg[0] == '"' and arg[-1] == '"' or arg[0] == "'" and arg[-1] == "'")
67
+ else arg
68
+ for arg in args]
69
+
65
70
  if any(arg is None for arg in args):
66
71
  raise MagicCommandException(f'could not parse parameters for command "{command}"')
67
72
 
@@ -87,6 +92,9 @@ class MagicCommandHandler:
87
92
  value = match.group(i + 2)
88
93
  i += 3
89
94
 
95
+ if value is not None and (value[0] == '"' and value[-1] == '"' or value[0] == "'" and value[-1] == "'"):
96
+ value = value[1:-1]
97
+
90
98
  if name is not None:
91
99
  optionals[name.lower()] = value
92
100
 
@@ -0,0 +1,144 @@
1
+ import json
2
+ from decimal import Decimal
3
+ from typing import Dict, List, Optional
4
+ from uuid import uuid4
5
+
6
+ from .lib import init_plotly
7
+
8
+
9
+ def __div_id() -> str:
10
+ return f'div-{str(uuid4())}'
11
+
12
+
13
+ def __layout(title: Optional[str]):
14
+ layout = {
15
+ 'dragmode': False,
16
+ 'xaxis': {
17
+ 'rangeselector': {
18
+ 'visible': False
19
+ }
20
+ }
21
+ }
22
+
23
+ if title is not None:
24
+ layout['title'] = {
25
+ 'text': title,
26
+ 'font': {
27
+ 'family': 'sans-serif',
28
+ 'size': 32,
29
+ 'color': 'rgb(0, 0, 0)'
30
+ },
31
+ 'xanchor': 'center'
32
+ }
33
+
34
+ return layout
35
+
36
+
37
+ def __config():
38
+ return {
39
+ 'displayModeBar': False,
40
+ 'scrollZoom': False
41
+ }
42
+
43
+
44
+ def __fix_decimal(x: List):
45
+ return [float(x) if isinstance(x, Decimal) else x
46
+ for x in x]
47
+
48
+
49
+ def draw_chart(title: Optional[str], traces: List[Dict] | Dict) -> str:
50
+ init = init_plotly()
51
+ div_id = __div_id()
52
+ layout = __layout(title)
53
+ config = __config()
54
+
55
+ if not isinstance(traces, str):
56
+ traces = json.dumps(traces)
57
+
58
+ return f'''
59
+ <script type="text/javascript">
60
+ {init}
61
+ </script>
62
+
63
+ <div id="{div_id}"></div>
64
+ <script type="text/javascript">
65
+ Plotly.newPlot('{div_id}', {traces}, {json.dumps(layout)}, {json.dumps(config)});
66
+ </script>
67
+ '''
68
+
69
+
70
+ def draw_scatter_chart(title: Optional[str], x, **ys) -> str:
71
+ return draw_chart(title, [
72
+ {
73
+ 'x': __fix_decimal(x),
74
+ 'y': __fix_decimal(y),
75
+ 'mode': 'markers',
76
+ 'type': 'scatter',
77
+ 'name': name
78
+ }
79
+ for name, y in ys.items()
80
+ ])
81
+
82
+
83
+ def draw_line_chart(title: Optional[str], x, **ys) -> str:
84
+ return draw_chart(title, [
85
+ {
86
+ 'x': __fix_decimal(x),
87
+ 'y': __fix_decimal(y),
88
+ 'mode': 'lines+markers',
89
+ 'name': name
90
+ }
91
+ for name, y in ys.items()
92
+ ])
93
+
94
+
95
+ def draw_bar_chart(title: Optional[str], x, **ys) -> str:
96
+ return draw_chart(title, [
97
+ {
98
+ 'x': __fix_decimal(x),
99
+ 'y': __fix_decimal(y),
100
+ 'type': 'bar',
101
+ 'name': name
102
+ }
103
+ for name, y in ys.items()
104
+ ])
105
+
106
+
107
+ def draw_pie_chart(title: Optional[str], x, y) -> str:
108
+ return draw_chart(title, [{
109
+ 'values': __fix_decimal(y),
110
+ 'labels': __fix_decimal(x),
111
+ 'type': 'pie'
112
+ }])
113
+
114
+
115
+ def draw_bubble_chart(title: Optional[str], x, y, s, c) -> str:
116
+ return draw_chart(title, [{
117
+ 'x': __fix_decimal(x),
118
+ 'y': __fix_decimal(y),
119
+ 'mode': 'markers',
120
+ 'marker': {
121
+ 'size': __fix_decimal(s),
122
+ 'color': __fix_decimal(c)
123
+ }
124
+ }])
125
+
126
+
127
+ def draw_heatmap_chart(title: Optional[str], x, y, z) -> str:
128
+ return draw_chart(title, [{
129
+ 'x': __fix_decimal(x[0]),
130
+ 'y': __fix_decimal(y[0]),
131
+ 'z': [__fix_decimal(v) for v in z[0]],
132
+ 'type': 'heatmap'
133
+ }])
134
+
135
+
136
+ __all__ = [
137
+ 'draw_chart',
138
+ 'draw_scatter_chart',
139
+ 'draw_line_chart',
140
+ 'draw_bar_chart',
141
+ 'draw_pie_chart',
142
+ 'draw_bubble_chart',
143
+ 'draw_heatmap_chart',
144
+ ]
@@ -1,12 +1,14 @@
1
- from typing import Dict
1
+ from typing import Dict, Optional
2
2
 
3
3
  from graphviz import Digraph
4
+ from uuid import uuid4
4
5
 
5
6
  from duckdb_kernel.db import Table
6
7
  from duckdb_kernel.parser.elements import RAElement
7
8
  from duckdb_kernel.util.formatting import row_count
8
9
  from .Drawer import Drawer
9
10
  from ..db import Connection
11
+ from .lib import *
10
12
 
11
13
 
12
14
  class RATreeDrawer(Drawer):
@@ -15,6 +17,9 @@ class RATreeDrawer(Drawer):
15
17
  self.root_node: RAElement = root_node
16
18
  self.tables: Dict[str, Table] = tables
17
19
 
20
+ self.nodes: Dict[str, RAElement] = {}
21
+ self.root_node_id: Optional[str] = None
22
+
18
23
  def to_graph(self) -> Digraph:
19
24
  # create graph
20
25
  ps = Digraph('Schema',
@@ -31,7 +36,11 @@ class RATreeDrawer(Drawer):
31
36
 
32
37
  def __add_node(self, ps: Digraph, node: RAElement) -> str:
33
38
  # use id of node object as identifier
34
- node_id = f'node_{id(node)}'
39
+ node_id = f'node_{str(uuid4()).replace("-", "_")}'
40
+
41
+ self.nodes[node_id] = node
42
+ if node == self.root_node:
43
+ self.root_node_id = node_id
35
44
 
36
45
  # generate child nodes
37
46
  child_ids = [self.__add_node(ps, child) for child in node.children]
@@ -69,3 +78,26 @@ class RATreeDrawer(Drawer):
69
78
 
70
79
  # return node identifier to generate edges
71
80
  return node_id
81
+
82
+ def to_interactive_svg(self) -> str:
83
+ div_id = f'div-{str(uuid4())}'
84
+
85
+ css = init_css()
86
+ ra = init_ra()
87
+ svg = self.to_svg(True)
88
+
89
+ return f'''
90
+ <style type="text/css">
91
+ {css}
92
+ </style>
93
+
94
+ <div id="{div_id}">
95
+ {svg}
96
+ </div>
97
+
98
+ <script type="text/javascript">
99
+ {ra}
100
+
101
+ animate_ra('{div_id}', '{self.root_node_id}')
102
+ </script>
103
+ '''
@@ -1,2 +1,3 @@
1
+ from .Plotly import *
1
2
  from .RATreeDrawer import RATreeDrawer
2
3
  from .SchemaDrawer import SchemaDrawer
@@ -0,0 +1,53 @@
1
+ import os
2
+
3
+ __CSS_INITIALIZED = False
4
+ __RA_INITIALIZED = False
5
+ __PLOTLY_INITIALIZED = False
6
+
7
+ __location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
8
+
9
+
10
+ def init_css() -> str:
11
+ global __CSS_INITIALIZED
12
+
13
+ if not __CSS_INITIALIZED:
14
+ with open(os.path.join(__location, 'ra.css')) as ra_file:
15
+ css = ra_file.read()
16
+ else:
17
+ css = ''
18
+
19
+ __CSS_INITIALIZED = True
20
+ return css
21
+
22
+
23
+ def init_ra() -> str:
24
+ global __RA_INITIALIZED
25
+
26
+ if not __RA_INITIALIZED:
27
+ with open(os.path.join(__location, 'ra.js')) as ra_file:
28
+ ra = ra_file.read()
29
+ else:
30
+ ra = ''
31
+
32
+ __RA_INITIALIZED = True
33
+ return ra
34
+
35
+
36
+ def init_plotly() -> str:
37
+ global __PLOTLY_INITIALIZED
38
+
39
+ if not __PLOTLY_INITIALIZED:
40
+ with open(os.path.join(__location, 'plotly-3.0.1.min.js')) as plotly_file:
41
+ plotly = plotly_file.read()
42
+ else:
43
+ plotly = ''
44
+
45
+ __PLOTLY_INITIALIZED = True
46
+ return plotly
47
+
48
+
49
+ __all__ = [
50
+ 'init_css',
51
+ 'init_ra',
52
+ 'init_plotly',
53
+ ]