real-ladybug 0.0.1.dev1__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of real-ladybug might be problematic. Click here for more details.

Files changed (114) hide show
  1. real_ladybug/__init__.py +83 -0
  2. real_ladybug/_lbug.cp312-win_amd64.pyd +0 -0
  3. real_ladybug/_lbug.exp +0 -0
  4. real_ladybug/_lbug.lib +0 -0
  5. real_ladybug/async_connection.py +226 -0
  6. real_ladybug/connection.py +323 -0
  7. real_ladybug/constants.py +7 -0
  8. real_ladybug/database.py +307 -0
  9. real_ladybug/prepared_statement.py +51 -0
  10. real_ladybug/py.typed +0 -0
  11. real_ladybug/query_result.py +511 -0
  12. real_ladybug/torch_geometric_feature_store.py +185 -0
  13. real_ladybug/torch_geometric_graph_store.py +131 -0
  14. real_ladybug/torch_geometric_result_converter.py +282 -0
  15. real_ladybug/types.py +39 -0
  16. real_ladybug-0.0.1.dev1.dist-info/METADATA +88 -0
  17. real_ladybug-0.0.1.dev1.dist-info/RECORD +114 -0
  18. real_ladybug-0.0.1.dev1.dist-info/WHEEL +5 -0
  19. real_ladybug-0.0.1.dev1.dist-info/licenses/LICENSE +21 -0
  20. real_ladybug-0.0.1.dev1.dist-info/top_level.txt +3 -0
  21. real_ladybug-0.0.1.dev1.dist-info/zip-safe +1 -0
  22. real_ladybug-source/scripts/antlr4/hash.py +2 -0
  23. real_ladybug-source/scripts/antlr4/keywordhandler.py +47 -0
  24. real_ladybug-source/scripts/collect-extensions.py +68 -0
  25. real_ladybug-source/scripts/collect-single-file-header.py +126 -0
  26. real_ladybug-source/scripts/export-dbs.py +101 -0
  27. real_ladybug-source/scripts/export-import-test.py +345 -0
  28. real_ladybug-source/scripts/extension/purge-beta.py +34 -0
  29. real_ladybug-source/scripts/generate-cpp-docs/collect_files.py +122 -0
  30. real_ladybug-source/scripts/generate-tinysnb.py +34 -0
  31. real_ladybug-source/scripts/get-clangd-diagnostics.py +233 -0
  32. real_ladybug-source/scripts/migrate-lbug-db.py +308 -0
  33. real_ladybug-source/scripts/multiplatform-test-helper/collect-results.py +71 -0
  34. real_ladybug-source/scripts/multiplatform-test-helper/notify-discord.py +68 -0
  35. real_ladybug-source/scripts/pip-package/package_tar.py +90 -0
  36. real_ladybug-source/scripts/pip-package/setup.py +130 -0
  37. real_ladybug-source/scripts/run-clang-format.py +408 -0
  38. real_ladybug-source/scripts/setup-extension-repo.py +67 -0
  39. real_ladybug-source/scripts/test-simsimd-dispatch.py +45 -0
  40. real_ladybug-source/scripts/update-nightly-build-version.py +81 -0
  41. real_ladybug-source/third_party/brotli/scripts/dictionary/step-01-download-rfc.py +16 -0
  42. real_ladybug-source/third_party/brotli/scripts/dictionary/step-02-rfc-to-bin.py +34 -0
  43. real_ladybug-source/third_party/brotli/scripts/dictionary/step-03-validate-bin.py +35 -0
  44. real_ladybug-source/third_party/brotli/scripts/dictionary/step-04-generate-java-literals.py +85 -0
  45. real_ladybug-source/third_party/pybind11/tools/codespell_ignore_lines_from_errors.py +35 -0
  46. real_ladybug-source/third_party/pybind11/tools/libsize.py +36 -0
  47. real_ladybug-source/third_party/pybind11/tools/make_changelog.py +63 -0
  48. real_ladybug-source/tools/python_api/build/real_ladybug/__init__.py +83 -0
  49. real_ladybug-source/tools/python_api/build/real_ladybug/async_connection.py +226 -0
  50. real_ladybug-source/tools/python_api/build/real_ladybug/connection.py +323 -0
  51. real_ladybug-source/tools/python_api/build/real_ladybug/constants.py +7 -0
  52. real_ladybug-source/tools/python_api/build/real_ladybug/database.py +307 -0
  53. real_ladybug-source/tools/python_api/build/real_ladybug/prepared_statement.py +51 -0
  54. real_ladybug-source/tools/python_api/build/real_ladybug/py.typed +0 -0
  55. real_ladybug-source/tools/python_api/build/real_ladybug/query_result.py +511 -0
  56. real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_feature_store.py +185 -0
  57. real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_graph_store.py +131 -0
  58. real_ladybug-source/tools/python_api/build/real_ladybug/torch_geometric_result_converter.py +282 -0
  59. real_ladybug-source/tools/python_api/build/real_ladybug/types.py +39 -0
  60. real_ladybug-source/tools/python_api/src_py/__init__.py +83 -0
  61. real_ladybug-source/tools/python_api/src_py/async_connection.py +226 -0
  62. real_ladybug-source/tools/python_api/src_py/connection.py +323 -0
  63. real_ladybug-source/tools/python_api/src_py/constants.py +7 -0
  64. real_ladybug-source/tools/python_api/src_py/database.py +307 -0
  65. real_ladybug-source/tools/python_api/src_py/prepared_statement.py +51 -0
  66. real_ladybug-source/tools/python_api/src_py/py.typed +0 -0
  67. real_ladybug-source/tools/python_api/src_py/query_result.py +511 -0
  68. real_ladybug-source/tools/python_api/src_py/torch_geometric_feature_store.py +185 -0
  69. real_ladybug-source/tools/python_api/src_py/torch_geometric_graph_store.py +131 -0
  70. real_ladybug-source/tools/python_api/src_py/torch_geometric_result_converter.py +282 -0
  71. real_ladybug-source/tools/python_api/src_py/types.py +39 -0
  72. real_ladybug-source/tools/python_api/test/conftest.py +230 -0
  73. real_ladybug-source/tools/python_api/test/disabled_test_extension.py +73 -0
  74. real_ladybug-source/tools/python_api/test/ground_truth.py +430 -0
  75. real_ladybug-source/tools/python_api/test/test_arrow.py +694 -0
  76. real_ladybug-source/tools/python_api/test/test_async_connection.py +159 -0
  77. real_ladybug-source/tools/python_api/test/test_blob_parameter.py +145 -0
  78. real_ladybug-source/tools/python_api/test/test_connection.py +49 -0
  79. real_ladybug-source/tools/python_api/test/test_database.py +234 -0
  80. real_ladybug-source/tools/python_api/test/test_datatype.py +372 -0
  81. real_ladybug-source/tools/python_api/test/test_df.py +564 -0
  82. real_ladybug-source/tools/python_api/test/test_dict.py +112 -0
  83. real_ladybug-source/tools/python_api/test/test_exception.py +54 -0
  84. real_ladybug-source/tools/python_api/test/test_fsm.py +227 -0
  85. real_ladybug-source/tools/python_api/test/test_get_header.py +49 -0
  86. real_ladybug-source/tools/python_api/test/test_helper.py +8 -0
  87. real_ladybug-source/tools/python_api/test/test_issue.py +147 -0
  88. real_ladybug-source/tools/python_api/test/test_iteration.py +96 -0
  89. real_ladybug-source/tools/python_api/test/test_networkx.py +437 -0
  90. real_ladybug-source/tools/python_api/test/test_parameter.py +340 -0
  91. real_ladybug-source/tools/python_api/test/test_prepared_statement.py +117 -0
  92. real_ladybug-source/tools/python_api/test/test_query_result.py +54 -0
  93. real_ladybug-source/tools/python_api/test/test_query_result_close.py +44 -0
  94. real_ladybug-source/tools/python_api/test/test_scan_pandas.py +676 -0
  95. real_ladybug-source/tools/python_api/test/test_scan_pandas_pyarrow.py +714 -0
  96. real_ladybug-source/tools/python_api/test/test_scan_polars.py +165 -0
  97. real_ladybug-source/tools/python_api/test/test_scan_pyarrow.py +167 -0
  98. real_ladybug-source/tools/python_api/test/test_timeout.py +11 -0
  99. real_ladybug-source/tools/python_api/test/test_torch_geometric.py +640 -0
  100. real_ladybug-source/tools/python_api/test/test_torch_geometric_remote_backend.py +111 -0
  101. real_ladybug-source/tools/python_api/test/test_udf.py +207 -0
  102. real_ladybug-source/tools/python_api/test/test_version.py +6 -0
  103. real_ladybug-source/tools/python_api/test/test_wal.py +80 -0
  104. real_ladybug-source/tools/python_api/test/type_aliases.py +10 -0
  105. real_ladybug-source/tools/rust_api/update_version.py +47 -0
  106. real_ladybug-source/tools/shell/test/conftest.py +218 -0
  107. real_ladybug-source/tools/shell/test/test_helper.py +60 -0
  108. real_ladybug-source/tools/shell/test/test_shell_basics.py +325 -0
  109. real_ladybug-source/tools/shell/test/test_shell_commands.py +656 -0
  110. real_ladybug-source/tools/shell/test/test_shell_control_edit.py +438 -0
  111. real_ladybug-source/tools/shell/test/test_shell_control_search.py +468 -0
  112. real_ladybug-source/tools/shell/test/test_shell_esc_edit.py +232 -0
  113. real_ladybug-source/tools/shell/test/test_shell_esc_search.py +162 -0
  114. real_ladybug-source/tools/shell/test/test_shell_flags.py +645 -0
@@ -0,0 +1,131 @@
1
+ from __future__ import annotations
2
+
3
+ import multiprocessing
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch_geometric.data.graph_store import EdgeAttr, EdgeLayout, GraphStore
10
+
11
+ from .connection import Connection
12
+
13
+ if TYPE_CHECKING:
14
+ import sys
15
+
16
+ from torch_geometric.typing import EdgeTensorType
17
+
18
+ from .database import Database
19
+
20
+ if sys.version_info >= (3, 10):
21
+ from typing import TypeAlias
22
+ else:
23
+ from typing_extensions import TypeAlias
24
+
25
+ StoreKeyType: TypeAlias = tuple[tuple[str], Any, bool]
26
+
27
+ REL_BATCH_SIZE = 1000000
28
+
29
+
30
+ @dataclass
31
+ class Rel: # noqa: D101
32
+ edge_type: tuple[str, ...]
33
+ layout: str
34
+ is_sorted: bool
35
+ size: tuple[int, ...]
36
+ materialized: bool = False
37
+ edge_index: EdgeTensorType | None = None
38
+
39
+
40
+ class LbugGraphStore(GraphStore): # type: ignore[misc]
41
+ """Graph store compatible with `torch_geometric`."""
42
+
43
+ def __init__(self, db: Database, num_threads: int | None = None):
44
+ super().__init__()
45
+ self.db = db
46
+ self.connection: Connection | None = None
47
+ self.store: dict[StoreKeyType, Rel] = {}
48
+ if num_threads is None:
49
+ num_threads = multiprocessing.cpu_count()
50
+ self.num_threads = num_threads
51
+ self.__populate_edge_attrs()
52
+
53
+ @staticmethod
54
+ def key(attr: EdgeAttr) -> tuple[tuple[str], Any, bool]: # noqa: D102
55
+ return attr.edge_type, attr.layout.value, attr.is_sorted
56
+
57
+ def _put_edge_index(self, edge_index: EdgeTensorType, edge_attr: EdgeAttr) -> None:
58
+ key = self.key(edge_attr)
59
+ if key in self.store:
60
+ self.store[key].edge_index = edge_index
61
+ self.store[key].materialized = True
62
+ self.store[key].size = edge_attr.size
63
+ else:
64
+ self.store[key] = Rel(key[0], key[1], key[2], edge_attr.size, True, edge_index)
65
+
66
+ def _get_edge_index(self, edge_attr: EdgeAttr) -> EdgeTensorType | None:
67
+ if edge_attr.layout.value == EdgeLayout.COO.value: # noqa: SIM102
68
+ # We always return a sorted COO edge index, if the request is
69
+ # for an unsorted COO edge index, we change the is_sorted flag
70
+ # to True and return the sorted COO edge index.
71
+ if edge_attr.is_sorted is False:
72
+ edge_attr.is_sorted = True
73
+
74
+ key = self.key(edge_attr)
75
+ if key in self.store:
76
+ rel = self.store[self.key(edge_attr)]
77
+ if not rel.materialized and rel.layout != EdgeLayout.COO.value:
78
+ msg = "Only COO layout is supported"
79
+ raise ValueError(msg)
80
+
81
+ if rel.layout == EdgeLayout.COO.value:
82
+ self.__get_edge_coo_from_database(self.key(edge_attr))
83
+ return rel.edge_index
84
+ else:
85
+ return None
86
+
87
+ def _remove_edge_index(self, edge_attr: EdgeAttr) -> None:
88
+ key = self.key(edge_attr)
89
+ if key in self.store:
90
+ del self.store[key]
91
+
92
+ def get_all_edge_attrs(self) -> list[EdgeAttr]:
93
+ """Return all EdgeAttr from the store values."""
94
+ return [EdgeAttr(rel.edge_type, rel.layout, rel.is_sorted, rel.size) for rel in self.store.values()]
95
+
96
+ def __get_edge_coo_from_database(self, key: StoreKeyType) -> None:
97
+ if not self.connection:
98
+ self.connection = Connection(self.db, self.num_threads)
99
+
100
+ rel = self.store[key]
101
+ if rel.layout != EdgeLayout.COO.value:
102
+ msg = "Only COO layout is supported"
103
+ raise ValueError(msg)
104
+ if rel.materialized:
105
+ return
106
+
107
+ edge_type = rel.edge_type
108
+ num_edges = self.connection._connection.get_num_rels(edge_type[1])
109
+ result = np.empty(2 * num_edges, dtype=np.int64)
110
+ self.connection._connection.get_all_edges_for_torch_geometric(
111
+ result, edge_type[0], edge_type[1], edge_type[2], REL_BATCH_SIZE
112
+ )
113
+ edge_list = torch.from_numpy(result)
114
+ edge_list = edge_list.reshape((2, edge_list.shape[0] // 2))
115
+ rel.edge_index = edge_list
116
+ rel.materialized = True
117
+
118
+ def __populate_edge_attrs(self) -> None:
119
+ if not self.connection:
120
+ self.connection = Connection(self.db, self.num_threads)
121
+ rel_tables = self.connection._get_rel_table_names()
122
+ for rel_table in rel_tables:
123
+ edge_type = (rel_table["src"], rel_table["name"], rel_table["dst"])
124
+ size = self.__get_size(edge_type)
125
+ rel = Rel(edge_type, EdgeLayout.COO.value, True, size, False, None)
126
+ self.store[self.key(EdgeAttr(edge_type, EdgeLayout.COO, True))] = rel
127
+
128
+ def __get_size(self, edge_type: tuple[str, ...]) -> tuple[int, int]:
129
+ num_nodes = self.connection._connection.get_num_nodes # type: ignore[union-attr]
130
+ src_count, dst_count = num_nodes(edge_type[0]), num_nodes(edge_type[-1])
131
+ return (src_count, dst_count)
@@ -0,0 +1,282 @@
1
+ from __future__ import annotations
2
+
3
+ import warnings
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ from .types import Type
7
+
8
+ if TYPE_CHECKING:
9
+ import torch_geometric.data as geo
10
+
11
+ from .query_result import QueryResult
12
+
13
+ from .constants import ID, LABEL, SRC, DST
14
+
15
+
16
+ class TorchGeometricResultConverter:
17
+ """Convert graph results to `torch_geometric`."""
18
+
19
+ def __init__(self, query_result: QueryResult):
20
+ self.query_result = query_result
21
+ self.nodes_dict: dict[str, Any] = {}
22
+ self.edges_dict: dict[str, Any] = {}
23
+ self.edges_properties: dict[str | tuple[str, str], dict[str, Any]] = {}
24
+ self.rels: dict[tuple[Any, ...], dict[str, Any]] = {}
25
+ self.nodes_property_names_dict: dict[str, Any] = {}
26
+ self.table_to_label_dict: dict[int, str] = {}
27
+ self.internal_id_to_pos_dict: dict[tuple[int, int], int | None] = {}
28
+ self.pos_to_primary_key_dict: dict[str, Any] = {}
29
+ self.warning_messages: set[str] = set()
30
+ self.unconverted_properties: dict[str, Any] = {}
31
+ self.properties_to_extract = self.query_result._get_properties_to_extract()
32
+
33
+ def __get_node_property_names(self, table_name: str) -> dict[str, Any]:
34
+ if table_name in self.nodes_property_names_dict:
35
+ return self.nodes_property_names_dict[table_name]
36
+ results = self.query_result.connection._get_node_property_names(table_name)
37
+ self.nodes_property_names_dict[table_name] = results
38
+ return results
39
+
40
+ def __populate_nodes_dict_and_deduplicte_edges(self) -> None:
41
+ self.query_result.reset_iterator()
42
+ while self.query_result.has_next():
43
+ row = self.query_result.get_next()
44
+ for i in self.properties_to_extract:
45
+ column_type, _ = self.properties_to_extract[i]
46
+ if column_type == Type.NODE.value:
47
+ node = row[i]
48
+ label = node[LABEL]
49
+ nid = node[ID]
50
+ self.table_to_label_dict[nid["table"]] = label
51
+
52
+ if (nid["table"], nid["offset"]) in self.internal_id_to_pos_dict:
53
+ continue
54
+
55
+ node_property_names = self.__get_node_property_names(label)
56
+
57
+ pos, primary_key = self.__extract_properties_from_node(node, label, node_property_names)
58
+
59
+ self.internal_id_to_pos_dict[nid["table"], nid["offset"]] = pos
60
+ if label not in self.pos_to_primary_key_dict:
61
+ self.pos_to_primary_key_dict[label] = {}
62
+ self.pos_to_primary_key_dict[label][pos] = primary_key
63
+
64
+ elif column_type == Type.REL.value:
65
+ src = row[i][SRC]
66
+ dst = row[i][DST]
67
+ self.rels[src["table"], src["offset"], dst["table"], dst["offset"]] = row[i]
68
+
69
+ def __extract_properties_from_node(
70
+ self,
71
+ node: dict[str, Any],
72
+ label: str,
73
+ node_property_names: dict[str, Any],
74
+ ) -> tuple[int | None, Any]:
75
+ pos = None
76
+ import torch
77
+
78
+ for prop_name in node_property_names:
79
+ # Read primary key
80
+ if node_property_names[prop_name]["is_primary_key"]:
81
+ primary_key = node[prop_name]
82
+
83
+ # If property is already marked as unconverted, then add it directly without further checks
84
+ if label in self.unconverted_properties and prop_name in self.unconverted_properties[label]:
85
+ pos = self.__add_unconverted_property(node, label, prop_name)
86
+ continue
87
+
88
+ # Mark properties that are not supported by torch_geometric as unconverted
89
+ if node_property_names[prop_name]["type"] not in [Type.INT64.value, Type.DOUBLE.value, Type.BOOL.value]:
90
+ self.warning_messages.add(
91
+ "Property {}.{} of type {} is not supported by torch_geometric. The property is marked as unconverted.".format(
92
+ label, prop_name, node_property_names[prop_name]["type"]
93
+ )
94
+ )
95
+ self.__mark_property_unconverted(label, prop_name)
96
+ pos = self.__add_unconverted_property(node, label, prop_name)
97
+ continue
98
+ if node[prop_name] is None:
99
+ self.warning_messages.add(
100
+ f"Property {label}.{prop_name} has a null value. torch_geometric does not support null values. The property is marked as unconverted."
101
+ )
102
+ self.__mark_property_unconverted(label, prop_name)
103
+ pos = self.__add_unconverted_property(node, label, prop_name)
104
+ continue
105
+
106
+ if node_property_names[prop_name]["dimension"] == 0:
107
+ curr_value = node[prop_name]
108
+ else:
109
+ try:
110
+ if node_property_names[prop_name]["type"] == Type.INT64.value:
111
+ curr_value = torch.LongTensor(node[prop_name])
112
+ elif node_property_names[prop_name]["type"] == Type.DOUBLE.value:
113
+ curr_value = torch.FloatTensor(node[prop_name])
114
+ elif node_property_names[prop_name]["type"] == Type.BOOL.value:
115
+ curr_value = torch.BoolTensor(node[prop_name])
116
+ except ValueError:
117
+ self.warning_messages.add(
118
+ f"Property {label}.{prop_name} cannot be converted to Tensor (likely due to nested list of variable length). The property is marked as unconverted."
119
+ )
120
+ self.__mark_property_unconverted(label, prop_name)
121
+ pos = self.__add_unconverted_property(node, label, prop_name)
122
+ continue
123
+
124
+ # Check if the shape of the property is consistent
125
+ if label in self.nodes_dict and prop_name in self.nodes_dict[label]: # noqa: SIM102
126
+ # If the shape is inconsistent, then mark the property as unconverted
127
+ if curr_value.shape != self.nodes_dict[label][prop_name][0].shape:
128
+ self.warning_messages.add(
129
+ f"Property {label}.{prop_name} has an inconsistent shape. The property is marked as unconverted."
130
+ )
131
+ self.__mark_property_unconverted(label, prop_name)
132
+ pos = self.__add_unconverted_property(node, label, prop_name)
133
+ continue
134
+
135
+ # Create the dictionary for the label if it does not exist
136
+ if label not in self.nodes_dict:
137
+ self.nodes_dict[label] = {}
138
+ if prop_name not in self.nodes_dict[label]:
139
+ self.nodes_dict[label][prop_name] = []
140
+
141
+ # Add the property to the dictionary
142
+ self.nodes_dict[label][prop_name].append(curr_value)
143
+
144
+ # The pos will be overwritten for each property, but
145
+ # it should be the same for all properties
146
+ pos = len(self.nodes_dict[label][prop_name]) - 1
147
+ return pos, primary_key
148
+
149
+ def __add_unconverted_property(self, node: dict[str, Any], label: str, prop_name: str) -> int:
150
+ self.unconverted_properties[label][prop_name].append(node[prop_name])
151
+ return len(self.unconverted_properties[label][prop_name]) - 1
152
+
153
+ def __mark_property_unconverted(self, label: str, prop_name: str) -> None:
154
+ import torch
155
+
156
+ if label not in self.unconverted_properties:
157
+ self.unconverted_properties[label] = {}
158
+ if prop_name not in self.unconverted_properties[label]:
159
+ if label in self.nodes_dict and prop_name in self.nodes_dict[label]:
160
+ self.unconverted_properties[label][prop_name] = self.nodes_dict[label][prop_name]
161
+ del self.nodes_dict[label][prop_name]
162
+ if len(self.nodes_dict[label]) == 0:
163
+ del self.nodes_dict[label]
164
+ for i in range(len(self.unconverted_properties[label][prop_name])):
165
+ # If the property is a tensor, convert it back to list (consistent with the original type)
166
+ if torch.is_tensor(self.unconverted_properties[label][prop_name][i]): # type: ignore[no-untyped-call]
167
+ self.unconverted_properties[label][prop_name][i] = self.unconverted_properties[label][
168
+ prop_name
169
+ ][i].tolist()
170
+ else:
171
+ self.unconverted_properties[label][prop_name] = []
172
+
173
+ def __populate_edges_dict(self) -> None:
174
+ # Post-process edges, map internal ids to positions
175
+ for r in self.rels:
176
+ src_pos = self.internal_id_to_pos_dict[r[0], r[1]]
177
+ dst_pos = self.internal_id_to_pos_dict[r[2], r[3]]
178
+ src_label = self.table_to_label_dict[r[0]]
179
+ dst_label = self.table_to_label_dict[r[2]]
180
+ if src_label not in self.edges_dict:
181
+ self.edges_dict[src_label] = {}
182
+ if dst_label not in self.edges_dict[src_label]:
183
+ self.edges_dict[src_label][dst_label] = []
184
+ self.edges_dict[src_label][dst_label].append((src_pos, dst_pos))
185
+ curr_edge_properties = self.rels[r]
186
+ if (src_label, dst_label) not in self.edges_properties:
187
+ self.edges_properties[src_label, dst_label] = {}
188
+ for prop_name in curr_edge_properties:
189
+ if prop_name in [SRC, DST, ID]:
190
+ continue
191
+ if prop_name not in self.edges_properties[src_label, dst_label]:
192
+ self.edges_properties[src_label, dst_label][prop_name] = []
193
+ self.edges_properties[src_label, dst_label][prop_name].append(curr_edge_properties[prop_name])
194
+
195
+ def __print_warnings(self) -> None:
196
+ for message in self.warning_messages:
197
+ warnings.warn(message, stacklevel=2)
198
+
199
+ def __convert_to_torch_geometric(
200
+ self,
201
+ ) -> tuple[
202
+ geo.Data | geo.HeteroData,
203
+ dict[str, Any],
204
+ dict[str, Any],
205
+ dict[str | tuple[str, str], dict[str, Any]],
206
+ ]:
207
+ import torch
208
+ import torch_geometric
209
+
210
+ if len(self.nodes_dict) == 0:
211
+ self.warning_messages.add("No nodes found or all node properties are not converted.")
212
+
213
+ # If there is only one node type, then convert to torch_geometric.data.Data
214
+ # Otherwise, convert to torch_geometric.data.HeteroData
215
+ if len(self.nodes_dict) == 1:
216
+ data = torch_geometric.data.Data()
217
+ is_hetero = False
218
+ else:
219
+ data = torch_geometric.data.HeteroData()
220
+ is_hetero = True
221
+
222
+ # Convert nodes to tensors
223
+ converted: torch.Tensor
224
+ for label in self.nodes_dict:
225
+ for prop_name in self.nodes_dict[label]:
226
+ prop_type = self.nodes_property_names_dict[label][prop_name]["type"]
227
+ prop_dimension = self.nodes_property_names_dict[label][prop_name]["dimension"]
228
+ if prop_dimension == 0:
229
+ if prop_type == Type.INT64.value:
230
+ converted = torch.LongTensor(self.nodes_dict[label][prop_name])
231
+ elif prop_type == Type.BOOL.value:
232
+ converted = torch.BoolTensor(self.nodes_dict[label][prop_name])
233
+ elif prop_type == Type.DOUBLE.value:
234
+ converted = torch.FloatTensor(self.nodes_dict[label][prop_name])
235
+ else:
236
+ converted = torch.stack(self.nodes_dict[label][prop_name], dim=0)
237
+ if is_hetero:
238
+ data[label][prop_name] = converted
239
+ else:
240
+ data[prop_name] = converted
241
+
242
+ # Convert edges to tensors
243
+ for src_label in self.edges_dict:
244
+ for dst_label in self.edges_dict[src_label]:
245
+ edge_idx = torch.tensor(self.edges_dict[src_label][dst_label], dtype=torch.long).t().contiguous()
246
+ if is_hetero:
247
+ data[src_label, dst_label].edge_index = edge_idx
248
+ else:
249
+ data.edge_index = edge_idx
250
+
251
+ pos_to_primary_key_dict: dict[str, Any] = (
252
+ self.pos_to_primary_key_dict[label] if not is_hetero else self.pos_to_primary_key_dict
253
+ )
254
+
255
+ if is_hetero:
256
+ unconverted_properties = self.unconverted_properties
257
+ edge_properties = self.edges_properties
258
+ else:
259
+ if len(self.unconverted_properties) == 0:
260
+ unconverted_properties = {}
261
+ else:
262
+ unconverted_properties = self.unconverted_properties[next(iter(self.unconverted_properties))]
263
+ if len(self.edges_properties) == 0:
264
+ edge_properties = {}
265
+ else:
266
+ edge_properties = self.edges_properties[next(iter(self.edges_properties))] # type: ignore[assignment]
267
+ return data, pos_to_primary_key_dict, unconverted_properties, edge_properties
268
+
269
+ def get_as_torch_geometric(
270
+ self,
271
+ ) -> tuple[
272
+ geo.Data | geo.HeteroData,
273
+ dict[str, Any],
274
+ dict[str, Any],
275
+ dict[str | tuple[str, str], dict[str, Any]],
276
+ ]:
277
+ """Convert graph data to `torch_geometric`."""
278
+ self.__populate_nodes_dict_and_deduplicte_edges()
279
+ self.__populate_edges_dict()
280
+ result = self.__convert_to_torch_geometric()
281
+ self.__print_warnings()
282
+ return result
@@ -0,0 +1,39 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Type(Enum):
5
+ """The type of a value in the database."""
6
+
7
+ ANY = "ANY"
8
+ NODE = "NODE"
9
+ REL = "REL"
10
+ RECURSIVE_REL = "RECURSIVE_REL"
11
+ SERIAL = "SERIAL"
12
+ BOOL = "BOOL"
13
+ INT64 = "INT64"
14
+ INT32 = "INT32"
15
+ INT16 = "INT16"
16
+ INT8 = "INT8"
17
+ UINT64 = "UINT64"
18
+ UINT32 = "UINT32"
19
+ UINT16 = "UINT16"
20
+ UINT8 = "UINT8"
21
+ INT128 = "INT128"
22
+ DOUBLE = "DOUBLE"
23
+ FLOAT = "FLOAT"
24
+ DATE = "DATE"
25
+ TIMESTAMP = "TIMESTAMP"
26
+ TIMSTAMP_TZ = "TIMESTAMP_TZ"
27
+ TIMESTAMP_NS = "TIMESTAMP_NS"
28
+ TIMESTAMP_MS = "TIMESTAMP_MS"
29
+ TIMESTAMP_SEC = "TIMESTAMP_SEC"
30
+ INTERVAL = "INTERVAL"
31
+ INTERNAL_ID = "INTERNAL_ID"
32
+ STRING = "STRING"
33
+ BLOB = "BLOB"
34
+ UUID = "UUID"
35
+ LIST = "LIST"
36
+ ARRAY = "ARRAY"
37
+ STRUCT = "STRUCT"
38
+ MAP = "MAP"
39
+ UNION = "UNION"
@@ -0,0 +1,230 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import pytest
10
+ from test_helper import LBUG_ROOT
11
+
12
+ python_build_dir = Path(__file__).parent.parent / "build"
13
+ try:
14
+ import real_ladybug as lb
15
+ except ModuleNotFoundError:
16
+ sys.path.append(str(python_build_dir))
17
+ import real_ladybug as lb
18
+
19
+ if TYPE_CHECKING:
20
+ from type_aliases import ConnDB
21
+
22
+
23
+ def init_npy(conn: lb.Connection) -> None:
24
+ conn.execute(
25
+ """
26
+ CREATE NODE TABLE npyoned (
27
+ i64 INT64,
28
+ i32 INT32,
29
+ i16 INT16,
30
+ f64 DOUBLE,
31
+ f32 FLOAT,
32
+ PRIMARY KEY(i64)
33
+ );
34
+ """
35
+ )
36
+ conn.execute(
37
+ f"""
38
+ COPY npyoned from (
39
+ "{LBUG_ROOT}/dataset/npy-1d/one_dim_int64.npy",
40
+ "{LBUG_ROOT}/dataset/npy-1d/one_dim_int32.npy",
41
+ "{LBUG_ROOT}/dataset/npy-1d/one_dim_int16.npy",
42
+ "{LBUG_ROOT}/dataset/npy-1d/one_dim_double.npy",
43
+ "{LBUG_ROOT}/dataset/npy-1d/one_dim_float.npy") BY COLUMN;
44
+ """
45
+ )
46
+ conn.execute(
47
+ """
48
+ CREATE NODE TABLE npytwod (
49
+ id INT64,
50
+ i64 INT64[3],
51
+ i32 INT32[3],
52
+ i16 INT16[3],
53
+ f64 DOUBLE[3],
54
+ f32 FLOAT[3],
55
+ PRIMARY KEY(id)
56
+ );
57
+ """
58
+ )
59
+ conn.execute(
60
+ f"""
61
+ COPY npytwod FROM (
62
+ "{LBUG_ROOT}/dataset/npy-2d/id_int64.npy",
63
+ "{LBUG_ROOT}/dataset/npy-2d/two_dim_int64.npy",
64
+ "{LBUG_ROOT}/dataset/npy-2d/two_dim_int32.npy",
65
+ "{LBUG_ROOT}/dataset/npy-2d/two_dim_int16.npy",
66
+ "{LBUG_ROOT}/dataset/npy-2d/two_dim_double.npy",
67
+ "{LBUG_ROOT}/dataset/npy-2d/two_dim_float.npy") BY COLUMN;
68
+ """
69
+ )
70
+
71
+
72
+ def init_tensor(conn: lb.Connection) -> None:
73
+ conn.execute(
74
+ """
75
+ CREATE NODE TABLE tensor (
76
+ ID INT64,
77
+ boolTensor BOOLEAN[],
78
+ doubleTensor DOUBLE[][],
79
+ intTensor INT64[][][],
80
+ oneDimInt INT64,
81
+ PRIMARY KEY (ID)
82
+ );
83
+ """
84
+ )
85
+ conn.execute(f'COPY tensor FROM "{LBUG_ROOT}/dataset/tensor-list/vTensor.csv" (HEADER=true)')
86
+
87
+
88
+ def init_long_str(conn: lb.Connection) -> None:
89
+ conn.execute("CREATE NODE TABLE personLongString (name STRING, spouse STRING, PRIMARY KEY(name))")
90
+ conn.execute(f'COPY personLongString FROM "{LBUG_ROOT}/dataset/long-string-pk-tests/vPerson.csv"')
91
+ conn.execute("CREATE REL TABLE knowsLongString (FROM personLongString TO personLongString, MANY_MANY)")
92
+ conn.execute(f'COPY knowsLongString FROM "{LBUG_ROOT}/dataset/long-string-pk-tests/eKnows.csv"')
93
+
94
+
95
+ def init_tinysnb(conn: lb.Connection) -> None:
96
+ tiny_snb_path = (Path(__file__).parent / f"{LBUG_ROOT}/dataset/tinysnb").resolve()
97
+ schema_path = tiny_snb_path / "schema.cypher"
98
+ with schema_path.open(mode="r") as f:
99
+ for line in f.readlines():
100
+ line = line.strip()
101
+ if line:
102
+ conn.execute(line)
103
+
104
+ copy_path = tiny_snb_path / "copy.cypher"
105
+ with copy_path.open(mode="r") as f:
106
+ for line in f.readlines():
107
+ line = line.strip()
108
+ line = line.replace("dataset/tinysnb", f"{LBUG_ROOT}/dataset/tinysnb")
109
+ if line:
110
+ conn.execute(line)
111
+
112
+
113
+ def init_demo(conn: lb.Connection) -> None:
114
+ tiny_snb_path = (Path(__file__).parent / f"{LBUG_ROOT}/dataset/demo-db/csv").resolve()
115
+ schema_path = tiny_snb_path / "schema.cypher"
116
+ with schema_path.open(mode="r") as f:
117
+ for line in f.readlines():
118
+ line = line.strip()
119
+ if line:
120
+ conn.execute(line)
121
+
122
+ copy_path = tiny_snb_path / "copy.cypher"
123
+ with copy_path.open(mode="r") as f:
124
+ for line in f.readlines():
125
+ line = line.strip()
126
+ line = line.replace("dataset/demo-db/csv", f"{LBUG_ROOT}/dataset/demo-db/csv")
127
+ if line:
128
+ conn.execute(line)
129
+
130
+
131
+ def init_movie_serial(conn: lb.Connection) -> None:
132
+ conn.execute(
133
+ """
134
+ CREATE NODE TABLE moviesSerial (
135
+ ID SERIAL,
136
+ name STRING,
137
+ length INT32,
138
+ note STRING,
139
+ PRIMARY KEY (ID)
140
+ );"""
141
+ )
142
+ conn.execute(f'COPY moviesSerial from "{LBUG_ROOT}/dataset/tinysnb-serial/vMovies.csv"')
143
+
144
+
145
+ _POOL_SIZE_: int = 256 * 1024 * 1024
146
+
147
+
148
+ def get_db_file_path(tmp_path: Path) -> Path:
149
+ """Return the path to the database file."""
150
+ return tmp_path / "db.kz"
151
+
152
+
153
+ def init_db(path: Path) -> Path:
154
+ if Path(path).exists():
155
+ shutil.rmtree(path)
156
+ Path.mkdir(path)
157
+
158
+ db_path = get_db_file_path(path)
159
+ conn, _ = create_conn_db(db_path, read_only=False)
160
+ init_tinysnb(conn)
161
+ init_demo(conn)
162
+ init_npy(conn)
163
+ init_tensor(conn)
164
+ init_long_str(conn)
165
+ init_movie_serial(conn)
166
+ return db_path
167
+
168
+
169
+ _READONLY_CONN_DB_: ConnDB | None = None
170
+ _READONLY_ASYNC_CONNECTION_: lb.AsyncConnection | None = None
171
+
172
+
173
+ def create_conn_db(path: Path, *, read_only: bool) -> ConnDB:
174
+ """Return a new connection and database."""
175
+ db = lb.Database(path, buffer_pool_size=_POOL_SIZE_, read_only=read_only)
176
+ conn = lb.Connection(db, num_threads=4)
177
+ return conn, db
178
+
179
+
180
+ @pytest.fixture
181
+ def conn_db_readonly(tmp_path: Path) -> ConnDB:
182
+ """Return a cached read-only connection and database."""
183
+ global _READONLY_CONN_DB_
184
+ if _READONLY_CONN_DB_ is None:
185
+ _READONLY_CONN_DB_ = create_conn_db(init_db(tmp_path), read_only=True)
186
+ return _READONLY_CONN_DB_
187
+
188
+
189
+ @pytest.fixture
190
+ def conn_db_readwrite(tmp_path: Path) -> ConnDB:
191
+ """Return a new writable connection and database."""
192
+ return create_conn_db(init_db(tmp_path), read_only=False)
193
+
194
+
195
+ @pytest.fixture
196
+ def async_connection_readonly(tmp_path: Path) -> lb.AsyncConnection:
197
+ """Return a cached read-only async connection."""
198
+ global _READONLY_ASYNC_CONNECTION_
199
+ if _READONLY_ASYNC_CONNECTION_ is None:
200
+ conn, db = create_conn_db(init_db(tmp_path), read_only=True)
201
+ conn.close()
202
+ _READONLY_ASYNC_CONNECTION_ = lb.AsyncConnection(db, max_threads_per_query=4)
203
+ return _READONLY_ASYNC_CONNECTION_
204
+
205
+
206
+ @pytest.fixture
207
+ def async_connection_readwrite(tmp_path: Path) -> lb.AsyncConnection:
208
+ """Return a writeable async connection."""
209
+ conn, db = create_conn_db(init_db(tmp_path), read_only=False)
210
+ conn.close()
211
+ return lb.AsyncConnection(db, max_threads_per_query=4)
212
+
213
+
214
+ @pytest.fixture
215
+ def conn_db_empty(tmp_path: Path) -> ConnDB:
216
+ """Return a new empty connection and database."""
217
+ return create_conn_db(get_db_file_path(tmp_path), read_only=False)
218
+
219
+
220
+ @pytest.fixture
221
+ def conn_db_in_mem() -> ConnDB:
222
+ """Return a new in-memory connection and database."""
223
+ db = lb.Database(database_path=":memory:", buffer_pool_size=_POOL_SIZE_, read_only=False)
224
+ conn = lb.Connection(db, num_threads=4)
225
+ return conn, db
226
+
227
+
228
+ @pytest.fixture
229
+ def build_dir() -> Path:
230
+ return python_build_dir