corr_vars_widget 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,241 @@
1
+ # Modification of
2
+ # https://github.com/manzt/quak/blob/main/src/quak/_widget.py
3
+ # https://github.com/manzt/quak/blob/main/src/quak/_util.py
4
+
5
+ # MIT License
6
+
7
+ # Copyright (c) 2024 Trevor Manz
8
+
9
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
10
+ # of this software and associated documentation files (the "Software"), to deal
11
+ # in the Software without restriction, including without limitation the rights
12
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13
+ # copies of the Software, and to permit persons to whom the Software is
14
+ # furnished to do so, subject to the following conditions:
15
+
16
+ # The above copyright notice and this permission notice shall be included in all
17
+ # copies or substantial portions of the Software.
18
+
19
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25
+ # SOFTWARE.
26
+
27
+ from __future__ import annotations
28
+
29
+ import io
30
+ import logging
31
+ import pathlib
32
+ import time
33
+
34
+ import anywidget
35
+ import duckdb
36
+ import polars as pl
37
+ import pyarrow as pa
38
+ import pyarrow.feather as feather
39
+ import traitlets
40
+
41
+ import pathlib
42
+
43
+ import anywidget
44
+ import traitlets
45
+
46
+ bundler_assets_dir = pathlib.Path(__file__).parent / "static"
47
+
48
+ logger = logging.getLogger(__name__)
49
+ logger.addHandler(logging.NullHandler())
50
+
51
+ SLOW_QUERY_THRESHOLD = 5000
52
+
53
+
54
+ def table_to_ipc(
55
+ table: pa.lib.Table | pa.lib.RecordBatch | pa.lib.RecordBatchReader,
56
+ ) -> memoryview:
57
+ """Convert Arrow tabular data to an Arrow IPC message."""
58
+ if isinstance(table, pa.RecordBatchReader):
59
+ table = table.read_all()
60
+ elif isinstance(table, pa.RecordBatch):
61
+ table = pa.Table.from_batches([table], schema=table.schema)
62
+ elif not isinstance(table, pa.Table):
63
+ raise TypeError(
64
+ "Expected a pyarrow Table, RecordBatch, or RecordBatchReader,",
65
+ f"got {type(table)!r}",
66
+ )
67
+
68
+ sink = io.BytesIO()
69
+ feather.write_feather(table, sink, compression="uncompressed")
70
+ return sink.getbuffer()
71
+
72
+
73
+ class ObsWidget(anywidget.AnyWidget):
74
+ """An anywidget for displaying obs data in a table."""
75
+
76
+ _esm = bundler_assets_dir / "obs" / "obs.js"
77
+ _css = bundler_assets_dir / "obs" / "main.css"
78
+
79
+ _obs_level = traitlets.Unicode().tag(sync=True)
80
+
81
+ _table_name = traitlets.Unicode().tag(sync=True)
82
+ _columns = traitlets.List(traitlets.Unicode()).tag(sync=True)
83
+
84
+ # The SQL query for the current data (read-only)
85
+ sql = traitlets.Unicode().tag(sync=True)
86
+
87
+ def __init__(self, data: pl.DataFrame, obs_level: str | None = None) -> None:
88
+ """
89
+ Initialize the ObsWidget.
90
+
91
+ Args:
92
+ data: A Polars DataFrame containing the observation data.
93
+ obs_level: An optional string representing the observation level
94
+ (e.g., 'Hospital', 'ICU stay').
95
+ """
96
+ table = "obs"
97
+
98
+ conn = duckdb.connect(":memory:")
99
+ # FIXME: special case pl.DataFrame for now until DuckDB
100
+ # supports `[string,bytes]_view` Arrow data types
101
+ # see: https://github.com/manzt/quak/issues/41
102
+ # Polars .to_arrow() method will cast to non-view array types for us
103
+ arrow_table = data.to_arrow()
104
+ conn.register(table, arrow_table)
105
+ self._conn = conn
106
+ super().__init__(
107
+ _obs_level=obs_level or "",
108
+ _table_name=table,
109
+ _columns=data.columns,
110
+ sql=f'SELECT * FROM "{table}"',
111
+ )
112
+ self.on_msg(self._handle_custom_msg)
113
+
114
+ def _handle_custom_msg(self, data: dict, buffers: list) -> None:
115
+ logger.debug(f"{data=}, {buffers=}")
116
+
117
+ start = time.time()
118
+
119
+ uuid = data["uuid"]
120
+ sql = data["sql"]
121
+ command = data["type"]
122
+ try:
123
+ if command == "arrow":
124
+ result = self._conn.query(sql).arrow()
125
+ buf = table_to_ipc(result)
126
+ self.send({"type": "arrow", "uuid": uuid}, buffers=[buf])
127
+ elif command == "exec":
128
+ self._conn.execute(sql)
129
+ self.send({"type": "exec", "uuid": uuid})
130
+ elif command == "json":
131
+ result = self._conn.query(sql).df()
132
+ json = result.to_dict(orient="records")
133
+ self.send({"type": "json", "uuid": uuid, "result": json})
134
+ else:
135
+ raise ValueError(f"Unknown command {command}")
136
+ except Exception as e:
137
+ logger.exception("Error processing query")
138
+ self.send({"error": str(e), "uuid": uuid})
139
+
140
+ total = round((time.time() - start) * 1_000)
141
+ if total > SLOW_QUERY_THRESHOLD:
142
+ logger.warning(f"DONE. Slow query {uuid} took {total} ms.\n{sql}")
143
+ else:
144
+ logger.info(f"DONE. Query {uuid} took {total} ms.\n{sql}")
145
+
146
+ def data(self) -> duckdb.DuckDBPyRelation:
147
+ """Return the current SQL as a DuckDB relation."""
148
+ return self._conn.query(self.sql)
149
+
150
+
151
+ class ObsmWidget(anywidget.AnyWidget):
152
+ """An anywidget for displaying obsm data in a table."""
153
+
154
+ _esm = bundler_assets_dir / "obsm" / "obsm.js"
155
+ _css = bundler_assets_dir / "obsm" / "main.css"
156
+
157
+ _tables = traitlets.List(
158
+ traitlets.Dict(
159
+ per_key_traits={
160
+ "_table_name": traitlets.Unicode(),
161
+ "_columns": traitlets.List(traitlets.Unicode()),
162
+ "sql": traitlets.Unicode(),
163
+ }
164
+ )
165
+ ).tag(sync=True)
166
+
167
+ def __init__(self, data: dict[str, pl.DataFrame]) -> None:
168
+ """
169
+ Initialize the ObsmWidget.
170
+
171
+ Args:
172
+ data: A dictionary mapping table names to Polars DataFrames.
173
+ """
174
+
175
+ conn = duckdb.connect(":memory:")
176
+ tables = []
177
+ for table, data in data.items():
178
+ # FIXME: special case pl.DataFrame for now until DuckDB
179
+ # supports `[string,bytes]_view` Arrow data types
180
+ # see: https://github.com/manzt/quak/issues/41
181
+ # Polars .to_arrow() method will cast to non-view array types for us
182
+ arrow_table = data.to_arrow()
183
+ conn.register(table, arrow_table)
184
+ tables.append(
185
+ {
186
+ "_table_name": table,
187
+ "_columns": data.columns,
188
+ "sql": f'SELECT * FROM "{table}"',
189
+ }
190
+ )
191
+ self._conn = conn
192
+ super().__init__(_tables=tables)
193
+ self.on_msg(self._handle_custom_msg)
194
+
195
+ def _handle_custom_msg(self, data: dict, buffers: list) -> None:
196
+ logger.debug(f"{data=}, {buffers=}")
197
+
198
+ start = time.time()
199
+
200
+ uuid = data["uuid"]
201
+ sql = data["sql"]
202
+ command = data["type"]
203
+ try:
204
+ if command == "arrow":
205
+ result = self._conn.query(sql).arrow()
206
+ buf = table_to_ipc(result)
207
+ self.send({"type": "arrow", "uuid": uuid}, buffers=[buf])
208
+ elif command == "exec":
209
+ self._conn.execute(sql)
210
+ self.send({"type": "exec", "uuid": uuid})
211
+ elif command == "json":
212
+ result = self._conn.query(sql).df()
213
+ json = result.to_dict(orient="records")
214
+ self.send({"type": "json", "uuid": uuid, "result": json})
215
+ else:
216
+ raise ValueError(f"Unknown command {command}")
217
+ except Exception as e:
218
+ logger.exception("Error processing query")
219
+ self.send({"error": str(e), "uuid": uuid})
220
+
221
+ total = round((time.time() - start) * 1_000)
222
+ if total > SLOW_QUERY_THRESHOLD:
223
+ logger.warning(f"DONE. Slow query {uuid} took {total} ms.\n{sql}")
224
+ else:
225
+ logger.info(f"DONE. Query {uuid} took {total} ms.\n{sql}")
226
+
227
+ @property
228
+ def sql(self) -> dict[str, str]:
229
+ """Return the current SQL as a DuckDB relation."""
230
+ return {table["_table_name"]: table["sql"] for table in self._tables}
231
+
232
+ @property
233
+ def data(self) -> dict[str, duckdb.DuckDBPyRelation]:
234
+ """Return the current SQL as a DuckDB relation."""
235
+ return {
236
+ table["_table_name"]: self._conn.query(table["sql"])
237
+ for table in self._tables
238
+ }
239
+
240
+
241
+ __all__ = ["ObsWidget", "ObsmWidget"]