graphiti-core 0.24.3__py3-none-any.whl → 0.25.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -19,6 +19,7 @@ from collections.abc import Coroutine
19
19
  from typing import Any
20
20
 
21
21
  from neo4j import AsyncGraphDatabase, EagerResult
22
+ from neo4j.exceptions import ClientError
22
23
  from typing_extensions import LiteralString
23
24
 
24
25
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
@@ -88,6 +89,21 @@ class Neo4jDriver(GraphDriver):
88
89
  'CALL db.indexes() YIELD name DROP INDEX name',
89
90
  )
90
91
 
92
+ async def _execute_index_query(self, query: LiteralString) -> EagerResult | None:
93
+ """Execute an index creation query, ignoring 'index already exists' errors.
94
+
95
+ Neo4j can raise EquivalentSchemaRuleAlreadyExists when concurrent CREATE INDEX
96
+ IF NOT EXISTS queries race, even though the index exists. This is safe to ignore.
97
+ """
98
+ try:
99
+ return await self.execute_query(query)
100
+ except ClientError as e:
101
+ # Ignore "equivalent index already exists" error (race condition with IF NOT EXISTS)
102
+ if 'EquivalentSchemaRuleAlreadyExists' in str(e):
103
+ logger.debug(f'Index already exists (concurrent creation): {query[:50]}...')
104
+ return None
105
+ raise
106
+
91
107
  async def build_indices_and_constraints(self, delete_existing: bool = False):
92
108
  if delete_existing:
93
109
  await self.delete_all_indexes()
@@ -98,14 +114,7 @@ class Neo4jDriver(GraphDriver):
98
114
 
99
115
  index_queries: list[LiteralString] = range_indices + fulltext_indices
100
116
 
101
- await semaphore_gather(
102
- *[
103
- self.execute_query(
104
- query,
105
- )
106
- for query in index_queries
107
- ]
108
- )
117
+ await semaphore_gather(*[self._execute_index_query(query) for query in index_queries])
109
118
 
110
119
  async def health_check(self) -> None:
111
120
  """Check Neo4j connectivity by running the driver's verify_connectivity method."""
graphiti_core/graphiti.py CHANGED
@@ -35,6 +35,7 @@ from graphiti_core.edges import (
35
35
  create_entity_edge_embeddings,
36
36
  )
37
37
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
38
+ from graphiti_core.errors import NodeNotFoundError
38
39
  from graphiti_core.graphiti_types import GraphitiClients
39
40
  from graphiti_core.helpers import (
40
41
  get_default_group_id,
@@ -384,6 +385,7 @@ class Graphiti:
384
385
  edge_types: dict[str, type[BaseModel]] | None,
385
386
  nodes: list[EntityNode],
386
387
  uuid_map: dict[str, str],
388
+ custom_extraction_instructions: str | None = None,
387
389
  ) -> tuple[list[EntityEdge], list[EntityEdge]]:
388
390
  """Extract edges from episode and resolve against existing graph."""
389
391
  extracted_edges = await extract_edges(
@@ -394,6 +396,7 @@ class Graphiti:
394
396
  edge_type_map,
395
397
  group_id,
396
398
  edge_types,
399
+ custom_extraction_instructions,
397
400
  )
398
401
 
399
402
  edges = resolve_edge_pointers(extracted_edges, uuid_map)
@@ -627,6 +630,7 @@ class Graphiti:
627
630
  previous_episode_uuids: list[str] | None = None,
628
631
  edge_types: dict[str, type[BaseModel]] | None = None,
629
632
  edge_type_map: dict[tuple[str, str], list[str]] | None = None,
633
+ custom_extraction_instructions: str | None = None,
630
634
  ) -> AddEpisodeResults:
631
635
  """
632
636
  Process an episode and update the graph.
@@ -661,6 +665,9 @@ class Graphiti:
661
665
  previous_episode_uuids : list[str] | None
662
666
  Optional. list of episode uuids to use as the previous episodes. If this is not provided,
663
667
  the most recent episodes by created_at date will be used.
668
+ custom_extraction_instructions : str | None
669
+ Optional. Custom extraction instructions string to be included in the extract entities and extract edges prompts.
670
+ This allows for additional instructions or context to guide the extraction process.
664
671
 
665
672
  Returns
666
673
  -------
@@ -739,7 +746,12 @@ class Graphiti:
739
746
 
740
747
  # Extract and resolve nodes
741
748
  extracted_nodes = await extract_nodes(
742
- self.clients, episode, previous_episodes, entity_types, excluded_entity_types
749
+ self.clients,
750
+ episode,
751
+ previous_episodes,
752
+ entity_types,
753
+ excluded_entity_types,
754
+ custom_extraction_instructions,
743
755
  )
744
756
 
745
757
  nodes, uuid_map, _ = await resolve_extracted_nodes(
@@ -760,6 +772,7 @@ class Graphiti:
760
772
  edge_types,
761
773
  nodes,
762
774
  uuid_map,
775
+ custom_extraction_instructions,
763
776
  )
764
777
 
765
778
  # Extract node attributes
@@ -1176,12 +1189,47 @@ class Graphiti:
1176
1189
  if edge.fact_embedding is None:
1177
1190
  await edge.generate_embedding(self.embedder)
1178
1191
 
1179
- nodes, uuid_map, _ = await resolve_extracted_nodes(
1180
- self.clients,
1181
- [source_node, target_node],
1182
- )
1192
+ try:
1193
+ resolved_source = await EntityNode.get_by_uuid(self.driver, source_node.uuid)
1194
+ except NodeNotFoundError:
1195
+ resolved_source_nodes, _, _ = await resolve_extracted_nodes(
1196
+ self.clients,
1197
+ [source_node],
1198
+ )
1199
+ resolved_source = resolved_source_nodes[0]
1200
+
1201
+ try:
1202
+ resolved_target = await EntityNode.get_by_uuid(self.driver, target_node.uuid)
1203
+ except NodeNotFoundError:
1204
+ resolved_target_nodes, _, _ = await resolve_extracted_nodes(
1205
+ self.clients,
1206
+ [target_node],
1207
+ )
1208
+ resolved_target = resolved_target_nodes[0]
1209
+
1210
+ nodes = [resolved_source, resolved_target]
1211
+
1212
+ # Merge user-provided properties from original nodes into resolved nodes (excluding uuid)
1213
+ # Update attributes dictionary (merge rather than replace)
1214
+ if source_node.attributes:
1215
+ resolved_source.attributes.update(source_node.attributes)
1216
+ if target_node.attributes:
1217
+ resolved_target.attributes.update(target_node.attributes)
1218
+
1219
+ # Update summary if provided by user (non-empty string)
1220
+ if source_node.summary:
1221
+ resolved_source.summary = source_node.summary
1222
+ if target_node.summary:
1223
+ resolved_target.summary = target_node.summary
1224
+
1225
+ # Update labels (merge with existing)
1226
+ if source_node.labels:
1227
+ resolved_source.labels = list(set(resolved_source.labels) | set(source_node.labels))
1228
+ if target_node.labels:
1229
+ resolved_target.labels = list(set(resolved_target.labels) | set(target_node.labels))
1183
1230
 
1184
- updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1231
+ edge.source_node_uuid = resolved_source.uuid
1232
+ edge.target_node_uuid = resolved_target.uuid
1185
1233
 
1186
1234
  valid_edges = await EntityEdge.get_between_nodes(
1187
1235
  self.driver, edge.source_node_uuid, edge.target_node_uuid
@@ -1190,8 +1238,8 @@ class Graphiti:
1190
1238
  related_edges = (
1191
1239
  await search(
1192
1240
  self.clients,
1193
- updated_edge.fact,
1194
- group_ids=[updated_edge.group_id],
1241
+ edge.fact,
1242
+ group_ids=[edge.group_id],
1195
1243
  config=EDGE_HYBRID_SEARCH_RRF,
1196
1244
  search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1197
1245
  )
@@ -1199,8 +1247,8 @@ class Graphiti:
1199
1247
  existing_edges = (
1200
1248
  await search(
1201
1249
  self.clients,
1202
- updated_edge.fact,
1203
- group_ids=[updated_edge.group_id],
1250
+ edge.fact,
1251
+ group_ids=[edge.group_id],
1204
1252
  config=EDGE_HYBRID_SEARCH_RRF,
1205
1253
  search_filter=SearchFilters(),
1206
1254
  )
@@ -1208,7 +1256,7 @@ class Graphiti:
1208
1256
 
1209
1257
  resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
1210
1258
  self.llm_client,
1211
- updated_edge,
1259
+ edge,
1212
1260
  related_edges,
1213
1261
  existing_edges,
1214
1262
  EpisodicNode(
graphiti_core/helpers.py CHANGED
@@ -34,9 +34,24 @@ load_dotenv()
34
34
 
35
35
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
36
36
  SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
37
- MAX_REFLEXION_ITERATIONS = int(os.getenv('MAX_REFLEXION_ITERATIONS', 0))
38
37
  DEFAULT_PAGE_LIMIT = 20
39
38
 
39
+ # Content chunking configuration for entity extraction
40
+ # Density-based chunking: only chunk high-density content (many entities per token)
41
+ # This targets the failure case (large entity-dense inputs) while preserving
42
+ # context for prose/narrative content
43
+ CHUNK_TOKEN_SIZE = int(os.getenv('CHUNK_TOKEN_SIZE', 3000))
44
+ CHUNK_OVERLAP_TOKENS = int(os.getenv('CHUNK_OVERLAP_TOKENS', 200))
45
+ # Minimum tokens before considering chunking - short content processes fine regardless of density
46
+ CHUNK_MIN_TOKENS = int(os.getenv('CHUNK_MIN_TOKENS', 1000))
47
+ # Entity density threshold: chunk if estimated density > this value
48
+ # For JSON: elements per 1000 tokens > threshold * 1000 (e.g., 0.15 = 150 elements/1000 tokens)
49
+ # For Text: capitalized words per 1000 tokens > threshold * 500 (e.g., 0.15 = 75 caps/1000 tokens)
50
+ # Higher values = more conservative (less chunking), targets P95+ density cases
51
+ # Examples that trigger chunking at 0.15: AWS cost data (12mo), bulk data imports, entity-dense JSON
52
+ # Examples that DON'T chunk at 0.15: meeting transcripts, news articles, documentation
53
+ CHUNK_DENSITY_THRESHOLD = float(os.getenv('CHUNK_DENSITY_THRESHOLD', 0.15))
54
+
40
55
 
41
56
  def parse_db_date(input_date: neo4j_time.DateTime | str | None) -> datetime | None:
42
57
  if isinstance(input_date, neo4j_time.DateTime):
@@ -31,8 +31,8 @@ from .errors import RateLimitError, RefusalError
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- DEFAULT_MODEL = 'gpt-5-mini'
35
- DEFAULT_SMALL_MODEL = 'gpt-5-nano'
34
+ DEFAULT_MODEL = 'gpt-4.1-mini'
35
+ DEFAULT_SMALL_MODEL = 'gpt-4.1-nano'
36
36
  DEFAULT_REASONING = 'minimal'
37
37
  DEFAULT_VERBOSITY = 'low'
38
38
 
@@ -78,15 +78,25 @@ class OpenAIClient(BaseOpenAIClient):
78
78
  model.startswith('gpt-5') or model.startswith('o1') or model.startswith('o3')
79
79
  )
80
80
 
81
- response = await self.client.responses.parse(
82
- model=model,
83
- input=messages, # type: ignore
84
- temperature=temperature if not is_reasoning_model else None,
85
- max_output_tokens=max_tokens,
86
- text_format=response_model, # type: ignore
87
- reasoning={'effort': reasoning} if reasoning is not None else None, # type: ignore
88
- text={'verbosity': verbosity} if verbosity is not None else None, # type: ignore
89
- )
81
+ request_kwargs = {
82
+ 'model': model,
83
+ 'input': messages, # type: ignore
84
+ 'max_output_tokens': max_tokens,
85
+ 'text_format': response_model, # type: ignore
86
+ }
87
+
88
+ temperature_value = temperature if not is_reasoning_model else None
89
+ if temperature_value is not None:
90
+ request_kwargs['temperature'] = temperature_value
91
+
92
+ # Only include reasoning and verbosity parameters for reasoning models
93
+ if is_reasoning_model and reasoning is not None:
94
+ request_kwargs['reasoning'] = {'effort': reasoning} # type: ignore
95
+
96
+ if is_reasoning_model and verbosity is not None:
97
+ request_kwargs['text'] = {'verbosity': verbosity} # type: ignore
98
+
99
+ response = await self.client.responses.parse(**request_kwargs)
90
100
 
91
101
  return response
92
102
 
@@ -110,7 +110,7 @@ Only extract facts that:
110
110
  You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
111
111
 
112
112
 
113
- {context['custom_prompt']}
113
+ {context['custom_extraction_instructions']}
114
114
 
115
115
  # EXTRACTION RULES
116
116
 
@@ -124,7 +124,7 @@ reference entities. Only extract distinct entities from the CURRENT MESSAGE. Don
124
124
  5. **Formatting**:
125
125
  - Be **explicit and unambiguous** in naming entities (e.g., use full names when available).
126
126
 
127
- {context['custom_prompt']}
127
+ {context['custom_extraction_instructions']}
128
128
  """
129
129
  return [
130
130
  Message(role='system', content=sys_prompt),
@@ -148,7 +148,7 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
148
148
  {context['episode_content']}
149
149
  </JSON>
150
150
 
151
- {context['custom_prompt']}
151
+ {context['custom_extraction_instructions']}
152
152
 
153
153
  Given the above source description and JSON, extract relevant entities from the provided JSON.
154
154
  For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
@@ -182,7 +182,7 @@ Given the above text, extract entities from the TEXT that are explicitly or impl
182
182
  For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
183
183
  Indicate the classified entity type by providing its entity_type_id.
184
184
 
185
- {context['custom_prompt']}
185
+ {context['custom_extraction_instructions']}
186
186
 
187
187
  Guidelines:
188
188
  1. Extract significant entities, concepts, or actors mentioned in the conversation.
@@ -35,7 +35,7 @@ class ComparisonOperator(Enum):
35
35
 
36
36
 
37
37
  class DateFilter(BaseModel):
38
- date: datetime | None = Field(description='A datetime to filter on')
38
+ date: datetime | None = Field(default=None, description='A datetime to filter on')
39
39
  comparison_operator: ComparisonOperator = Field(
40
40
  description='Comparison operator for date filter'
41
41
  )
@@ -44,6 +44,7 @@ class DateFilter(BaseModel):
44
44
  class PropertyFilter(BaseModel):
45
45
  property_name: str = Field(description='Property name')
46
46
  property_value: str | int | float | None = Field(
47
+ default=None,
47
48
  description='Value you want to match on for the property'
48
49
  )
49
50
  comparison_operator: ComparisonOperator = Field(