graphiti-core 0.3.6__tar.gz → 0.3.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/PKG-INFO +2 -2
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/edges.py +4 -5
- graphiti_core-0.3.8/graphiti_core/embedder/__init__.py +4 -0
- graphiti_core-0.3.6/graphiti_core/helpers.py → graphiti_core-0.3.8/graphiti_core/embedder/client.py +15 -4
- graphiti_core-0.3.8/graphiti_core/embedder/openai.py +48 -0
- graphiti_core-0.3.8/graphiti_core/embedder/voyage.py +47 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/graphiti.py +24 -18
- graphiti_core-0.3.8/graphiti_core/helpers.py +53 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/anthropic_client.py +0 -5
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/client.py +0 -4
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/config.py +0 -1
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/groq_client.py +0 -5
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/openai_client.py +0 -6
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/utils.py +3 -7
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/nodes.py +7 -9
- graphiti_core-0.3.8/graphiti_core/prompts/eval.py +90 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/lib.py +6 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search.py +54 -49
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search_utils.py +40 -146
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/community_operations.py +2 -1
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/graph_data_operations.py +17 -31
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/pyproject.toml +2 -2
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/LICENSE +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/README.md +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/__init__.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/errors.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/__init__.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/errors.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/__init__.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/dedupe_edges.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/dedupe_nodes.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/extract_edge_dates.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/extract_edges.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/extract_nodes.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/invalidate_edges.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/models.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/summarize_nodes.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/py.typed +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/__init__.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search_config.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search_config_recipes.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/__init__.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/bulk_utils.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/__init__.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/node_operations.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
- {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: graphiti-core
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.8
|
|
4
4
|
Summary: A temporal graph building library
|
|
5
5
|
License: Apache-2.0
|
|
6
6
|
Author: Paul Paliychuk
|
|
@@ -14,7 +14,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
|
14
14
|
Requires-Dist: diskcache (>=5.6.3,<6.0.0)
|
|
15
15
|
Requires-Dist: neo4j (>=5.23.0,<6.0.0)
|
|
16
16
|
Requires-Dist: numpy (>=1.0.0)
|
|
17
|
-
Requires-Dist: openai (>=1.
|
|
17
|
+
Requires-Dist: openai (>=1.50.2,<2.0.0)
|
|
18
18
|
Requires-Dist: pydantic (>=2.8.2,<3.0.0)
|
|
19
19
|
Requires-Dist: tenacity (<9.0.0)
|
|
20
20
|
Description-Content-Type: text/markdown
|
|
@@ -24,9 +24,9 @@ from uuid import uuid4
|
|
|
24
24
|
from neo4j import AsyncDriver
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
|
|
27
|
+
from graphiti_core.embedder import EmbedderClient
|
|
27
28
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
28
29
|
from graphiti_core.helpers import parse_db_date
|
|
29
|
-
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
30
30
|
from graphiti_core.nodes import Node
|
|
31
31
|
|
|
32
32
|
logger = logging.getLogger(__name__)
|
|
@@ -171,17 +171,16 @@ class EntityEdge(Edge):
|
|
|
171
171
|
default=None, description='datetime of when the fact stopped being true'
|
|
172
172
|
)
|
|
173
173
|
|
|
174
|
-
async def generate_embedding(self, embedder
|
|
174
|
+
async def generate_embedding(self, embedder: EmbedderClient):
|
|
175
175
|
start = time()
|
|
176
176
|
|
|
177
177
|
text = self.fact.replace('\n', ' ')
|
|
178
|
-
|
|
179
|
-
self.fact_embedding = embedding[:EMBEDDING_DIM]
|
|
178
|
+
self.fact_embedding = await embedder.create(input=[text])
|
|
180
179
|
|
|
181
180
|
end = time()
|
|
182
181
|
logger.info(f'embedded {text} in {end - start} ms')
|
|
183
182
|
|
|
184
|
-
return
|
|
183
|
+
return self.fact_embedding
|
|
185
184
|
|
|
186
185
|
async def save(self, driver: AsyncDriver):
|
|
187
186
|
result = await driver.execute_query(
|
graphiti_core-0.3.6/graphiti_core/helpers.py → graphiti_core-0.3.8/graphiti_core/embedder/client.py
RENAMED
|
@@ -14,10 +14,21 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
-
from
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from typing import Iterable, List, Literal
|
|
18
19
|
|
|
19
|
-
from
|
|
20
|
+
from pydantic import BaseModel, Field
|
|
20
21
|
|
|
22
|
+
EMBEDDING_DIM = 1024
|
|
21
23
|
|
|
22
|
-
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
class EmbedderConfig(BaseModel):
|
|
26
|
+
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class EmbedderClient(ABC):
|
|
30
|
+
@abstractmethod
|
|
31
|
+
async def create(
|
|
32
|
+
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
33
|
+
) -> list[float]:
|
|
34
|
+
pass
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Iterable, List
|
|
18
|
+
|
|
19
|
+
from openai import AsyncOpenAI
|
|
20
|
+
from openai.types import EmbeddingModel
|
|
21
|
+
|
|
22
|
+
from .client import EmbedderClient, EmbedderConfig
|
|
23
|
+
|
|
24
|
+
DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAIEmbedderConfig(EmbedderConfig):
|
|
28
|
+
embedding_model: EmbeddingModel | str = DEFAULT_EMBEDDING_MODEL
|
|
29
|
+
api_key: str | None = None
|
|
30
|
+
base_url: str | None = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class OpenAIEmbedder(EmbedderClient):
|
|
34
|
+
"""
|
|
35
|
+
OpenAI Embedder Client
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, config: OpenAIEmbedderConfig | None = None):
|
|
39
|
+
if config is None:
|
|
40
|
+
config = OpenAIEmbedderConfig()
|
|
41
|
+
self.config = config
|
|
42
|
+
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
|
|
43
|
+
|
|
44
|
+
async def create(
|
|
45
|
+
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
46
|
+
) -> list[float]:
|
|
47
|
+
result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
|
|
48
|
+
return result.data[0].embedding[: self.config.embedding_dim]
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Iterable, List
|
|
18
|
+
|
|
19
|
+
import voyageai # type: ignore
|
|
20
|
+
from pydantic import Field
|
|
21
|
+
|
|
22
|
+
from .client import EmbedderClient, EmbedderConfig
|
|
23
|
+
|
|
24
|
+
DEFAULT_EMBEDDING_MODEL = 'voyage-3'
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class VoyageAIEmbedderConfig(EmbedderConfig):
|
|
28
|
+
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
|
|
29
|
+
api_key: str | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class VoyageAIEmbedder(EmbedderClient):
|
|
33
|
+
"""
|
|
34
|
+
VoyageAI Embedder Client
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config: VoyageAIEmbedderConfig | None = None):
|
|
38
|
+
if config is None:
|
|
39
|
+
config = VoyageAIEmbedderConfig()
|
|
40
|
+
self.config = config
|
|
41
|
+
self.client = voyageai.AsyncClient(api_key=config.api_key)
|
|
42
|
+
|
|
43
|
+
async def create(
|
|
44
|
+
self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
45
|
+
) -> list[float]:
|
|
46
|
+
result = await self.client.embed(input, model=self.config.embedding_model)
|
|
47
|
+
return result.embeddings[0][: self.config.embedding_dim]
|
|
@@ -23,6 +23,7 @@ from dotenv import load_dotenv
|
|
|
23
23
|
from neo4j import AsyncGraphDatabase
|
|
24
24
|
|
|
25
25
|
from graphiti_core.edges import EntityEdge, EpisodicEdge
|
|
26
|
+
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
|
|
26
27
|
from graphiti_core.llm_client import LLMClient, OpenAIClient
|
|
27
28
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
|
|
28
29
|
from graphiti_core.search.search import SearchConfig, search
|
|
@@ -83,6 +84,7 @@ class Graphiti:
|
|
|
83
84
|
user: str,
|
|
84
85
|
password: str,
|
|
85
86
|
llm_client: LLMClient | None = None,
|
|
87
|
+
embedder: EmbedderClient | None = None,
|
|
86
88
|
store_raw_episode_content: bool = True,
|
|
87
89
|
):
|
|
88
90
|
"""
|
|
@@ -128,6 +130,10 @@ class Graphiti:
|
|
|
128
130
|
self.llm_client = llm_client
|
|
129
131
|
else:
|
|
130
132
|
self.llm_client = OpenAIClient()
|
|
133
|
+
if embedder:
|
|
134
|
+
self.embedder = embedder
|
|
135
|
+
else:
|
|
136
|
+
self.embedder = OpenAIEmbedder()
|
|
131
137
|
|
|
132
138
|
async def close(self):
|
|
133
139
|
"""
|
|
@@ -161,7 +167,7 @@ class Graphiti:
|
|
|
161
167
|
"""
|
|
162
168
|
await self.driver.close()
|
|
163
169
|
|
|
164
|
-
async def build_indices_and_constraints(self):
|
|
170
|
+
async def build_indices_and_constraints(self, delete_existing: bool = False):
|
|
165
171
|
"""
|
|
166
172
|
Build indices and constraints in the Neo4j database.
|
|
167
173
|
|
|
@@ -171,6 +177,9 @@ class Graphiti:
|
|
|
171
177
|
Parameters
|
|
172
178
|
----------
|
|
173
179
|
self
|
|
180
|
+
delete_existing : bool, optional
|
|
181
|
+
Whether to clear existing indices before creating new ones.
|
|
182
|
+
|
|
174
183
|
|
|
175
184
|
Returns
|
|
176
185
|
-------
|
|
@@ -191,7 +200,7 @@ class Graphiti:
|
|
|
191
200
|
Caution: Running this method on a large existing database may take some time
|
|
192
201
|
and could impact database performance during execution.
|
|
193
202
|
"""
|
|
194
|
-
await build_indices_and_constraints(self.driver)
|
|
203
|
+
await build_indices_and_constraints(self.driver, delete_existing)
|
|
195
204
|
|
|
196
205
|
async def retrieve_episodes(
|
|
197
206
|
self,
|
|
@@ -287,7 +296,6 @@ class Graphiti:
|
|
|
287
296
|
start = time()
|
|
288
297
|
|
|
289
298
|
entity_edges: list[EntityEdge] = []
|
|
290
|
-
embedder = self.llm_client.get_embedder()
|
|
291
299
|
now = datetime.now()
|
|
292
300
|
|
|
293
301
|
previous_episodes = await self.retrieve_episodes(
|
|
@@ -315,7 +323,7 @@ class Graphiti:
|
|
|
315
323
|
# Calculate Embeddings
|
|
316
324
|
|
|
317
325
|
await asyncio.gather(
|
|
318
|
-
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
|
|
326
|
+
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
|
|
319
327
|
)
|
|
320
328
|
|
|
321
329
|
# Resolve extracted nodes with nodes already in the graph and extract facts
|
|
@@ -343,7 +351,7 @@ class Graphiti:
|
|
|
343
351
|
# calculate embeddings
|
|
344
352
|
await asyncio.gather(
|
|
345
353
|
*[
|
|
346
|
-
edge.generate_embedding(embedder)
|
|
354
|
+
edge.generate_embedding(self.embedder)
|
|
347
355
|
for edge in extracted_edges_with_resolved_pointers
|
|
348
356
|
]
|
|
349
357
|
)
|
|
@@ -436,7 +444,7 @@ class Graphiti:
|
|
|
436
444
|
if update_communities:
|
|
437
445
|
await asyncio.gather(
|
|
438
446
|
*[
|
|
439
|
-
update_community(self.driver, self.llm_client, embedder, node)
|
|
447
|
+
update_community(self.driver, self.llm_client, self.embedder, node)
|
|
440
448
|
for node in nodes
|
|
441
449
|
]
|
|
442
450
|
)
|
|
@@ -485,7 +493,6 @@ class Graphiti:
|
|
|
485
493
|
"""
|
|
486
494
|
try:
|
|
487
495
|
start = time()
|
|
488
|
-
embedder = self.llm_client.get_embedder()
|
|
489
496
|
now = datetime.now()
|
|
490
497
|
|
|
491
498
|
episodes = [
|
|
@@ -517,8 +524,8 @@ class Graphiti:
|
|
|
517
524
|
|
|
518
525
|
# Generate embeddings
|
|
519
526
|
await asyncio.gather(
|
|
520
|
-
*[node.generate_name_embedding(embedder) for node in extracted_nodes],
|
|
521
|
-
*[edge.generate_embedding(embedder) for edge in extracted_edges],
|
|
527
|
+
*[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
|
|
528
|
+
*[edge.generate_embedding(self.embedder) for edge in extracted_edges],
|
|
522
529
|
)
|
|
523
530
|
|
|
524
531
|
# Dedupe extracted nodes, compress extracted edges
|
|
@@ -561,14 +568,14 @@ class Graphiti:
|
|
|
561
568
|
raise e
|
|
562
569
|
|
|
563
570
|
async def build_communities(self):
|
|
564
|
-
embedder = self.llm_client.get_embedder()
|
|
565
|
-
|
|
566
571
|
# Clear existing communities
|
|
567
572
|
await remove_communities(self.driver)
|
|
568
573
|
|
|
569
574
|
community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
|
|
570
575
|
|
|
571
|
-
await asyncio.gather(
|
|
576
|
+
await asyncio.gather(
|
|
577
|
+
*[node.generate_name_embedding(self.embedder) for node in community_nodes]
|
|
578
|
+
)
|
|
572
579
|
|
|
573
580
|
await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
|
|
574
581
|
await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
|
|
@@ -619,7 +626,7 @@ class Graphiti:
|
|
|
619
626
|
edges = (
|
|
620
627
|
await search(
|
|
621
628
|
self.driver,
|
|
622
|
-
self.
|
|
629
|
+
self.embedder,
|
|
623
630
|
query,
|
|
624
631
|
group_ids,
|
|
625
632
|
search_config,
|
|
@@ -636,9 +643,7 @@ class Graphiti:
|
|
|
636
643
|
group_ids: list[str] | None = None,
|
|
637
644
|
center_node_uuid: str | None = None,
|
|
638
645
|
) -> SearchResults:
|
|
639
|
-
return await search(
|
|
640
|
-
self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
|
|
641
|
-
)
|
|
646
|
+
return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
|
|
642
647
|
|
|
643
648
|
async def get_nodes_by_query(
|
|
644
649
|
self,
|
|
@@ -683,14 +688,15 @@ class Graphiti:
|
|
|
683
688
|
to each individual search method before results are combined and deduplicated.
|
|
684
689
|
If not specified, a default limit (defined in the search functions) will be used.
|
|
685
690
|
"""
|
|
686
|
-
embedder = self.llm_client.get_embedder()
|
|
687
691
|
search_config = (
|
|
688
692
|
NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
|
|
689
693
|
)
|
|
690
694
|
search_config.limit = limit
|
|
691
695
|
|
|
692
696
|
nodes = (
|
|
693
|
-
await search(
|
|
697
|
+
await search(
|
|
698
|
+
self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
|
|
699
|
+
)
|
|
694
700
|
).nodes
|
|
695
701
|
return nodes
|
|
696
702
|
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from datetime import datetime
|
|
18
|
+
|
|
19
|
+
from neo4j import time as neo4j_time
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
|
|
23
|
+
return neo_date.to_native() if neo_date else None
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def lucene_sanitize(query: str) -> str:
|
|
27
|
+
# Escape special characters from a query before passing into Lucene
|
|
28
|
+
# + - && || ! ( ) { } [ ] ^ " ~ * ? : \
|
|
29
|
+
escape_map = str.maketrans(
|
|
30
|
+
{
|
|
31
|
+
'+': r'\+',
|
|
32
|
+
'-': r'\-',
|
|
33
|
+
'&': r'\&',
|
|
34
|
+
'|': r'\|',
|
|
35
|
+
'!': r'\!',
|
|
36
|
+
'(': r'\(',
|
|
37
|
+
')': r'\)',
|
|
38
|
+
'{': r'\{',
|
|
39
|
+
'}': r'\}',
|
|
40
|
+
'[': r'\[',
|
|
41
|
+
']': r'\]',
|
|
42
|
+
'^': r'\^',
|
|
43
|
+
'"': r'\"',
|
|
44
|
+
'~': r'\~',
|
|
45
|
+
'*': r'\*',
|
|
46
|
+
'?': r'\?',
|
|
47
|
+
':': r'\:',
|
|
48
|
+
'\\': r'\\',
|
|
49
|
+
}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
sanitized = query.translate(escape_map)
|
|
53
|
+
return sanitized
|
|
@@ -20,7 +20,6 @@ import typing
|
|
|
20
20
|
|
|
21
21
|
import anthropic
|
|
22
22
|
from anthropic import AsyncAnthropic
|
|
23
|
-
from openai import AsyncOpenAI
|
|
24
23
|
|
|
25
24
|
from ..prompts.models import Message
|
|
26
25
|
from .client import LLMClient
|
|
@@ -47,10 +46,6 @@ class AnthropicClient(LLMClient):
|
|
|
47
46
|
max_retries=1,
|
|
48
47
|
)
|
|
49
48
|
|
|
50
|
-
def get_embedder(self) -> typing.Any:
|
|
51
|
-
openai_client = AsyncOpenAI()
|
|
52
|
-
return openai_client.embeddings
|
|
53
|
-
|
|
54
49
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
55
50
|
system_message = messages[0]
|
|
56
51
|
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
|
|
@@ -55,10 +55,6 @@ class LLMClient(ABC):
|
|
|
55
55
|
self.cache_enabled = cache
|
|
56
56
|
self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
|
|
57
57
|
|
|
58
|
-
@abstractmethod
|
|
59
|
-
def get_embedder(self) -> typing.Any:
|
|
60
|
-
pass
|
|
61
|
-
|
|
62
58
|
@retry(
|
|
63
59
|
stop=stop_after_attempt(4),
|
|
64
60
|
wait=wait_random_exponential(multiplier=10, min=5, max=120),
|
|
@@ -21,7 +21,6 @@ import typing
|
|
|
21
21
|
import groq
|
|
22
22
|
from groq import AsyncGroq
|
|
23
23
|
from groq.types.chat import ChatCompletionMessageParam
|
|
24
|
-
from openai import AsyncOpenAI
|
|
25
24
|
|
|
26
25
|
from ..prompts.models import Message
|
|
27
26
|
from .client import LLMClient
|
|
@@ -44,10 +43,6 @@ class GroqClient(LLMClient):
|
|
|
44
43
|
|
|
45
44
|
self.client = AsyncGroq(api_key=config.api_key)
|
|
46
45
|
|
|
47
|
-
def get_embedder(self) -> typing.Any:
|
|
48
|
-
openai_client = AsyncOpenAI()
|
|
49
|
-
return openai_client.embeddings
|
|
50
|
-
|
|
51
46
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
52
47
|
msgs: list[ChatCompletionMessageParam] = []
|
|
53
48
|
for m in messages:
|
|
@@ -49,9 +49,6 @@ class OpenAIClient(LLMClient):
|
|
|
49
49
|
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
|
|
50
50
|
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
|
|
51
51
|
|
|
52
|
-
get_embedder() -> typing.Any:
|
|
53
|
-
Returns the embedder from the OpenAI client.
|
|
54
|
-
|
|
55
52
|
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
|
|
56
53
|
Generates a response from the language model based on the provided messages.
|
|
57
54
|
"""
|
|
@@ -78,9 +75,6 @@ class OpenAIClient(LLMClient):
|
|
|
78
75
|
else:
|
|
79
76
|
self.client = client
|
|
80
77
|
|
|
81
|
-
def get_embedder(self) -> typing.Any:
|
|
82
|
-
return self.client.embeddings
|
|
83
|
-
|
|
84
78
|
async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
|
|
85
79
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
86
80
|
for m in messages:
|
|
@@ -15,22 +15,18 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
-
import typing
|
|
19
18
|
from time import time
|
|
20
19
|
|
|
21
|
-
from graphiti_core.
|
|
20
|
+
from graphiti_core.embedder.client import EmbedderClient
|
|
22
21
|
|
|
23
22
|
logger = logging.getLogger(__name__)
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
async def generate_embedding(
|
|
27
|
-
embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
|
|
28
|
-
):
|
|
25
|
+
async def generate_embedding(embedder: EmbedderClient, text: str):
|
|
29
26
|
start = time()
|
|
30
27
|
|
|
31
28
|
text = text.replace('\n', ' ')
|
|
32
|
-
embedding =
|
|
33
|
-
embedding = embedding[:EMBEDDING_DIM]
|
|
29
|
+
embedding = await embedder.create(input=[text])
|
|
34
30
|
|
|
35
31
|
end = time()
|
|
36
32
|
logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
|
|
@@ -25,8 +25,8 @@ from uuid import uuid4
|
|
|
25
25
|
from neo4j import AsyncDriver
|
|
26
26
|
from pydantic import BaseModel, Field
|
|
27
27
|
|
|
28
|
+
from graphiti_core.embedder import EmbedderClient
|
|
28
29
|
from graphiti_core.errors import NodeNotFoundError
|
|
29
|
-
from graphiti_core.llm_client.config import EMBEDDING_DIM
|
|
30
30
|
|
|
31
31
|
logger = logging.getLogger(__name__)
|
|
32
32
|
|
|
@@ -212,15 +212,14 @@ class EntityNode(Node):
|
|
|
212
212
|
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
|
|
213
213
|
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
|
|
214
214
|
|
|
215
|
-
async def generate_name_embedding(self, embedder
|
|
215
|
+
async def generate_name_embedding(self, embedder: EmbedderClient):
|
|
216
216
|
start = time()
|
|
217
217
|
text = self.name.replace('\n', ' ')
|
|
218
|
-
|
|
219
|
-
self.name_embedding = embedding[:EMBEDDING_DIM]
|
|
218
|
+
self.name_embedding = await embedder.create(input=[text])
|
|
220
219
|
end = time()
|
|
221
220
|
logger.info(f'embedded {text} in {end - start} ms')
|
|
222
221
|
|
|
223
|
-
return
|
|
222
|
+
return self.name_embedding
|
|
224
223
|
|
|
225
224
|
async def save(self, driver: AsyncDriver):
|
|
226
225
|
result = await driver.execute_query(
|
|
@@ -323,15 +322,14 @@ class CommunityNode(Node):
|
|
|
323
322
|
|
|
324
323
|
return result
|
|
325
324
|
|
|
326
|
-
async def generate_name_embedding(self, embedder
|
|
325
|
+
async def generate_name_embedding(self, embedder: EmbedderClient):
|
|
327
326
|
start = time()
|
|
328
327
|
text = self.name.replace('\n', ' ')
|
|
329
|
-
|
|
330
|
-
self.name_embedding = embedding[:EMBEDDING_DIM]
|
|
328
|
+
self.name_embedding = await embedder.create(input=[text])
|
|
331
329
|
end = time()
|
|
332
330
|
logger.info(f'embedded {text} in {end - start} ms')
|
|
333
331
|
|
|
334
|
-
return
|
|
332
|
+
return self.name_embedding
|
|
335
333
|
|
|
336
334
|
@classmethod
|
|
337
335
|
async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import json
|
|
18
|
+
from typing import Any, Protocol, TypedDict
|
|
19
|
+
|
|
20
|
+
from .models import Message, PromptFunction, PromptVersion
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Prompt(Protocol):
|
|
24
|
+
qa_prompt: PromptVersion
|
|
25
|
+
eval_prompt: PromptVersion
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Versions(TypedDict):
|
|
29
|
+
qa_prompt: PromptFunction
|
|
30
|
+
eval_prompt: PromptFunction
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def qa_prompt(context: dict[str, Any]) -> list[Message]:
|
|
34
|
+
sys_prompt = """You are Alice and should respond to all questions from the first person perspective of Alice"""
|
|
35
|
+
|
|
36
|
+
user_prompt = f"""
|
|
37
|
+
Your task is to briefly answer the question in the way that you think Alice would answer the question.
|
|
38
|
+
You are given the following entity summaries and facts to help you determine the answer to your question.
|
|
39
|
+
<ENTITY_SUMMARIES>
|
|
40
|
+
{json.dumps(context['entity_summaries'])}
|
|
41
|
+
</ENTITY_SUMMARIES
|
|
42
|
+
<FACTS>
|
|
43
|
+
{json.dumps(context['facts'])}
|
|
44
|
+
</FACTS>
|
|
45
|
+
<QUESTION>
|
|
46
|
+
{context['query']}
|
|
47
|
+
</QUESTION>
|
|
48
|
+
respond with a JSON object in the following format:
|
|
49
|
+
{{
|
|
50
|
+
"ANSWER": "how Alice would answer the question"
|
|
51
|
+
}}
|
|
52
|
+
"""
|
|
53
|
+
return [
|
|
54
|
+
Message(role='system', content=sys_prompt),
|
|
55
|
+
Message(role='user', content=user_prompt),
|
|
56
|
+
]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def eval_prompt(context: dict[str, Any]) -> list[Message]:
|
|
60
|
+
sys_prompt = (
|
|
61
|
+
"""You are a judge that determines if answers to questions match a gold standard answer"""
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
user_prompt = f"""
|
|
65
|
+
Given the QUESTION and the gold standard ANSWER determine if the RESPONSE to the question is correct or incorrect.
|
|
66
|
+
Although the RESPONSE may be more verbose, mark it as correct as long as it references the same topic
|
|
67
|
+
as the gold standard ANSWER. Also include your reasoning for the grade.
|
|
68
|
+
<QUESTION>
|
|
69
|
+
{context['query']}
|
|
70
|
+
</QUESTION>
|
|
71
|
+
<ANSWER>
|
|
72
|
+
{context['answer']}
|
|
73
|
+
</ANSWER>
|
|
74
|
+
<RESPONSE>
|
|
75
|
+
{context['response']}
|
|
76
|
+
</RESPONSE>
|
|
77
|
+
|
|
78
|
+
respond with a JSON object in the following format:
|
|
79
|
+
{{
|
|
80
|
+
"is_correct": "boolean if the answer is correct or incorrect"
|
|
81
|
+
"reasoning": "why you determined the response was correct or incorrect"
|
|
82
|
+
}}
|
|
83
|
+
"""
|
|
84
|
+
return [
|
|
85
|
+
Message(role='system', content=sys_prompt),
|
|
86
|
+
Message(role='user', content=user_prompt),
|
|
87
|
+
]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
|
|
@@ -34,6 +34,9 @@ from .dedupe_nodes import (
|
|
|
34
34
|
from .dedupe_nodes import (
|
|
35
35
|
versions as dedupe_nodes_versions,
|
|
36
36
|
)
|
|
37
|
+
from .eval import Prompt as EvalPrompt
|
|
38
|
+
from .eval import Versions as EvalVersions
|
|
39
|
+
from .eval import versions as eval_versions
|
|
37
40
|
from .extract_edge_dates import (
|
|
38
41
|
Prompt as ExtractEdgeDatesPrompt,
|
|
39
42
|
)
|
|
@@ -84,6 +87,7 @@ class PromptLibrary(Protocol):
|
|
|
84
87
|
invalidate_edges: InvalidateEdgesPrompt
|
|
85
88
|
extract_edge_dates: ExtractEdgeDatesPrompt
|
|
86
89
|
summarize_nodes: SummarizeNodesPrompt
|
|
90
|
+
eval: EvalPrompt
|
|
87
91
|
|
|
88
92
|
|
|
89
93
|
class PromptLibraryImpl(TypedDict):
|
|
@@ -94,6 +98,7 @@ class PromptLibraryImpl(TypedDict):
|
|
|
94
98
|
invalidate_edges: InvalidateEdgesVersions
|
|
95
99
|
extract_edge_dates: ExtractEdgeDatesVersions
|
|
96
100
|
summarize_nodes: SummarizeNodesVersions
|
|
101
|
+
eval: EvalVersions
|
|
97
102
|
|
|
98
103
|
|
|
99
104
|
class VersionWrapper:
|
|
@@ -124,5 +129,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
|
|
|
124
129
|
'invalidate_edges': invalidate_edges_versions,
|
|
125
130
|
'extract_edge_dates': extract_edge_dates_versions,
|
|
126
131
|
'summarize_nodes': summarize_nodes_versions,
|
|
132
|
+
'eval': eval_versions,
|
|
127
133
|
}
|
|
128
134
|
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
|