real-ladybug 0.0.1.dev1__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of real-ladybug might be problematic. Click here for more details.
- real_ladybug/__init__.py +83 -0
- real_ladybug/_lbug.cp312-win_amd64.pyd +0 -0
- real_ladybug/_lbug.exp +0 -0
- real_ladybug/_lbug.lib +0 -0
- real_ladybug/async_connection.py +226 -0
- real_ladybug/connection.py +323 -0
- real_ladybug/constants.py +7 -0
- real_ladybug/database.py +307 -0
- real_ladybug/prepared_statement.py +51 -0
- real_ladybug/py.typed +0 -0
- real_ladybug/query_result.py +511 -0
- real_ladybug/torch_geometric_feature_store.py +185 -0
- real_ladybug/torch_geometric_graph_store.py +131 -0
- real_ladybug/torch_geometric_result_converter.py +282 -0
- real_ladybug/types.py +39 -0
- real_ladybug-0.0.1.dev1.dist-info/METADATA +88 -0
- real_ladybug-0.0.1.dev1.dist-info/RECORD +114 -0
- real_ladybug-0.0.1.dev1.dist-info/WHEEL +5 -0
- real_ladybug-0.0.1.dev1.dist-info/licenses/LICENSE +21 -0
- real_ladybug-0.0.1.dev1.dist-info/top_level.txt +3 -0
- real_ladybug-0.0.1.dev1.dist-info/zip-safe +1 -0
- real_ladybug-source/scripts/antlr4/hash.py +2 -0
- real_ladybug-source/scripts/antlr4/keywordhandler.py +47 -0
- real_ladybug-source/scripts/collect-extensions.py +68 -0
- real_ladybug-source/scripts/collect-single-file-header.py +126 -0
- real_ladybug-source/scripts/export-dbs.py +101 -0
- real_ladybug-source/scripts/export-import-test.py +345 -0
- real_ladybug-source/scripts/extension/purge-beta.py +34 -0
- real_ladybug-source/scripts/generate-cpp-docs/collect_files.py +122 -0
- real_ladybug-source/scripts/generate-tinysnb.py +34 -0
- real_ladybug-source/scripts/get-clangd-diagnostics.py +233 -0
- real_ladybug-source/scripts/migrate-lbug-db.py +308 -0
- real_ladybug-source/scripts/multiplatform-test-helper/collect-results.py +71 -0
- real_ladybug-source/scripts/multiplatform-test-helper/notify-discord.py +68 -0
- real_ladybug-source/scripts/pip-package/package_tar.py +90 -0
- real_ladybug-source/scripts/pip-package/setup.py +130 -0
- real_ladybug-source/scripts/run-clang-format.py +408 -0
- real_ladybug-source/scripts/setup-extension-repo.py +67 -0
- real_ladybug-source/scripts/test-simsimd-dispatch.py +45 -0
- real_ladybug-source/scripts/update-nightly-build-version.py +81 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-01-download-rfc.py +16 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-02-rfc-to-bin.py +34 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-03-validate-bin.py +35 -0
- real_ladybug-source/third_party/brotli/scripts/dictionary/step-04-generate-java-literals.py +85 -0
- real_ladybug-source/third_party/pybind11/tools/codespell_ignore_lines_from_errors.py +35 -0
- real_ladybug-source/third_party/pybind11/tools/libsize.py +36 -0
- real_ladybug-source/third_party/pybind11/tools/make_changelog.py +63 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/__init__.py +83 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/async_connection.py +226 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/connection.py +323 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/constants.py +7 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/database.py +307 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/prepared_statement.py +51 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/py.typed +0 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/query_result.py +511 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_feature_store.py +185 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_graph_store.py +131 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_result_converter.py +282 -0
- real_ladybug-source/tools/python_api/build/real_ladybug/types.py +39 -0
- real_ladybug-source/tools/python_api/src_py/__init__.py +83 -0
- real_ladybug-source/tools/python_api/src_py/async_connection.py +226 -0
- real_ladybug-source/tools/python_api/src_py/connection.py +323 -0
- real_ladybug-source/tools/python_api/src_py/constants.py +7 -0
- real_ladybug-source/tools/python_api/src_py/database.py +307 -0
- real_ladybug-source/tools/python_api/src_py/prepared_statement.py +51 -0
- real_ladybug-source/tools/python_api/src_py/py.typed +0 -0
- real_ladybug-source/tools/python_api/src_py/query_result.py +511 -0
- real_ladybug-source/tools/python_api/src_py/torch_geometric_feature_store.py +185 -0
- real_ladybug-source/tools/python_api/src_py/torch_geometric_graph_store.py +131 -0
- real_ladybug-source/tools/python_api/src_py/torch_geometric_result_converter.py +282 -0
- real_ladybug-source/tools/python_api/src_py/types.py +39 -0
- real_ladybug-source/tools/python_api/test/conftest.py +230 -0
- real_ladybug-source/tools/python_api/test/disabled_test_extension.py +73 -0
- real_ladybug-source/tools/python_api/test/ground_truth.py +430 -0
- real_ladybug-source/tools/python_api/test/test_arrow.py +694 -0
- real_ladybug-source/tools/python_api/test/test_async_connection.py +159 -0
- real_ladybug-source/tools/python_api/test/test_blob_parameter.py +145 -0
- real_ladybug-source/tools/python_api/test/test_connection.py +49 -0
- real_ladybug-source/tools/python_api/test/test_database.py +234 -0
- real_ladybug-source/tools/python_api/test/test_datatype.py +372 -0
- real_ladybug-source/tools/python_api/test/test_df.py +564 -0
- real_ladybug-source/tools/python_api/test/test_dict.py +112 -0
- real_ladybug-source/tools/python_api/test/test_exception.py +54 -0
- real_ladybug-source/tools/python_api/test/test_fsm.py +227 -0
- real_ladybug-source/tools/python_api/test/test_get_header.py +49 -0
- real_ladybug-source/tools/python_api/test/test_helper.py +8 -0
- real_ladybug-source/tools/python_api/test/test_issue.py +147 -0
- real_ladybug-source/tools/python_api/test/test_iteration.py +96 -0
- real_ladybug-source/tools/python_api/test/test_networkx.py +437 -0
- real_ladybug-source/tools/python_api/test/test_parameter.py +340 -0
- real_ladybug-source/tools/python_api/test/test_prepared_statement.py +117 -0
- real_ladybug-source/tools/python_api/test/test_query_result.py +54 -0
- real_ladybug-source/tools/python_api/test/test_query_result_close.py +44 -0
- real_ladybug-source/tools/python_api/test/test_scan_pandas.py +676 -0
- real_ladybug-source/tools/python_api/test/test_scan_pandas_pyarrow.py +714 -0
- real_ladybug-source/tools/python_api/test/test_scan_polars.py +165 -0
- real_ladybug-source/tools/python_api/test/test_scan_pyarrow.py +167 -0
- real_ladybug-source/tools/python_api/test/test_timeout.py +11 -0
- real_ladybug-source/tools/python_api/test/test_torch_geometric.py +640 -0
- real_ladybug-source/tools/python_api/test/test_torch_geometric_remote_backend.py +111 -0
- real_ladybug-source/tools/python_api/test/test_udf.py +207 -0
- real_ladybug-source/tools/python_api/test/test_version.py +6 -0
- real_ladybug-source/tools/python_api/test/test_wal.py +80 -0
- real_ladybug-source/tools/python_api/test/type_aliases.py +10 -0
- real_ladybug-source/tools/rust_api/update_version.py +47 -0
- real_ladybug-source/tools/shell/test/conftest.py +218 -0
- real_ladybug-source/tools/shell/test/test_helper.py +60 -0
- real_ladybug-source/tools/shell/test/test_shell_basics.py +325 -0
- real_ladybug-source/tools/shell/test/test_shell_commands.py +656 -0
- real_ladybug-source/tools/shell/test/test_shell_control_edit.py +438 -0
- real_ladybug-source/tools/shell/test/test_shell_control_search.py +468 -0
- real_ladybug-source/tools/shell/test/test_shell_esc_edit.py +232 -0
- real_ladybug-source/tools/shell/test/test_shell_esc_search.py +162 -0
- real_ladybug-source/tools/shell/test/test_shell_flags.py +645 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import multiprocessing
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore
|
|
10
|
+
|
|
11
|
+
from .connection import Connection
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import sys
|
|
15
|
+
|
|
16
|
+
from torch_geometric.typing import EdgeTensorType
|
|
17
|
+
|
|
18
|
+
from .database import Database
|
|
19
|
+
|
|
20
|
+
if sys.version_info >= (3, 10):
|
|
21
|
+
from typing import TypeAlias
|
|
22
|
+
else:
|
|
23
|
+
from typing_extensions import TypeAlias
|
|
24
|
+
|
|
25
|
+
StoreKeyType: TypeAlias = tuple[tuple[str], Any, bool]
|
|
26
|
+
|
|
27
|
+
REL_BATCH_SIZE = 1000000
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class Rel: # noqa: D101
|
|
32
|
+
edge_type: tuple[str, ...]
|
|
33
|
+
layout: str
|
|
34
|
+
is_sorted: bool
|
|
35
|
+
size: tuple[int, ...]
|
|
36
|
+
materialized: bool = False
|
|
37
|
+
edge_index: EdgeTensorType | None = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class LbugGraphStore(GraphStore): # type: ignore[misc]
|
|
41
|
+
"""Graph store compatible with `torch_geometric`."""
|
|
42
|
+
|
|
43
|
+
def __init__(self, db: Database, num_threads: int | None = None):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.db = db
|
|
46
|
+
self.connection: Connection | None = None
|
|
47
|
+
self.store: dict[StoreKeyType, Rel] = {}
|
|
48
|
+
if num_threads is None:
|
|
49
|
+
num_threads = multiprocessing.cpu_count()
|
|
50
|
+
self.num_threads = num_threads
|
|
51
|
+
self.__populate_edge_attrs()
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def key(attr: EdgeAttr) -> tuple[tuple[str], Any, bool]: # noqa: D102
|
|
55
|
+
return attr.edge_type, attr.layout.value, attr.is_sorted
|
|
56
|
+
|
|
57
|
+
def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> None:
|
|
58
|
+
key = self.key(edge_attr)
|
|
59
|
+
if key in self.store:
|
|
60
|
+
self.store[key].edge_index = edge_index
|
|
61
|
+
self.store[key].materialized = True
|
|
62
|
+
self.store[key].size = edge_attr.size
|
|
63
|
+
else:
|
|
64
|
+
self.store[key] = Rel(key[0], key[1], key[2], edge_attr.size, True, edge_index)
|
|
65
|
+
|
|
66
|
+
def _get_edge_index(self, edge_attr: EdgeAttr) -> EdgeTensorType | None:
|
|
67
|
+
if edge_attr.layout.value == EdgeLayout.COO.value: # noqa: SIM102
|
|
68
|
+
# We always return a sorted COO edge index, if the request is
|
|
69
|
+
# for an unsorted COO edge index, we change the is_sorted flag
|
|
70
|
+
# to True and return the sorted COO edge index.
|
|
71
|
+
if edge_attr.is_sorted is False:
|
|
72
|
+
edge_attr.is_sorted = True
|
|
73
|
+
|
|
74
|
+
key = self.key(edge_attr)
|
|
75
|
+
if key in self.store:
|
|
76
|
+
rel = self.store[self.key(edge_attr)]
|
|
77
|
+
if not rel.materialized and rel.layout != EdgeLayout.COO.value:
|
|
78
|
+
msg = "Only COO layout is supported"
|
|
79
|
+
raise ValueError(msg)
|
|
80
|
+
|
|
81
|
+
if rel.layout == EdgeLayout.COO.value:
|
|
82
|
+
self.__get_edge_coo_from_database(self.key(edge_attr))
|
|
83
|
+
return rel.edge_index
|
|
84
|
+
else:
|
|
85
|
+
return None
|
|
86
|
+
|
|
87
|
+
def _remove_edge_index(self, edge_attr: EdgeAttr) -> None:
|
|
88
|
+
key = self.key(edge_attr)
|
|
89
|
+
if key in self.store:
|
|
90
|
+
del self.store[key]
|
|
91
|
+
|
|
92
|
+
def get_all_edge_attrs(self) -> list[EdgeAttr]:
|
|
93
|
+
"""Return all EdgeAttr from the store values."""
|
|
94
|
+
return [EdgeAttr(rel.edge_type, rel.layout, rel.is_sorted, rel.size) for rel in self.store.values()]
|
|
95
|
+
|
|
96
|
+
def __get_edge_coo_from_database(self, key: StoreKeyType) -> None:
|
|
97
|
+
if not self.connection:
|
|
98
|
+
self.connection = Connection(self.db, self.num_threads)
|
|
99
|
+
|
|
100
|
+
rel = self.store[key]
|
|
101
|
+
if rel.layout != EdgeLayout.COO.value:
|
|
102
|
+
msg = "Only COO layout is supported"
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
if rel.materialized:
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
edge_type = rel.edge_type
|
|
108
|
+
num_edges = self.connection._connection.get_num_rels(edge_type[1])
|
|
109
|
+
result = np.empty(2 * num_edges, dtype=np.int64)
|
|
110
|
+
self.connection._connection.get_all_edges_for_torch_geometric(
|
|
111
|
+
result, edge_type[0], edge_type[1], edge_type[2], REL_BATCH_SIZE
|
|
112
|
+
)
|
|
113
|
+
edge_list = torch.from_numpy(result)
|
|
114
|
+
edge_list = edge_list.reshape((2, edge_list.shape[0] // 2))
|
|
115
|
+
rel.edge_index = edge_list
|
|
116
|
+
rel.materialized = True
|
|
117
|
+
|
|
118
|
+
def __populate_edge_attrs(self) -> None:
|
|
119
|
+
if not self.connection:
|
|
120
|
+
self.connection = Connection(self.db, self.num_threads)
|
|
121
|
+
rel_tables = self.connection._get_rel_table_names()
|
|
122
|
+
for rel_table in rel_tables:
|
|
123
|
+
edge_type = (rel_table["src"], rel_table["name"], rel_table["dst"])
|
|
124
|
+
size = self.__get_size(edge_type)
|
|
125
|
+
rel = Rel(edge_type, EdgeLayout.COO.value, True, size, False, None)
|
|
126
|
+
self.store[self.key(EdgeAttr(edge_type, EdgeLayout.COO, True))] = rel
|
|
127
|
+
|
|
128
|
+
def __get_size(self, edge_type: tuple[str, ...]) -> tuple[int, int]:
|
|
129
|
+
num_nodes = self.connection._connection.get_num_nodes # type: ignore[union-attr]
|
|
130
|
+
src_count, dst_count = num_nodes(edge_type[0]), num_nodes(edge_type[-1])
|
|
131
|
+
return (src_count, dst_count)
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import warnings
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
from .types import Type
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
import torch_geometric.data as geo
|
|
10
|
+
|
|
11
|
+
from .query_result import QueryResult
|
|
12
|
+
|
|
13
|
+
from .constants import ID, LABEL, SRC, DST
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TorchGeometricResultConverter:
|
|
17
|
+
"""Convert graph results to `torch_geometric`."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, query_result: QueryResult):
|
|
20
|
+
self.query_result = query_result
|
|
21
|
+
self.nodes_dict: dict[str, Any] = {}
|
|
22
|
+
self.edges_dict: dict[str, Any] = {}
|
|
23
|
+
self.edges_properties: dict[str | tuple[str, str], dict[str, Any]] = {}
|
|
24
|
+
self.rels: dict[tuple[Any, ...], dict[str, Any]] = {}
|
|
25
|
+
self.nodes_property_names_dict: dict[str, Any] = {}
|
|
26
|
+
self.table_to_label_dict: dict[int, str] = {}
|
|
27
|
+
self.internal_id_to_pos_dict: dict[tuple[int, int], int | None] = {}
|
|
28
|
+
self.pos_to_primary_key_dict: dict[str, Any] = {}
|
|
29
|
+
self.warning_messages: set[str] = set()
|
|
30
|
+
self.unconverted_properties: dict[str, Any] = {}
|
|
31
|
+
self.properties_to_extract = self.query_result._get_properties_to_extract()
|
|
32
|
+
|
|
33
|
+
def __get_node_property_names(self, table_name: str) -> dict[str, Any]:
|
|
34
|
+
if table_name in self.nodes_property_names_dict:
|
|
35
|
+
return self.nodes_property_names_dict[table_name]
|
|
36
|
+
results = self.query_result.connection._get_node_property_names(table_name)
|
|
37
|
+
self.nodes_property_names_dict[table_name] = results
|
|
38
|
+
return results
|
|
39
|
+
|
|
40
|
+
def __populate_nodes_dict_and_deduplicte_edges(self) -> None:
|
|
41
|
+
self.query_result.reset_iterator()
|
|
42
|
+
while self.query_result.has_next():
|
|
43
|
+
row = self.query_result.get_next()
|
|
44
|
+
for i in self.properties_to_extract:
|
|
45
|
+
column_type, _ = self.properties_to_extract[i]
|
|
46
|
+
if column_type == Type.NODE.value:
|
|
47
|
+
node = row[i]
|
|
48
|
+
label = node[LABEL]
|
|
49
|
+
nid = node[ID]
|
|
50
|
+
self.table_to_label_dict[nid["table"]] = label
|
|
51
|
+
|
|
52
|
+
if (nid["table"], nid["offset"]) in self.internal_id_to_pos_dict:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
node_property_names = self.__get_node_property_names(label)
|
|
56
|
+
|
|
57
|
+
pos, primary_key = self.__extract_properties_from_node(node, label, node_property_names)
|
|
58
|
+
|
|
59
|
+
self.internal_id_to_pos_dict[nid["table"], nid["offset"]] = pos
|
|
60
|
+
if label not in self.pos_to_primary_key_dict:
|
|
61
|
+
self.pos_to_primary_key_dict[label] = {}
|
|
62
|
+
self.pos_to_primary_key_dict[label][pos] = primary_key
|
|
63
|
+
|
|
64
|
+
elif column_type == Type.REL.value:
|
|
65
|
+
src = row[i][SRC]
|
|
66
|
+
dst = row[i][DST]
|
|
67
|
+
self.rels[src["table"], src["offset"], dst["table"], dst["offset"]] = row[i]
|
|
68
|
+
|
|
69
|
+
def __extract_properties_from_node(
|
|
70
|
+
self,
|
|
71
|
+
node: dict[str, Any],
|
|
72
|
+
label: str,
|
|
73
|
+
node_property_names: dict[str, Any],
|
|
74
|
+
) -> tuple[int | None, Any]:
|
|
75
|
+
pos = None
|
|
76
|
+
import torch
|
|
77
|
+
|
|
78
|
+
for prop_name in node_property_names:
|
|
79
|
+
# Read primary key
|
|
80
|
+
if node_property_names[prop_name]["is_primary_key"]:
|
|
81
|
+
primary_key = node[prop_name]
|
|
82
|
+
|
|
83
|
+
# If property is already marked as unconverted, then add it directly without further checks
|
|
84
|
+
if label in self.unconverted_properties and prop_name in self.unconverted_properties[label]:
|
|
85
|
+
pos = self.__add_unconverted_property(node, label, prop_name)
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
# Mark properties that are not supported by torch_geometric as unconverted
|
|
89
|
+
if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]:
|
|
90
|
+
self.warning_messages.add(
|
|
91
|
+
"Property {}.{} of type {} is not supported by torch_geometric. The property is marked as unconverted.".format(
|
|
92
|
+
label, prop_name, node_property_names[prop_name]["type"]
|
|
93
|
+
)
|
|
94
|
+
)
|
|
95
|
+
self.__mark_property_unconverted(label, prop_name)
|
|
96
|
+
pos = self.__add_unconverted_property(node, label, prop_name)
|
|
97
|
+
continue
|
|
98
|
+
if node[prop_name] is None:
|
|
99
|
+
self.warning_messages.add(
|
|
100
|
+
f"Property {label}.{prop_name} has a null value. torch_geometric does not support null values. The property is marked as unconverted."
|
|
101
|
+
)
|
|
102
|
+
self.__mark_property_unconverted(label, prop_name)
|
|
103
|
+
pos = self.__add_unconverted_property(node, label, prop_name)
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
if node_property_names[prop_name]["dimension"] == 0:
|
|
107
|
+
curr_value = node[prop_name]
|
|
108
|
+
else:
|
|
109
|
+
try:
|
|
110
|
+
if node_property_names[prop_name]["type"] == Type.INT64.value:
|
|
111
|
+
curr_value = torch.LongTensor(node[prop_name])
|
|
112
|
+
elif node_property_names[prop_name]["type"] == Type.DOUBLE.value:
|
|
113
|
+
curr_value = torch.FloatTensor(node[prop_name])
|
|
114
|
+
elif node_property_names[prop_name]["type"] == Type.BOOL.value:
|
|
115
|
+
curr_value = torch.BoolTensor(node[prop_name])
|
|
116
|
+
except ValueError:
|
|
117
|
+
self.warning_messages.add(
|
|
118
|
+
f"Property {label}.{prop_name} cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted."
|
|
119
|
+
)
|
|
120
|
+
self.__mark_property_unconverted(label, prop_name)
|
|
121
|
+
pos = self.__add_unconverted_property(node, label, prop_name)
|
|
122
|
+
continue
|
|
123
|
+
|
|
124
|
+
# Check if the shape of the property is consistent
|
|
125
|
+
if label in self.nodes_dict and prop_name in self.nodes_dict[label]: # noqa: SIM102
|
|
126
|
+
# If the shape is inconsistent, then mark the property as unconverted
|
|
127
|
+
if curr_value.shape != self.nodes_dict[label][prop_name][0].shape:
|
|
128
|
+
self.warning_messages.add(
|
|
129
|
+
f"Property {label}.{prop_name} has an inconsistent shape. The property is marked as unconverted."
|
|
130
|
+
)
|
|
131
|
+
self.__mark_property_unconverted(label, prop_name)
|
|
132
|
+
pos = self.__add_unconverted_property(node, label, prop_name)
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
# Create the dictionary for the label if it does not exist
|
|
136
|
+
if label not in self.nodes_dict:
|
|
137
|
+
self.nodes_dict[label] = {}
|
|
138
|
+
if prop_name not in self.nodes_dict[label]:
|
|
139
|
+
self.nodes_dict[label][prop_name] = []
|
|
140
|
+
|
|
141
|
+
# Add the property to the dictionary
|
|
142
|
+
self.nodes_dict[label][prop_name].append(curr_value)
|
|
143
|
+
|
|
144
|
+
# The pos will be overwritten for each property, but
|
|
145
|
+
# it should be the same for all properties
|
|
146
|
+
pos = len(self.nodes_dict[label][prop_name]) - 1
|
|
147
|
+
return pos, primary_key
|
|
148
|
+
|
|
149
|
+
def __add_unconverted_property(self, node: dict[str, Any], label: str, prop_name: str) -> int:
|
|
150
|
+
self.unconverted_properties[label][prop_name].append(node[prop_name])
|
|
151
|
+
return len(self.unconverted_properties[label][prop_name]) - 1
|
|
152
|
+
|
|
153
|
+
def __mark_property_unconverted(self, label: str, prop_name: str) -> None:
|
|
154
|
+
import torch
|
|
155
|
+
|
|
156
|
+
if label not in self.unconverted_properties:
|
|
157
|
+
self.unconverted_properties[label] = {}
|
|
158
|
+
if prop_name not in self.unconverted_properties[label]:
|
|
159
|
+
if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
|
|
160
|
+
self.unconverted_properties[label][prop_name] = self.nodes_dict[label][prop_name]
|
|
161
|
+
del self.nodes_dict[label][prop_name]
|
|
162
|
+
if len(self.nodes_dict[label]) == 0:
|
|
163
|
+
del self.nodes_dict[label]
|
|
164
|
+
for i in range(len(self.unconverted_properties[label][prop_name])):
|
|
165
|
+
# If the property is a tensor, convert it back to list (consistent with the original type)
|
|
166
|
+
if torch.is_tensor(self.unconverted_properties[label][prop_name][i]): # type: ignore[no-untyped-call]
|
|
167
|
+
self.unconverted_properties[label][prop_name][i] = self.unconverted_properties[label][
|
|
168
|
+
prop_name
|
|
169
|
+
][i].tolist()
|
|
170
|
+
else:
|
|
171
|
+
self.unconverted_properties[label][prop_name] = []
|
|
172
|
+
|
|
173
|
+
def __populate_edges_dict(self) -> None:
|
|
174
|
+
# Post-process edges, map internal ids to positions
|
|
175
|
+
for r in self.rels:
|
|
176
|
+
src_pos = self.internal_id_to_pos_dict[r[0], r[1]]
|
|
177
|
+
dst_pos = self.internal_id_to_pos_dict[r[2], r[3]]
|
|
178
|
+
src_label = self.table_to_label_dict[r[0]]
|
|
179
|
+
dst_label = self.table_to_label_dict[r[2]]
|
|
180
|
+
if src_label not in self.edges_dict:
|
|
181
|
+
self.edges_dict[src_label] = {}
|
|
182
|
+
if dst_label not in self.edges_dict[src_label]:
|
|
183
|
+
self.edges_dict[src_label][dst_label] = []
|
|
184
|
+
self.edges_dict[src_label][dst_label].append((src_pos, dst_pos))
|
|
185
|
+
curr_edge_properties = self.rels[r]
|
|
186
|
+
if (src_label, dst_label) not in self.edges_properties:
|
|
187
|
+
self.edges_properties[src_label, dst_label] = {}
|
|
188
|
+
for prop_name in curr_edge_properties:
|
|
189
|
+
if prop_name in [SRC, DST, ID]:
|
|
190
|
+
continue
|
|
191
|
+
if prop_name not in self.edges_properties[src_label, dst_label]:
|
|
192
|
+
self.edges_properties[src_label, dst_label][prop_name] = []
|
|
193
|
+
self.edges_properties[src_label, dst_label][prop_name].append(curr_edge_properties[prop_name])
|
|
194
|
+
|
|
195
|
+
def __print_warnings(self) -> None:
|
|
196
|
+
for message in self.warning_messages:
|
|
197
|
+
warnings.warn(message, stacklevel=2)
|
|
198
|
+
|
|
199
|
+
def __convert_to_torch_geometric(
|
|
200
|
+
self,
|
|
201
|
+
) -> tuple[
|
|
202
|
+
geo.Data | geo.HeteroData,
|
|
203
|
+
dict[str, Any],
|
|
204
|
+
dict[str, Any],
|
|
205
|
+
dict[str | tuple[str, str], dict[str, Any]],
|
|
206
|
+
]:
|
|
207
|
+
import torch
|
|
208
|
+
import torch_geometric
|
|
209
|
+
|
|
210
|
+
if len(self.nodes_dict) == 0:
|
|
211
|
+
self.warning_messages.add("No nodes found or all node properties are not converted.")
|
|
212
|
+
|
|
213
|
+
# If there is only one node type, then convert to torch_geometric.data.Data
|
|
214
|
+
# Otherwise, convert to torch_geometric.data.HeteroData
|
|
215
|
+
if len(self.nodes_dict) == 1:
|
|
216
|
+
data = torch_geometric.data.Data()
|
|
217
|
+
is_hetero = False
|
|
218
|
+
else:
|
|
219
|
+
data = torch_geometric.data.HeteroData()
|
|
220
|
+
is_hetero = True
|
|
221
|
+
|
|
222
|
+
# Convert nodes to tensors
|
|
223
|
+
converted: torch.Tensor
|
|
224
|
+
for label in self.nodes_dict:
|
|
225
|
+
for prop_name in self.nodes_dict[label]:
|
|
226
|
+
prop_type = self.nodes_property_names_dict[label][prop_name]["type"]
|
|
227
|
+
prop_dimension = self.nodes_property_names_dict[label][prop_name]["dimension"]
|
|
228
|
+
if prop_dimension == 0:
|
|
229
|
+
if prop_type == Type.INT64.value:
|
|
230
|
+
converted = torch.LongTensor(self.nodes_dict[label][prop_name])
|
|
231
|
+
elif prop_type == Type.BOOL.value:
|
|
232
|
+
converted = torch.BoolTensor(self.nodes_dict[label][prop_name])
|
|
233
|
+
elif prop_type == Type.DOUBLE.value:
|
|
234
|
+
converted = torch.FloatTensor(self.nodes_dict[label][prop_name])
|
|
235
|
+
else:
|
|
236
|
+
converted = torch.stack(self.nodes_dict[label][prop_name], dim=0)
|
|
237
|
+
if is_hetero:
|
|
238
|
+
data[label][prop_name] = converted
|
|
239
|
+
else:
|
|
240
|
+
data[prop_name] = converted
|
|
241
|
+
|
|
242
|
+
# Convert edges to tensors
|
|
243
|
+
for src_label in self.edges_dict:
|
|
244
|
+
for dst_label in self.edges_dict[src_label]:
|
|
245
|
+
edge_idx = torch.tensor(self.edges_dict[src_label][dst_label], dtype=torch.long).t().contiguous()
|
|
246
|
+
if is_hetero:
|
|
247
|
+
data[src_label, dst_label].edge_index = edge_idx
|
|
248
|
+
else:
|
|
249
|
+
data.edge_index = edge_idx
|
|
250
|
+
|
|
251
|
+
pos_to_primary_key_dict: dict[str, Any] = (
|
|
252
|
+
self.pos_to_primary_key_dict[label] if not is_hetero else self.pos_to_primary_key_dict
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
if is_hetero:
|
|
256
|
+
unconverted_properties = self.unconverted_properties
|
|
257
|
+
edge_properties = self.edges_properties
|
|
258
|
+
else:
|
|
259
|
+
if len(self.unconverted_properties) == 0:
|
|
260
|
+
unconverted_properties = {}
|
|
261
|
+
else:
|
|
262
|
+
unconverted_properties = self.unconverted_properties[next(iter(self.unconverted_properties))]
|
|
263
|
+
if len(self.edges_properties) == 0:
|
|
264
|
+
edge_properties = {}
|
|
265
|
+
else:
|
|
266
|
+
edge_properties = self.edges_properties[next(iter(self.edges_properties))] # type: ignore[assignment]
|
|
267
|
+
return data, pos_to_primary_key_dict, unconverted_properties, edge_properties
|
|
268
|
+
|
|
269
|
+
def get_as_torch_geometric(
|
|
270
|
+
self,
|
|
271
|
+
) -> tuple[
|
|
272
|
+
geo.Data | geo.HeteroData,
|
|
273
|
+
dict[str, Any],
|
|
274
|
+
dict[str, Any],
|
|
275
|
+
dict[str | tuple[str, str], dict[str, Any]],
|
|
276
|
+
]:
|
|
277
|
+
"""Convert graph data to `torch_geometric`."""
|
|
278
|
+
self.__populate_nodes_dict_and_deduplicte_edges()
|
|
279
|
+
self.__populate_edges_dict()
|
|
280
|
+
result = self.__convert_to_torch_geometric()
|
|
281
|
+
self.__print_warnings()
|
|
282
|
+
return result
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Type(Enum):
|
|
5
|
+
"""The type of a value in the database."""
|
|
6
|
+
|
|
7
|
+
ANY = "ANY"
|
|
8
|
+
NODE = "NODE"
|
|
9
|
+
REL = "REL"
|
|
10
|
+
RECURSIVE_REL = "RECURSIVE_REL"
|
|
11
|
+
SERIAL = "SERIAL"
|
|
12
|
+
BOOL = "BOOL"
|
|
13
|
+
INT64 = "INT64"
|
|
14
|
+
INT32 = "INT32"
|
|
15
|
+
INT16 = "INT16"
|
|
16
|
+
INT8 = "INT8"
|
|
17
|
+
UINT64 = "UINT64"
|
|
18
|
+
UINT32 = "UINT32"
|
|
19
|
+
UINT16 = "UINT16"
|
|
20
|
+
UINT8 = "UINT8"
|
|
21
|
+
INT128 = "INT128"
|
|
22
|
+
DOUBLE = "DOUBLE"
|
|
23
|
+
FLOAT = "FLOAT"
|
|
24
|
+
DATE = "DATE"
|
|
25
|
+
TIMESTAMP = "TIMESTAMP"
|
|
26
|
+
TIMSTAMP_TZ = "TIMESTAMP_TZ"
|
|
27
|
+
TIMESTAMP_NS = "TIMESTAMP_NS"
|
|
28
|
+
TIMESTAMP_MS = "TIMESTAMP_MS"
|
|
29
|
+
TIMESTAMP_SEC = "TIMESTAMP_SEC"
|
|
30
|
+
INTERVAL = "INTERVAL"
|
|
31
|
+
INTERNAL_ID = "INTERNAL_ID"
|
|
32
|
+
STRING = "STRING"
|
|
33
|
+
BLOB = "BLOB"
|
|
34
|
+
UUID = "UUID"
|
|
35
|
+
LIST = "LIST"
|
|
36
|
+
ARRAY = "ARRAY"
|
|
37
|
+
STRUCT = "STRUCT"
|
|
38
|
+
MAP = "MAP"
|
|
39
|
+
UNION = "UNION"
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""
|
|
2
|
+
# Lbug Python API bindings.
|
|
3
|
+
|
|
4
|
+
This package provides a Python API for Lbug graph database management system.
|
|
5
|
+
|
|
6
|
+
To install the package, run:
|
|
7
|
+
```
|
|
8
|
+
python3 -m pip install real_ladybug
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
Example usage:
|
|
12
|
+
```python
|
|
13
|
+
import real_ladybug as lb
|
|
14
|
+
|
|
15
|
+
db = lb.Database("./test")
|
|
16
|
+
conn = lb.Connection(db)
|
|
17
|
+
|
|
18
|
+
# Define the schema
|
|
19
|
+
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
|
|
20
|
+
conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
|
|
21
|
+
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
|
|
22
|
+
conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")
|
|
23
|
+
|
|
24
|
+
# Load some data
|
|
25
|
+
conn.execute('COPY User FROM "user.csv"')
|
|
26
|
+
conn.execute('COPY City FROM "city.csv"')
|
|
27
|
+
conn.execute('COPY Follows FROM "follows.csv"')
|
|
28
|
+
conn.execute('COPY LivesIn FROM "lives-in.csv"')
|
|
29
|
+
|
|
30
|
+
# Query the data
|
|
31
|
+
results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
|
|
32
|
+
while results.has_next():
|
|
33
|
+
print(results.get_next())
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
The dataset used in this example can be found [here](https://github.com/LadybugDB/ladybug/tree/master/dataset/demo-db/csv).
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
from __future__ import annotations
|
|
41
|
+
|
|
42
|
+
import os
|
|
43
|
+
import sys
|
|
44
|
+
|
|
45
|
+
# Set RTLD_GLOBAL and RTLD_LAZY flags on Linux to fix the issue with loading
|
|
46
|
+
# extensions
|
|
47
|
+
if sys.platform == "linux":
|
|
48
|
+
original_dlopen_flags = sys.getdlopenflags()
|
|
49
|
+
sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
|
|
50
|
+
|
|
51
|
+
from .async_connection import AsyncConnection
|
|
52
|
+
from .connection import Connection
|
|
53
|
+
from .database import Database
|
|
54
|
+
from .prepared_statement import PreparedStatement
|
|
55
|
+
from .query_result import QueryResult
|
|
56
|
+
from .types import Type
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def __getattr__(name: str) -> str | int:
|
|
60
|
+
if name in ("version", "__version__"):
|
|
61
|
+
return Database.get_version()
|
|
62
|
+
elif name == "storage_version":
|
|
63
|
+
return Database.get_storage_version()
|
|
64
|
+
else:
|
|
65
|
+
msg = f"module {__name__!r} has no attribute {name!r}"
|
|
66
|
+
raise AttributeError(msg)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# Restore the original dlopen flags
|
|
70
|
+
if sys.platform == "linux":
|
|
71
|
+
sys.setdlopenflags(original_dlopen_flags)
|
|
72
|
+
|
|
73
|
+
__all__ = [
|
|
74
|
+
"AsyncConnection",
|
|
75
|
+
"Connection",
|
|
76
|
+
"Database",
|
|
77
|
+
"PreparedStatement",
|
|
78
|
+
"QueryResult",
|
|
79
|
+
"Type",
|
|
80
|
+
"__version__", # noqa: F822
|
|
81
|
+
"storage_version", # noqa: F822
|
|
82
|
+
"version", # noqa: F822
|
|
83
|
+
]
|