graphmemory 0.1.2__tar.gz → 0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.1
2
+ Name: graphmemory
3
+ Version: 0.2
4
+ Summary: A package for creating a hybrid graph / vector database for use with GraphRAG.
5
+ Home-page: https://github.com/bradAGI/GraphMemory
6
+ Author: BradAGI
7
+ Author-email: cavemen_summary_0f@icloud.com
8
+ License: LICENSE.txt
9
+ Keywords: graphrag graph database rag vector database
10
+ Requires-Dist: duckdb==1.0.0
11
+ Requires-Dist: pydantic==2.7.3
@@ -1,7 +1,7 @@
1
1
  # GraphMemory - GraphRAG Database
2
2
 
3
3
  ## Overview
4
- This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphRAG` for managing nodes and edges.
4
+ This project provides an embedded graph database implementation with vector similarity search (VSS) using DuckDB. It includes a Python class `GraphMemory` for managing nodes and edges.
5
5
 
6
6
  Each node has a unique ID, a JSON data field (any arbitrary dictionary), and a vector of floating point values.
7
7
 
@@ -12,37 +12,77 @@ This database can be used for any graph-based RAG application or knowledge graph
12
12
  Vector embeddings can be created using [sentence-transformers](https://www.sbert.net/) or other API based models.
13
13
 
14
14
  ## Installation
15
- 1. Clone the repository:
16
- ```sh
17
- git clone https://github.com/bradAGI/GraphMemory
18
- ```
19
- 2. Install the required packages:
20
- ```sh
21
- pip install -r requirements.txt
22
- ```
15
+ ```sh
16
+ pip install graphmemory
17
+ ```
23
18
 
24
19
  ## Usage
25
20
 
26
- ### GraphRAG Class
27
- The `GraphRAG` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
21
+ ### GraphMemory Class
22
+ The `GraphMemory` class provides methods to manage nodes and edges, perform bulk inserts, create indexes, and find nearest neighbors using vector similarity search.
28
23
 
29
24
  ### Auto-Incrementing IDs
30
25
  If you do not provide an ID for a node or edge, the database will automatically assign a unique ID.
31
26
 
32
27
  ### Example
33
28
  ```python
34
- from graphrag.graphrag import GraphRAG
35
- from graphrag.models import Node, Edge
29
+ from graphmemory import GraphMemory
30
+ from graphmemory.models import Node, Edge
31
+
32
+ import json
33
+ from openai import OpenAI
34
+ import os
35
+
36
+ client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
37
+
38
+ # Use a LLM to extract structured data from a unstructured text (there are a variety of ways to do this)
39
+ text = "George Washington was the first President of the United States and served from 1789 to 1797."
40
+
41
+ def extract_dict(text):
42
+ return json.loads(client.chat.completions.create(
43
+ model="gpt-3.5-turbo",
44
+ messages=[
45
+ {"role": "system", "content": "Extract structured data from this text: " + text}
46
+ ]
47
+ ).choices[0].message.content)
48
+
49
+
50
+ gw_dict = extract_dict(text)
51
+ print(gw_dict)
52
+
53
+ # Output:
54
+ # {
55
+ # "President": "George Washington",
56
+ # "Position": "First President",
57
+ # "Country": "United States",
58
+ # "Term": "1789-1797"
59
+ # }
60
+
61
+ def calculate_embedding(input_json):
62
+ return client.embeddings.create(
63
+ input=input_json,
64
+ model="text-embedding-3-small"
65
+ ).data[0].embedding
66
+
67
+
68
+ embedding = calculate_embedding(gw_dict)
69
+ print(embedding)
70
+
71
+ # Output:
72
+ # [-0.006929283495992422, -0.005336422007530928, ... (omitted for spacing), 0.04664124920964241, -0.024047505110502243]
36
73
 
37
74
 
38
75
  # Initialize the database from disk (make sure to set vector_length correctly)
39
- graph_db = GraphRAG(database='graph.db', vector_length=3)
76
+ graph_db = GraphMemory(database='graph.db', vector_length=len(embedding))
40
77
 
41
78
  # Insert nodes
42
- node1 = Node(data={"name": "George Washington", "role": "President"}, vector=[0.1, 0.2, 0.3])
79
+ node1 = Node(data=gw_dict, vector=embedding)
43
80
  node1_id = graph_db.insert_node(node1)
44
81
 
45
- node2 = Node(data={"name": "Thomas Jefferson", "role": "Secretary of State"}, vector=[0.4, 0.5, 0.6])
82
+ text2 = "Thomas Jefferson was the first Secretary of State of the United States and served from 1797 to 1801."
83
+ tj_dict = extract_dict(text2)
84
+
85
+ node2 = Node(data=tj_dict, vector=calculate_embedding(tj_dict))
46
86
  node2_id = graph_db.insert_node(node2)
47
87
 
48
88
  # Insert edge
@@ -62,49 +102,13 @@ connected_nodes = graph_db.connected_nodes(node1_id)
62
102
  print("Connected Nodes:", connected_nodes)
63
103
 
64
104
  # Find nearest neighbors
65
- neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=1)
105
+ neighbors = graph_db.nearest_nodes(calculate_embedding({"President": "George Washington"}), limit=1)
66
106
  print("Nearest Neighbors:", neighbors)
67
-
68
- # Insert an edge between the two nodes with a relation
69
- edge = Edge(source_id=node1_id, target_id=node2_id, relation="served_under", weight=0.5)
70
- graph_db.insert_edge(edge)
71
-
72
- # Define the additional nodes for bulk insert
73
- nodes = [
74
- Node(data={"name": "Alexander Hamilton", "role": "Secretary of the Treasury", "term": "1789–1795"}, vector=[0.7, 0.8, 0.9]),
75
- Node(data={"name": "Oliver Wolcott Jr.", "role": "Secretary of the Treasury", "term": "1795–1797"}, vector=[1.6, 1.7, 1.8]),
76
- ]
77
-
78
- # Bulk insert nodes
79
- graph_db.bulk_insert_nodes(nodes)
80
-
81
- # Define the additional edges for bulk insert
82
- edges = [
83
- Edge(source_id=nodes[0].id, target_id=nodes[1].id, relation="succeeded_by", weight=0.7),
84
- Edge(source_id=nodes[1].id, target_id=nodes[2].id, relation="succeeded_by", weight=0.8)
85
- ]
86
-
87
- # Bulk insert edges
88
- graph_db.bulk_insert_edges(edges)
89
-
90
- # Delete a node
91
- graph_db.delete_node(nodes[-1].id)
92
-
93
- # Delete an edge
94
- graph_db.delete_edge(1, 2)
95
-
96
- # Find nearest nodes to a given vector by distance
97
- neighbors = graph_db.nearest_neighbors(vector=[0.1, 0.2, 0.3], limit=2)
98
- print("Nearest Neighbors:", neighbors)
99
-
100
- # Find connected nodes
101
- connected_nodes = graph_db.connected_nodes(nodes[1].id)
102
- print("Connected Nodes:", connected_nodes)
103
107
  ```
104
108
 
105
- ## GraphRAG Class Methods
109
+ ## GraphMemory Class Methods
106
110
 
107
- The `GraphRAG` class provides the following public methods for interacting with the graph database:
111
+ The `GraphMemory` class provides the following public methods for interacting with the graph database:
108
112
 
109
113
  1. `__init__(self, database=None, vector_length=3)`
110
114
  - Initializes the database connection and sets up the database schema if necessary.
@@ -136,7 +140,7 @@ The `GraphRAG` class provides the following public methods for interacting with
136
140
  10. `create_index(self)`
137
141
  - Creates an index on the node vectors to improve search performance.
138
142
 
139
- 11. `nearest_neighbors(self, vector: List[float], limit: int) -> List[Neighbor]`
143
+ 11. `nearest_nodes(self, vector: List[float], limit: int) -> List[Neighbor]`
140
144
  - Finds and returns the nearest neighbor nodes based on vector similarity.
141
145
 
142
146
  12. `connected_nodes(self, node_id: int) -> List[Node]`
@@ -0,0 +1,3 @@
1
+ from .database import GraphMemory
2
+ from .models import Node, Edge, NearestNode
3
+
@@ -0,0 +1,429 @@
1
+ import duckdb
2
+ import json
3
+ import os
4
+ import logging
5
+ from contextlib import contextmanager
6
+ from graphmemory.models import Node, Edge, NearestNode
7
+ from typing import List, Any
8
+ from typing import Dict as D
9
+ import uuid
10
+
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class GraphMemory:
18
+ def __init__(self, database=None, vector_length=3):
19
+ self.database = database
20
+ self.vector_length = vector_length
21
+ self.conn = duckdb.connect(database=self.database)
22
+ self._load_vss_extension()
23
+ self._configure_database()
24
+
25
+ if database and os.path.exists(database):
26
+ self.load_database(database)
27
+
28
+ # Check if 'nodes' and 'edges' tables exist, and create them if they do not
29
+ nodes_exist = self.conn.execute(
30
+ "SELECT 1 FROM information_schema.tables WHERE table_name = 'nodes';").fetchone()
31
+ edges_exist = self.conn.execute(
32
+ "SELECT 1 FROM information_schema.tables WHERE table_name = 'edges';").fetchone()
33
+
34
+ if not nodes_exist or not edges_exist:
35
+ self.create_tables()
36
+ logger.info("Tables created or verified successfully.")
37
+
38
+ def load_database(self, path):
39
+ if not os.path.exists(path):
40
+ logger.error(f"Database file not found: {path}")
41
+ return
42
+ try:
43
+ self.conn.execute(f"ATTACH DATABASE '{path}' AS vss;")
44
+ except duckdb.Error as e:
45
+ logger.error(f"Error loading database: {e}")
46
+
47
+ def _configure_database(self):
48
+ try:
49
+ self.conn.execute("SET hnsw_enable_experimental_persistence=true;")
50
+ except duckdb.Error as e:
51
+ logger.error(f"Error setting configuration: {e}")
52
+
53
+ def _load_vss_extension(self):
54
+ try:
55
+ self.conn.execute("INSTALL vss;")
56
+ self.conn.execute("LOAD vss;")
57
+ except duckdb.Error as e:
58
+ logger.error(f"Error loading VSS extension: {e}")
59
+
60
+ def set_vector_length(self, vector_length):
61
+ self.vector_length = vector_length
62
+ logger.info(f"Vector length set to: {self.vector_length}")
63
+
64
+ def create_tables(self):
65
+ # Correctly format the SQL string to include vector_length
66
+ self.conn.execute(f"""
67
+ CREATE TABLE IF NOT EXISTS nodes (
68
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
69
+ type TEXT,
70
+ properties JSON,
71
+ vector FLOAT[{self.vector_length}]
72
+ );
73
+ """)
74
+ self.conn.execute(f"""
75
+ CREATE TABLE IF NOT EXISTS edges (
76
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
77
+ source_id UUID,
78
+ target_id UUID,
79
+ relation TEXT,
80
+ weight FLOAT,
81
+ FOREIGN KEY (source_id) REFERENCES nodes(id),
82
+ FOREIGN KEY (target_id) REFERENCES nodes(id)
83
+ );
84
+ """)
85
+ logger.info("Tables 'nodes' and 'edges' created or already exist.")
86
+ self.conn.commit()
87
+
88
+ def insert_node(self, node: Node) -> uuid.UUID:
89
+ if node.vector and not self._validate_vector(node.vector):
90
+ logger.error("Invalid vector: Must be a list of float values.")
91
+ return None
92
+ try:
93
+ with self.transaction():
94
+ result = self.conn.execute(
95
+ "INSERT INTO nodes (id, type, properties, vector) VALUES (?, ?, ?, ?) RETURNING id;",
96
+ (str(node.id), node.type, json.dumps(node.properties), node.vector if node.vector else [0.0] * self.vector_length)
97
+ ).fetchone()
98
+ if result:
99
+ logger.info(f"Node inserted with ID: {result[0]}")
100
+ return result[0] if result else None
101
+ except duckdb.Error as e:
102
+ logger.error(f"Error during insert node: {e}")
103
+ return None
104
+
105
+ def insert_edge(self, edge: Edge):
106
+ try:
107
+ with self.transaction():
108
+ # Check if source and target nodes exist
109
+ source_exists = self.conn.execute(
110
+ "SELECT 1 FROM nodes WHERE id = ?", (str(edge.source_id),)).fetchone()
111
+ target_exists = self.conn.execute(
112
+ "SELECT 1 FROM nodes WHERE id = ?", (str(edge.target_id),)).fetchone()
113
+ if not source_exists or not target_exists:
114
+ raise ValueError("Source or target node does not exist.")
115
+
116
+ self.conn.execute("INSERT INTO edges (id, source_id, target_id, relation, weight) VALUES (?, ?, ?, ?, ?);", (
117
+ str(edge.id), str(edge.source_id), str(edge.target_id), edge.relation, edge.weight))
118
+ except duckdb.Error as e:
119
+ logger.error(f"Error during insert edge: {e}")
120
+ except ValueError as e:
121
+ logger.error(f"Error during insert edge: {e}")
122
+ raise
123
+
124
+ def bulk_insert_nodes(self, nodes: List[Node]) -> List[Node]:
125
+ try:
126
+ with self.transaction():
127
+ for node in nodes:
128
+ result = self.conn.execute(
129
+ "INSERT INTO nodes (id, type, properties, vector) VALUES (?, ?, ?, ?) RETURNING id;",
130
+ (str(node.id), node.type, json.dumps(node.properties), node.vector if node.vector else None)
131
+ ).fetchone()
132
+ if result:
133
+ node.id = result[0]
134
+ return nodes
135
+ except duckdb.Error as e:
136
+ logger.error(f"Error during bulk insert nodes: {e}")
137
+ return []
138
+
139
+ def bulk_insert_edges(self, edges: List[Edge]):
140
+ try:
141
+ with self.transaction():
142
+ self.conn.executemany(
143
+ "INSERT INTO edges (id, source_id, target_id, relation, weight) VALUES (?, ?, ?, ?, ?);",
144
+ [(str(edge.id), str(edge.source_id), str(edge.target_id), edge.relation, edge.weight)
145
+ for edge in edges]
146
+ )
147
+ except duckdb.Error as e:
148
+ logger.error(f"Error during bulk insert edges: {e}")
149
+
150
+ def delete_node(self, node_id: uuid.UUID):
151
+ try:
152
+ with self.transaction():
153
+ self.conn.execute(
154
+ "DELETE FROM nodes WHERE id = ?;", (str(node_id),))
155
+ self.conn.execute(
156
+ "DELETE FROM edges WHERE source_id = ? OR target_id = ?;", (str(node_id), str(node_id)))
157
+ except duckdb.Error as e:
158
+ logger.error(f"Error deleting node: {e}")
159
+
160
+ def delete_edge(self, source_id: uuid.UUID, target_id: uuid.UUID):
161
+ try:
162
+ with self.transaction():
163
+ self.conn.execute(
164
+ "DELETE FROM edges WHERE source_id = ? AND target_id = ?;", (str(source_id), str(target_id)))
165
+ except duckdb.Error as e:
166
+ logger.error(f"Error deleting edge: {e}")
167
+
168
+ def create_index(self):
169
+ try:
170
+ self.conn.execute(
171
+ "CREATE INDEX IF NOT EXISTS vss_idx ON nodes USING HNSW(vector);")
172
+ except duckdb.Error as e:
173
+ logger.error(f"Error creating index: {e}")
174
+
175
+ def nearest_nodes(self, vector: List[float], limit: int) -> List[NearestNode]:
176
+ if not self._validate_vector(vector):
177
+ logger.error("Invalid vector: Must be a list of float values.")
178
+ return []
179
+
180
+ query = f"""
181
+ SELECT id, type, properties, vector, array_distance(vector, CAST(? AS FLOAT[{self.vector_length}])) AS distance
182
+ FROM nodes
183
+ WHERE vector IS NOT NULL
184
+ ORDER BY distance LIMIT ?;
185
+ """
186
+ try:
187
+ results = self.conn.execute(query, (vector, limit)).fetchall()
188
+ return [
189
+ NearestNode(
190
+ node=Node(id=row[0], type=row[1], properties=json.loads(row[2]), vector=row[3]),
191
+ distance=row[4]
192
+ ) for row in results
193
+ ]
194
+ except duckdb.Error as e:
195
+ logger.error(f"Error fetching nearest neighbors: {e}")
196
+ return []
197
+
198
+ def connected_nodes(self, node_id: uuid.UUID) -> List[Node]:
199
+ query = """
200
+ SELECT n.id, n.type, n.properties, n.vector
201
+ FROM nodes n
202
+ WHERE n.id IN (
203
+ SELECT target_id FROM edges WHERE source_id = ?
204
+ UNION
205
+ SELECT source_id FROM edges WHERE target_id = ?
206
+ );
207
+ """
208
+ try:
209
+ logger.info(
210
+ f"Executing query to fetch connected nodes for node_id: {node_id}")
211
+ results = self.conn.execute(query, (str(node_id), str(node_id))).fetchall()
212
+ if results:
213
+ connected_nodes = [Node(id=uuid.UUID(str(row[0])), type=row[1], properties=json.loads(row[2]), vector=row[3]) for row in results]
214
+ logger.info(f"Found {len(connected_nodes)} connected nodes.")
215
+ else:
216
+ connected_nodes = []
217
+ logger.info("No connected nodes found.")
218
+ return connected_nodes
219
+ except duckdb.Error as e:
220
+ logger.error(f"Error fetching connected nodes: {e}")
221
+ return []
222
+
223
+ def nodes_to_json(self) -> List[D[str, Any]]:
224
+ try:
225
+ nodes = self.conn.execute(
226
+ "SELECT id, type, properties, vector FROM nodes;").fetchall()
227
+ return [{"id": row[0], "type": row[1], "properties": json.loads(row[2]), "vector": row[3]} for row in nodes]
228
+ except duckdb.Error as e:
229
+ logger.error(f"Error fetching nodes: {e}")
230
+ return []
231
+
232
+ def edges_to_json(self) -> List[D[str, Any]]:
233
+ try:
234
+ edges = self.conn.execute(
235
+ "SELECT id, source_id, target_id, relation, weight FROM edges;").fetchall()
236
+ return [{"id": str(row[0]), "source_id": str(row[1]), "target_id": str(row[2]), "relation": row[3], "weight": row[4]} for row in edges]
237
+ except duckdb.Error as e:
238
+ logger.error(f"Error fetching edges: {e}")
239
+ return []
240
+
241
+ def get_node(self, node_id: uuid.UUID) -> Node:
242
+ try:
243
+ node = self.conn.execute(
244
+ "SELECT id, type, properties, vector FROM nodes WHERE id = ?;", (str(node_id),)).fetchone()
245
+ if node:
246
+ return Node(id=node[0], type=node[1], properties=json.loads(node[2]), vector=node[3])
247
+ else:
248
+ return None
249
+ except duckdb.Error as e:
250
+ logger.error(f"Error fetching node: {e}")
251
+ return None
252
+
253
+ def nodes_by_attribute(self, attribute, value) -> List[Node]:
254
+ try:
255
+ query = f"SELECT id, type, properties, vector FROM nodes WHERE json_extract(properties, '$.{attribute}') = ?;"
256
+ nodes = self.conn.execute(query, (json.dumps(value),)).fetchall()
257
+ if nodes:
258
+ return [Node(id=row[0], type=row[1], properties=json.loads(row[2]), vector=row[3]) for row in nodes]
259
+ else:
260
+ return []
261
+ except duckdb.Error as e:
262
+ logger.error(f"Error fetching nodes: {e}")
263
+ return []
264
+
265
+ def get_nodes_vector(self, node_id: int) -> List[float]:
266
+ try:
267
+ vector = self.conn.execute(
268
+ "SELECT vector FROM nodes WHERE id = ?;", (node_id,)).fetchone()
269
+ return vector[0] if vector else []
270
+ except duckdb.Error as e:
271
+ logger.error(f"Error fetching vector: {e}")
272
+ return []
273
+
274
+ def print_json(self):
275
+ nodes_json = self.nodes_to_json()
276
+ edges_json = self.edges_to_json()
277
+ print("Nodes JSON:", json.dumps(nodes_json, indent=2))
278
+ print("Edges JSON:", json.dumps(edges_json, indent=2))
279
+
280
+ def cypher(self, cypher_query):
281
+ sql_query = self._cypher_to_sql(cypher_query)
282
+ try:
283
+ results = self.conn.execute(sql_query).fetchall()
284
+ logger.debug(f"Query results: {results}")
285
+ return results
286
+ except duckdb.Error as e:
287
+ logger.error(f"Error executing SQL query: {e}")
288
+ return []
289
+
290
+
291
+ def _cypher_to_sql(self, cypher_query):
292
+ import re
293
+ import json # Added import for json
294
+ # Define regex patterns for nodes, relationships, and properties
295
+ node_pattern = re.compile(r"\((\w+)(?::(\w+))?(?:\s*{([^}]+)})?\)")
296
+ rel_pattern = re.compile(r"\[(\w+)?(?::(\w+))?(?:\s*{([^}]+)})?\]")
297
+
298
+ # Helper function to parse properties
299
+ def parse_properties(prop_string):
300
+ properties = {}
301
+ if prop_string:
302
+ props = prop_string.split(',')
303
+ for prop in props:
304
+ key, value = prop.split(':')
305
+ value = value.strip().strip('"\'')
306
+ if value.isdigit():
307
+ value = int(value)
308
+ elif re.match(r"^\d+?\.\d+?$", value):
309
+ value = float(value)
310
+ properties[key.strip()] = value
311
+ return properties
312
+
313
+ # Extract MATCH clause
314
+ match_clause = re.search(r'MATCH\s+(.*)\s+RETURN', cypher_query, re.IGNORECASE)
315
+ if not match_clause:
316
+ raise ValueError("Invalid Cypher query: missing MATCH or RETURN clause")
317
+ match_content = match_clause.group(1).strip()
318
+
319
+ # Extract RETURN clause
320
+ return_clause = re.search(r'RETURN\s+(.*)', cypher_query, re.IGNORECASE)
321
+ if not return_clause:
322
+ raise ValueError("Invalid Cypher query: missing RETURN clause")
323
+ return_content = return_clause.group(1).strip().split(',')
324
+
325
+ # Parse nodes and relationships together
326
+ elements = re.split(r'(\[.*?\])', match_content)
327
+
328
+ nodes = []
329
+ relationships = []
330
+ for elem in elements:
331
+ if '(' in elem:
332
+ match = node_pattern.search(elem)
333
+ if match:
334
+ alias, label, prop_string = match.groups()
335
+ nodes.append({
336
+ "alias": alias,
337
+ "label": label,
338
+ "properties": parse_properties(prop_string)
339
+ })
340
+ elif '[' in elem:
341
+ match = rel_pattern.search(elem)
342
+ if match:
343
+ alias, label, prop_string = match.groups()
344
+ relationships.append({
345
+ "alias": alias or f"r{len(relationships)+1}",
346
+ "label": label,
347
+ "properties": parse_properties(prop_string)
348
+ })
349
+
350
+ # Start building the SQL query
351
+ sql_query = "SELECT "
352
+ sql_parts = []
353
+
354
+ # Determine what is being returned
355
+ for item in return_content:
356
+ item = item.strip()
357
+ if '.' in item:
358
+ alias, field = item.split('.')
359
+ if field == "embedding":
360
+ sql_parts.append(f"{alias}.{field}")
361
+ else:
362
+ sql_parts.append("*")
363
+
364
+ if not sql_parts:
365
+ sql_parts.append("*")
366
+
367
+ from_clause = []
368
+ where_conditions = []
369
+
370
+ # Process nodes and relationships in sequence
371
+ for i, node in enumerate(nodes):
372
+ alias, label, properties = node.values()
373
+ if i == 0:
374
+ from_clause.append(f"nodes AS {alias}")
375
+ else:
376
+ prev_node = nodes[i-1]['alias']
377
+ rel = relationships[i-1]
378
+ rel_alias, rel_label, rel_properties = rel.values()
379
+ from_clause.append(f"JOIN nodes AS {alias} ON {prev_node}.id = {rel_alias}.start_node_id AND {alias}.id = {rel_alias}.end_node_id")
380
+
381
+ if label:
382
+ where_conditions.append(f"{alias}.type = '{label}'")
383
+ for prop, val in properties.items():
384
+ if prop == "embedding":
385
+ sql_parts.append(f"{alias}.embedding")
386
+ else:
387
+ if isinstance(val, (int, float)):
388
+ where_conditions.append(f"json_extract({alias}.properties, '$.{prop}') = json('{val}')")
389
+ else:
390
+ where_conditions.append(f"json_extract({alias}.properties, '$.{prop}') = json('{json.dumps(val)}')")
391
+
392
+ for rel in relationships:
393
+ rel_alias, rel_label, rel_properties = rel.values()
394
+ if rel_label:
395
+ where_conditions.append(f"{rel_alias}.type = '{rel_label}'")
396
+ for prop, val in rel_properties.items():
397
+ if isinstance(val, (int, float)):
398
+ where_conditions.append(f"json_extract({rel_alias}.properties, '$.{prop}') = json('{val}')")
399
+ else:
400
+ where_conditions.append(f"json_extract({rel_alias}.properties, '$.{prop}') = json('{json.dumps(val)}')")
401
+
402
+ sql_query += ", ".join(sql_parts)
403
+ sql_query += " FROM " + " ".join(from_clause)
404
+
405
+ if where_conditions:
406
+ sql_query += " WHERE " + " AND ".join(where_conditions)
407
+
408
+ return sql_query + ";"
409
+
410
+
411
+ def _validate_vector(self, vector):
412
+ return isinstance(vector, list) and len(vector) == self.vector_length and all(isinstance(x, float) for x in vector)
413
+
414
+ @contextmanager
415
+ def transaction(self):
416
+ try:
417
+ self.conn.execute("BEGIN TRANSACTION;")
418
+ yield
419
+ self.conn.execute("COMMIT;")
420
+ except Exception as e:
421
+ self.conn.execute("ROLLBACK;")
422
+ raise e
423
+
424
+ def __enter__(self):
425
+ return self
426
+
427
+ def __exit__(self, exc_type, exc_value, traceback):
428
+ self.conn.close()
429
+ logger.info("Database connection closed.")
@@ -0,0 +1,30 @@
1
+ from pydantic import BaseModel, Field
2
+ from typing import List, Dict, Any, Optional
3
+ import uuid
4
+
5
+ class GraphEntity(BaseModel):
6
+ id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier for the entity.")
7
+
8
+ class Config:
9
+ extra = 'forbid'
10
+
11
+ class Node(GraphEntity):
12
+ properties: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Properties of the entity.")
13
+ type: Optional[str] = Field(default=None, description="Optional label for the node to categorize it, ex: Person")
14
+ vector: Optional[List[float]] = Field(default_factory=list, description="Vector representation of the node.")
15
+
16
+ class Config:
17
+ extra = 'forbid'
18
+
19
+ class Edge(GraphEntity):
20
+ source_id: uuid.UUID
21
+ target_id: uuid.UUID
22
+ relation: str = Field(default=None, description="Relation between the source and target nodes")
23
+ weight: Optional[float] = Field(default=None, description="Weight of the edge")
24
+
25
+ class Config:
26
+ extra = 'forbid'
27
+
28
+ class NearestNode(BaseModel):
29
+ node: Node
30
+ distance: float
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.1
2
+ Name: graphmemory
3
+ Version: 0.2
4
+ Summary: A package for creating a hybrid graph / vector database for use with GraphRAG.
5
+ Home-page: https://github.com/bradAGI/GraphMemory
6
+ Author: BradAGI
7
+ Author-email: cavemen_summary_0f@icloud.com
8
+ License: LICENSE.txt
9
+ Keywords: graphrag graph database rag vector database
10
+ Requires-Dist: duckdb==1.0.0
11
+ Requires-Dist: pydantic==2.7.3
@@ -1,11 +1,11 @@
1
1
  README.md
2
2
  setup.py
3
+ graphmemory/__init__.py
4
+ graphmemory/database.py
5
+ graphmemory/models.py
3
6
  graphmemory.egg-info/PKG-INFO
4
7
  graphmemory.egg-info/SOURCES.txt
5
8
  graphmemory.egg-info/dependency_links.txt
6
9
  graphmemory.egg-info/requires.txt
7
10
  graphmemory.egg-info/top_level.txt
8
- graphrag/__init__.py
9
- graphrag/database.py
10
- graphrag/models.py
11
11
  tests/tests.py
@@ -0,0 +1,2 @@
1
+ duckdb==1.0.0
2
+ pydantic==2.7.3
@@ -0,0 +1 @@
1
+ graphmemory