graphmemory 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,193 @@
1
+ Metadata-Version: 2.1
2
+ Name: graphmemory
3
+ Version: 0.1.0
4
+ Summary: A package for creating a graph database for use with GraphRAG.
5
+ Home-page: http://pypi.python.org/pypi/graphmemory/
6
+ Author: BradAGI
7
+ Author-email: cavemen_summary_0f@icloud.com
8
+ License: LICENSE.txt
9
+ Keywords: graphrag graph database rag
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: duckdb
12
+ Requires-Dist: json
13
+ Requires-Dist: pydantic
14
+ Requires-Dist: os
15
+ Requires-Dist: logging
16
+
17
+ # GraphMemory - GraphRAG Database
18
+
19
+ ## Overview
20
+ This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
21
+
22
+ Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
23
+
24
+ Each edge has a unique ID, a source node ID, a target node ID, a relationship type, and a weight.
25
+
26
+ This database can be used for any graph-based RAG application or knowledge graph application.
27
+
28
+ Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
29
+
30
+ ## Requirements
31
+ - Python 3.x
32
+ - DuckDB
33
+ - Pydantic
34
+
35
+ ## Installation
36
+ 1. Clone the repository:
37
+ ```sh
38
+ git clone https://github.com/bradAGI/GraphMemory
39
+ ```
40
+ 2. Install the required packages:
41
+ ```sh
42
+ pip install -r requirements.txt
43
+ ```
44
+
45
+ ## Usage
46
+
47
+ ### GraphRAG Class
48
+ The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
49
+
50
+ ### Auto-Incrementing IDs
51
+ If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
52
+
53
+ ### Example
54
+ ```python
55
+ from graphrag.graphrag import GraphRAG
56
+ from graphrag.models import Node, Edge
57
+
58
+
59
+ # Initialize the database from disk (make sure to set vector_length correctly)
60
+ graph_db = GraphRAG(database='graph.db', vector_length=3)
61
+
62
+ # Insert nodes
63
+ node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
64
+ node1_id = graph_db.insert_node(node1)
65
+
66
+ node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
67
+ node2_id = graph_db.insert_node(node2)
68
+
69
+ # Insert edge
70
+ edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
71
+ graph_db.insert_edge(edge)
72
+
73
+ # Print all nodes in the database
74
+ nodes = graph_db.nodes_to_json()
75
+ print("Nodes:", nodes)
76
+
77
+ # Print all edges in the database
78
+ edges = graph_db.edges_to_json()
79
+ print("Edges:", edges)
80
+
81
+ # Find connected nodes
82
+ connected_nodes = graph_db.connected_nodes(node1_id)
83
+ print("Connected Nodes:", connected_nodes)
84
+
85
+ # Find nearest neighbors
86
+ neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
87
+ print("Nearest Neighbors:", neighbors)
88
+
89
+ # Insert an edge between the two nodes with a relation
90
+ edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
91
+ graph_db.insert_edge(edge)
92
+
93
+ # Define the additional nodes for bulk insert
94
+ nodes = [
95
+ Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
96
+ Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
97
+ ]
98
+
99
+ # Bulk insert nodes
100
+ graph_db.bulk_insert_nodes(nodes)
101
+
102
+ # Define the additional edges for bulk insert
103
+ edges = [
104
+ Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
105
+ Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
106
+ ]
107
+
108
+ # Bulk insert edges
109
+ graph_db.bulk_insert_edges(edges)
110
+
111
+ # Delete a node
112
+ graph_db.delete_node(nodes[-1].id)
113
+
114
+ # Delete an edge
115
+ graph_db.delete_edge(1, 2)
116
+
117
+ # Find nearest nodes to a given vector by distance
118
+ neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
119
+ print("Nearest Neighbors:", neighbors)
120
+
121
+ # Find connected nodes
122
+ connected_nodes = graph_db.connected_nodes(nodes[1].id)
123
+ print("Connected Nodes:", connected_nodes)
124
+ ```
125
+
126
+ ## GraphRAG Class Methods
127
+
128
+ The `GraphRAG` class provides the following public methods for interacting with the graph database:
129
+
130
+ 1. `__init__(self, database=None, vector_length=3)`
131
+ - Initializes the database connection and sets up the database schema if necessary.
132
+
133
+ 2. `set_vector_length(self, vector_length)`
134
+ - Sets the length of the vectors for the nodes in the database.
135
+
136
+ 3. `create_tables(self)`
137
+ - Creates the necessary database tables for nodes and edges if they do not exist.
138
+
139
+ 4. `insert_node(self, node: Node) -> int`
140
+ - Inserts a node into the database and returns the node ID.
141
+
142
+ 5. `insert_edge(self, edge: Edge)`
143
+ - Inserts an edge between two nodes in the database.
144
+
145
+ 6. `bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]`
146
+ - Performs a bulk insert of multiple nodes into the database.
147
+
148
+ 7. `bulk_insert_edges(self, edges: List[Edge])`
149
+ - Performs a bulk insert of multiple edges into the database.
150
+
151
+ 8. `delete_node(self, node_id: int)`
152
+ - Deletes a node and its associated edges from the database.
153
+
154
+ 9. `delete_edge(self, source_id: int, target_id: int)`
155
+ - Deletes an edge from the database.
156
+
157
+ 10. `create_index(self)`
158
+ - Creates an index on the node vectors to improve search performance.
159
+
160
+ 11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
161
+ - Finds and returns the nearest neighbor nodes based on vector similarity.
162
+
163
+ 12. `connected_nodes(self, node_id: int) -> List[Node]`
164
+ - Retrieves all nodes directly connected to the specified node.
165
+
166
+ 13. `nodes_to_json(self)`
167
+ - Returns a JSON representation of all nodes in the database.
168
+
169
+ 14. `edges_to_json(self)`
170
+ - Returns a JSON representation of all edges in the database.
171
+
172
+ 15. `get_node(self, node_id: int)`
173
+ - Retrieves a specific node by its ID.
174
+
175
+ 16. `print_json(self)`
176
+ - Prints the JSON representation of all nodes and edges in the database.
177
+
178
+ These methods facilitate the management and querying of the graph database, allowing for efficient data handling and retrieval.
179
+
180
+ ## Testing
181
+ Unit tests are provided in `tests/tests.py`.
182
+
183
+ ### Running Tests
184
+ To run the unit tests, use the following command:
185
+ ```sh
186
+ python -m unittest discover -s tests
187
+ ```
188
+
189
+ ## License
190
+ This project is licensed under the MIT License. See the LICENSE file for details.
191
+
192
+ ## Contributing
193
+ Contributions are welcome! Please open an issue or submit a pull request.
@@ -0,0 +1,177 @@
1
+ # GraphMemory - GraphRAG Database
2
+
3
+ ## Overview
4
+ This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
5
+
6
+ Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
7
+
8
+ Each edge has a unique ID, a source node ID, a target node ID, a relationship type, and a weight.
9
+
10
+ This database can be used for any graph-based RAG application or knowledge graph application.
11
+
12
+ Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
13
+
14
+ ## Requirements
15
+ - Python 3.x
16
+ - DuckDB
17
+ - Pydantic
18
+
19
+ ## Installation
20
+ 1. Clone the repository:
21
+ ```sh
22
+ git clone https://github.com/bradAGI/GraphMemory
23
+ ```
24
+ 2. Install the required packages:
25
+ ```sh
26
+ pip install -r requirements.txt
27
+ ```
28
+
29
+ ## Usage
30
+
31
+ ### GraphRAG Class
32
+ The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
33
+
34
+ ### Auto-Incrementing IDs
35
+ If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
36
+
37
+ ### Example
38
+ ```python
39
+ from graphrag.graphrag import GraphRAG
40
+ from graphrag.models import Node, Edge
41
+
42
+
43
+ # Initialize the database from disk (make sure to set vector_length correctly)
44
+ graph_db = GraphRAG(database='graph.db', vector_length=3)
45
+
46
+ # Insert nodes
47
+ node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
48
+ node1_id = graph_db.insert_node(node1)
49
+
50
+ node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
51
+ node2_id = graph_db.insert_node(node2)
52
+
53
+ # Insert edge
54
+ edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
55
+ graph_db.insert_edge(edge)
56
+
57
+ # Print all nodes in the database
58
+ nodes = graph_db.nodes_to_json()
59
+ print("Nodes:", nodes)
60
+
61
+ # Print all edges in the database
62
+ edges = graph_db.edges_to_json()
63
+ print("Edges:", edges)
64
+
65
+ # Find connected nodes
66
+ connected_nodes = graph_db.connected_nodes(node1_id)
67
+ print("Connected Nodes:", connected_nodes)
68
+
69
+ # Find nearest neighbors
70
+ neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
71
+ print("Nearest Neighbors:", neighbors)
72
+
73
+ # Insert an edge between the two nodes with a relation
74
+ edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
75
+ graph_db.insert_edge(edge)
76
+
77
+ # Define the additional nodes for bulk insert
78
+ nodes = [
79
+ Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
80
+ Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
81
+ ]
82
+
83
+ # Bulk insert nodes
84
+ graph_db.bulk_insert_nodes(nodes)
85
+
86
+ # Define the additional edges for bulk insert
87
+ edges = [
88
+ Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
89
+ Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
90
+ ]
91
+
92
+ # Bulk insert edges
93
+ graph_db.bulk_insert_edges(edges)
94
+
95
+ # Delete a node
96
+ graph_db.delete_node(nodes[-1].id)
97
+
98
+ # Delete an edge
99
+ graph_db.delete_edge(1, 2)
100
+
101
+ # Find nearest nodes to a given vector by distance
102
+ neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
103
+ print("Nearest Neighbors:", neighbors)
104
+
105
+ # Find connected nodes
106
+ connected_nodes = graph_db.connected_nodes(nodes[1].id)
107
+ print("Connected Nodes:", connected_nodes)
108
+ ```
109
+
110
+ ## GraphRAG Class Methods
111
+
112
+ The `GraphRAG` class provides the following public methods for interacting with the graph database:
113
+
114
+ 1. `__init__(self, database=None, vector_length=3)`
115
+ - Initializes the database connection and sets up the database schema if necessary.
116
+
117
+ 2. `set_vector_length(self, vector_length)`
118
+ - Sets the length of the vectors for the nodes in the database.
119
+
120
+ 3. `create_tables(self)`
121
+ - Creates the necessary database tables for nodes and edges if they do not exist.
122
+
123
+ 4. `insert_node(self, node: Node) -> int`
124
+ - Inserts a node into the database and returns the node ID.
125
+
126
+ 5. `insert_edge(self, edge: Edge)`
127
+ - Inserts an edge between two nodes in the database.
128
+
129
+ 6. `bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]`
130
+ - Performs a bulk insert of multiple nodes into the database.
131
+
132
+ 7. `bulk_insert_edges(self, edges: List[Edge])`
133
+ - Performs a bulk insert of multiple edges into the database.
134
+
135
+ 8. `delete_node(self, node_id: int)`
136
+ - Deletes a node and its associated edges from the database.
137
+
138
+ 9. `delete_edge(self, source_id: int, target_id: int)`
139
+ - Deletes an edge from the database.
140
+
141
+ 10. `create_index(self)`
142
+ - Creates an index on the node vectors to improve search performance.
143
+
144
+ 11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
145
+ - Finds and returns the nearest neighbor nodes based on vector similarity.
146
+
147
+ 12. `connected_nodes(self, node_id: int) -> List[Node]`
148
+ - Retrieves all nodes directly connected to the specified node.
149
+
150
+ 13. `nodes_to_json(self)`
151
+ - Returns a JSON representation of all nodes in the database.
152
+
153
+ 14. `edges_to_json(self)`
154
+ - Returns a JSON representation of all edges in the database.
155
+
156
+ 15. `get_node(self, node_id: int)`
157
+ - Retrieves a specific node by its ID.
158
+
159
+ 16. `print_json(self)`
160
+ - Prints the JSON representation of all nodes and edges in the database.
161
+
162
+ These methods facilitate the management and querying of the graph database, allowing for efficient data handling and retrieval.
163
+
164
+ ## Testing
165
+ Unit tests are provided in `tests/tests.py`.
166
+
167
+ ### Running Tests
168
+ To run the unit tests, use the following command:
169
+ ```sh
170
+ python -m unittest discover -s tests
171
+ ```
172
+
173
+ ## License
174
+ This project is licensed under the MIT License. See the LICENSE file for details.
175
+
176
+ ## Contributing
177
+ Contributions are welcome! Please open an issue or submit a pull request.
@@ -0,0 +1,193 @@
1
+ Metadata-Version: 2.1
2
+ Name: graphmemory
3
+ Version: 0.1.0
4
+ Summary: A package for creating a graph database for use with GraphRAG.
5
+ Home-page: http://pypi.python.org/pypi/graphmemory/
6
+ Author: BradAGI
7
+ Author-email: cavemen_summary_0f@icloud.com
8
+ License: LICENSE.txt
9
+ Keywords: graphrag graph database rag
10
+ Description-Content-Type: text/markdown
11
+ Requires-Dist: duckdb
12
+ Requires-Dist: json
13
+ Requires-Dist: pydantic
14
+ Requires-Dist: os
15
+ Requires-Dist: logging
16
+
17
+ # GraphMemory - GraphRAG Database
18
+
19
+ ## Overview
20
+ This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
21
+
22
+ Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
23
+
24
+ Each edge has a unique ID, a source node ID, a target node ID, a relationship type, and a weight.
25
+
26
+ This database can be used for any graph-based RAG application or knowledge graph application.
27
+
28
+ Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
29
+
30
+ ## Requirements
31
+ - Python 3.x
32
+ - DuckDB
33
+ - Pydantic
34
+
35
+ ## Installation
36
+ 1. Clone the repository:
37
+ ```sh
38
+ git clone https://github.com/bradAGI/GraphMemory
39
+ ```
40
+ 2. Install the required packages:
41
+ ```sh
42
+ pip install -r requirements.txt
43
+ ```
44
+
45
+ ## Usage
46
+
47
+ ### GraphRAG Class
48
+ The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
49
+
50
+ ### Auto-Incrementing IDs
51
+ If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
52
+
53
+ ### Example
54
+ ```python
55
+ from graphrag.graphrag import GraphRAG
56
+ from graphrag.models import Node, Edge
57
+
58
+
59
+ # Initialize the database from disk (make sure to set vector_length correctly)
60
+ graph_db = GraphRAG(database='graph.db', vector_length=3)
61
+
62
+ # Insert nodes
63
+ node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
64
+ node1_id = graph_db.insert_node(node1)
65
+
66
+ node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
67
+ node2_id = graph_db.insert_node(node2)
68
+
69
+ # Insert edge
70
+ edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
71
+ graph_db.insert_edge(edge)
72
+
73
+ # Print all nodes in the database
74
+ nodes = graph_db.nodes_to_json()
75
+ print("Nodes:", nodes)
76
+
77
+ # Print all edges in the database
78
+ edges = graph_db.edges_to_json()
79
+ print("Edges:", edges)
80
+
81
+ # Find connected nodes
82
+ connected_nodes = graph_db.connected_nodes(node1_id)
83
+ print("Connected Nodes:", connected_nodes)
84
+
85
+ # Find nearest neighbors
86
+ neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
87
+ print("Nearest Neighbors:", neighbors)
88
+
89
+ # Insert an edge between the two nodes with a relation
90
+ edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
91
+ graph_db.insert_edge(edge)
92
+
93
+ # Define the additional nodes for bulk insert
94
+ nodes = [
95
+ Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
96
+ Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
97
+ ]
98
+
99
+ # Bulk insert nodes
100
+ graph_db.bulk_insert_nodes(nodes)
101
+
102
+ # Define the additional edges for bulk insert
103
+ edges = [
104
+ Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
105
+ Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
106
+ ]
107
+
108
+ # Bulk insert edges
109
+ graph_db.bulk_insert_edges(edges)
110
+
111
+ # Delete a node
112
+ graph_db.delete_node(nodes[-1].id)
113
+
114
+ # Delete an edge
115
+ graph_db.delete_edge(1, 2)
116
+
117
+ # Find nearest nodes to a given vector by distance
118
+ neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
119
+ print("Nearest Neighbors:", neighbors)
120
+
121
+ # Find connected nodes
122
+ connected_nodes = graph_db.connected_nodes(nodes[1].id)
123
+ print("Connected Nodes:", connected_nodes)
124
+ ```
125
+
126
+ ## GraphRAG Class Methods
127
+
128
+ The `GraphRAG` class provides the following public methods for interacting with the graph database:
129
+
130
+ 1. `__init__(self, database=None, vector_length=3)`
131
+ - Initializes the database connection and sets up the database schema if necessary.
132
+
133
+ 2. `set_vector_length(self, vector_length)`
134
+ - Sets the length of the vectors for the nodes in the database.
135
+
136
+ 3. `create_tables(self)`
137
+ - Creates the necessary database tables for nodes and edges if they do not exist.
138
+
139
+ 4. `insert_node(self, node: Node) -> int`
140
+ - Inserts a node into the database and returns the node ID.
141
+
142
+ 5. `insert_edge(self, edge: Edge)`
143
+ - Inserts an edge between two nodes in the database.
144
+
145
+ 6. `bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]`
146
+ - Performs a bulk insert of multiple nodes into the database.
147
+
148
+ 7. `bulk_insert_edges(self, edges: List[Edge])`
149
+ - Performs a bulk insert of multiple edges into the database.
150
+
151
+ 8. `delete_node(self, node_id: int)`
152
+ - Deletes a node and its associated edges from the database.
153
+
154
+ 9. `delete_edge(self, source_id: int, target_id: int)`
155
+ - Deletes an edge from the database.
156
+
157
+ 10. `create_index(self)`
158
+ - Creates an index on the node vectors to improve search performance.
159
+
160
+ 11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
161
+ - Finds and returns the nearest neighbor nodes based on vector similarity.
162
+
163
+ 12. `connected_nodes(self, node_id: int) -> List[Node]`
164
+ - Retrieves all nodes directly connected to the specified node.
165
+
166
+ 13. `nodes_to_json(self)`
167
+ - Returns a JSON representation of all nodes in the database.
168
+
169
+ 14. `edges_to_json(self)`
170
+ - Returns a JSON representation of all edges in the database.
171
+
172
+ 15. `get_node(self, node_id: int)`
173
+ - Retrieves a specific node by its ID.
174
+
175
+ 16. `print_json(self)`
176
+ - Prints the JSON representation of all nodes and edges in the database.
177
+
178
+ These methods facilitate the management and querying of the graph database, allowing for efficient data handling and retrieval.
179
+
180
+ ## Testing
181
+ Unit tests are provided in `tests/tests.py`.
182
+
183
+ ### Running Tests
184
+ To run the unit tests, use the following command:
185
+ ```sh
186
+ python -m unittest discover -s tests
187
+ ```
188
+
189
+ ## License
190
+ This project is licensed under the MIT License. See the LICENSE file for details.
191
+
192
+ ## Contributing
193
+ Contributions are welcome! Please open an issue or submit a pull request.
@@ -0,0 +1,11 @@
1
+ README.md
2
+ setup.py
3
+ graphmemory.egg-info/PKG-INFO
4
+ graphmemory.egg-info/SOURCES.txt
5
+ graphmemory.egg-info/dependency_links.txt
6
+ graphmemory.egg-info/requires.txt
7
+ graphmemory.egg-info/top_level.txt
8
+ graphrag/__init__.py
9
+ graphrag/database.py
10
+ graphrag/models.py
11
+ tests/tests.py
@@ -0,0 +1,5 @@
1
+ duckdb
2
+ json
3
+ pydantic
4
+ os
5
+ logging
@@ -0,0 +1 @@
1
+ graphrag
@@ -0,0 +1,3 @@
1
+ from .database import GraphRAG
2
+ from .models import Node, Edge, Neighbor
3
+
@@ -0,0 +1,258 @@
1
+ import duckdb
2
+ import json
3
+ import os
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from graphrag.models import Node, Edge, Neighbor
7
+ from typing import List
8
+
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class GraphRAG:
15
+ def __init__(self, database=None, vector_length=3):
16
+ self.database = database
17
+ self.vector_length = vector_length
18
+ self.conn = duckdb.connect(database=self.database)
19
+ self._load_vss_extension()
20
+ self._configure_database()
21
+
22
+ if database and os.path.exists(database):
23
+ self.load_database(database)
24
+
25
+ # Check if 'nodes' and 'edges' tables exist, and create them if they do not
26
+ nodes_exist = self.conn.execute("SELECT 1 FROM information_schema.tables WHERE table_name = 'nodes';").fetchone()
27
+ edges_exist = self.conn.execute("SELECT 1 FROM information_schema.tables WHERE table_name = 'edges';").fetchone()
28
+
29
+ if not nodes_exist or not edges_exist:
30
+ self.create_tables()
31
+ logger.info("Tables created or verified successfully.")
32
+
33
+ def load_database(self, path):
34
+ if not os.path.exists(path):
35
+ logger.error(f"Database file not found: {path}")
36
+ return
37
+ try:
38
+ self.conn.execute(f"ATTACH DATABASE '{path}' AS vss;")
39
+ except duckdb.Error as e:
40
+ logger.error(f"Error loading database: {e}")
41
+
42
+ def _configure_database(self):
43
+ try:
44
+ self.conn.execute("SET hnsw_enable_experimental_persistence=true;")
45
+ except duckdb.Error as e:
46
+ logger.error(f"Error setting configuration: {e}")
47
+
48
+ def _load_vss_extension(self):
49
+ try:
50
+ self.conn.execute("INSTALL vss;")
51
+ self.conn.execute("LOAD vss;")
52
+ except duckdb.Error as e:
53
+ logger.error(f"Error loading VSS extension: {e}")
54
+
55
+ def set_vector_length(self, vector_length):
56
+ self.vector_length = vector_length
57
+ logger.info(f"Vector length set to: {self.vector_length}")
58
+
59
+ def create_tables(self):
60
+ self.conn.execute("CREATE SEQUENCE IF NOT EXISTS seq_id;")
61
+ self.conn.execute(f"""
62
+ CREATE TABLE IF NOT EXISTS nodes (
63
+ id INTEGER DEFAULT nextval('seq_id'),
64
+ data JSON,
65
+ vector FLOAT[{self.vector_length}],
66
+ PRIMARY KEY (id)
67
+ );
68
+ """)
69
+ self.conn.execute("""
70
+ CREATE TABLE IF NOT EXISTS edges (
71
+ source_id INTEGER,
72
+ target_id INTEGER,
73
+ weight FLOAT,
74
+ relation TEXT,
75
+ PRIMARY KEY (source_id, target_id)
76
+ );
77
+ """)
78
+ logger.info("Tables 'nodes' and 'edges' created or already exist.")
79
+ self.conn.commit() # Ensure changes are committed
80
+
81
+ def insert_node(self, node: Node) -> int:
82
+ if not self._validate_vector(node.vector):
83
+ logger.error("Invalid vector: Must be a list of float values.")
84
+ return -1
85
+ try:
86
+ with self.transaction():
87
+ result = self.conn.execute(
88
+ "INSERT INTO nodes (data, vector) VALUES (?, ?) RETURNING id;",
89
+ (json.dumps(node.data), node.vector)
90
+ ).fetchone()
91
+ if result:
92
+ logger.info(f"Node inserted with ID: {result[0]}")
93
+ return result[0] if result else -1
94
+ except duckdb.Error as e:
95
+ logger.error(f"Error during insert node: {e}")
96
+ return -1
97
+
98
+ # Update insert_edge to use Edge model
99
+ def insert_edge(self, edge: Edge):
100
+ try:
101
+ with self.transaction():
102
+ # Check if source and target nodes exist
103
+ source_exists = self.conn.execute("SELECT 1 FROM nodes WHERE id = ?", (edge.source_id,)).fetchone()
104
+ target_exists = self.conn.execute("SELECT 1 FROM nodes WHERE id = ?", (edge.target_id,)).fetchone()
105
+ if not source_exists or not target_exists:
106
+ raise ValueError("Source or target node does not exist.")
107
+
108
+ self.conn.execute("INSERT INTO edges (source_id, target_id, weight, relation) VALUES (?, ?, ?, ?);", (edge.source_id, edge.target_id, edge.weight, edge.relation))
109
+ except duckdb.Error as e:
110
+ logger.error(f"Error during insert edge: {e}")
111
+ except ValueError as e:
112
+ logger.error(f"Error during insert edge: {e}")
113
+ raise # Re-raise the ValueError
114
+
115
+ def bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]:
116
+ try:
117
+ with self.transaction():
118
+ for node in nodes:
119
+ result = self.conn.execute(
120
+ "INSERT INTO nodes (data, vector) VALUES (?, ?) RETURNING id;",
121
+ (json.dumps(node.data), node.vector)
122
+ ).fetchone()
123
+ if result:
124
+ node.id = result[0] # Set the node ID
125
+ return nodes
126
+ except duckdb.Error as e:
127
+ logger.error(f"Error during bulk insert nodes: {e}")
128
+ return []
129
+
130
+ def bulk_insert_edges(self, edges: List[Edge]):
131
+ try:
132
+ with self.transaction():
133
+ self.conn.executemany(
134
+ "INSERT INTO edges (source_id, target_id, weight, relation) VALUES (?, ?, ?, ?);",
135
+ [(edge.source_id, edge.target_id, edge.weight, edge.relation) for edge in edges]
136
+ )
137
+ except duckdb.Error as e:
138
+ logger.error(f"Error during bulk insert edges: {e}")
139
+
140
+ def delete_node(self, node_id: int):
141
+ try:
142
+ with self.transaction():
143
+ self.conn.execute("DELETE FROM nodes WHERE id = ?;", (node_id,))
144
+ self.conn.execute("DELETE FROM edges WHERE source_id = ? OR target_id = ?;", (node_id, node_id))
145
+ except duckdb.Error as e:
146
+ logger.error(f"Error deleting node: {e}")
147
+
148
+ def delete_edge(self, source_id: int, target_id: int):
149
+ try:
150
+ with self.transaction():
151
+ self.conn.execute("DELETE FROM edges WHERE source_id = ? AND target_id = ?;", (source_id, target_id))
152
+ except duckdb.Error as e:
153
+ logger.error(f"Error deleting edge: {e}")
154
+
155
+ def create_index(self):
156
+ try:
157
+ self.conn.execute("CREATE INDEX IF NOT EXISTS vss_idx ON nodes USING HNSW(vector);")
158
+ except duckdb.Error as e:
159
+ logger.error(f"Error creating index: {e}")
160
+
161
+ # Update nearest_neighbors to use vector and limit directly
162
+ def nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]:
163
+ if not self._validate_vector(vector):
164
+ logger.error("Invalid vector: Must be a list of float values.")
165
+ return []
166
+
167
+ query = f"""
168
+ SELECT id, data, vector, array_distance(vector, CAST(? AS FLOAT[{self.vector_length}])) AS distance
169
+ FROM nodes
170
+ ORDER BY distance LIMIT ?;
171
+ """
172
+ try:
173
+ results = self.conn.execute(query, (vector, limit)).fetchall()
174
+ return [
175
+ Neighbor(
176
+ node=Node(id=row[0], data=json.loads(row[1]), vector=row[2]),
177
+ distance=row[3]
178
+ ) for row in results
179
+ ]
180
+ except duckdb.Error as e:
181
+ logger.error(f"Error fetching nearest neighbors: {e}")
182
+ return []
183
+
184
+ def connected_nodes(self, node_id: int) -> List[Node]:
185
+ query = """
186
+ SELECT n.id, n.data, n.vector
187
+ FROM nodes n
188
+ WHERE n.id IN (
189
+ SELECT target_id FROM edges WHERE source_id = CAST(? AS INTEGER)
190
+ UNION
191
+ SELECT source_id FROM edges WHERE target_id = CAST(? AS INTEGER)
192
+ );
193
+ """
194
+ try:
195
+ logger.info(f"Executing query to fetch connected nodes for node_id: {node_id}")
196
+ results = self.conn.execute(query, (node_id, node_id)).fetchall()
197
+ if results:
198
+ connected_nodes = [Node(id=row[0], data=json.loads(row[1]), vector=row[2]) for row in results]
199
+ logger.info(f"Found {len(connected_nodes)} connected nodes.")
200
+ else:
201
+ connected_nodes = []
202
+ logger.info("No connected nodes found.")
203
+ return connected_nodes
204
+ except duckdb.Error as e:
205
+ logger.error(f"Error fetching connected nodes: {e}")
206
+ return []
207
+
208
+ def nodes_to_json(self):
209
+ try:
210
+ nodes = self.conn.execute("SELECT id, data, vector FROM nodes;").fetchall()
211
+ return [{"id": row[0], "data": json.loads(row[1]), "vector": row[2]} for row in nodes]
212
+ except duckdb.Error as e:
213
+ logger.error(f"Error fetching nodes: {e}")
214
+ return []
215
+
216
+ def edges_to_json(self):
217
+ try:
218
+ edges = self.conn.execute("SELECT source_id, target_id, weight, relation FROM edges;").fetchall()
219
+ return [{"source_id": row[0], "target_id": row[1], "weight": row[2], "relation": row[3]} for row in edges]
220
+ except duckdb.Error as e:
221
+ logger.error(f"Error fetching edges: {e}")
222
+ return []
223
+
224
+ # Get node by id
225
+ def get_node(self, node_id: int):
226
+ try:
227
+ node = self.conn.execute("SELECT id, data, vector FROM nodes WHERE id = ?;", (node_id,)).fetchone()
228
+ return {"id": node[0], "data": json.loads(node[1]), "vector": node[2]}
229
+ except duckdb.Error as e:
230
+ logger.error(f"Error fetching node: {e}")
231
+ return {}
232
+
233
+ def print_json(self):
234
+ nodes_json = self.nodes_to_json()
235
+ edges_json = self.edges_to_json()
236
+ print("Nodes JSON:", json.dumps(nodes_json, indent=2))
237
+ print("Edges JSON:", json.dumps(edges_json, indent=2))
238
+
239
+ def _validate_vector(self, vector):
240
+ return isinstance(vector, list) and len(vector) == self.vector_length and all(isinstance(x, float) for x in vector)
241
+
242
+ @contextmanager
243
+ def transaction(self):
244
+ try:
245
+ self.conn.execute("BEGIN TRANSACTION;")
246
+ yield
247
+ self.conn.execute("COMMIT;")
248
+ except Exception as e:
249
+ self.conn.execute("ROLLBACK;")
250
+ raise e
251
+
252
+ def __enter__(self):
253
+ return self
254
+
255
+ def __exit__(self, exc_type, exc_value, traceback):
256
+ self.conn.close()
257
+ logger.info("Database connection closed.")
258
+
@@ -0,0 +1,17 @@
1
+ from pydantic import BaseModel
2
+ from typing import List, Dict, Any
3
+
4
+ class Node(BaseModel):
5
+ id: int = None
6
+ data: Dict[str, Any]
7
+ vector: List[float]
8
+
9
+ class Edge(BaseModel):
10
+ source_id: int
11
+ target_id: int
12
+ relation: str
13
+ weight: float
14
+
15
+ class Neighbor(BaseModel):
16
+ node: Node
17
+ distance: float
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,22 @@
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='graphmemory',
5
+ version='0.1.0',
6
+ author='BradAGI',
7
+ author_email='cavemen_summary_0f@icloud.com',
8
+ packages=find_packages(),
9
+ url='http://pypi.python.org/pypi/graphmemory/',
10
+ license='LICENSE.txt',
11
+ description='A package for creating a graph database for use with GraphRAG.',
12
+ long_description=open('README.md').read(),
13
+ long_description_content_type='text/markdown',
14
+ install_requires=[
15
+ 'duckdb',
16
+ 'json',
17
+ 'pydantic',
18
+ 'os',
19
+ 'logging'
20
+ ],
21
+ keywords='graphrag graph database rag'
22
+ )
@@ -0,0 +1,239 @@
1
+ import sys
2
+ import os
3
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
4
+ import duckdb
5
+ import unittest
6
+ from src.graphrag import GraphRAG
7
+ from src.models import Node, Edge, Neighbor
8
+ from pydantic import ValidationError
9
+
10
+ class TestGraphRAG(unittest.TestCase):
11
+
12
+ def setUp(self):
13
+ self.db = GraphRAG(database=':memory:', vector_length=3)
14
+
15
+ def test_insert_node(self):
16
+ node = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
17
+ node_id = self.db.insert_node(node)
18
+ result = self.db.conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
19
+ self.assertIsNotNone(result)
20
+ self.assertEqual(result[0], node_id)
21
+ self.assertEqual(result[1], '{"name": "node1"}')
22
+ self.assertAlmostEqual(result[2][0], 0.1, places=7)
23
+ self.assertAlmostEqual(result[2][1], 0.2, places=7)
24
+ self.assertAlmostEqual(result[2][2], 0.3, places=7)
25
+
26
+ def test_insert_invalid_node(self):
27
+ with self.assertRaises(ValidationError):
28
+ node = Node(data={"name": "node1"}, vector=[0.1, 0.2, "invalid"])
29
+ self.db.insert_node(node)
30
+
31
+ def test_insert_node_invalid_vector_length(self):
32
+ node = Node(data={"name": "node1"}, vector=[0.1, 0.2])
33
+ result = self.db.insert_node(node)
34
+ self.assertEqual(result, -1)
35
+
36
+ def test_insert_edge(self):
37
+ node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
38
+ node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
39
+ node1_id = self.db.insert_node(node1)
40
+ node2_id = self.db.insert_node(node2)
41
+ edge = Edge(source_id=node1_id, target_id=node2_id, weight=0.5, relation="friendship")
42
+ self.db.insert_edge(edge)
43
+ result = self.db.conn.execute("SELECT source_id, target_id, weight, relation FROM edges WHERE source_id = ? AND target_id = ?", (node1_id, node2_id)).fetchone()
44
+ self.assertIsNotNone(result)
45
+ self.assertEqual(result[0], node1_id)
46
+ self.assertEqual(result[1], node2_id)
47
+ self.assertEqual(result[2], 0.5)
48
+ self.assertEqual(result[3], "friendship")
49
+
50
+ def test_insert_edge_non_existent_nodes(self):
51
+ edge = Edge(source_id=999, target_id=1000, weight=0.5, relation="friendship")
52
+ with self.assertRaises(ValueError):
53
+ self.db.insert_edge(edge)
54
+
55
+ def test_nearest_neighbors_empty_db(self):
56
+ neighbors = self.db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
57
+ self.assertEqual(len(neighbors), 0)
58
+
59
+ def test_nearest_neighbors(self):
60
+ nodes = [
61
+ Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3]),
62
+ Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6]),
63
+ Node(data={"name": "node3"}, vector=[0.7, 0.8, 0.9])
64
+ ]
65
+ for node in nodes:
66
+ self.db.insert_node(node)
67
+ self.db.create_index()
68
+ neighbors = self.db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
69
+ self.assertEqual(len(neighbors), 2)
70
+ self.assertEqual(neighbors[0].node.id, 1)
71
+ self.assertEqual(neighbors[1].node.id, 2)
72
+
73
+ def test_nodes_to_json(self):
74
+ node = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
75
+ self.db.insert_node(node)
76
+ nodes_json = self.db.nodes_to_json()
77
+ self.assertEqual(len(nodes_json), 1)
78
+ self.assertEqual(nodes_json[0]['id'], 1)
79
+ self.assertEqual(nodes_json[0]['data'], {"name": "node1"})
80
+ self.assertAlmostEqual(nodes_json[0]['vector'][0], 0.1, places=7)
81
+ self.assertAlmostEqual(nodes_json[0]['vector'][1], 0.2, places=7)
82
+ self.assertAlmostEqual(nodes_json[0]['vector'][2], 0.3, places=7)
83
+
84
+ def test_edges_to_json(self):
85
+ node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
86
+ node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
87
+ self.db.insert_node(node1)
88
+ self.db.insert_node(node2)
89
+ edge = Edge(source_id=1, target_id=2, weight=0.5, relation="friendship")
90
+ self.db.insert_edge(edge)
91
+ edges_json = self.db.edges_to_json()
92
+ self.assertEqual(len(edges_json), 1)
93
+ self.assertEqual(edges_json[0]['source_id'], 1)
94
+ self.assertEqual(edges_json[0]['target_id'], 2)
95
+ self.assertEqual(edges_json[0]['weight'], 0.5)
96
+ self.assertEqual(edges_json[0]['relation'], "friendship")
97
+
98
+ def test_bulk_insert_nodes(self):
99
+ nodes = [
100
+ Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3]),
101
+ Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
102
+ ]
103
+ inserted_nodes = self.db.bulk_insert_nodes(nodes)
104
+ self.assertEqual(len(inserted_nodes), 2)
105
+ self.assertTrue(all(node.id is not None for node in inserted_nodes))
106
+
107
+ def test_bulk_insert_edges(self):
108
+ node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
109
+ node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
110
+ node3 = Node(data={"name": "node3"}, vector=[0.7, 0.8, 0.9])
111
+ self.db.bulk_insert_nodes([node1, node2, node3])
112
+ edges = [
113
+ Edge(source_id=1, target_id=2, weight=0.5, relation="friendship"),
114
+ Edge(source_id=2, target_id=3, weight=0.7, relation="colleague")
115
+ ]
116
+ self.db.bulk_insert_edges(edges)
117
+ result = self.db.conn.execute("SELECT * FROM edges").fetchall()
118
+ self.assertEqual(len(result), 2)
119
+
120
+ def test_delete_node(self):
121
+ node = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
122
+ node_id = self.db.insert_node(node)
123
+ self.db.delete_node(node_id)
124
+ result = self.db.conn.execute("SELECT * FROM nodes WHERE id = ?", (node_id,)).fetchone()
125
+ self.assertIsNone(result)
126
+
127
+ def test_delete_non_existent_node(self):
128
+ self.db.delete_node(999)
129
+ result = self.db.conn.execute("SELECT * FROM nodes WHERE id = 999").fetchone()
130
+ self.assertIsNone(result)
131
+
132
+ def test_delete_edge(self):
133
+ node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
134
+ node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
135
+ self.db.insert_node(node1)
136
+ self.db.insert_node(node2)
137
+ edge = Edge(source_id=1, target_id=2, weight=0.5, relation="friendship")
138
+ self.db.insert_edge(edge)
139
+ self.db.delete_edge(1, 2)
140
+ result = self.db.conn.execute("SELECT * FROM edges WHERE source_id = 1 AND target_id = 2").fetchone()
141
+ self.assertIsNone(result)
142
+
143
+ def test_delete_non_existent_edge(self):
144
+ self.db.delete_edge(999, 1000)
145
+ result = self.db.conn.execute("SELECT * FROM edges WHERE source_id = 999 AND target_id = 1000").fetchone()
146
+ self.assertIsNone(result)
147
+
148
+ def test_transaction_handling(self):
149
+ try:
150
+ with self.db.transaction():
151
+ self.db.conn.execute("INSERT INTO nodes (data, vector) VALUES ('{\"name\": \"node1\"}', [0.1, 0.2, 0.3]);")
152
+ raise Exception("Force rollback")
153
+ except:
154
+ pass
155
+ result = self.db.conn.execute("SELECT * FROM nodes WHERE data = '{\"name\": \"node1\"}'").fetchone()
156
+ self.assertIsNone(result)
157
+
158
+ def test_auto_increment_node_id(self):
159
+ node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
160
+ node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
161
+ id1 = self.db.insert_node(node1)
162
+ id2 = self.db.insert_node(node2)
163
+
164
+ self.assertEqual(id1, 1)
165
+ self.assertEqual(id2, 2)
166
+
167
+ def test_auto_increment_edge_id(self):
168
+ node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
169
+ node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
170
+ id1 = self.db.insert_node(node1)
171
+ id2 = self.db.insert_node(node2)
172
+
173
+ edge1 = Edge(source_id=id1, target_id=id2, weight=0.5, relation="friendship")
174
+ edge2 = Edge(source_id=id2, target_id=id1, weight=0.7, relation="colleague")
175
+ self.db.insert_edge(edge1)
176
+ self.db.insert_edge(edge2)
177
+
178
+ result1 = self.db.conn.execute("SELECT source_id, target_id FROM edges WHERE source_id = ? AND target_id = ?", (id1, id2)).fetchone()
179
+ result2 = self.db.conn.execute("SELECT source_id, target_id FROM edges WHERE source_id = ? AND target_id = ?", (id2, id1)).fetchone()
180
+
181
+ self.assertIsNotNone(result1)
182
+ self.assertIsNotNone(result2)
183
+ self.assertEqual(result1[0], id1)
184
+ self.assertEqual(result1[1], id2)
185
+ self.assertEqual(result2[0], id2)
186
+ self.assertEqual(result2[1], id1)
187
+
188
+ def tearDown(self):
189
+ self.db.conn.close()
190
+ # Ensure any test directories are cleaned up
191
+ import shutil
192
+ if os.path.exists('test_db_parquet'):
193
+ shutil.rmtree('test_db_parquet')
194
+ if os.path.exists('test_db_csv'):
195
+ shutil.rmtree('test_db_csv')
196
+
197
+ class TestGraphRAGGetConnectedNodes(unittest.TestCase):
198
+
199
+ def setUp(self):
200
+ # Initialize the GraphRAG instance with an in-memory database for testing
201
+ self.db = GraphRAG(database=':memory:', vector_length=3)
202
+ self.db.create_tables() # Ensure tables are created
203
+
204
+ # Insert nodes and edges for testing
205
+ self.node1 = Node(data={"name": "node1"}, vector=[0.1, 0.2, 0.3])
206
+ self.node2 = Node(data={"name": "node2"}, vector=[0.4, 0.5, 0.6])
207
+ self.node3 = Node(data={"name": "node3"}, vector=[0.7, 0.8, 0.9])
208
+ self.node1_id = self.db.insert_node(self.node1)
209
+ self.node2_id = self.db.insert_node(self.node2)
210
+ self.node3_id = self.db.insert_node(self.node3)
211
+
212
+ # Insert edges
213
+ self.db.insert_edge(Edge(source_id=self.node1_id, target_id=self.node2_id, weight=0.5, relation="friendship"))
214
+ self.db.insert_edge(Edge(source_id=self.node2_id, target_id=self.node3_id, weight=0.7, relation="colleague"))
215
+
216
+ def test_get_connected_nodes(self):
217
+ # Test for node1, should have node2 as connected
218
+ connected_nodes = self.db.get_connected_nodes(self.node1_id)
219
+ self.assertEqual(len(connected_nodes), 1)
220
+ self.assertEqual(connected_nodes[0].id, self.node2_id)
221
+
222
+ # Test for node2, should have node1 and node3 as connected
223
+ connected_nodes = self.db.get_connected_nodes(self.node2_id)
224
+ self.assertEqual(len(connected_nodes), 2)
225
+ connected_ids = [node.id for node in connected_nodes]
226
+ self.assertIn(self.node1_id, connected_ids)
227
+ self.assertIn(self.node3_id, connected_ids)
228
+
229
+ # Test for node3, should have node2 as connected
230
+ connected_nodes = self.db.get_connected_nodes(self.node3_id)
231
+ self.assertEqual(len(connected_nodes), 1)
232
+ self.assertEqual(connected_nodes[0].id, self.node2_id)
233
+
234
+ def tearDown(self):
235
+ self.db.conn.close()
236
+
237
+ if __name__ == '__main__':
238
+ unittest.main()
239
+