graphiti-core 0.2.2__tar.gz → 0.2.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (37) hide show
  1. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/PKG-INFO +2 -3
  2. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/edges.py +39 -32
  3. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/graphiti.py +45 -30
  4. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/nodes.py +40 -39
  5. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/search/search.py +5 -2
  6. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/search/search_utils.py +74 -143
  7. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/bulk_utils.py +31 -3
  8. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/edge_operations.py +7 -5
  9. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/graph_data_operations.py +17 -8
  10. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/node_operations.py +1 -0
  11. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/pyproject.toml +3 -6
  12. graphiti_core-0.2.2/graphiti_core/utils/utils.py +0 -60
  13. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/LICENSE +0 -0
  14. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/README.md +0 -0
  15. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/__init__.py +0 -0
  16. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/helpers.py +0 -0
  17. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/__init__.py +0 -0
  18. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/anthropic_client.py +0 -0
  19. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/client.py +0 -0
  20. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/config.py +0 -0
  21. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/groq_client.py +0 -0
  22. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/openai_client.py +0 -0
  23. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/llm_client/utils.py +0 -0
  24. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/__init__.py +0 -0
  25. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/dedupe_edges.py +0 -0
  26. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  27. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  28. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_edges.py +0 -0
  29. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/extract_nodes.py +0 -0
  30. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/invalidate_edges.py +0 -0
  31. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/lib.py +0 -0
  32. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/prompts/models.py +0 -0
  33. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/search/__init__.py +0 -0
  34. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/__init__.py +0 -0
  35. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/__init__.py +0 -0
  36. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  37. {graphiti_core-0.2.2 → graphiti_core-0.2.3}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.2.2
3
+ Version: 0.2.3
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -12,11 +12,10 @@ Classifier: Programming Language :: Python :: 3.10
12
12
  Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
- Requires-Dist: fastapi (>=0.112.0,<0.113.0)
16
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
+ Requires-Dist: numpy (>=2.1.1,<3.0.0)
17
17
  Requires-Dist: openai (>=1.38.0,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
- Requires-Dist: sentence-transformers (>=3.0.1,<4.0.0)
20
19
  Requires-Dist: tenacity (<9.0.0)
21
20
  Description-Content-Type: text/markdown
22
21
 
@@ -18,6 +18,7 @@ import logging
18
18
  from abc import ABC, abstractmethod
19
19
  from datetime import datetime
20
20
  from time import time
21
+ from typing import Any
21
22
  from uuid import uuid4
22
23
 
23
24
  from neo4j import AsyncDriver
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
32
33
 
33
34
  class Edge(BaseModel, ABC):
34
35
  uuid: str = Field(default_factory=lambda: uuid4().hex)
36
+ group_id: str | None = Field(description='partition of the graph')
35
37
  source_node_uuid: str
36
38
  target_node_uuid: str
37
39
  created_at: datetime
@@ -61,11 +63,12 @@ class EpisodicEdge(Edge):
61
63
  MATCH (episode:Episodic {uuid: $episode_uuid})
62
64
  MATCH (node:Entity {uuid: $entity_uuid})
63
65
  MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
64
- SET r = {uuid: $uuid, created_at: $created_at}
66
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
65
67
  RETURN r.uuid AS uuid""",
66
68
  episode_uuid=self.source_node_uuid,
67
69
  entity_uuid=self.target_node_uuid,
68
70
  uuid=self.uuid,
71
+ group_id=self.group_id,
69
72
  created_at=self.created_at,
70
73
  )
71
74
 
@@ -92,7 +95,8 @@ class EpisodicEdge(Edge):
92
95
  """
93
96
  MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
94
97
  RETURN
95
- e.uuid As uuid,
98
+ e.uuid As uuid,
99
+ e.group_id AS group_id,
96
100
  n.uuid AS source_node_uuid,
97
101
  m.uuid AS target_node_uuid,
98
102
  e.created_at AS created_at
@@ -100,17 +104,7 @@ class EpisodicEdge(Edge):
100
104
  uuid=uuid,
101
105
  )
102
106
 
103
- edges: list[EpisodicEdge] = []
104
-
105
- for record in records:
106
- edges.append(
107
- EpisodicEdge(
108
- uuid=record['uuid'],
109
- source_node_uuid=record['source_node_uuid'],
110
- target_node_uuid=record['target_node_uuid'],
111
- created_at=record['created_at'].to_native(),
112
- )
113
- )
107
+ edges = [get_episodic_edge_from_record(record) for record in records]
114
108
 
115
109
  logger.info(f'Found Edge: {uuid}')
116
110
 
@@ -153,7 +147,7 @@ class EntityEdge(Edge):
153
147
  MATCH (source:Entity {uuid: $source_uuid})
154
148
  MATCH (target:Entity {uuid: $target_uuid})
155
149
  MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target)
156
- SET r = {uuid: $uuid, name: $name, fact: $fact, fact_embedding: $fact_embedding,
150
+ SET r = {uuid: $uuid, name: $name, group_id: $group_id, fact: $fact, fact_embedding: $fact_embedding,
157
151
  episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
158
152
  valid_at: $valid_at, invalid_at: $invalid_at}
159
153
  RETURN r.uuid AS uuid""",
@@ -161,6 +155,7 @@ class EntityEdge(Edge):
161
155
  target_uuid=self.target_node_uuid,
162
156
  uuid=self.uuid,
163
157
  name=self.name,
158
+ group_id=self.group_id,
164
159
  fact=self.fact,
165
160
  fact_embedding=self.fact_embedding,
166
161
  episodes=self.episodes,
@@ -198,6 +193,7 @@ class EntityEdge(Edge):
198
193
  m.uuid AS target_node_uuid,
199
194
  e.created_at AS created_at,
200
195
  e.name AS name,
196
+ e.group_id AS group_id,
201
197
  e.fact AS fact,
202
198
  e.fact_embedding AS fact_embedding,
203
199
  e.episodes AS episodes,
@@ -208,25 +204,36 @@ class EntityEdge(Edge):
208
204
  uuid=uuid,
209
205
  )
210
206
 
211
- edges: list[EntityEdge] = []
212
-
213
- for record in records:
214
- edges.append(
215
- EntityEdge(
216
- uuid=record['uuid'],
217
- source_node_uuid=record['source_node_uuid'],
218
- target_node_uuid=record['target_node_uuid'],
219
- fact=record['fact'],
220
- name=record['name'],
221
- episodes=record['episodes'],
222
- fact_embedding=record['fact_embedding'],
223
- created_at=record['created_at'].to_native(),
224
- expired_at=parse_db_date(record['expired_at']),
225
- valid_at=parse_db_date(record['valid_at']),
226
- invalid_at=parse_db_date(record['invalid_at']),
227
- )
228
- )
207
+ edges = [get_entity_edge_from_record(record) for record in records]
229
208
 
230
209
  logger.info(f'Found Edge: {uuid}')
231
210
 
232
211
  return edges[0]
212
+
213
+
214
+ # Edge helpers
215
+ def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
216
+ return EpisodicEdge(
217
+ uuid=record['uuid'],
218
+ group_id=record['group_id'],
219
+ source_node_uuid=record['source_node_uuid'],
220
+ target_node_uuid=record['target_node_uuid'],
221
+ created_at=record['created_at'].to_native(),
222
+ )
223
+
224
+
225
+ def get_entity_edge_from_record(record: Any) -> EntityEdge:
226
+ return EntityEdge(
227
+ uuid=record['uuid'],
228
+ source_node_uuid=record['source_node_uuid'],
229
+ target_node_uuid=record['target_node_uuid'],
230
+ fact=record['fact'],
231
+ name=record['name'],
232
+ group_id=record['group_id'],
233
+ episodes=record['episodes'],
234
+ fact_embedding=record['fact_embedding'],
235
+ created_at=record['created_at'].to_native(),
236
+ expired_at=parse_db_date(record['expired_at']),
237
+ valid_at=parse_db_date(record['valid_at']),
238
+ invalid_at=parse_db_date(record['invalid_at']),
239
+ )
@@ -18,7 +18,6 @@ import asyncio
18
18
  import logging
19
19
  from datetime import datetime
20
20
  from time import time
21
- from typing import Callable
22
21
 
23
22
  from dotenv import load_dotenv
24
23
  from neo4j import AsyncGraphDatabase
@@ -120,7 +119,7 @@ class Graphiti:
120
119
 
121
120
  Parameters
122
121
  ----------
123
- None
122
+ self
124
123
 
125
124
  Returns
126
125
  -------
@@ -151,7 +150,7 @@ class Graphiti:
151
150
 
152
151
  Parameters
153
152
  ----------
154
- None
153
+ self
155
154
 
156
155
  Returns
157
156
  -------
@@ -178,6 +177,7 @@ class Graphiti:
178
177
  self,
179
178
  reference_time: datetime,
180
179
  last_n: int = EPISODE_WINDOW_LEN,
180
+ group_ids: list[str | None] | None = None,
181
181
  ) -> list[EpisodicNode]:
182
182
  """
183
183
  Retrieve the last n episodic nodes from the graph.
@@ -191,6 +191,8 @@ class Graphiti:
191
191
  The reference time to retrieve episodes before.
192
192
  last_n : int, optional
193
193
  The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
194
+ group_ids : list[str | None], optional
195
+ The group ids to return data from.
194
196
 
195
197
  Returns
196
198
  -------
@@ -202,7 +204,7 @@ class Graphiti:
202
204
  The actual retrieval is performed by the `retrieve_episodes` function
203
205
  from the `graphiti_core.utils` module.
204
206
  """
205
- return await retrieve_episodes(self.driver, reference_time, last_n)
207
+ return await retrieve_episodes(self.driver, reference_time, last_n, group_ids)
206
208
 
207
209
  async def add_episode(
208
210
  self,
@@ -211,8 +213,8 @@ class Graphiti:
211
213
  source_description: str,
212
214
  reference_time: datetime,
213
215
  source: EpisodeType = EpisodeType.message,
214
- success_callback: Callable | None = None,
215
- error_callback: Callable | None = None,
216
+ group_id: str | None = None,
217
+ uuid: str | None = None,
216
218
  ):
217
219
  """
218
220
  Process an episode and update the graph.
@@ -232,10 +234,10 @@ class Graphiti:
232
234
  The reference time for the episode.
233
235
  source : EpisodeType, optional
234
236
  The type of the episode. Defaults to EpisodeType.message.
235
- success_callback : Callable | None, optional
236
- A callback function to be called upon successful processing.
237
- error_callback : Callable | None, optional
238
- A callback function to be called if an error occurs during processing.
237
+ group_id : str | None
238
+ An id for the graph partition the episode is a part of.
239
+ uuid : str | None
240
+ Optional uuid of the episode.
239
241
 
240
242
  Returns
241
243
  -------
@@ -266,9 +268,12 @@ class Graphiti:
266
268
  embedder = self.llm_client.get_embedder()
267
269
  now = datetime.now()
268
270
 
269
- previous_episodes = await self.retrieve_episodes(reference_time, last_n=3)
271
+ previous_episodes = await self.retrieve_episodes(
272
+ reference_time, last_n=3, group_ids=[group_id]
273
+ )
270
274
  episode = EpisodicNode(
271
275
  name=name,
276
+ group_id=group_id,
272
277
  labels=[],
273
278
  source=source,
274
279
  content=episode_body,
@@ -276,6 +281,7 @@ class Graphiti:
276
281
  created_at=now,
277
282
  valid_at=reference_time,
278
283
  )
284
+ episode.uuid = uuid if uuid is not None else episode.uuid
279
285
 
280
286
  # Extract entities as nodes
281
287
 
@@ -299,7 +305,9 @@ class Graphiti:
299
305
 
300
306
  (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
301
307
  resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
302
- extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
308
+ extract_edges(
309
+ self.llm_client, episode, extracted_nodes, previous_episodes, group_id
310
+ ),
303
311
  )
304
312
  logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
305
313
  nodes.extend(mentioned_nodes)
@@ -388,11 +396,7 @@ class Graphiti:
388
396
 
389
397
  logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
390
398
 
391
- episodic_edges: list[EpisodicEdge] = build_episodic_edges(
392
- mentioned_nodes,
393
- episode,
394
- now,
395
- )
399
+ episodic_edges: list[EpisodicEdge] = build_episodic_edges(mentioned_nodes, episode, now)
396
400
 
397
401
  logger.info(f'Built episodic edges: {episodic_edges}')
398
402
 
@@ -405,18 +409,10 @@ class Graphiti:
405
409
  end = time()
406
410
  logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
407
411
 
408
- if success_callback:
409
- await success_callback(episode)
410
412
  except Exception as e:
411
- if error_callback:
412
- await error_callback(episode, e)
413
- else:
414
- raise e
413
+ raise e
415
414
 
416
- async def add_episode_bulk(
417
- self,
418
- bulk_episodes: list[RawEpisode],
419
- ):
415
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
420
416
  """
421
417
  Process multiple episodes in bulk and update the graph.
422
418
 
@@ -427,6 +423,8 @@ class Graphiti:
427
423
  ----------
428
424
  bulk_episodes : list[RawEpisode]
429
425
  A list of RawEpisode objects to be processed and added to the graph.
426
+ group_id : str | None
427
+ An id for the graph partition the episode is a part of.
430
428
 
431
429
  Returns
432
430
  -------
@@ -463,6 +461,7 @@ class Graphiti:
463
461
  source=episode.source,
464
462
  content=episode.content,
465
463
  source_description=episode.source_description,
464
+ group_id=group_id,
466
465
  created_at=now,
467
466
  valid_at=episode.reference_time,
468
467
  )
@@ -527,7 +526,13 @@ class Graphiti:
527
526
  except Exception as e:
528
527
  raise e
529
528
 
530
- async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
529
+ async def search(
530
+ self,
531
+ query: str,
532
+ center_node_uuid: str | None = None,
533
+ group_ids: list[str | None] | None = None,
534
+ num_results=10,
535
+ ):
531
536
  """
532
537
  Perform a hybrid search on the knowledge graph.
533
538
 
@@ -540,6 +545,8 @@ class Graphiti:
540
545
  The search query string.
541
546
  center_node_uuid: str, optional
542
547
  Facts will be reranked based on proximity to this node
548
+ group_ids : list[str | None] | None, optional
549
+ The graph partitions to return data from.
543
550
  num_results : int, optional
544
551
  The maximum number of results to return. Defaults to 10.
545
552
 
@@ -562,6 +569,7 @@ class Graphiti:
562
569
  num_episodes=0,
563
570
  num_edges=num_results,
564
571
  num_nodes=0,
572
+ group_ids=group_ids,
565
573
  search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
566
574
  reranker=reranker,
567
575
  )
@@ -590,7 +598,10 @@ class Graphiti:
590
598
  )
591
599
 
592
600
  async def get_nodes_by_query(
593
- self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
601
+ self,
602
+ query: str,
603
+ group_ids: list[str | None] | None = None,
604
+ limit: int = RELEVANT_SCHEMA_LIMIT,
594
605
  ) -> list[EntityNode]:
595
606
  """
596
607
  Retrieve nodes from the graph database based on a text query.
@@ -602,6 +613,8 @@ class Graphiti:
602
613
  ----------
603
614
  query : str
604
615
  The text query to search for in the graph.
616
+ group_ids : list[str | None] | None, optional
617
+ The graph partitions to return data from.
605
618
  limit : int | None, optional
606
619
  The maximum number of results to return per search method.
607
620
  If None, a default limit will be applied.
@@ -626,5 +639,7 @@ class Graphiti:
626
639
  """
627
640
  embedder = self.llm_client.get_embedder()
628
641
  query_embedding = await generate_embedding(embedder, query)
629
- relevant_nodes = await hybrid_node_search([query], [query_embedding], self.driver, limit)
642
+ relevant_nodes = await hybrid_node_search(
643
+ [query], [query_embedding], self.driver, group_ids, limit
644
+ )
630
645
  return relevant_nodes
@@ -19,10 +19,10 @@ from abc import ABC, abstractmethod
19
19
  from datetime import datetime
20
20
  from enum import Enum
21
21
  from time import time
22
+ from typing import Any
22
23
  from uuid import uuid4
23
24
 
24
25
  from neo4j import AsyncDriver
25
- from openai import OpenAI
26
26
  from pydantic import BaseModel, Field
27
27
 
28
28
  from graphiti_core.llm_client.config import EMBEDDING_DIM
@@ -69,6 +69,7 @@ class EpisodeType(Enum):
69
69
  class Node(BaseModel, ABC):
70
70
  uuid: str = Field(default_factory=lambda: uuid4().hex)
71
71
  name: str = Field(description='name of the node')
72
+ group_id: str | None = Field(description='partition of the graph')
72
73
  labels: list[str] = Field(default_factory=list)
73
74
  created_at: datetime = Field(default_factory=lambda: datetime.now())
74
75
 
@@ -106,11 +107,12 @@ class EpisodicNode(Node):
106
107
  result = await driver.execute_query(
107
108
  """
108
109
  MERGE (n:Episodic {uuid: $uuid})
109
- SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content,
110
+ SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
110
111
  entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
111
112
  RETURN n.uuid AS uuid""",
112
113
  uuid=self.uuid,
113
114
  name=self.name,
115
+ group_id=self.group_id,
114
116
  source_description=self.source_description,
115
117
  content=self.content,
116
118
  entity_edges=self.entity_edges,
@@ -141,29 +143,19 @@ class EpisodicNode(Node):
141
143
  records, _, _ = await driver.execute_query(
142
144
  """
143
145
  MATCH (e:Episodic {uuid: $uuid})
144
- RETURN e.content as content,
145
- e.created_at as created_at,
146
- e.valid_at as valid_at,
147
- e.uuid as uuid,
148
- e.name as name,
149
- e.source_description as source_description,
150
- e.source as source
146
+ RETURN e.content AS content,
147
+ e.created_at AS created_at,
148
+ e.valid_at AS valid_at,
149
+ e.uuid AS uuid,
150
+ e.name AS name,
151
+ e.group_id AS group_id
152
+ e.source_description AS source_description,
153
+ e.source AS source
151
154
  """,
152
155
  uuid=uuid,
153
156
  )
154
157
 
155
- episodes = [
156
- EpisodicNode(
157
- content=record['content'],
158
- created_at=record['created_at'].to_native().timestamp(),
159
- valid_at=(record['valid_at'].to_native()),
160
- uuid=record['uuid'],
161
- source=EpisodeType.from_str(record['source']),
162
- name=record['name'],
163
- source_description=record['source_description'],
164
- )
165
- for record in records
166
- ]
158
+ episodes = [get_episodic_node_from_record(record) for record in records]
167
159
 
168
160
  logger.info(f'Found Node: {uuid}')
169
161
 
@@ -174,10 +166,6 @@ class EntityNode(Node):
174
166
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
175
167
  summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
176
168
 
177
- async def update_summary(self, driver: AsyncDriver): ...
178
-
179
- async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...
180
-
181
169
  async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
182
170
  start = time()
183
171
  text = self.name.replace('\n', ' ')
@@ -192,10 +180,11 @@ class EntityNode(Node):
192
180
  result = await driver.execute_query(
193
181
  """
194
182
  MERGE (n:Entity {uuid: $uuid})
195
- SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
183
+ SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
196
184
  RETURN n.uuid AS uuid""",
197
185
  uuid=self.uuid,
198
186
  name=self.name,
187
+ group_id=self.group_id,
199
188
  summary=self.summary,
200
189
  name_embedding=self.name_embedding,
201
190
  created_at=self.created_at,
@@ -227,25 +216,14 @@ class EntityNode(Node):
227
216
  n.uuid As uuid,
228
217
  n.name AS name,
229
218
  n.name_embedding AS name_embedding,
219
+ n.group_id AS group_id
230
220
  n.created_at AS created_at,
231
221
  n.summary AS summary
232
222
  """,
233
223
  uuid=uuid,
234
224
  )
235
225
 
236
- nodes: list[EntityNode] = []
237
-
238
- for record in records:
239
- nodes.append(
240
- EntityNode(
241
- uuid=record['uuid'],
242
- name=record['name'],
243
- name_embedding=record['name_embedding'],
244
- labels=['Entity'],
245
- created_at=record['created_at'].to_native(),
246
- summary=record['summary'],
247
- )
248
- )
226
+ nodes = [get_entity_node_from_record(record) for record in records]
249
227
 
250
228
  logger.info(f'Found Node: {uuid}')
251
229
 
@@ -253,3 +231,26 @@ class EntityNode(Node):
253
231
 
254
232
 
255
233
  # Node helpers
234
+ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
235
+ return EpisodicNode(
236
+ content=record['content'],
237
+ created_at=record['created_at'].to_native().timestamp(),
238
+ valid_at=(record['valid_at'].to_native()),
239
+ uuid=record['uuid'],
240
+ group_id=record['group_id'],
241
+ source=EpisodeType.from_str(record['source']),
242
+ name=record['name'],
243
+ source_description=record['source_description'],
244
+ )
245
+
246
+
247
+ def get_entity_node_from_record(record: Any) -> EntityNode:
248
+ return EntityNode(
249
+ uuid=record['uuid'],
250
+ name=record['name'],
251
+ group_id=record['group_id'],
252
+ name_embedding=record['name_embedding'],
253
+ labels=['Entity'],
254
+ created_at=record['created_at'].to_native(),
255
+ summary=record['summary'],
256
+ )
@@ -52,6 +52,7 @@ class SearchConfig(BaseModel):
52
52
  num_edges: int = Field(default=10)
53
53
  num_nodes: int = Field(default=10)
54
54
  num_episodes: int = EPISODE_WINDOW_LEN
55
+ group_ids: list[str | None] | None
55
56
  search_methods: list[SearchMethod]
56
57
  reranker: Reranker | None
57
58
 
@@ -83,7 +84,9 @@ async def hybrid_search(
83
84
  nodes.extend(await get_mentioned_nodes(driver, episodes))
84
85
 
85
86
  if SearchMethod.bm25 in config.search_methods:
86
- text_search = await edge_fulltext_search(driver, query, None, None, 2 * config.num_edges)
87
+ text_search = await edge_fulltext_search(
88
+ driver, query, None, None, config.group_ids, 2 * config.num_edges
89
+ )
87
90
  search_results.append(text_search)
88
91
 
89
92
  if SearchMethod.cosine_similarity in config.search_methods:
@@ -95,7 +98,7 @@ async def hybrid_search(
95
98
  )
96
99
 
97
100
  similarity_search = await edge_similarity_search(
98
- driver, search_vector, None, None, 2 * config.num_edges
101
+ driver, search_vector, None, None, config.group_ids, 2 * config.num_edges
99
102
  )
100
103
  search_results.append(similarity_search)
101
104