graphmemory 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphmemory-0.1.0/PKG-INFO +193 -0
- graphmemory-0.1.0/README.md +177 -0
- graphmemory-0.1.0/graphmemory.egg-info/PKG-INFO +193 -0
- graphmemory-0.1.0/graphmemory.egg-info/SOURCES.txt +11 -0
- graphmemory-0.1.0/graphmemory.egg-info/dependency_links.txt +1 -0
- graphmemory-0.1.0/graphmemory.egg-info/requires.txt +5 -0
- graphmemory-0.1.0/graphmemory.egg-info/top_level.txt +1 -0
- graphmemory-0.1.0/graphrag/__init__.py +3 -0
- graphmemory-0.1.0/graphrag/database.py +258 -0
- graphmemory-0.1.0/graphrag/models.py +17 -0
- graphmemory-0.1.0/setup.cfg +4 -0
- graphmemory-0.1.0/setup.py +22 -0
- graphmemory-0.1.0/tests/tests.py +239 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: graphmemory
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A package for creating a graph database for use with GraphRAG.
|
|
5
|
+
Home-page: http://pypi.python.org/pypi/graphmemory/
|
|
6
|
+
Author: BradAGI
|
|
7
|
+
Author-email: cavemen_summary_0f@icloud.com
|
|
8
|
+
License: LICENSE.txt
|
|
9
|
+
Keywords: graphrag graph database rag
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: duckdb
|
|
12
|
+
Requires-Dist: json
|
|
13
|
+
Requires-Dist: pydantic
|
|
14
|
+
Requires-Dist: os
|
|
15
|
+
Requires-Dist: logging
|
|
16
|
+
|
|
17
|
+
# GraphMemory - GraphRAG Database
|
|
18
|
+
|
|
19
|
+
## Overview
|
|
20
|
+
This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
|
|
21
|
+
|
|
22
|
+
Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
|
|
23
|
+
|
|
24
|
+
Each edge has a unique ID, a source node ID, a target node ID, a relationship type, and a weight.
|
|
25
|
+
|
|
26
|
+
This database can be used for any graph-based RAG application or knowledge graph application.
|
|
27
|
+
|
|
28
|
+
Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
|
|
29
|
+
|
|
30
|
+
## Requirements
|
|
31
|
+
- Python 3.x
|
|
32
|
+
- DuckDB
|
|
33
|
+
- Pydantic
|
|
34
|
+
|
|
35
|
+
## Installation
|
|
36
|
+
1. Clone the repository:
|
|
37
|
+
```sh
|
|
38
|
+
git clone https://github.com/bradAGI/GraphMemory
|
|
39
|
+
```
|
|
40
|
+
2. Install the required packages:
|
|
41
|
+
```sh
|
|
42
|
+
pip install -r requirements.txt
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Usage
|
|
46
|
+
|
|
47
|
+
### GraphRAG Class
|
|
48
|
+
The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
|
|
49
|
+
|
|
50
|
+
### Auto-Incrementing IDs
|
|
51
|
+
If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
|
|
52
|
+
|
|
53
|
+
### Example
|
|
54
|
+
```python
|
|
55
|
+
from graphrag.graphrag import GraphRAG
|
|
56
|
+
from graphrag.models import Node, Edge
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Initialize the database from disk (make sure to set vector_length correctly)
|
|
60
|
+
graph_db = GraphRAG(database='graph.db', vector_length=3)
|
|
61
|
+
|
|
62
|
+
# Insert nodes
|
|
63
|
+
node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
|
|
64
|
+
node1_id = graph_db.insert_node(node1)
|
|
65
|
+
|
|
66
|
+
node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
|
|
67
|
+
node2_id = graph_db.insert_node(node2)
|
|
68
|
+
|
|
69
|
+
# Insert edge
|
|
70
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
|
|
71
|
+
graph_db.insert_edge(edge)
|
|
72
|
+
|
|
73
|
+
# Print all nodes in the database
|
|
74
|
+
nodes = graph_db.nodes_to_json()
|
|
75
|
+
print("Nodes:", nodes)
|
|
76
|
+
|
|
77
|
+
# Print all edges in the database
|
|
78
|
+
edges = graph_db.edges_to_json()
|
|
79
|
+
print("Edges:", edges)
|
|
80
|
+
|
|
81
|
+
# Find connected nodes
|
|
82
|
+
connected_nodes = graph_db.connected_nodes(node1_id)
|
|
83
|
+
print("Connected Nodes:", connected_nodes)
|
|
84
|
+
|
|
85
|
+
# Find nearest neighbors
|
|
86
|
+
neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
|
|
87
|
+
print("Nearest Neighbors:", neighbors)
|
|
88
|
+
|
|
89
|
+
# Insert an edge between the two nodes with a relation
|
|
90
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
|
|
91
|
+
graph_db.insert_edge(edge)
|
|
92
|
+
|
|
93
|
+
# Define the additional nodes for bulk insert
|
|
94
|
+
nodes = [
|
|
95
|
+
Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
|
|
96
|
+
Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
# Bulk insert nodes
|
|
100
|
+
graph_db.bulk_insert_nodes(nodes)
|
|
101
|
+
|
|
102
|
+
# Define the additional edges for bulk insert
|
|
103
|
+
edges = [
|
|
104
|
+
Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
|
|
105
|
+
Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
# Bulk insert edges
|
|
109
|
+
graph_db.bulk_insert_edges(edges)
|
|
110
|
+
|
|
111
|
+
# Delete a node
|
|
112
|
+
graph_db.delete_node(nodes[-1].id)
|
|
113
|
+
|
|
114
|
+
# Delete an edge
|
|
115
|
+
graph_db.delete_edge(1, 2)
|
|
116
|
+
|
|
117
|
+
# Find nearest nodes to a given vector by distance
|
|
118
|
+
neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
|
|
119
|
+
print("Nearest Neighbors:", neighbors)
|
|
120
|
+
|
|
121
|
+
# Find connected nodes
|
|
122
|
+
connected_nodes = graph_db.connected_nodes(nodes[1].id)
|
|
123
|
+
print("Connected Nodes:", connected_nodes)
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
## GraphRAG Class Methods
|
|
127
|
+
|
|
128
|
+
The `GraphRAG` class provides the following public methods for interacting with the graph database:
|
|
129
|
+
|
|
130
|
+
1. `__init__(self, database=None, vector_length=3)`
|
|
131
|
+
- Initializes the database connection and sets up the database schema if necessary.
|
|
132
|
+
|
|
133
|
+
2. `set_vector_length(self, vector_length)`
|
|
134
|
+
- Sets the length of the vectors for the nodes in the database.
|
|
135
|
+
|
|
136
|
+
3. `create_tables(self)`
|
|
137
|
+
- Creates the necessary database tables for nodes and edges if they do not exist.
|
|
138
|
+
|
|
139
|
+
4. `insert_node(self, node: Node) -> int`
|
|
140
|
+
- Inserts a node into the database and returns the node ID.
|
|
141
|
+
|
|
142
|
+
5. `insert_edge(self, edge: Edge)`
|
|
143
|
+
- Inserts an edge between two nodes in the database.
|
|
144
|
+
|
|
145
|
+
6. `bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]`
|
|
146
|
+
- Performs a bulk insert of multiple nodes into the database.
|
|
147
|
+
|
|
148
|
+
7. `bulk_insert_edges(self, edges: List[Edge])`
|
|
149
|
+
- Performs a bulk insert of multiple edges into the database.
|
|
150
|
+
|
|
151
|
+
8. `delete_node(self, node_id: int)`
|
|
152
|
+
- Deletes a node and its associated edges from the database.
|
|
153
|
+
|
|
154
|
+
9. `delete_edge(self, source_id: int, target_id: int)`
|
|
155
|
+
- Deletes an edge from the database.
|
|
156
|
+
|
|
157
|
+
10. `create_index(self)`
|
|
158
|
+
- Creates an index on the node vectors to improve search performance.
|
|
159
|
+
|
|
160
|
+
11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
|
|
161
|
+
- Finds and returns the nearest neighbor nodes based on vector similarity.
|
|
162
|
+
|
|
163
|
+
12. `connected_nodes(self, node_id: int) -> List[Node]`
|
|
164
|
+
- Retrieves all nodes directly connected to the specified node.
|
|
165
|
+
|
|
166
|
+
13. `nodes_to_json(self)`
|
|
167
|
+
- Returns a JSON representation of all nodes in the database.
|
|
168
|
+
|
|
169
|
+
14. `edges_to_json(self)`
|
|
170
|
+
- Returns a JSON representation of all edges in the database.
|
|
171
|
+
|
|
172
|
+
15. `get_node(self, node_id: int)`
|
|
173
|
+
- Retrieves a specific node by its ID.
|
|
174
|
+
|
|
175
|
+
16. `print_json(self)`
|
|
176
|
+
- Prints the JSON representation of all nodes and edges in the database.
|
|
177
|
+
|
|
178
|
+
These methods facilitate the management and querying of the graph database, allowing for efficient data handling and retrieval.
|
|
179
|
+
|
|
180
|
+
## Testing
|
|
181
|
+
Unit tests are provided in `tests/tests.py`.
|
|
182
|
+
|
|
183
|
+
### Running Tests
|
|
184
|
+
To run the unit tests, use the following command:
|
|
185
|
+
```sh
|
|
186
|
+
python -m unittest discover -s tests
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
## License
|
|
190
|
+
This project is licensed under the MIT License. See the LICENSE file for details.
|
|
191
|
+
|
|
192
|
+
## Contributing
|
|
193
|
+
Contributions are welcome! Please open an issue or submit a pull request.
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# GraphMemory - GraphRAG Database
|
|
2
|
+
|
|
3
|
+
## Overview
|
|
4
|
+
This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
|
|
5
|
+
|
|
6
|
+
Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
|
|
7
|
+
|
|
8
|
+
Each edge has a unique ID, a source node ID, a target node ID, a relationship type, and a weight.
|
|
9
|
+
|
|
10
|
+
This database can be used for any graph-based RAG application or knowledge graph application.
|
|
11
|
+
|
|
12
|
+
Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
|
|
13
|
+
|
|
14
|
+
## Requirements
|
|
15
|
+
- Python 3.x
|
|
16
|
+
- DuckDB
|
|
17
|
+
- Pydantic
|
|
18
|
+
|
|
19
|
+
## Installation
|
|
20
|
+
1. Clone the repository:
|
|
21
|
+
```sh
|
|
22
|
+
git clone https://github.com/bradAGI/GraphMemory
|
|
23
|
+
```
|
|
24
|
+
2. Install the required packages:
|
|
25
|
+
```sh
|
|
26
|
+
pip install -r requirements.txt
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## Usage
|
|
30
|
+
|
|
31
|
+
### GraphRAG Class
|
|
32
|
+
The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
|
|
33
|
+
|
|
34
|
+
### Auto-Incrementing IDs
|
|
35
|
+
If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
|
|
36
|
+
|
|
37
|
+
### Example
|
|
38
|
+
```python
|
|
39
|
+
from graphrag.graphrag import GraphRAG
|
|
40
|
+
from graphrag.models import Node, Edge
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# Initialize the database from disk (make sure to set vector_length correctly)
|
|
44
|
+
graph_db = GraphRAG(database='graph.db', vector_length=3)
|
|
45
|
+
|
|
46
|
+
# Insert nodes
|
|
47
|
+
node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
|
|
48
|
+
node1_id = graph_db.insert_node(node1)
|
|
49
|
+
|
|
50
|
+
node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
|
|
51
|
+
node2_id = graph_db.insert_node(node2)
|
|
52
|
+
|
|
53
|
+
# Insert edge
|
|
54
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
|
|
55
|
+
graph_db.insert_edge(edge)
|
|
56
|
+
|
|
57
|
+
# Print all nodes in the database
|
|
58
|
+
nodes = graph_db.nodes_to_json()
|
|
59
|
+
print("Nodes:", nodes)
|
|
60
|
+
|
|
61
|
+
# Print all edges in the database
|
|
62
|
+
edges = graph_db.edges_to_json()
|
|
63
|
+
print("Edges:", edges)
|
|
64
|
+
|
|
65
|
+
# Find connected nodes
|
|
66
|
+
connected_nodes = graph_db.connected_nodes(node1_id)
|
|
67
|
+
print("Connected Nodes:", connected_nodes)
|
|
68
|
+
|
|
69
|
+
# Find nearest neighbors
|
|
70
|
+
neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
|
|
71
|
+
print("Nearest Neighbors:", neighbors)
|
|
72
|
+
|
|
73
|
+
# Insert an edge between the two nodes with a relation
|
|
74
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
|
|
75
|
+
graph_db.insert_edge(edge)
|
|
76
|
+
|
|
77
|
+
# Define the additional nodes for bulk insert
|
|
78
|
+
nodes = [
|
|
79
|
+
Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
|
|
80
|
+
Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
# Bulk insert nodes
|
|
84
|
+
graph_db.bulk_insert_nodes(nodes)
|
|
85
|
+
|
|
86
|
+
# Define the additional edges for bulk insert
|
|
87
|
+
edges = [
|
|
88
|
+
Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
|
|
89
|
+
Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
|
|
90
|
+
]
|
|
91
|
+
|
|
92
|
+
# Bulk insert edges
|
|
93
|
+
graph_db.bulk_insert_edges(edges)
|
|
94
|
+
|
|
95
|
+
# Delete a node
|
|
96
|
+
graph_db.delete_node(nodes[-1].id)
|
|
97
|
+
|
|
98
|
+
# Delete an edge
|
|
99
|
+
graph_db.delete_edge(1, 2)
|
|
100
|
+
|
|
101
|
+
# Find nearest nodes to a given vector by distance
|
|
102
|
+
neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
|
|
103
|
+
print("Nearest Neighbors:", neighbors)
|
|
104
|
+
|
|
105
|
+
# Find connected nodes
|
|
106
|
+
connected_nodes = graph_db.connected_nodes(nodes[1].id)
|
|
107
|
+
print("Connected Nodes:", connected_nodes)
|
|
108
|
+
```
|
|
109
|
+
|
|
110
|
+
## GraphRAG Class Methods
|
|
111
|
+
|
|
112
|
+
The `GraphRAG` class provides the following public methods for interacting with the graph database:
|
|
113
|
+
|
|
114
|
+
1. `__init__(self, database=None, vector_length=3)`
|
|
115
|
+
- Initializes the database connection and sets up the database schema if necessary.
|
|
116
|
+
|
|
117
|
+
2. `set_vector_length(self, vector_length)`
|
|
118
|
+
- Sets the length of the vectors for the nodes in the database.
|
|
119
|
+
|
|
120
|
+
3. `create_tables(self)`
|
|
121
|
+
- Creates the necessary database tables for nodes and edges if they do not exist.
|
|
122
|
+
|
|
123
|
+
4. `insert_node(self, node: Node) -> int`
|
|
124
|
+
- Inserts a node into the database and returns the node ID.
|
|
125
|
+
|
|
126
|
+
5. `insert_edge(self, edge: Edge)`
|
|
127
|
+
- Inserts an edge between two nodes in the database.
|
|
128
|
+
|
|
129
|
+
6. `bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]`
|
|
130
|
+
- Performs a bulk insert of multiple nodes into the database.
|
|
131
|
+
|
|
132
|
+
7. `bulk_insert_edges(self, edges: List[Edge])`
|
|
133
|
+
- Performs a bulk insert of multiple edges into the database.
|
|
134
|
+
|
|
135
|
+
8. `delete_node(self, node_id: int)`
|
|
136
|
+
- Deletes a node and its associated edges from the database.
|
|
137
|
+
|
|
138
|
+
9. `delete_edge(self, source_id: int, target_id: int)`
|
|
139
|
+
- Deletes an edge from the database.
|
|
140
|
+
|
|
141
|
+
10. `create_index(self)`
|
|
142
|
+
- Creates an index on the node vectors to improve search performance.
|
|
143
|
+
|
|
144
|
+
11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
|
|
145
|
+
- Finds and returns the nearest neighbor nodes based on vector similarity.
|
|
146
|
+
|
|
147
|
+
12. `connected_nodes(self, node_id: int) -> List[Node]`
|
|
148
|
+
- Retrieves all nodes directly connected to the specified node.
|
|
149
|
+
|
|
150
|
+
13. `nodes_to_json(self)`
|
|
151
|
+
- Returns a JSON representation of all nodes in the database.
|
|
152
|
+
|
|
153
|
+
14. `edges_to_json(self)`
|
|
154
|
+
- Returns a JSON representation of all edges in the database.
|
|
155
|
+
|
|
156
|
+
15. `get_node(self, node_id: int)`
|
|
157
|
+
- Retrieves a specific node by its ID.
|
|
158
|
+
|
|
159
|
+
16. `print_json(self)`
|
|
160
|
+
- Prints the JSON representation of all nodes and edges in the database.
|
|
161
|
+
|
|
162
|
+
These methods facilitate the management and querying of the graph database, allowing for efficient data handling and retrieval.
|
|
163
|
+
|
|
164
|
+
## Testing
|
|
165
|
+
Unit tests are provided in `tests/tests.py`.
|
|
166
|
+
|
|
167
|
+
### Running Tests
|
|
168
|
+
To run the unit tests, use the following command:
|
|
169
|
+
```sh
|
|
170
|
+
python -m unittest discover -s tests
|
|
171
|
+
```
|
|
172
|
+
|
|
173
|
+
## License
|
|
174
|
+
This project is licensed under the MIT License. See the LICENSE file for details.
|
|
175
|
+
|
|
176
|
+
## Contributing
|
|
177
|
+
Contributions are welcome! Please open an issue or submit a pull request.
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: graphmemory
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: A package for creating a graph database for use with GraphRAG.
|
|
5
|
+
Home-page: http://pypi.python.org/pypi/graphmemory/
|
|
6
|
+
Author: BradAGI
|
|
7
|
+
Author-email: cavemen_summary_0f@icloud.com
|
|
8
|
+
License: LICENSE.txt
|
|
9
|
+
Keywords: graphrag graph database rag
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
Requires-Dist: duckdb
|
|
12
|
+
Requires-Dist: json
|
|
13
|
+
Requires-Dist: pydantic
|
|
14
|
+
Requires-Dist: os
|
|
15
|
+
Requires-Dist: logging
|
|
16
|
+
|
|
17
|
+
# GraphMemory - GraphRAG Database
|
|
18
|
+
|
|
19
|
+
## Overview
|
|
20
|
+
This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
|
|
21
|
+
|
|
22
|
+
Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
|
|
23
|
+
|
|
24
|
+
Each edge has a unique ID, a source node ID, a target node ID, a relationship type, and a weight.
|
|
25
|
+
|
|
26
|
+
This database can be used for any graph-based RAG application or knowledge graph application.
|
|
27
|
+
|
|
28
|
+
Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
|
|
29
|
+
|
|
30
|
+
## Requirements
|
|
31
|
+
- Python 3.x
|
|
32
|
+
- DuckDB
|
|
33
|
+
- Pydantic
|
|
34
|
+
|
|
35
|
+
## Installation
|
|
36
|
+
1. Clone the repository:
|
|
37
|
+
```sh
|
|
38
|
+
git clone https://github.com/bradAGI/GraphMemory
|
|
39
|
+
```
|
|
40
|
+
2. Install the required packages:
|
|
41
|
+
```sh
|
|
42
|
+
pip install -r requirements.txt
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Usage
|
|
46
|
+
|
|
47
|
+
### GraphRAG Class
|
|
48
|
+
The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
|
|
49
|
+
|
|
50
|
+
### Auto-Incrementing IDs
|
|
51
|
+
If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
|
|
52
|
+
|
|
53
|
+
### Example
|
|
54
|
+
```python
|
|
55
|
+
from graphrag.graphrag import GraphRAG
|
|
56
|
+
from graphrag.models import Node, Edge
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# Initialize the database from disk (make sure to set vector_length correctly)
|
|
60
|
+
graph_db = GraphRAG(database='graph.db', vector_length=3)
|
|
61
|
+
|
|
62
|
+
# Insert nodes
|
|
63
|
+
node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
|
|
64
|
+
node1_id = graph_db.insert_node(node1)
|
|
65
|
+
|
|
66
|
+
node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
|
|
67
|
+
node2_id = graph_db.insert_node(node2)
|
|
68
|
+
|
|
69
|
+
# Insert edge
|
|
70
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
|
|
71
|
+
graph_db.insert_edge(edge)
|
|
72
|
+
|
|
73
|
+
# Print all nodes in the database
|
|
74
|
+
nodes = graph_db.nodes_to_json()
|
|
75
|
+
print("Nodes:", nodes)
|
|
76
|
+
|
|
77
|
+
# Print all edges in the database
|
|
78
|
+
edges = graph_db.edges_to_json()
|
|
79
|
+
print("Edges:", edges)
|
|
80
|
+
|
|
81
|
+
# Find connected nodes
|
|
82
|
+
connected_nodes = graph_db.connected_nodes(node1_id)
|
|
83
|
+
print("Connected Nodes:", connected_nodes)
|
|
84
|
+
|
|
85
|
+
# Find nearest neighbors
|
|
86
|
+
neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
|
|
87
|
+
print("Nearest Neighbors:", neighbors)
|
|
88
|
+
|
|
89
|
+
# Insert an edge between the two nodes with a relation
|
|
90
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
|
|
91
|
+
graph_db.insert_edge(edge)
|
|
92
|
+
|
|
93
|
+
# Define the additional nodes for bulk insert
|
|
94
|
+
nodes = [
|
|
95
|
+
Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
|
|
96
|
+
Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
# Bulk insert nodes
|
|
100
|
+
graph_db.bulk_insert_nodes(nodes)
|
|
101
|
+
|
|
102
|
+
# Define the additional edges for bulk insert
|
|
103
|
+
edges = [
|
|
104
|
+
Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
|
|
105
|
+
Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
# Bulk insert edges
|
|
109
|
+
graph_db.bulk_insert_edges(edges)
|
|
110
|
+
|
|
111
|
+
# Delete a node
|
|
112
|
+
graph_db.delete_node(nodes[-1].id)
|
|
113
|
+
|
|
114
|
+
# Delete an edge
|
|
115
|
+
graph_db.delete_edge(1, 2)
|
|
116
|
+
|
|
117
|
+
# Find nearest nodes to a given vector by distance
|
|
118
|
+
neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
|
|
119
|
+
print("Nearest Neighbors:", neighbors)
|
|
120
|
+
|
|
121
|
+
# Find connected nodes
|
|
122
|
+
connected_nodes = graph_db.connected_nodes(nodes[1].id)
|
|
123
|
+
print("Connected Nodes:", connected_nodes)
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
## GraphRAG Class Methods
|
|
127
|
+
|
|
128
|
+
The `GraphRAG` class provides the following public methods for interacting with the graph database:
|
|
129
|
+
|
|
130
|
+
1. `__init__(self, database=None, vector_length=3)`
|
|
131
|
+
- Initializes the database connection and sets up the database schema if necessary.
|
|
132
|
+
|
|
133
|
+
2. `set_vector_length(self, vector_length)`
|
|
134
|
+
- Sets the length of the vectors for the nodes in the database.
|
|
135
|
+
|
|
136
|
+
3. `create_tables(self)`
|
|
137
|
+
- Creates the necessary database tables for nodes and edges if they do not exist.
|
|
138
|
+
|
|
139
|
+
4. `insert_node(self, node: Node) -> int`
|
|
140
|
+
- Inserts a node into the database and returns the node ID.
|
|
141
|
+
|
|
142
|
+
5. `insert_edge(self, edge: Edge)`
|
|
143
|
+
- Inserts an edge between two nodes in the database.
|
|
144
|
+
|
|
145
|
+
6. `bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]`
|
|
146
|
+
- Performs a bulk insert of multiple nodes into the database.
|
|
147
|
+
|
|
148
|
+
7. `bulk_insert_edges(self, edges: List[Edge])`
|
|
149
|
+
- Performs a bulk insert of multiple edges into the database.
|
|
150
|
+
|
|
151
|
+
8. `delete_node(self, node_id: int)`
|
|
152
|
+
- Deletes a node and its associated edges from the database.
|
|
153
|
+
|
|
154
|
+
9. `delete_edge(self, source_id: int, target_id: int)`
|
|
155
|
+
- Deletes an edge from the database.
|
|
156
|
+
|
|
157
|
+
10. `create_index(self)`
|
|
158
|
+
- Creates an index on the node vectors to improve search performance.
|
|
159
|
+
|
|
160
|
+
11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
|
|
161
|
+
- Finds and returns the nearest neighbor nodes based on vector similarity.
|
|
162
|
+
|
|
163
|
+
12. `connected_nodes(self, node_id: int) -> List[Node]`
|
|
164
|
+
- Retrieves all nodes directly connected to the specified node.
|
|
165
|
+
|
|
166
|
+
13. `nodes_to_json(self)`
|
|
167
|
+
- Returns a JSON representation of all nodes in the database.
|
|
168
|
+
|
|
169
|
+
14. `edges_to_json(self)`
|
|
170
|
+
- Returns a JSON representation of all edges in the database.
|
|
171
|
+
|
|
172
|
+
15. `get_node(self, node_id: int)`
|
|
173
|
+
- Retrieves a specific node by its ID.
|
|
174
|
+
|
|
175
|
+
16. `print_json(self)`
|
|
176
|
+
- Prints the JSON representation of all nodes and edges in the database.
|
|
177
|
+
|
|
178
|
+
These methods facilitate the management and querying of the graph database, allowing for efficient data handling and retrieval.
|
|
179
|
+
|
|
180
|
+
## Testing
|
|
181
|
+
Unit tests are provided in `tests/tests.py`.
|
|
182
|
+
|
|
183
|
+
### Running Tests
|
|
184
|
+
To run the unit tests, use the following command:
|
|
185
|
+
```sh
|
|
186
|
+
python -m unittest discover -s tests
|
|
187
|
+
```
|
|
188
|
+
|
|
189
|
+
## License
|
|
190
|
+
This project is licensed under the MIT License. See the LICENSE file for details.
|
|
191
|
+
|
|
192
|
+
## Contributing
|
|
193
|
+
Contributions are welcome! Please open an issue or submit a pull request.
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
setup.py
|
|
3
|
+
graphmemory.egg-info/PKG-INFO
|
|
4
|
+
graphmemory.egg-info/SOURCES.txt
|
|
5
|
+
graphmemory.egg-info/dependency_links.txt
|
|
6
|
+
graphmemory.egg-info/requires.txt
|
|
7
|
+
graphmemory.egg-info/top_level.txt
|
|
8
|
+
graphrag/__init__.py
|
|
9
|
+
graphrag/database.py
|
|
10
|
+
graphrag/models.py
|
|
11
|
+
tests/tests.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
graphrag
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import duckdb
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from graphrag.models import Node, Edge, Neighbor
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Configure logging
|
|
11
|
+
logging.basicConfig(level=logging.INFO)
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
class GraphRAG:
|
|
15
|
+
def __init__(self, database=None, vector_length=3):
|
|
16
|
+
self.database = database
|
|
17
|
+
self.vector_length = vector_length
|
|
18
|
+
self.conn = duckdb.connect(database=self.database)
|
|
19
|
+
self._load_vss_extension()
|
|
20
|
+
self._configure_database()
|
|
21
|
+
|
|
22
|
+
if database and os.path.exists(database):
|
|
23
|
+
self.load_database(database)
|
|
24
|
+
|
|
25
|
+
# Check if 'nodes' and 'edges' tables exist, and create them if they do not
|
|
26
|
+
nodes_exist = self.conn.execute("SELECT 1 FROM information_schema.tables WHERE table_name = 'nodes';").fetchone()
|
|
27
|
+
edges_exist = self.conn.execute("SELECT 1 FROM information_schema.tables WHERE table_name = 'edges';").fetchone()
|
|
28
|
+
|
|
29
|
+
if not nodes_exist or not edges_exist:
|
|
30
|
+
self.create_tables()
|
|
31
|
+
logger.info("Tables created or verified successfully.")
|
|
32
|
+
|
|
33
|
+
def load_database(self, path):
|
|
34
|
+
if not os.path.exists(path):
|
|
35
|
+
logger.error(f"Database file not found: {path}")
|
|
36
|
+
return
|
|
37
|
+
try:
|
|
38
|
+
self.conn.execute(f"ATTACH DATABASE '{path}' AS vss;")
|
|
39
|
+
except duckdb.Error as e:
|
|
40
|
+
logger.error(f"Error loading database: {e}")
|
|
41
|
+
|
|
42
|
+
def _configure_database(self):
|
|
43
|
+
try:
|
|
44
|
+
self.conn.execute("SET hnsw_enable_experimental_persistence=true;")
|
|
45
|
+
except duckdb.Error as e:
|
|
46
|
+
logger.error(f"Error setting configuration: {e}")
|
|
47
|
+
|
|
48
|
+
def _load_vss_extension(self):
|
|
49
|
+
try:
|
|
50
|
+
self.conn.execute("INSTALL vss;")
|
|
51
|
+
self.conn.execute("LOAD vss;")
|
|
52
|
+
except duckdb.Error as e:
|
|
53
|
+
logger.error(f"Error loading VSS extension: {e}")
|
|
54
|
+
|
|
55
|
+
def set_vector_length(self, vector_length):
|
|
56
|
+
self.vector_length = vector_length
|
|
57
|
+
logger.info(f"Vector length set to: {self.vector_length}")
|
|
58
|
+
|
|
59
|
+
def create_tables(self):
|
|
60
|
+
self.conn.execute("CREATE SEQUENCE IF NOT EXISTS seq_id;")
|
|
61
|
+
self.conn.execute(f"""
|
|
62
|
+
CREATE TABLE IF NOT EXISTS nodes (
|
|
63
|
+
id INTEGER DEFAULT nextval('seq_id'),
|
|
64
|
+
data JSON,
|
|
65
|
+
vector FLOAT[{self.vector_length}],
|
|
66
|
+
PRIMARY KEY (id)
|
|
67
|
+
);
|
|
68
|
+
""")
|
|
69
|
+
self.conn.execute("""
|
|
70
|
+
CREATE TABLE IF NOT EXISTS edges (
|
|
71
|
+
source_id INTEGER,
|
|
72
|
+
target_id INTEGER,
|
|
73
|
+
weight FLOAT,
|
|
74
|
+
relation TEXT,
|
|
75
|
+
PRIMARY KEY (source_id, target_id)
|
|
76
|
+
);
|
|
77
|
+
""")
|
|
78
|
+
logger.info("Tables 'nodes' and 'edges' created or already exist.")
|
|
79
|
+
self.conn.commit() # Ensure changes are committed
|
|
80
|
+
|
|
81
|
+
def insert_node(self, node: Node) -> int:
|
|
82
|
+
if not self._validate_vector(node.vector):
|
|
83
|
+
logger.error("Invalid vector: Must be a list of float values.")
|
|
84
|
+
return -1
|
|
85
|
+
try:
|
|
86
|
+
with self.transaction():
|
|
87
|
+
result = self.conn.execute(
|
|
88
|
+
"INSERT INTO nodes (data, vector) VALUES (?, ?) RETURNING id;",
|
|
89
|
+
(json.dumps(node.data), node.vector)
|
|
90
|
+
).fetchone()
|
|
91
|
+
if result:
|
|
92
|
+
logger.info(f"Node inserted with ID: {result[0]}")
|
|
93
|
+
return result[0] if result else -1
|
|
94
|
+
except duckdb.Error as e:
|
|
95
|
+
logger.error(f"Error during insert node: {e}")
|
|
96
|
+
return -1
|
|
97
|
+
|
|
98
|
+
# Update insert_edge to use Edge model
|
|
99
|
+
def insert_edge(self, edge: Edge):
|
|
100
|
+
try:
|
|
101
|
+
with self.transaction():
|
|
102
|
+
# Check if source and target nodes exist
|
|
103
|
+
source_exists = self.conn.execute("SELECT 1 FROM nodes WHERE id = ?", (edge.source_id,)).fetchone()
|
|
104
|
+
target_exists = self.conn.execute("SELECT 1 FROM nodes WHERE id = ?", (edge.target_id,)).fetchone()
|
|
105
|
+
if not source_exists or not target_exists:
|
|
106
|
+
raise ValueError("Source or target node does not exist.")
|
|
107
|
+
|
|
108
|
+
self.conn.execute("INSERT INTO edges (source_id, target_id, weight, relation) VALUES (?, ?, ?, ?);", (edge.source_id, edge.target_id, edge.weight, edge.relation))
|
|
109
|
+
except duckdb.Error as e:
|
|
110
|
+
logger.error(f"Error during insert edge: {e}")
|
|
111
|
+
except ValueError as e:
|
|
112
|
+
logger.error(f"Error during insert edge: {e}")
|
|
113
|
+
raise # Re-raise the ValueError
|
|
114
|
+
|
|
115
|
+
def bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]:
|
|
116
|
+
try:
|
|
117
|
+
with self.transaction():
|
|
118
|
+
for node in nodes:
|
|
119
|
+
result = self.conn.execute(
|
|
120
|
+
"INSERT INTO nodes (data, vector) VALUES (?, ?) RETURNING id;",
|
|
121
|
+
(json.dumps(node.data), node.vector)
|
|
122
|
+
).fetchone()
|
|
123
|
+
if result:
|
|
124
|
+
node.id = result[0] # Set the node ID
|
|
125
|
+
return nodes
|
|
126
|
+
except duckdb.Error as e:
|
|
127
|
+
logger.error(f"Error during bulk insert nodes: {e}")
|
|
128
|
+
return []
|
|
129
|
+
|
|
130
|
+
def bulk_insert_edges(self, edges: List[Edge]):
|
|
131
|
+
try:
|
|
132
|
+
with self.transaction():
|
|
133
|
+
self.conn.executemany(
|
|
134
|
+
"INSERT INTO edges (source_id, target_id, weight, relation) VALUES (?, ?, ?, ?);",
|
|
135
|
+
[(edge.source_id, edge.target_id, edge.weight, edge.relation) for edge in edges]
|
|
136
|
+
)
|
|
137
|
+
except duckdb.Error as e:
|
|
138
|
+
logger.error(f"Error during bulk insert edges: {e}")
|
|
139
|
+
|
|
140
|
+
def delete_node(self, node_id: int):
|
|
141
|
+
try:
|
|
142
|
+
with self.transaction():
|
|
143
|
+
self.conn.execute("DELETE FROM nodes WHERE id = ?;", (node_id,))
|
|
144
|
+
self.conn.execute("DELETE FROM edges WHERE source_id = ? OR target_id = ?;", (node_id, node_id))
|
|
145
|
+
except duckdb.Error as e:
|
|
146
|
+
logger.error(f"Error deleting node: {e}")
|
|
147
|
+
|
|
148
|
+
def delete_edge(self, source_id: int, target_id: int):
|
|
149
|
+
try:
|
|
150
|
+
with self.transaction():
|
|
151
|
+
self.conn.execute("DELETE FROM edges WHERE source_id = ? AND target_id = ?;", (source_id, target_id))
|
|
152
|
+
except duckdb.Error as e:
|
|
153
|
+
logger.error(f"Error deleting edge: {e}")
|
|
154
|
+
|
|
155
|
+
def create_index(self):
|
|
156
|
+
try:
|
|
157
|
+
self.conn.execute("CREATE INDEX IF NOT EXISTS vss_idx ON nodes USING HNSW(vector);")
|
|
158
|
+
except duckdb.Error as e:
|
|
159
|
+
logger.error(f"Error creating index: {e}")
|
|
160
|
+
|
|
161
|
+
# Update nearest_neighbors to use vector and limit directly
|
|
162
|
+
def nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]:
|
|
163
|
+
if not self._validate_vector(vector):
|
|
164
|
+
logger.error("Invalid vector: Must be a list of float values.")
|
|
165
|
+
return []
|
|
166
|
+
|
|
167
|
+
query = f"""
|
|
168
|
+
SELECT id, data, vector, array_distance(vector, CAST(? AS FLOAT[{self.vector_length}])) AS distance
|
|
169
|
+
FROM nodes
|
|
170
|
+
ORDER BY distance LIMIT ?;
|
|
171
|
+
"""
|
|
172
|
+
try:
|
|
173
|
+
results = self.conn.execute(query, (vector, limit)).fetchall()
|
|
174
|
+
return [
|
|
175
|
+
Neighbor(
|
|
176
|
+
node=Node(id=row[0], data=json.loads(row[1]), vector=row[2]),
|
|
177
|
+
distance=row[3]
|
|
178
|
+
) for row in results
|
|
179
|
+
]
|
|
180
|
+
except duckdb.Error as e:
|
|
181
|
+
logger.error(f"Error fetching nearest neighbors: {e}")
|
|
182
|
+
return []
|
|
183
|
+
|
|
184
|
+
def connected_nodes(self, node_id: int) -> List[Node]:
|
|
185
|
+
query = """
|
|
186
|
+
SELECT n.id, n.data, n.vector
|
|
187
|
+
FROM nodes n
|
|
188
|
+
WHERE n.id IN (
|
|
189
|
+
SELECT target_id FROM edges WHERE source_id = CAST(? AS INTEGER)
|
|
190
|
+
UNION
|
|
191
|
+
SELECT source_id FROM edges WHERE target_id = CAST(? AS INTEGER)
|
|
192
|
+
);
|
|
193
|
+
"""
|
|
194
|
+
try:
|
|
195
|
+
logger.info(f"Executing query to fetch connected nodes for node_id: {node_id}")
|
|
196
|
+
results = self.conn.execute(query, (node_id, node_id)).fetchall()
|
|
197
|
+
if results:
|
|
198
|
+
connected_nodes = [Node(id=row[0], data=json.loads(row[1]), vector=row[2]) for row in results]
|
|
199
|
+
logger.info(f"Found {len(connected_nodes)} connected nodes.")
|
|
200
|
+
else:
|
|
201
|
+
connected_nodes = []
|
|
202
|
+
logger.info("No connected nodes found.")
|
|
203
|
+
return connected_nodes
|
|
204
|
+
except duckdb.Error as e:
|
|
205
|
+
logger.error(f"Error fetching connected nodes: {e}")
|
|
206
|
+
return []
|
|
207
|
+
|
|
208
|
+
def nodes_to_json(self):
|
|
209
|
+
try:
|
|
210
|
+
nodes = self.conn.execute("SELECT id, data, vector FROM nodes;").fetchall()
|
|
211
|
+
return [{"id": row[0], "data": json.loads(row[1]), "vector": row[2]} for row in nodes]
|
|
212
|
+
except duckdb.Error as e:
|
|
213
|
+
logger.error(f"Error fetching nodes: {e}")
|
|
214
|
+
return []
|
|
215
|
+
|
|
216
|
+
def edges_to_json(self):
|
|
217
|
+
try:
|
|
218
|
+
edges = self.conn.execute("SELECT source_id, target_id, weight, relation FROM edges;").fetchall()
|
|
219
|
+
return [{"source_id": row[0], "target_id": row[1], "weight": row[2], "relation": row[3]} for row in edges]
|
|
220
|
+
except duckdb.Error as e:
|
|
221
|
+
logger.error(f"Error fetching edges: {e}")
|
|
222
|
+
return []
|
|
223
|
+
|
|
224
|
+
# Get node by id
|
|
225
|
+
def get_node(self, node_id: int):
|
|
226
|
+
try:
|
|
227
|
+
node = self.conn.execute("SELECT id, data, vector FROM nodes WHERE id = ?;", (node_id,)).fetchone()
|
|
228
|
+
return {"id": node[0], "data": json.loads(node[1]), "vector": node[2]}
|
|
229
|
+
except duckdb.Error as e:
|
|
230
|
+
logger.error(f"Error fetching node: {e}")
|
|
231
|
+
return {}
|
|
232
|
+
|
|
233
|
+
def print_json(self):
|
|
234
|
+
nodes_json = self.nodes_to_json()
|
|
235
|
+
edges_json = self.edges_to_json()
|
|
236
|
+
print("Nodes JSON:", json.dumps(nodes_json, indent=2))
|
|
237
|
+
print("Edges JSON:", json.dumps(edges_json, indent=2))
|
|
238
|
+
|
|
239
|
+
def _validate_vector(self, vector):
|
|
240
|
+
return isinstance(vector, list) and len(vector) == self.vector_length and all(isinstance(x, float) for x in vector)
|
|
241
|
+
|
|
242
|
+
@contextmanager
|
|
243
|
+
def transaction(self):
|
|
244
|
+
try:
|
|
245
|
+
self.conn.execute("BEGIN TRANSACTION;")
|
|
246
|
+
yield
|
|
247
|
+
self.conn.execute("COMMIT;")
|
|
248
|
+
except Exception as e:
|
|
249
|
+
self.conn.execute("ROLLBACK;")
|
|
250
|
+
raise e
|
|
251
|
+
|
|
252
|
+
def __enter__(self):
|
|
253
|
+
return self
|
|
254
|
+
|
|
255
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
256
|
+
self.conn.close()
|
|
257
|
+
logger.info("Database connection closed.")
|
|
258
|
+
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from pydantic import BaseModel
|
|
2
|
+
from typing import List, Dict, Any
|
|
3
|
+
|
|
4
|
+
class Node(BaseModel):
|
|
5
|
+
id: int = None
|
|
6
|
+
data: Dict[str, Any]
|
|
7
|
+
vector: List[float]
|
|
8
|
+
|
|
9
|
+
class Edge(BaseModel):
|
|
10
|
+
source_id: int
|
|
11
|
+
target_id: int
|
|
12
|
+
relation: str
|
|
13
|
+
weight: float
|
|
14
|
+
|
|
15
|
+
class Neighbor(BaseModel):
|
|
16
|
+
node: Node
|
|
17
|
+
distance: float
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from setuptools import setup, find_packages
|
|
2
|
+
|
|
3
|
+
setup(
|
|
4
|
+
name='graphmemory',
|
|
5
|
+
version='0.1.0',
|
|
6
|
+
author='BradAGI',
|
|
7
|
+
author_email='cavemen_summary_0f@icloud.com',
|
|
8
|
+
packages=find_packages(),
|
|
9
|
+
url='http://pypi.python.org/pypi/graphmemory/',
|
|
10
|
+
license='LICENSE.txt',
|
|
11
|
+
description='A package for creating a graph database for use with GraphRAG.',
|
|
12
|
+
long_description=open('README.md').read(),
|
|
13
|
+
long_description_content_type='text/markdown',
|
|
14
|
+
install_requires=[
|
|
15
|
+
'duckdb',
|
|
16
|
+
'json',
|
|
17
|
+
'pydantic',
|
|
18
|
+
'os',
|
|
19
|
+
'logging'
|
|
20
|
+
],
|
|
21
|
+
keywords='graphrag graph database rag'
|
|
22
|
+
)
|
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
|
|
4
|
+
import duckdb
|
|
5
|
+
import unittest
|
|
6
|
+
from src.graphrag import GraphRAG
|
|
7
|
+
from src.models import Node, Edge, Neighbor
|
|
8
|
+
from pydantic import ValidationError
|
|
9
|
+
|
|
10
|
+
class TestGraphRAG(unittest.TestCase):
|
|
11
|
+
|
|
12
|
+
def setUp(self):
|
|
13
|
+
self.db = GraphRAG(database=':memory:', vector_length=3)
|
|
14
|
+
|
|
15
|
+
def test_insert_node(self):
|
|
16
|
+
node = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
17
|
+
node_id = self.db.insert_node(node)
|
|
18
|
+
result = self.db.conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
|
|
19
|
+
self.assertIsNotNone(result)
|
|
20
|
+
self.assertEqual(result[0], node_id)
|
|
21
|
+
self.assertEqual(result[1], '{"name": "node1"}')
|
|
22
|
+
self.assertAlmostEqual(result[2][0], 0.1, places=7)
|
|
23
|
+
self.assertAlmostEqual(result[2][1], 0.2, places=7)
|
|
24
|
+
self.assertAlmostEqual(result[2][2], 0.3, places=7)
|
|
25
|
+
|
|
26
|
+
def test_insert_invalid_node(self):
|
|
27
|
+
with self.assertRaises(ValidationError):
|
|
28
|
+
node = Node(data={"name": "node1"}, vector=[0.1, 0.2, "invalid"])
|
|
29
|
+
self.db.insert_node(node)
|
|
30
|
+
|
|
31
|
+
def test_insert_node_invalid_vector_length(self):
|
|
32
|
+
node = Node(data={"name": "node1"}, vector=[0.1, 0.2])
|
|
33
|
+
result = self.db.insert_node(node)
|
|
34
|
+
self.assertEqual(result, -1)
|
|
35
|
+
|
|
36
|
+
def test_insert_edge(self):
|
|
37
|
+
node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
38
|
+
node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
39
|
+
node1_id = self.db.insert_node(node1)
|
|
40
|
+
node2_id = self.db.insert_node(node2)
|
|
41
|
+
edge = Edge(source_id=node1_id, target_id=node2_id, weight=0.5, relation="friendship")
|
|
42
|
+
self.db.insert_edge(edge)
|
|
43
|
+
result = self.db.conn.execute("SELECT source_id, target_id, weight, relation FROM edges WHERE source_id = ? AND target_id = ?", (node1_id, node2_id)).fetchone()
|
|
44
|
+
self.assertIsNotNone(result)
|
|
45
|
+
self.assertEqual(result[0], node1_id)
|
|
46
|
+
self.assertEqual(result[1], node2_id)
|
|
47
|
+
self.assertEqual(result[2], 0.5)
|
|
48
|
+
self.assertEqual(result[3], "friendship")
|
|
49
|
+
|
|
50
|
+
def test_insert_edge_non_existent_nodes(self):
|
|
51
|
+
edge = Edge(source_id=999, target_id=1000, weight=0.5, relation="friendship")
|
|
52
|
+
with self.assertRaises(ValueError):
|
|
53
|
+
self.db.insert_edge(edge)
|
|
54
|
+
|
|
55
|
+
def test_nearest_neighbors_empty_db(self):
|
|
56
|
+
neighbors = self.db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
|
|
57
|
+
self.assertEqual(len(neighbors), 0)
|
|
58
|
+
|
|
59
|
+
def test_nearest_neighbors(self):
|
|
60
|
+
nodes = [
|
|
61
|
+
Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3]),
|
|
62
|
+
Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6]),
|
|
63
|
+
Node(data={"name": "node3"}, vector=[0.7, 0.8, 0.9])
|
|
64
|
+
]
|
|
65
|
+
for node in nodes:
|
|
66
|
+
self.db.insert_node(node)
|
|
67
|
+
self.db.create_index()
|
|
68
|
+
neighbors = self.db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
|
|
69
|
+
self.assertEqual(len(neighbors), 2)
|
|
70
|
+
self.assertEqual(neighbors[0].node.id, 1)
|
|
71
|
+
self.assertEqual(neighbors[1].node.id, 2)
|
|
72
|
+
|
|
73
|
+
def test_nodes_to_json(self):
|
|
74
|
+
node = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
75
|
+
self.db.insert_node(node)
|
|
76
|
+
nodes_json = self.db.nodes_to_json()
|
|
77
|
+
self.assertEqual(len(nodes_json), 1)
|
|
78
|
+
self.assertEqual(nodes_json[0]['id'], 1)
|
|
79
|
+
self.assertEqual(nodes_json[0]['data'], {"name": "node1"})
|
|
80
|
+
self.assertAlmostEqual(nodes_json[0]['vector'][0], 0.1, places=7)
|
|
81
|
+
self.assertAlmostEqual(nodes_json[0]['vector'][1], 0.2, places=7)
|
|
82
|
+
self.assertAlmostEqual(nodes_json[0]['vector'][2], 0.3, places=7)
|
|
83
|
+
|
|
84
|
+
def test_edges_to_json(self):
|
|
85
|
+
node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
86
|
+
node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
87
|
+
self.db.insert_node(node1)
|
|
88
|
+
self.db.insert_node(node2)
|
|
89
|
+
edge = Edge(source_id=1, target_id=2, weight=0.5, relation="friendship")
|
|
90
|
+
self.db.insert_edge(edge)
|
|
91
|
+
edges_json = self.db.edges_to_json()
|
|
92
|
+
self.assertEqual(len(edges_json), 1)
|
|
93
|
+
self.assertEqual(edges_json[0]['source_id'], 1)
|
|
94
|
+
self.assertEqual(edges_json[0]['target_id'], 2)
|
|
95
|
+
self.assertEqual(edges_json[0]['weight'], 0.5)
|
|
96
|
+
self.assertEqual(edges_json[0]['relation'], "friendship")
|
|
97
|
+
|
|
98
|
+
def test_bulk_insert_nodes(self):
|
|
99
|
+
nodes = [
|
|
100
|
+
Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3]),
|
|
101
|
+
Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
102
|
+
]
|
|
103
|
+
inserted_nodes = self.db.bulk_insert_nodes(nodes)
|
|
104
|
+
self.assertEqual(len(inserted_nodes), 2)
|
|
105
|
+
self.assertTrue(all(node.id is not None for node in inserted_nodes))
|
|
106
|
+
|
|
107
|
+
def test_bulk_insert_edges(self):
|
|
108
|
+
node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
109
|
+
node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
110
|
+
node3 = Node(data={"name": "node3"}, vector=[0.7, 0.8, 0.9])
|
|
111
|
+
self.db.bulk_insert_nodes([node1, node2, node3])
|
|
112
|
+
edges = [
|
|
113
|
+
Edge(source_id=1, target_id=2, weight=0.5, relation="friendship"),
|
|
114
|
+
Edge(source_id=2, target_id=3, weight=0.7, relation="colleague")
|
|
115
|
+
]
|
|
116
|
+
self.db.bulk_insert_edges(edges)
|
|
117
|
+
result = self.db.conn.execute("SELECT * FROM edges").fetchall()
|
|
118
|
+
self.assertEqual(len(result), 2)
|
|
119
|
+
|
|
120
|
+
def test_delete_node(self):
|
|
121
|
+
node = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
122
|
+
node_id = self.db.insert_node(node)
|
|
123
|
+
self.db.delete_node(node_id)
|
|
124
|
+
result = self.db.conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
|
|
125
|
+
self.assertIsNone(result)
|
|
126
|
+
|
|
127
|
+
def test_delete_non_existent_node(self):
|
|
128
|
+
self.db.delete_node(999)
|
|
129
|
+
result = self.db.conn.execute("SELECT * FROM nodes WHERE id = 999").fetchone()
|
|
130
|
+
self.assertIsNone(result)
|
|
131
|
+
|
|
132
|
+
def test_delete_edge(self):
|
|
133
|
+
node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
134
|
+
node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
135
|
+
self.db.insert_node(node1)
|
|
136
|
+
self.db.insert_node(node2)
|
|
137
|
+
edge = Edge(source_id=1, target_id=2, weight=0.5, relation="friendship")
|
|
138
|
+
self.db.insert_edge(edge)
|
|
139
|
+
self.db.delete_edge(1, 2)
|
|
140
|
+
result = self.db.conn.execute("SELECT * FROM edges WHERE source_id = 1 AND target_id = 2").fetchone()
|
|
141
|
+
self.assertIsNone(result)
|
|
142
|
+
|
|
143
|
+
def test_delete_non_existent_edge(self):
|
|
144
|
+
self.db.delete_edge(999, 1000)
|
|
145
|
+
result = self.db.conn.execute("SELECT * FROM edges WHERE source_id = 999 AND target_id = 1000").fetchone()
|
|
146
|
+
self.assertIsNone(result)
|
|
147
|
+
|
|
148
|
+
def test_transaction_handling(self):
|
|
149
|
+
try:
|
|
150
|
+
with self.db.transaction():
|
|
151
|
+
self.db.conn.execute("INSERT INTO nodes (data, vector) VALUES ('{\"name\": \"node1\"}', [0.1, 0.2, 0.3]);")
|
|
152
|
+
raise Exception("Force rollback")
|
|
153
|
+
except:
|
|
154
|
+
pass
|
|
155
|
+
result = self.db.conn.execute("SELECT * FROM nodes WHERE data = '{\"name\": \"node1\"}'").fetchone()
|
|
156
|
+
self.assertIsNone(result)
|
|
157
|
+
|
|
158
|
+
def test_auto_increment_node_id(self):
|
|
159
|
+
node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
160
|
+
node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
161
|
+
id1 = self.db.insert_node(node1)
|
|
162
|
+
id2 = self.db.insert_node(node2)
|
|
163
|
+
|
|
164
|
+
self.assertEqual(id1, 1)
|
|
165
|
+
self.assertEqual(id2, 2)
|
|
166
|
+
|
|
167
|
+
def test_auto_increment_edge_id(self):
|
|
168
|
+
node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
169
|
+
node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
170
|
+
id1 = self.db.insert_node(node1)
|
|
171
|
+
id2 = self.db.insert_node(node2)
|
|
172
|
+
|
|
173
|
+
edge1 = Edge(source_id=id1, target_id=id2, weight=0.5, relation="friendship")
|
|
174
|
+
edge2 = Edge(source_id=id2, target_id=id1, weight=0.7, relation="colleague")
|
|
175
|
+
self.db.insert_edge(edge1)
|
|
176
|
+
self.db.insert_edge(edge2)
|
|
177
|
+
|
|
178
|
+
result1 = self.db.conn.execute("SELECT source_id, target_id FROM edges WHERE source_id = ? AND target_id = ?", (id1, id2)).fetchone()
|
|
179
|
+
result2 = self.db.conn.execute("SELECT source_id, target_id FROM edges WHERE source_id = ? AND target_id = ?", (id2, id1)).fetchone()
|
|
180
|
+
|
|
181
|
+
self.assertIsNotNone(result1)
|
|
182
|
+
self.assertIsNotNone(result2)
|
|
183
|
+
self.assertEqual(result1[0], id1)
|
|
184
|
+
self.assertEqual(result1[1], id2)
|
|
185
|
+
self.assertEqual(result2[0], id2)
|
|
186
|
+
self.assertEqual(result2[1], id1)
|
|
187
|
+
|
|
188
|
+
def tearDown(self):
|
|
189
|
+
self.db.conn.close()
|
|
190
|
+
# Ensure any test directories are cleaned up
|
|
191
|
+
import shutil
|
|
192
|
+
if os.path.exists('test_db_parquet'):
|
|
193
|
+
shutil.rmtree('test_db_parquet')
|
|
194
|
+
if os.path.exists('test_db_csv'):
|
|
195
|
+
shutil.rmtree('test_db_csv')
|
|
196
|
+
|
|
197
|
+
class TestGraphRAGGetConnectedNodes(unittest.TestCase):
|
|
198
|
+
|
|
199
|
+
def setUp(self):
|
|
200
|
+
# Initialize the GraphRAG instance with an in-memory database for testing
|
|
201
|
+
self.db = GraphRAG(database=':memory:', vector_length=3)
|
|
202
|
+
self.db.create_tables() # Ensure tables are created
|
|
203
|
+
|
|
204
|
+
# Insert nodes and edges for testing
|
|
205
|
+
self.node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
|
|
206
|
+
self.node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
|
|
207
|
+
self.node3 = Node(data={"name": "node3"}, vector=[0.7, 0.8, 0.9])
|
|
208
|
+
self.node1_id = self.db.insert_node(self.node1)
|
|
209
|
+
self.node2_id = self.db.insert_node(self.node2)
|
|
210
|
+
self.node3_id = self.db.insert_node(self.node3)
|
|
211
|
+
|
|
212
|
+
# Insert edges
|
|
213
|
+
self.db.insert_edge(Edge(source_id=self.node1_id, target_id=self.node2_id, weight=0.5, relation="friendship"))
|
|
214
|
+
self.db.insert_edge(Edge(source_id=self.node2_id, target_id=self.node3_id, weight=0.7, relation="colleague"))
|
|
215
|
+
|
|
216
|
+
def test_get_connected_nodes(self):
|
|
217
|
+
# Test for node1, should have node2 as connected
|
|
218
|
+
connected_nodes = self.db.get_connected_nodes(self.node1_id)
|
|
219
|
+
self.assertEqual(len(connected_nodes), 1)
|
|
220
|
+
self.assertEqual(connected_nodes[0].id, self.node2_id)
|
|
221
|
+
|
|
222
|
+
# Test for node2, should have node1 and node3 as connected
|
|
223
|
+
connected_nodes = self.db.get_connected_nodes(self.node2_id)
|
|
224
|
+
self.assertEqual(len(connected_nodes), 2)
|
|
225
|
+
connected_ids = [node.id for node in connected_nodes]
|
|
226
|
+
self.assertIn(self.node1_id, connected_ids)
|
|
227
|
+
self.assertIn(self.node3_id, connected_ids)
|
|
228
|
+
|
|
229
|
+
# Test for node3, should have node2 as connected
|
|
230
|
+
connected_nodes = self.db.get_connected_nodes(self.node3_id)
|
|
231
|
+
self.assertEqual(len(connected_nodes), 1)
|
|
232
|
+
self.assertEqual(connected_nodes[0].id, self.node2_id)
|
|
233
|
+
|
|
234
|
+
def tearDown(self):
|
|
235
|
+
self.db.conn.close()
|
|
236
|
+
|
|
237
|
+
if __name__ == '__main__':
|
|
238
|
+
unittest.main()
|
|
239
|
+
|