graphiti-core 0.9.6__tar.gz → 0.10.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/PKG-INFO +6 -1
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/README.md +5 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/cross_encoder/openai_reranker_client.py +11 -11
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/graphiti.py +1 -1
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/client.py +7 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/config.py +1 -1
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/openai_client.py +5 -2
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/openai_generic_client.py +5 -2
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/nodes.py +25 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/eval.py +45 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/search.py +116 -19
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/search_config.py +21 -1
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/search_config_recipes.py +21 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/search_helpers.py +10 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/search_utils.py +61 -9
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/graph_data_operations.py +2 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/pyproject.toml +1 -1
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/LICENSE +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/cross_encoder/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/cross_encoder/client.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/edges.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/embedder/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/embedder/client.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/embedder/gemini.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/embedder/openai.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/embedder/voyage.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/helpers.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/anthropic_client.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/gemini_client.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/groq_client.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/utils.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/models/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/models/edges/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/models/edges/edge_db_queries.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/models/nodes/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/models/nodes/node_db_queries.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/lib.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/prompt_helpers.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/search/search_filters.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/datetime_utils.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/community_operations.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/utils.py +0 -0
- {graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/ontology_utils/entity_types_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.10.0
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -126,6 +126,11 @@ Requirements:
|
|
|
126
126
|
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
|
|
127
127
|
- OpenAI API key (for LLM inference and embedding)
|
|
128
128
|
|
|
129
|
+
> [!IMPORTANT]
|
|
130
|
+
> Graphiti works best with LLM services that support Structured Output (such as OpenAI and Gemini).
|
|
131
|
+
> Using other services may result in incorrect output schemas and ingestion failures. This is particularly
|
|
132
|
+
> problematic when using smaller models.
|
|
133
|
+
|
|
129
134
|
Optional:
|
|
130
135
|
|
|
131
136
|
- Google Gemini, Anthropic, or Groq API key (for alternative LLM providers)
|
|
@@ -94,6 +94,11 @@ Requirements:
|
|
|
94
94
|
- Neo4j 5.26 or higher (serves as the embeddings storage backend)
|
|
95
95
|
- OpenAI API key (for LLM inference and embedding)
|
|
96
96
|
|
|
97
|
+
> [!IMPORTANT]
|
|
98
|
+
> Graphiti works best with LLM services that support Structured Output (such as OpenAI and Gemini).
|
|
99
|
+
> Using other services may result in incorrect output schemas and ingestion failures. This is particularly
|
|
100
|
+
> problematic when using smaller models.
|
|
101
|
+
|
|
97
102
|
Optional:
|
|
98
103
|
|
|
99
104
|
- Google Gemini, Anthropic, or Groq API key (for alternative LLM providers)
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/cross_encoder/openai_reranker_client.py
RENAMED
|
@@ -17,9 +17,9 @@ limitations under the License.
|
|
|
17
17
|
import logging
|
|
18
18
|
from typing import Any
|
|
19
19
|
|
|
20
|
+
import numpy as np
|
|
20
21
|
import openai
|
|
21
22
|
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
|
22
|
-
from pydantic import BaseModel
|
|
23
23
|
|
|
24
24
|
from ..helpers import semaphore_gather
|
|
25
25
|
from ..llm_client import LLMConfig, RateLimitError
|
|
@@ -28,11 +28,7 @@ from .client import CrossEncoderClient
|
|
|
28
28
|
|
|
29
29
|
logger = logging.getLogger(__name__)
|
|
30
30
|
|
|
31
|
-
DEFAULT_MODEL = 'gpt-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class BooleanClassifier(BaseModel):
|
|
35
|
-
isTrue: bool
|
|
31
|
+
DEFAULT_MODEL = 'gpt-4.1-nano'
|
|
36
32
|
|
|
37
33
|
|
|
38
34
|
class OpenAIRerankerClient(CrossEncoderClient):
|
|
@@ -107,11 +103,15 @@ class OpenAIRerankerClient(CrossEncoderClient):
|
|
|
107
103
|
]
|
|
108
104
|
scores: list[float] = []
|
|
109
105
|
for top_logprobs in responses_top_logprobs:
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
106
|
+
if len(top_logprobs) == 0:
|
|
107
|
+
continue
|
|
108
|
+
norm_logprobs = np.exp(top_logprobs[0].logprob)
|
|
109
|
+
if bool(top_logprobs[0].token):
|
|
110
|
+
scores.append(norm_logprobs)
|
|
111
|
+
else:
|
|
112
|
+
scores.append(1 - norm_logprobs)
|
|
113
|
+
|
|
114
|
+
results = [(passage, score) for passage, score in zip(passages, scores, strict=True)]
|
|
115
115
|
results.sort(reverse=True, key=lambda x: x[1])
|
|
116
116
|
return results
|
|
117
117
|
except openai.RateLimitError as e:
|
|
@@ -750,7 +750,7 @@ class Graphiti:
|
|
|
750
750
|
|
|
751
751
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
752
752
|
|
|
753
|
-
return SearchResults(edges=edges, nodes=nodes, communities=[])
|
|
753
|
+
return SearchResults(edges=edges, nodes=nodes, episodes=[], communities=[])
|
|
754
754
|
|
|
755
755
|
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
|
|
756
756
|
if source_node.name_embedding is None:
|
|
@@ -32,6 +32,10 @@ from .errors import RateLimitError
|
|
|
32
32
|
DEFAULT_TEMPERATURE = 0
|
|
33
33
|
DEFAULT_CACHE_DIR = './llm_cache'
|
|
34
34
|
|
|
35
|
+
MULTILINGUAL_EXTRACTION_RESPONSES = (
|
|
36
|
+
'\n\nAny extracted information should be returned in the same language as it was written in.'
|
|
37
|
+
)
|
|
38
|
+
|
|
35
39
|
logger = logging.getLogger(__name__)
|
|
36
40
|
|
|
37
41
|
|
|
@@ -133,6 +137,9 @@ class LLMClient(ABC):
|
|
|
133
137
|
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
134
138
|
)
|
|
135
139
|
|
|
140
|
+
# Add multilingual extraction instructions
|
|
141
|
+
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
|
142
|
+
|
|
136
143
|
if self.cache_enabled and self.cache_dir is not None:
|
|
137
144
|
cache_key = self._get_cache_key(messages)
|
|
138
145
|
|
|
@@ -43,7 +43,7 @@ class LLMConfig:
|
|
|
43
43
|
This is required for making authorized requests.
|
|
44
44
|
|
|
45
45
|
model (str, optional): The specific LLM model to use for generating responses.
|
|
46
|
-
Defaults to "gpt-
|
|
46
|
+
Defaults to "gpt-4.1-mini", which appears to be a custom model name.
|
|
47
47
|
Common values might include "gpt-3.5-turbo" or "gpt-4".
|
|
48
48
|
|
|
49
49
|
base_url (str, optional): The base URL of the LLM API service.
|
|
@@ -24,13 +24,13 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
24
24
|
from pydantic import BaseModel
|
|
25
25
|
|
|
26
26
|
from ..prompts.models import Message
|
|
27
|
-
from .client import LLMClient
|
|
27
|
+
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
|
28
28
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
29
29
|
from .errors import RateLimitError, RefusalError
|
|
30
30
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
32
32
|
|
|
33
|
-
DEFAULT_MODEL = 'gpt-
|
|
33
|
+
DEFAULT_MODEL = 'gpt-4.1-mini'
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
class OpenAIClient(LLMClient):
|
|
@@ -136,6 +136,9 @@ class OpenAIClient(LLMClient):
|
|
|
136
136
|
retry_count = 0
|
|
137
137
|
last_error = None
|
|
138
138
|
|
|
139
|
+
# Add multilingual extraction instructions
|
|
140
|
+
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
|
141
|
+
|
|
139
142
|
while retry_count <= self.MAX_RETRIES:
|
|
140
143
|
try:
|
|
141
144
|
response = await self._generate_response(messages, response_model, max_tokens)
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/llm_client/openai_generic_client.py
RENAMED
|
@@ -25,13 +25,13 @@ from openai.types.chat import ChatCompletionMessageParam
|
|
|
25
25
|
from pydantic import BaseModel
|
|
26
26
|
|
|
27
27
|
from ..prompts.models import Message
|
|
28
|
-
from .client import LLMClient
|
|
28
|
+
from .client import MULTILINGUAL_EXTRACTION_RESPONSES, LLMClient
|
|
29
29
|
from .config import DEFAULT_MAX_TOKENS, LLMConfig
|
|
30
30
|
from .errors import RateLimitError, RefusalError
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
33
33
|
|
|
34
|
-
DEFAULT_MODEL = 'gpt-
|
|
34
|
+
DEFAULT_MODEL = 'gpt-4.1-mini'
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
class OpenAIGenericClient(LLMClient):
|
|
@@ -130,6 +130,9 @@ class OpenAIGenericClient(LLMClient):
|
|
|
130
130
|
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
|
|
131
131
|
)
|
|
132
132
|
|
|
133
|
+
# Add multilingual extraction instructions
|
|
134
|
+
messages[0].content += MULTILINGUAL_EXTRACTION_RESPONSES
|
|
135
|
+
|
|
133
136
|
while retry_count <= self.MAX_RETRIES:
|
|
134
137
|
try:
|
|
135
138
|
response = await self._generate_response(
|
|
@@ -251,6 +251,31 @@ class EpisodicNode(Node):
|
|
|
251
251
|
|
|
252
252
|
return episodes
|
|
253
253
|
|
|
254
|
+
@classmethod
|
|
255
|
+
async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
|
|
256
|
+
records, _, _ = await driver.execute_query(
|
|
257
|
+
"""
|
|
258
|
+
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
259
|
+
RETURN DISTINCT
|
|
260
|
+
e.content AS content,
|
|
261
|
+
e.created_at AS created_at,
|
|
262
|
+
e.valid_at AS valid_at,
|
|
263
|
+
e.uuid AS uuid,
|
|
264
|
+
e.name AS name,
|
|
265
|
+
e.group_id AS group_id,
|
|
266
|
+
e.source_description AS source_description,
|
|
267
|
+
e.source AS source,
|
|
268
|
+
e.entity_edges AS entity_edges
|
|
269
|
+
""",
|
|
270
|
+
entity_node_uuid=entity_node_uuid,
|
|
271
|
+
database_=DEFAULT_DATABASE,
|
|
272
|
+
routing_='r',
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
276
|
+
|
|
277
|
+
return episodes
|
|
278
|
+
|
|
254
279
|
|
|
255
280
|
class EntityNode(Node):
|
|
256
281
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
@@ -37,16 +37,28 @@ class EvalResponse(BaseModel):
|
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
class EvalAddEpisodeResults(BaseModel):
|
|
41
|
+
candidate_is_worse: bool = Field(
|
|
42
|
+
...,
|
|
43
|
+
description='boolean if the baseline extraction is higher quality than the candidate extraction.',
|
|
44
|
+
)
|
|
45
|
+
reasoning: str = Field(
|
|
46
|
+
..., description='why you determined the response was correct or incorrect'
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
40
50
|
class Prompt(Protocol):
|
|
41
51
|
qa_prompt: PromptVersion
|
|
42
52
|
eval_prompt: PromptVersion
|
|
43
53
|
query_expansion: PromptVersion
|
|
54
|
+
eval_add_episode_results: PromptVersion
|
|
44
55
|
|
|
45
56
|
|
|
46
57
|
class Versions(TypedDict):
|
|
47
58
|
qa_prompt: PromptFunction
|
|
48
59
|
eval_prompt: PromptFunction
|
|
49
60
|
query_expansion: PromptFunction
|
|
61
|
+
eval_add_episode_results: PromptFunction
|
|
50
62
|
|
|
51
63
|
|
|
52
64
|
def query_expansion(context: dict[str, Any]) -> list[Message]:
|
|
@@ -112,8 +124,41 @@ def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|
|
112
124
|
]
|
|
113
125
|
|
|
114
126
|
|
|
127
|
+
def eval_add_episode_results(context: dict[str, Any]) -> list[Message]:
|
|
128
|
+
sys_prompt = """You are a judge that determines whether a baseline graph building result from a list of messages is better
|
|
129
|
+
than a candidate graph building result based on the same messages."""
|
|
130
|
+
|
|
131
|
+
user_prompt = f"""
|
|
132
|
+
Given the following PREVIOUS MESSAGES and MESSAGE, determine if the BASELINE graph data extracted from the
|
|
133
|
+
conversation is higher quality than the CANDIDATE graph data extracted from the conversation.
|
|
134
|
+
|
|
135
|
+
Return False if the BASELINE extraction is better, and True otherwise. If the CANDIDATE extraction and
|
|
136
|
+
BASELINE extraction are nearly identical in quality, return True. Add your reasoning for your decision to the reasoning field
|
|
137
|
+
|
|
138
|
+
<PREVIOUS MESSAGES>
|
|
139
|
+
{context['previous_messages']}
|
|
140
|
+
</PREVIOUS MESSAGES>
|
|
141
|
+
<MESSAGE>
|
|
142
|
+
{context['message']}
|
|
143
|
+
</MESSAGE>
|
|
144
|
+
|
|
145
|
+
<BASELINE>
|
|
146
|
+
{context['baseline']}
|
|
147
|
+
</BASELINE>
|
|
148
|
+
|
|
149
|
+
<CANDIDATE>
|
|
150
|
+
{context['candidate']}
|
|
151
|
+
</CANDIDATE>
|
|
152
|
+
"""
|
|
153
|
+
return [
|
|
154
|
+
Message(role='system', content=sys_prompt),
|
|
155
|
+
Message(role='user', content=user_prompt),
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
|
|
115
159
|
versions: Versions = {
|
|
116
160
|
'qa_prompt': qa_prompt,
|
|
117
161
|
'eval_prompt': eval_prompt,
|
|
118
162
|
'query_expansion': query_expansion,
|
|
163
|
+
'eval_add_episode_results': eval_add_episode_results,
|
|
119
164
|
}
|
|
@@ -25,7 +25,7 @@ from graphiti_core.edges import EntityEdge
|
|
|
25
25
|
from graphiti_core.embedder import EmbedderClient
|
|
26
26
|
from graphiti_core.errors import SearchRerankerError
|
|
27
27
|
from graphiti_core.helpers import semaphore_gather
|
|
28
|
-
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
28
|
+
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
29
29
|
from graphiti_core.search.search_config import (
|
|
30
30
|
DEFAULT_SEARCH_LIMIT,
|
|
31
31
|
CommunityReranker,
|
|
@@ -33,6 +33,8 @@ from graphiti_core.search.search_config import (
|
|
|
33
33
|
EdgeReranker,
|
|
34
34
|
EdgeSearchConfig,
|
|
35
35
|
EdgeSearchMethod,
|
|
36
|
+
EpisodeReranker,
|
|
37
|
+
EpisodeSearchConfig,
|
|
36
38
|
NodeReranker,
|
|
37
39
|
NodeSearchConfig,
|
|
38
40
|
NodeSearchMethod,
|
|
@@ -46,6 +48,7 @@ from graphiti_core.search.search_utils import (
|
|
|
46
48
|
edge_bfs_search,
|
|
47
49
|
edge_fulltext_search,
|
|
48
50
|
edge_similarity_search,
|
|
51
|
+
episode_fulltext_search,
|
|
49
52
|
episode_mentions_reranker,
|
|
50
53
|
maximal_marginal_relevance,
|
|
51
54
|
node_bfs_search,
|
|
@@ -74,13 +77,14 @@ async def search(
|
|
|
74
77
|
return SearchResults(
|
|
75
78
|
edges=[],
|
|
76
79
|
nodes=[],
|
|
80
|
+
episodes=[],
|
|
77
81
|
communities=[],
|
|
78
82
|
)
|
|
79
83
|
query_vector = await embedder.create(input_data=[query.replace('\n', ' ')])
|
|
80
84
|
|
|
81
85
|
# if group_ids is empty, set it to None
|
|
82
86
|
group_ids = group_ids if group_ids else None
|
|
83
|
-
edges, nodes, communities = await semaphore_gather(
|
|
87
|
+
edges, nodes, episodes, communities = await semaphore_gather(
|
|
84
88
|
edge_search(
|
|
85
89
|
driver,
|
|
86
90
|
cross_encoder,
|
|
@@ -92,6 +96,7 @@ async def search(
|
|
|
92
96
|
center_node_uuid,
|
|
93
97
|
bfs_origin_node_uuids,
|
|
94
98
|
config.limit,
|
|
99
|
+
config.reranker_min_score,
|
|
95
100
|
),
|
|
96
101
|
node_search(
|
|
97
102
|
driver,
|
|
@@ -104,6 +109,18 @@ async def search(
|
|
|
104
109
|
center_node_uuid,
|
|
105
110
|
bfs_origin_node_uuids,
|
|
106
111
|
config.limit,
|
|
112
|
+
config.reranker_min_score,
|
|
113
|
+
),
|
|
114
|
+
episode_search(
|
|
115
|
+
driver,
|
|
116
|
+
cross_encoder,
|
|
117
|
+
query,
|
|
118
|
+
query_vector,
|
|
119
|
+
group_ids,
|
|
120
|
+
config.episode_config,
|
|
121
|
+
search_filter,
|
|
122
|
+
config.limit,
|
|
123
|
+
config.reranker_min_score,
|
|
107
124
|
),
|
|
108
125
|
community_search(
|
|
109
126
|
driver,
|
|
@@ -112,14 +129,15 @@ async def search(
|
|
|
112
129
|
query_vector,
|
|
113
130
|
group_ids,
|
|
114
131
|
config.community_config,
|
|
115
|
-
bfs_origin_node_uuids,
|
|
116
132
|
config.limit,
|
|
133
|
+
config.reranker_min_score,
|
|
117
134
|
),
|
|
118
135
|
)
|
|
119
136
|
|
|
120
137
|
results = SearchResults(
|
|
121
138
|
edges=edges,
|
|
122
139
|
nodes=nodes,
|
|
140
|
+
episodes=episodes,
|
|
123
141
|
communities=communities,
|
|
124
142
|
)
|
|
125
143
|
|
|
@@ -141,6 +159,7 @@ async def edge_search(
|
|
|
141
159
|
center_node_uuid: str | None = None,
|
|
142
160
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
143
161
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
162
|
+
reranker_min_score: float = 0,
|
|
144
163
|
) -> list[EntityEdge]:
|
|
145
164
|
if config is None:
|
|
146
165
|
return []
|
|
@@ -180,7 +199,7 @@ async def edge_search(
|
|
|
180
199
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
181
200
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
182
201
|
|
|
183
|
-
reranked_uuids = rrf(search_result_uuids)
|
|
202
|
+
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
184
203
|
elif config.reranker == EdgeReranker.mmr:
|
|
185
204
|
search_result_uuids_and_vectors = [
|
|
186
205
|
(edge.uuid, edge.fact_embedding if edge.fact_embedding is not None else [0.0] * 1024)
|
|
@@ -188,23 +207,31 @@ async def edge_search(
|
|
|
188
207
|
for edge in result
|
|
189
208
|
]
|
|
190
209
|
reranked_uuids = maximal_marginal_relevance(
|
|
191
|
-
query_vector,
|
|
210
|
+
query_vector,
|
|
211
|
+
search_result_uuids_and_vectors,
|
|
212
|
+
config.mmr_lambda,
|
|
213
|
+
min_score=reranker_min_score,
|
|
192
214
|
)
|
|
193
215
|
elif config.reranker == EdgeReranker.cross_encoder:
|
|
194
216
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
195
217
|
|
|
196
|
-
rrf_result_uuids = rrf(search_result_uuids)
|
|
218
|
+
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
197
219
|
rrf_edges = [edge_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
198
220
|
|
|
199
221
|
fact_to_uuid_map = {edge.fact: edge.uuid for edge in rrf_edges}
|
|
200
222
|
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
|
|
201
|
-
reranked_uuids = [
|
|
223
|
+
reranked_uuids = [
|
|
224
|
+
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
225
|
+
]
|
|
202
226
|
elif config.reranker == EdgeReranker.node_distance:
|
|
203
227
|
if center_node_uuid is None:
|
|
204
228
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
205
229
|
|
|
206
230
|
# use rrf as a preliminary sort
|
|
207
|
-
sorted_result_uuids = rrf(
|
|
231
|
+
sorted_result_uuids = rrf(
|
|
232
|
+
[[edge.uuid for edge in result] for result in search_results],
|
|
233
|
+
min_score=reranker_min_score,
|
|
234
|
+
)
|
|
208
235
|
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
|
|
209
236
|
|
|
210
237
|
# node distance reranking
|
|
@@ -214,7 +241,9 @@ async def edge_search(
|
|
|
214
241
|
|
|
215
242
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
|
216
243
|
|
|
217
|
-
reranked_node_uuids = await node_distance_reranker(
|
|
244
|
+
reranked_node_uuids = await node_distance_reranker(
|
|
245
|
+
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
|
246
|
+
)
|
|
218
247
|
|
|
219
248
|
for node_uuid in reranked_node_uuids:
|
|
220
249
|
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
|
|
@@ -238,6 +267,7 @@ async def node_search(
|
|
|
238
267
|
center_node_uuid: str | None = None,
|
|
239
268
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
240
269
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
270
|
+
reranker_min_score: float = 0,
|
|
241
271
|
) -> list[EntityNode]:
|
|
242
272
|
if config is None:
|
|
243
273
|
return []
|
|
@@ -269,7 +299,7 @@ async def node_search(
|
|
|
269
299
|
|
|
270
300
|
reranked_uuids: list[str] = []
|
|
271
301
|
if config.reranker == NodeReranker.rrf:
|
|
272
|
-
reranked_uuids = rrf(search_result_uuids)
|
|
302
|
+
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
273
303
|
elif config.reranker == NodeReranker.mmr:
|
|
274
304
|
search_result_uuids_and_vectors = [
|
|
275
305
|
(node.uuid, node.name_embedding if node.name_embedding is not None else [0.0] * 1024)
|
|
@@ -277,24 +307,36 @@ async def node_search(
|
|
|
277
307
|
for node in result
|
|
278
308
|
]
|
|
279
309
|
reranked_uuids = maximal_marginal_relevance(
|
|
280
|
-
query_vector,
|
|
310
|
+
query_vector,
|
|
311
|
+
search_result_uuids_and_vectors,
|
|
312
|
+
config.mmr_lambda,
|
|
313
|
+
min_score=reranker_min_score,
|
|
281
314
|
)
|
|
282
315
|
elif config.reranker == NodeReranker.cross_encoder:
|
|
283
316
|
# use rrf as a preliminary reranker
|
|
284
|
-
rrf_result_uuids = rrf(search_result_uuids)
|
|
317
|
+
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
285
318
|
rrf_results = [node_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
286
319
|
|
|
287
320
|
summary_to_uuid_map = {node.summary: node.uuid for node in rrf_results}
|
|
288
321
|
|
|
289
322
|
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
290
|
-
reranked_uuids = [
|
|
323
|
+
reranked_uuids = [
|
|
324
|
+
summary_to_uuid_map[fact]
|
|
325
|
+
for fact, score in reranked_summaries
|
|
326
|
+
if score >= reranker_min_score
|
|
327
|
+
]
|
|
291
328
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
292
|
-
reranked_uuids = await episode_mentions_reranker(
|
|
329
|
+
reranked_uuids = await episode_mentions_reranker(
|
|
330
|
+
driver, search_result_uuids, min_score=reranker_min_score
|
|
331
|
+
)
|
|
293
332
|
elif config.reranker == NodeReranker.node_distance:
|
|
294
333
|
if center_node_uuid is None:
|
|
295
334
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
296
335
|
reranked_uuids = await node_distance_reranker(
|
|
297
|
-
driver,
|
|
336
|
+
driver,
|
|
337
|
+
rrf(search_result_uuids, min_score=reranker_min_score),
|
|
338
|
+
center_node_uuid,
|
|
339
|
+
min_score=reranker_min_score,
|
|
298
340
|
)
|
|
299
341
|
|
|
300
342
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
@@ -302,6 +344,54 @@ async def node_search(
|
|
|
302
344
|
return reranked_nodes[:limit]
|
|
303
345
|
|
|
304
346
|
|
|
347
|
+
async def episode_search(
|
|
348
|
+
driver: AsyncDriver,
|
|
349
|
+
cross_encoder: CrossEncoderClient,
|
|
350
|
+
query: str,
|
|
351
|
+
_query_vector: list[float],
|
|
352
|
+
group_ids: list[str] | None,
|
|
353
|
+
config: EpisodeSearchConfig | None,
|
|
354
|
+
search_filter: SearchFilters,
|
|
355
|
+
limit=DEFAULT_SEARCH_LIMIT,
|
|
356
|
+
reranker_min_score: float = 0,
|
|
357
|
+
) -> list[EpisodicNode]:
|
|
358
|
+
if config is None:
|
|
359
|
+
return []
|
|
360
|
+
|
|
361
|
+
search_results: list[list[EpisodicNode]] = list(
|
|
362
|
+
await semaphore_gather(
|
|
363
|
+
*[
|
|
364
|
+
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
|
|
365
|
+
]
|
|
366
|
+
)
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
|
|
370
|
+
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
|
371
|
+
|
|
372
|
+
reranked_uuids: list[str] = []
|
|
373
|
+
if config.reranker == EpisodeReranker.rrf:
|
|
374
|
+
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
375
|
+
|
|
376
|
+
elif config.reranker == EpisodeReranker.cross_encoder:
|
|
377
|
+
# use rrf as a preliminary reranker
|
|
378
|
+
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
379
|
+
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
380
|
+
|
|
381
|
+
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
|
382
|
+
|
|
383
|
+
reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
|
|
384
|
+
reranked_uuids = [
|
|
385
|
+
content_to_uuid_map[content]
|
|
386
|
+
for content, score in reranked_contents
|
|
387
|
+
if score >= reranker_min_score
|
|
388
|
+
]
|
|
389
|
+
|
|
390
|
+
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
|
391
|
+
|
|
392
|
+
return reranked_episodes[:limit]
|
|
393
|
+
|
|
394
|
+
|
|
305
395
|
async def community_search(
|
|
306
396
|
driver: AsyncDriver,
|
|
307
397
|
cross_encoder: CrossEncoderClient,
|
|
@@ -309,8 +399,8 @@ async def community_search(
|
|
|
309
399
|
query_vector: list[float],
|
|
310
400
|
group_ids: list[str] | None,
|
|
311
401
|
config: CommunitySearchConfig | None,
|
|
312
|
-
bfs_origin_node_uuids: list[str] | None = None,
|
|
313
402
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
403
|
+
reranker_min_score: float = 0,
|
|
314
404
|
) -> list[CommunityNode]:
|
|
315
405
|
if config is None:
|
|
316
406
|
return []
|
|
@@ -333,7 +423,7 @@ async def community_search(
|
|
|
333
423
|
|
|
334
424
|
reranked_uuids: list[str] = []
|
|
335
425
|
if config.reranker == CommunityReranker.rrf:
|
|
336
|
-
reranked_uuids = rrf(search_result_uuids)
|
|
426
|
+
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
337
427
|
elif config.reranker == CommunityReranker.mmr:
|
|
338
428
|
search_result_uuids_and_vectors = [
|
|
339
429
|
(
|
|
@@ -344,14 +434,21 @@ async def community_search(
|
|
|
344
434
|
for community in result
|
|
345
435
|
]
|
|
346
436
|
reranked_uuids = maximal_marginal_relevance(
|
|
347
|
-
query_vector,
|
|
437
|
+
query_vector,
|
|
438
|
+
search_result_uuids_and_vectors,
|
|
439
|
+
config.mmr_lambda,
|
|
440
|
+
min_score=reranker_min_score,
|
|
348
441
|
)
|
|
349
442
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
350
443
|
summary_to_uuid_map = {
|
|
351
444
|
node.summary: node.uuid for result in search_results for node in result
|
|
352
445
|
}
|
|
353
446
|
reranked_summaries = await cross_encoder.rank(query, list(summary_to_uuid_map.keys()))
|
|
354
|
-
reranked_uuids = [
|
|
447
|
+
reranked_uuids = [
|
|
448
|
+
summary_to_uuid_map[fact]
|
|
449
|
+
for fact, score in reranked_summaries
|
|
450
|
+
if score >= reranker_min_score
|
|
451
|
+
]
|
|
355
452
|
|
|
356
453
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
357
454
|
|
|
@@ -19,7 +19,7 @@ from enum import Enum
|
|
|
19
19
|
from pydantic import BaseModel, Field
|
|
20
20
|
|
|
21
21
|
from graphiti_core.edges import EntityEdge
|
|
22
|
-
from graphiti_core.nodes import CommunityNode, EntityNode
|
|
22
|
+
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
|
|
23
23
|
from graphiti_core.search.search_utils import (
|
|
24
24
|
DEFAULT_MIN_SCORE,
|
|
25
25
|
DEFAULT_MMR_LAMBDA,
|
|
@@ -41,6 +41,10 @@ class NodeSearchMethod(Enum):
|
|
|
41
41
|
bfs = 'breadth_first_search'
|
|
42
42
|
|
|
43
43
|
|
|
44
|
+
class EpisodeSearchMethod(Enum):
|
|
45
|
+
bm25 = 'bm25'
|
|
46
|
+
|
|
47
|
+
|
|
44
48
|
class CommunitySearchMethod(Enum):
|
|
45
49
|
cosine_similarity = 'cosine_similarity'
|
|
46
50
|
bm25 = 'bm25'
|
|
@@ -62,6 +66,11 @@ class NodeReranker(Enum):
|
|
|
62
66
|
cross_encoder = 'cross_encoder'
|
|
63
67
|
|
|
64
68
|
|
|
69
|
+
class EpisodeReranker(Enum):
|
|
70
|
+
rrf = 'reciprocal_rank_fusion'
|
|
71
|
+
cross_encoder = 'cross_encoder'
|
|
72
|
+
|
|
73
|
+
|
|
65
74
|
class CommunityReranker(Enum):
|
|
66
75
|
rrf = 'reciprocal_rank_fusion'
|
|
67
76
|
mmr = 'mmr'
|
|
@@ -84,6 +93,14 @@ class NodeSearchConfig(BaseModel):
|
|
|
84
93
|
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
85
94
|
|
|
86
95
|
|
|
96
|
+
class EpisodeSearchConfig(BaseModel):
|
|
97
|
+
search_methods: list[EpisodeSearchMethod]
|
|
98
|
+
reranker: EpisodeReranker = Field(default=EpisodeReranker.rrf)
|
|
99
|
+
sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
|
|
100
|
+
mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
|
|
101
|
+
bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
|
|
102
|
+
|
|
103
|
+
|
|
87
104
|
class CommunitySearchConfig(BaseModel):
|
|
88
105
|
search_methods: list[CommunitySearchMethod]
|
|
89
106
|
reranker: CommunityReranker = Field(default=CommunityReranker.rrf)
|
|
@@ -95,11 +112,14 @@ class CommunitySearchConfig(BaseModel):
|
|
|
95
112
|
class SearchConfig(BaseModel):
|
|
96
113
|
edge_config: EdgeSearchConfig | None = Field(default=None)
|
|
97
114
|
node_config: NodeSearchConfig | None = Field(default=None)
|
|
115
|
+
episode_config: EpisodeSearchConfig | None = Field(default=None)
|
|
98
116
|
community_config: CommunitySearchConfig | None = Field(default=None)
|
|
99
117
|
limit: int = Field(default=DEFAULT_SEARCH_LIMIT)
|
|
118
|
+
reranker_min_score: float = Field(default=0)
|
|
100
119
|
|
|
101
120
|
|
|
102
121
|
class SearchResults(BaseModel):
|
|
103
122
|
edges: list[EntityEdge]
|
|
104
123
|
nodes: list[EntityNode]
|
|
124
|
+
episodes: list[EpisodicNode]
|
|
105
125
|
communities: list[CommunityNode]
|
|
@@ -21,6 +21,9 @@ from graphiti_core.search.search_config import (
|
|
|
21
21
|
EdgeReranker,
|
|
22
22
|
EdgeSearchConfig,
|
|
23
23
|
EdgeSearchMethod,
|
|
24
|
+
EpisodeReranker,
|
|
25
|
+
EpisodeSearchConfig,
|
|
26
|
+
EpisodeSearchMethod,
|
|
24
27
|
NodeReranker,
|
|
25
28
|
NodeSearchConfig,
|
|
26
29
|
NodeSearchMethod,
|
|
@@ -37,6 +40,12 @@ COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
|
|
|
37
40
|
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
|
|
38
41
|
reranker=NodeReranker.rrf,
|
|
39
42
|
),
|
|
43
|
+
episode_config=EpisodeSearchConfig(
|
|
44
|
+
search_methods=[
|
|
45
|
+
EpisodeSearchMethod.bm25,
|
|
46
|
+
],
|
|
47
|
+
reranker=EpisodeReranker.rrf,
|
|
48
|
+
),
|
|
40
49
|
community_config=CommunitySearchConfig(
|
|
41
50
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
42
51
|
reranker=CommunityReranker.rrf,
|
|
@@ -55,6 +64,12 @@ COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
|
|
|
55
64
|
reranker=NodeReranker.mmr,
|
|
56
65
|
mmr_lambda=1,
|
|
57
66
|
),
|
|
67
|
+
episode_config=EpisodeSearchConfig(
|
|
68
|
+
search_methods=[
|
|
69
|
+
EpisodeSearchMethod.bm25,
|
|
70
|
+
],
|
|
71
|
+
reranker=EpisodeReranker.rrf,
|
|
72
|
+
),
|
|
58
73
|
community_config=CommunitySearchConfig(
|
|
59
74
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
60
75
|
reranker=CommunityReranker.mmr,
|
|
@@ -80,6 +95,12 @@ COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
|
|
|
80
95
|
],
|
|
81
96
|
reranker=NodeReranker.cross_encoder,
|
|
82
97
|
),
|
|
98
|
+
episode_config=EpisodeSearchConfig(
|
|
99
|
+
search_methods=[
|
|
100
|
+
EpisodeSearchMethod.bm25,
|
|
101
|
+
],
|
|
102
|
+
reranker=EpisodeReranker.cross_encoder,
|
|
103
|
+
),
|
|
83
104
|
community_config=CommunitySearchConfig(
|
|
84
105
|
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
|
|
85
106
|
reranker=CommunityReranker.cross_encoder,
|
|
@@ -38,6 +38,13 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
|
|
|
38
38
|
entity_json = [
|
|
39
39
|
{'entity_name': node.name, 'summary': node.summary} for node in search_results.nodes
|
|
40
40
|
]
|
|
41
|
+
episode_json = [
|
|
42
|
+
{
|
|
43
|
+
'source_description': episode.source_description,
|
|
44
|
+
'content': episode.content,
|
|
45
|
+
}
|
|
46
|
+
for episode in search_results.episodes
|
|
47
|
+
]
|
|
41
48
|
community_json = [
|
|
42
49
|
{'community_name': community.name, 'summary': community.summary}
|
|
43
50
|
for community in search_results.communities
|
|
@@ -55,6 +62,9 @@ def search_results_to_context_string(search_results: SearchResults) -> str:
|
|
|
55
62
|
<ENTITIES>
|
|
56
63
|
{json.dumps(entity_json, indent=12)}
|
|
57
64
|
</ENTITIES>
|
|
65
|
+
<EPISODES>
|
|
66
|
+
{json.dumps(episode_json, indent=12)}
|
|
67
|
+
</EPISODES>
|
|
58
68
|
<COMMUNITIES>
|
|
59
69
|
{json.dumps(community_json, indent=12)}
|
|
60
70
|
</COMMUNITIES>
|
|
@@ -37,6 +37,7 @@ from graphiti_core.nodes import (
|
|
|
37
37
|
EpisodicNode,
|
|
38
38
|
get_community_node_from_record,
|
|
39
39
|
get_entity_node_from_record,
|
|
40
|
+
get_episodic_node_from_record,
|
|
40
41
|
)
|
|
41
42
|
from graphiti_core.search.search_filters import (
|
|
42
43
|
SearchFilters,
|
|
@@ -229,8 +230,8 @@ async def edge_similarity_search(
|
|
|
229
230
|
|
|
230
231
|
query: LiteralString = (
|
|
231
232
|
"""
|
|
232
|
-
|
|
233
|
-
|
|
233
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
234
|
+
"""
|
|
234
235
|
+ group_filter_query
|
|
235
236
|
+ filter_query
|
|
236
237
|
+ """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
|
|
@@ -475,6 +476,48 @@ async def node_bfs_search(
|
|
|
475
476
|
return nodes
|
|
476
477
|
|
|
477
478
|
|
|
479
|
+
async def episode_fulltext_search(
|
|
480
|
+
driver: AsyncDriver,
|
|
481
|
+
query: str,
|
|
482
|
+
_search_filter: SearchFilters,
|
|
483
|
+
group_ids: list[str] | None = None,
|
|
484
|
+
limit=RELEVANT_SCHEMA_LIMIT,
|
|
485
|
+
) -> list[EpisodicNode]:
|
|
486
|
+
# BM25 search to get top episodes
|
|
487
|
+
fuzzy_query = fulltext_query(query, group_ids)
|
|
488
|
+
if fuzzy_query == '':
|
|
489
|
+
return []
|
|
490
|
+
|
|
491
|
+
records, _, _ = await driver.execute_query(
|
|
492
|
+
"""
|
|
493
|
+
CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
|
|
494
|
+
YIELD node AS episode, score
|
|
495
|
+
MATCH (e:Episodic)
|
|
496
|
+
WHERE e.uuid = episode.uuid
|
|
497
|
+
RETURN
|
|
498
|
+
e.content AS content,
|
|
499
|
+
e.created_at AS created_at,
|
|
500
|
+
e.valid_at AS valid_at,
|
|
501
|
+
e.uuid AS uuid,
|
|
502
|
+
e.name AS name,
|
|
503
|
+
e.group_id AS group_id,
|
|
504
|
+
e.source_description AS source_description,
|
|
505
|
+
e.source AS source,
|
|
506
|
+
e.entity_edges AS entity_edges
|
|
507
|
+
ORDER BY score DESC
|
|
508
|
+
LIMIT $limit
|
|
509
|
+
""",
|
|
510
|
+
query=fuzzy_query,
|
|
511
|
+
group_ids=group_ids,
|
|
512
|
+
limit=limit,
|
|
513
|
+
database_=DEFAULT_DATABASE,
|
|
514
|
+
routing_='r',
|
|
515
|
+
)
|
|
516
|
+
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
517
|
+
|
|
518
|
+
return episodes
|
|
519
|
+
|
|
520
|
+
|
|
478
521
|
async def community_fulltext_search(
|
|
479
522
|
driver: AsyncDriver,
|
|
480
523
|
query: str,
|
|
@@ -718,7 +761,7 @@ async def get_relevant_edges(
|
|
|
718
761
|
|
|
719
762
|
|
|
720
763
|
# takes in a list of rankings of uuids
|
|
721
|
-
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|
764
|
+
def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
|
|
722
765
|
scores: dict[str, float] = defaultdict(float)
|
|
723
766
|
for result in results:
|
|
724
767
|
for i, uuid in enumerate(result):
|
|
@@ -729,11 +772,14 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|
|
729
772
|
|
|
730
773
|
sorted_uuids = [term[0] for term in scored_uuids]
|
|
731
774
|
|
|
732
|
-
return sorted_uuids
|
|
775
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
733
776
|
|
|
734
777
|
|
|
735
778
|
async def node_distance_reranker(
|
|
736
|
-
driver: AsyncDriver,
|
|
779
|
+
driver: AsyncDriver,
|
|
780
|
+
node_uuids: list[str],
|
|
781
|
+
center_node_uuid: str,
|
|
782
|
+
min_score: float = 0,
|
|
737
783
|
) -> list[str]:
|
|
738
784
|
# filter out node_uuid center node node uuid
|
|
739
785
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
@@ -767,12 +813,15 @@ async def node_distance_reranker(
|
|
|
767
813
|
|
|
768
814
|
# add back in filtered center uuid if it was filtered out
|
|
769
815
|
if center_node_uuid in node_uuids:
|
|
816
|
+
scores[center_node_uuid] = 0.1
|
|
770
817
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
771
818
|
|
|
772
|
-
return filtered_uuids
|
|
819
|
+
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
|
773
820
|
|
|
774
821
|
|
|
775
|
-
async def episode_mentions_reranker(
|
|
822
|
+
async def episode_mentions_reranker(
|
|
823
|
+
driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
824
|
+
) -> list[str]:
|
|
776
825
|
# use rrf as a preliminary ranker
|
|
777
826
|
sorted_uuids = rrf(node_uuids)
|
|
778
827
|
scores: dict[str, float] = {}
|
|
@@ -796,13 +845,14 @@ async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[s
|
|
|
796
845
|
# rerank on shortest distance
|
|
797
846
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
798
847
|
|
|
799
|
-
return sorted_uuids
|
|
848
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
800
849
|
|
|
801
850
|
|
|
802
851
|
def maximal_marginal_relevance(
|
|
803
852
|
query_vector: list[float],
|
|
804
853
|
candidates: list[tuple[str, list[float]]],
|
|
805
854
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
855
|
+
min_score: float = 0,
|
|
806
856
|
):
|
|
807
857
|
candidates_with_mmr: list[tuple[str, float]] = []
|
|
808
858
|
for candidate in candidates:
|
|
@@ -812,4 +862,6 @@ def maximal_marginal_relevance(
|
|
|
812
862
|
|
|
813
863
|
candidates_with_mmr.sort(reverse=True, key=lambda c: c[1])
|
|
814
864
|
|
|
815
|
-
return list(
|
|
865
|
+
return list(
|
|
866
|
+
set([candidate[0] for candidate in candidates_with_mmr if candidate[1] >= min_score])
|
|
867
|
+
)
|
|
@@ -71,6 +71,8 @@ async def build_indices_and_constraints(driver: AsyncDriver, delete_existing: bo
|
|
|
71
71
|
]
|
|
72
72
|
|
|
73
73
|
fulltext_indices: list[LiteralString] = [
|
|
74
|
+
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
|
|
75
|
+
FOR (e:Episodic) ON EACH [e.content, e.source, e.group_id]""",
|
|
74
76
|
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
|
|
75
77
|
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
|
|
76
78
|
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "graphiti-core"
|
|
3
3
|
description = "A temporal graph building library"
|
|
4
|
-
version = "0.
|
|
4
|
+
version = "0.10.0"
|
|
5
5
|
authors = [
|
|
6
6
|
{ "name" = "Paul Paliychuk", "email" = "paul@getzep.com" },
|
|
7
7
|
{ "name" = "Preston Rasmussen", "email" = "preston@getzep.com" },
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/cross_encoder/bge_reranker_client.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/community_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/edge_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/node_operations.py
RENAMED
|
File without changes
|
{graphiti_core-0.9.6 → graphiti_core-0.10.0}/graphiti_core/utils/maintenance/temporal_operations.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|