phenolrs 0.5.12__cp313-cp313-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- phenolrs/__init__.py +5 -0
- phenolrs/aql/__init__.py +4 -0
- phenolrs/aql/loader.py +723 -0
- phenolrs/aql/typings.py +98 -0
- phenolrs/networkx/__init__.py +1 -0
- phenolrs/networkx/loader.py +144 -0
- phenolrs/networkx/typings.py +17 -0
- phenolrs/numpy/__init__.py +1 -0
- phenolrs/numpy/loader.py +114 -0
- phenolrs/numpy/typings.py +12 -0
- phenolrs/phenolrs.cpython-313-aarch64-linux-gnu.so +0 -0
- phenolrs/phenolrs.pyi +84 -0
- phenolrs/pyg/__init__.py +1 -0
- phenolrs/pyg/loader.py +173 -0
- phenolrs/pyg/typings.py +2 -0
- phenolrs-0.5.12.dist-info/METADATA +24 -0
- phenolrs-0.5.12.dist-info/RECORD +19 -0
- phenolrs-0.5.12.dist-info/WHEEL +4 -0
- phenolrs-0.5.12.dist-info/licenses/LICENSE +85 -0
phenolrs/aql/typings.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Type definitions for AQL-based graph loading.
|
|
2
|
+
|
|
3
|
+
This module follows the design specification from:
|
|
4
|
+
https://github.com/arangodb/documents/blob/master/DesignDocuments/02_PLANNING/GetGraphsOutOfArangoDBviaAQL.md
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Any, Dict, List, Literal, TypedDict, Union
|
|
8
|
+
|
|
9
|
+
# Supported data types for attributes
|
|
10
|
+
DataType = Literal[
|
|
11
|
+
"bool", "string", "u64", "i64", "f64", "json", "number", "int", "float"
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class AqlQuery(TypedDict, total=False):
|
|
16
|
+
"""An AQL query with optional bind variables.
|
|
17
|
+
|
|
18
|
+
Each query should return items of the form:
|
|
19
|
+
{"vertices": [...], "edges": [...]}
|
|
20
|
+
|
|
21
|
+
Both vertices and edges are optional in the return value.
|
|
22
|
+
Vertices must have at least an _id attribute.
|
|
23
|
+
Edges must have at least _from and _to attributes.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
query: str # Required: The AQL query string
|
|
27
|
+
bindVars: Dict[str, Any] # Optional: Bind variables for the query
|
|
28
|
+
bind_vars: Dict[str, Any] # Alternative name for bindVars
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class AttributeItem(TypedDict):
|
|
32
|
+
"""A single attribute definition with name and type."""
|
|
33
|
+
|
|
34
|
+
name: str
|
|
35
|
+
type: DataType
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Attribute specification can be either:
|
|
39
|
+
# - A dict mapping attribute names to types: {"name": "string", "age": "i64"}
|
|
40
|
+
# - A list of attribute items: [{"name": "name", "type": "string"}, ...]
|
|
41
|
+
AttributeSpec = Union[Dict[str, DataType], List[AttributeItem]]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class DatabaseConfig(TypedDict, total=False):
|
|
45
|
+
"""Database connection configuration."""
|
|
46
|
+
|
|
47
|
+
endpoints: List[str] # List of ArangoDB endpoints
|
|
48
|
+
database: str # Database name (default: "_system")
|
|
49
|
+
username: str # Username for authentication
|
|
50
|
+
password: str # Password for authentication
|
|
51
|
+
jwt_token: str # JWT token for authentication (alternative to username/password)
|
|
52
|
+
tls_cert: str # TLS certificate for secure connections
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class AqlDataLoadRequest(TypedDict, total=False):
|
|
56
|
+
"""Request for AQL-based graph loading.
|
|
57
|
+
|
|
58
|
+
The queries field is a list of lists of AQL queries:
|
|
59
|
+
- The outer list is processed **sequentially**
|
|
60
|
+
- Each inner list of queries can be executed **in parallel**
|
|
61
|
+
|
|
62
|
+
This structure allows for:
|
|
63
|
+
1. Loading vertices first, then edges (for efficient memory usage)
|
|
64
|
+
2. Parallel loading of multiple vertex/edge collections
|
|
65
|
+
3. Sequential execution of dependent operations
|
|
66
|
+
|
|
67
|
+
Example for use case 1 (filtered collections)::
|
|
68
|
+
|
|
69
|
+
{
|
|
70
|
+
"queries": [
|
|
71
|
+
# First group: load all vertices in parallel
|
|
72
|
+
[
|
|
73
|
+
{"query": "FOR x IN v1 RETURN {vertices: [x]}"},
|
|
74
|
+
{"query": "FOR x IN v2 RETURN {vertices: [x]}"}
|
|
75
|
+
],
|
|
76
|
+
# Second group: load all edges in parallel
|
|
77
|
+
[
|
|
78
|
+
{"query": "FOR e IN edges1 RETURN {edges: [e]}"},
|
|
79
|
+
]
|
|
80
|
+
]
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
Example for use case 2 (graph traversals)::
|
|
84
|
+
|
|
85
|
+
{
|
|
86
|
+
"queries": [[{
|
|
87
|
+
"query": "FOR v, e IN 0..10 OUTBOUND @s GRAPH 'g' "
|
|
88
|
+
"RETURN {vertices: [v], edges: [e]}",
|
|
89
|
+
"bindVars": {"s": "vertex/1"}
|
|
90
|
+
}]]
|
|
91
|
+
}
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
database_config: DatabaseConfig
|
|
95
|
+
batch_size: int # Number of items per batch (default: 10000)
|
|
96
|
+
vertex_attributes: AttributeSpec # Schema for vertex attributes
|
|
97
|
+
edge_attributes: AttributeSpec # Schema for edge attributes
|
|
98
|
+
queries: List[List[AqlQuery]] # List of query groups
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .loader import NetworkXLoader # noqa: F401
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import Any, Set, Tuple
|
|
2
|
+
|
|
3
|
+
from phenolrs import PhenolError, graph_to_networkx_format
|
|
4
|
+
|
|
5
|
+
from .typings import (
|
|
6
|
+
ArangoIDtoIndex,
|
|
7
|
+
DiGraphAdjDict,
|
|
8
|
+
DstIndices,
|
|
9
|
+
EdgeIndices,
|
|
10
|
+
EdgeValuesDict,
|
|
11
|
+
GraphAdjDict,
|
|
12
|
+
MultiDiGraphAdjDict,
|
|
13
|
+
MultiGraphAdjDict,
|
|
14
|
+
NodeDict,
|
|
15
|
+
SrcIndices,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class NetworkXLoader:
|
|
20
|
+
@staticmethod
|
|
21
|
+
def load_into_networkx(
|
|
22
|
+
database: str,
|
|
23
|
+
metagraph: dict[str, dict[str, Set[str]]],
|
|
24
|
+
hosts: list[str],
|
|
25
|
+
user_jwt: str | None = None,
|
|
26
|
+
username: str | None = None,
|
|
27
|
+
password: str | None = None,
|
|
28
|
+
tls_cert: Any | None = None,
|
|
29
|
+
parallelism: int | None = None,
|
|
30
|
+
batch_size: int | None = None,
|
|
31
|
+
load_adj_dict: bool = True,
|
|
32
|
+
load_coo: bool = True,
|
|
33
|
+
load_all_vertex_attributes: bool = True,
|
|
34
|
+
load_all_edge_attributes: bool = True,
|
|
35
|
+
is_directed: bool = True,
|
|
36
|
+
is_multigraph: bool = True,
|
|
37
|
+
symmetrize_edges_if_directed: bool = False,
|
|
38
|
+
) -> Tuple[
|
|
39
|
+
NodeDict,
|
|
40
|
+
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
|
|
41
|
+
SrcIndices,
|
|
42
|
+
DstIndices,
|
|
43
|
+
EdgeIndices,
|
|
44
|
+
ArangoIDtoIndex,
|
|
45
|
+
EdgeValuesDict,
|
|
46
|
+
]:
|
|
47
|
+
if "vertexCollections" not in metagraph:
|
|
48
|
+
raise PhenolError("vertexCollections not found in metagraph")
|
|
49
|
+
|
|
50
|
+
if "edgeCollections" not in metagraph:
|
|
51
|
+
raise PhenolError("edgeCollections not found in metagraph")
|
|
52
|
+
|
|
53
|
+
if len(metagraph["vertexCollections"]) + len(metagraph["edgeCollections"]) == 0:
|
|
54
|
+
m = "vertexCollections and edgeCollections cannot both be empty"
|
|
55
|
+
raise PhenolError(m)
|
|
56
|
+
|
|
57
|
+
if len(metagraph["edgeCollections"]) == 0 and (load_adj_dict or load_coo):
|
|
58
|
+
m = "edgeCollections must be non-empty if **load_adj_dict** or **load_coo** is True" # noqa
|
|
59
|
+
raise PhenolError(m)
|
|
60
|
+
|
|
61
|
+
if load_all_vertex_attributes:
|
|
62
|
+
for entries in metagraph["vertexCollections"].values():
|
|
63
|
+
if len(entries) > 0:
|
|
64
|
+
m = f"load_all_vertex_attributes is True, but a vertexCollections entry contains attributes: {entries}" # noqa
|
|
65
|
+
raise PhenolError(m)
|
|
66
|
+
|
|
67
|
+
if load_all_edge_attributes:
|
|
68
|
+
for entries in metagraph["edgeCollections"].values():
|
|
69
|
+
if len(entries) > 0:
|
|
70
|
+
m = f"load_all_edge_attributes is True, but an edgeCollections entry contains attributes: {entries}" # noqa
|
|
71
|
+
raise PhenolError(m)
|
|
72
|
+
|
|
73
|
+
if len(metagraph["edgeCollections"]) != 0 and not (load_coo or load_adj_dict):
|
|
74
|
+
m = "load_coo and load_adj_dict cannot both be False if edgeCollections is non-empty" # noqa
|
|
75
|
+
raise PhenolError(m)
|
|
76
|
+
|
|
77
|
+
# TODO: replace with pydantic validation
|
|
78
|
+
db_config_options: dict[str, Any] = {
|
|
79
|
+
"endpoints": hosts,
|
|
80
|
+
"database": database,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
load_config_options: dict[str, Any] = {
|
|
84
|
+
"parallelism": parallelism if parallelism is not None else 8,
|
|
85
|
+
"batch_size": batch_size if batch_size is not None else 100000,
|
|
86
|
+
"prefetch_count": 5,
|
|
87
|
+
"load_all_vertex_attributes": load_all_vertex_attributes,
|
|
88
|
+
"load_all_edge_attributes": load_all_edge_attributes,
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if username:
|
|
92
|
+
db_config_options["username"] = username
|
|
93
|
+
if password:
|
|
94
|
+
db_config_options["password"] = password
|
|
95
|
+
if user_jwt:
|
|
96
|
+
db_config_options["jwt_token"] = user_jwt
|
|
97
|
+
if tls_cert:
|
|
98
|
+
db_config_options["tls_cert"] = tls_cert
|
|
99
|
+
|
|
100
|
+
graph_config = {
|
|
101
|
+
"load_adj_dict": load_adj_dict,
|
|
102
|
+
"load_coo": load_coo,
|
|
103
|
+
"is_directed": is_directed,
|
|
104
|
+
"is_multigraph": is_multigraph,
|
|
105
|
+
"symmetrize_edges_if_directed": symmetrize_edges_if_directed,
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
vertex_collections = [
|
|
109
|
+
{"name": v_col_name, "fields": list(entries)}
|
|
110
|
+
for v_col_name, entries in metagraph["vertexCollections"].items()
|
|
111
|
+
]
|
|
112
|
+
|
|
113
|
+
edge_collections = [
|
|
114
|
+
{"name": e_col_name, "fields": list(entries)}
|
|
115
|
+
for e_col_name, entries in metagraph["edgeCollections"].items()
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
(
|
|
119
|
+
node_dict,
|
|
120
|
+
adj_dict,
|
|
121
|
+
src_indices,
|
|
122
|
+
dst_indices,
|
|
123
|
+
edge_indices,
|
|
124
|
+
id_to_index_map,
|
|
125
|
+
edge_values,
|
|
126
|
+
) = graph_to_networkx_format(
|
|
127
|
+
request={
|
|
128
|
+
"vertex_collections": vertex_collections,
|
|
129
|
+
"edge_collections": edge_collections,
|
|
130
|
+
"database_config": db_config_options,
|
|
131
|
+
"load_config": load_config_options,
|
|
132
|
+
},
|
|
133
|
+
graph_config=graph_config, # TODO Anthony: Move into request
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return (
|
|
137
|
+
node_dict,
|
|
138
|
+
adj_dict,
|
|
139
|
+
src_indices,
|
|
140
|
+
dst_indices,
|
|
141
|
+
edge_indices,
|
|
142
|
+
id_to_index_map,
|
|
143
|
+
edge_values,
|
|
144
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
|
|
6
|
+
Json = dict[str, Any]
|
|
7
|
+
NodeDict = dict[str, Json]
|
|
8
|
+
GraphAdjDict = dict[str, dict[str, Json]]
|
|
9
|
+
DiGraphAdjDict = dict[str, GraphAdjDict]
|
|
10
|
+
MultiGraphAdjDict = dict[str, dict[str, dict[int, Json]]]
|
|
11
|
+
MultiDiGraphAdjDict = dict[str, MultiGraphAdjDict]
|
|
12
|
+
EdgeValuesDict = dict[str, list[int | float]]
|
|
13
|
+
|
|
14
|
+
SrcIndices = npt.NDArray[np.int64]
|
|
15
|
+
DstIndices = npt.NDArray[np.int64]
|
|
16
|
+
EdgeIndices = npt.NDArray[np.int64]
|
|
17
|
+
ArangoIDtoIndex = dict[str, int]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .loader import NumpyLoader # noqa: F401
|
phenolrs/numpy/loader.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
from phenolrs import PhenolError, graph_to_numpy_format
|
|
4
|
+
|
|
5
|
+
from .typings import (
|
|
6
|
+
ArangoCollectionSourceToOutput,
|
|
7
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
8
|
+
ArangoCollectionToIndexToArangoKey,
|
|
9
|
+
ArangoCollectionToNodeFeatures,
|
|
10
|
+
COOByEdgeType,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class NumpyLoader:
|
|
15
|
+
@staticmethod
|
|
16
|
+
def load_graph_to_numpy(
|
|
17
|
+
database: str,
|
|
18
|
+
metagraph: dict[str, Any],
|
|
19
|
+
hosts: list[str],
|
|
20
|
+
user_jwt: str | None = None,
|
|
21
|
+
username: str | None = None,
|
|
22
|
+
password: str | None = None,
|
|
23
|
+
tls_cert: Any | None = None,
|
|
24
|
+
parallelism: int | None = None,
|
|
25
|
+
batch_size: int | None = None,
|
|
26
|
+
) -> Tuple[
|
|
27
|
+
ArangoCollectionToNodeFeatures,
|
|
28
|
+
COOByEdgeType,
|
|
29
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
30
|
+
ArangoCollectionToIndexToArangoKey,
|
|
31
|
+
ArangoCollectionSourceToOutput,
|
|
32
|
+
]:
|
|
33
|
+
# TODO: replace with pydantic validation
|
|
34
|
+
db_config_options: dict[str, Any] = {
|
|
35
|
+
"endpoints": hosts,
|
|
36
|
+
"database": database,
|
|
37
|
+
}
|
|
38
|
+
load_config_options: dict[str, Any] = {
|
|
39
|
+
"parallelism": parallelism if parallelism is not None else 8,
|
|
40
|
+
"batch_size": batch_size if batch_size is not None else 100000,
|
|
41
|
+
"prefetch_count": 5,
|
|
42
|
+
"load_all_vertex_attributes": False,
|
|
43
|
+
"load_all_edge_attributes": False,
|
|
44
|
+
}
|
|
45
|
+
if username:
|
|
46
|
+
db_config_options["username"] = username
|
|
47
|
+
if password:
|
|
48
|
+
db_config_options["password"] = password
|
|
49
|
+
if user_jwt:
|
|
50
|
+
db_config_options["jwt_token"] = user_jwt
|
|
51
|
+
if tls_cert:
|
|
52
|
+
db_config_options["tls_cert"] = tls_cert
|
|
53
|
+
|
|
54
|
+
if "vertexCollections" not in metagraph:
|
|
55
|
+
raise PhenolError("vertexCollections not found in metagraph")
|
|
56
|
+
|
|
57
|
+
# Address the possibility of having something like this:
|
|
58
|
+
# "USER": {"x": {"features": None}}
|
|
59
|
+
# Should be converted to:
|
|
60
|
+
# "USER": {"x": "features"}
|
|
61
|
+
entries: dict[str, Any]
|
|
62
|
+
for v_col_name, entries in metagraph["vertexCollections"].items():
|
|
63
|
+
for source_name, value in entries.items():
|
|
64
|
+
if isinstance(value, dict):
|
|
65
|
+
if len(value) != 1:
|
|
66
|
+
m = f"Only one feature field should be specified per attribute. Found {value}" # noqa: E501
|
|
67
|
+
raise PhenolError(m)
|
|
68
|
+
|
|
69
|
+
value_key = list(value.keys())[0]
|
|
70
|
+
if value[value_key] is not None:
|
|
71
|
+
m = f"Invalid value for feature {source_name}: {value_key}. Found {value[value_key]}" # noqa: E501
|
|
72
|
+
raise PhenolError(m)
|
|
73
|
+
|
|
74
|
+
metagraph["vertexCollections"][v_col_name][source_name] = value_key
|
|
75
|
+
|
|
76
|
+
vertex_collections = [
|
|
77
|
+
{"name": v_col_name, "fields": list(entries.values())}
|
|
78
|
+
for v_col_name, entries in metagraph["vertexCollections"].items()
|
|
79
|
+
]
|
|
80
|
+
vertex_cols_source_to_output = {
|
|
81
|
+
v_col_name: {
|
|
82
|
+
source_name: output_name for output_name, source_name in entries.items()
|
|
83
|
+
}
|
|
84
|
+
for v_col_name, entries in metagraph["vertexCollections"].items()
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
edge_collections = []
|
|
88
|
+
if "edgeCollections" in metagraph:
|
|
89
|
+
edge_collections = [
|
|
90
|
+
{"name": e_col_name, "fields": list(entries.values())}
|
|
91
|
+
for e_col_name, entries in metagraph["edgeCollections"].items()
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
(
|
|
95
|
+
features_by_col,
|
|
96
|
+
coo_map,
|
|
97
|
+
col_to_adb_key_to_ind,
|
|
98
|
+
col_to_ind_to_adb_key,
|
|
99
|
+
) = graph_to_numpy_format(
|
|
100
|
+
{
|
|
101
|
+
"vertex_collections": vertex_collections,
|
|
102
|
+
"edge_collections": edge_collections,
|
|
103
|
+
"database_config": db_config_options,
|
|
104
|
+
"load_config": load_config_options,
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
return (
|
|
109
|
+
features_by_col,
|
|
110
|
+
coo_map,
|
|
111
|
+
col_to_adb_key_to_ind,
|
|
112
|
+
col_to_ind_to_adb_key,
|
|
113
|
+
vertex_cols_source_to_output,
|
|
114
|
+
)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
import numpy.typing as npt
|
|
5
|
+
|
|
6
|
+
EdgeType = Tuple[str, str, str]
|
|
7
|
+
|
|
8
|
+
ArangoCollectionToNodeFeatures = dict[str, dict[str, npt.NDArray[np.float64]]]
|
|
9
|
+
COOByEdgeType = dict[EdgeType, npt.NDArray[np.float64]]
|
|
10
|
+
ArangoCollectionToArangoKeyToIndex = dict[str, dict[str, int]]
|
|
11
|
+
ArangoCollectionToIndexToArangoKey = dict[str, dict[int, str]]
|
|
12
|
+
ArangoCollectionSourceToOutput = dict[str, dict[str, str]]
|
|
Binary file
|
phenolrs/phenolrs.pyi
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
from .networkx.typings import (
|
|
4
|
+
ArangoIDtoIndex,
|
|
5
|
+
DiGraphAdjDict,
|
|
6
|
+
DstIndices,
|
|
7
|
+
EdgeIndices,
|
|
8
|
+
EdgeValuesDict,
|
|
9
|
+
GraphAdjDict,
|
|
10
|
+
MultiDiGraphAdjDict,
|
|
11
|
+
MultiGraphAdjDict,
|
|
12
|
+
NodeDict,
|
|
13
|
+
SrcIndices,
|
|
14
|
+
)
|
|
15
|
+
from .numpy.typings import (
|
|
16
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
17
|
+
ArangoCollectionToIndexToArangoKey,
|
|
18
|
+
ArangoCollectionToNodeFeatures,
|
|
19
|
+
COOByEdgeType,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def graph_to_numpy_format(request: dict[str, typing.Any]) -> typing.Tuple[
|
|
23
|
+
ArangoCollectionToNodeFeatures,
|
|
24
|
+
COOByEdgeType,
|
|
25
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
26
|
+
ArangoCollectionToIndexToArangoKey,
|
|
27
|
+
]: ...
|
|
28
|
+
def graph_to_networkx_format(
|
|
29
|
+
request: dict[str, typing.Any], graph_config: dict[str, typing.Any]
|
|
30
|
+
) -> typing.Tuple[
|
|
31
|
+
NodeDict,
|
|
32
|
+
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
|
|
33
|
+
SrcIndices,
|
|
34
|
+
DstIndices,
|
|
35
|
+
EdgeIndices,
|
|
36
|
+
ArangoIDtoIndex,
|
|
37
|
+
EdgeValuesDict,
|
|
38
|
+
]: ...
|
|
39
|
+
|
|
40
|
+
# AQL-based graph loading types
|
|
41
|
+
class _AqlQueryRequired(typing.TypedDict):
|
|
42
|
+
"""Required fields for AqlQuery."""
|
|
43
|
+
|
|
44
|
+
query: str # The AQL query string
|
|
45
|
+
|
|
46
|
+
class AqlQuery(_AqlQueryRequired, total=False):
|
|
47
|
+
"""An AQL query with optional bind variables."""
|
|
48
|
+
|
|
49
|
+
bindVars: dict[str, typing.Any] # Optional: bind variables for the query
|
|
50
|
+
|
|
51
|
+
AttributeSpec = dict[str, str] # {"attr_name": "type_name"}
|
|
52
|
+
# OR list of {"name": str, "type": str}
|
|
53
|
+
|
|
54
|
+
AqlDataLoadRequest = typing.TypedDict(
|
|
55
|
+
"AqlDataLoadRequest",
|
|
56
|
+
{
|
|
57
|
+
"database_config": dict[str, typing.Any],
|
|
58
|
+
"batch_size": int,
|
|
59
|
+
"vertex_attributes": AttributeSpec | list[dict[str, str]],
|
|
60
|
+
"edge_attributes": AttributeSpec | list[dict[str, str]],
|
|
61
|
+
"queries": list[list[AqlQuery]],
|
|
62
|
+
},
|
|
63
|
+
total=False,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def graph_aql_to_numpy_format(request: AqlDataLoadRequest) -> typing.Tuple[
|
|
67
|
+
ArangoCollectionToNodeFeatures,
|
|
68
|
+
COOByEdgeType,
|
|
69
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
70
|
+
ArangoCollectionToIndexToArangoKey,
|
|
71
|
+
]: ...
|
|
72
|
+
def graph_aql_to_networkx_format(
|
|
73
|
+
request: AqlDataLoadRequest, graph_config: dict[str, typing.Any]
|
|
74
|
+
) -> typing.Tuple[
|
|
75
|
+
NodeDict,
|
|
76
|
+
GraphAdjDict | DiGraphAdjDict | MultiGraphAdjDict | MultiDiGraphAdjDict,
|
|
77
|
+
SrcIndices,
|
|
78
|
+
DstIndices,
|
|
79
|
+
EdgeIndices,
|
|
80
|
+
ArangoIDtoIndex,
|
|
81
|
+
EdgeValuesDict,
|
|
82
|
+
]: ...
|
|
83
|
+
|
|
84
|
+
class PhenolError(Exception): ...
|
phenolrs/pyg/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .loader import PygLoader # noqa: F401
|
phenolrs/pyg/loader.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from typing import Any, Tuple
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from phenolrs import PhenolError
|
|
6
|
+
from phenolrs.numpy import NumpyLoader
|
|
7
|
+
|
|
8
|
+
from .typings import (
|
|
9
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
10
|
+
ArangoCollectionToIndexToArangoKey,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import torch
|
|
15
|
+
from torch_geometric.data import Data, HeteroData
|
|
16
|
+
|
|
17
|
+
TORCH_AVAILABLE = True
|
|
18
|
+
except ImportError:
|
|
19
|
+
TORCH_AVAILABLE = False
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class PygLoader:
|
|
23
|
+
@staticmethod
|
|
24
|
+
def load_into_pyg_data(
|
|
25
|
+
database: str,
|
|
26
|
+
metagraph: dict[str, Any],
|
|
27
|
+
hosts: list[str],
|
|
28
|
+
user_jwt: str | None = None,
|
|
29
|
+
username: str | None = None,
|
|
30
|
+
password: str | None = None,
|
|
31
|
+
tls_cert: Any | None = None,
|
|
32
|
+
parallelism: int | None = None,
|
|
33
|
+
batch_size: int | None = None,
|
|
34
|
+
) -> Tuple[
|
|
35
|
+
"Data", ArangoCollectionToArangoKeyToIndex, ArangoCollectionToIndexToArangoKey
|
|
36
|
+
]:
|
|
37
|
+
if not TORCH_AVAILABLE:
|
|
38
|
+
m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501
|
|
39
|
+
raise ImportError(m)
|
|
40
|
+
|
|
41
|
+
if "vertexCollections" not in metagraph:
|
|
42
|
+
raise PhenolError("vertexCollections not found in metagraph")
|
|
43
|
+
if "edgeCollections" not in metagraph:
|
|
44
|
+
raise PhenolError("edgeCollections not found in metagraph")
|
|
45
|
+
|
|
46
|
+
if len(metagraph["vertexCollections"]) == 0:
|
|
47
|
+
raise PhenolError("vertexCollections must map to non-empty dictionary")
|
|
48
|
+
if len(metagraph["edgeCollections"]) == 0:
|
|
49
|
+
raise PhenolError("edgeCollections must map to non-empty dictionary")
|
|
50
|
+
|
|
51
|
+
if len(metagraph["vertexCollections"]) > 1:
|
|
52
|
+
m = "More than one vertex collection specified for homogeneous dataset"
|
|
53
|
+
raise PhenolError(m)
|
|
54
|
+
if len(metagraph["edgeCollections"]) > 1:
|
|
55
|
+
m = "More than one edge collection specified for homogeneous dataset"
|
|
56
|
+
raise PhenolError(m)
|
|
57
|
+
|
|
58
|
+
v_col_spec_name = list(metagraph["vertexCollections"].keys())[0]
|
|
59
|
+
v_col_spec = list(metagraph["vertexCollections"].values())[0]
|
|
60
|
+
|
|
61
|
+
(
|
|
62
|
+
features_by_col,
|
|
63
|
+
coo_map,
|
|
64
|
+
col_to_adb_key_to_ind,
|
|
65
|
+
col_to_ind_to_adb_key,
|
|
66
|
+
vertex_cols_source_to_output,
|
|
67
|
+
) = NumpyLoader.load_graph_to_numpy(
|
|
68
|
+
database,
|
|
69
|
+
metagraph,
|
|
70
|
+
hosts,
|
|
71
|
+
user_jwt,
|
|
72
|
+
username,
|
|
73
|
+
password,
|
|
74
|
+
tls_cert,
|
|
75
|
+
parallelism,
|
|
76
|
+
batch_size,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
data = Data()
|
|
80
|
+
# add the features
|
|
81
|
+
if v_col_spec_name not in features_by_col:
|
|
82
|
+
raise PhenolError(f"Unable to load data for collection {v_col_spec_name}")
|
|
83
|
+
for feature in v_col_spec.keys():
|
|
84
|
+
feature_source_key = v_col_spec[feature]
|
|
85
|
+
if feature_source_key not in features_by_col[v_col_spec_name]:
|
|
86
|
+
raise PhenolError(
|
|
87
|
+
f"Unable to load features {feature_source_key} for collection {v_col_spec_name}" # noqa: E501
|
|
88
|
+
)
|
|
89
|
+
result = torch.from_numpy(
|
|
90
|
+
features_by_col[v_col_spec_name][feature_source_key].astype(np.float64)
|
|
91
|
+
)
|
|
92
|
+
if result.numel() > 0:
|
|
93
|
+
data[feature] = result
|
|
94
|
+
|
|
95
|
+
# finally add the edges
|
|
96
|
+
edge_col_name = list(metagraph["edgeCollections"].keys())[0]
|
|
97
|
+
for e_tup in coo_map.keys():
|
|
98
|
+
e_name, from_name, to_name = e_tup
|
|
99
|
+
if e_name == edge_col_name:
|
|
100
|
+
result = torch.from_numpy(coo_map[e_tup].astype(np.int64))
|
|
101
|
+
if result.numel() > 0:
|
|
102
|
+
data["edge_index"] = result
|
|
103
|
+
|
|
104
|
+
return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key
|
|
105
|
+
|
|
106
|
+
@staticmethod
|
|
107
|
+
def load_into_pyg_heterodata(
|
|
108
|
+
database: str,
|
|
109
|
+
metagraph: dict[str, Any],
|
|
110
|
+
hosts: list[str],
|
|
111
|
+
user_jwt: str | None = None,
|
|
112
|
+
username: str | None = None,
|
|
113
|
+
password: str | None = None,
|
|
114
|
+
tls_cert: Any | None = None,
|
|
115
|
+
parallelism: int | None = None,
|
|
116
|
+
batch_size: int | None = None,
|
|
117
|
+
) -> tuple[
|
|
118
|
+
"HeteroData",
|
|
119
|
+
ArangoCollectionToArangoKeyToIndex,
|
|
120
|
+
ArangoCollectionToIndexToArangoKey,
|
|
121
|
+
]:
|
|
122
|
+
if not TORCH_AVAILABLE:
|
|
123
|
+
m = "Missing required dependencies. Install with `pip install phenolrs[torch]`" # noqa: E501
|
|
124
|
+
raise ImportError(m)
|
|
125
|
+
|
|
126
|
+
if "vertexCollections" not in metagraph:
|
|
127
|
+
raise PhenolError("vertexCollections not found in metagraph")
|
|
128
|
+
if "edgeCollections" not in metagraph:
|
|
129
|
+
raise PhenolError("edgeCollections not found in metagraph")
|
|
130
|
+
|
|
131
|
+
if len(metagraph["vertexCollections"]) == 0:
|
|
132
|
+
raise PhenolError("vertexCollections must map to non-empty dictionary")
|
|
133
|
+
if len(metagraph["edgeCollections"]) == 0:
|
|
134
|
+
raise PhenolError("edgeCollections must map to non-empty dictionary")
|
|
135
|
+
|
|
136
|
+
(
|
|
137
|
+
features_by_col,
|
|
138
|
+
coo_map,
|
|
139
|
+
col_to_adb_key_to_ind,
|
|
140
|
+
col_to_ind_to_adb_key,
|
|
141
|
+
vertex_cols_source_to_output,
|
|
142
|
+
) = NumpyLoader.load_graph_to_numpy(
|
|
143
|
+
database,
|
|
144
|
+
metagraph,
|
|
145
|
+
hosts,
|
|
146
|
+
user_jwt,
|
|
147
|
+
username,
|
|
148
|
+
password,
|
|
149
|
+
tls_cert,
|
|
150
|
+
parallelism,
|
|
151
|
+
batch_size,
|
|
152
|
+
)
|
|
153
|
+
data = HeteroData()
|
|
154
|
+
for col in features_by_col.keys():
|
|
155
|
+
col_mapping = vertex_cols_source_to_output[col]
|
|
156
|
+
for feature in features_by_col[col].keys():
|
|
157
|
+
if feature == "@collection_name":
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
target_name = col_mapping[feature]
|
|
161
|
+
result = torch.from_numpy(
|
|
162
|
+
features_by_col[col][feature].astype(np.float64)
|
|
163
|
+
)
|
|
164
|
+
if result.numel() > 0:
|
|
165
|
+
data[col][target_name] = result
|
|
166
|
+
|
|
167
|
+
for edge_col in coo_map.keys():
|
|
168
|
+
edge_col_name, from_name, to_name = edge_col
|
|
169
|
+
result = torch.from_numpy(coo_map[edge_col].astype(np.int64))
|
|
170
|
+
if result.numel() > 0:
|
|
171
|
+
data[(from_name, edge_col_name, to_name)].edge_index = result
|
|
172
|
+
|
|
173
|
+
return data, col_to_adb_key_to_ind, col_to_ind_to_adb_key
|
phenolrs/pyg/typings.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: phenolrs
|
|
3
|
+
Version: 0.5.12
|
|
4
|
+
Classifier: Programming Language :: Rust
|
|
5
|
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
|
6
|
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
|
7
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
8
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
11
|
+
Classifier: Typing :: Typed
|
|
12
|
+
Requires-Dist: numpy
|
|
13
|
+
Requires-Dist: python-arango>=8.2.6
|
|
14
|
+
Requires-Dist: pytest ; extra == 'tests'
|
|
15
|
+
Requires-Dist: arango-datasets ; extra == 'tests'
|
|
16
|
+
Requires-Dist: adbnx-adapter ; extra == 'tests'
|
|
17
|
+
Requires-Dist: torch ; extra == 'torch'
|
|
18
|
+
Requires-Dist: torch-geometric ; extra == 'torch'
|
|
19
|
+
Requires-Dist: version ; extra == 'dynamic'
|
|
20
|
+
Provides-Extra: tests
|
|
21
|
+
Provides-Extra: torch
|
|
22
|
+
Provides-Extra: dynamic
|
|
23
|
+
License-File: LICENSE
|
|
24
|
+
Requires-Python: >=3.10
|