graphiti-core 0.3.6__tar.gz → 0.3.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (48) hide show
  1. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/PKG-INFO +2 -2
  2. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/edges.py +4 -5
  3. graphiti_core-0.3.8/graphiti_core/embedder/__init__.py +4 -0
  4. graphiti_core-0.3.6/graphiti_core/helpers.py → graphiti_core-0.3.8/graphiti_core/embedder/client.py +15 -4
  5. graphiti_core-0.3.8/graphiti_core/embedder/openai.py +48 -0
  6. graphiti_core-0.3.8/graphiti_core/embedder/voyage.py +47 -0
  7. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/graphiti.py +24 -18
  8. graphiti_core-0.3.8/graphiti_core/helpers.py +53 -0
  9. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/anthropic_client.py +0 -5
  10. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/client.py +0 -4
  11. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/config.py +0 -1
  12. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/groq_client.py +0 -5
  13. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/openai_client.py +0 -6
  14. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/utils.py +3 -7
  15. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/nodes.py +7 -9
  16. graphiti_core-0.3.8/graphiti_core/prompts/eval.py +90 -0
  17. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/lib.py +6 -0
  18. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search.py +54 -49
  19. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search_utils.py +40 -146
  20. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/community_operations.py +2 -1
  21. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/graph_data_operations.py +17 -31
  22. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/pyproject.toml +2 -2
  23. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/LICENSE +0 -0
  24. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/README.md +0 -0
  25. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/__init__.py +0 -0
  26. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/errors.py +0 -0
  27. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/__init__.py +0 -0
  28. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/llm_client/errors.py +0 -0
  29. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/__init__.py +0 -0
  30. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/dedupe_edges.py +0 -0
  31. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/dedupe_nodes.py +0 -0
  32. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/extract_edge_dates.py +0 -0
  33. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/extract_edges.py +0 -0
  34. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/extract_nodes.py +0 -0
  35. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/invalidate_edges.py +0 -0
  36. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/models.py +0 -0
  37. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/prompts/summarize_nodes.py +0 -0
  38. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/py.typed +0 -0
  39. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/__init__.py +0 -0
  40. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search_config.py +0 -0
  41. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/search/search_config_recipes.py +0 -0
  42. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/__init__.py +0 -0
  43. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/bulk_utils.py +0 -0
  44. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/__init__.py +0 -0
  45. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/edge_operations.py +0 -0
  46. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/node_operations.py +0 -0
  47. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/temporal_operations.py +0 -0
  48. {graphiti_core-0.3.6 → graphiti_core-0.3.8}/graphiti_core/utils/maintenance/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.3.6
3
+ Version: 0.3.8
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -14,7 +14,7 @@ Classifier: Programming Language :: Python :: 3.12
14
14
  Requires-Dist: diskcache (>=5.6.3,<6.0.0)
15
15
  Requires-Dist: neo4j (>=5.23.0,<6.0.0)
16
16
  Requires-Dist: numpy (>=1.0.0)
17
- Requires-Dist: openai (>=1.38.0,<2.0.0)
17
+ Requires-Dist: openai (>=1.50.2,<2.0.0)
18
18
  Requires-Dist: pydantic (>=2.8.2,<3.0.0)
19
19
  Requires-Dist: tenacity (<9.0.0)
20
20
  Description-Content-Type: text/markdown
@@ -24,9 +24,9 @@ from uuid import uuid4
24
24
  from neo4j import AsyncDriver
25
25
  from pydantic import BaseModel, Field
26
26
 
27
+ from graphiti_core.embedder import EmbedderClient
27
28
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
28
29
  from graphiti_core.helpers import parse_db_date
29
- from graphiti_core.llm_client.config import EMBEDDING_DIM
30
30
  from graphiti_core.nodes import Node
31
31
 
32
32
  logger = logging.getLogger(__name__)
@@ -171,17 +171,16 @@ class EntityEdge(Edge):
171
171
  default=None, description='datetime of when the fact stopped being true'
172
172
  )
173
173
 
174
- async def generate_embedding(self, embedder, model='text-embedding-3-small'):
174
+ async def generate_embedding(self, embedder: EmbedderClient):
175
175
  start = time()
176
176
 
177
177
  text = self.fact.replace('\n', ' ')
178
- embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
179
- self.fact_embedding = embedding[:EMBEDDING_DIM]
178
+ self.fact_embedding = await embedder.create(input=[text])
180
179
 
181
180
  end = time()
182
181
  logger.info(f'embedded {text} in {end - start} ms')
183
182
 
184
- return embedding
183
+ return self.fact_embedding
185
184
 
186
185
  async def save(self, driver: AsyncDriver):
187
186
  result = await driver.execute_query(
@@ -0,0 +1,4 @@
1
+ from .client import EmbedderClient
2
+ from .openai import OpenAIEmbedder, OpenAIEmbedderConfig
3
+
4
+ __all__ = ['EmbedderClient', 'OpenAIEmbedder', 'OpenAIEmbedderConfig']
@@ -14,10 +14,21 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- from datetime import datetime
17
+ from abc import ABC, abstractmethod
18
+ from typing import Iterable, List, Literal
18
19
 
19
- from neo4j import time as neo4j_time
20
+ from pydantic import BaseModel, Field
20
21
 
22
+ EMBEDDING_DIM = 1024
21
23
 
22
- def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
23
- return neo_date.to_native() if neo_date else None
24
+
25
+ class EmbedderConfig(BaseModel):
26
+ embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
27
+
28
+
29
+ class EmbedderClient(ABC):
30
+ @abstractmethod
31
+ async def create(
32
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
33
+ ) -> list[float]:
34
+ pass
@@ -0,0 +1,48 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from typing import Iterable, List
18
+
19
+ from openai import AsyncOpenAI
20
+ from openai.types import EmbeddingModel
21
+
22
+ from .client import EmbedderClient, EmbedderConfig
23
+
24
+ DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'
25
+
26
+
27
+ class OpenAIEmbedderConfig(EmbedderConfig):
28
+ embedding_model: EmbeddingModel | str = DEFAULT_EMBEDDING_MODEL
29
+ api_key: str | None = None
30
+ base_url: str | None = None
31
+
32
+
33
+ class OpenAIEmbedder(EmbedderClient):
34
+ """
35
+ OpenAI Embedder Client
36
+ """
37
+
38
+ def __init__(self, config: OpenAIEmbedderConfig | None = None):
39
+ if config is None:
40
+ config = OpenAIEmbedderConfig()
41
+ self.config = config
42
+ self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
43
+
44
+ async def create(
45
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
46
+ ) -> list[float]:
47
+ result = await self.client.embeddings.create(input=input, model=self.config.embedding_model)
48
+ return result.data[0].embedding[: self.config.embedding_dim]
@@ -0,0 +1,47 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from typing import Iterable, List
18
+
19
+ import voyageai # type: ignore
20
+ from pydantic import Field
21
+
22
+ from .client import EmbedderClient, EmbedderConfig
23
+
24
+ DEFAULT_EMBEDDING_MODEL = 'voyage-3'
25
+
26
+
27
+ class VoyageAIEmbedderConfig(EmbedderConfig):
28
+ embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
29
+ api_key: str | None = None
30
+
31
+
32
+ class VoyageAIEmbedder(EmbedderClient):
33
+ """
34
+ VoyageAI Embedder Client
35
+ """
36
+
37
+ def __init__(self, config: VoyageAIEmbedderConfig | None = None):
38
+ if config is None:
39
+ config = VoyageAIEmbedderConfig()
40
+ self.config = config
41
+ self.client = voyageai.AsyncClient(api_key=config.api_key)
42
+
43
+ async def create(
44
+ self, input: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
45
+ ) -> list[float]:
46
+ result = await self.client.embed(input, model=self.config.embedding_model)
47
+ return result.embeddings[0][: self.config.embedding_dim]
@@ -23,6 +23,7 @@ from dotenv import load_dotenv
23
23
  from neo4j import AsyncGraphDatabase
24
24
 
25
25
  from graphiti_core.edges import EntityEdge, EpisodicEdge
26
+ from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
26
27
  from graphiti_core.llm_client import LLMClient, OpenAIClient
27
28
  from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
28
29
  from graphiti_core.search.search import SearchConfig, search
@@ -83,6 +84,7 @@ class Graphiti:
83
84
  user: str,
84
85
  password: str,
85
86
  llm_client: LLMClient | None = None,
87
+ embedder: EmbedderClient | None = None,
86
88
  store_raw_episode_content: bool = True,
87
89
  ):
88
90
  """
@@ -128,6 +130,10 @@ class Graphiti:
128
130
  self.llm_client = llm_client
129
131
  else:
130
132
  self.llm_client = OpenAIClient()
133
+ if embedder:
134
+ self.embedder = embedder
135
+ else:
136
+ self.embedder = OpenAIEmbedder()
131
137
 
132
138
  async def close(self):
133
139
  """
@@ -161,7 +167,7 @@ class Graphiti:
161
167
  """
162
168
  await self.driver.close()
163
169
 
164
- async def build_indices_and_constraints(self):
170
+ async def build_indices_and_constraints(self, delete_existing: bool = False):
165
171
  """
166
172
  Build indices and constraints in the Neo4j database.
167
173
 
@@ -171,6 +177,9 @@ class Graphiti:
171
177
  Parameters
172
178
  ----------
173
179
  self
180
+ delete_existing : bool, optional
181
+ Whether to clear existing indices before creating new ones.
182
+
174
183
 
175
184
  Returns
176
185
  -------
@@ -191,7 +200,7 @@ class Graphiti:
191
200
  Caution: Running this method on a large existing database may take some time
192
201
  and could impact database performance during execution.
193
202
  """
194
- await build_indices_and_constraints(self.driver)
203
+ await build_indices_and_constraints(self.driver, delete_existing)
195
204
 
196
205
  async def retrieve_episodes(
197
206
  self,
@@ -287,7 +296,6 @@ class Graphiti:
287
296
  start = time()
288
297
 
289
298
  entity_edges: list[EntityEdge] = []
290
- embedder = self.llm_client.get_embedder()
291
299
  now = datetime.now()
292
300
 
293
301
  previous_episodes = await self.retrieve_episodes(
@@ -315,7 +323,7 @@ class Graphiti:
315
323
  # Calculate Embeddings
316
324
 
317
325
  await asyncio.gather(
318
- *[node.generate_name_embedding(embedder) for node in extracted_nodes]
326
+ *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
319
327
  )
320
328
 
321
329
  # Resolve extracted nodes with nodes already in the graph and extract facts
@@ -343,7 +351,7 @@ class Graphiti:
343
351
  # calculate embeddings
344
352
  await asyncio.gather(
345
353
  *[
346
- edge.generate_embedding(embedder)
354
+ edge.generate_embedding(self.embedder)
347
355
  for edge in extracted_edges_with_resolved_pointers
348
356
  ]
349
357
  )
@@ -436,7 +444,7 @@ class Graphiti:
436
444
  if update_communities:
437
445
  await asyncio.gather(
438
446
  *[
439
- update_community(self.driver, self.llm_client, embedder, node)
447
+ update_community(self.driver, self.llm_client, self.embedder, node)
440
448
  for node in nodes
441
449
  ]
442
450
  )
@@ -485,7 +493,6 @@ class Graphiti:
485
493
  """
486
494
  try:
487
495
  start = time()
488
- embedder = self.llm_client.get_embedder()
489
496
  now = datetime.now()
490
497
 
491
498
  episodes = [
@@ -517,8 +524,8 @@ class Graphiti:
517
524
 
518
525
  # Generate embeddings
519
526
  await asyncio.gather(
520
- *[node.generate_name_embedding(embedder) for node in extracted_nodes],
521
- *[edge.generate_embedding(embedder) for edge in extracted_edges],
527
+ *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
528
+ *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
522
529
  )
523
530
 
524
531
  # Dedupe extracted nodes, compress extracted edges
@@ -561,14 +568,14 @@ class Graphiti:
561
568
  raise e
562
569
 
563
570
  async def build_communities(self):
564
- embedder = self.llm_client.get_embedder()
565
-
566
571
  # Clear existing communities
567
572
  await remove_communities(self.driver)
568
573
 
569
574
  community_nodes, community_edges = await build_communities(self.driver, self.llm_client)
570
575
 
571
- await asyncio.gather(*[node.generate_name_embedding(embedder) for node in community_nodes])
576
+ await asyncio.gather(
577
+ *[node.generate_name_embedding(self.embedder) for node in community_nodes]
578
+ )
572
579
 
573
580
  await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
574
581
  await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
@@ -619,7 +626,7 @@ class Graphiti:
619
626
  edges = (
620
627
  await search(
621
628
  self.driver,
622
- self.llm_client.get_embedder(),
629
+ self.embedder,
623
630
  query,
624
631
  group_ids,
625
632
  search_config,
@@ -636,9 +643,7 @@ class Graphiti:
636
643
  group_ids: list[str] | None = None,
637
644
  center_node_uuid: str | None = None,
638
645
  ) -> SearchResults:
639
- return await search(
640
- self.driver, self.llm_client.get_embedder(), query, group_ids, config, center_node_uuid
641
- )
646
+ return await search(self.driver, self.embedder, query, group_ids, config, center_node_uuid)
642
647
 
643
648
  async def get_nodes_by_query(
644
649
  self,
@@ -683,14 +688,15 @@ class Graphiti:
683
688
  to each individual search method before results are combined and deduplicated.
684
689
  If not specified, a default limit (defined in the search functions) will be used.
685
690
  """
686
- embedder = self.llm_client.get_embedder()
687
691
  search_config = (
688
692
  NODE_HYBRID_SEARCH_RRF if center_node_uuid is None else NODE_HYBRID_SEARCH_NODE_DISTANCE
689
693
  )
690
694
  search_config.limit = limit
691
695
 
692
696
  nodes = (
693
- await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid)
697
+ await search(
698
+ self.driver, self.embedder, query, group_ids, search_config, center_node_uuid
699
+ )
694
700
  ).nodes
695
701
  return nodes
696
702
 
@@ -0,0 +1,53 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ from datetime import datetime
18
+
19
+ from neo4j import time as neo4j_time
20
+
21
+
22
+ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
23
+ return neo_date.to_native() if neo_date else None
24
+
25
+
26
+ def lucene_sanitize(query: str) -> str:
27
+ # Escape special characters from a query before passing into Lucene
28
+ # + - && || ! ( ) { } [ ] ^ " ~ * ? : \
29
+ escape_map = str.maketrans(
30
+ {
31
+ '+': r'\+',
32
+ '-': r'\-',
33
+ '&': r'\&',
34
+ '|': r'\|',
35
+ '!': r'\!',
36
+ '(': r'\(',
37
+ ')': r'\)',
38
+ '{': r'\{',
39
+ '}': r'\}',
40
+ '[': r'\[',
41
+ ']': r'\]',
42
+ '^': r'\^',
43
+ '"': r'\"',
44
+ '~': r'\~',
45
+ '*': r'\*',
46
+ '?': r'\?',
47
+ ':': r'\:',
48
+ '\\': r'\\',
49
+ }
50
+ )
51
+
52
+ sanitized = query.translate(escape_map)
53
+ return sanitized
@@ -20,7 +20,6 @@ import typing
20
20
 
21
21
  import anthropic
22
22
  from anthropic import AsyncAnthropic
23
- from openai import AsyncOpenAI
24
23
 
25
24
  from ..prompts.models import Message
26
25
  from .client import LLMClient
@@ -47,10 +46,6 @@ class AnthropicClient(LLMClient):
47
46
  max_retries=1,
48
47
  )
49
48
 
50
- def get_embedder(self) -> typing.Any:
51
- openai_client = AsyncOpenAI()
52
- return openai_client.embeddings
53
-
54
49
  async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
55
50
  system_message = messages[0]
56
51
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
@@ -55,10 +55,6 @@ class LLMClient(ABC):
55
55
  self.cache_enabled = cache
56
56
  self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
57
57
 
58
- @abstractmethod
59
- def get_embedder(self) -> typing.Any:
60
- pass
61
-
62
58
  @retry(
63
59
  stop=stop_after_attempt(4),
64
60
  wait=wait_random_exponential(multiplier=10, min=5, max=120),
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- EMBEDDING_DIM = 1024
18
17
  DEFAULT_MAX_TOKENS = 16384
19
18
  DEFAULT_TEMPERATURE = 0
20
19
 
@@ -21,7 +21,6 @@ import typing
21
21
  import groq
22
22
  from groq import AsyncGroq
23
23
  from groq.types.chat import ChatCompletionMessageParam
24
- from openai import AsyncOpenAI
25
24
 
26
25
  from ..prompts.models import Message
27
26
  from .client import LLMClient
@@ -44,10 +43,6 @@ class GroqClient(LLMClient):
44
43
 
45
44
  self.client = AsyncGroq(api_key=config.api_key)
46
45
 
47
- def get_embedder(self) -> typing.Any:
48
- openai_client = AsyncOpenAI()
49
- return openai_client.embeddings
50
-
51
46
  async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
52
47
  msgs: list[ChatCompletionMessageParam] = []
53
48
  for m in messages:
@@ -49,9 +49,6 @@ class OpenAIClient(LLMClient):
49
49
  __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
50
50
  Initializes the OpenAIClient with the provided configuration, cache setting, and client.
51
51
 
52
- get_embedder() -> typing.Any:
53
- Returns the embedder from the OpenAI client.
54
-
55
52
  _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
56
53
  Generates a response from the language model based on the provided messages.
57
54
  """
@@ -78,9 +75,6 @@ class OpenAIClient(LLMClient):
78
75
  else:
79
76
  self.client = client
80
77
 
81
- def get_embedder(self) -> typing.Any:
82
- return self.client.embeddings
83
-
84
78
  async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
85
79
  openai_messages: list[ChatCompletionMessageParam] = []
86
80
  for m in messages:
@@ -15,22 +15,18 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
- import typing
19
18
  from time import time
20
19
 
21
- from graphiti_core.llm_client.config import EMBEDDING_DIM
20
+ from graphiti_core.embedder.client import EmbedderClient
22
21
 
23
22
  logger = logging.getLogger(__name__)
24
23
 
25
24
 
26
- async def generate_embedding(
27
- embedder: typing.Any, text: str, model: str = 'text-embedding-3-small'
28
- ):
25
+ async def generate_embedding(embedder: EmbedderClient, text: str):
29
26
  start = time()
30
27
 
31
28
  text = text.replace('\n', ' ')
32
- embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
33
- embedding = embedding[:EMBEDDING_DIM]
29
+ embedding = await embedder.create(input=[text])
34
30
 
35
31
  end = time()
36
32
  logger.debug(f'embedded text of length {len(text)} in {end - start} ms')
@@ -25,8 +25,8 @@ from uuid import uuid4
25
25
  from neo4j import AsyncDriver
26
26
  from pydantic import BaseModel, Field
27
27
 
28
+ from graphiti_core.embedder import EmbedderClient
28
29
  from graphiti_core.errors import NodeNotFoundError
29
- from graphiti_core.llm_client.config import EMBEDDING_DIM
30
30
 
31
31
  logger = logging.getLogger(__name__)
32
32
 
@@ -212,15 +212,14 @@ class EntityNode(Node):
212
212
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
213
213
  summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
214
214
 
215
- async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
215
+ async def generate_name_embedding(self, embedder: EmbedderClient):
216
216
  start = time()
217
217
  text = self.name.replace('\n', ' ')
218
- embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
219
- self.name_embedding = embedding[:EMBEDDING_DIM]
218
+ self.name_embedding = await embedder.create(input=[text])
220
219
  end = time()
221
220
  logger.info(f'embedded {text} in {end - start} ms')
222
221
 
223
- return embedding
222
+ return self.name_embedding
224
223
 
225
224
  async def save(self, driver: AsyncDriver):
226
225
  result = await driver.execute_query(
@@ -323,15 +322,14 @@ class CommunityNode(Node):
323
322
 
324
323
  return result
325
324
 
326
- async def generate_name_embedding(self, embedder, model='text-embedding-3-small'):
325
+ async def generate_name_embedding(self, embedder: EmbedderClient):
327
326
  start = time()
328
327
  text = self.name.replace('\n', ' ')
329
- embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
330
- self.name_embedding = embedding[:EMBEDDING_DIM]
328
+ self.name_embedding = await embedder.create(input=[text])
331
329
  end = time()
332
330
  logger.info(f'embedded {text} in {end - start} ms')
333
331
 
334
- return embedding
332
+ return self.name_embedding
335
333
 
336
334
  @classmethod
337
335
  async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
@@ -0,0 +1,90 @@
1
+ """
2
+ Copyright 2024, Zep Software, Inc.
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
17
+ import json
18
+ from typing import Any, Protocol, TypedDict
19
+
20
+ from .models import Message, PromptFunction, PromptVersion
21
+
22
+
23
+ class Prompt(Protocol):
24
+ qa_prompt: PromptVersion
25
+ eval_prompt: PromptVersion
26
+
27
+
28
+ class Versions(TypedDict):
29
+ qa_prompt: PromptFunction
30
+ eval_prompt: PromptFunction
31
+
32
+
33
+ def qa_prompt(context: dict[str, Any]) -> list[Message]:
34
+ sys_prompt = """You are Alice and should respond to all questions from the first person perspective of Alice"""
35
+
36
+ user_prompt = f"""
37
+ Your task is to briefly answer the question in the way that you think Alice would answer the question.
38
+ You are given the following entity summaries and facts to help you determine the answer to your question.
39
+ <ENTITY_SUMMARIES>
40
+ {json.dumps(context['entity_summaries'])}
41
+ </ENTITY_SUMMARIES
42
+ <FACTS>
43
+ {json.dumps(context['facts'])}
44
+ </FACTS>
45
+ <QUESTION>
46
+ {context['query']}
47
+ </QUESTION>
48
+ respond with a JSON object in the following format:
49
+ {{
50
+ "ANSWER": "how Alice would answer the question"
51
+ }}
52
+ """
53
+ return [
54
+ Message(role='system', content=sys_prompt),
55
+ Message(role='user', content=user_prompt),
56
+ ]
57
+
58
+
59
+ def eval_prompt(context: dict[str, Any]) -> list[Message]:
60
+ sys_prompt = (
61
+ """You are a judge that determines if answers to questions match a gold standard answer"""
62
+ )
63
+
64
+ user_prompt = f"""
65
+ Given the QUESTION and the gold standard ANSWER determine if the RESPONSE to the question is correct or incorrect.
66
+ Although the RESPONSE may be more verbose, mark it as correct as long as it references the same topic
67
+ as the gold standard ANSWER. Also include your reasoning for the grade.
68
+ <QUESTION>
69
+ {context['query']}
70
+ </QUESTION>
71
+ <ANSWER>
72
+ {context['answer']}
73
+ </ANSWER>
74
+ <RESPONSE>
75
+ {context['response']}
76
+ </RESPONSE>
77
+
78
+ respond with a JSON object in the following format:
79
+ {{
80
+ "is_correct": "boolean if the answer is correct or incorrect"
81
+ "reasoning": "why you determined the response was correct or incorrect"
82
+ }}
83
+ """
84
+ return [
85
+ Message(role='system', content=sys_prompt),
86
+ Message(role='user', content=user_prompt),
87
+ ]
88
+
89
+
90
+ versions: Versions = {'qa_prompt': qa_prompt, 'eval_prompt': eval_prompt}
@@ -34,6 +34,9 @@ from .dedupe_nodes import (
34
34
  from .dedupe_nodes import (
35
35
  versions as dedupe_nodes_versions,
36
36
  )
37
+ from .eval import Prompt as EvalPrompt
38
+ from .eval import Versions as EvalVersions
39
+ from .eval import versions as eval_versions
37
40
  from .extract_edge_dates import (
38
41
  Prompt as ExtractEdgeDatesPrompt,
39
42
  )
@@ -84,6 +87,7 @@ class PromptLibrary(Protocol):
84
87
  invalidate_edges: InvalidateEdgesPrompt
85
88
  extract_edge_dates: ExtractEdgeDatesPrompt
86
89
  summarize_nodes: SummarizeNodesPrompt
90
+ eval: EvalPrompt
87
91
 
88
92
 
89
93
  class PromptLibraryImpl(TypedDict):
@@ -94,6 +98,7 @@ class PromptLibraryImpl(TypedDict):
94
98
  invalidate_edges: InvalidateEdgesVersions
95
99
  extract_edge_dates: ExtractEdgeDatesVersions
96
100
  summarize_nodes: SummarizeNodesVersions
101
+ eval: EvalVersions
97
102
 
98
103
 
99
104
  class VersionWrapper:
@@ -124,5 +129,6 @@ PROMPT_LIBRARY_IMPL: PromptLibraryImpl = {
124
129
  'invalidate_edges': invalidate_edges_versions,
125
130
  'extract_edge_dates': extract_edge_dates_versions,
126
131
  'summarize_nodes': summarize_nodes_versions,
132
+ 'eval': eval_versions,
127
133
  }
128
134
  prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]