real-ladybug 0.0.1.dev1__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of real-ladybug might be problematic. Click here for more details.

Files changed (114) hide show
  1. real_ladybug/__init__.py +83 -0
  2. real_ladybug/_lbug.cp311-win_amd64.pyd +0 -0
  3. real_ladybug/_lbug.exp +0 -0
  4. real_ladybug/_lbug.lib +0 -0
  5. real_ladybug/async_connection.py +226 -0
  6. real_ladybug/connection.py +323 -0
  7. real_ladybug/constants.py +7 -0
  8. real_ladybug/database.py +307 -0
  9. real_ladybug/prepared_statement.py +51 -0
  10. real_ladybug/py.typed +0 -0
  11. real_ladybug/query_result.py +511 -0
  12. real_ladybug/torch_geometric_feature_store.py +185 -0
  13. real_ladybug/torch_geometric_graph_store.py +131 -0
  14. real_ladybug/torch_geometric_result_converter.py +282 -0
  15. real_ladybug/types.py +39 -0
  16. real_ladybug-0.0.1.dev1.dist-info/METADATA +88 -0
  17. real_ladybug-0.0.1.dev1.dist-info/RECORD +114 -0
  18. real_ladybug-0.0.1.dev1.dist-info/WHEEL +5 -0
  19. real_ladybug-0.0.1.dev1.dist-info/licenses/LICENSE +21 -0
  20. real_ladybug-0.0.1.dev1.dist-info/top_level.txt +3 -0
  21. real_ladybug-0.0.1.dev1.dist-info/zip-safe +1 -0
  22. real_ladybug-source/scripts/antlr4/hash.py +2 -0
  23. real_ladybug-source/scripts/antlr4/keywordhandler.py +47 -0
  24. real_ladybug-source/scripts/collect-extensions.py +68 -0
  25. real_ladybug-source/scripts/collect-single-file-header.py +126 -0
  26. real_ladybug-source/scripts/export-dbs.py +101 -0
  27. real_ladybug-source/scripts/export-import-test.py +345 -0
  28. real_ladybug-source/scripts/extension/purge-beta.py +34 -0
  29. real_ladybug-source/scripts/generate-cpp-docs/collect_files.py +122 -0
  30. real_ladybug-source/scripts/generate-tinysnb.py +34 -0
  31. real_ladybug-source/scripts/get-clangd-diagnostics.py +233 -0
  32. real_ladybug-source/scripts/migrate-lbug-db.py +308 -0
  33. real_ladybug-source/scripts/multiplatform-test-helper/collect-results.py +71 -0
  34. real_ladybug-source/scripts/multiplatform-test-helper/notify-discord.py +68 -0
  35. real_ladybug-source/scripts/pip-package/package_tar.py +90 -0
  36. real_ladybug-source/scripts/pip-package/setup.py +130 -0
  37. real_ladybug-source/scripts/run-clang-format.py +408 -0
  38. real_ladybug-source/scripts/setup-extension-repo.py +67 -0
  39. real_ladybug-source/scripts/test-simsimd-dispatch.py +45 -0
  40. real_ladybug-source/scripts/update-nightly-build-version.py +81 -0
  41. real_ladybug-source/third_party/brotli/scripts/dictionary/step-01-download-rfc.py +16 -0
  42. real_ladybug-source/third_party/brotli/scripts/dictionary/step-02-rfc-to-bin.py +34 -0
  43. real_ladybug-source/third_party/brotli/scripts/dictionary/step-03-validate-bin.py +35 -0
  44. real_ladybug-source/third_party/brotli/scripts/dictionary/step-04-generate-java-literals.py +85 -0
  45. real_ladybug-source/third_party/pybind11/tools/codespell_ignore_lines_from_errors.py +35 -0
  46. real_ladybug-source/third_party/pybind11/tools/libsize.py +36 -0
  47. real_ladybug-source/third_party/pybind11/tools/make_changelog.py +63 -0
  48. real_ladybug-source/tools/python_api/build/real_ladybug/__init__.py +83 -0
  49. real_ladybug-source/tools/python_api/build/real_ladybug/async_connection.py +226 -0
  50. real_ladybug-source/tools/python_api/build/real_ladybug/connection.py +323 -0
  51. real_ladybug-source/tools/python_api/build/real_ladybug/constants.py +7 -0
  52. real_ladybug-source/tools/python_api/build/real_ladybug/database.py +307 -0
  53. real_ladybug-source/tools/python_api/build/real_ladybug/prepared_statement.py +51 -0
  54. real_ladybug-source/tools/python_api/build/real_ladybug/py.typed +0 -0
  55. real_ladybug-source/tools/python_api/build/real_ladybug/query_result.py +511 -0
  56. real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_feature_store.py +185 -0
  57. real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_graph_store.py +131 -0
  58. real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_result_converter.py +282 -0
  59. real_ladybug-source/tools/python_api/build/real_ladybug/types.py +39 -0
  60. real_ladybug-source/tools/python_api/src_py/__init__.py +83 -0
  61. real_ladybug-source/tools/python_api/src_py/async_connection.py +226 -0
  62. real_ladybug-source/tools/python_api/src_py/connection.py +323 -0
  63. real_ladybug-source/tools/python_api/src_py/constants.py +7 -0
  64. real_ladybug-source/tools/python_api/src_py/database.py +307 -0
  65. real_ladybug-source/tools/python_api/src_py/prepared_statement.py +51 -0
  66. real_ladybug-source/tools/python_api/src_py/py.typed +0 -0
  67. real_ladybug-source/tools/python_api/src_py/query_result.py +511 -0
  68. real_ladybug-source/tools/python_api/src_py/torch_geometric_feature_store.py +185 -0
  69. real_ladybug-source/tools/python_api/src_py/torch_geometric_graph_store.py +131 -0
  70. real_ladybug-source/tools/python_api/src_py/torch_geometric_result_converter.py +282 -0
  71. real_ladybug-source/tools/python_api/src_py/types.py +39 -0
  72. real_ladybug-source/tools/python_api/test/conftest.py +230 -0
  73. real_ladybug-source/tools/python_api/test/disabled_test_extension.py +73 -0
  74. real_ladybug-source/tools/python_api/test/ground_truth.py +430 -0
  75. real_ladybug-source/tools/python_api/test/test_arrow.py +694 -0
  76. real_ladybug-source/tools/python_api/test/test_async_connection.py +159 -0
  77. real_ladybug-source/tools/python_api/test/test_blob_parameter.py +145 -0
  78. real_ladybug-source/tools/python_api/test/test_connection.py +49 -0
  79. real_ladybug-source/tools/python_api/test/test_database.py +234 -0
  80. real_ladybug-source/tools/python_api/test/test_datatype.py +372 -0
  81. real_ladybug-source/tools/python_api/test/test_df.py +564 -0
  82. real_ladybug-source/tools/python_api/test/test_dict.py +112 -0
  83. real_ladybug-source/tools/python_api/test/test_exception.py +54 -0
  84. real_ladybug-source/tools/python_api/test/test_fsm.py +227 -0
  85. real_ladybug-source/tools/python_api/test/test_get_header.py +49 -0
  86. real_ladybug-source/tools/python_api/test/test_helper.py +8 -0
  87. real_ladybug-source/tools/python_api/test/test_issue.py +147 -0
  88. real_ladybug-source/tools/python_api/test/test_iteration.py +96 -0
  89. real_ladybug-source/tools/python_api/test/test_networkx.py +437 -0
  90. real_ladybug-source/tools/python_api/test/test_parameter.py +340 -0
  91. real_ladybug-source/tools/python_api/test/test_prepared_statement.py +117 -0
  92. real_ladybug-source/tools/python_api/test/test_query_result.py +54 -0
  93. real_ladybug-source/tools/python_api/test/test_query_result_close.py +44 -0
  94. real_ladybug-source/tools/python_api/test/test_scan_pandas.py +676 -0
  95. real_ladybug-source/tools/python_api/test/test_scan_pandas_pyarrow.py +714 -0
  96. real_ladybug-source/tools/python_api/test/test_scan_polars.py +165 -0
  97. real_ladybug-source/tools/python_api/test/test_scan_pyarrow.py +167 -0
  98. real_ladybug-source/tools/python_api/test/test_timeout.py +11 -0
  99. real_ladybug-source/tools/python_api/test/test_torch_geometric.py +640 -0
  100. real_ladybug-source/tools/python_api/test/test_torch_geometric_remote_backend.py +111 -0
  101. real_ladybug-source/tools/python_api/test/test_udf.py +207 -0
  102. real_ladybug-source/tools/python_api/test/test_version.py +6 -0
  103. real_ladybug-source/tools/python_api/test/test_wal.py +80 -0
  104. real_ladybug-source/tools/python_api/test/type_aliases.py +10 -0
  105. real_ladybug-source/tools/rust_api/update_version.py +47 -0
  106. real_ladybug-source/tools/shell/test/conftest.py +218 -0
  107. real_ladybug-source/tools/shell/test/test_helper.py +60 -0
  108. real_ladybug-source/tools/shell/test/test_shell_basics.py +325 -0
  109. real_ladybug-source/tools/shell/test/test_shell_commands.py +656 -0
  110. real_ladybug-source/tools/shell/test/test_shell_control_edit.py +438 -0
  111. real_ladybug-source/tools/shell/test/test_shell_control_search.py +468 -0
  112. real_ladybug-source/tools/shell/test/test_shell_esc_edit.py +232 -0
  113. real_ladybug-source/tools/shell/test/test_shell_esc_search.py +162 -0
  114. real_ladybug-source/tools/shell/test/test_shell_flags.py +645 -0
@@ -0,0 +1,511 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from .torch_geometric_result_converter import TorchGeometricResultConverter
6
+ from .types import Type
7
+
8
+ from .constants import ID, LABEL, SRC, DST, NODES, RELS
9
+
10
+ if TYPE_CHECKING:
11
+ import sys
12
+ from collections.abc import Iterator
13
+ from types import TracebackType
14
+ from typing import Any
15
+
16
+ import networkx as nx
17
+ import pandas as pd
18
+ import polars as pl
19
+ import pyarrow as pa
20
+ import torch_geometric.data as geo
21
+
22
+ from . import _lbug
23
+
24
+ if sys.version_info >= (3, 11):
25
+ from typing import Self
26
+ else:
27
+ from typing_extensions import Self
28
+
29
+
30
+ class QueryResult:
31
+ """QueryResult stores the result of a query execution."""
32
+
33
+ def __init__(self, connection: _lbug.Connection, query_result: _lbug.QueryResult): # type: ignore[name-defined]
34
+ """
35
+ Parameters
36
+ ----------
37
+ connection : _lbug.Connection
38
+ The underlying C++ connection object from pybind11.
39
+
40
+ query_result : _lbug.QueryResult
41
+ The underlying C++ query result object from pybind11.
42
+
43
+ """
44
+ self.connection = connection
45
+ self._query_result = query_result
46
+ self.is_closed = False
47
+ self.as_dict = False
48
+
49
+ def __enter__(self) -> Self:
50
+ return self
51
+
52
+ def __exit__(
53
+ self,
54
+ exc_type: type[BaseException] | None,
55
+ exc_value: BaseException | None,
56
+ exc_traceback: TracebackType | None,
57
+ ) -> None:
58
+ self.close()
59
+
60
+ def __del__(self) -> None:
61
+ self.close()
62
+
63
+ def __iter__(self) -> Iterator[list[Any] | dict[str, Any]]:
64
+ return self
65
+
66
+ def __next__(self) -> list[Any] | dict[str, Any]:
67
+ if self.has_next():
68
+ return self.get_next()
69
+
70
+ raise StopIteration
71
+
72
+ def has_next(self) -> bool:
73
+ """
74
+ Check if there are more rows in the query result.
75
+
76
+ Returns
77
+ -------
78
+ bool
79
+ True if there are more rows in the query result, False otherwise.
80
+ """
81
+ self.check_for_query_result_close()
82
+ return self._query_result.hasNext()
83
+
84
+ def get_next(self) -> list[Any] | dict[str, Any]:
85
+ """
86
+ Get the next row in the query result.
87
+
88
+ Returns
89
+ -------
90
+ list
91
+ Next row in the query result.
92
+
93
+ Raises
94
+ ------
95
+ Exception
96
+ If there are no more rows.
97
+ """
98
+ self.check_for_query_result_close()
99
+ row = self._query_result.getNext()
100
+ return _row_to_dict(self.columns, row) if self.as_dict else row
101
+
102
+ def get_all(self) -> list[list[Any] | dict[str, Any]]:
103
+ """
104
+ Get the next row in the query result.
105
+
106
+ Returns
107
+ -------
108
+ list
109
+ All remaining rows in the query result.
110
+ """
111
+ return list(self)
112
+
113
+ def get_n(self, count: int) -> list[list[Any] | dict[str, Any]]:
114
+ """
115
+ Get many rows in the query result.
116
+
117
+ Returns
118
+ -------
119
+ list
120
+ Up to `count` rows in the query result.
121
+ """
122
+ results = []
123
+ while self.has_next() and count > 0:
124
+ results.append(self.get_next())
125
+ count -= 1
126
+ return results
127
+
128
+ def close(self) -> None:
129
+ """Close the query result."""
130
+ if not self.is_closed:
131
+ # Allows the connection to be garbage collected if the query result
132
+ # is closed manually by the user.
133
+ self._query_result.close()
134
+ self.connection = None
135
+ self.is_closed = True
136
+
137
+ def check_for_query_result_close(self) -> None:
138
+ """
139
+ Check if the query result is closed and raise an exception if it is.
140
+
141
+ Raises
142
+ ------
143
+ Exception
144
+ If the query result is closed.
145
+
146
+ """
147
+ if self.is_closed:
148
+ msg = "Query result is closed"
149
+ raise RuntimeError(msg)
150
+
151
+ def get_as_df(self) -> pd.DataFrame:
152
+ """
153
+ Get the query result as a Pandas DataFrame.
154
+
155
+ See Also
156
+ --------
157
+ get_as_pl : Get the query result as a Polars DataFrame.
158
+ get_as_arrow : Get the query result as a PyArrow Table.
159
+
160
+ Returns
161
+ -------
162
+ pandas.DataFrame
163
+ Query result as a Pandas DataFrame.
164
+
165
+ """
166
+ self.check_for_query_result_close()
167
+
168
+ return self._query_result.getAsDF()
169
+
170
+ def get_as_pl(self) -> pl.DataFrame:
171
+ """
172
+ Get the query result as a Polars DataFrame.
173
+
174
+ See Also
175
+ --------
176
+ get_as_df : Get the query result as a Pandas DataFrame.
177
+ get_as_arrow : Get the query result as a PyArrow Table.
178
+
179
+ Returns
180
+ -------
181
+ polars.DataFrame
182
+ Query result as a Polars DataFrame.
183
+ """
184
+ import polars as pl
185
+
186
+ self.check_for_query_result_close()
187
+
188
+ # note: polars should always export just a single chunk,
189
+ # (eg: "-1") otherwise it will just need to rechunk anyway
190
+ return pl.from_arrow( # type: ignore[return-value]
191
+ data=self.get_as_arrow(chunk_size=-1, fallbackExtensionTypes=True),
192
+ )
193
+
194
+ def get_as_arrow(self, chunk_size: int | None = None, *, fallbackExtensionTypes: bool = False) -> pa.Table:
195
+ """
196
+ Get the query result as a PyArrow Table.
197
+
198
+ Parameters
199
+ ----------
200
+ chunk_size : Number of rows to include in each chunk.
201
+ None
202
+ The chunk size is adaptive and depends on the number of columns in the query result.
203
+ -1 or 0
204
+ The entire result is returned as a single chunk.
205
+ > 0
206
+ The chunk size is the number of rows specified.
207
+
208
+ fallbackExtensionTypes : bool
209
+ Avoid using Arrow extension types for compatibility with Polars
210
+
211
+ See Also
212
+ --------
213
+ get_as_pl : Get the query result as a Polars DataFrame.
214
+ get_as_df : Get the query result as a Pandas DataFrame.
215
+
216
+ Returns
217
+ -------
218
+ pyarrow.Table
219
+ Query result as a PyArrow Table.
220
+ """
221
+ self.check_for_query_result_close()
222
+
223
+ if chunk_size is None:
224
+ # Adaptive; target 10m total elements in each chunk.
225
+ # (eg: if we had 10 cols, this would result in a 1m row chunk_size).
226
+ target_n_elems = 10_000_000
227
+ chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
228
+ elif chunk_size <= 0:
229
+ # No chunking: return the entire result as a single chunk
230
+ chunk_size = self.get_num_tuples()
231
+
232
+ return self._query_result.getAsArrow(chunk_size, fallbackExtensionTypes)
233
+
234
+ def get_column_data_types(self) -> list[str]:
235
+ """
236
+ Get the data types of the columns in the query result.
237
+
238
+ Returns
239
+ -------
240
+ list
241
+ Data types of the columns in the query result.
242
+
243
+ """
244
+ self.check_for_query_result_close()
245
+ return self._query_result.getColumnDataTypes()
246
+
247
+ def get_column_names(self) -> list[str]:
248
+ """
249
+ Get the names of the columns in the query result.
250
+
251
+ Returns
252
+ -------
253
+ list
254
+ Names of the columns in the query result.
255
+
256
+ """
257
+ self.check_for_query_result_close()
258
+ return self._query_result.getColumnNames()
259
+
260
+ def get_schema(self) -> dict[str, str]:
261
+ """
262
+ Get the column schema of the query result.
263
+
264
+ Returns
265
+ -------
266
+ dict
267
+ Schema of the query result.
268
+
269
+ """
270
+ self.check_for_query_result_close()
271
+ return dict(
272
+ zip(
273
+ self._query_result.getColumnNames(),
274
+ self._query_result.getColumnDataTypes(),
275
+ )
276
+ )
277
+
278
+ def reset_iterator(self) -> None:
279
+ """Reset the iterator of the query result."""
280
+ self.check_for_query_result_close()
281
+ self._query_result.resetIterator()
282
+
283
+ def get_as_networkx(
284
+ self,
285
+ directed: bool = True, # noqa: FBT001
286
+ ) -> nx.MultiGraph | nx.MultiDiGraph:
287
+ """
288
+ Convert the nodes and rels in query result into a NetworkX directed or undirected graph
289
+ with the following rules:
290
+ Columns with data type other than node or rel will be ignored.
291
+ Duplicated nodes and rels will be converted only once.
292
+
293
+ Parameters
294
+ ----------
295
+ directed : bool
296
+ Whether the graph should be directed. Defaults to True.
297
+
298
+ Returns
299
+ -------
300
+ networkx.MultiDiGraph or networkx.MultiGraph
301
+ Query result as a NetworkX graph.
302
+
303
+ """
304
+ self.check_for_query_result_close()
305
+ import networkx as nx
306
+
307
+ nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph()
308
+ properties_to_extract = self._get_properties_to_extract()
309
+
310
+ self.reset_iterator()
311
+
312
+ nodes = {}
313
+ rels = {}
314
+ table_to_label_dict = {}
315
+ table_primary_key_dict = {}
316
+
317
+ def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
318
+ node_label = node[LABEL]
319
+ return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
320
+
321
+ def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
322
+ return rel[ID]["table"], rel[ID]["offset"]
323
+
324
+ # De-duplicate nodes and rels
325
+ while self.has_next():
326
+ row = self.get_next()
327
+ for i in properties_to_extract:
328
+ # Skip empty nodes and rels, which may be returned by
329
+ # OPTIONAL MATCH
330
+ if row[i] is None or row[i] == {}:
331
+ continue
332
+ column_type, _ = properties_to_extract[i]
333
+ if column_type == Type.NODE.value:
334
+ nid = row[i][ID]
335
+ nodes[nid["table"], nid["offset"]] = row[i]
336
+ table_to_label_dict[nid["table"]] = row[i][LABEL]
337
+
338
+ elif column_type == Type.REL.value:
339
+ rels[encode_rel_id(row[i])] = row[i]
340
+
341
+ elif column_type == Type.RECURSIVE_REL.value:
342
+ for node in row[i][NODES]:
343
+ nid = node[ID]
344
+ nodes[nid["table"], nid["offset"]] = node
345
+ table_to_label_dict[nid["table"]] = node[LABEL]
346
+ for rel in row[i][RELS]:
347
+ for key in list(rel.keys()):
348
+ if rel[key] is None:
349
+ del rel[key]
350
+ rels[encode_rel_id(rel)] = rel
351
+
352
+ # Add nodes
353
+ for node in nodes.values():
354
+ nid = node[ID]
355
+ node_id = node[LABEL] + "_" + str(nid["offset"])
356
+ if node[LABEL] not in table_primary_key_dict:
357
+ props = self.connection._get_node_property_names(node[LABEL])
358
+ for prop_name in props:
359
+ if props[prop_name]["is_primary_key"]:
360
+ table_primary_key_dict[node[LABEL]] = prop_name
361
+ break
362
+ node_id = encode_node_id(node, table_primary_key_dict)
363
+ node[node[LABEL]] = True
364
+ nx_graph.add_node(node_id, **node)
365
+
366
+ # Add rels
367
+ for rel in rels.values():
368
+ src = rel[SRC]
369
+ dst = rel[DST]
370
+ src_node = nodes[src["table"], src["offset"]]
371
+ dst_node = nodes[dst["table"], dst["offset"]]
372
+ src_id = encode_node_id(src_node, table_primary_key_dict)
373
+ dst_id = encode_node_id(dst_node, table_primary_key_dict)
374
+ nx_graph.add_edge(src_id, dst_id, **rel)
375
+ return nx_graph
376
+
377
+ def _get_properties_to_extract(self) -> dict[int, tuple[str, str]]:
378
+ column_names = self.get_column_names()
379
+ column_types = self.get_column_data_types()
380
+ properties_to_extract = {}
381
+
382
+ # Iterate over columns and extract nodes and rels, ignoring other columns
383
+ for i in range(len(column_names)):
384
+ column_name = column_names[i]
385
+ column_type = column_types[i]
386
+ if column_type in [
387
+ Type.NODE.value,
388
+ Type.REL.value,
389
+ Type.RECURSIVE_REL.value,
390
+ ]:
391
+ properties_to_extract[i] = (column_type, column_name)
392
+ return properties_to_extract
393
+
394
+ def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]: # type: ignore[type-arg]
395
+ """
396
+ Convert the nodes and rels in query result into a PyTorch Geometric graph representation
397
+ torch_geometric.data.Data or torch_geometric.data.HeteroData.
398
+
399
+ For node conversion, numerical and boolean properties are directly converted into tensor and
400
+ stored in Data/HeteroData. For properties cannot be converted into tensor automatically
401
+ (please refer to the notes below for more detail), they are returned as unconverted_properties.
402
+
403
+ For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
404
+ as edge_properties.
405
+
406
+ Node properties that cannot be converted into tensor automatically:
407
+ - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
408
+ automatically.
409
+ - If a node property contains a null value, it cannot be converted automatically.
410
+ - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
411
+ converted automatically.
412
+ - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
413
+ is 6 for one node but 5 for another node), it cannot be converted automatically.
414
+
415
+ Additional conversion rules:
416
+ - Columns with data type other than node or rel will be ignored.
417
+ - Duplicated nodes and rels will be converted only once.
418
+
419
+ Returns
420
+ -------
421
+ torch_geometric.data.Data or torch_geometric.data.HeteroData
422
+ Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties
423
+ and edge_index tensor.
424
+
425
+ dict
426
+ A dictionary that maps the positional offset of each node in Data/HeteroData to its primary
427
+ key in the database.
428
+
429
+ dict
430
+ A dictionary contains node properties that cannot be converted into tensor automatically. The
431
+ order of values for each property is aligned with nodes in Data/HeteroData.
432
+
433
+ dict
434
+ A dictionary contains edge properties. The order of values for each property is aligned with
435
+ edge_index in Data/HeteroData.
436
+ """
437
+ self.check_for_query_result_close()
438
+ # Despite we are not using torch_geometric in this file, we need to
439
+ # import it here to throw an error early if the user does not have
440
+ # torch_geometric or torch installed.
441
+
442
+ converter = TorchGeometricResultConverter(self)
443
+ return converter.get_as_torch_geometric()
444
+
445
+ def get_execution_time(self) -> int:
446
+ """
447
+ Get the time in ms which was required for executing the query.
448
+
449
+ Returns
450
+ -------
451
+ double
452
+ Query execution time as double in ms.
453
+
454
+ """
455
+ self.check_for_query_result_close()
456
+ return self._query_result.getExecutionTime()
457
+
458
+ def get_compiling_time(self) -> int:
459
+ """
460
+ Get the time in ms which was required for compiling the query.
461
+
462
+ Returns
463
+ -------
464
+ double
465
+ Query compile time as double in ms.
466
+
467
+ """
468
+ self.check_for_query_result_close()
469
+ return self._query_result.getCompilingTime()
470
+
471
+ def get_num_tuples(self) -> int:
472
+ """
473
+ Get the number of tuples which the query returned.
474
+
475
+ Returns
476
+ -------
477
+ int
478
+ Number of tuples.
479
+
480
+ """
481
+ self.check_for_query_result_close()
482
+ return self._query_result.getNumTuples()
483
+
484
+ def rows_as_dict(self, state=True) -> Self:
485
+ """
486
+ Change the format of the results, such that each row is a dict with the
487
+ column name as a key.
488
+
489
+ Parameters
490
+ ----------
491
+ state
492
+ Whether to turn dict formatting on or off. Turns it on by default.
493
+
494
+ Returns
495
+ -------
496
+ self
497
+ The object itself.
498
+
499
+ """
500
+ self.as_dict = state
501
+ if state:
502
+ self.columns = self.get_column_names()
503
+ return self
504
+
505
+
506
+ def _row_to_dict(columns: list[str], row: list[Any]) -> dict[str, Any]:
507
+ if len(columns) != len(row):
508
+ msg = "Number of columns in output row does not match number of columns"
509
+ raise RuntimeError(msg)
510
+
511
+ return dict(zip(columns, row))
@@ -0,0 +1,185 @@
1
+ from __future__ import annotations
2
+
3
+ import multiprocessing
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch import Tensor
9
+ from torch_geometric.data.feature_store import FeatureStore, TensorAttr
10
+
11
+ from .connection import Connection
12
+ from .types import Type
13
+
14
+ if TYPE_CHECKING:
15
+ from torch_geometric.typing import FeatureTensorType
16
+
17
+ from .database import Database
18
+
19
+
20
+ class LbugFeatureStore(FeatureStore): # type: ignore[misc]
21
+ """Feature store compatible with `torch_geometric`."""
22
+
23
+ def __init__(self, db: Database, num_threads: int | None = None):
24
+ super().__init__()
25
+ self.db = db
26
+ self.connection: Connection = None # type: ignore[assignment]
27
+ self.node_properties_cache: dict[str, dict[str, Any]] = {}
28
+ if num_threads is None:
29
+ num_threads = multiprocessing.cpu_count()
30
+ self.num_threads = num_threads
31
+
32
+ def __get_connection(self) -> Connection:
33
+ if self.connection is None:
34
+ self.connection = Connection(self.db, self.num_threads)
35
+ return self.connection
36
+
37
+ def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
38
+ raise NotImplementedError
39
+
40
+ def _get_tensor(self, attr: TensorAttr) -> FeatureTensorType | None:
41
+ table_name = attr.group_name
42
+ attr_name = attr.attr_name
43
+ attr_info = self.__get_node_property(table_name, attr_name)
44
+ if (
45
+ attr_info["type"]
46
+ not in [
47
+ Type.INT64.value,
48
+ Type.INT32.value,
49
+ Type.INT16.value,
50
+ Type.DOUBLE.value,
51
+ Type.FLOAT.value,
52
+ Type.BOOL.value,
53
+ ]
54
+ ) or (attr_info["dimension"] > 0 and "shape" not in attr_info):
55
+ return self.__get_tensor_by_query(attr)
56
+ return self.__get_tensor_by_scan(attr)
57
+
58
+ def __get_tensor_by_scan(self, attr: TensorAttr) -> FeatureTensorType | None:
59
+ table_name = attr.group_name
60
+ attr_name = attr.attr_name
61
+ indices = attr.index
62
+
63
+ if indices is None:
64
+ shape = self.get_tensor_size(attr)
65
+ if indices is None:
66
+ indices = np.arange(shape[0], dtype=np.uint64)
67
+ elif isinstance(indices, slice):
68
+ if indices.step is None or indices.step == 1:
69
+ indices = np.arange(indices.start, indices.stop, dtype=np.uint64)
70
+ else:
71
+ indices = np.arange(indices.start, indices.stop, indices.step, dtype=np.uint64)
72
+ elif isinstance(indices, int):
73
+ indices = np.array([indices])
74
+
75
+ if table_name not in self.node_properties_cache:
76
+ self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
77
+ attr_info = self.node_properties_cache[table_name][attr_name]
78
+
79
+ flat_dim = 1
80
+ if attr_info["dimension"] > 0:
81
+ for i in range(attr_info["dimension"]):
82
+ flat_dim *= attr_info["shape"][i]
83
+ scan_result = self.connection.database._scan_node_table(
84
+ table_name, attr_name, attr_info["type"], flat_dim, indices, self.num_threads
85
+ )
86
+
87
+ if attr_info["dimension"] > 0 and "shape" in attr_info:
88
+ result_shape = (len(indices),) + attr_info["shape"]
89
+ scan_result = scan_result.reshape(result_shape)
90
+
91
+ result = torch.from_numpy(scan_result)
92
+ return result
93
+
94
+ def __get_tensor_by_query(self, attr: TensorAttr) -> FeatureTensorType | None:
95
+ table_name = attr.group_name
96
+ attr_name = attr.attr_name
97
+ indices = attr.index
98
+
99
+ self.__get_connection()
100
+
101
+ match_clause = f"MATCH (item:{table_name})"
102
+ return_clause = f"RETURN item.{attr_name}"
103
+
104
+ if indices is None:
105
+ where_clause = ""
106
+ elif isinstance(indices, int):
107
+ where_clause = f"WHERE offset(id(item)) = {indices}"
108
+ elif isinstance(indices, slice):
109
+ if indices.step is None or indices.step == 1:
110
+ where_clause = f"WHERE offset(id(item)) >= {indices.start} AND offset(id(item)) < {indices.stop}"
111
+ else:
112
+ where_clause = (
113
+ f"WHERE offset(id(item)) >= {indices.start} AND offset(id(item)) < {indices.stop} "
114
+ f"AND (offset(id(item)) - {indices.start}) % {indices.step} = 0"
115
+ )
116
+ elif isinstance(indices, (Tensor, list, np.ndarray, tuple)):
117
+ where_clause = "WHERE"
118
+ for i in indices:
119
+ where_clause += f" offset(id(item)) = {int(i)} OR"
120
+ where_clause = where_clause[:-3]
121
+ else:
122
+ msg = f"Invalid attr.index type: {type(indices)!s}"
123
+ raise ValueError(msg)
124
+
125
+ query = f"{match_clause} {where_clause} {return_clause}"
126
+ res = self.connection.execute(query)
127
+
128
+ result_list = []
129
+ while res.has_next():
130
+ value_array = res.get_next()
131
+ if len(value_array) == 1:
132
+ value_array = value_array[0]
133
+ result_list.append(value_array)
134
+ try:
135
+ return torch.tensor(result_list)
136
+ except Exception:
137
+ return result_list
138
+
139
+ def _remove_tensor(self, attr: TensorAttr) -> bool:
140
+ raise NotImplementedError
141
+
142
+ def _get_tensor_size(self, attr: TensorAttr) -> tuple[Any, ...]:
143
+ self.__get_connection()
144
+ length_query = f"MATCH (item:{attr.group_name}) RETURN count(item)"
145
+ res = self.connection.execute(length_query)
146
+ length = res.get_next()[0]
147
+ attr_info = self.__get_node_property(attr.group_name, attr.attr_name)
148
+ if attr_info["dimension"] == 0:
149
+ return (length,)
150
+ else:
151
+ return (length,) + attr_info["shape"]
152
+
153
+ def __get_node_property(self, table_name: str, attr_name: str) -> dict[str, Any]:
154
+ if table_name in self.node_properties_cache and attr_name in self.node_properties_cache[table_name]:
155
+ return self.node_properties_cache[table_name][attr_name]
156
+ self.__get_connection()
157
+ if table_name not in self.node_properties_cache:
158
+ self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
159
+ if attr_name not in self.node_properties_cache[table_name]:
160
+ msg = f"Attribute {attr_name} not found in group {table_name}"
161
+ raise ValueError(msg)
162
+ attr_info = self.node_properties_cache[table_name][attr_name]
163
+ return attr_info
164
+
165
+ def get_all_tensor_attrs(self) -> list[TensorAttr]:
166
+ """Return all TensorAttr from the table nodes."""
167
+ result_list = []
168
+ self.__get_connection()
169
+ for table_name in self.connection._get_node_table_names():
170
+ if table_name not in self.node_properties_cache:
171
+ self.node_properties_cache[table_name] = self.connection._get_node_property_names(table_name)
172
+ for attr_name in self.node_properties_cache[table_name]:
173
+ if self.node_properties_cache[table_name][attr_name]["type"] in [
174
+ Type.INT64.value,
175
+ Type.INT32.value,
176
+ Type.INT16.value,
177
+ Type.DOUBLE.value,
178
+ Type.FLOAT.value,
179
+ Type.BOOL.value,
180
+ ] and (
181
+ self.node_properties_cache[table_name][attr_name]["dimension"] == 0
182
+ or "shape" in self.node_properties_cache[table_name][attr_name]
183
+ ):
184
+ result_list.append(TensorAttr(table_name, attr_name))
185
+ return result_list