deepeval 3.4.7__py3-none-any.whl → 3.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,117 @@
1
+ from typing import List
2
+ from textwrap import dedent
3
+
4
+
5
+ class ConversationalVerdictNodeTemplate:
6
+ @staticmethod
7
+ def generate_reason(verbose_steps: List[str], score: float, name: str):
8
+ return dedent(
9
+ f"""You are given a metric name, its score, and a traversal path through a conversational evaluation DAG (Directed Acyclic Graph).
10
+ This DAG reflects step-by-step reasoning over a dialogue to arrive at the final verdict.
11
+
12
+ Each step in the DAG represents a judgment based on parts of the conversation — including roles and the contents they spoke of.
13
+
14
+ Your task is to explain **why the score was assigned**, using the traversal steps to justify the reasoning.
15
+
16
+ Metric Name:
17
+ {name}
18
+
19
+ Score:
20
+ {score}
21
+
22
+ DAG Traversal:
23
+ {verbose_steps}
24
+
25
+ **
26
+ IMPORTANT: Only return JSON with a 'reason' key.
27
+ Example:
28
+ {{
29
+ "reason": "The score is {score} because the assistant repeatedly failed to clarify the user's ambiguous statements, as shown in the DAG traversal path."
30
+ }}
31
+ **
32
+ JSON:
33
+ """
34
+ )
35
+
36
+
37
+ class ConversationalTaskNodeTemplate:
38
+ @staticmethod
39
+ def generate_task_output(instructions: str, text: str):
40
+ return dedent(
41
+ f"""You are given a set of task instructions and a full conversation between a user and an assistant.
42
+
43
+ Instructions:
44
+ {instructions}
45
+
46
+ {text}
47
+
48
+ ===END OF INPUT===
49
+
50
+ **
51
+ IMPORTANT: Only return a JSON with the 'output' key containing the result of applying the instructions to the conversation.
52
+ Example:
53
+ {{
54
+ "output": "..."
55
+ }}
56
+ **
57
+ JSON:
58
+ """
59
+ )
60
+
61
+
62
+ class ConversationalBinaryJudgementTemplate:
63
+ @staticmethod
64
+ def generate_binary_verdict(criteria: str, text: str):
65
+ return dedent(
66
+ f"""{criteria}
67
+
68
+ Below is the full conversation you should evaluate. Consider dialogue context, speaker roles, and how responses were handled.
69
+
70
+ Full Conversation:
71
+ {text}
72
+
73
+ **
74
+ IMPORTANT: Only return JSON with two keys:
75
+ - 'verdict': true or false
76
+ - 'reason': justification based on specific parts of the conversation
77
+
78
+ Example:
79
+ {{
80
+ "verdict": true,
81
+ "reason": "The assistant provided a clear and direct answer in response to every user query."
82
+ }}
83
+ **
84
+ JSON:
85
+ """
86
+ )
87
+
88
+
89
+ class ConversationalNonBinaryJudgementTemplate:
90
+ @staticmethod
91
+ def generate_non_binary_verdict(
92
+ criteria: str, text: str, options: List[str]
93
+ ):
94
+ return dedent(
95
+ f"""{criteria}
96
+
97
+ You are evaluating the following conversation. Choose one of the options that best reflects the assistant's behavior.
98
+
99
+ Options: {options}
100
+
101
+ Full Conversation:
102
+ {text}
103
+
104
+ **
105
+ IMPORTANT: Only return JSON with two keys:
106
+ - 'verdict': one of the listed options
107
+ - 'reason': explanation referencing specific conversation points
108
+
109
+ Example:
110
+ {{
111
+ "verdict": "{options[1]}",
112
+ "reason": "The assistant partially addressed the user’s issue but missed clarifying their follow-up question."
113
+ }}
114
+ **
115
+ JSON:
116
+ """
117
+ )
@@ -13,8 +13,8 @@ from deepeval.models import DeepEvalBaseLLM
13
13
  from deepeval.metrics.indicator import metric_progress_indicator
14
14
  from deepeval.metrics.g_eval.schema import *
15
15
  from deepeval.metrics.dag.graph import DeepAcyclicGraph
16
- from deepeval.metrics.dag.utils import copy_graph
17
16
  from deepeval.metrics.dag.utils import (
17
+ copy_graph,
18
18
  is_valid_dag_from_roots,
19
19
  extract_required_params,
20
20
  )
@@ -34,7 +34,12 @@ class DAGMetric(BaseMetric):
34
34
  verbose_mode: bool = False,
35
35
  _include_dag_suffix: bool = True,
36
36
  ):
37
- if is_valid_dag_from_roots(dag.root_nodes) == False:
37
+ if (
38
+ is_valid_dag_from_roots(
39
+ root_nodes=dag.root_nodes, multiturn=dag.multiturn
40
+ )
41
+ == False
42
+ ):
38
43
  raise ValueError("Cycle detected in DAG graph.")
39
44
 
40
45
  self._verbose_steps: List[str] = []
@@ -56,7 +61,9 @@ class DAGMetric(BaseMetric):
56
61
  _in_component: bool = False,
57
62
  ) -> float:
58
63
  check_llm_test_case_params(
59
- test_case, extract_required_params(self.dag.root_nodes), self
64
+ test_case,
65
+ extract_required_params(self.dag.root_nodes, self.dag.multiturn),
66
+ self,
60
67
  )
61
68
 
62
69
  self.evaluation_cost = 0 if self.using_native_model else None
@@ -91,7 +98,9 @@ class DAGMetric(BaseMetric):
91
98
  _in_component: bool = False,
92
99
  ) -> float:
93
100
  check_llm_test_case_params(
94
- test_case, extract_required_params(self.dag.root_nodes), self
101
+ test_case,
102
+ extract_required_params(self.dag.root_nodes, self.dag.multiturn),
103
+ self,
95
104
  )
96
105
 
97
106
  self.evaluation_cost = 0 if self.using_native_model else None
@@ -1,39 +1,71 @@
1
1
  import asyncio
2
- from typing import List
2
+ from typing import List, Union
3
3
 
4
4
  from deepeval.metrics.dag import (
5
5
  BaseNode,
6
6
  NonBinaryJudgementNode,
7
7
  BinaryJudgementNode,
8
8
  )
9
- from deepeval.test_case import LLMTestCase
10
- from deepeval.metrics import BaseMetric
9
+ from deepeval.metrics.conversational_dag import (
10
+ ConversationalBaseNode,
11
+ ConversationalBinaryJudgementNode,
12
+ ConversationalNonBinaryJudgementNode,
13
+ )
14
+ from deepeval.test_case import LLMTestCase, ConversationalTestCase
15
+ from deepeval.metrics import BaseMetric, BaseConversationalMetric
16
+
17
+
18
+ def validate_root_nodes(
19
+ root_nodes: Union[List[BaseNode], List[ConversationalBaseNode]],
20
+ ):
21
+ # see if all root nodes are of the same type, more verbose error message, actualy we should say we cannot mix multi and single turn nodes
22
+ if not all(isinstance(node, type(root_nodes[0])) for node in root_nodes):
23
+ raise ValueError("You cannot mix multi and single turn nodes")
24
+ return True
11
25
 
12
26
 
13
27
  class DeepAcyclicGraph:
28
+ multiturn: bool
29
+
14
30
  def __init__(
15
31
  self,
16
- root_nodes: List[BaseNode],
32
+ root_nodes: Union[List[BaseNode], List[ConversationalBaseNode]],
17
33
  ):
18
- for root_node in root_nodes:
19
- if isinstance(root_node, NonBinaryJudgementNode) or isinstance(
20
- root_node, BinaryJudgementNode
21
- ):
22
- if len(root_nodes) > 1:
23
- raise ValueError(
24
- "You cannot provide more than one root node when using 'BinaryJudgementNode' or 'NonBinaryJudgementNode' in root_nodes."
25
- )
34
+ validate_root_nodes(root_nodes)
35
+ self.multiturn = isinstance(root_nodes[0], ConversationalBaseNode)
26
36
 
37
+ if not self.multiturn:
38
+ for root_node in root_nodes:
39
+ if isinstance(root_node, NonBinaryJudgementNode) or isinstance(
40
+ root_node, BinaryJudgementNode
41
+ ):
42
+ if len(root_nodes) > 1:
43
+ raise ValueError(
44
+ "You cannot provide more than one root node when using 'BinaryJudgementNode' or 'NonBinaryJudgementNode' in root_nodes."
45
+ )
46
+ else:
47
+ for root_node in root_nodes:
48
+ if isinstance(
49
+ root_node, ConversationalNonBinaryJudgementNode
50
+ ) or isinstance(root_node, ConversationalBinaryJudgementNode):
51
+ if len(root_nodes) > 1:
52
+ raise ValueError(
53
+ "You cannot provide more than one root node when using 'ConversationalBinaryJudgementNode' or 'ConversationalNonBinaryJudgementNode' in root_nodes."
54
+ )
27
55
  self.root_nodes = root_nodes
28
56
 
29
- def _execute(self, metric: BaseMetric, test_case: LLMTestCase) -> None:
57
+ def _execute(
58
+ self,
59
+ metric: Union[BaseMetric, BaseConversationalMetric],
60
+ test_case: Union[LLMTestCase, ConversationalTestCase],
61
+ ) -> None:
30
62
  for root_node in self.root_nodes:
31
63
  root_node._execute(metric=metric, test_case=test_case, depth=0)
32
64
 
33
65
  async def _a_execute(
34
66
  self,
35
- metric: BaseMetric,
36
- test_case: LLMTestCase,
67
+ metric: Union[BaseMetric, BaseConversationalMetric],
68
+ test_case: Union[LLMTestCase, ConversationalTestCase],
37
69
  ) -> None:
38
70
  await asyncio.gather(
39
71
  *(
@@ -1,4 +1,4 @@
1
- from typing import Set, Dict, Optional
1
+ from typing import Set, Dict, Optional, Union
2
2
  import inspect
3
3
 
4
4
  from deepeval.metrics.dag import (
@@ -9,18 +9,33 @@ from deepeval.metrics.dag import (
9
9
  TaskNode,
10
10
  DeepAcyclicGraph,
11
11
  )
12
- from deepeval.test_case import LLMTestCaseParams
12
+ from deepeval.metrics.conversational_dag import (
13
+ ConversationalBaseNode,
14
+ ConversationalBinaryJudgementNode,
15
+ ConversationalNonBinaryJudgementNode,
16
+ ConversationalTaskNode,
17
+ ConversationalVerdictNode,
18
+ )
19
+ from deepeval.test_case import LLMTestCaseParams, TurnParams
13
20
 
14
21
 
15
- def is_valid_dag_from_roots(root_nodes: list[BaseNode]) -> bool:
22
+ def is_valid_dag_from_roots(
23
+ root_nodes: Union[list[BaseNode], list[ConversationalBaseNode]],
24
+ multiturn: bool,
25
+ ) -> bool:
16
26
  visited = set()
17
27
  for root in root_nodes:
18
- if not is_valid_dag(root, visited, set()):
28
+ if not is_valid_dag(root, multiturn, visited, set()):
19
29
  return False
20
30
  return True
21
31
 
22
32
 
23
- def is_valid_dag(node: BaseNode, visited=None, stack=None) -> bool:
33
+ def is_valid_dag(
34
+ node: Union[BaseNode, ConversationalBaseNode],
35
+ multiturn: bool,
36
+ visited=None,
37
+ stack=None,
38
+ ) -> bool:
24
39
  if visited is None:
25
40
  visited = set()
26
41
  if stack is None:
@@ -33,14 +48,24 @@ def is_valid_dag(node: BaseNode, visited=None, stack=None) -> bool:
33
48
 
34
49
  visited.add(node)
35
50
  stack.add(node)
36
- if (
37
- isinstance(node, TaskNode)
38
- or isinstance(node, BinaryJudgementNode)
39
- or isinstance(node, NonBinaryJudgementNode)
40
- ):
41
- for child in node.children:
42
- if not is_valid_dag(child, visited, stack):
43
- return False
51
+ if not multiturn:
52
+ if (
53
+ isinstance(node, TaskNode)
54
+ or isinstance(node, BinaryJudgementNode)
55
+ or isinstance(node, NonBinaryJudgementNode)
56
+ ):
57
+ for child in node.children:
58
+ if not is_valid_dag(child, multiturn, visited, stack):
59
+ return False
60
+ else:
61
+ if (
62
+ isinstance(node, ConversationalTaskNode)
63
+ or isinstance(node, ConversationalBinaryJudgementNode)
64
+ or isinstance(node, ConversationalNonBinaryJudgementNode)
65
+ ):
66
+ for child in node.children:
67
+ if not is_valid_dag(child, multiturn, visited, stack):
68
+ return False
44
69
 
45
70
  stack.remove(node)
46
71
  return True
@@ -48,29 +73,51 @@ def is_valid_dag(node: BaseNode, visited=None, stack=None) -> bool:
48
73
 
49
74
  def extract_required_params(
50
75
  nodes: list[BaseNode],
51
- required_params: Optional[Set[LLMTestCaseParams]] = None,
52
- ) -> Set[LLMTestCaseParams]:
76
+ multiturn: bool,
77
+ required_params: Optional[
78
+ Union[Set[LLMTestCaseParams], Set[TurnParams]]
79
+ ] = None,
80
+ ) -> Union[Set[LLMTestCaseParams], Set[TurnParams]]:
53
81
  if required_params is None:
54
82
  required_params = set()
55
83
 
56
84
  for node in nodes:
57
- if (
58
- isinstance(node, TaskNode)
59
- or isinstance(node, BinaryJudgementNode)
60
- or isinstance(node, NonBinaryJudgementNode)
61
- ):
62
- if node.evaluation_params is not None:
63
- required_params.update(node.evaluation_params)
64
- extract_required_params(node.children, required_params)
85
+ if not multiturn:
86
+ if (
87
+ isinstance(node, TaskNode)
88
+ or isinstance(node, BinaryJudgementNode)
89
+ or isinstance(node, NonBinaryJudgementNode)
90
+ ):
91
+ if node.evaluation_params is not None:
92
+ required_params.update(node.evaluation_params)
93
+ extract_required_params(
94
+ node.children, multiturn, required_params
95
+ )
96
+ else:
97
+ if (
98
+ isinstance(node, ConversationalTaskNode)
99
+ or isinstance(node, ConversationalBinaryJudgementNode)
100
+ or isinstance(node, ConversationalNonBinaryJudgementNode)
101
+ ):
102
+ if node.evaluation_params is not None:
103
+ required_params.update(node.evaluation_params)
104
+ extract_required_params(
105
+ node.children, multiturn, required_params
106
+ )
65
107
 
66
108
  return required_params
67
109
 
68
110
 
69
111
  def copy_graph(original_dag: DeepAcyclicGraph) -> DeepAcyclicGraph:
70
112
  # This mapping avoids re-copying nodes that appear in multiple places.
71
- visited: Dict[BaseNode, BaseNode] = {}
72
-
73
- def copy_node(node: BaseNode) -> BaseNode:
113
+ visited: Union[
114
+ Dict[BaseNode, BaseNode],
115
+ Dict[ConversationalBaseNode, ConversationalBaseNode],
116
+ ] = {}
117
+
118
+ def copy_node(
119
+ node: Union[BaseNode, ConversationalBaseNode],
120
+ ) -> Union[BaseNode, ConversationalBaseNode]:
74
121
  if node in visited:
75
122
  return visited[node]
76
123
 
@@ -98,22 +145,40 @@ def copy_graph(original_dag: DeepAcyclicGraph) -> DeepAcyclicGraph:
98
145
  "_depth",
99
146
  ]
100
147
  }
101
- if (
102
- isinstance(node, TaskNode)
103
- or isinstance(node, BinaryJudgementNode)
104
- or isinstance(node, NonBinaryJudgementNode)
105
- ):
106
- copied_node = node_class(
107
- **valid_args,
108
- children=[copy_node(child) for child in node.children]
109
- )
148
+ if not original_dag.multiturn:
149
+ if (
150
+ isinstance(node, TaskNode)
151
+ or isinstance(node, BinaryJudgementNode)
152
+ or isinstance(node, NonBinaryJudgementNode)
153
+ ):
154
+ copied_node = node_class(
155
+ **valid_args,
156
+ children=[copy_node(child) for child in node.children]
157
+ )
158
+ else:
159
+ if isinstance(node, VerdictNode) and node.child:
160
+ copied_node = node_class(
161
+ **valid_args, child=copy_node(node.child)
162
+ )
163
+ else:
164
+ copied_node = node_class(**valid_args)
110
165
  else:
111
- if isinstance(node, VerdictNode) and node.child:
166
+ if (
167
+ isinstance(node, ConversationalTaskNode)
168
+ or isinstance(node, ConversationalBinaryJudgementNode)
169
+ or isinstance(node, ConversationalNonBinaryJudgementNode)
170
+ ):
112
171
  copied_node = node_class(
113
- **valid_args, child=copy_node(node.child)
172
+ **valid_args,
173
+ children=[copy_node(child) for child in node.children]
114
174
  )
115
175
  else:
116
- copied_node = node_class(**valid_args)
176
+ if isinstance(node, ConversationalVerdictNode) and node.child:
177
+ copied_node = node_class(
178
+ **valid_args, child=copy_node(node.child)
179
+ )
180
+ else:
181
+ copied_node = node_class(**valid_args)
117
182
 
118
183
  visited[node] = copied_node
119
184
  return copied_node
@@ -1,47 +1,72 @@
1
- from typing import Optional, List, Dict, Union, Type
2
1
  import os
3
2
 
3
+ from typing import Dict, List, Optional, Type, TYPE_CHECKING
4
+ from types import SimpleNamespace
5
+
4
6
  from deepeval.models.base_model import DeepEvalBaseEmbeddingModel
5
7
 
6
- # check langchain availability
7
- try:
8
+
9
+ if TYPE_CHECKING:
10
+ from chromadb.api.models.Collection import Collection
8
11
  from langchain_core.documents import Document as LCDocument
9
- from langchain_text_splitters import TokenTextSplitter
10
12
  from langchain_text_splitters.base import TextSplitter
11
- from langchain_community.document_loaders import (
12
- PyPDFLoader,
13
- TextLoader,
14
- Docx2txtLoader,
15
- )
16
13
  from langchain_community.document_loaders.base import BaseLoader
17
14
 
18
- langchain_available = True
19
- except ImportError:
20
- langchain_available = False
21
-
22
- # check chromadb availability
23
- try:
24
- import chromadb
25
- from chromadb import Metadata
26
- from chromadb.api.models.Collection import Collection
27
-
28
- chroma_db_available = True
29
- except ImportError:
30
- chroma_db_available = False
31
15
 
32
-
33
- # Define a helper function to check availability
34
- def _check_chromadb_available():
35
- if not chroma_db_available:
16
+ # Lazy import caches
17
+ _langchain_ns = None
18
+ _chroma_mod = None
19
+ _langchain_import_error = None
20
+ _chroma_import_error = None
21
+
22
+
23
+ def _get_langchain():
24
+ """Return a namespace of langchain classes, or raise ImportError with root cause."""
25
+ global _langchain_ns, _langchain_import_error
26
+ if _langchain_ns is not None:
27
+ return _langchain_ns
28
+ try:
29
+ from langchain_core.documents import Document as LCDocument # type: ignore
30
+ from langchain_text_splitters import TokenTextSplitter # type: ignore
31
+ from langchain_text_splitters.base import TextSplitter # type: ignore
32
+ from langchain_community.document_loaders import ( # type: ignore
33
+ PyPDFLoader,
34
+ TextLoader,
35
+ Docx2txtLoader,
36
+ )
37
+ from langchain_community.document_loaders.base import BaseLoader # type: ignore
38
+
39
+ _langchain_ns = SimpleNamespace(
40
+ LCDocument=LCDocument,
41
+ TokenTextSplitter=TokenTextSplitter,
42
+ TextSplitter=TextSplitter,
43
+ PyPDFLoader=PyPDFLoader,
44
+ TextLoader=TextLoader,
45
+ Docx2txtLoader=Docx2txtLoader,
46
+ BaseLoader=BaseLoader,
47
+ )
48
+ return _langchain_ns
49
+ except Exception as e:
50
+ _langchain_import_error = e
36
51
  raise ImportError(
37
- "chromadb is required for this functionality. Install it via your package manager"
52
+ f"langchain, langchain_community, and langchain_text_splitters are required. Root cause: {e}"
38
53
  )
39
54
 
40
55
 
41
- def _check_langchain_available():
42
- if not langchain_available:
56
+ def _get_chromadb():
57
+ """Return the chromadb module, or raise ImportError with root cause."""
58
+ global _chroma_mod, _chroma_import_error
59
+ if _chroma_mod is not None:
60
+ return _chroma_mod
61
+ try:
62
+ import chromadb
63
+
64
+ _chroma_mod = chromadb
65
+ return _chroma_mod
66
+ except Exception as e:
67
+ _chroma_import_error = e
43
68
  raise ImportError(
44
- "langchain, langchain_community, and langchain_text_splitters are required for this functionality. Install it via your package manager"
69
+ f"chromadb is required for this functionality. Root cause: {e}"
45
70
  )
46
71
 
47
72
 
@@ -50,22 +75,16 @@ class DocumentChunker:
50
75
  self,
51
76
  embedder: DeepEvalBaseEmbeddingModel,
52
77
  ):
53
- _check_chromadb_available()
54
- _check_langchain_available()
55
78
  self.text_token_count: Optional[int] = None # set later
56
79
 
57
80
  self.source_file: Optional[str] = None
58
81
  self.chunks: Optional["Collection"] = None
59
- self.sections: Optional[List[LCDocument]] = None
82
+ self.sections: Optional[List["LCDocument"]] = None
60
83
  self.embedder: DeepEvalBaseEmbeddingModel = embedder
61
84
  self.mean_embedding: Optional[float] = None
62
85
 
63
86
  # Mapping of file extensions to their respective loader classes
64
- self.loader_mapping: Dict[str, Type[BaseLoader]] = {
65
- ".pdf": PyPDFLoader,
66
- ".txt": TextLoader,
67
- ".docx": Docx2txtLoader,
68
- }
87
+ self.loader_mapping: Dict[str, "Type[BaseLoader]"] = {}
69
88
 
70
89
  #########################################################
71
90
  ### Chunking Docs #######################################
@@ -74,7 +93,8 @@ class DocumentChunker:
74
93
  async def a_chunk_doc(
75
94
  self, chunk_size: int = 1024, chunk_overlap: int = 0
76
95
  ) -> "Collection":
77
- _check_chromadb_available()
96
+ lc = _get_langchain()
97
+ chroma = _get_chromadb()
78
98
 
79
99
  # Raise error if chunk_doc is called before load_doc
80
100
  if self.sections is None or self.source_file is None:
@@ -85,13 +105,13 @@ class DocumentChunker:
85
105
  # Create ChromaDB client
86
106
  full_document_path, _ = os.path.splitext(self.source_file)
87
107
  document_name = os.path.basename(full_document_path)
88
- client = chromadb.PersistentClient(path=f".vector_db/{document_name}")
108
+ client = chroma.PersistentClient(path=f".vector_db/{document_name}")
89
109
 
90
110
  collection_name = f"processed_chunks_{chunk_size}_{chunk_overlap}"
91
111
  try:
92
112
  collection = client.get_collection(name=collection_name)
93
113
  except Exception:
94
- text_splitter: TextSplitter = TokenTextSplitter(
114
+ text_splitter: "TextSplitter" = lc.TokenTextSplitter(
95
115
  chunk_size=chunk_size, chunk_overlap=chunk_overlap
96
116
  )
97
117
  # Collection doesn't exist, so create it and then add documents
@@ -108,7 +128,7 @@ class DocumentChunker:
108
128
  batch_contents = contents[i:batch_end]
109
129
  batch_embeddings = embeddings[i:batch_end]
110
130
  batch_ids = ids[i:batch_end]
111
- batch_metadatas: List["Metadata"] = [
131
+ batch_metadatas: List[dict] = [
112
132
  {"source_file": self.source_file} for _ in batch_contents
113
133
  ]
114
134
 
@@ -121,7 +141,8 @@ class DocumentChunker:
121
141
  return collection
122
142
 
123
143
  def chunk_doc(self, chunk_size: int = 1024, chunk_overlap: int = 0):
124
- _check_chromadb_available()
144
+ lc = _get_langchain()
145
+ chroma = _get_chromadb()
125
146
 
126
147
  # Raise error if chunk_doc is called before load_doc
127
148
  if self.sections is None or self.source_file is None:
@@ -132,13 +153,13 @@ class DocumentChunker:
132
153
  # Create ChromaDB client
133
154
  full_document_path, _ = os.path.splitext(self.source_file)
134
155
  document_name = os.path.basename(full_document_path)
135
- client = chromadb.PersistentClient(path=f".vector_db/{document_name}")
156
+ client = chroma.PersistentClient(path=f".vector_db/{document_name}")
136
157
 
137
158
  collection_name = f"processed_chunks_{chunk_size}_{chunk_overlap}"
138
159
  try:
139
160
  collection = client.get_collection(name=collection_name)
140
161
  except Exception:
141
- text_splitter: TextSplitter = TokenTextSplitter(
162
+ text_splitter: "TextSplitter" = lc.TokenTextSplitter(
142
163
  chunk_size=chunk_size, chunk_overlap=chunk_overlap
143
164
  )
144
165
  # Collection doesn't exist, so create it and then add documents
@@ -155,7 +176,7 @@ class DocumentChunker:
155
176
  batch_contents = contents[i:batch_end]
156
177
  batch_embeddings = embeddings[i:batch_end]
157
178
  batch_ids = ids[i:batch_end]
158
- batch_metadatas: List["Metadata"] = [
179
+ batch_metadatas: List[dict] = [
159
180
  {"source_file": self.source_file} for _ in batch_contents
160
181
  ]
161
182
 
@@ -172,17 +193,31 @@ class DocumentChunker:
172
193
  #########################################################
173
194
 
174
195
  def get_loader(self, path: str, encoding: Optional[str]) -> "BaseLoader":
196
+ lc = _get_langchain()
197
+ # set mapping lazily now that langchain classes exist
198
+ if not self.loader_mapping:
199
+ self.loader_mapping = {
200
+ ".pdf": lc.PyPDFLoader,
201
+ ".txt": lc.TextLoader,
202
+ ".docx": lc.Docx2txtLoader,
203
+ ".md": lc.TextLoader,
204
+ ".markdown": lc.TextLoader,
205
+ ".mdx": lc.TextLoader,
206
+ }
207
+
175
208
  # Find appropriate doc loader
176
209
  _, extension = os.path.splitext(path)
177
210
  extension = extension.lower()
178
- loader: Optional[type[BaseLoader]] = self.loader_mapping.get(extension)
211
+ loader: Optional["Type[BaseLoader]"] = self.loader_mapping.get(
212
+ extension
213
+ )
179
214
  if loader is None:
180
215
  raise ValueError(f"Unsupported file format: {extension}")
181
216
 
182
- # Load doc into sections and calculate total character count
183
- if loader is TextLoader:
217
+ # Load doc into sections and calculate total token count
218
+ if loader is lc.TextLoader:
184
219
  return loader(path, encoding=encoding, autodetect_encoding=True)
185
- elif loader is PyPDFLoader or loader is Docx2txtLoader:
220
+ elif loader in (lc.PyPDFLoader, lc.Docx2txtLoader):
186
221
  return loader(path)
187
222
  else:
188
223
  raise ValueError(f"Unsupported file format: {extension}")
@@ -200,5 +235,6 @@ class DocumentChunker:
200
235
  self.source_file = path
201
236
 
202
237
  def count_tokens(self, chunks: List["LCDocument"]):
203
- counter = TokenTextSplitter(chunk_size=1, chunk_overlap=0)
238
+ lc = _get_langchain()
239
+ counter = lc.TokenTextSplitter(chunk_size=1, chunk_overlap=0)
204
240
  return len(counter.split_documents(chunks))