graphiti-core 0.4.3__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (61) hide show
  1. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/PKG-INFO +1 -1
  2. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/client.py +1 -1
  3. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/openai_reranker_client.py +2 -2
  4. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/edges.py +13 -10
  5. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/graphiti.py +25 -27
  6. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/helpers.py +25 -0
  7. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/anthropic_client.py +4 -1
  8. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/client.py +45 -5
  9. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/errors.py +8 -0
  10. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/groq_client.py +4 -1
  11. graphiti_core-0.5.0/graphiti_core/llm_client/openai_client.py +163 -0
  12. graphiti_core-0.4.3/graphiti_core/llm_client/openai_client.py → graphiti_core-0.5.0/graphiti_core/llm_client/openai_generic_client.py +67 -3
  13. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/nodes.py +16 -12
  14. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/dedupe_edges.py +20 -17
  15. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/dedupe_nodes.py +15 -1
  16. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/eval.py +17 -14
  17. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_edge_dates.py +15 -7
  18. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_edges.py +18 -19
  19. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/extract_nodes.py +11 -21
  20. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/invalidate_edges.py +13 -25
  21. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/summarize_nodes.py +17 -16
  22. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/search/search.py +5 -5
  23. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/search/search_utils.py +54 -13
  24. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/bulk_utils.py +22 -15
  25. graphiti_core-0.5.0/graphiti_core/utils/datetime_utils.py +42 -0
  26. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/community_operations.py +13 -9
  27. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/edge_operations.py +26 -19
  28. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/graph_data_operations.py +3 -4
  29. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/node_operations.py +19 -13
  30. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/temporal_operations.py +16 -7
  31. graphiti_core-0.5.0/graphiti_core/utils/maintenance/utils.py +0 -0
  32. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/pyproject.toml +1 -1
  33. graphiti_core-0.4.3/graphiti_core/utils/__init__.py +0 -15
  34. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/LICENSE +0 -0
  35. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/README.md +0 -0
  36. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/__init__.py +0 -0
  37. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/__init__.py +0 -0
  38. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/cross_encoder/bge_reranker_client.py +0 -0
  39. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/embedder/__init__.py +0 -0
  40. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/embedder/client.py +0 -0
  41. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/embedder/openai.py +0 -0
  42. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/embedder/voyage.py +0 -0
  43. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/errors.py +0 -0
  44. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/__init__.py +0 -0
  45. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/config.py +0 -0
  46. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/llm_client/utils.py +0 -0
  47. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/models/__init__.py +0 -0
  48. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/models/edges/__init__.py +0 -0
  49. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/models/edges/edge_db_queries.py +0 -0
  50. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/models/nodes/__init__.py +0 -0
  51. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/models/nodes/node_db_queries.py +0 -0
  52. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/__init__.py +0 -0
  53. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/lib.py +0 -0
  54. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/models.py +0 -0
  55. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/prompts/prompt_helpers.py +0 -0
  56. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/py.typed +0 -0
  57. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/search/__init__.py +0 -0
  58. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/search/search_config.py +0 -0
  59. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/search/search_config_recipes.py +0 -0
  60. /graphiti_core-0.4.3/graphiti_core/utils/maintenance/utils.py → /graphiti_core-0.5.0/graphiti_core/utils/__init__.py +0 -0
  61. {graphiti_core-0.4.3 → graphiti_core-0.5.0}/graphiti_core/utils/maintenance/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: graphiti-core
3
- Version: 0.4.3
3
+ Version: 0.5.0
4
4
  Summary: A temporal graph building library
5
5
  License: Apache-2.0
6
6
  Author: Paul Paliychuk
@@ -34,7 +34,7 @@ class CrossEncoderClient(ABC):
34
34
  passages (list[str]): A list of passages to rank.
35
35
 
36
36
  Returns:
37
- List[tuple[str, float]]: A list of tuples containing the passage and its score,
37
+ list[tuple[str, float]]: A list of tuples containing the passage and its score,
38
38
  sorted in descending order of relevance.
39
39
  """
40
40
  pass
@@ -14,7 +14,6 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
18
  from typing import Any
20
19
 
@@ -22,6 +21,7 @@ import openai
22
21
  from openai import AsyncOpenAI
23
22
  from pydantic import BaseModel
24
23
 
24
+ from ..helpers import semaphore_gather
25
25
  from ..llm_client import LLMConfig, RateLimitError
26
26
  from ..prompts import Message
27
27
  from .client import CrossEncoderClient
@@ -75,7 +75,7 @@ class OpenAIRerankerClient(CrossEncoderClient):
75
75
  for passage in passages
76
76
  ]
77
77
  try:
78
- responses = await asyncio.gather(
78
+ responses = await semaphore_gather(
79
79
  *[
80
80
  self.client.chat.completions.create(
81
81
  model=DEFAULT_MODEL,
@@ -27,7 +27,7 @@ from typing_extensions import LiteralString
27
27
 
28
28
  from graphiti_core.embedder import EmbedderClient
29
29
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
30
- from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
30
+ from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
31
31
  from graphiti_core.models.edges.edge_db_queries import (
32
32
  COMMUNITY_EDGE_SAVE,
33
33
  ENTITY_EDGE_SAVE,
@@ -142,10 +142,11 @@ class EpisodicEdge(Edge):
142
142
  cls,
143
143
  driver: AsyncDriver,
144
144
  group_ids: list[str],
145
- limit: int = DEFAULT_PAGE_LIMIT,
145
+ limit: int | None = None,
146
146
  created_at: datetime | None = None,
147
147
  ):
148
148
  cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
149
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
149
150
 
150
151
  records, _, _ = await driver.execute_query(
151
152
  """
@@ -161,8 +162,8 @@ class EpisodicEdge(Edge):
161
162
  m.uuid AS target_node_uuid,
162
163
  e.created_at AS created_at
163
164
  ORDER BY e.uuid DESC
164
- LIMIT $limit
165
- """,
165
+ """
166
+ + limit_query,
166
167
  group_ids=group_ids,
167
168
  created_at=created_at,
168
169
  limit=limit,
@@ -294,10 +295,11 @@ class EntityEdge(Edge):
294
295
  cls,
295
296
  driver: AsyncDriver,
296
297
  group_ids: list[str],
297
- limit: int = DEFAULT_PAGE_LIMIT,
298
+ limit: int | None = None,
298
299
  created_at: datetime | None = None,
299
300
  ):
300
301
  cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
302
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
301
303
 
302
304
  records, _, _ = await driver.execute_query(
303
305
  """
@@ -320,8 +322,8 @@ class EntityEdge(Edge):
320
322
  e.valid_at AS valid_at,
321
323
  e.invalid_at AS invalid_at
322
324
  ORDER BY e.uuid DESC
323
- LIMIT $limit
324
- """,
325
+ """
326
+ + limit_query,
325
327
  group_ids=group_ids,
326
328
  created_at=created_at,
327
329
  limit=limit,
@@ -400,10 +402,11 @@ class CommunityEdge(Edge):
400
402
  cls,
401
403
  driver: AsyncDriver,
402
404
  group_ids: list[str],
403
- limit: int = DEFAULT_PAGE_LIMIT,
405
+ limit: int | None = None,
404
406
  created_at: datetime | None = None,
405
407
  ):
406
408
  cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''
409
+ limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
407
410
 
408
411
  records, _, _ = await driver.execute_query(
409
412
  """
@@ -419,8 +422,8 @@ class CommunityEdge(Edge):
419
422
  m.uuid AS target_node_uuid,
420
423
  e.created_at AS created_at
421
424
  ORDER BY e.uuid DESC
422
- LIMIT $limit
423
- """,
425
+ """
426
+ + limit_query,
424
427
  group_ids=group_ids,
425
428
  created_at=created_at,
426
429
  limit=limit,
@@ -14,9 +14,8 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
- import asyncio
18
17
  import logging
19
- from datetime import datetime, timezone
18
+ from datetime import datetime
20
19
  from time import time
21
20
 
22
21
  from dotenv import load_dotenv
@@ -27,7 +26,7 @@ from graphiti_core.cross_encoder.client import CrossEncoderClient
27
26
  from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
28
27
  from graphiti_core.edges import EntityEdge, EpisodicEdge
29
28
  from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
30
- from graphiti_core.helpers import DEFAULT_DATABASE
29
+ from graphiti_core.helpers import DEFAULT_DATABASE, semaphore_gather
31
30
  from graphiti_core.llm_client import LLMClient, OpenAIClient
32
31
  from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
33
32
  from graphiti_core.search.search import SearchConfig, search
@@ -43,10 +42,6 @@ from graphiti_core.search.search_utils import (
43
42
  get_relevant_edges,
44
43
  get_relevant_nodes,
45
44
  )
46
- from graphiti_core.utils import (
47
- build_episodic_edges,
48
- retrieve_episodes,
49
- )
50
45
  from graphiti_core.utils.bulk_utils import (
51
46
  RawEpisode,
52
47
  add_nodes_and_edges_bulk,
@@ -57,12 +52,14 @@ from graphiti_core.utils.bulk_utils import (
57
52
  resolve_edge_pointers,
58
53
  retrieve_previous_episodes_bulk,
59
54
  )
55
+ from graphiti_core.utils.datetime_utils import utc_now
60
56
  from graphiti_core.utils.maintenance.community_operations import (
61
57
  build_communities,
62
58
  remove_communities,
63
59
  update_community,
64
60
  )
65
61
  from graphiti_core.utils.maintenance.edge_operations import (
62
+ build_episodic_edges,
66
63
  dedupe_extracted_edge,
67
64
  extract_edges,
68
65
  resolve_edge_contradictions,
@@ -71,6 +68,7 @@ from graphiti_core.utils.maintenance.edge_operations import (
71
68
  from graphiti_core.utils.maintenance.graph_data_operations import (
72
69
  EPISODE_WINDOW_LEN,
73
70
  build_indices_and_constraints,
71
+ retrieve_episodes,
74
72
  )
75
73
  from graphiti_core.utils.maintenance.node_operations import (
76
74
  extract_nodes,
@@ -313,7 +311,7 @@ class Graphiti:
313
311
  start = time()
314
312
 
315
313
  entity_edges: list[EntityEdge] = []
316
- now = datetime.now(timezone.utc)
314
+ now = utc_now()
317
315
 
318
316
  previous_episodes = await self.retrieve_episodes(
319
317
  reference_time, last_n=RELEVANT_SCHEMA_LIMIT, group_ids=[group_id]
@@ -341,13 +339,13 @@ class Graphiti:
341
339
 
342
340
  # Calculate Embeddings
343
341
 
344
- await asyncio.gather(
342
+ await semaphore_gather(
345
343
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes]
346
344
  )
347
345
 
348
346
  # Find relevant nodes already in the graph
349
347
  existing_nodes_lists: list[list[EntityNode]] = list(
350
- await asyncio.gather(
348
+ await semaphore_gather(
351
349
  *[get_relevant_nodes(self.driver, [node]) for node in extracted_nodes]
352
350
  )
353
351
  )
@@ -355,7 +353,7 @@ class Graphiti:
355
353
  # Resolve extracted nodes with nodes already in the graph and extract facts
356
354
  logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
357
355
 
358
- (mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
356
+ (mentioned_nodes, uuid_map), extracted_edges = await semaphore_gather(
359
357
  resolve_extracted_nodes(
360
358
  self.llm_client,
361
359
  extracted_nodes,
@@ -375,7 +373,7 @@ class Graphiti:
375
373
  )
376
374
 
377
375
  # calculate embeddings
378
- await asyncio.gather(
376
+ await semaphore_gather(
379
377
  *[
380
378
  edge.generate_embedding(self.embedder)
381
379
  for edge in extracted_edges_with_resolved_pointers
@@ -384,7 +382,7 @@ class Graphiti:
384
382
 
385
383
  # Resolve extracted edges with related edges already in the graph
386
384
  related_edges_list: list[list[EntityEdge]] = list(
387
- await asyncio.gather(
385
+ await semaphore_gather(
388
386
  *[
389
387
  get_relevant_edges(
390
388
  self.driver,
@@ -405,7 +403,7 @@ class Graphiti:
405
403
  )
406
404
 
407
405
  existing_source_edges_list: list[list[EntityEdge]] = list(
408
- await asyncio.gather(
406
+ await semaphore_gather(
409
407
  *[
410
408
  get_relevant_edges(
411
409
  self.driver,
@@ -420,7 +418,7 @@ class Graphiti:
420
418
  )
421
419
 
422
420
  existing_target_edges_list: list[list[EntityEdge]] = list(
423
- await asyncio.gather(
421
+ await semaphore_gather(
424
422
  *[
425
423
  get_relevant_edges(
426
424
  self.driver,
@@ -469,7 +467,7 @@ class Graphiti:
469
467
 
470
468
  # Update any communities
471
469
  if update_communities:
472
- await asyncio.gather(
470
+ await semaphore_gather(
473
471
  *[
474
472
  update_community(self.driver, self.llm_client, self.embedder, node)
475
473
  for node in nodes
@@ -522,7 +520,7 @@ class Graphiti:
522
520
  """
523
521
  try:
524
522
  start = time()
525
- now = datetime.now(timezone.utc)
523
+ now = utc_now()
526
524
 
527
525
  episodes = [
528
526
  EpisodicNode(
@@ -539,7 +537,7 @@ class Graphiti:
539
537
  ]
540
538
 
541
539
  # Save all the episodes
542
- await asyncio.gather(*[episode.save(self.driver) for episode in episodes])
540
+ await semaphore_gather(*[episode.save(self.driver) for episode in episodes])
543
541
 
544
542
  # Get previous episode context for each episode
545
543
  episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes)
@@ -552,19 +550,19 @@ class Graphiti:
552
550
  ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs)
553
551
 
554
552
  # Generate embeddings
555
- await asyncio.gather(
553
+ await semaphore_gather(
556
554
  *[node.generate_name_embedding(self.embedder) for node in extracted_nodes],
557
555
  *[edge.generate_embedding(self.embedder) for edge in extracted_edges],
558
556
  )
559
557
 
560
558
  # Dedupe extracted nodes, compress extracted edges
561
- (nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
559
+ (nodes, uuid_map), extracted_edges_timestamped = await semaphore_gather(
562
560
  dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
563
561
  extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
564
562
  )
565
563
 
566
564
  # save nodes to KG
567
- await asyncio.gather(*[node.save(self.driver) for node in nodes])
565
+ await semaphore_gather(*[node.save(self.driver) for node in nodes])
568
566
 
569
567
  # re-map edge pointers so that they don't point to discard dupe nodes
570
568
  extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
@@ -575,7 +573,7 @@ class Graphiti:
575
573
  )
576
574
 
577
575
  # save episodic edges to KG
578
- await asyncio.gather(
576
+ await semaphore_gather(
579
577
  *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers]
580
578
  )
581
579
 
@@ -588,7 +586,7 @@ class Graphiti:
588
586
  # invalidate edges
589
587
 
590
588
  # save edges to KG
591
- await asyncio.gather(*[edge.save(self.driver) for edge in edges])
589
+ await semaphore_gather(*[edge.save(self.driver) for edge in edges])
592
590
 
593
591
  end = time()
594
592
  logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
@@ -611,12 +609,12 @@ class Graphiti:
611
609
  self.driver, self.llm_client, group_ids
612
610
  )
613
611
 
614
- await asyncio.gather(
612
+ await semaphore_gather(
615
613
  *[node.generate_name_embedding(self.embedder) for node in community_nodes]
616
614
  )
617
615
 
618
- await asyncio.gather(*[node.save(self.driver) for node in community_nodes])
619
- await asyncio.gather(*[edge.save(self.driver) for edge in community_edges])
616
+ await semaphore_gather(*[node.save(self.driver) for node in community_nodes])
617
+ await semaphore_gather(*[edge.save(self.driver) for edge in community_edges])
620
618
 
621
619
  return community_nodes
622
620
 
@@ -699,7 +697,7 @@ class Graphiti:
699
697
  async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults:
700
698
  episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
701
699
 
702
- edges_list = await asyncio.gather(
700
+ edges_list = await semaphore_gather(
703
701
  *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes]
704
702
  )
705
703
 
@@ -14,7 +14,9 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import os
19
+ from collections.abc import Coroutine
18
20
  from datetime import datetime
19
21
 
20
22
  import numpy as np
@@ -25,6 +27,7 @@ load_dotenv()
25
27
 
26
28
  DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
27
29
  USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
30
+ SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 20))
28
31
  MAX_REFLEXION_ITERATIONS = 2
29
32
  DEFAULT_PAGE_LIMIT = 20
30
33
 
@@ -57,6 +60,12 @@ def lucene_sanitize(query: str) -> str:
57
60
  ':': r'\:',
58
61
  '\\': r'\\',
59
62
  '/': r'\/',
63
+ 'O': r'\O',
64
+ 'R': r'\R',
65
+ 'N': r'\N',
66
+ 'T': r'\T',
67
+ 'A': r'\A',
68
+ 'D': r'\D',
60
69
  }
61
70
  )
62
71
 
@@ -74,3 +83,19 @@ def normalize_l2(embedding: list[float]) -> list[float]:
74
83
  else:
75
84
  norm = np.linalg.norm(embedding_array, 2, axis=1, keepdims=True)
76
85
  return (np.where(norm == 0, embedding_array, embedding_array / norm)).tolist()
86
+
87
+
88
+ # Use this instead of asyncio.gather() to bound coroutines
89
+ async def semaphore_gather(
90
+ *coroutines: Coroutine, max_coroutines: int = SEMAPHORE_LIMIT, return_exceptions=True
91
+ ):
92
+ semaphore = asyncio.Semaphore(max_coroutines)
93
+
94
+ async def _wrap_coroutine(coroutine):
95
+ async with semaphore:
96
+ return await coroutine
97
+
98
+ return await asyncio.gather(
99
+ *(_wrap_coroutine(coroutine) for coroutine in coroutines),
100
+ return_exceptions=return_exceptions,
101
+ )
@@ -20,6 +20,7 @@ import typing
20
20
 
21
21
  import anthropic
22
22
  from anthropic import AsyncAnthropic
23
+ from pydantic import BaseModel
23
24
 
24
25
  from ..prompts.models import Message
25
26
  from .client import LLMClient
@@ -46,7 +47,9 @@ class AnthropicClient(LLMClient):
46
47
  max_retries=1,
47
48
  )
48
49
 
49
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
50
+ async def _generate_response(
51
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
52
+ ) -> dict[str, typing.Any]:
50
53
  system_message = messages[0]
51
54
  user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]] + [
52
55
  {'role': 'assistant', 'content': '{'}
@@ -22,6 +22,7 @@ from abc import ABC, abstractmethod
22
22
 
23
23
  import httpx
24
24
  from diskcache import Cache
25
+ from pydantic import BaseModel
25
26
  from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
26
27
 
27
28
  from ..prompts.models import Message
@@ -55,6 +56,28 @@ class LLMClient(ABC):
55
56
  self.cache_enabled = cache
56
57
  self.cache_dir = Cache(DEFAULT_CACHE_DIR) # Create a cache directory
57
58
 
59
+ def _clean_input(self, input: str) -> str:
60
+ """Clean input string of invalid unicode and control characters.
61
+
62
+ Args:
63
+ input: Raw input string to be cleaned
64
+
65
+ Returns:
66
+ Cleaned string safe for LLM processing
67
+ """
68
+ # Clean any invalid Unicode
69
+ cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
70
+
71
+ # Remove zero-width characters and other invisible unicode
72
+ zero_width = '\u200b\u200c\u200d\ufeff\u2060'
73
+ for char in zero_width:
74
+ cleaned = cleaned.replace(char, '')
75
+
76
+ # Remove control characters except newlines, returns, and tabs
77
+ cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
78
+
79
+ return cleaned
80
+
58
81
  @retry(
59
82
  stop=stop_after_attempt(4),
60
83
  wait=wait_random_exponential(multiplier=10, min=5, max=120),
@@ -66,14 +89,18 @@ class LLMClient(ABC):
66
89
  else None,
67
90
  reraise=True,
68
91
  )
69
- async def _generate_response_with_retry(self, messages: list[Message]) -> dict[str, typing.Any]:
92
+ async def _generate_response_with_retry(
93
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
94
+ ) -> dict[str, typing.Any]:
70
95
  try:
71
- return await self._generate_response(messages)
96
+ return await self._generate_response(messages, response_model)
72
97
  except (httpx.HTTPStatusError, RateLimitError) as e:
73
98
  raise e
74
99
 
75
100
  @abstractmethod
76
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
101
+ async def _generate_response(
102
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
103
+ ) -> dict[str, typing.Any]:
77
104
  pass
78
105
 
79
106
  def _get_cache_key(self, messages: list[Message]) -> str:
@@ -82,7 +109,17 @@ class LLMClient(ABC):
82
109
  key_str = f'{self.model}:{message_str}'
83
110
  return hashlib.md5(key_str.encode()).hexdigest()
84
111
 
85
- async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
112
+ async def generate_response(
113
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
114
+ ) -> dict[str, typing.Any]:
115
+ if response_model is not None:
116
+ serialized_model = json.dumps(response_model.model_json_schema())
117
+ messages[
118
+ -1
119
+ ].content += (
120
+ f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
121
+ )
122
+
86
123
  if self.cache_enabled:
87
124
  cache_key = self._get_cache_key(messages)
88
125
 
@@ -91,7 +128,10 @@ class LLMClient(ABC):
91
128
  logger.debug(f'Cache hit for {cache_key}')
92
129
  return cached_response
93
130
 
94
- response = await self._generate_response_with_retry(messages)
131
+ for message in messages:
132
+ message.content = self._clean_input(message.content)
133
+
134
+ response = await self._generate_response_with_retry(messages, response_model)
95
135
 
96
136
  if self.cache_enabled:
97
137
  self.cache_dir.set(cache_key, response)
@@ -21,3 +21,11 @@ class RateLimitError(Exception):
21
21
  def __init__(self, message='Rate limit exceeded. Please try again later.'):
22
22
  self.message = message
23
23
  super().__init__(self.message)
24
+
25
+
26
+ class RefusalError(Exception):
27
+ """Exception raised when the LLM refuses to generate a response."""
28
+
29
+ def __init__(self, message: str):
30
+ self.message = message
31
+ super().__init__(self.message)
@@ -21,6 +21,7 @@ import typing
21
21
  import groq
22
22
  from groq import AsyncGroq
23
23
  from groq.types.chat import ChatCompletionMessageParam
24
+ from pydantic import BaseModel
24
25
 
25
26
  from ..prompts.models import Message
26
27
  from .client import LLMClient
@@ -43,7 +44,9 @@ class GroqClient(LLMClient):
43
44
 
44
45
  self.client = AsyncGroq(api_key=config.api_key)
45
46
 
46
- async def _generate_response(self, messages: list[Message]) -> dict[str, typing.Any]:
47
+ async def _generate_response(
48
+ self, messages: list[Message], response_model: type[BaseModel] | None = None
49
+ ) -> dict[str, typing.Any]:
47
50
  msgs: list[ChatCompletionMessageParam] = []
48
51
  for m in messages:
49
52
  if m.role == 'user':