graphiti-core 0.2.3__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (42) hide show
  1. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/PKG-INFO +8 -2
  2. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/README.md +6 -0
  3. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/edges.py +68 -29
  4. graphiti_core-0.3.0/graphiti_core/errors.py +18 -0
  5. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/graphiti.py +18 -1
  6. graphiti_core-0.3.0/graphiti_core/llm_client/__init__.py +6 -0
  7. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/anthropic_client.py +9 -1
  8. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/client.py +17 -10
  9. graphiti_core-0.3.0/graphiti_core/llm_client/errors.py +6 -0
  10. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/groq_client.py +4 -0
  11. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/openai_client.py +4 -0
  12. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/nodes.py +144 -20
  13. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/extract_nodes.py +43 -1
  14. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/lib.py +6 -0
  15. graphiti_core-0.3.0/graphiti_core/prompts/summarize_nodes.py +79 -0
  16. graphiti_core-0.3.0/graphiti_core/py.typed +1 -0
  17. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/search/search_utils.py +27 -22
  18. graphiti_core-0.3.0/graphiti_core/utils/maintenance/community_operations.py +155 -0
  19. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/edge_operations.py +20 -2
  20. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/graph_data_operations.py +11 -0
  21. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/node_operations.py +26 -1
  22. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/pyproject.toml +3 -3
  23. graphiti_core-0.2.3/graphiti_core/llm_client/__init__.py +0 -5
  24. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/LICENSE +0 -0
  25. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/__init__.py +0 -0
  26. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/helpers.py +0 -0
  27. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/config.py +0 -0
  28. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/llm_client/utils.py +0 -0
  29. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/__init__.py +0 -0
  30. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/dedupe_edges.py +0 -0
  31. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  32. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  33. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/extract_edges.py +0 -0
  34. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/invalidate_edges.py +0 -0
  35. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/prompts/models.py +0 -0
  36. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/search/__init__.py +0 -0
  37. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/search/search.py +0 -0
  38. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/__init__.py +0 -0
  39. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/bulk_utils.py +0 -0
  40. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
  41. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  42. {graphiti_core-0.2.3 → graphiti_core-0.3.0}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.2.3
3
+ Version: 0.3.0
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -13,7 +13,7 @@ Classifier: Programming Language :: Python :: 3.11
13
13
  Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
- Requires-Dist: numpy (>=2.1.1,<3.0.0)
16
+ Requires-Dist: numpy (>=1.0.0)
17
17
  Requires-Dist: openai (>=1.38.0,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
@@ -170,6 +170,12 @@ await graphiti.search('Who was the California Attorney General?', center_node_uu
170
170
  graphiti.close()
171
171
  ```
172
172
 
173
+ ## Graph Service
174
+
175
+ The `server` directory contains an API service for interacting with the Graphiti API. It is built using FastAPI.
176
+
177
+ Please see the [server README](./server/README.md) for more information.
178
+
173
179
  ## Documentation
174
180
 
175
181
  - [Guides and API documentation](https://help.getzep.com/graphiti).
@@ -149,6 +149,12 @@ await graphiti.search('Who was the California Attorney General?', center_node_uu
149
149
  graphiti.close()
150
150
  ```
151
151
 
152
+ ## Graph Service
153
+
154
+ The `server` directory contains an API service for interacting with the Graphiti API. It is built using FastAPI.
155
+
156
+ Please see the [server README](./server/README.md) for more information.
157
+
152
158
  ## Documentation
153
159
 
154
160
  - [Guides and API documentation](https://help.getzep.com/graphiti).
@@ -24,6 +24,7 @@ from uuid import uuid4
24
24
  from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
26
 
27
+ from graphiti_core.errors import EdgeNotFoundError
27
28
  from graphiti_core.helpers import parse_db_date
28
29
  from graphiti_core.llm_client.config import EMBEDDING_DIM
29
30
  from graphiti_core.nodes import Node
@@ -41,8 +42,18 @@ class Edge(BaseModel, ABC):
41
42
  @abstractmethod
42
43
  async def save(self, driver: AsyncDriver): ...
43
44
 
44
- @abstractmethod
45
- async def delete(self, driver: AsyncDriver): ...
45
+ async def delete(self, driver: AsyncDriver):
46
+ result = await driver.execute_query(
47
+ """
48
+ MATCH (n)-[e {uuid: $uuid}]->(m)
49
+ DELETE e
50
+ """,
51
+ uuid=self.uuid,
52
+ )
53
+
54
+ logger.info(f'Deleted Edge: {self.uuid}')
55
+
56
+ return result
46
57
 
47
58
  def __hash__(self):
48
59
  return hash(self.uuid)
@@ -76,19 +87,6 @@ class EpisodicEdge(Edge):
76
87
 
77
88
  return result
78
89
 
79
- async def delete(self, driver: AsyncDriver):
80
- result = await driver.execute_query(
81
- """
82
- MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
83
- DELETE e
84
- """,
85
- uuid=self.uuid,
86
- )
87
-
88
- logger.info(f'Deleted Edge: {self.uuid}')
89
-
90
- return result
91
-
92
90
  @classmethod
93
91
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
94
92
  records, _, _ = await driver.execute_query(
@@ -107,7 +105,8 @@ class EpisodicEdge(Edge):
107
105
  edges = [get_episodic_edge_from_record(record) for record in records]
108
106
 
109
107
  logger.info(f'Found Edge: {uuid}')
110
-
108
+ if len(edges) == 0:
109
+ raise EdgeNotFoundError(uuid)
111
110
  return edges[0]
112
111
 
113
112
 
@@ -169,19 +168,6 @@ class EntityEdge(Edge):
169
168
 
170
169
  return result
171
170
 
172
- async def delete(self, driver: AsyncDriver):
173
- result = await driver.execute_query(
174
- """
175
- MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
176
- DELETE e
177
- """,
178
- uuid=self.uuid,
179
- )
180
-
181
- logger.info(f'Deleted Edge: {self.uuid}')
182
-
183
- return result
184
-
185
171
  @classmethod
186
172
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
187
173
  records, _, _ = await driver.execute_query(
@@ -206,6 +192,49 @@ class EntityEdge(Edge):
206
192
 
207
193
  edges = [get_entity_edge_from_record(record) for record in records]
208
194
 
195
+ logger.info(f'Found Edge: {uuid}')
196
+ if len(edges) == 0:
197
+ raise EdgeNotFoundError(uuid)
198
+ return edges[0]
199
+
200
+
201
+ class CommunityEdge(Edge):
202
+ async def save(self, driver: AsyncDriver):
203
+ result = await driver.execute_query(
204
+ """
205
+ MATCH (community:Community {uuid: $community_uuid})
206
+ MATCH (node:Entity | Community {uuid: $entity_uuid})
207
+ MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
208
+ SET r = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
209
+ RETURN r.uuid AS uuid""",
210
+ community_uuid=self.source_node_uuid,
211
+ entity_uuid=self.target_node_uuid,
212
+ uuid=self.uuid,
213
+ group_id=self.group_id,
214
+ created_at=self.created_at,
215
+ )
216
+
217
+ logger.info(f'Saved edge to neo4j: {self.uuid}')
218
+
219
+ return result
220
+
221
+ @classmethod
222
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
223
+ records, _, _ = await driver.execute_query(
224
+ """
225
+ MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m:Entity | Community)
226
+ RETURN
227
+ e.uuid As uuid,
228
+ e.group_id AS group_id,
229
+ n.uuid AS source_node_uuid,
230
+ m.uuid AS target_node_uuid,
231
+ e.created_at AS created_at
232
+ """,
233
+ uuid=uuid,
234
+ )
235
+
236
+ edges = [get_community_edge_from_record(record) for record in records]
237
+
209
238
  logger.info(f'Found Edge: {uuid}')
210
239
 
211
240
  return edges[0]
@@ -237,3 +266,13 @@ def get_entity_edge_from_record(record: Any) -> EntityEdge:
237
266
  valid_at=parse_db_date(record['valid_at']),
238
267
  invalid_at=parse_db_date(record['invalid_at']),
239
268
  )
269
+
270
+
271
+ def get_community_edge_from_record(record: Any):
272
+ return CommunityEdge(
273
+ uuid=record['uuid'],
274
+ group_id=record['group_id'],
275
+ source_node_uuid=record['source_node_uuid'],
276
+ target_node_uuid=record['target_node_uuid'],
277
+ created_at=record['created_at'].to_native(),
278
+ )
@@ -0,0 +1,18 @@
1
+ class GraphitiError(Exception):
2
+ """Base exception class for Graphiti Core."""
3
+
4
+
5
+ class EdgeNotFoundError(GraphitiError):
6
+ """Raised when an edge is not found."""
7
+
8
+ def __init__(self, uuid: str):
9
+ self.message = f'edge {uuid} not found'
10
+ super().__init__(self.message)
11
+
12
+
13
+ class NodeNotFoundError(GraphitiError):
14
+ """Raised when a node is not found."""
15
+
16
+ def __init__(self, uuid: str):
17
+ self.message = f'node {uuid} not found'
18
+ super().__init__(self.message)
@@ -46,6 +46,10 @@ from graphiti_core.utils.bulk_utils import (
46
46
  resolve_edge_pointers,
47
47
  retrieve_previous_episodes_bulk,
48
48
  )
49
+ from graphiti_core.utils.maintenance.community_operations import (
50
+ build_communities,
51
+ remove_communities,
52
+ )
49
53
  from graphiti_core.utils.maintenance.edge_operations import (
50
54
  extract_edges,
51
55
  resolve_extracted_edges,
@@ -412,7 +416,7 @@ class Graphiti:
412
416
  except Exception as e:
413
417
  raise e
414
418
 
415
- async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None):
419
+ async def add_episode_bulk(self, bulk_episodes: list[RawEpisode], group_id: str | None = None):
416
420
  """
417
421
  Process multiple episodes in bulk and update the graph.
418
422
 
@@ -526,6 +530,19 @@ class Graphiti:
526
530
  except Exception as e:
527
531
  raise e
528
532
 
533
+ async def build_communities(self):
534
+ embedder = self.llm_client.get_embedder()
535
+
536
+ # Clear existing communities
537
+ await remove_communities(self.driver)
538
+
539
+ community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
540
+
541
+ await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
542
+
543
+ await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
544
+ await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
545
+
529
546
  async def search(
530
547
  self,
531
548
  query: str,
@@ -0,0 +1,6 @@
1
+ from .client import LLMClient
2
+ from .config import LLMConfig
3
+ from .errors import RateLimitError
4
+ from .openai_client import OpenAIClient
5
+
6
+ __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig', 'RateLimitError']
@@ -18,12 +18,14 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import anthropic
21
22
  from anthropic import AsyncAnthropic
22
23
  from openai import AsyncOpenAI
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
26
27
  from .config import LLMConfig
28
+ from .errors import RateLimitError
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -35,7 +37,11 @@ class AnthropicClient(LLMClient):
35
37
  if config is None:
36
38
  config = LLMConfig()
37
39
  super().__init__(config, cache)
38
- self.client = AsyncAnthropic(api_key=config.api_key)
40
+ self.client = AsyncAnthropic(
41
+ api_key=config.api_key,
42
+ # we'll use tenacity to retry
43
+ max_retries=1,
44
+ )
39
45
 
40
46
  def get_embedder(self) -> typing.Any:
41
47
  openai_client = AsyncOpenAI()
@@ -58,6 +64,8 @@ class AnthropicClient(LLMClient):
58
64
  )
59
65
 
60
66
  return json.loads('{' + result.content[0].text) # type: ignore
67
+ except anthropic.RateLimitError as e:
68
+ raise RateLimitError from e
61
69
  except Exception as e:
62
70
  logger.error(f'Error in generating LLM response: {e}')
63
71
  raise
@@ -22,10 +22,11 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
- from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
25
+ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
26
 
27
27
  from ..prompts.models import Message
28
28
  from .config import LLMConfig
29
+ from .errors import RateLimitError
29
30
 
30
31
  DEFAULT_TEMPERATURE = 0
31
32
  DEFAULT_CACHE_DIR = './llm_cache'
@@ -33,7 +34,10 @@ DEFAULT_CACHE_DIR = './llm_cache'
33
34
  logger = logging.getLogger(__name__)
34
35
 
35
36
 
36
- def is_server_error(exception):
37
+ def is_server_or_retry_error(exception):
38
+ if isinstance(exception, RateLimitError):
39
+ return True
40
+
37
41
  return (
38
42
  isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
39
43
  )
@@ -56,18 +60,21 @@ class LLMClient(ABC):
56
60
  pass
57
61
 
58
62
  @retry(
59
- stop=stop_after_attempt(3),
60
- wait=wait_exponential(multiplier=1, min=4, max=10),
61
- retry=retry_if_exception(is_server_error),
63
+ stop=stop_after_attempt(4),
64
+ wait=wait_random_exponential(multiplier=10, min=5, max=120),
65
+ retry=retry_if_exception(is_server_or_retry_error),
66
+ after=lambda retry_state: logger.warning(
67
+ f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
68
+ )
69
+ if retry_state.attempt_number > 1
70
+ else None,
71
+ reraise=True,
62
72
  )
63
73
  async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
64
74
  try:
65
75
  return await self._generate_response(messages)
66
- except httpx.HTTPStatusError as e:
67
- if not is_server_error(e):
68
- raise Exception(f'LLM request error: {e}') from e
69
- else:
70
- raise
76
+ except (httpx.HTTPStatusError, RateLimitError) as e:
77
+ raise e
71
78
 
72
79
  @abstractmethod
73
80
  async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
@@ -0,0 +1,6 @@
1
+ class RateLimitError(Exception):
2
+ """Exception raised when the rate limit is exceeded."""
3
+
4
+ def __init__(self, message='Rate limit exceeded. Please try again later.'):
5
+ self.message = message
6
+ super().__init__(self.message)
@@ -18,6 +18,7 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import groq
21
22
  from groq import AsyncGroq
22
23
  from groq.types.chat import ChatCompletionMessageParam
23
24
  from openai import AsyncOpenAI
@@ -25,6 +26,7 @@ from openai import AsyncOpenAI
25
26
  from ..prompts.models import Message
26
27
  from .client import LLMClient
27
28
  from .config import LLMConfig
29
+ from .errors import RateLimitError
28
30
 
29
31
  logger = logging.getLogger(__name__)
30
32
 
@@ -59,6 +61,8 @@ class GroqClient(LLMClient):
59
61
  )
60
62
  result = response.choices[0].message.content or ''
61
63
  return json.loads(result)
64
+ except groq.RateLimitError as e:
65
+ raise RateLimitError from e
62
66
  except Exception as e:
63
67
  logger.error(f'Error in generating LLM response: {e}')
64
68
  raise
@@ -18,12 +18,14 @@ import json
18
18
  import logging
19
19
  import typing
20
20
 
21
+ import openai
21
22
  from openai import AsyncOpenAI
22
23
  from openai.types.chat import ChatCompletionMessageParam
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
26
27
  from .config import LLMConfig
28
+ from .errors import RateLimitError
27
29
 
28
30
  logger = logging.getLogger(__name__)
29
31
 
@@ -59,6 +61,8 @@ class OpenAIClient(LLMClient):
59
61
  )
60
62
  result = response.choices[0].message.content or ''
61
63
  return json.loads(result)
64
+ except openai.RateLimitError as e:
65
+ raise RateLimitError from e
62
66
  except Exception as e:
63
67
  logger.error(f'Error in generating LLM response: {e}')
64
68
  raise
@@ -25,6 +25,7 @@ from uuid import uuid4
25
25
  from neo4j import AsyncDriver
26
26
  from pydantic import BaseModel, Field
27
27
 
28
+ from graphiti_core.errors import NodeNotFoundError
28
29
  from graphiti_core.llm_client.config import EMBEDDING_DIM
29
30
 
30
31
  logger = logging.getLogger(__name__)
@@ -76,8 +77,18 @@ class Node(BaseModel, ABC):
76
77
  @abstractmethod
77
78
  async def save(self, driver: AsyncDriver): ...
78
79
 
79
- @abstractmethod
80
- async def delete(self, driver: AsyncDriver): ...
80
+ async def delete(self, driver: AsyncDriver):
81
+ result = await driver.execute_query(
82
+ """
83
+ MATCH (n {uuid: $uuid})
84
+ DETACH DELETE n
85
+ """,
86
+ uuid=self.uuid,
87
+ )
88
+
89
+ logger.info(f'Deleted Node: {self.uuid}')
90
+
91
+ return result
81
92
 
82
93
  def __hash__(self):
83
94
  return hash(self.uuid)
@@ -90,6 +101,9 @@ class Node(BaseModel, ABC):
90
101
  @classmethod
91
102
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
92
103
 
104
+ @classmethod
105
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
106
+
93
107
 
94
108
  class EpisodicNode(Node):
95
109
  source: EpisodeType = Field(description='source type')
@@ -125,24 +139,37 @@ class EpisodicNode(Node):
125
139
 
126
140
  return result
127
141
 
128
- async def delete(self, driver: AsyncDriver):
129
- result = await driver.execute_query(
142
+ @classmethod
143
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
144
+ records, _, _ = await driver.execute_query(
130
145
  """
131
- MATCH (n:Episodic {uuid: $uuid})
132
- DETACH DELETE n
146
+ MATCH (e:Episodic {uuid: $uuid})
147
+ RETURN e.content AS content,
148
+ e.created_at AS created_at,
149
+ e.valid_at AS valid_at,
150
+ e.uuid AS uuid,
151
+ e.name AS name,
152
+ e.group_id AS group_id,
153
+ e.source_description AS source_description,
154
+ e.source AS source
133
155
  """,
134
- uuid=self.uuid,
156
+ uuid=uuid,
135
157
  )
136
158
 
137
- logger.info(f'Deleted Node: {self.uuid}')
159
+ episodes = [get_episodic_node_from_record(record) for record in records]
138
160
 
139
- return result
161
+ logger.info(f'Found Node: {uuid}')
162
+
163
+ if len(episodes) == 0:
164
+ raise NodeNotFoundError(uuid)
165
+
166
+ return episodes[0]
140
167
 
141
168
  @classmethod
142
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
169
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
143
170
  records, _, _ = await driver.execute_query(
144
171
  """
145
- MATCH (e:Episodic {uuid: $uuid})
172
+ MATCH (e:Episodic) WHERE e.uuid IN $uuids
146
173
  RETURN e.content AS content,
147
174
  e.created_at AS created_at,
148
175
  e.valid_at AS valid_at,
@@ -152,14 +179,14 @@ class EpisodicNode(Node):
152
179
  e.source_description AS source_description,
153
180
  e.source AS source
154
181
  """,
155
- uuid=uuid,
182
+ uuids=uuids,
156
183
  )
157
184
 
158
185
  episodes = [get_episodic_node_from_record(record) for record in records]
159
186
 
160
- logger.info(f'Found Node: {uuid}')
187
+ logger.info(f'Found Nodes: {uuids}')
161
188
 
162
- return episodes[0]
189
+ return episodes
163
190
 
164
191
 
165
192
  class EntityNode(Node):
@@ -194,24 +221,88 @@ class EntityNode(Node):
194
221
 
195
222
  return result
196
223
 
197
- async def delete(self, driver: AsyncDriver):
198
- result = await driver.execute_query(
224
+ @classmethod
225
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
226
+ records, _, _ = await driver.execute_query(
199
227
  """
200
228
  MATCH (n:Entity {uuid: $uuid})
201
- DETACH DELETE n
229
+ RETURN
230
+ n.uuid As uuid,
231
+ n.name AS name,
232
+ n.name_embedding AS name_embedding,
233
+ n.group_id AS group_id
234
+ n.created_at AS created_at,
235
+ n.summary AS summary
236
+ """,
237
+ uuid=uuid,
238
+ )
239
+
240
+ nodes = [get_entity_node_from_record(record) for record in records]
241
+
242
+ logger.info(f'Found Node: {uuid}')
243
+
244
+ return nodes[0]
245
+
246
+ @classmethod
247
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
248
+ records, _, _ = await driver.execute_query(
249
+ """
250
+ MATCH (n:Entity) WHERE n.uuid IN $uuids
251
+ RETURN
252
+ n.uuid As uuid,
253
+ n.name AS name,
254
+ n.name_embedding AS name_embedding,
255
+ n.group_id AS group_id,
256
+ n.created_at AS created_at,
257
+ n.summary AS summary
202
258
  """,
259
+ uuids=uuids,
260
+ )
261
+
262
+ nodes = [get_entity_node_from_record(record) for record in records]
263
+
264
+ logger.info(f'Found Nodes: {uuids}')
265
+
266
+ return nodes
267
+
268
+
269
+ class CommunityNode(Node):
270
+ name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
271
+ summary: str = Field(description='region summary of member nodes', default_factory=str)
272
+
273
+ async def save(self, driver: AsyncDriver):
274
+ result = await driver.execute_query(
275
+ """
276
+ MERGE (n:Community {uuid: $uuid})
277
+ SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
278
+ RETURN n.uuid AS uuid""",
203
279
  uuid=self.uuid,
280
+ name=self.name,
281
+ group_id=self.group_id,
282
+ summary=self.summary,
283
+ name_embedding=self.name_embedding,
284
+ created_at=self.created_at,
204
285
  )
205
286
 
206
- logger.info(f'Deleted Node: {self.uuid}')
287
+ logger.info(f'Saved Node to neo4j: {self.uuid}')
207
288
 
208
289
  return result
209
290
 
291
+ async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
292
+ start = time()
293
+ text = self.name.replace('\n', ' ')
294
+ embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
295
+ self.name_embedding = embedding[:EMBEDDING_DIM]
296
+ end = time()
297
+ logger.info(f'embedded {text} in {end - start} ms')
298
+
299
+ return embedding
300
+
210
301
  @classmethod
211
302
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
212
303
  records, _, _ = await driver.execute_query(
213
304
  """
214
- MATCH (n:Entity {uuid: $uuid})
305
+ MATCH (n:Community {uuid: $uuid})
215
306
  RETURN
216
307
  n.uuid As uuid,
217
308
  n.name AS name,
@@ -223,12 +314,34 @@ class EntityNode(Node):
223
314
  uuid=uuid,
224
315
  )
225
316
 
226
- nodes = [get_entity_node_from_record(record) for record in records]
317
+ nodes = [get_community_node_from_record(record) for record in records]
227
318
 
228
319
  logger.info(f'Found Node: {uuid}')
229
320
 
230
321
  return nodes[0]
231
322
 
323
+ @classmethod
324
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
325
+ records, _, _ = await driver.execute_query(
326
+ """
327
+ MATCH (n:Community) WHERE n.uuid IN $uuids
328
+ RETURN
329
+ n.uuid As uuid,
330
+ n.name AS name,
331
+ n.name_embedding AS name_embedding,
332
+ n.group_id AS group_id
333
+ n.created_at AS created_at,
334
+ n.summary AS summary
335
+ """,
336
+ uuids=uuids,
337
+ )
338
+
339
+ nodes = [get_community_node_from_record(record) for record in records]
340
+
341
+ logger.info(f'Found Nodes: {uuids}')
342
+
343
+ return nodes
344
+
232
345
 
233
346
  # Node helpers
234
347
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
@@ -254,3 +367,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
254
367
  created_at=record['created_at'].to_native(),
255
368
  summary=record['summary'],
256
369
  )
370
+
371
+
372
+ def get_community_node_from_record(record: Any) -> CommunityNode:
373
+ return CommunityNode(
374
+ uuid=record['uuid'],
375
+ name=record['name'],
376
+ group_id=record['group_id'],
377
+ name_embedding=record['name_embedding'],
378
+ created_at=record['created_at'].to_native(),
379
+ summary=record['summary'],
380
+ )
@@ -24,12 +24,14 @@ class Prompt(Protocol):
24
24
  v1: PromptVersion
25
25
  v2: PromptVersion
26
26
  extract_json: PromptVersion
27
+ extract_text: PromptVersion
27
28
 
28
29
 
29
30
  class Versions(TypedDict):
30
31
  v1: PromptFunction
31
32
  v2: PromptFunction
32
33
  extract_json: PromptFunction
34
+ extract_text: PromptFunction
33
35
 
34
36
 
35
37
  def v1(context: dict[str, Any]) -> list[Message]:
@@ -144,4 +146,44 @@ Respond with a JSON object in the following format:
144
146
  ]
145
147
 
146
148
 
147
- versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
149
+ def extract_text(context: dict[str, Any]) -> list[Message]:
150
+ sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
151
+
152
+ user_prompt = f"""
153
+ Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
154
+
155
+ Conversation:
156
+ {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
157
+ <CURRENT MESSAGE>
158
+ {context["episode_content"]}
159
+
160
+ Guidelines:
161
+ 2. Extract significant entities, concepts, or actors mentioned in the conversation.
162
+ 3. Provide concise but informative summaries for each extracted node.
163
+ 4. Avoid creating nodes for relationships or actions.
164
+ 5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
165
+ 6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
166
+
167
+ Respond with a JSON object in the following format:
168
+ {{
169
+ "extracted_nodes": [
170
+ {{
171
+ "name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
172
+ "labels": ["Entity", "OptionalAdditionalLabel"],
173
+ "summary": "Brief summary of the node's role or significance"
174
+ }}
175
+ ]
176
+ }}
177
+ """
178
+ return [
179
+ Message(role='system', content=sys_prompt),
180
+ Message(role='user', content=user_prompt),
181
+ ]
182
+
183
+
184
+ versions: Versions = {
185
+ 'v1': v1,
186
+ 'v2': v2,
187
+ 'extract_json': extract_json,
188
+ 'extract_text': extract_text,
189
+ }
@@ -71,6 +71,9 @@ from .invalidate_edges import (
71
71
  versions as invalidate_edges_versions,
72
72
  )
73
73
  from .models import Message, PromptFunction
74
+ from .summarize_nodes import Prompt as SummarizeNodesPrompt
75
+ from .summarize_nodes import Versions as SummarizeNodesVersions
76
+ from .summarize_nodes import versions as summarize_nodes_versions
74
77
 
75
78
 
76
79
  class PromptLibrary(Protocol):
@@ -80,6 +83,7 @@ class PromptLibrary(Protocol):
80
83
  dedupe_edges: DedupeEdgesPrompt
81
84
  invalidate_edges: InvalidateEdgesPrompt
82
85
  extract_edge_dates: ExtractEdgeDatesPrompt
86
+ summarize_nodes: SummarizeNodesPrompt
83
87
 
84
88
 
85
89
  class PromptLibraryImpl(TypedDict):
@@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict):
89
93
  dedupe_edges: DedupeEdgesVersions
90
94
  invalidate_edges: InvalidateEdgesVersions
91
95
  extract_edge_dates: ExtractEdgeDatesVersions
96
+ summarize_nodes: SummarizeNodesVersions
92
97
 
93
98
 
94
99
  class VersionWrapper:
@@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
118
123
  'dedupe_edges': dedupe_edges_versions,
119
124
  'invalidate_edges': invalidate_edges_versions,
120
125
  'extract_edge_dates': extract_edge_dates_versions,
126
+ 'summarize_nodes': summarize_nodes_versions,
121
127
  }
122
128
  prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
@@ -0,0 +1,79 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ from typing import Any, Protocol, TypedDict
19
+
20
+ from .models import Message, PromptFunction, PromptVersion
21
+
22
+
23
+ class Prompt(Protocol):
24
+ summarize_pair: PromptVersion
25
+ summary_description: PromptVersion
26
+
27
+
28
+ class Versions(TypedDict):
29
+ summarize_pair: PromptFunction
30
+ summary_description: PromptFunction
31
+
32
+
33
+ def summarize_pair(context: dict[str, Any]) -> list[Message]:
34
+ return [
35
+ Message(
36
+ role='system',
37
+ content='You are a helpful assistant that combines summaries.',
38
+ ),
39
+ Message(
40
+ role='user',
41
+ content=f"""
42
+ Synthesize the information from the following two summaries into a single succinct summary.
43
+
44
+ Summaries:
45
+ {json.dumps(context['node_summaries'], indent=2)}
46
+
47
+ Respond with a JSON object in the following format:
48
+ {{
49
+ "summary": "Summary containing the important information from both summaries"
50
+ }}
51
+ """,
52
+ ),
53
+ ]
54
+
55
+
56
+ def summary_description(context: dict[str, Any]) -> list[Message]:
57
+ return [
58
+ Message(
59
+ role='system',
60
+ content='You are a helpful assistant that describes provided contents in a single sentence.',
61
+ ),
62
+ Message(
63
+ role='user',
64
+ content=f"""
65
+ Create a short one sentence description of the summary that explains what kind of information is summarized.
66
+
67
+ Summary:
68
+ {json.dumps(context['summary'], indent=2)}
69
+
70
+ Respond with a JSON object in the following format:
71
+ {{
72
+ "description": "One sentence description of the provided summary"
73
+ }}
74
+ """,
75
+ ),
76
+ ]
77
+
78
+
79
+ versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description}
@@ -0,0 +1 @@
1
+ # This file is intentionally left empty to indicate that the package is typed.
@@ -496,34 +496,39 @@ async def node_distance_reranker(
496
496
  sorted_uuids = rrf(results)
497
497
  scores: dict[str, float] = {}
498
498
 
499
- for uuid in sorted_uuids:
500
- # Find the shortest path to center node
501
- records, _, _ = await driver.execute_query(
502
- """
499
+ # Find the shortest path to center node
500
+ query = Query("""
503
501
  MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
504
- MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
505
- WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
506
- RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
507
- """,
508
- edge_uuid=uuid,
509
- center_uuid=center_node_uuid,
510
- )
511
- distance = 0.01
502
+ MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
503
+ RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
504
+ """)
512
505
 
513
- for record in records:
514
- if (
515
- record['source_uuid'] == center_node_uuid
516
- or record['target_uuid'] == center_node_uuid
517
- ):
518
- continue
519
- distance = record['score']
506
+ path_results = await asyncio.gather(
507
+ *[
508
+ driver.execute_query(
509
+ query,
510
+ edge_uuid=uuid,
511
+ center_uuid=center_node_uuid,
512
+ )
513
+ for uuid in sorted_uuids
514
+ ]
515
+ )
516
+
517
+ for uuid, result in zip(sorted_uuids, path_results):
518
+ records = result[0]
519
+ record = records[0] if len(records) > 0 else None
520
+ distance: float = record['score'] if record is not None else float('inf')
521
+ if record is not None and (
522
+ record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
523
+ ):
524
+ distance = 0
520
525
 
521
526
  if uuid in scores:
522
- scores[uuid] = min(1 / distance, scores[uuid])
527
+ scores[uuid] = min(distance, scores[uuid])
523
528
  else:
524
- scores[uuid] = 1 / distance
529
+ scores[uuid] = distance
525
530
 
526
531
  # rerank on shortest distance
527
- sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
532
+ sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
528
533
 
529
534
  return sorted_uuids
@@ -0,0 +1,155 @@
1
+ import asyncio
2
+ import logging
3
+ from collections import defaultdict
4
+ from datetime import datetime
5
+
6
+ from neo4j import AsyncDriver
7
+
8
+ from graphiti_core.edges import CommunityEdge
9
+ from graphiti_core.llm_client import LLMClient
10
+ from graphiti_core.nodes import CommunityNode, EntityNode
11
+ from graphiti_core.prompts import prompt_library
12
+ from graphiti_core.utils.maintenance.edge_operations import build_community_edges
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ async def build_community_projection(driver: AsyncDriver) -> str:
18
+ records, _, _ = await driver.execute_query("""
19
+ CALL gds.graph.project("communities", "Entity",
20
+ {RELATES_TO: {
21
+ type: "RELATES_TO",
22
+ orientation: "UNDIRECTED",
23
+ properties: {weight: {property: "*", aggregation: "COUNT"}}
24
+ }}
25
+ )
26
+ YIELD graphName AS graph, nodeProjection AS nodes, relationshipProjection AS edges
27
+ """)
28
+
29
+ return records[0]['graph']
30
+
31
+
32
+ async def destroy_projection(driver: AsyncDriver, projection_name: str):
33
+ await driver.execute_query(
34
+ """
35
+ CALL gds.graph.drop($projection_name)
36
+ """,
37
+ projection_name=projection_name,
38
+ )
39
+
40
+
41
+ async def get_community_clusters(
42
+ driver: AsyncDriver, projection_name: str
43
+ ) -> list[list[EntityNode]]:
44
+ records, _, _ = await driver.execute_query("""
45
+ CALL gds.leiden.stream("communities")
46
+ YIELD nodeId, communityId
47
+ RETURN gds.util.asNode(nodeId).uuid AS entity_uuid, communityId
48
+ """)
49
+ community_map: dict[int, list[str]] = defaultdict(list)
50
+ for record in records:
51
+ community_map[record['communityId']].append(record['entity_uuid'])
52
+
53
+ community_clusters: list[list[EntityNode]] = list(
54
+ await asyncio.gather(
55
+ *[EntityNode.get_by_uuids(driver, cluster) for cluster in community_map.values()]
56
+ )
57
+ )
58
+
59
+ return community_clusters
60
+
61
+
62
+ async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
63
+ # Prepare context for LLM
64
+ context = {'node_summaries': [{'summary': summary} for summary in summary_pair]}
65
+
66
+ llm_response = await llm_client.generate_response(
67
+ prompt_library.summarize_nodes.summarize_pair(context)
68
+ )
69
+
70
+ pair_summary = llm_response.get('summary', '')
71
+
72
+ return pair_summary
73
+
74
+
75
+ async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
76
+ context = {'summary': summary}
77
+
78
+ llm_response = await llm_client.generate_response(
79
+ prompt_library.summarize_nodes.summary_description(context)
80
+ )
81
+
82
+ description = llm_response.get('description', '')
83
+
84
+ return description
85
+
86
+
87
+ async def build_community(
88
+ llm_client: LLMClient, community_cluster: list[EntityNode]
89
+ ) -> tuple[CommunityNode, list[CommunityEdge]]:
90
+ summaries = [entity.summary for entity in community_cluster]
91
+ length = len(summaries)
92
+ while length > 1:
93
+ odd_one_out: str | None = None
94
+ if length % 2 == 1:
95
+ odd_one_out = summaries.pop()
96
+ length -= 1
97
+ new_summaries: list[str] = list(
98
+ await asyncio.gather(
99
+ *[
100
+ summarize_pair(llm_client, (str(left_summary), str(right_summary)))
101
+ for left_summary, right_summary in zip(
102
+ summaries[: int(length / 2)], summaries[int(length / 2) :]
103
+ )
104
+ ]
105
+ )
106
+ )
107
+ if odd_one_out is not None:
108
+ new_summaries.append(odd_one_out)
109
+ summaries = new_summaries
110
+ length = len(summaries)
111
+
112
+ summary = summaries[0]
113
+ name = await generate_summary_description(llm_client, summary)
114
+ now = datetime.now()
115
+ community_node = CommunityNode(
116
+ name=name,
117
+ group_id=community_cluster[0].group_id,
118
+ labels=['Community'],
119
+ created_at=now,
120
+ summary=summary,
121
+ )
122
+ community_edges = build_community_edges(community_cluster, community_node, now)
123
+
124
+ logger.info((community_node, community_edges))
125
+
126
+ return community_node, community_edges
127
+
128
+
129
+ async def build_communities(
130
+ driver: AsyncDriver, llm_client: LLMClient
131
+ ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
132
+ projection = await build_community_projection(driver)
133
+ community_clusters = await get_community_clusters(driver, projection)
134
+
135
+ communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
136
+ await asyncio.gather(
137
+ *[build_community(llm_client, cluster) for cluster in community_clusters]
138
+ )
139
+ )
140
+
141
+ community_nodes: list[CommunityNode] = []
142
+ community_edges: list[CommunityEdge] = []
143
+ for community in communities:
144
+ community_nodes.append(community[0])
145
+ community_edges.extend(community[1])
146
+
147
+ await destroy_projection(driver, projection)
148
+ return community_nodes, community_edges
149
+
150
+
151
+ async def remove_communities(driver: AsyncDriver):
152
+ await driver.execute_query("""
153
+ MATCH (c:Community)
154
+ DETACH DELETE c
155
+ """)
@@ -20,9 +20,9 @@ from datetime import datetime
20
20
  from time import time
21
21
  from typing import List
22
22
 
23
- from graphiti_core.edges import EntityEdge, EpisodicEdge
23
+ from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
24
24
  from graphiti_core.llm_client import LLMClient
25
- from graphiti_core.nodes import EntityNode, EpisodicNode
25
+ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
26
26
  from graphiti_core.prompts import prompt_library
27
27
  from graphiti_core.utils.maintenance.temporal_operations import (
28
28
  extract_edge_dates,
@@ -50,6 +50,24 @@ def build_episodic_edges(
50
50
  return edges
51
51
 
52
52
 
53
+ def build_community_edges(
54
+ entity_nodes: List[EntityNode],
55
+ community_node: CommunityNode,
56
+ created_at: datetime,
57
+ ) -> List[CommunityEdge]:
58
+ edges: List[CommunityEdge] = [
59
+ CommunityEdge(
60
+ source_node_uuid=community_node.uuid,
61
+ target_node_uuid=node.uuid,
62
+ created_at=created_at,
63
+ group_id=community_node.group_id,
64
+ )
65
+ for node in entity_nodes
66
+ ]
67
+
68
+ return edges
69
+
70
+
53
71
  async def extract_edges(
54
72
  llm_client: LLMClient,
55
73
  episode: EpisodicNode,
@@ -32,8 +32,10 @@ async def build_indices_and_constraints(driver: AsyncDriver):
32
32
  range_indices: list[LiteralString] = [
33
33
  'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
34
34
  'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
35
+ 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
35
36
  'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
36
37
  'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
38
+ 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
37
39
  'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
38
40
  'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
39
41
  'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
@@ -51,6 +53,7 @@ async def build_indices_and_constraints(driver: AsyncDriver):
51
53
 
52
54
  fulltext_indices: list[LiteralString] = [
53
55
  'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]',
56
+ 'CREATE FULLTEXT INDEX community_name IF NOT EXISTS FOR (n:Community) ON EACH [n.name]',
54
57
  'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]',
55
58
  ]
56
59
 
@@ -71,6 +74,14 @@ async def build_indices_and_constraints(driver: AsyncDriver):
71
74
  `vector.similarity_function`: 'cosine'
72
75
  }}
73
76
  """,
77
+ """
78
+ CREATE VECTOR INDEX community_name_embedding IF NOT EXISTS
79
+ FOR (n:Community) ON (n.name_embedding)
80
+ OPTIONS {indexConfig: {
81
+ `vector.dimensions`: 1024,
82
+ `vector.similarity_function`: 'cosine'
83
+ }}
84
+ """,
74
85
  ]
75
86
  index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices
76
87
 
@@ -48,6 +48,29 @@ async def extract_message_nodes(
48
48
  return extracted_node_data
49
49
 
50
50
 
51
+ async def extract_text_nodes(
52
+ llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
53
+ ) -> list[dict[str, Any]]:
54
+ # Prepare context for LLM
55
+ context = {
56
+ 'episode_content': episode.content,
57
+ 'episode_timestamp': episode.valid_at.isoformat(),
58
+ 'previous_episodes': [
59
+ {
60
+ 'content': ep.content,
61
+ 'timestamp': ep.valid_at.isoformat(),
62
+ }
63
+ for ep in previous_episodes
64
+ ],
65
+ }
66
+
67
+ llm_response = await llm_client.generate_response(
68
+ prompt_library.extract_nodes.extract_text(context)
69
+ )
70
+ extracted_node_data = llm_response.get('extracted_nodes', [])
71
+ return extracted_node_data
72
+
73
+
51
74
  async def extract_json_nodes(
52
75
  llm_client: LLMClient,
53
76
  episode: EpisodicNode,
@@ -73,8 +96,10 @@ async def extract_nodes(
73
96
  ) -> list[EntityNode]:
74
97
  start = time()
75
98
  extracted_node_data: list[dict[str, Any]] = []
76
- if episode.source in [EpisodeType.message, EpisodeType.text]:
99
+ if episode.source == EpisodeType.message:
77
100
  extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
101
+ elif episode.source == EpisodeType.text:
102
+ extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes)
78
103
  elif episode.source == EpisodeType.json:
79
104
  extracted_node_data = await extract_json_nodes(llm_client, episode)
80
105
 
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "graphiti-core"
3
- version = "0.2.3"
3
+ version = "0.3.0"
4
4
  description = "A temporal graph building library"
5
5
  authors = [
6
6
  "Paul Paliychuk <paul@getzep.com>",
@@ -19,7 +19,7 @@ neo4j = "^5.23.0"
19
19
  diskcache = "^5.6.3"
20
20
  openai = "^1.38.0"
21
21
  tenacity = "<9.0.0"
22
- numpy = "^2.1.1"
22
+ numpy = ">=1.0.0"
23
23
 
24
24
  [tool.poetry.dev-dependencies]
25
25
  pytest = "^8.3.2"
@@ -31,7 +31,7 @@ ruff = "^0.6.2"
31
31
  [tool.poetry.group.dev.dependencies]
32
32
  pydantic = "^2.8.2"
33
33
  mypy = "^1.11.1"
34
- groq = ">=0.9,<0.11"
34
+ groq = ">=0.9,<0.12"
35
35
  anthropic = "^0.34.1"
36
36
  ipykernel = "^6.29.5"
37
37
  jupyterlab = "^4.2.4"
@@ -1,5 +0,0 @@
1
- from .client import LLMClient
2
- from .config import LLMConfig
3
- from .openai_client import OpenAIClient
4
-
5
- __all__ = ['LLMClient', 'OpenAIClient', 'LLMConfig']
File without changes