graphiti-core 0.21.0rc12__py3-none-any.whl → 0.22.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -14,29 +14,16 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import copy
19
18
  import logging
20
19
  import os
21
20
  from abc import ABC, abstractmethod
22
21
  from collections.abc import Coroutine
23
- from datetime import datetime
24
22
  from enum import Enum
25
23
  from typing import Any
26
24
 
27
25
  from dotenv import load_dotenv
28
26
 
29
- from graphiti_core.embedder.client import EMBEDDING_DIM
30
-
31
- try:
32
- from opensearchpy import AsyncOpenSearch, helpers
33
-
34
- _HAS_OPENSEARCH = True
35
- except ImportError:
36
- OpenSearch = None
37
- helpers = None
38
- _HAS_OPENSEARCH = False
39
-
40
27
  logger = logging.getLogger(__name__)
41
28
 
42
29
  DEFAULT_SIZE = 10
@@ -56,91 +43,6 @@ class GraphProvider(Enum):
56
43
  NEPTUNE = 'neptune'
57
44
 
58
45
 
59
- aoss_indices = [
60
- {
61
- 'index_name': ENTITY_INDEX_NAME,
62
- 'body': {
63
- 'settings': {'index': {'knn': True}},
64
- 'mappings': {
65
- 'properties': {
66
- 'uuid': {'type': 'keyword'},
67
- 'name': {'type': 'text'},
68
- 'summary': {'type': 'text'},
69
- 'group_id': {'type': 'keyword'},
70
- 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
71
- 'name_embedding': {
72
- 'type': 'knn_vector',
73
- 'dimension': EMBEDDING_DIM,
74
- 'method': {
75
- 'engine': 'faiss',
76
- 'space_type': 'cosinesimil',
77
- 'name': 'hnsw',
78
- 'parameters': {'ef_construction': 128, 'm': 16},
79
- },
80
- },
81
- }
82
- },
83
- },
84
- },
85
- {
86
- 'index_name': COMMUNITY_INDEX_NAME,
87
- 'body': {
88
- 'mappings': {
89
- 'properties': {
90
- 'uuid': {'type': 'keyword'},
91
- 'name': {'type': 'text'},
92
- 'group_id': {'type': 'keyword'},
93
- }
94
- }
95
- },
96
- },
97
- {
98
- 'index_name': EPISODE_INDEX_NAME,
99
- 'body': {
100
- 'mappings': {
101
- 'properties': {
102
- 'uuid': {'type': 'keyword'},
103
- 'content': {'type': 'text'},
104
- 'source': {'type': 'text'},
105
- 'source_description': {'type': 'text'},
106
- 'group_id': {'type': 'keyword'},
107
- 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
108
- 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
109
- }
110
- }
111
- },
112
- },
113
- {
114
- 'index_name': ENTITY_EDGE_INDEX_NAME,
115
- 'body': {
116
- 'settings': {'index': {'knn': True}},
117
- 'mappings': {
118
- 'properties': {
119
- 'uuid': {'type': 'keyword'},
120
- 'name': {'type': 'text'},
121
- 'fact': {'type': 'text'},
122
- 'group_id': {'type': 'keyword'},
123
- 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
124
- 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
125
- 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
126
- 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
127
- 'fact_embedding': {
128
- 'type': 'knn_vector',
129
- 'dimension': EMBEDDING_DIM,
130
- 'method': {
131
- 'engine': 'faiss',
132
- 'space_type': 'cosinesimil',
133
- 'name': 'hnsw',
134
- 'parameters': {'ef_construction': 128, 'm': 16},
135
- },
136
- },
137
- }
138
- },
139
- },
140
- },
141
- ]
142
-
143
-
144
46
  class GraphDriverSession(ABC):
145
47
  provider: GraphProvider
146
48
 
@@ -171,7 +73,7 @@ class GraphDriver(ABC):
171
73
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
172
74
  )
173
75
  _database: str
174
- aoss_client: AsyncOpenSearch | None # type: ignore
76
+ aoss_client: Any # type: ignore
175
77
 
176
78
  @abstractmethod
177
79
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -199,119 +101,6 @@ class GraphDriver(ABC):
199
101
 
200
102
  return cloned
201
103
 
202
- async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
203
- # No matter what happens above, always return True
204
- return self.delete_aoss_indices()
205
-
206
- async def create_aoss_indices(self):
207
- client = self.aoss_client
208
- if not client:
209
- logger.warning('No OpenSearch client found')
210
- return
211
-
212
- for index in aoss_indices:
213
- alias_name = index['index_name']
214
-
215
- # If alias already exists, skip (idempotent behavior)
216
- if await client.indices.exists_alias(name=alias_name):
217
- continue
218
-
219
- # Build a physical index name with timestamp
220
- ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
221
- physical_index_name = f'{alias_name}_{ts_suffix}'
222
-
223
- # Create the index
224
- await client.indices.create(index=physical_index_name, body=index['body'])
225
-
226
- # Point alias to it
227
- await client.indices.put_alias(index=physical_index_name, name=alias_name)
228
-
229
- # Allow some time for index creation
230
- await asyncio.sleep(1)
231
-
232
- async def delete_aoss_indices(self):
233
- client = self.aoss_client
234
-
235
- if not client:
236
- logger.warning('No OpenSearch client found')
237
- return
238
-
239
- for entry in aoss_indices:
240
- alias_name = entry['index_name']
241
-
242
- try:
243
- # Resolve alias → indices
244
- alias_info = await client.indices.get_alias(name=alias_name)
245
- indices = list(alias_info.keys())
246
-
247
- if not indices:
248
- logger.info(f"No indices found for alias '{alias_name}'")
249
- continue
250
-
251
- for index in indices:
252
- if await client.indices.exists(index=index):
253
- await client.indices.delete(index=index)
254
- logger.info(f"Deleted index '{index}' (alias: {alias_name})")
255
- else:
256
- logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
257
-
258
- except Exception as e:
259
- logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
260
-
261
- async def clear_aoss_indices(self):
262
- client = self.aoss_client
263
-
264
- if not client:
265
- logger.warning('No OpenSearch client found')
266
- return
267
-
268
- for index in aoss_indices:
269
- index_name = index['index_name']
270
-
271
- if await client.indices.exists(index=index_name):
272
- try:
273
- # Delete all documents but keep the index
274
- response = await client.delete_by_query(
275
- index=index_name,
276
- body={'query': {'match_all': {}}},
277
- )
278
- logger.info(f"Cleared index '{index_name}': {response}")
279
- except Exception as e:
280
- logger.error(f"Error clearing index '{index_name}': {e}")
281
- else:
282
- logger.warning(f"Index '{index_name}' does not exist")
283
-
284
- async def save_to_aoss(self, name: str, data: list[dict]) -> int:
285
- client = self.aoss_client
286
- if not client or not helpers:
287
- logger.warning('No OpenSearch client found')
288
- return 0
289
-
290
- for index in aoss_indices:
291
- if name.lower() == index['index_name']:
292
- to_index = []
293
- for d in data:
294
- doc = {}
295
- for p in index['body']['mappings']['properties']:
296
- if p in d: # protect against missing fields
297
- doc[p] = d[p]
298
-
299
- item = {
300
- '_index': name,
301
- '_id': d['uuid'],
302
- '_routing': d.get('group_id'),
303
- '_source': doc,
304
- }
305
- to_index.append(item)
306
-
307
- success, failed = await helpers.async_bulk(
308
- client, to_index, stats_only=True, request_timeout=60
309
- )
310
-
311
- return success if failed == 0 else success
312
-
313
- return 0
314
-
315
104
  def build_fulltext_query(
316
105
  self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
317
106
  ) -> str:
@@ -320,3 +109,9 @@ class GraphDriver(ABC):
320
109
  Only implemented by providers that need custom fulltext query building.
321
110
  """
322
111
  raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
112
+
113
+ async def save_to_aoss(self, name: str, data: list[dict]) -> int:
114
+ return 0
115
+
116
+ async def clear_aoss_indices(self):
117
+ return 1
@@ -22,28 +22,9 @@ from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
24
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
25
- from graphiti_core.helpers import semaphore_gather
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
29
- try:
30
- import boto3
31
- from opensearchpy import (
32
- AIOHttpConnection,
33
- AsyncOpenSearch,
34
- AWSV4SignerAuth,
35
- Urllib3AWSV4SignerAuth,
36
- Urllib3HttpConnection,
37
- )
38
-
39
- _HAS_OPENSEARCH = True
40
- except ImportError:
41
- boto3 = None
42
- OpenSearch = None
43
- Urllib3AWSV4SignerAuth = None
44
- Urllib3HttpConnection = None
45
- _HAS_OPENSEARCH = False
46
-
47
28
 
48
29
  class Neo4jDriver(GraphDriver):
49
30
  provider = GraphProvider.NEO4J
@@ -54,11 +35,6 @@ class Neo4jDriver(GraphDriver):
54
35
  user: str | None,
55
36
  password: str | None,
56
37
  database: str = 'neo4j',
57
- aoss_host: str | None = None,
58
- aoss_port: int | None = None,
59
- aws_profile_name: str | None = None,
60
- aws_region: str | None = None,
61
- aws_service: str | None = None,
62
38
  ):
63
39
  super().__init__()
64
40
  self.client = AsyncGraphDatabase.driver(
@@ -68,24 +44,6 @@ class Neo4jDriver(GraphDriver):
68
44
  self._database = database
69
45
 
70
46
  self.aoss_client = None
71
- if aoss_host and aoss_port and boto3 is not None:
72
- try:
73
- region = aws_region
74
- service = aws_service
75
- credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
76
- auth = AWSV4SignerAuth(credentials, region or '', service or '')
77
-
78
- self.aoss_client = AsyncOpenSearch(
79
- hosts=[{'host': aoss_host, 'port': aoss_port}],
80
- auth=auth,
81
- use_ssl=True,
82
- verify_certs=True,
83
- connection_class=AIOHttpConnection,
84
- pool_maxsize=20,
85
- ) # type: ignore
86
- except Exception as e:
87
- logger.warning(f'Failed to initialize OpenSearch client: {e}')
88
- self.aoss_client = None
89
47
 
90
48
  async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
91
49
  # Check if database_ is provided in kwargs.
@@ -111,13 +69,6 @@ class Neo4jDriver(GraphDriver):
111
69
  return await self.client.close()
112
70
 
113
71
  def delete_all_indexes(self) -> Coroutine:
114
- if self.aoss_client:
115
- return semaphore_gather(
116
- self.client.execute_query(
117
- 'CALL db.indexes() YIELD name DROP INDEX name',
118
- ),
119
- self.delete_aoss_indices(),
120
- )
121
72
  return self.client.execute_query(
122
73
  'CALL db.indexes() YIELD name DROP INDEX name',
123
74
  )
@@ -22,21 +22,16 @@ from typing import Any
22
22
 
23
23
  import boto3
24
24
  from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
25
- from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
25
+ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
26
26
 
27
- from graphiti_core.driver.driver import (
28
- DEFAULT_SIZE,
29
- GraphDriver,
30
- GraphDriverSession,
31
- GraphProvider,
32
- )
27
+ from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
33
28
 
34
29
  logger = logging.getLogger(__name__)
30
+ DEFAULT_SIZE = 10
35
31
 
36
- neptune_aoss_indices = [
32
+ aoss_indices = [
37
33
  {
38
34
  'index_name': 'node_name_and_summary',
39
- 'alias_name': 'entities',
40
35
  'body': {
41
36
  'mappings': {
42
37
  'properties': {
@@ -54,7 +49,6 @@ neptune_aoss_indices = [
54
49
  },
55
50
  {
56
51
  'index_name': 'community_name',
57
- 'alias_name': 'communities',
58
52
  'body': {
59
53
  'mappings': {
60
54
  'properties': {
@@ -71,7 +65,6 @@ neptune_aoss_indices = [
71
65
  },
72
66
  {
73
67
  'index_name': 'episode_content',
74
- 'alias_name': 'episodes',
75
68
  'body': {
76
69
  'mappings': {
77
70
  'properties': {
@@ -95,7 +88,6 @@ neptune_aoss_indices = [
95
88
  },
96
89
  {
97
90
  'index_name': 'edge_name_and_fact',
98
- 'alias_name': 'facts',
99
91
  'body': {
100
92
  'mappings': {
101
93
  'properties': {
@@ -228,27 +220,52 @@ class NeptuneDriver(GraphDriver):
228
220
  async def _delete_all_data(self) -> Any:
229
221
  return await self.execute_query('MATCH (n) DETACH DELETE n')
230
222
 
223
+ def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
224
+ return self.delete_all_indexes_impl()
225
+
226
+ async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
227
+ # No matter what happens above, always return True
228
+ return self.delete_aoss_indices()
229
+
231
230
  async def create_aoss_indices(self):
232
- for index in neptune_aoss_indices:
231
+ for index in aoss_indices:
233
232
  index_name = index['index_name']
234
233
  client = self.aoss_client
235
- if not client:
236
- raise ValueError(
237
- 'You must provide an AOSS endpoint to create an OpenSearch driver.'
238
- )
239
234
  if not client.indices.exists(index=index_name):
240
- await client.indices.create(index=index_name, body=index['body'])
241
-
242
- alias_name = index.get('alias_name', index_name)
243
-
244
- if not client.indices.exists_alias(name=alias_name, index=index_name):
245
- await client.indices.put_alias(index=index_name, name=alias_name)
246
-
235
+ client.indices.create(index=index_name, body=index['body'])
247
236
  # Sleep for 1 minute to let the index creation complete
248
237
  await asyncio.sleep(60)
249
238
 
250
- def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
251
- return self.delete_all_indexes_impl()
239
+ async def delete_aoss_indices(self):
240
+ for index in aoss_indices:
241
+ index_name = index['index_name']
242
+ client = self.aoss_client
243
+ if client.indices.exists(index=index_name):
244
+ client.indices.delete(index=index_name)
245
+
246
+ def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
247
+ for index in aoss_indices:
248
+ if name.lower() == index['index_name']:
249
+ index['query']['query']['multi_match']['query'] = query_text
250
+ query = {'size': limit, 'query': index['query']}
251
+ resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
252
+ return resp
253
+ return {}
254
+
255
+ def save_to_aoss(self, name: str, data: list[dict]) -> int:
256
+ for index in aoss_indices:
257
+ if name.lower() == index['index_name']:
258
+ to_index = []
259
+ for d in data:
260
+ item = {'_index': name, '_id': d['uuid']}
261
+ for p in index['body']['mappings']['properties']:
262
+ if p in d:
263
+ item[p] = d[p]
264
+ to_index.append(item)
265
+ success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
266
+ return success
267
+
268
+ return 0
252
269
 
253
270
 
254
271
  class NeptuneDriverSession(GraphDriverSession):
@@ -33,12 +33,16 @@ DEFAULT_TEMPERATURE = 0
33
33
  DEFAULT_CACHE_DIR = './llm_cache'
34
34
 
35
35
 
36
- def get_extraction_language_instruction() -> str:
36
+ def get_extraction_language_instruction(group_id: str | None = None) -> str:
37
37
  """Returns instruction for language extraction behavior.
38
38
 
39
39
  Override this function to customize language extraction:
40
40
  - Return empty string to disable multilingual instructions
41
41
  - Return custom instructions for specific language requirements
42
+ - Use group_id to provide different instructions per group/partition
43
+
44
+ Args:
45
+ group_id: Optional partition identifier for the graph
42
46
 
43
47
  Returns:
44
48
  str: Language instruction to append to system messages
@@ -142,6 +146,7 @@ class LLMClient(ABC):
142
146
  response_model: type[BaseModel] | None = None,
143
147
  max_tokens: int | None = None,
144
148
  model_size: ModelSize = ModelSize.medium,
149
+ group_id: str | None = None,
145
150
  ) -> dict[str, typing.Any]:
146
151
  if max_tokens is None:
147
152
  max_tokens = self.max_tokens
@@ -155,7 +160,7 @@ class LLMClient(ABC):
155
160
  )
156
161
 
157
162
  # Add multilingual extraction instructions
158
- messages[0].content += get_extraction_language_instruction()
163
+ messages[0].content += get_extraction_language_instruction(group_id)
159
164
 
160
165
  if self.cache_enabled and self.cache_dir is not None:
161
166
  cache_key = self._get_cache_key(messages)
@@ -357,6 +357,7 @@ class GeminiClient(LLMClient):
357
357
  response_model: type[BaseModel] | None = None,
358
358
  max_tokens: int | None = None,
359
359
  model_size: ModelSize = ModelSize.medium,
360
+ group_id: str | None = None,
360
361
  ) -> dict[str, typing.Any]:
361
362
  """
362
363
  Generate a response from the Gemini language model with retry logic and error handling.
@@ -367,6 +368,7 @@ class GeminiClient(LLMClient):
367
368
  response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
368
369
  max_tokens (int | None): The maximum number of tokens to generate in the response.
369
370
  model_size (ModelSize): The size of the model to use (small or medium).
371
+ group_id (str | None): Optional partition identifier for the graph.
370
372
 
371
373
  Returns:
372
374
  dict[str, typing.Any]: The response from the language model.
@@ -376,7 +378,7 @@ class GeminiClient(LLMClient):
376
378
  last_output = None
377
379
 
378
380
  # Add multilingual extraction instructions
379
- messages[0].content += get_extraction_language_instruction()
381
+ messages[0].content += get_extraction_language_instruction(group_id)
380
382
 
381
383
  while retry_count < self.MAX_RETRIES:
382
384
  try:
@@ -175,6 +175,7 @@ class BaseOpenAIClient(LLMClient):
175
175
  response_model: type[BaseModel] | None = None,
176
176
  max_tokens: int | None = None,
177
177
  model_size: ModelSize = ModelSize.medium,
178
+ group_id: str | None = None,
178
179
  ) -> dict[str, typing.Any]:
179
180
  """Generate a response with retry logic and error handling."""
180
181
  if max_tokens is None:
@@ -184,7 +185,7 @@ class BaseOpenAIClient(LLMClient):
184
185
  last_error = None
185
186
 
186
187
  # Add multilingual extraction instructions
187
- messages[0].content += get_extraction_language_instruction()
188
+ messages[0].content += get_extraction_language_instruction(group_id)
188
189
 
189
190
  while retry_count <= self.MAX_RETRIES:
190
191
  try:
@@ -120,6 +120,7 @@ class OpenAIGenericClient(LLMClient):
120
120
  response_model: type[BaseModel] | None = None,
121
121
  max_tokens: int | None = None,
122
122
  model_size: ModelSize = ModelSize.medium,
123
+ group_id: str | None = None,
123
124
  ) -> dict[str, typing.Any]:
124
125
  if max_tokens is None:
125
126
  max_tokens = self.max_tokens
@@ -136,7 +137,7 @@ class OpenAIGenericClient(LLMClient):
136
137
  )
137
138
 
138
139
  # Add multilingual extraction instructions
139
- messages[0].content += get_extraction_language_instruction()
140
+ messages[0].content += get_extraction_language_instruction(group_id)
140
141
 
141
142
  while retry_count <= self.MAX_RETRIES:
142
143
  try:
@@ -139,6 +139,7 @@ async def extract_edges(
139
139
  prompt_library.extract_edges.edge(context),
140
140
  response_model=ExtractedEdges,
141
141
  max_tokens=extract_edges_max_tokens,
142
+ group_id=group_id,
142
143
  )
143
144
  edges_data = ExtractedEdges(**llm_response).edges
144
145
 
@@ -150,6 +151,7 @@ async def extract_edges(
150
151
  prompt_library.extract_edges.reflexion(context),
151
152
  response_model=MissingFacts,
152
153
  max_tokens=extract_edges_max_tokens,
154
+ group_id=group_id,
153
155
  )
154
156
 
155
157
  missing_facts = reflexion_response.get('missing_facts', [])
@@ -177,6 +179,10 @@ async def extract_edges(
177
179
  valid_at_datetime = None
178
180
  invalid_at_datetime = None
179
181
 
182
+ # Filter out empty edges
183
+ if not edge_data.fact.strip():
184
+ continue
185
+
180
186
  source_node_idx = edge_data.source_entity_id
181
187
  target_node_idx = edge_data.target_entity_id
182
188
 
@@ -64,6 +64,7 @@ async def extract_nodes_reflexion(
64
64
  episode: EpisodicNode,
65
65
  previous_episodes: list[EpisodicNode],
66
66
  node_names: list[str],
67
+ group_id: str | None = None,
67
68
  ) -> list[str]:
68
69
  # Prepare context for LLM
69
70
  context = {
@@ -73,7 +74,9 @@ async def extract_nodes_reflexion(
73
74
  }
74
75
 
75
76
  llm_response = await llm_client.generate_response(
76
- prompt_library.extract_nodes.reflexion(context), MissedEntities
77
+ prompt_library.extract_nodes.reflexion(context),
78
+ MissedEntities,
79
+ group_id=group_id,
77
80
  )
78
81
  missed_entities = llm_response.get('missed_entities', [])
79
82
 
@@ -129,16 +132,19 @@ async def extract_nodes(
129
132
  llm_response = await llm_client.generate_response(
130
133
  prompt_library.extract_nodes.extract_message(context),
131
134
  response_model=ExtractedEntities,
135
+ group_id=episode.group_id,
132
136
  )
133
137
  elif episode.source == EpisodeType.text:
134
138
  llm_response = await llm_client.generate_response(
135
139
  prompt_library.extract_nodes.extract_text(context),
136
140
  response_model=ExtractedEntities,
141
+ group_id=episode.group_id,
137
142
  )
138
143
  elif episode.source == EpisodeType.json:
139
144
  llm_response = await llm_client.generate_response(
140
145
  prompt_library.extract_nodes.extract_json(context),
141
146
  response_model=ExtractedEntities,
147
+ group_id=episode.group_id,
142
148
  )
143
149
 
144
150
  response_object = ExtractedEntities(**llm_response)
@@ -152,6 +158,7 @@ async def extract_nodes(
152
158
  episode,
153
159
  previous_episodes,
154
160
  [entity.name for entity in extracted_entities],
161
+ episode.group_id,
155
162
  )
156
163
 
157
164
  entities_missed = len(missing_entities) != 0
@@ -192,6 +199,7 @@ async def extract_nodes(
192
199
  logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
193
200
 
194
201
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
202
+
195
203
  return extracted_nodes
196
204
 
197
205
 
@@ -477,63 +485,95 @@ async def extract_attributes_from_node(
477
485
  entity_type: type[BaseModel] | None = None,
478
486
  should_summarize_node: NodeSummaryFilter | None = None,
479
487
  ) -> EntityNode:
480
- node_context: dict[str, Any] = {
481
- 'name': node.name,
482
- 'summary': node.summary,
483
- 'entity_types': node.labels,
484
- 'attributes': node.attributes,
485
- }
488
+ # Extract attributes if entity type is defined and has attributes
489
+ llm_response = await _extract_entity_attributes(
490
+ llm_client, node, episode, previous_episodes, entity_type
491
+ )
486
492
 
487
- attributes_context: dict[str, Any] = {
488
- 'node': node_context,
489
- 'episode_content': episode.content if episode is not None else '',
490
- 'previous_episodes': (
491
- [ep.content for ep in previous_episodes] if previous_episodes is not None else []
492
- ),
493
- }
493
+ # Extract summary if needed
494
+ await _extract_entity_summary(
495
+ llm_client, node, episode, previous_episodes, should_summarize_node
496
+ )
497
+
498
+ node.attributes.update(llm_response)
499
+
500
+ return node
494
501
 
495
- summary_context: dict[str, Any] = {
496
- 'node': node_context,
497
- 'episode_content': episode.content if episode is not None else '',
498
- 'previous_episodes': (
499
- [ep.content for ep in previous_episodes] if previous_episodes is not None else []
500
- ),
501
- }
502
502
 
503
- has_entity_attributes: bool = bool(
504
- entity_type is not None and len(entity_type.model_fields) != 0
503
+ async def _extract_entity_attributes(
504
+ llm_client: LLMClient,
505
+ node: EntityNode,
506
+ episode: EpisodicNode | None,
507
+ previous_episodes: list[EpisodicNode] | None,
508
+ entity_type: type[BaseModel] | None,
509
+ ) -> dict[str, Any]:
510
+ if entity_type is None or len(entity_type.model_fields) == 0:
511
+ return {}
512
+
513
+ attributes_context = _build_episode_context(
514
+ # should not include summary
515
+ node_data={
516
+ 'name': node.name,
517
+ 'entity_types': node.labels,
518
+ 'attributes': node.attributes,
519
+ },
520
+ episode=episode,
521
+ previous_episodes=previous_episodes,
505
522
  )
506
523
 
507
- llm_response = (
508
- (
509
- await llm_client.generate_response(
510
- prompt_library.extract_nodes.extract_attributes(attributes_context),
511
- response_model=entity_type,
512
- model_size=ModelSize.small,
513
- )
514
- )
515
- if has_entity_attributes
516
- else {}
524
+ llm_response = await llm_client.generate_response(
525
+ prompt_library.extract_nodes.extract_attributes(attributes_context),
526
+ response_model=entity_type,
527
+ model_size=ModelSize.small,
528
+ group_id=node.group_id,
517
529
  )
518
530
 
519
- # Determine if summary should be generated
520
- generate_summary = True
521
- if should_summarize_node is not None:
522
- generate_summary = await should_summarize_node(node)
523
-
524
- # Conditionally generate summary
525
- if generate_summary:
526
- summary_response = await llm_client.generate_response(
527
- prompt_library.extract_nodes.extract_summary(summary_context),
528
- response_model=EntitySummary,
529
- model_size=ModelSize.small,
530
- )
531
- node.summary = summary_response.get('summary', '')
531
+ # validate response
532
+ entity_type(**llm_response)
532
533
 
533
- if has_entity_attributes and entity_type is not None:
534
- entity_type(**llm_response)
535
- node_attributes = {key: value for key, value in llm_response.items()}
534
+ return llm_response
536
535
 
537
- node.attributes.update(node_attributes)
538
536
 
539
- return node
537
+ async def _extract_entity_summary(
538
+ llm_client: LLMClient,
539
+ node: EntityNode,
540
+ episode: EpisodicNode | None,
541
+ previous_episodes: list[EpisodicNode] | None,
542
+ should_summarize_node: NodeSummaryFilter | None,
543
+ ) -> None:
544
+ if should_summarize_node is not None and not await should_summarize_node(node):
545
+ return
546
+
547
+ summary_context = _build_episode_context(
548
+ node_data={
549
+ 'name': node.name,
550
+ 'summary': node.summary,
551
+ 'entity_types': node.labels,
552
+ 'attributes': node.attributes,
553
+ },
554
+ episode=episode,
555
+ previous_episodes=previous_episodes,
556
+ )
557
+
558
+ summary_response = await llm_client.generate_response(
559
+ prompt_library.extract_nodes.extract_summary(summary_context),
560
+ response_model=EntitySummary,
561
+ model_size=ModelSize.small,
562
+ group_id=node.group_id,
563
+ )
564
+
565
+ node.summary = summary_response.get('summary', '')
566
+
567
+
568
+ def _build_episode_context(
569
+ node_data: dict[str, Any],
570
+ episode: EpisodicNode | None,
571
+ previous_episodes: list[EpisodicNode] | None,
572
+ ) -> dict[str, Any]:
573
+ return {
574
+ 'node': node_data,
575
+ 'episode_content': episode.content if episode is not None else '',
576
+ 'previous_episodes': (
577
+ [ep.content for ep in previous_episodes] if previous_episodes is not None else []
578
+ ),
579
+ }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphiti-core
3
- Version: 0.21.0rc12
3
+ Version: 0.22.0rc0
4
4
  Summary: A temporal graph building library
5
5
  Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
6
6
  Project-URL: Repository, https://github.com/getzep/graphiti
@@ -13,11 +13,11 @@ graphiti_core/cross_encoder/client.py,sha256=KLsbfWKOEaAV3adFe3XZlAeb-gje9_sVKCV
13
13
  graphiti_core/cross_encoder/gemini_reranker_client.py,sha256=hmITG5YIib52nrKvINwRi4xTfAO1U4jCCaEVIwImHw0,6208
14
14
  graphiti_core/cross_encoder/openai_reranker_client.py,sha256=WHMl6Q6gEslR2EzjwpFSZt2Kh6bnu8alkLvzmi0MDtg,4674
15
15
  graphiti_core/driver/__init__.py,sha256=kCWimqQU19airu5gKwCmZtZuXkDfaQfKSUhMDoL-rTA,626
16
- graphiti_core/driver/driver.py,sha256=EO9Aj5O2vpH7iyvQQcE5uJGQ8eA-_i6f8NwfAlW8r74,10831
16
+ graphiti_core/driver/driver.py,sha256=sF6CkGLNPIvUgrmWkVws7TvQCskRHiQKJze4Y4ibMmI,3357
17
17
  graphiti_core/driver/falkordb_driver.py,sha256=Q-dImfK4O2bkikqFzo0Wg2g7iFFRSuzy_c6u82tX6-M,9361
18
18
  graphiti_core/driver/kuzu_driver.py,sha256=RcWu8E0CCdofrFe34NmCeqfuhaZr_7ZN5jqDkI3VQMI,5453
19
- graphiti_core/driver/neo4j_driver.py,sha256=E93PdOZaH7wzEbIfoiDSYht49jr6zSzvMMyo1INGEOw,4096
20
- graphiti_core/driver/neptune_driver.py,sha256=akNLHhFHPEeQu-xO3PM51RomklntT6k5eA2CQ4AFbCc,10311
19
+ graphiti_core/driver/neo4j_driver.py,sha256=xiMUvGpW-XFM_2ab5nJJTHoi_LM7CvVZVq6ZO0BbNwc,2380
20
+ graphiti_core/driver/neptune_driver.py,sha256=dyQcaA5VnpNA_XkaWdvgGN3Q0QqbxWcVIud--yT8qhE,11266
21
21
  graphiti_core/embedder/__init__.py,sha256=EL564ZuE-DZjcuKNUK_exMn_XHXm2LdO9fzdXePVKL4,179
22
22
  graphiti_core/embedder/azure_openai.py,sha256=OyomPwC1fIsddI-3n6g00kQFdQznZorBhHwkQKCLUok,2384
23
23
  graphiti_core/embedder/client.py,sha256=BXFMXvuPWxaAzPaPILnxtqQQ4JWBFQv9GdBLOXUWgwE,1158
@@ -27,14 +27,14 @@ graphiti_core/embedder/voyage.py,sha256=oJHAZiNqjdEJOKgoKfGWcxK2-Ewqn5UB3vrBwIwP
27
27
  graphiti_core/llm_client/__init__.py,sha256=QgBWUiCeBp6YiA_xqyrDvJ9jIyy1hngH8g7FWahN3nw,776
28
28
  graphiti_core/llm_client/anthropic_client.py,sha256=xTFcrgMDK77BwnChBhYj51Jaa2mRNI850oJv2pKZI0A,12892
29
29
  graphiti_core/llm_client/azure_openai_client.py,sha256=ekERggAekbb7enes1RJqdRChf_mjaZTFXsnMbxO7azQ,2497
30
- graphiti_core/llm_client/client.py,sha256=KUWq7Gq9J4PdP06lLCBEb8OSZOE6luPqaQ3xgtpZwWg,6835
30
+ graphiti_core/llm_client/client.py,sha256=xF3KtXbgP0jC6nKHtIiP5m9dNzxuZaqqQHCKiexijjU,7053
31
31
  graphiti_core/llm_client/config.py,sha256=pivp29CDIbDPqgw5NF9Ok2AwcqTV5z5_Q1bgNs1CDGs,2560
32
32
  graphiti_core/llm_client/errors.py,sha256=pn6brRiLW60DAUIXJYKBT6MInrS4ueuH1hNLbn_JbQo,1243
33
- graphiti_core/llm_client/gemini_client.py,sha256=AxD7sqsPQdgfcZCBIGN302s1hFYlBN9FOQcDEV0tw08,17725
33
+ graphiti_core/llm_client/gemini_client.py,sha256=ohwuvJ-YTJ67xr6t5UYwSFo87WsyHeMiu8vNCifHod0,17850
34
34
  graphiti_core/llm_client/groq_client.py,sha256=bYLE_cg1QEhugsJOXh4b1vPbxagKeMWqk48240GCzMs,2922
35
- graphiti_core/llm_client/openai_base_client.py,sha256=LeEBZ33Y_bIz-YSr6aCbYKMI9r0SNPeZkALXQ0iFsSE,8488
35
+ graphiti_core/llm_client/openai_base_client.py,sha256=HGt4CyyFCSZyCBwR__IbUUBF0V6Qwr9Ydu_XLtXPIA8,8533
36
36
  graphiti_core/llm_client/openai_client.py,sha256=AuaCFQFMJEGzBkFVouccq3XentmWRIKW0RLRBCUMm7Y,3763
37
- graphiti_core/llm_client/openai_generic_client.py,sha256=lyOQwzIMVb9pk3WWrU5zsG38J26QGKebxC40-lRYMJg,7007
37
+ graphiti_core/llm_client/openai_generic_client.py,sha256=UseKg9rCqXizAdG1xGGU-jnfwuWJCvVkf-legT0MqjQ,7052
38
38
  graphiti_core/llm_client/utils.py,sha256=zKpxXEbKa369m4W7RDEf-m56kH46V1Mx3RowcWZEWWs,1000
39
39
  graphiti_core/migrations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
40
40
  graphiti_core/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -69,13 +69,13 @@ graphiti_core/utils/datetime_utils.py,sha256=J-zYSq7-H-2n9hYOXNIun12kM10vNX9mMAT
69
69
  graphiti_core/utils/maintenance/__init__.py,sha256=vW4H1KyapTl-OOz578uZABYcpND4wPx3Vt6aAPaXh78,301
70
70
  graphiti_core/utils/maintenance/community_operations.py,sha256=3IMxfOacZAYtZKebyYtWJYNZPLOPlS8Il-lzitEkoos,10681
71
71
  graphiti_core/utils/maintenance/dedup_helpers.py,sha256=B7k6KkB6Sii8PZCWNNTvsNiy4BNTNWpoLeGgrPLq6BE,9220
72
- graphiti_core/utils/maintenance/edge_operations.py,sha256=9jbFNM1Qm0wJJr9BR6gXyMiRuDgClim0MspDMBQmW40,26404
72
+ graphiti_core/utils/maintenance/edge_operations.py,sha256=1hlcJRFnxthGkSr07QyDcOVug7N8dQj5aIENJ17JrpA,26564
73
73
  graphiti_core/utils/maintenance/graph_data_operations.py,sha256=42icj3S_ELAJ-NK3jVS_rg_243dmnaZOyUitJj_uJ-M,6085
74
- graphiti_core/utils/maintenance/node_operations.py,sha256=IKiqRqTeePTVFsl1X_N8DVRVAIhrSob7YPkuLvRM_Rk,18622
74
+ graphiti_core/utils/maintenance/node_operations.py,sha256=ARng4x_pCpfA3g4bM0BncOkxBPaQ2IsdIaYfVq3V3X0,19603
75
75
  graphiti_core/utils/maintenance/temporal_operations.py,sha256=wq1I4kqeIoswit6sPohug91FEwrGaVnJ06g1vkJjSLY,3442
76
76
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
77
77
  graphiti_core/utils/ontology_utils/entity_types_utils.py,sha256=4eVgxLWY6Q8k9cRJ5pW59IYF--U4nXZsZIGOVb_yHfQ,1285
78
- graphiti_core-0.21.0rc12.dist-info/METADATA,sha256=5lFgZ88TQ2wk5EMyx_S3eljGZ9RL8jdLvqxzUTukVgA,27085
79
- graphiti_core-0.21.0rc12.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
80
- graphiti_core-0.21.0rc12.dist-info/licenses/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
81
- graphiti_core-0.21.0rc12.dist-info/RECORD,,
78
+ graphiti_core-0.22.0rc0.dist-info/METADATA,sha256=uYwosUYjFpfCidLgBO8OuZ_P2fcrwNkRIkXn33lIwXk,27084
79
+ graphiti_core-0.22.0rc0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
80
+ graphiti_core-0.22.0rc0.dist-info/licenses/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
81
+ graphiti_core-0.22.0rc0.dist-info/RECORD,,