graphiti-core 0.2.3__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/nodes.py CHANGED
@@ -25,6 +25,7 @@ from uuid import uuid4
25
25
  from neo4j import AsyncDriver
26
26
  from pydantic import BaseModel, Field
27
27
 
28
+ from graphiti_core.errors import NodeNotFoundError
28
29
  from graphiti_core.llm_client.config import EMBEDDING_DIM
29
30
 
30
31
  logger = logging.getLogger(__name__)
@@ -76,8 +77,18 @@ class Node(BaseModel, ABC):
76
77
  @abstractmethod
77
78
  async def save(self, driver: AsyncDriver): ...
78
79
 
79
- @abstractmethod
80
- async def delete(self, driver: AsyncDriver): ...
80
+ async def delete(self, driver: AsyncDriver):
81
+ result = await driver.execute_query(
82
+ """
83
+ MATCH (n {uuid: $uuid})
84
+ DETACH DELETE n
85
+ """,
86
+ uuid=self.uuid,
87
+ )
88
+
89
+ logger.info(f'Deleted Node: {self.uuid}')
90
+
91
+ return result
81
92
 
82
93
  def __hash__(self):
83
94
  return hash(self.uuid)
@@ -90,6 +101,9 @@ class Node(BaseModel, ABC):
90
101
  @classmethod
91
102
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
92
103
 
104
+ @classmethod
105
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
106
+
93
107
 
94
108
  class EpisodicNode(Node):
95
109
  source: EpisodeType = Field(description='source type')
@@ -125,24 +139,37 @@ class EpisodicNode(Node):
125
139
 
126
140
  return result
127
141
 
128
- async def delete(self, driver: AsyncDriver):
129
- result = await driver.execute_query(
142
+ @classmethod
143
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
144
+ records, _, _ = await driver.execute_query(
130
145
  """
131
- MATCH (n:Episodic {uuid: $uuid})
132
- DETACH DELETE n
146
+ MATCH (e:Episodic {uuid: $uuid})
147
+ RETURN e.content AS content,
148
+ e.created_at AS created_at,
149
+ e.valid_at AS valid_at,
150
+ e.uuid AS uuid,
151
+ e.name AS name,
152
+ e.group_id AS group_id,
153
+ e.source_description AS source_description,
154
+ e.source AS source
133
155
  """,
134
- uuid=self.uuid,
156
+ uuid=uuid,
135
157
  )
136
158
 
137
- logger.info(f'Deleted Node: {self.uuid}')
159
+ episodes = [get_episodic_node_from_record(record) for record in records]
138
160
 
139
- return result
161
+ logger.info(f'Found Node: {uuid}')
162
+
163
+ if len(episodes) == 0:
164
+ raise NodeNotFoundError(uuid)
165
+
166
+ return episodes[0]
140
167
 
141
168
  @classmethod
142
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
169
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
143
170
  records, _, _ = await driver.execute_query(
144
171
  """
145
- MATCH (e:Episodic {uuid: $uuid})
172
+ MATCH (e:Episodic) WHERE e.uuid IN $uuids
146
173
  RETURN e.content AS content,
147
174
  e.created_at AS created_at,
148
175
  e.valid_at AS valid_at,
@@ -152,14 +179,14 @@ class EpisodicNode(Node):
152
179
  e.source_description AS source_description,
153
180
  e.source AS source
154
181
  """,
155
- uuid=uuid,
182
+ uuids=uuids,
156
183
  )
157
184
 
158
185
  episodes = [get_episodic_node_from_record(record) for record in records]
159
186
 
160
- logger.info(f'Found Node: {uuid}')
187
+ logger.info(f'Found Nodes: {uuids}')
161
188
 
162
- return episodes[0]
189
+ return episodes
163
190
 
164
191
 
165
192
  class EntityNode(Node):
@@ -194,24 +221,88 @@ class EntityNode(Node):
194
221
 
195
222
  return result
196
223
 
197
- async def delete(self, driver: AsyncDriver):
198
- result = await driver.execute_query(
224
+ @classmethod
225
+ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
226
+ records, _, _ = await driver.execute_query(
199
227
  """
200
228
  MATCH (n:Entity {uuid: $uuid})
201
- DETACH DELETE n
229
+ RETURN
230
+ n.uuid As uuid,
231
+ n.name AS name,
232
+ n.name_embedding AS name_embedding,
233
+ n.group_id AS group_id
234
+ n.created_at AS created_at,
235
+ n.summary AS summary
236
+ """,
237
+ uuid=uuid,
238
+ )
239
+
240
+ nodes = [get_entity_node_from_record(record) for record in records]
241
+
242
+ logger.info(f'Found Node: {uuid}')
243
+
244
+ return nodes[0]
245
+
246
+ @classmethod
247
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
248
+ records, _, _ = await driver.execute_query(
249
+ """
250
+ MATCH (n:Entity) WHERE n.uuid IN $uuids
251
+ RETURN
252
+ n.uuid As uuid,
253
+ n.name AS name,
254
+ n.name_embedding AS name_embedding,
255
+ n.group_id AS group_id,
256
+ n.created_at AS created_at,
257
+ n.summary AS summary
202
258
  """,
259
+ uuids=uuids,
260
+ )
261
+
262
+ nodes = [get_entity_node_from_record(record) for record in records]
263
+
264
+ logger.info(f'Found Nodes: {uuids}')
265
+
266
+ return nodes
267
+
268
+
269
+ class CommunityNode(Node):
270
+ name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
271
+ summary: str = Field(description='region summary of member nodes', default_factory=str)
272
+
273
+ async def save(self, driver: AsyncDriver):
274
+ result = await driver.execute_query(
275
+ """
276
+ MERGE (n:Community {uuid: $uuid})
277
+ SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, group_id: $group_id, summary: $summary, created_at: $created_at}
278
+ RETURN n.uuid AS uuid""",
203
279
  uuid=self.uuid,
280
+ name=self.name,
281
+ group_id=self.group_id,
282
+ summary=self.summary,
283
+ name_embedding=self.name_embedding,
284
+ created_at=self.created_at,
204
285
  )
205
286
 
206
- logger.info(f'Deleted Node: {self.uuid}')
287
+ logger.info(f'Saved Node to neo4j: {self.uuid}')
207
288
 
208
289
  return result
209
290
 
291
+ async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
292
+ start = time()
293
+ text = self.name.replace('\n', ' ')
294
+ embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
295
+ self.name_embedding = embedding[:EMBEDDING_DIM]
296
+ end = time()
297
+ logger.info(f'embedded {text} in {end - start} ms')
298
+
299
+ return embedding
300
+
210
301
  @classmethod
211
302
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
212
303
  records, _, _ = await driver.execute_query(
213
304
  """
214
- MATCH (n:Entity {uuid: $uuid})
305
+ MATCH (n:Community {uuid: $uuid})
215
306
  RETURN
216
307
  n.uuid As uuid,
217
308
  n.name AS name,
@@ -223,12 +314,34 @@ class EntityNode(Node):
223
314
  uuid=uuid,
224
315
  )
225
316
 
226
- nodes = [get_entity_node_from_record(record) for record in records]
317
+ nodes = [get_community_node_from_record(record) for record in records]
227
318
 
228
319
  logger.info(f'Found Node: {uuid}')
229
320
 
230
321
  return nodes[0]
231
322
 
323
+ @classmethod
324
+ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
325
+ records, _, _ = await driver.execute_query(
326
+ """
327
+ MATCH (n:Community) WHERE n.uuid IN $uuids
328
+ RETURN
329
+ n.uuid As uuid,
330
+ n.name AS name,
331
+ n.name_embedding AS name_embedding,
332
+ n.group_id AS group_id
333
+ n.created_at AS created_at,
334
+ n.summary AS summary
335
+ """,
336
+ uuids=uuids,
337
+ )
338
+
339
+ nodes = [get_community_node_from_record(record) for record in records]
340
+
341
+ logger.info(f'Found Nodes: {uuids}')
342
+
343
+ return nodes
344
+
232
345
 
233
346
  # Node helpers
234
347
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
@@ -254,3 +367,14 @@ def get_entity_node_from_record(record: Any) -> EntityNode:
254
367
  created_at=record['created_at'].to_native(),
255
368
  summary=record['summary'],
256
369
  )
370
+
371
+
372
+ def get_community_node_from_record(record: Any) -> CommunityNode:
373
+ return CommunityNode(
374
+ uuid=record['uuid'],
375
+ name=record['name'],
376
+ group_id=record['group_id'],
377
+ name_embedding=record['name_embedding'],
378
+ created_at=record['created_at'].to_native(),
379
+ summary=record['summary'],
380
+ )
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from typing import Any, Protocol, TypedDict
2
18
 
3
19
  from .models import Message, PromptFunction, PromptVersion
@@ -24,12 +24,14 @@ class Prompt(Protocol):
24
24
  v1: PromptVersion
25
25
  v2: PromptVersion
26
26
  extract_json: PromptVersion
27
+ extract_text: PromptVersion
27
28
 
28
29
 
29
30
  class Versions(TypedDict):
30
31
  v1: PromptFunction
31
32
  v2: PromptFunction
32
33
  extract_json: PromptFunction
34
+ extract_text: PromptFunction
33
35
 
34
36
 
35
37
  def v1(context: dict[str, Any]) -> list[Message]:
@@ -144,4 +146,44 @@ Respond with a JSON object in the following format:
144
146
  ]
145
147
 
146
148
 
147
- versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
149
+ def extract_text(context: dict[str, Any]) -> list[Message]:
150
+ sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""
151
+
152
+ user_prompt = f"""
153
+ Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
154
+
155
+ Conversation:
156
+ {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
157
+ <CURRENT MESSAGE>
158
+ {context["episode_content"]}
159
+
160
+ Guidelines:
161
+ 2. Extract significant entities, concepts, or actors mentioned in the conversation.
162
+ 3. Provide concise but informative summaries for each extracted node.
163
+ 4. Avoid creating nodes for relationships or actions.
164
+ 5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
165
+ 6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
166
+
167
+ Respond with a JSON object in the following format:
168
+ {{
169
+ "extracted_nodes": [
170
+ {{
171
+ "name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
172
+ "labels": ["Entity", "OptionalAdditionalLabel"],
173
+ "summary": "Brief summary of the node's role or significance"
174
+ }}
175
+ ]
176
+ }}
177
+ """
178
+ return [
179
+ Message(role='system', content=sys_prompt),
180
+ Message(role='user', content=user_prompt),
181
+ ]
182
+
183
+
184
+ versions: Versions = {
185
+ 'v1': v1,
186
+ 'v2': v2,
187
+ 'extract_json': extract_json,
188
+ 'extract_text': extract_text,
189
+ }
@@ -71,6 +71,9 @@ from .invalidate_edges import (
71
71
  versions as invalidate_edges_versions,
72
72
  )
73
73
  from .models import Message, PromptFunction
74
+ from .summarize_nodes import Prompt as SummarizeNodesPrompt
75
+ from .summarize_nodes import Versions as SummarizeNodesVersions
76
+ from .summarize_nodes import versions as summarize_nodes_versions
74
77
 
75
78
 
76
79
  class PromptLibrary(Protocol):
@@ -80,6 +83,7 @@ class PromptLibrary(Protocol):
80
83
  dedupe_edges: DedupeEdgesPrompt
81
84
  invalidate_edges: InvalidateEdgesPrompt
82
85
  extract_edge_dates: ExtractEdgeDatesPrompt
86
+ summarize_nodes: SummarizeNodesPrompt
83
87
 
84
88
 
85
89
  class PromptLibraryImpl(TypedDict):
@@ -89,6 +93,7 @@ class PromptLibraryImpl(TypedDict):
89
93
  dedupe_edges: DedupeEdgesVersions
90
94
  invalidate_edges: InvalidateEdgesVersions
91
95
  extract_edge_dates: ExtractEdgeDatesVersions
96
+ summarize_nodes: SummarizeNodesVersions
92
97
 
93
98
 
94
99
  class VersionWrapper:
@@ -118,5 +123,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
118
123
  'dedupe_edges': dedupe_edges_versions,
119
124
  'invalidate_edges': invalidate_edges_versions,
120
125
  'extract_edge_dates': extract_edge_dates_versions,
126
+ 'summarize_nodes': summarize_nodes_versions,
121
127
  }
122
128
  prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
@@ -0,0 +1,79 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ from typing import Any, Protocol, TypedDict
19
+
20
+ from .models import Message, PromptFunction, PromptVersion
21
+
22
+
23
+ class Prompt(Protocol):
24
+ summarize_pair: PromptVersion
25
+ summary_description: PromptVersion
26
+
27
+
28
+ class Versions(TypedDict):
29
+ summarize_pair: PromptFunction
30
+ summary_description: PromptFunction
31
+
32
+
33
+ def summarize_pair(context: dict[str, Any]) -> list[Message]:
34
+ return [
35
+ Message(
36
+ role='system',
37
+ content='You are a helpful assistant that combines summaries.',
38
+ ),
39
+ Message(
40
+ role='user',
41
+ content=f"""
42
+ Synthesize the information from the following two summaries into a single succinct summary.
43
+
44
+ Summaries:
45
+ {json.dumps(context['node_summaries'], indent=2)}
46
+
47
+ Respond with a JSON object in the following format:
48
+ {{
49
+ "summary": "Summary containing the important information from both summaries"
50
+ }}
51
+ """,
52
+ ),
53
+ ]
54
+
55
+
56
+ def summary_description(context: dict[str, Any]) -> list[Message]:
57
+ return [
58
+ Message(
59
+ role='system',
60
+ content='You are a helpful assistant that describes provided contents in a single sentence.',
61
+ ),
62
+ Message(
63
+ role='user',
64
+ content=f"""
65
+ Create a short one sentence description of the summary that explains what kind of information is summarized.
66
+
67
+ Summary:
68
+ {json.dumps(context['summary'], indent=2)}
69
+
70
+ Respond with a JSON object in the following format:
71
+ {{
72
+ "description": "One sentence description of the provided summary"
73
+ }}
74
+ """,
75
+ ),
76
+ ]
77
+
78
+
79
+ versions: Versions = {'summarize_pair': summarize_pair, 'summary_description': summary_description}
graphiti_core/py.typed ADDED
@@ -0,0 +1 @@
1
+ # This file is intentionally left empty to indicate that the package is typed.