real-ladybug 0.0.1.dev1__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of real-ladybug might be problematic. Click here for more details.
- real_ladybug/__init__.py +83 -0
- real_ladybug/_lbug.cp312-win_amd64.pyd +0 -0
- real_ladybug/_lbug.exp +0 -0
- real_ladybug/_lbug.lib +0 -0
- real_ladybug/async_connection.py +226 -0
- real_ladybug/connection.py +323 -0
- real_ladybug/constants.py +7 -0
- real_ladybug/database.py +307 -0
- real_ladybug/prepared_statement.py +51 -0
- real_ladybug/py.typed +0 -0
- real_ladybug/query_result.py +511 -0
- real_ladybug/torch_geometric_feature_store.py +185 -0
- real_ladybug/torch_geometric_graph_store.py +131 -0
- real_ladybug/torch_geometric_result_converter.py +282 -0
- real_ladybug/types.py +39 -0
- real_ladybug-0.0.1.dev1.dist-info/METADATA +88 -0
- real_ladybug-0.0.1.dev1.dist-info/RECORD +114 -0
- real_ladybug-0.0.1.dev1.dist-info/WHEEL +5 -0
- real_ladybug-0.0.1.dev1.dist-info/licenses/LICENSE +21 -0
- real_ladybug-0.0.1.dev1.dist-info/top_level.txt +3 -0
- real_ladybug-0.0.1.dev1.dist-info/zip-safe +1 -0
- real_ladybug-source/scripts/antlr4/hash.py +2 -0
- real_ladybug-source/scripts/antlr4/keywordhandler.py +47 -0
- real_ladybug-source/scripts/collect-extensions.py +68 -0
- real_ladybug-source/scripts/collect-single-file-header.py +126 -0
- real_ladybug-source/scripts/export-dbs.py +101 -0
- real_ladybug-source/scripts/export-import-test.py +345 -0
- real_ladybug-source/scripts/extension/purge-beta.py +34 -0
- real_ladybug-source/scripts/generate-cpp-docs/collect_files.py +122 -0
- real_ladybug-source/scripts/generate-tinysnb.py +34 -0
- real_ladybug-source/scripts/get-clangd-diagnostics.py +233 -0
- real_ladybug-source/scripts/migrate-lbug-db.py +308 -0
- real_ladybug-source/scripts/multiplatform-test-helper/collect-results.py +71 -0
- real_ladybug-source/scripts/multiplatform-test-helper/notify-discord.py +68 -0
- real_ladybug-source/scripts/pip-package/package_tar.py +90 -0
- real_ladybug-source/scripts/pip-package/setup.py +130 -0
- real_ladybug-source/scripts/run-clang-format.py +408 -0
- real_ladybug-source/scripts/setup-extension-repo.py +67 -0
- real_ladybug-source/scripts/test-simsimd-dispatch.py +45 -0
- real_ladybug-source/scripts/update-nightly-build-version.py +81 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-01-download-rfc.py +16 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-02-rfc-to-bin.py +34 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-03-validate-bin.py +35 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-04-generate-java-literals.py +85 -0
- real_ladybug-source/third_party/pybind11/tools/codespell_ignore_lines_from_errors.py +35 -0
- real_ladybug-source/third_party/pybind11/tools/libsize.py +36 -0
- real_ladybug-source/third_party/pybind11/tools/make_changelog.py +63 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/__init__.py +83 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/async_connection.py +226 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/connection.py +323 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/constants.py +7 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/database.py +307 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/prepared_statement.py +51 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/py.typed +0 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/query_result.py +511 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_feature_store.py +185 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_graph_store.py +131 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_result_converter.py +282 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/types.py +39 -0
- real_ladybug-source/tools/python_api/src_py/__init__.py +83 -0
- real_ladybug-source/tools/python_api/src_py/async_connection.py +226 -0
- real_ladybug-source/tools/python_api/src_py/connection.py +323 -0
- real_ladybug-source/tools/python_api/src_py/constants.py +7 -0
- real_ladybug-source/tools/python_api/src_py/database.py +307 -0
- real_ladybug-source/tools/python_api/src_py/prepared_statement.py +51 -0
- real_ladybug-source/tools/python_api/src_py/py.typed +0 -0
- real_ladybug-source/tools/python_api/src_py/query_result.py +511 -0
- real_ladybug-source/tools/python_api/src_py/torch_geometric_feature_store.py +185 -0
- real_ladybug-source/tools/python_api/src_py/torch_geometric_graph_store.py +131 -0
- real_ladybug-source/tools/python_api/src_py/torch_geometric_result_converter.py +282 -0
- real_ladybug-source/tools/python_api/src_py/types.py +39 -0
- real_ladybug-source/tools/python_api/test/conftest.py +230 -0
- real_ladybug-source/tools/python_api/test/disabled_test_extension.py +73 -0
- real_ladybug-source/tools/python_api/test/ground_truth.py +430 -0
- real_ladybug-source/tools/python_api/test/test_arrow.py +694 -0
- real_ladybug-source/tools/python_api/test/test_async_connection.py +159 -0
- real_ladybug-source/tools/python_api/test/test_blob_parameter.py +145 -0
- real_ladybug-source/tools/python_api/test/test_connection.py +49 -0
- real_ladybug-source/tools/python_api/test/test_database.py +234 -0
- real_ladybug-source/tools/python_api/test/test_datatype.py +372 -0
- real_ladybug-source/tools/python_api/test/test_df.py +564 -0
- real_ladybug-source/tools/python_api/test/test_dict.py +112 -0
- real_ladybug-source/tools/python_api/test/test_exception.py +54 -0
- real_ladybug-source/tools/python_api/test/test_fsm.py +227 -0
- real_ladybug-source/tools/python_api/test/test_get_header.py +49 -0
- real_ladybug-source/tools/python_api/test/test_helper.py +8 -0
- real_ladybug-source/tools/python_api/test/test_issue.py +147 -0
- real_ladybug-source/tools/python_api/test/test_iteration.py +96 -0
- real_ladybug-source/tools/python_api/test/test_networkx.py +437 -0
- real_ladybug-source/tools/python_api/test/test_parameter.py +340 -0
- real_ladybug-source/tools/python_api/test/test_prepared_statement.py +117 -0
- real_ladybug-source/tools/python_api/test/test_query_result.py +54 -0
- real_ladybug-source/tools/python_api/test/test_query_result_close.py +44 -0
- real_ladybug-source/tools/python_api/test/test_scan_pandas.py +676 -0
- real_ladybug-source/tools/python_api/test/test_scan_pandas_pyarrow.py +714 -0
- real_ladybug-source/tools/python_api/test/test_scan_polars.py +165 -0
- real_ladybug-source/tools/python_api/test/test_scan_pyarrow.py +167 -0
- real_ladybug-source/tools/python_api/test/test_timeout.py +11 -0
- real_ladybug-source/tools/python_api/test/test_torch_geometric.py +640 -0
- real_ladybug-source/tools/python_api/test/test_torch_geometric_remote_backend.py +111 -0
- real_ladybug-source/tools/python_api/test/test_udf.py +207 -0
- real_ladybug-source/tools/python_api/test/test_version.py +6 -0
- real_ladybug-source/tools/python_api/test/test_wal.py +80 -0
- real_ladybug-source/tools/python_api/test/type_aliases.py +10 -0
- real_ladybug-source/tools/rust_api/update_version.py +47 -0
- real_ladybug-source/tools/shell/test/conftest.py +218 -0
- real_ladybug-source/tools/shell/test/test_helper.py +60 -0
- real_ladybug-source/tools/shell/test/test_shell_basics.py +325 -0
- real_ladybug-source/tools/shell/test/test_shell_commands.py +656 -0
- real_ladybug-source/tools/shell/test/test_shell_control_edit.py +438 -0
- real_ladybug-source/tools/shell/test/test_shell_control_search.py +468 -0
- real_ladybug-source/tools/shell/test/test_shell_esc_edit.py +232 -0
- real_ladybug-source/tools/shell/test/test_shell_esc_search.py +162 -0
- real_ladybug-source/tools/shell/test/test_shell_flags.py +645 -0
|
@@ -0,0 +1,511 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from .torch_geometric_result_converter import TorchGeometricResultConverter
|
|
6
|
+
from .types import Type
|
|
7
|
+
|
|
8
|
+
from .constants import ID, LABEL, SRC, DST, NODES, RELS
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import sys
|
|
12
|
+
from collections.abc import Iterator
|
|
13
|
+
from types import TracebackType
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import networkx as nx
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import polars as pl
|
|
19
|
+
import pyarrow as pa
|
|
20
|
+
import torch_geometric.data as geo
|
|
21
|
+
|
|
22
|
+
from . import _lbug
|
|
23
|
+
|
|
24
|
+
if sys.version_info >= (3, 11):
|
|
25
|
+
from typing import Self
|
|
26
|
+
else:
|
|
27
|
+
from typing_extensions import Self
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class QueryResult:
|
|
31
|
+
"""QueryResult stores the result of a query execution."""
|
|
32
|
+
|
|
33
|
+
def __init__(self, connection: _lbug.Connection, query_result: _lbug.QueryResult): # type: ignore[name-defined]
|
|
34
|
+
"""
|
|
35
|
+
Parameters
|
|
36
|
+
----------
|
|
37
|
+
connection : _lbug.Connection
|
|
38
|
+
The underlying C++ connection object from pybind11.
|
|
39
|
+
|
|
40
|
+
query_result : _lbug.QueryResult
|
|
41
|
+
The underlying C++ query result object from pybind11.
|
|
42
|
+
|
|
43
|
+
"""
|
|
44
|
+
self.connection = connection
|
|
45
|
+
self._query_result = query_result
|
|
46
|
+
self.is_closed = False
|
|
47
|
+
self.as_dict = False
|
|
48
|
+
|
|
49
|
+
def __enter__(self) -> Self:
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
def __exit__(
|
|
53
|
+
self,
|
|
54
|
+
exc_type: type[BaseException] | None,
|
|
55
|
+
exc_value: BaseException | None,
|
|
56
|
+
exc_traceback: TracebackType | None,
|
|
57
|
+
) -> None:
|
|
58
|
+
self.close()
|
|
59
|
+
|
|
60
|
+
def __del__(self) -> None:
|
|
61
|
+
self.close()
|
|
62
|
+
|
|
63
|
+
def __iter__(self) -> Iterator[list[Any] | dict[str, Any]]:
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
def __next__(self) -> list[Any] | dict[str, Any]:
|
|
67
|
+
if self.has_next():
|
|
68
|
+
return self.get_next()
|
|
69
|
+
|
|
70
|
+
raise StopIteration
|
|
71
|
+
|
|
72
|
+
def has_next(self) -> bool:
|
|
73
|
+
"""
|
|
74
|
+
Check if there are more rows in the query result.
|
|
75
|
+
|
|
76
|
+
Returns
|
|
77
|
+
-------
|
|
78
|
+
bool
|
|
79
|
+
True if there are more rows in the query result, False otherwise.
|
|
80
|
+
"""
|
|
81
|
+
self.check_for_query_result_close()
|
|
82
|
+
return self._query_result.hasNext()
|
|
83
|
+
|
|
84
|
+
def get_next(self) -> list[Any] | dict[str, Any]:
|
|
85
|
+
"""
|
|
86
|
+
Get the next row in the query result.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
list
|
|
91
|
+
Next row in the query result.
|
|
92
|
+
|
|
93
|
+
Raises
|
|
94
|
+
------
|
|
95
|
+
Exception
|
|
96
|
+
If there are no more rows.
|
|
97
|
+
"""
|
|
98
|
+
self.check_for_query_result_close()
|
|
99
|
+
row = self._query_result.getNext()
|
|
100
|
+
return _row_to_dict(self.columns, row) if self.as_dict else row
|
|
101
|
+
|
|
102
|
+
def get_all(self) -> list[list[Any] | dict[str, Any]]:
|
|
103
|
+
"""
|
|
104
|
+
Get the next row in the query result.
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
list
|
|
109
|
+
All remaining rows in the query result.
|
|
110
|
+
"""
|
|
111
|
+
return list(self)
|
|
112
|
+
|
|
113
|
+
def get_n(self, count: int) -> list[list[Any] | dict[str, Any]]:
|
|
114
|
+
"""
|
|
115
|
+
Get many rows in the query result.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
list
|
|
120
|
+
Up to `count` rows in the query result.
|
|
121
|
+
"""
|
|
122
|
+
results = []
|
|
123
|
+
while self.has_next() and count > 0:
|
|
124
|
+
results.append(self.get_next())
|
|
125
|
+
count -= 1
|
|
126
|
+
return results
|
|
127
|
+
|
|
128
|
+
def close(self) -> None:
|
|
129
|
+
"""Close the query result."""
|
|
130
|
+
if not self.is_closed:
|
|
131
|
+
# Allows the connection to be garbage collected if the query result
|
|
132
|
+
# is closed manually by the user.
|
|
133
|
+
self._query_result.close()
|
|
134
|
+
self.connection = None
|
|
135
|
+
self.is_closed = True
|
|
136
|
+
|
|
137
|
+
def check_for_query_result_close(self) -> None:
|
|
138
|
+
"""
|
|
139
|
+
Check if the query result is closed and raise an exception if it is.
|
|
140
|
+
|
|
141
|
+
Raises
|
|
142
|
+
------
|
|
143
|
+
Exception
|
|
144
|
+
If the query result is closed.
|
|
145
|
+
|
|
146
|
+
"""
|
|
147
|
+
if self.is_closed:
|
|
148
|
+
msg = "Query result is closed"
|
|
149
|
+
raise RuntimeError(msg)
|
|
150
|
+
|
|
151
|
+
def get_as_df(self) -> pd.DataFrame:
|
|
152
|
+
"""
|
|
153
|
+
Get the query result as a Pandas DataFrame.
|
|
154
|
+
|
|
155
|
+
See Also
|
|
156
|
+
--------
|
|
157
|
+
get_as_pl : Get the query result as a Polars DataFrame.
|
|
158
|
+
get_as_arrow : Get the query result as a PyArrow Table.
|
|
159
|
+
|
|
160
|
+
Returns
|
|
161
|
+
-------
|
|
162
|
+
pandas.DataFrame
|
|
163
|
+
Query result as a Pandas DataFrame.
|
|
164
|
+
|
|
165
|
+
"""
|
|
166
|
+
self.check_for_query_result_close()
|
|
167
|
+
|
|
168
|
+
return self._query_result.getAsDF()
|
|
169
|
+
|
|
170
|
+
def get_as_pl(self) -> pl.DataFrame:
|
|
171
|
+
"""
|
|
172
|
+
Get the query result as a Polars DataFrame.
|
|
173
|
+
|
|
174
|
+
See Also
|
|
175
|
+
--------
|
|
176
|
+
get_as_df : Get the query result as a Pandas DataFrame.
|
|
177
|
+
get_as_arrow : Get the query result as a PyArrow Table.
|
|
178
|
+
|
|
179
|
+
Returns
|
|
180
|
+
-------
|
|
181
|
+
polars.DataFrame
|
|
182
|
+
Query result as a Polars DataFrame.
|
|
183
|
+
"""
|
|
184
|
+
import polars as pl
|
|
185
|
+
|
|
186
|
+
self.check_for_query_result_close()
|
|
187
|
+
|
|
188
|
+
# note: polars should always export just a single chunk,
|
|
189
|
+
# (eg: "-1") otherwise it will just need to rechunk anyway
|
|
190
|
+
return pl.from_arrow( # type: ignore[return-value]
|
|
191
|
+
data=self.get_as_arrow(chunk_size=-1, fallbackExtensionTypes=True),
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def get_as_arrow(self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False) -> pa.Table:
|
|
195
|
+
"""
|
|
196
|
+
Get the query result as a PyArrow Table.
|
|
197
|
+
|
|
198
|
+
Parameters
|
|
199
|
+
----------
|
|
200
|
+
chunk_size : Number of rows to include in each chunk.
|
|
201
|
+
None
|
|
202
|
+
The chunk size is adaptive and depends on the number of columns in the query result.
|
|
203
|
+
-1 or 0
|
|
204
|
+
The entire result is returned as a single chunk.
|
|
205
|
+
> 0
|
|
206
|
+
The chunk size is the number of rows specified.
|
|
207
|
+
|
|
208
|
+
fallbackExtensionTypes : bool
|
|
209
|
+
Avoid using Arrow extension types for compatibility with Polars
|
|
210
|
+
|
|
211
|
+
See Also
|
|
212
|
+
--------
|
|
213
|
+
get_as_pl : Get the query result as a Polars DataFrame.
|
|
214
|
+
get_as_df : Get the query result as a Pandas DataFrame.
|
|
215
|
+
|
|
216
|
+
Returns
|
|
217
|
+
-------
|
|
218
|
+
pyarrow.Table
|
|
219
|
+
Query result as a PyArrow Table.
|
|
220
|
+
"""
|
|
221
|
+
self.check_for_query_result_close()
|
|
222
|
+
|
|
223
|
+
if chunk_size is None:
|
|
224
|
+
# Adaptive; target 10m total elements in each chunk.
|
|
225
|
+
# (eg: if we had 10 cols, this would result in a 1m row chunk_size).
|
|
226
|
+
target_n_elems = 10_000_000
|
|
227
|
+
chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
|
|
228
|
+
elif chunk_size <= 0:
|
|
229
|
+
# No chunking: return the entire result as a single chunk
|
|
230
|
+
chunk_size = self.get_num_tuples()
|
|
231
|
+
|
|
232
|
+
return self._query_result.getAsArrow(chunk_size, fallbackExtensionTypes)
|
|
233
|
+
|
|
234
|
+
def get_column_data_types(self) -> list[str]:
|
|
235
|
+
"""
|
|
236
|
+
Get the data types of the columns in the query result.
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
list
|
|
241
|
+
Data types of the columns in the query result.
|
|
242
|
+
|
|
243
|
+
"""
|
|
244
|
+
self.check_for_query_result_close()
|
|
245
|
+
return self._query_result.getColumnDataTypes()
|
|
246
|
+
|
|
247
|
+
def get_column_names(self) -> list[str]:
|
|
248
|
+
"""
|
|
249
|
+
Get the names of the columns in the query result.
|
|
250
|
+
|
|
251
|
+
Returns
|
|
252
|
+
-------
|
|
253
|
+
list
|
|
254
|
+
Names of the columns in the query result.
|
|
255
|
+
|
|
256
|
+
"""
|
|
257
|
+
self.check_for_query_result_close()
|
|
258
|
+
return self._query_result.getColumnNames()
|
|
259
|
+
|
|
260
|
+
def get_schema(self) -> dict[str, str]:
|
|
261
|
+
"""
|
|
262
|
+
Get the column schema of the query result.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
dict
|
|
267
|
+
Schema of the query result.
|
|
268
|
+
|
|
269
|
+
"""
|
|
270
|
+
self.check_for_query_result_close()
|
|
271
|
+
return dict(
|
|
272
|
+
zip(
|
|
273
|
+
self._query_result.getColumnNames(),
|
|
274
|
+
self._query_result.getColumnDataTypes(),
|
|
275
|
+
)
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
def reset_iterator(self) -> None:
|
|
279
|
+
"""Reset the iterator of the query result."""
|
|
280
|
+
self.check_for_query_result_close()
|
|
281
|
+
self._query_result.resetIterator()
|
|
282
|
+
|
|
283
|
+
def get_as_networkx(
|
|
284
|
+
self,
|
|
285
|
+
directed: bool = True, # noqa: FBT001
|
|
286
|
+
) -> nx.MultiGraph | nx.MultiDiGraph:
|
|
287
|
+
"""
|
|
288
|
+
Convert the nodes and rels in query result into a NetworkX directed or undirected graph
|
|
289
|
+
with the following rules:
|
|
290
|
+
Columns with data type other than node or rel will be ignored.
|
|
291
|
+
Duplicated nodes and rels will be converted only once.
|
|
292
|
+
|
|
293
|
+
Parameters
|
|
294
|
+
----------
|
|
295
|
+
directed : bool
|
|
296
|
+
Whether the graph should be directed. Defaults to True.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
-------
|
|
300
|
+
networkx.MultiDiGraph or networkx.MultiGraph
|
|
301
|
+
Query result as a NetworkX graph.
|
|
302
|
+
|
|
303
|
+
"""
|
|
304
|
+
self.check_for_query_result_close()
|
|
305
|
+
import networkx as nx
|
|
306
|
+
|
|
307
|
+
nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph()
|
|
308
|
+
properties_to_extract = self._get_properties_to_extract()
|
|
309
|
+
|
|
310
|
+
self.reset_iterator()
|
|
311
|
+
|
|
312
|
+
nodes = {}
|
|
313
|
+
rels = {}
|
|
314
|
+
table_to_label_dict = {}
|
|
315
|
+
table_primary_key_dict = {}
|
|
316
|
+
|
|
317
|
+
def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
|
|
318
|
+
node_label = node[LABEL]
|
|
319
|
+
return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
|
|
320
|
+
|
|
321
|
+
def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
|
|
322
|
+
return rel[ID]["table"], rel[ID]["offset"]
|
|
323
|
+
|
|
324
|
+
# De-duplicate nodes and rels
|
|
325
|
+
while self.has_next():
|
|
326
|
+
row = self.get_next()
|
|
327
|
+
for i in properties_to_extract:
|
|
328
|
+
# Skip empty nodes and rels, which may be returned by
|
|
329
|
+
# OPTIONAL MATCH
|
|
330
|
+
if row[i] is None or row[i] == {}:
|
|
331
|
+
continue
|
|
332
|
+
column_type, _ = properties_to_extract[i]
|
|
333
|
+
if column_type == Type.NODE.value:
|
|
334
|
+
nid = row[i][ID]
|
|
335
|
+
nodes[nid["table"], nid["offset"]] = row[i]
|
|
336
|
+
table_to_label_dict[nid["table"]] = row[i][LABEL]
|
|
337
|
+
|
|
338
|
+
elif column_type == Type.REL.value:
|
|
339
|
+
rels[encode_rel_id(row[i])] = row[i]
|
|
340
|
+
|
|
341
|
+
elif column_type == Type.RECURSIVE_REL.value:
|
|
342
|
+
for node in row[i][NODES]:
|
|
343
|
+
nid = node[ID]
|
|
344
|
+
nodes[nid["table"], nid["offset"]] = node
|
|
345
|
+
table_to_label_dict[nid["table"]] = node[LABEL]
|
|
346
|
+
for rel in row[i][RELS]:
|
|
347
|
+
for key in list(rel.keys()):
|
|
348
|
+
if rel[key] is None:
|
|
349
|
+
del rel[key]
|
|
350
|
+
rels[encode_rel_id(rel)] = rel
|
|
351
|
+
|
|
352
|
+
# Add nodes
|
|
353
|
+
for node in nodes.values():
|
|
354
|
+
nid = node[ID]
|
|
355
|
+
node_id = node[LABEL] + "_" + str(nid["offset"])
|
|
356
|
+
if node[LABEL] not in table_primary_key_dict:
|
|
357
|
+
props = self.connection._get_node_property_names(node[LABEL])
|
|
358
|
+
for prop_name in props:
|
|
359
|
+
if props[prop_name]["is_primary_key"]:
|
|
360
|
+
table_primary_key_dict[node[LABEL]] = prop_name
|
|
361
|
+
break
|
|
362
|
+
node_id = encode_node_id(node, table_primary_key_dict)
|
|
363
|
+
node[node[LABEL]] = True
|
|
364
|
+
nx_graph.add_node(node_id, **node)
|
|
365
|
+
|
|
366
|
+
# Add rels
|
|
367
|
+
for rel in rels.values():
|
|
368
|
+
src = rel[SRC]
|
|
369
|
+
dst = rel[DST]
|
|
370
|
+
src_node = nodes[src["table"], src["offset"]]
|
|
371
|
+
dst_node = nodes[dst["table"], dst["offset"]]
|
|
372
|
+
src_id = encode_node_id(src_node, table_primary_key_dict)
|
|
373
|
+
dst_id = encode_node_id(dst_node, table_primary_key_dict)
|
|
374
|
+
nx_graph.add_edge(src_id, dst_id, **rel)
|
|
375
|
+
return nx_graph
|
|
376
|
+
|
|
377
|
+
def _get_properties_to_extract(self) -> dict[int, tuple[str, str]]:
|
|
378
|
+
column_names = self.get_column_names()
|
|
379
|
+
column_types = self.get_column_data_types()
|
|
380
|
+
properties_to_extract = {}
|
|
381
|
+
|
|
382
|
+
# Iterate over columns and extract nodes and rels, ignoring other columns
|
|
383
|
+
for i in range(len(column_names)):
|
|
384
|
+
column_name = column_names[i]
|
|
385
|
+
column_type = column_types[i]
|
|
386
|
+
if column_type in [
|
|
387
|
+
Type.NODE.value,
|
|
388
|
+
Type.REL.value,
|
|
389
|
+
Type.RECURSIVE_REL.value,
|
|
390
|
+
]:
|
|
391
|
+
properties_to_extract[i] = (column_type, column_name)
|
|
392
|
+
return properties_to_extract
|
|
393
|
+
|
|
394
|
+
def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]: # type: ignore[type-arg]
|
|
395
|
+
"""
|
|
396
|
+
Convert the nodes and rels in query result into a PyTorch Geometric graph representation
|
|
397
|
+
torch_geometric.data.Data or torch_geometric.data.HeteroData.
|
|
398
|
+
|
|
399
|
+
For node conversion, numerical and boolean properties are directly converted into tensor and
|
|
400
|
+
stored in Data/HeteroData. For properties cannot be converted into tensor automatically
|
|
401
|
+
(please refer to the notes below for more detail), they are returned as unconverted_properties.
|
|
402
|
+
|
|
403
|
+
For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
|
|
404
|
+
as edge_properties.
|
|
405
|
+
|
|
406
|
+
Node properties that cannot be converted into tensor automatically:
|
|
407
|
+
- If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
|
|
408
|
+
automatically.
|
|
409
|
+
- If a node property contains a null value, it cannot be converted automatically.
|
|
410
|
+
- If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
|
|
411
|
+
converted automatically.
|
|
412
|
+
- If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
|
|
413
|
+
is 6 for one node but 5 for another node), it cannot be converted automatically.
|
|
414
|
+
|
|
415
|
+
Additional conversion rules:
|
|
416
|
+
- Columns with data type other than node or rel will be ignored.
|
|
417
|
+
- Duplicated nodes and rels will be converted only once.
|
|
418
|
+
|
|
419
|
+
Returns
|
|
420
|
+
-------
|
|
421
|
+
torch_geometric.data.Data or torch_geometric.data.HeteroData
|
|
422
|
+
Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties
|
|
423
|
+
and edge_index tensor.
|
|
424
|
+
|
|
425
|
+
dict
|
|
426
|
+
A dictionary that maps the positional offset of each node in Data/HeteroData to its primary
|
|
427
|
+
key in the database.
|
|
428
|
+
|
|
429
|
+
dict
|
|
430
|
+
A dictionary contains node properties that cannot be converted into tensor automatically. The
|
|
431
|
+
order of values for each property is aligned with nodes in Data/HeteroData.
|
|
432
|
+
|
|
433
|
+
dict
|
|
434
|
+
A dictionary contains edge properties. The order of values for each property is aligned with
|
|
435
|
+
edge_index in Data/HeteroData.
|
|
436
|
+
"""
|
|
437
|
+
self.check_for_query_result_close()
|
|
438
|
+
# Despite we are not using torch_geometric in this file, we need to
|
|
439
|
+
# import it here to throw an error early if the user does not have
|
|
440
|
+
# torch_geometric or torch installed.
|
|
441
|
+
|
|
442
|
+
converter = TorchGeometricResultConverter(self)
|
|
443
|
+
return converter.get_as_torch_geometric()
|
|
444
|
+
|
|
445
|
+
def get_execution_time(self) -> int:
|
|
446
|
+
"""
|
|
447
|
+
Get the time in ms which was required for executing the query.
|
|
448
|
+
|
|
449
|
+
Returns
|
|
450
|
+
-------
|
|
451
|
+
double
|
|
452
|
+
Query execution time as double in ms.
|
|
453
|
+
|
|
454
|
+
"""
|
|
455
|
+
self.check_for_query_result_close()
|
|
456
|
+
return self._query_result.getExecutionTime()
|
|
457
|
+
|
|
458
|
+
def get_compiling_time(self) -> int:
|
|
459
|
+
"""
|
|
460
|
+
Get the time in ms which was required for compiling the query.
|
|
461
|
+
|
|
462
|
+
Returns
|
|
463
|
+
-------
|
|
464
|
+
double
|
|
465
|
+
Query compile time as double in ms.
|
|
466
|
+
|
|
467
|
+
"""
|
|
468
|
+
self.check_for_query_result_close()
|
|
469
|
+
return self._query_result.getCompilingTime()
|
|
470
|
+
|
|
471
|
+
def get_num_tuples(self) -> int:
|
|
472
|
+
"""
|
|
473
|
+
Get the number of tuples which the query returned.
|
|
474
|
+
|
|
475
|
+
Returns
|
|
476
|
+
-------
|
|
477
|
+
int
|
|
478
|
+
Number of tuples.
|
|
479
|
+
|
|
480
|
+
"""
|
|
481
|
+
self.check_for_query_result_close()
|
|
482
|
+
return self._query_result.getNumTuples()
|
|
483
|
+
|
|
484
|
+
def rows_as_dict(self, state=True) -> Self:
|
|
485
|
+
"""
|
|
486
|
+
Change the format of the results, such that each row is a dict with the
|
|
487
|
+
column name as a key.
|
|
488
|
+
|
|
489
|
+
Parameters
|
|
490
|
+
----------
|
|
491
|
+
state
|
|
492
|
+
Whether to turn dict formatting on or off. Turns it on by default.
|
|
493
|
+
|
|
494
|
+
Returns
|
|
495
|
+
-------
|
|
496
|
+
self
|
|
497
|
+
The object itself.
|
|
498
|
+
|
|
499
|
+
"""
|
|
500
|
+
self.as_dict = state
|
|
501
|
+
if state:
|
|
502
|
+
self.columns = self.get_column_names()
|
|
503
|
+
return self
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]:
|
|
507
|
+
if len(columns) != len(row):
|
|
508
|
+
msg = "Number of columns in output row does not match number of columns"
|
|
509
|
+
raise RuntimeError(msg)
|
|
510
|
+
|
|
511
|
+
return dict(zip(columns, row))
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import multiprocessing
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
from torch import Tensor
|
|
9
|
+
from torch_geometric.data.feature_store import FeatureStore, TensorAttr
|
|
10
|
+
|
|
11
|
+
from .connection import Connection
|
|
12
|
+
from .types import Type
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from torch_geometric.typing import FeatureTensorType
|
|
16
|
+
|
|
17
|
+
from .database import Database
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class LbugFeatureStore(FeatureStore): # type: ignore[misc]
|
|
21
|
+
"""Feature store compatible with `torch_geometric`."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, db: Database, num_threads: int | None = None):
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.db = db
|
|
26
|
+
self.connection: Connection = None # type: ignore[assignment]
|
|
27
|
+
self.node_properties_cache: dict[str, dict[str, Any]] = {}
|
|
28
|
+
if num_threads is None:
|
|
29
|
+
num_threads = multiprocessing.cpu_count()
|
|
30
|
+
self.num_threads = num_threads
|
|
31
|
+
|
|
32
|
+
def __get_connection(self) -> Connection:
|
|
33
|
+
if self.connection is None:
|
|
34
|
+
self.connection = Connection(self.db, self.num_threads)
|
|
35
|
+
return self.connection
|
|
36
|
+
|
|
37
|
+
def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
|
|
38
|
+
raise NotImplementedError
|
|
39
|
+
|
|
40
|
+
def _get_tensor(self, attr: TensorAttr) -> FeatureTensorType | None:
|
|
41
|
+
table_name = attr.group_name
|
|
42
|
+
attr_name = attr.attr_name
|
|
43
|
+
attr_info = self.__get_node_property(table_name, attr_name)
|
|
44
|
+
if (
|
|
45
|
+
attr_info["type"]
|
|
46
|
+
not in [
|
|
47
|
+
Type.INT64.value,
|
|
48
|
+
Type.INT32.value,
|
|
49
|
+
Type.INT16.value,
|
|
50
|
+
Type.DOUBLE.value,
|
|
51
|
+
Type.FLOAT.value,
|
|
52
|
+
Type.BOOL.value,
|
|
53
|
+
]
|
|
54
|
+
) or (attr_info["dimension"] > 0 and "shape" not in attr_info):
|
|
55
|
+
return self.__get_tensor_by_query(attr)
|
|
56
|
+
return self.__get_tensor_by_scan(attr)
|
|
57
|
+
|
|
58
|
+
def __get_tensor_by_scan(self, attr: TensorAttr) -> FeatureTensorType | None:
|
|
59
|
+
table_name = attr.group_name
|
|
60
|
+
attr_name = attr.attr_name
|
|
61
|
+
indices = attr.index
|
|
62
|
+
|
|
63
|
+
if indices is None:
|
|
64
|
+
shape = self.get_tensor_size(attr)
|
|
65
|
+
if indices is None:
|
|
66
|
+
indices = np.arange(shape[0], dtype=np.uint64)
|
|
67
|
+
elif isinstance(indices, slice):
|
|
68
|
+
if indices.step is None or indices.step == 1:
|
|
69
|
+
indices = np.arange(indices.start, indices.stop, dtype=np.uint64)
|
|
70
|
+
else:
|
|
71
|
+
indices = np.arange(indices.start, indices.stop, indices.step, dtype=np.uint64)
|
|
72
|
+
elif isinstance(indices, int):
|
|
73
|
+
indices = np.array([indices])
|
|
74
|
+
|
|
75
|
+
if table_name not in self.node_properties_cache:
|
|
76
|
+
self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
|
|
77
|
+
attr_info = self.node_properties_cache[table_name][attr_name]
|
|
78
|
+
|
|
79
|
+
flat_dim = 1
|
|
80
|
+
if attr_info["dimension"] > 0:
|
|
81
|
+
for i in range(attr_info["dimension"]):
|
|
82
|
+
flat_dim *= attr_info["shape"][i]
|
|
83
|
+
scan_result = self.connection.database._scan_node_table(
|
|
84
|
+
table_name, attr_name, attr_info["type"], flat_dim, indices, self.num_threads
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if attr_info["dimension"] > 0 and "shape" in attr_info:
|
|
88
|
+
result_shape = (len(indices),) + attr_info["shape"]
|
|
89
|
+
scan_result = scan_result.reshape(result_shape)
|
|
90
|
+
|
|
91
|
+
result = torch.from_numpy(scan_result)
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
def __get_tensor_by_query(self, attr: TensorAttr) -> FeatureTensorType | None:
|
|
95
|
+
table_name = attr.group_name
|
|
96
|
+
attr_name = attr.attr_name
|
|
97
|
+
indices = attr.index
|
|
98
|
+
|
|
99
|
+
self.__get_connection()
|
|
100
|
+
|
|
101
|
+
match_clause = f"MATCH (item:{table_name})"
|
|
102
|
+
return_clause = f"RETURN item.{attr_name}"
|
|
103
|
+
|
|
104
|
+
if indices is None:
|
|
105
|
+
where_clause = ""
|
|
106
|
+
elif isinstance(indices, int):
|
|
107
|
+
where_clause = f"WHERE offset(id(item)) = {indices}"
|
|
108
|
+
elif isinstance(indices, slice):
|
|
109
|
+
if indices.step is None or indices.step == 1:
|
|
110
|
+
where_clause = f"WHERE offset(id(item)) >= {indices.start} AND offset(id(item)) < {indices.stop}"
|
|
111
|
+
else:
|
|
112
|
+
where_clause = (
|
|
113
|
+
f"WHERE offset(id(item)) >= {indices.start} AND offset(id(item)) < {indices.stop} "
|
|
114
|
+
f"AND (offset(id(item)) - {indices.start}) % {indices.step} = 0"
|
|
115
|
+
)
|
|
116
|
+
elif isinstance(indices, (Tensor, list, np.ndarray, tuple)):
|
|
117
|
+
where_clause = "WHERE"
|
|
118
|
+
for i in indices:
|
|
119
|
+
where_clause += f" offset(id(item)) = {int(i)} OR"
|
|
120
|
+
where_clause = where_clause[:-3]
|
|
121
|
+
else:
|
|
122
|
+
msg = f"Invalid attr.index type: {type(indices)!s}"
|
|
123
|
+
raise ValueError(msg)
|
|
124
|
+
|
|
125
|
+
query = f"{match_clause} {where_clause} {return_clause}"
|
|
126
|
+
res = self.connection.execute(query)
|
|
127
|
+
|
|
128
|
+
result_list = []
|
|
129
|
+
while res.has_next():
|
|
130
|
+
value_array = res.get_next()
|
|
131
|
+
if len(value_array) == 1:
|
|
132
|
+
value_array = value_array[0]
|
|
133
|
+
result_list.append(value_array)
|
|
134
|
+
try:
|
|
135
|
+
return torch.tensor(result_list)
|
|
136
|
+
except Exception:
|
|
137
|
+
return result_list
|
|
138
|
+
|
|
139
|
+
def _remove_tensor(self, attr: TensorAttr) -> bool:
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
def _get_tensor_size(self, attr: TensorAttr) -> tuple[Any, ...]:
|
|
143
|
+
self.__get_connection()
|
|
144
|
+
length_query = f"MATCH (item:{attr.group_name}) RETURN count(item)"
|
|
145
|
+
res = self.connection.execute(length_query)
|
|
146
|
+
length = res.get_next()[0]
|
|
147
|
+
attr_info = self.__get_node_property(attr.group_name, attr.attr_name)
|
|
148
|
+
if attr_info["dimension"] == 0:
|
|
149
|
+
return (length,)
|
|
150
|
+
else:
|
|
151
|
+
return (length,) + attr_info["shape"]
|
|
152
|
+
|
|
153
|
+
def __get_node_property(self, table_name: str, attr_name: str) -> dict[str, Any]:
|
|
154
|
+
if table_name in self.node_properties_cache and attr_name in self.node_properties_cache[table_name]:
|
|
155
|
+
return self.node_properties_cache[table_name][attr_name]
|
|
156
|
+
self.__get_connection()
|
|
157
|
+
if table_name not in self.node_properties_cache:
|
|
158
|
+
self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
|
|
159
|
+
if attr_name not in self.node_properties_cache[table_name]:
|
|
160
|
+
msg = f"Attribute {attr_name} not found in group {table_name}"
|
|
161
|
+
raise ValueError(msg)
|
|
162
|
+
attr_info = self.node_properties_cache[table_name][attr_name]
|
|
163
|
+
return attr_info
|
|
164
|
+
|
|
165
|
+
def get_all_tensor_attrs(self) -> list[TensorAttr]:
|
|
166
|
+
"""Return all TensorAttr from the table nodes."""
|
|
167
|
+
result_list = []
|
|
168
|
+
self.__get_connection()
|
|
169
|
+
for table_name in self.connection._get_node_table_names():
|
|
170
|
+
if table_name not in self.node_properties_cache:
|
|
171
|
+
self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
|
|
172
|
+
for attr_name in self.node_properties_cache[table_name]:
|
|
173
|
+
if self.node_properties_cache[table_name][attr_name]["type"] in [
|
|
174
|
+
Type.INT64.value,
|
|
175
|
+
Type.INT32.value,
|
|
176
|
+
Type.INT16.value,
|
|
177
|
+
Type.DOUBLE.value,
|
|
178
|
+
Type.FLOAT.value,
|
|
179
|
+
Type.BOOL.value,
|
|
180
|
+
] and (
|
|
181
|
+
self.node_properties_cache[table_name][attr_name]["dimension"] == 0
|
|
182
|
+
or "shape" in self.node_properties_cache[table_name][attr_name]
|
|
183
|
+
):
|
|
184
|
+
result_list.append(TensorAttr(table_name, attr_name))
|
|
185
|
+
return result_list
|