phenolrs 0.5.11__cp313-cp313-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of phenolrs might be problematic. Click here for more details.
- phenolrs/__init__.py +5 -0
- phenolrs/aql/__init__.py +4 -0
- phenolrs/aql/loader.py +723 -0
- phenolrs/aql/typings.py +98 -0
- phenolrs/networkx/__init__.py +1 -0
- phenolrs/networkx/loader.py +144 -0
- phenolrs/networkx/typings.py +17 -0
- phenolrs/numpy/__init__.py +1 -0
- phenolrs/numpy/loader.py +114 -0
- phenolrs/numpy/typings.py +12 -0
- phenolrs/phenolrs.cpython-313-darwin.so +0 -0
- phenolrs/phenolrs.pyi +84 -0
- phenolrs/pyg/__init__.py +1 -0
- phenolrs/pyg/loader.py +173 -0
- phenolrs/pyg/typings.py +2 -0
- phenolrs-0.5.11.dist-info/METADATA +24 -0
- phenolrs-0.5.11.dist-info/RECORD +19 -0
- phenolrs-0.5.11.dist-info/WHEEL +4 -0
- phenolrs-0.5.11.dist-info/licenses/LICENSE +85 -0
phenolrs/aql/loader.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
1
|
+
"""AQL-based graph loading for retrieving subgraphs from ArangoDB.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for loading graphs from ArangoDB using
|
|
4
|
+
custom AQL queries, following the design specification for flexible
|
|
5
|
+
graph export via AQL.
|
|
6
|
+
|
|
7
|
+
The key benefit of AQL-based loading is flexibility:
|
|
8
|
+
- Use indexes or traversals to find the right subgraph
|
|
9
|
+
- Filter vertices and edges with arbitrary AQL conditions
|
|
10
|
+
- Support for graph traversals to extract connected subgraphs
|
|
11
|
+
- Control over execution order (sequential groups, parallel queries)
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import re
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
from phenolrs import (
|
|
20
|
+
PhenolError,
|
|
21
|
+
graph_aql_to_networkx_format,
|
|
22
|
+
graph_aql_to_numpy_format,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from .typings import AqlQuery, AttributeSpec, DatabaseConfig
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from torch_geometric.data import Data, HeteroData
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import torch
|
|
32
|
+
from torch_geometric.data import Data, HeteroData # noqa: F811
|
|
33
|
+
|
|
34
|
+
TORCH_AVAILABLE = True
|
|
35
|
+
except ImportError:
|
|
36
|
+
TORCH_AVAILABLE = False
|
|
37
|
+
|
|
38
|
+
# Valid AQL identifier: alphanumeric, underscore, hyphen; starts with letter/_
|
|
39
|
+
_VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_\-]*$")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _validate_identifier(name: str, param_name: str) -> None:
|
|
43
|
+
"""Validate that a name is a safe AQL identifier."""
|
|
44
|
+
if not name or not _VALID_IDENTIFIER.match(name):
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"Invalid {param_name}: '{name}'. Must be alphanumeric with "
|
|
47
|
+
"underscores/hyphens, starting with a letter or underscore."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AqlLoader:
|
|
52
|
+
"""Loader for AQL-based graph extraction from ArangoDB.
|
|
53
|
+
|
|
54
|
+
This loader allows flexible graph extraction using custom AQL queries.
|
|
55
|
+
Queries are organized into groups:
|
|
56
|
+
- Outer list: Groups processed sequentially
|
|
57
|
+
- Inner list: Queries within a group processed in parallel
|
|
58
|
+
|
|
59
|
+
Example usage:
|
|
60
|
+
```python
|
|
61
|
+
from phenolrs.aql import AqlLoader
|
|
62
|
+
|
|
63
|
+
# Load a subgraph using filtered collections
|
|
64
|
+
loader = AqlLoader(
|
|
65
|
+
hosts=["http://localhost:8529"],
|
|
66
|
+
database="mydb",
|
|
67
|
+
username="root",
|
|
68
|
+
password="password"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Define the queries
|
|
72
|
+
queries = [
|
|
73
|
+
# First group: load vertices
|
|
74
|
+
[
|
|
75
|
+
{"query": "FOR v IN users FILTER v.active RETURN {vertices: [v]}"},
|
|
76
|
+
{"query": "FOR v IN products FILTER v.inStock RETURN {vertices: [v]}"}
|
|
77
|
+
],
|
|
78
|
+
# Second group: load edges
|
|
79
|
+
[
|
|
80
|
+
{"query": "FOR e IN purchases RETURN {edges: [e]}"}
|
|
81
|
+
]
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
# Load into numpy format
|
|
85
|
+
result = loader.load_to_numpy(
|
|
86
|
+
queries=queries,
|
|
87
|
+
vertex_attributes={"name": "string", "age": "i64"},
|
|
88
|
+
edge_attributes={"amount": "f64"}
|
|
89
|
+
)
|
|
90
|
+
```
|
|
91
|
+
"""
|
|
92
|
+
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
hosts: List[str],
|
|
96
|
+
database: str = "_system",
|
|
97
|
+
username: Optional[str] = None,
|
|
98
|
+
password: Optional[str] = None,
|
|
99
|
+
user_jwt: Optional[str] = None,
|
|
100
|
+
tls_cert: Optional[str] = None,
|
|
101
|
+
batch_size: int = 10000,
|
|
102
|
+
):
|
|
103
|
+
"""Initialize the AQL loader.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
hosts: List of ArangoDB endpoint URLs (e.g., ["http://localhost:8529"])
|
|
107
|
+
database: Database name (default: "_system")
|
|
108
|
+
username: Username for authentication
|
|
109
|
+
password: Password for authentication
|
|
110
|
+
user_jwt: JWT token for authentication (alternative to username/password)
|
|
111
|
+
tls_cert: TLS certificate for secure connections
|
|
112
|
+
batch_size: Number of items per batch (default: 10000)
|
|
113
|
+
"""
|
|
114
|
+
self.hosts = hosts
|
|
115
|
+
self.database = database
|
|
116
|
+
self.username = username
|
|
117
|
+
self.password = password
|
|
118
|
+
self.user_jwt = user_jwt
|
|
119
|
+
self.tls_cert = tls_cert
|
|
120
|
+
self.batch_size = batch_size
|
|
121
|
+
|
|
122
|
+
def _build_request(
|
|
123
|
+
self,
|
|
124
|
+
queries: List[List[AqlQuery]],
|
|
125
|
+
vertex_attributes: Optional[AttributeSpec] = None,
|
|
126
|
+
edge_attributes: Optional[AttributeSpec] = None,
|
|
127
|
+
max_type_errors: Optional[int] = None,
|
|
128
|
+
) -> Dict[str, Any]:
|
|
129
|
+
"""Build the request object for the Rust backend."""
|
|
130
|
+
db_config: DatabaseConfig = {
|
|
131
|
+
"endpoints": self.hosts,
|
|
132
|
+
"database": self.database,
|
|
133
|
+
}
|
|
134
|
+
# Only include credentials when provided (not empty strings)
|
|
135
|
+
if self.username:
|
|
136
|
+
db_config["username"] = self.username
|
|
137
|
+
if self.password:
|
|
138
|
+
db_config["password"] = self.password
|
|
139
|
+
if self.user_jwt:
|
|
140
|
+
db_config["jwt_token"] = self.user_jwt
|
|
141
|
+
if self.tls_cert:
|
|
142
|
+
db_config["tls_cert"] = self.tls_cert
|
|
143
|
+
|
|
144
|
+
request: Dict[str, Any] = {
|
|
145
|
+
"database_config": db_config,
|
|
146
|
+
"batch_size": self.batch_size,
|
|
147
|
+
"queries": queries,
|
|
148
|
+
}
|
|
149
|
+
|
|
150
|
+
if vertex_attributes is not None:
|
|
151
|
+
request["vertex_attributes"] = vertex_attributes
|
|
152
|
+
if edge_attributes is not None:
|
|
153
|
+
request["edge_attributes"] = edge_attributes
|
|
154
|
+
if max_type_errors is not None:
|
|
155
|
+
request["max_type_errors"] = max_type_errors
|
|
156
|
+
|
|
157
|
+
return request
|
|
158
|
+
|
|
159
|
+
def load_to_numpy(
|
|
160
|
+
self,
|
|
161
|
+
queries: List[List[AqlQuery]],
|
|
162
|
+
vertex_attributes: Optional[AttributeSpec] = None,
|
|
163
|
+
edge_attributes: Optional[AttributeSpec] = None,
|
|
164
|
+
max_type_errors: Optional[int] = None,
|
|
165
|
+
) -> Any:
|
|
166
|
+
"""Load a graph using AQL queries into numpy-compatible format.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
queries: List of query groups. Outer list is sequential,
|
|
170
|
+
inner lists are parallel.
|
|
171
|
+
Each query should return {"vertices": [...], "edges": [...]}.
|
|
172
|
+
vertex_attributes: Schema for vertex attributes. Either a dict
|
|
173
|
+
mapping attribute names to types
|
|
174
|
+
(e.g., {"name": "string", "age": "i64"})
|
|
175
|
+
or a list of {"name": str, "type": str} objects.
|
|
176
|
+
edge_attributes: Schema for edge attributes
|
|
177
|
+
(same format as vertex_attributes).
|
|
178
|
+
max_type_errors: Maximum number of type errors to report
|
|
179
|
+
before stopping. None uses the library default.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Tuple of (features_by_col, coo_map, col_to_key_to_ind,
|
|
183
|
+
col_to_ind_to_key)
|
|
184
|
+
"""
|
|
185
|
+
if not queries or not any(len(group) > 0 for group in queries):
|
|
186
|
+
raise PhenolError("At least one AQL query must be provided")
|
|
187
|
+
|
|
188
|
+
request = self._build_request(
|
|
189
|
+
queries, vertex_attributes, edge_attributes, max_type_errors
|
|
190
|
+
)
|
|
191
|
+
return graph_aql_to_numpy_format(request) # type: ignore[arg-type]
|
|
192
|
+
|
|
193
|
+
def load_to_networkx(
|
|
194
|
+
self,
|
|
195
|
+
queries: List[List[AqlQuery]],
|
|
196
|
+
vertex_attributes: Optional[AttributeSpec] = None,
|
|
197
|
+
edge_attributes: Optional[AttributeSpec] = None,
|
|
198
|
+
load_adj_dict: bool = True,
|
|
199
|
+
load_coo: bool = True,
|
|
200
|
+
is_directed: bool = True,
|
|
201
|
+
is_multigraph: bool = True,
|
|
202
|
+
symmetrize_edges_if_directed: bool = False,
|
|
203
|
+
max_type_errors: Optional[int] = None,
|
|
204
|
+
) -> Any:
|
|
205
|
+
"""Load a graph using AQL queries into NetworkX-compatible format.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
queries: List of query groups. Outer list is sequential,
|
|
209
|
+
inner lists are parallel.
|
|
210
|
+
Each query should return {"vertices": [...], "edges": [...]}.
|
|
211
|
+
vertex_attributes: Schema for vertex attributes.
|
|
212
|
+
edge_attributes: Schema for edge attributes.
|
|
213
|
+
load_adj_dict: Whether to load adjacency dictionary (default: True)
|
|
214
|
+
load_coo: Whether to load COO format (default: True)
|
|
215
|
+
is_directed: Whether the graph is directed (default: True)
|
|
216
|
+
is_multigraph: Whether to allow multiple edges (default: True)
|
|
217
|
+
symmetrize_edges_if_directed: Add reverse edges (default: False)
|
|
218
|
+
max_type_errors: Maximum number of type errors to report
|
|
219
|
+
before stopping. None uses the library default.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
A tuple of (node_dict, adj_dict, src_indices, dst_indices,
|
|
223
|
+
edge_indices, vertex_id_to_index, edge_values)
|
|
224
|
+
"""
|
|
225
|
+
if not queries or not any(len(group) > 0 for group in queries):
|
|
226
|
+
raise PhenolError("At least one AQL query must be provided")
|
|
227
|
+
|
|
228
|
+
request = self._build_request(
|
|
229
|
+
queries, vertex_attributes, edge_attributes, max_type_errors
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
graph_config = {
|
|
233
|
+
"load_adj_dict": load_adj_dict,
|
|
234
|
+
"load_coo": load_coo,
|
|
235
|
+
"is_directed": is_directed,
|
|
236
|
+
"is_multigraph": is_multigraph,
|
|
237
|
+
"symmetrize_edges_if_directed": symmetrize_edges_if_directed,
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
return graph_aql_to_networkx_format(
|
|
241
|
+
request, graph_config # type: ignore[arg-type]
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
def load_to_pyg_data(
|
|
245
|
+
self,
|
|
246
|
+
queries: List[List[AqlQuery]],
|
|
247
|
+
vertex_attributes: Optional[AttributeSpec] = None,
|
|
248
|
+
edge_attributes: Optional[AttributeSpec] = None,
|
|
249
|
+
pyg_feature_mapping: Optional[Dict[str, List[str]]] = None,
|
|
250
|
+
max_type_errors: Optional[int] = None,
|
|
251
|
+
) -> Tuple["Data", Dict[str, Dict[str, int]], Dict[str, Dict[int, str]]]:
|
|
252
|
+
"""Load a graph using AQL queries into PyTorch Geometric Data format.
|
|
253
|
+
|
|
254
|
+
This method loads a homogeneous graph (single node type, single edge type)
|
|
255
|
+
into a PyG Data object suitable for GNN training.
|
|
256
|
+
|
|
257
|
+
Note:
|
|
258
|
+
Edges are required for PyG Data format. The returned Data object
|
|
259
|
+
will always have an edge_index tensor. Ensure your queries return
|
|
260
|
+
edge documents in the "edges" array. For vertex-only graphs,
|
|
261
|
+
use :meth:`load_to_networkx` or :meth:`load_to_numpy` instead.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
queries: List of query groups. Outer list is sequential,
|
|
265
|
+
inner lists are parallel.
|
|
266
|
+
Each query should return {"vertices": [...], "edges": [...]}.
|
|
267
|
+
vertex_attributes: Schema for vertex attributes.
|
|
268
|
+
Maps attribute names to types (e.g., {"features": "f64"}).
|
|
269
|
+
Attributes must be numeric (f64, i64) for PyG compatibility.
|
|
270
|
+
edge_attributes: Schema for edge attributes.
|
|
271
|
+
pyg_feature_mapping: Optional mapping from PyG attribute names to
|
|
272
|
+
loaded attribute names.
|
|
273
|
+
Example: {"x": ["feat1", "feat2"], "y": ["label"]}
|
|
274
|
+
If None, all numeric vertex attributes are stacked into 'x'.
|
|
275
|
+
max_type_errors: Maximum number of type errors to report.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
A tuple of (Data, col_to_key_to_ind, col_to_ind_to_key)
|
|
279
|
+
|
|
280
|
+
Raises:
|
|
281
|
+
ImportError: If PyTorch/PyG dependencies are not installed.
|
|
282
|
+
PhenolError: If no queries are provided, no vertex/edge data is loaded,
|
|
283
|
+
multiple vertex collections or edge types are found (use
|
|
284
|
+
:meth:`load_to_pyg_heterodata` for heterogeneous graphs), or if
|
|
285
|
+
attributes have incompatible types (strings/objects).
|
|
286
|
+
|
|
287
|
+
Example:
|
|
288
|
+
>>> loader = AqlLoader(hosts=["http://localhost:8529"], database="mydb")
|
|
289
|
+
>>> queries = [[{"query": "FOR v IN users RETURN {vertices: [v]}"}],
|
|
290
|
+
... [{"query": "FOR e IN follows RETURN {edges: [e]}"}]]
|
|
291
|
+
>>> data, key_to_ind, ind_to_key = loader.load_to_pyg_data(
|
|
292
|
+
... queries=queries,
|
|
293
|
+
... vertex_attributes={"features": "f64", "label": "i64"},
|
|
294
|
+
... pyg_feature_mapping={"x": ["features"], "y": ["label"]}
|
|
295
|
+
... )
|
|
296
|
+
"""
|
|
297
|
+
if not TORCH_AVAILABLE:
|
|
298
|
+
raise ImportError(
|
|
299
|
+
"Missing required dependencies. "
|
|
300
|
+
"Install with `pip install phenolrs[torch]`"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if not queries or not any(len(group) > 0 for group in queries):
|
|
304
|
+
raise PhenolError("At least one AQL query must be provided")
|
|
305
|
+
|
|
306
|
+
# Load data as numpy first
|
|
307
|
+
request = self._build_request(
|
|
308
|
+
queries, vertex_attributes, edge_attributes, max_type_errors
|
|
309
|
+
)
|
|
310
|
+
(
|
|
311
|
+
features_by_col,
|
|
312
|
+
coo_map,
|
|
313
|
+
col_to_adb_key_to_ind,
|
|
314
|
+
col_to_ind_to_adb_key,
|
|
315
|
+
) = graph_aql_to_numpy_format(request) # type: ignore[arg-type] # fmt: skip
|
|
316
|
+
|
|
317
|
+
# For homogeneous graph, we expect exactly one vertex collection
|
|
318
|
+
vertex_cols = [c for c in features_by_col.keys() if c != "@collection_name"]
|
|
319
|
+
if len(vertex_cols) == 0:
|
|
320
|
+
raise PhenolError("No vertex data loaded from AQL queries")
|
|
321
|
+
if len(vertex_cols) > 1:
|
|
322
|
+
m = (
|
|
323
|
+
f"Multiple vertex collections ({vertex_cols}) found. "
|
|
324
|
+
"Use load_to_pyg_heterodata for heterogeneous graphs."
|
|
325
|
+
)
|
|
326
|
+
raise PhenolError(m)
|
|
327
|
+
|
|
328
|
+
v_col_name = vertex_cols[0]
|
|
329
|
+
v_features = features_by_col[v_col_name]
|
|
330
|
+
|
|
331
|
+
data = Data()
|
|
332
|
+
|
|
333
|
+
# Build feature mapping
|
|
334
|
+
if pyg_feature_mapping is not None:
|
|
335
|
+
# User specified mapping
|
|
336
|
+
for pyg_name, attr_list in pyg_feature_mapping.items():
|
|
337
|
+
tensors = []
|
|
338
|
+
for attr_name in attr_list:
|
|
339
|
+
if attr_name not in v_features:
|
|
340
|
+
raise PhenolError(
|
|
341
|
+
f"Attribute '{attr_name}' not found in loaded data. "
|
|
342
|
+
f"Available: {list(v_features.keys())}"
|
|
343
|
+
)
|
|
344
|
+
arr = v_features[attr_name]
|
|
345
|
+
# Check if attribute is string type (not convertible to numeric)
|
|
346
|
+
if arr.dtype.kind in ("U", "S", "O"):
|
|
347
|
+
raise PhenolError(
|
|
348
|
+
f"Attribute '{attr_name}' has string/object type "
|
|
349
|
+
"which cannot be converted to PyG tensors. "
|
|
350
|
+
"PyG requires numeric types (i64, f64, bool)."
|
|
351
|
+
)
|
|
352
|
+
if arr.ndim == 1:
|
|
353
|
+
arr = arr.reshape(-1, 1)
|
|
354
|
+
tensors.append(torch.from_numpy(arr.astype(np.float64)))
|
|
355
|
+
|
|
356
|
+
if tensors:
|
|
357
|
+
combined = torch.cat(tensors, dim=1)
|
|
358
|
+
if combined.numel() > 0:
|
|
359
|
+
data[pyg_name] = combined
|
|
360
|
+
else:
|
|
361
|
+
# Auto-mapping: stack all numeric attributes into 'x'
|
|
362
|
+
tensors = []
|
|
363
|
+
for attr_name, arr in v_features.items():
|
|
364
|
+
if attr_name == "@collection_name":
|
|
365
|
+
continue
|
|
366
|
+
# Check if attribute is string type (not convertible to numeric)
|
|
367
|
+
if arr.dtype.kind in ("U", "S", "O"):
|
|
368
|
+
raise PhenolError(
|
|
369
|
+
f"Attribute '{attr_name}' has string/object type "
|
|
370
|
+
"which cannot be converted to PyG tensors. "
|
|
371
|
+
"PyG requires numeric types (i64, f64, bool). "
|
|
372
|
+
"Exclude string attributes or use pyg_feature_mapping."
|
|
373
|
+
)
|
|
374
|
+
if arr.ndim == 1:
|
|
375
|
+
arr = arr.reshape(-1, 1)
|
|
376
|
+
tensors.append(torch.from_numpy(arr.astype(np.float64)))
|
|
377
|
+
|
|
378
|
+
if tensors:
|
|
379
|
+
combined = torch.cat(tensors, dim=1)
|
|
380
|
+
if combined.numel() > 0:
|
|
381
|
+
data.x = combined
|
|
382
|
+
|
|
383
|
+
# Add edges - expect exactly one edge type for homogeneous graph
|
|
384
|
+
if len(coo_map) == 0:
|
|
385
|
+
raise PhenolError("No edge data loaded from AQL queries")
|
|
386
|
+
if len(coo_map) > 1:
|
|
387
|
+
m = (
|
|
388
|
+
"Multiple edge types found. "
|
|
389
|
+
"Use load_to_pyg_heterodata for heterogeneous graphs."
|
|
390
|
+
)
|
|
391
|
+
raise PhenolError(m)
|
|
392
|
+
|
|
393
|
+
edge_key = list(coo_map.keys())[0]
|
|
394
|
+
edge_index = torch.from_numpy(coo_map[edge_key].astype(np.int64))
|
|
395
|
+
# Always assign edge_index, even if empty (use proper empty tensor shape)
|
|
396
|
+
if edge_index.numel() > 0:
|
|
397
|
+
data.edge_index = edge_index
|
|
398
|
+
else:
|
|
399
|
+
data.edge_index = torch.empty((2, 0), dtype=torch.long)
|
|
400
|
+
|
|
401
|
+
return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key
|
|
402
|
+
|
|
403
|
+
def load_to_pyg_heterodata(
|
|
404
|
+
self,
|
|
405
|
+
queries: List[List[AqlQuery]],
|
|
406
|
+
vertex_attributes: Optional[AttributeSpec] = None,
|
|
407
|
+
edge_attributes: Optional[AttributeSpec] = None,
|
|
408
|
+
pyg_feature_mapping: Optional[Dict[str, Dict[str, List[str]]]] = None,
|
|
409
|
+
max_type_errors: Optional[int] = None,
|
|
410
|
+
) -> Tuple["HeteroData", Dict[str, Dict[str, int]], Dict[str, Dict[int, str]]]:
|
|
411
|
+
"""Load a graph using AQL queries into PyTorch Geometric HeteroData format.
|
|
412
|
+
|
|
413
|
+
This method loads a heterogeneous graph (multiple node/edge types)
|
|
414
|
+
into a PyG HeteroData object suitable for heterogeneous GNN training.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
queries: List of query groups. Outer list is sequential,
|
|
418
|
+
inner lists are parallel.
|
|
419
|
+
Each query should return {"vertices": [...], "edges": [...]}.
|
|
420
|
+
vertex_attributes: Schema for vertex attributes.
|
|
421
|
+
edge_attributes: Schema for edge attributes.
|
|
422
|
+
pyg_feature_mapping: Optional nested mapping from collection names to
|
|
423
|
+
PyG attribute mappings. Example:
|
|
424
|
+
{"Users": {"x": ["feat1"], "y": ["label"]},
|
|
425
|
+
"Products": {"x": ["features"]}}
|
|
426
|
+
If None, all numeric attributes per collection are stacked into 'x'.
|
|
427
|
+
max_type_errors: Maximum number of type errors to report.
|
|
428
|
+
|
|
429
|
+
Returns:
|
|
430
|
+
A tuple of (HeteroData, col_to_key_to_ind, col_to_ind_to_key)
|
|
431
|
+
|
|
432
|
+
Example:
|
|
433
|
+
>>> loader = AqlLoader(hosts=["http://localhost:8529"], database="mydb")
|
|
434
|
+
>>> queries = [
|
|
435
|
+
... [{"query": "FOR v IN users RETURN {vertices: [v]}"},
|
|
436
|
+
... {"query": "FOR v IN products RETURN {vertices: [v]}"}],
|
|
437
|
+
... [{"query": "FOR e IN purchases RETURN {edges: [e]}"}]
|
|
438
|
+
... ]
|
|
439
|
+
>>> data, key_to_ind, ind_to_key = loader.load_to_pyg_heterodata(
|
|
440
|
+
... queries=queries,
|
|
441
|
+
... vertex_attributes={"features": "f64", "label": "i64"},
|
|
442
|
+
... pyg_feature_mapping={
|
|
443
|
+
... "users": {"x": ["features"], "y": ["label"]},
|
|
444
|
+
... "products": {"x": ["features"]}
|
|
445
|
+
... }
|
|
446
|
+
... )
|
|
447
|
+
"""
|
|
448
|
+
if not TORCH_AVAILABLE:
|
|
449
|
+
raise ImportError(
|
|
450
|
+
"Missing required dependencies. "
|
|
451
|
+
"Install with `pip install phenolrs[torch]`"
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
if not queries or not any(len(group) > 0 for group in queries):
|
|
455
|
+
raise PhenolError("At least one AQL query must be provided")
|
|
456
|
+
|
|
457
|
+
# Load data as numpy first
|
|
458
|
+
request = self._build_request(
|
|
459
|
+
queries, vertex_attributes, edge_attributes, max_type_errors
|
|
460
|
+
)
|
|
461
|
+
(
|
|
462
|
+
features_by_col,
|
|
463
|
+
coo_map,
|
|
464
|
+
col_to_adb_key_to_ind,
|
|
465
|
+
col_to_ind_to_adb_key,
|
|
466
|
+
) = graph_aql_to_numpy_format(request) # type: ignore[arg-type] # fmt: skip
|
|
467
|
+
|
|
468
|
+
data = HeteroData()
|
|
469
|
+
|
|
470
|
+
# Validate that collections referenced in pyg_feature_mapping exist
|
|
471
|
+
# This catches cases where string attributes were requested (which are
|
|
472
|
+
# silently dropped by the Rust backend, resulting in no vertex data)
|
|
473
|
+
if pyg_feature_mapping is not None:
|
|
474
|
+
for col_name in pyg_feature_mapping:
|
|
475
|
+
if col_name not in features_by_col:
|
|
476
|
+
raise PhenolError(
|
|
477
|
+
f"No vertex data loaded for collection '{col_name}'. "
|
|
478
|
+
"This may occur if only string/object type attributes were "
|
|
479
|
+
"requested, which cannot be converted to PyG tensors."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Process vertex features per collection
|
|
483
|
+
for col_name, col_features in features_by_col.items():
|
|
484
|
+
if pyg_feature_mapping is not None and col_name in pyg_feature_mapping:
|
|
485
|
+
# User specified mapping for this collection
|
|
486
|
+
col_mapping = pyg_feature_mapping[col_name]
|
|
487
|
+
for pyg_name, attr_list in col_mapping.items():
|
|
488
|
+
tensors = []
|
|
489
|
+
for attr_name in attr_list:
|
|
490
|
+
if attr_name not in col_features:
|
|
491
|
+
raise PhenolError(
|
|
492
|
+
f"Attribute '{attr_name}' not found in collection "
|
|
493
|
+
f"'{col_name}'. Available: {list(col_features.keys())}"
|
|
494
|
+
)
|
|
495
|
+
arr = col_features[attr_name]
|
|
496
|
+
# Check if attribute is string type (not convertible to numeric)
|
|
497
|
+
if arr.dtype.kind in ("U", "S", "O"):
|
|
498
|
+
raise PhenolError(
|
|
499
|
+
f"Attribute '{attr_name}' in '{col_name}' has "
|
|
500
|
+
"string/object type which cannot be converted "
|
|
501
|
+
"to PyG tensors. Requires numeric (i64, f64, bool)."
|
|
502
|
+
)
|
|
503
|
+
if arr.ndim == 1:
|
|
504
|
+
arr = arr.reshape(-1, 1)
|
|
505
|
+
tensors.append(torch.from_numpy(arr.astype(np.float64)))
|
|
506
|
+
|
|
507
|
+
if tensors:
|
|
508
|
+
combined = torch.cat(tensors, dim=1)
|
|
509
|
+
if combined.numel() > 0:
|
|
510
|
+
data[col_name][pyg_name] = combined
|
|
511
|
+
else:
|
|
512
|
+
# Auto-mapping: stack all numeric attributes into 'x'
|
|
513
|
+
tensors = []
|
|
514
|
+
for attr_name, arr in col_features.items():
|
|
515
|
+
if attr_name == "@collection_name":
|
|
516
|
+
continue
|
|
517
|
+
# Check if attribute is string type (not convertible to numeric)
|
|
518
|
+
if arr.dtype.kind in ("U", "S", "O"):
|
|
519
|
+
raise PhenolError(
|
|
520
|
+
f"Attribute '{attr_name}' in '{col_name}' has "
|
|
521
|
+
"string/object type which cannot be converted "
|
|
522
|
+
"to PyG tensors. Requires numeric (i64, f64, bool). "
|
|
523
|
+
"Exclude string attrs or use pyg_feature_mapping."
|
|
524
|
+
)
|
|
525
|
+
if arr.ndim == 1:
|
|
526
|
+
arr = arr.reshape(-1, 1)
|
|
527
|
+
tensors.append(torch.from_numpy(arr.astype(np.float64)))
|
|
528
|
+
|
|
529
|
+
if tensors:
|
|
530
|
+
combined = torch.cat(tensors, dim=1)
|
|
531
|
+
if combined.numel() > 0:
|
|
532
|
+
data[col_name].x = combined
|
|
533
|
+
|
|
534
|
+
# Add edges per edge type
|
|
535
|
+
for edge_key, edge_coo in coo_map.items():
|
|
536
|
+
edge_col_name, from_col, to_col = edge_key
|
|
537
|
+
edge_index = torch.from_numpy(edge_coo.astype(np.int64))
|
|
538
|
+
if edge_index.numel() > 0:
|
|
539
|
+
data[(from_col, edge_col_name, to_col)].edge_index = edge_index
|
|
540
|
+
|
|
541
|
+
return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key
|
|
542
|
+
|
|
543
|
+
@staticmethod
|
|
544
|
+
def create_vertex_query(
|
|
545
|
+
collection: str,
|
|
546
|
+
filter_condition: Optional[str] = None,
|
|
547
|
+
projection: Optional[List[str]] = None,
|
|
548
|
+
bind_vars: Optional[Dict[str, Any]] = None,
|
|
549
|
+
) -> AqlQuery:
|
|
550
|
+
"""Helper to create a vertex loading query.
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
collection: The vertex collection name
|
|
554
|
+
filter_condition: Optional AQL filter condition
|
|
555
|
+
(without FILTER keyword). Security note: Use bind_vars
|
|
556
|
+
for any user-provided values to prevent AQL injection.
|
|
557
|
+
projection: Optional list of fields to project.
|
|
558
|
+
If None, returns full document.
|
|
559
|
+
bind_vars: Optional bind variables
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
An AqlQuery object ready to use
|
|
563
|
+
|
|
564
|
+
Example:
|
|
565
|
+
>>> AqlLoader.create_vertex_query(
|
|
566
|
+
... "users", "doc.active == true", ["name", "age"]
|
|
567
|
+
... )
|
|
568
|
+
"""
|
|
569
|
+
_validate_identifier(collection, "collection")
|
|
570
|
+
query_parts = [f"FOR doc IN `{collection}`"]
|
|
571
|
+
|
|
572
|
+
if filter_condition:
|
|
573
|
+
query_parts.append(f"FILTER {filter_condition}")
|
|
574
|
+
|
|
575
|
+
if projection:
|
|
576
|
+
# Build projection with _id always included
|
|
577
|
+
# Validate field names to prevent injection
|
|
578
|
+
for f in projection:
|
|
579
|
+
_validate_identifier(f, "projection field")
|
|
580
|
+
fields = ["_id: doc._id"]
|
|
581
|
+
# Skip _id if already in projection to avoid duplicate keys
|
|
582
|
+
fields.extend([f"`{f}`: doc.`{f}`" for f in projection if f != "_id"])
|
|
583
|
+
return_expr = "{" + ", ".join(fields) + "}"
|
|
584
|
+
query_parts.append(f"RETURN {{vertices: [{return_expr}]}}")
|
|
585
|
+
else:
|
|
586
|
+
query_parts.append("RETURN {vertices: [doc]}")
|
|
587
|
+
|
|
588
|
+
return {
|
|
589
|
+
"query": " ".join(query_parts),
|
|
590
|
+
"bindVars": bind_vars or {},
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
@staticmethod
|
|
594
|
+
def create_edge_query(
|
|
595
|
+
collection: str,
|
|
596
|
+
filter_condition: Optional[str] = None,
|
|
597
|
+
projection: Optional[List[str]] = None,
|
|
598
|
+
bind_vars: Optional[Dict[str, Any]] = None,
|
|
599
|
+
) -> AqlQuery:
|
|
600
|
+
"""Helper to create an edge loading query.
|
|
601
|
+
|
|
602
|
+
Args:
|
|
603
|
+
collection: The edge collection name
|
|
604
|
+
filter_condition: Optional AQL filter condition
|
|
605
|
+
(without FILTER keyword). Security note: Use bind_vars
|
|
606
|
+
for any user-provided values to prevent AQL injection.
|
|
607
|
+
projection: Optional list of fields to project.
|
|
608
|
+
If None, returns full document.
|
|
609
|
+
bind_vars: Optional bind variables
|
|
610
|
+
|
|
611
|
+
Returns:
|
|
612
|
+
An AqlQuery object ready to use
|
|
613
|
+
"""
|
|
614
|
+
_validate_identifier(collection, "collection")
|
|
615
|
+
query_parts = [f"FOR doc IN `{collection}`"]
|
|
616
|
+
|
|
617
|
+
if filter_condition:
|
|
618
|
+
query_parts.append(f"FILTER {filter_condition}")
|
|
619
|
+
|
|
620
|
+
if projection:
|
|
621
|
+
# Build projection with _from and _to always included
|
|
622
|
+
# Validate field names to prevent injection
|
|
623
|
+
for f in projection:
|
|
624
|
+
if f not in ("_from", "_to"):
|
|
625
|
+
_validate_identifier(f, "projection field")
|
|
626
|
+
fields = ["_from: doc._from", "_to: doc._to"]
|
|
627
|
+
fields.extend(
|
|
628
|
+
[f"`{f}`: doc.`{f}`" for f in projection if f not in ("_from", "_to")]
|
|
629
|
+
)
|
|
630
|
+
return_expr = "{" + ", ".join(fields) + "}"
|
|
631
|
+
query_parts.append(f"RETURN {{edges: [{return_expr}]}}")
|
|
632
|
+
else:
|
|
633
|
+
query_parts.append("RETURN {edges: [doc]}")
|
|
634
|
+
|
|
635
|
+
return {
|
|
636
|
+
"query": " ".join(query_parts),
|
|
637
|
+
"bindVars": bind_vars or {},
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
@staticmethod
|
|
641
|
+
def create_traversal_query(
|
|
642
|
+
start_vertex: str,
|
|
643
|
+
graph_name: str,
|
|
644
|
+
min_depth: int = 1,
|
|
645
|
+
max_depth: int = 1,
|
|
646
|
+
direction: str = "OUTBOUND",
|
|
647
|
+
prune_condition: Optional[str] = None,
|
|
648
|
+
filter_condition: Optional[str] = None,
|
|
649
|
+
bind_vars: Optional[Dict[str, Any]] = None,
|
|
650
|
+
) -> AqlQuery:
|
|
651
|
+
"""Helper to create a graph traversal query.
|
|
652
|
+
|
|
653
|
+
Args:
|
|
654
|
+
start_vertex: The starting vertex. Must be either:
|
|
655
|
+
- A bind variable reference (e.g., "@start")
|
|
656
|
+
- A quoted literal (e.g., "'users/alice'")
|
|
657
|
+
Use bind_vars for user-provided values to prevent injection.
|
|
658
|
+
graph_name: The named graph to traverse
|
|
659
|
+
min_depth: Minimum traversal depth (default: 1)
|
|
660
|
+
max_depth: Maximum traversal depth (default: 1)
|
|
661
|
+
direction: Traversal direction - "OUTBOUND", "INBOUND", or "ANY"
|
|
662
|
+
(default: "OUTBOUND")
|
|
663
|
+
prune_condition: Optional PRUNE condition. Security note: Use
|
|
664
|
+
bind_vars for user-provided values.
|
|
665
|
+
filter_condition: Optional FILTER condition. Security note: Use
|
|
666
|
+
bind_vars for user-provided values.
|
|
667
|
+
bind_vars: Optional bind variables
|
|
668
|
+
|
|
669
|
+
Returns:
|
|
670
|
+
An AqlQuery object ready to use
|
|
671
|
+
|
|
672
|
+
Example:
|
|
673
|
+
>>> AqlLoader.create_traversal_query(
|
|
674
|
+
... "@start", "myGraph", 0, 3, bind_vars={"start": "users/1"}
|
|
675
|
+
... )
|
|
676
|
+
"""
|
|
677
|
+
# Validate start_vertex format to prevent malformed queries
|
|
678
|
+
if not (
|
|
679
|
+
start_vertex.startswith("@")
|
|
680
|
+
or (start_vertex.startswith("'") and start_vertex.endswith("'"))
|
|
681
|
+
):
|
|
682
|
+
raise ValueError(
|
|
683
|
+
"start_vertex must be a bind variable (@var) or quoted literal "
|
|
684
|
+
f"('value'), got: {start_vertex}"
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Validate direction parameter
|
|
688
|
+
valid_directions = ("OUTBOUND", "INBOUND", "ANY")
|
|
689
|
+
if direction not in valid_directions:
|
|
690
|
+
raise ValueError(
|
|
691
|
+
f"direction must be one of {valid_directions}, got: '{direction}'"
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
# Validate depth parameters
|
|
695
|
+
if min_depth < 0:
|
|
696
|
+
raise ValueError(f"min_depth must be non-negative, got: {min_depth}")
|
|
697
|
+
if max_depth < 0:
|
|
698
|
+
raise ValueError(f"max_depth must be non-negative, got: {max_depth}")
|
|
699
|
+
if max_depth < min_depth:
|
|
700
|
+
raise ValueError(
|
|
701
|
+
f"max_depth ({max_depth}) must be >= min_depth ({min_depth})"
|
|
702
|
+
)
|
|
703
|
+
|
|
704
|
+
# Use 0..max_depth to include the start vertex
|
|
705
|
+
_validate_identifier(graph_name, "graph_name")
|
|
706
|
+
query_parts = [
|
|
707
|
+
f"FOR v, e IN {min_depth}..{max_depth} {direction} "
|
|
708
|
+
f"{start_vertex} GRAPH `{graph_name}`"
|
|
709
|
+
]
|
|
710
|
+
|
|
711
|
+
if prune_condition:
|
|
712
|
+
query_parts.append(f"PRUNE {prune_condition}")
|
|
713
|
+
|
|
714
|
+
if filter_condition:
|
|
715
|
+
query_parts.append(f"FILTER {filter_condition}")
|
|
716
|
+
|
|
717
|
+
# Handle null edges when min_depth=0 (start vertex has no edge)
|
|
718
|
+
query_parts.append("RETURN {vertices: [v], edges: (e == null ? [] : [e])}")
|
|
719
|
+
|
|
720
|
+
return {
|
|
721
|
+
"query": " ".join(query_parts),
|
|
722
|
+
"bindVars": bind_vars or {},
|
|
723
|
+
}
|