deepeval 3.4.7__py3-none-any.whl → 3.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- deepeval/__init__.py +0 -2
- deepeval/_version.py +1 -1
- deepeval/cli/dotenv_handler.py +71 -0
- deepeval/cli/main.py +1039 -132
- deepeval/cli/utils.py +116 -2
- deepeval/key_handler.py +63 -2
- deepeval/metrics/__init__.py +4 -1
- deepeval/metrics/conversational_dag/__init__.py +7 -0
- deepeval/metrics/conversational_dag/conversational_dag.py +139 -0
- deepeval/metrics/conversational_dag/nodes.py +931 -0
- deepeval/metrics/conversational_dag/templates.py +117 -0
- deepeval/metrics/dag/dag.py +13 -4
- deepeval/metrics/dag/graph.py +47 -15
- deepeval/metrics/dag/utils.py +103 -38
- deepeval/synthesizer/chunking/doc_chunker.py +87 -51
- {deepeval-3.4.7.dist-info → deepeval-3.4.8.dist-info}/METADATA +1 -1
- {deepeval-3.4.7.dist-info → deepeval-3.4.8.dist-info}/RECORD +20 -15
- {deepeval-3.4.7.dist-info → deepeval-3.4.8.dist-info}/LICENSE.md +0 -0
- {deepeval-3.4.7.dist-info → deepeval-3.4.8.dist-info}/WHEEL +0 -0
- {deepeval-3.4.7.dist-info → deepeval-3.4.8.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
from textwrap import dedent
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ConversationalVerdictNodeTemplate:
|
|
6
|
+
@staticmethod
|
|
7
|
+
def generate_reason(verbose_steps: List[str], score: float, name: str):
|
|
8
|
+
return dedent(
|
|
9
|
+
f"""You are given a metric name, its score, and a traversal path through a conversational evaluation DAG (Directed Acyclic Graph).
|
|
10
|
+
This DAG reflects step-by-step reasoning over a dialogue to arrive at the final verdict.
|
|
11
|
+
|
|
12
|
+
Each step in the DAG represents a judgment based on parts of the conversation — including roles and the contents they spoke of.
|
|
13
|
+
|
|
14
|
+
Your task is to explain **why the score was assigned**, using the traversal steps to justify the reasoning.
|
|
15
|
+
|
|
16
|
+
Metric Name:
|
|
17
|
+
{name}
|
|
18
|
+
|
|
19
|
+
Score:
|
|
20
|
+
{score}
|
|
21
|
+
|
|
22
|
+
DAG Traversal:
|
|
23
|
+
{verbose_steps}
|
|
24
|
+
|
|
25
|
+
**
|
|
26
|
+
IMPORTANT: Only return JSON with a 'reason' key.
|
|
27
|
+
Example:
|
|
28
|
+
{{
|
|
29
|
+
"reason": "The score is {score} because the assistant repeatedly failed to clarify the user's ambiguous statements, as shown in the DAG traversal path."
|
|
30
|
+
}}
|
|
31
|
+
**
|
|
32
|
+
JSON:
|
|
33
|
+
"""
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ConversationalTaskNodeTemplate:
|
|
38
|
+
@staticmethod
|
|
39
|
+
def generate_task_output(instructions: str, text: str):
|
|
40
|
+
return dedent(
|
|
41
|
+
f"""You are given a set of task instructions and a full conversation between a user and an assistant.
|
|
42
|
+
|
|
43
|
+
Instructions:
|
|
44
|
+
{instructions}
|
|
45
|
+
|
|
46
|
+
{text}
|
|
47
|
+
|
|
48
|
+
===END OF INPUT===
|
|
49
|
+
|
|
50
|
+
**
|
|
51
|
+
IMPORTANT: Only return a JSON with the 'output' key containing the result of applying the instructions to the conversation.
|
|
52
|
+
Example:
|
|
53
|
+
{{
|
|
54
|
+
"output": "..."
|
|
55
|
+
}}
|
|
56
|
+
**
|
|
57
|
+
JSON:
|
|
58
|
+
"""
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ConversationalBinaryJudgementTemplate:
|
|
63
|
+
@staticmethod
|
|
64
|
+
def generate_binary_verdict(criteria: str, text: str):
|
|
65
|
+
return dedent(
|
|
66
|
+
f"""{criteria}
|
|
67
|
+
|
|
68
|
+
Below is the full conversation you should evaluate. Consider dialogue context, speaker roles, and how responses were handled.
|
|
69
|
+
|
|
70
|
+
Full Conversation:
|
|
71
|
+
{text}
|
|
72
|
+
|
|
73
|
+
**
|
|
74
|
+
IMPORTANT: Only return JSON with two keys:
|
|
75
|
+
- 'verdict': true or false
|
|
76
|
+
- 'reason': justification based on specific parts of the conversation
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
{{
|
|
80
|
+
"verdict": true,
|
|
81
|
+
"reason": "The assistant provided a clear and direct answer in response to every user query."
|
|
82
|
+
}}
|
|
83
|
+
**
|
|
84
|
+
JSON:
|
|
85
|
+
"""
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ConversationalNonBinaryJudgementTemplate:
|
|
90
|
+
@staticmethod
|
|
91
|
+
def generate_non_binary_verdict(
|
|
92
|
+
criteria: str, text: str, options: List[str]
|
|
93
|
+
):
|
|
94
|
+
return dedent(
|
|
95
|
+
f"""{criteria}
|
|
96
|
+
|
|
97
|
+
You are evaluating the following conversation. Choose one of the options that best reflects the assistant's behavior.
|
|
98
|
+
|
|
99
|
+
Options: {options}
|
|
100
|
+
|
|
101
|
+
Full Conversation:
|
|
102
|
+
{text}
|
|
103
|
+
|
|
104
|
+
**
|
|
105
|
+
IMPORTANT: Only return JSON with two keys:
|
|
106
|
+
- 'verdict': one of the listed options
|
|
107
|
+
- 'reason': explanation referencing specific conversation points
|
|
108
|
+
|
|
109
|
+
Example:
|
|
110
|
+
{{
|
|
111
|
+
"verdict": "{options[1]}",
|
|
112
|
+
"reason": "The assistant partially addressed the user’s issue but missed clarifying their follow-up question."
|
|
113
|
+
}}
|
|
114
|
+
**
|
|
115
|
+
JSON:
|
|
116
|
+
"""
|
|
117
|
+
)
|
deepeval/metrics/dag/dag.py
CHANGED
|
@@ -13,8 +13,8 @@ from deepeval.models import DeepEvalBaseLLM
|
|
|
13
13
|
from deepeval.metrics.indicator import metric_progress_indicator
|
|
14
14
|
from deepeval.metrics.g_eval.schema import *
|
|
15
15
|
from deepeval.metrics.dag.graph import DeepAcyclicGraph
|
|
16
|
-
from deepeval.metrics.dag.utils import copy_graph
|
|
17
16
|
from deepeval.metrics.dag.utils import (
|
|
17
|
+
copy_graph,
|
|
18
18
|
is_valid_dag_from_roots,
|
|
19
19
|
extract_required_params,
|
|
20
20
|
)
|
|
@@ -34,7 +34,12 @@ class DAGMetric(BaseMetric):
|
|
|
34
34
|
verbose_mode: bool = False,
|
|
35
35
|
_include_dag_suffix: bool = True,
|
|
36
36
|
):
|
|
37
|
-
if
|
|
37
|
+
if (
|
|
38
|
+
is_valid_dag_from_roots(
|
|
39
|
+
root_nodes=dag.root_nodes, multiturn=dag.multiturn
|
|
40
|
+
)
|
|
41
|
+
== False
|
|
42
|
+
):
|
|
38
43
|
raise ValueError("Cycle detected in DAG graph.")
|
|
39
44
|
|
|
40
45
|
self._verbose_steps: List[str] = []
|
|
@@ -56,7 +61,9 @@ class DAGMetric(BaseMetric):
|
|
|
56
61
|
_in_component: bool = False,
|
|
57
62
|
) -> float:
|
|
58
63
|
check_llm_test_case_params(
|
|
59
|
-
test_case,
|
|
64
|
+
test_case,
|
|
65
|
+
extract_required_params(self.dag.root_nodes, self.dag.multiturn),
|
|
66
|
+
self,
|
|
60
67
|
)
|
|
61
68
|
|
|
62
69
|
self.evaluation_cost = 0 if self.using_native_model else None
|
|
@@ -91,7 +98,9 @@ class DAGMetric(BaseMetric):
|
|
|
91
98
|
_in_component: bool = False,
|
|
92
99
|
) -> float:
|
|
93
100
|
check_llm_test_case_params(
|
|
94
|
-
test_case,
|
|
101
|
+
test_case,
|
|
102
|
+
extract_required_params(self.dag.root_nodes, self.dag.multiturn),
|
|
103
|
+
self,
|
|
95
104
|
)
|
|
96
105
|
|
|
97
106
|
self.evaluation_cost = 0 if self.using_native_model else None
|
deepeval/metrics/dag/graph.py
CHANGED
|
@@ -1,39 +1,71 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from typing import List
|
|
2
|
+
from typing import List, Union
|
|
3
3
|
|
|
4
4
|
from deepeval.metrics.dag import (
|
|
5
5
|
BaseNode,
|
|
6
6
|
NonBinaryJudgementNode,
|
|
7
7
|
BinaryJudgementNode,
|
|
8
8
|
)
|
|
9
|
-
from deepeval.
|
|
10
|
-
|
|
9
|
+
from deepeval.metrics.conversational_dag import (
|
|
10
|
+
ConversationalBaseNode,
|
|
11
|
+
ConversationalBinaryJudgementNode,
|
|
12
|
+
ConversationalNonBinaryJudgementNode,
|
|
13
|
+
)
|
|
14
|
+
from deepeval.test_case import LLMTestCase, ConversationalTestCase
|
|
15
|
+
from deepeval.metrics import BaseMetric, BaseConversationalMetric
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def validate_root_nodes(
|
|
19
|
+
root_nodes: Union[List[BaseNode], List[ConversationalBaseNode]],
|
|
20
|
+
):
|
|
21
|
+
# see if all root nodes are of the same type, more verbose error message, actualy we should say we cannot mix multi and single turn nodes
|
|
22
|
+
if not all(isinstance(node, type(root_nodes[0])) for node in root_nodes):
|
|
23
|
+
raise ValueError("You cannot mix multi and single turn nodes")
|
|
24
|
+
return True
|
|
11
25
|
|
|
12
26
|
|
|
13
27
|
class DeepAcyclicGraph:
|
|
28
|
+
multiturn: bool
|
|
29
|
+
|
|
14
30
|
def __init__(
|
|
15
31
|
self,
|
|
16
|
-
root_nodes: List[BaseNode],
|
|
32
|
+
root_nodes: Union[List[BaseNode], List[ConversationalBaseNode]],
|
|
17
33
|
):
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
root_node, BinaryJudgementNode
|
|
21
|
-
):
|
|
22
|
-
if len(root_nodes) > 1:
|
|
23
|
-
raise ValueError(
|
|
24
|
-
"You cannot provide more than one root node when using 'BinaryJudgementNode' or 'NonBinaryJudgementNode' in root_nodes."
|
|
25
|
-
)
|
|
34
|
+
validate_root_nodes(root_nodes)
|
|
35
|
+
self.multiturn = isinstance(root_nodes[0], ConversationalBaseNode)
|
|
26
36
|
|
|
37
|
+
if not self.multiturn:
|
|
38
|
+
for root_node in root_nodes:
|
|
39
|
+
if isinstance(root_node, NonBinaryJudgementNode) or isinstance(
|
|
40
|
+
root_node, BinaryJudgementNode
|
|
41
|
+
):
|
|
42
|
+
if len(root_nodes) > 1:
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"You cannot provide more than one root node when using 'BinaryJudgementNode' or 'NonBinaryJudgementNode' in root_nodes."
|
|
45
|
+
)
|
|
46
|
+
else:
|
|
47
|
+
for root_node in root_nodes:
|
|
48
|
+
if isinstance(
|
|
49
|
+
root_node, ConversationalNonBinaryJudgementNode
|
|
50
|
+
) or isinstance(root_node, ConversationalBinaryJudgementNode):
|
|
51
|
+
if len(root_nodes) > 1:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
"You cannot provide more than one root node when using 'ConversationalBinaryJudgementNode' or 'ConversationalNonBinaryJudgementNode' in root_nodes."
|
|
54
|
+
)
|
|
27
55
|
self.root_nodes = root_nodes
|
|
28
56
|
|
|
29
|
-
def _execute(
|
|
57
|
+
def _execute(
|
|
58
|
+
self,
|
|
59
|
+
metric: Union[BaseMetric, BaseConversationalMetric],
|
|
60
|
+
test_case: Union[LLMTestCase, ConversationalTestCase],
|
|
61
|
+
) -> None:
|
|
30
62
|
for root_node in self.root_nodes:
|
|
31
63
|
root_node._execute(metric=metric, test_case=test_case, depth=0)
|
|
32
64
|
|
|
33
65
|
async def _a_execute(
|
|
34
66
|
self,
|
|
35
|
-
metric: BaseMetric,
|
|
36
|
-
test_case: LLMTestCase,
|
|
67
|
+
metric: Union[BaseMetric, BaseConversationalMetric],
|
|
68
|
+
test_case: Union[LLMTestCase, ConversationalTestCase],
|
|
37
69
|
) -> None:
|
|
38
70
|
await asyncio.gather(
|
|
39
71
|
*(
|
deepeval/metrics/dag/utils.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Set, Dict, Optional
|
|
1
|
+
from typing import Set, Dict, Optional, Union
|
|
2
2
|
import inspect
|
|
3
3
|
|
|
4
4
|
from deepeval.metrics.dag import (
|
|
@@ -9,18 +9,33 @@ from deepeval.metrics.dag import (
|
|
|
9
9
|
TaskNode,
|
|
10
10
|
DeepAcyclicGraph,
|
|
11
11
|
)
|
|
12
|
-
from deepeval.
|
|
12
|
+
from deepeval.metrics.conversational_dag import (
|
|
13
|
+
ConversationalBaseNode,
|
|
14
|
+
ConversationalBinaryJudgementNode,
|
|
15
|
+
ConversationalNonBinaryJudgementNode,
|
|
16
|
+
ConversationalTaskNode,
|
|
17
|
+
ConversationalVerdictNode,
|
|
18
|
+
)
|
|
19
|
+
from deepeval.test_case import LLMTestCaseParams, TurnParams
|
|
13
20
|
|
|
14
21
|
|
|
15
|
-
def is_valid_dag_from_roots(
|
|
22
|
+
def is_valid_dag_from_roots(
|
|
23
|
+
root_nodes: Union[list[BaseNode], list[ConversationalBaseNode]],
|
|
24
|
+
multiturn: bool,
|
|
25
|
+
) -> bool:
|
|
16
26
|
visited = set()
|
|
17
27
|
for root in root_nodes:
|
|
18
|
-
if not is_valid_dag(root, visited, set()):
|
|
28
|
+
if not is_valid_dag(root, multiturn, visited, set()):
|
|
19
29
|
return False
|
|
20
30
|
return True
|
|
21
31
|
|
|
22
32
|
|
|
23
|
-
def is_valid_dag(
|
|
33
|
+
def is_valid_dag(
|
|
34
|
+
node: Union[BaseNode, ConversationalBaseNode],
|
|
35
|
+
multiturn: bool,
|
|
36
|
+
visited=None,
|
|
37
|
+
stack=None,
|
|
38
|
+
) -> bool:
|
|
24
39
|
if visited is None:
|
|
25
40
|
visited = set()
|
|
26
41
|
if stack is None:
|
|
@@ -33,14 +48,24 @@ def is_valid_dag(node: BaseNode, visited=None, stack=None) -> bool:
|
|
|
33
48
|
|
|
34
49
|
visited.add(node)
|
|
35
50
|
stack.add(node)
|
|
36
|
-
if
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
51
|
+
if not multiturn:
|
|
52
|
+
if (
|
|
53
|
+
isinstance(node, TaskNode)
|
|
54
|
+
or isinstance(node, BinaryJudgementNode)
|
|
55
|
+
or isinstance(node, NonBinaryJudgementNode)
|
|
56
|
+
):
|
|
57
|
+
for child in node.children:
|
|
58
|
+
if not is_valid_dag(child, multiturn, visited, stack):
|
|
59
|
+
return False
|
|
60
|
+
else:
|
|
61
|
+
if (
|
|
62
|
+
isinstance(node, ConversationalTaskNode)
|
|
63
|
+
or isinstance(node, ConversationalBinaryJudgementNode)
|
|
64
|
+
or isinstance(node, ConversationalNonBinaryJudgementNode)
|
|
65
|
+
):
|
|
66
|
+
for child in node.children:
|
|
67
|
+
if not is_valid_dag(child, multiturn, visited, stack):
|
|
68
|
+
return False
|
|
44
69
|
|
|
45
70
|
stack.remove(node)
|
|
46
71
|
return True
|
|
@@ -48,29 +73,51 @@ def is_valid_dag(node: BaseNode, visited=None, stack=None) -> bool:
|
|
|
48
73
|
|
|
49
74
|
def extract_required_params(
|
|
50
75
|
nodes: list[BaseNode],
|
|
51
|
-
|
|
52
|
-
|
|
76
|
+
multiturn: bool,
|
|
77
|
+
required_params: Optional[
|
|
78
|
+
Union[Set[LLMTestCaseParams], Set[TurnParams]]
|
|
79
|
+
] = None,
|
|
80
|
+
) -> Union[Set[LLMTestCaseParams], Set[TurnParams]]:
|
|
53
81
|
if required_params is None:
|
|
54
82
|
required_params = set()
|
|
55
83
|
|
|
56
84
|
for node in nodes:
|
|
57
|
-
if
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
85
|
+
if not multiturn:
|
|
86
|
+
if (
|
|
87
|
+
isinstance(node, TaskNode)
|
|
88
|
+
or isinstance(node, BinaryJudgementNode)
|
|
89
|
+
or isinstance(node, NonBinaryJudgementNode)
|
|
90
|
+
):
|
|
91
|
+
if node.evaluation_params is not None:
|
|
92
|
+
required_params.update(node.evaluation_params)
|
|
93
|
+
extract_required_params(
|
|
94
|
+
node.children, multiturn, required_params
|
|
95
|
+
)
|
|
96
|
+
else:
|
|
97
|
+
if (
|
|
98
|
+
isinstance(node, ConversationalTaskNode)
|
|
99
|
+
or isinstance(node, ConversationalBinaryJudgementNode)
|
|
100
|
+
or isinstance(node, ConversationalNonBinaryJudgementNode)
|
|
101
|
+
):
|
|
102
|
+
if node.evaluation_params is not None:
|
|
103
|
+
required_params.update(node.evaluation_params)
|
|
104
|
+
extract_required_params(
|
|
105
|
+
node.children, multiturn, required_params
|
|
106
|
+
)
|
|
65
107
|
|
|
66
108
|
return required_params
|
|
67
109
|
|
|
68
110
|
|
|
69
111
|
def copy_graph(original_dag: DeepAcyclicGraph) -> DeepAcyclicGraph:
|
|
70
112
|
# This mapping avoids re-copying nodes that appear in multiple places.
|
|
71
|
-
visited:
|
|
72
|
-
|
|
73
|
-
|
|
113
|
+
visited: Union[
|
|
114
|
+
Dict[BaseNode, BaseNode],
|
|
115
|
+
Dict[ConversationalBaseNode, ConversationalBaseNode],
|
|
116
|
+
] = {}
|
|
117
|
+
|
|
118
|
+
def copy_node(
|
|
119
|
+
node: Union[BaseNode, ConversationalBaseNode],
|
|
120
|
+
) -> Union[BaseNode, ConversationalBaseNode]:
|
|
74
121
|
if node in visited:
|
|
75
122
|
return visited[node]
|
|
76
123
|
|
|
@@ -98,22 +145,40 @@ def copy_graph(original_dag: DeepAcyclicGraph) -> DeepAcyclicGraph:
|
|
|
98
145
|
"_depth",
|
|
99
146
|
]
|
|
100
147
|
}
|
|
101
|
-
if
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
148
|
+
if not original_dag.multiturn:
|
|
149
|
+
if (
|
|
150
|
+
isinstance(node, TaskNode)
|
|
151
|
+
or isinstance(node, BinaryJudgementNode)
|
|
152
|
+
or isinstance(node, NonBinaryJudgementNode)
|
|
153
|
+
):
|
|
154
|
+
copied_node = node_class(
|
|
155
|
+
**valid_args,
|
|
156
|
+
children=[copy_node(child) for child in node.children]
|
|
157
|
+
)
|
|
158
|
+
else:
|
|
159
|
+
if isinstance(node, VerdictNode) and node.child:
|
|
160
|
+
copied_node = node_class(
|
|
161
|
+
**valid_args, child=copy_node(node.child)
|
|
162
|
+
)
|
|
163
|
+
else:
|
|
164
|
+
copied_node = node_class(**valid_args)
|
|
110
165
|
else:
|
|
111
|
-
if
|
|
166
|
+
if (
|
|
167
|
+
isinstance(node, ConversationalTaskNode)
|
|
168
|
+
or isinstance(node, ConversationalBinaryJudgementNode)
|
|
169
|
+
or isinstance(node, ConversationalNonBinaryJudgementNode)
|
|
170
|
+
):
|
|
112
171
|
copied_node = node_class(
|
|
113
|
-
**valid_args,
|
|
172
|
+
**valid_args,
|
|
173
|
+
children=[copy_node(child) for child in node.children]
|
|
114
174
|
)
|
|
115
175
|
else:
|
|
116
|
-
|
|
176
|
+
if isinstance(node, ConversationalVerdictNode) and node.child:
|
|
177
|
+
copied_node = node_class(
|
|
178
|
+
**valid_args, child=copy_node(node.child)
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
copied_node = node_class(**valid_args)
|
|
117
182
|
|
|
118
183
|
visited[node] = copied_node
|
|
119
184
|
return copied_node
|
|
@@ -1,47 +1,72 @@
|
|
|
1
|
-
from typing import Optional, List, Dict, Union, Type
|
|
2
1
|
import os
|
|
3
2
|
|
|
3
|
+
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
|
4
|
+
from types import SimpleNamespace
|
|
5
|
+
|
|
4
6
|
from deepeval.models.base_model import DeepEvalBaseEmbeddingModel
|
|
5
7
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from chromadb.api.models.Collection import Collection
|
|
8
11
|
from langchain_core.documents import Document as LCDocument
|
|
9
|
-
from langchain_text_splitters import TokenTextSplitter
|
|
10
12
|
from langchain_text_splitters.base import TextSplitter
|
|
11
|
-
from langchain_community.document_loaders import (
|
|
12
|
-
PyPDFLoader,
|
|
13
|
-
TextLoader,
|
|
14
|
-
Docx2txtLoader,
|
|
15
|
-
)
|
|
16
13
|
from langchain_community.document_loaders.base import BaseLoader
|
|
17
14
|
|
|
18
|
-
langchain_available = True
|
|
19
|
-
except ImportError:
|
|
20
|
-
langchain_available = False
|
|
21
|
-
|
|
22
|
-
# check chromadb availability
|
|
23
|
-
try:
|
|
24
|
-
import chromadb
|
|
25
|
-
from chromadb import Metadata
|
|
26
|
-
from chromadb.api.models.Collection import Collection
|
|
27
|
-
|
|
28
|
-
chroma_db_available = True
|
|
29
|
-
except ImportError:
|
|
30
|
-
chroma_db_available = False
|
|
31
15
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
16
|
+
# Lazy import caches
|
|
17
|
+
_langchain_ns = None
|
|
18
|
+
_chroma_mod = None
|
|
19
|
+
_langchain_import_error = None
|
|
20
|
+
_chroma_import_error = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _get_langchain():
|
|
24
|
+
"""Return a namespace of langchain classes, or raise ImportError with root cause."""
|
|
25
|
+
global _langchain_ns, _langchain_import_error
|
|
26
|
+
if _langchain_ns is not None:
|
|
27
|
+
return _langchain_ns
|
|
28
|
+
try:
|
|
29
|
+
from langchain_core.documents import Document as LCDocument # type: ignore
|
|
30
|
+
from langchain_text_splitters import TokenTextSplitter # type: ignore
|
|
31
|
+
from langchain_text_splitters.base import TextSplitter # type: ignore
|
|
32
|
+
from langchain_community.document_loaders import ( # type: ignore
|
|
33
|
+
PyPDFLoader,
|
|
34
|
+
TextLoader,
|
|
35
|
+
Docx2txtLoader,
|
|
36
|
+
)
|
|
37
|
+
from langchain_community.document_loaders.base import BaseLoader # type: ignore
|
|
38
|
+
|
|
39
|
+
_langchain_ns = SimpleNamespace(
|
|
40
|
+
LCDocument=LCDocument,
|
|
41
|
+
TokenTextSplitter=TokenTextSplitter,
|
|
42
|
+
TextSplitter=TextSplitter,
|
|
43
|
+
PyPDFLoader=PyPDFLoader,
|
|
44
|
+
TextLoader=TextLoader,
|
|
45
|
+
Docx2txtLoader=Docx2txtLoader,
|
|
46
|
+
BaseLoader=BaseLoader,
|
|
47
|
+
)
|
|
48
|
+
return _langchain_ns
|
|
49
|
+
except Exception as e:
|
|
50
|
+
_langchain_import_error = e
|
|
36
51
|
raise ImportError(
|
|
37
|
-
"
|
|
52
|
+
f"langchain, langchain_community, and langchain_text_splitters are required. Root cause: {e}"
|
|
38
53
|
)
|
|
39
54
|
|
|
40
55
|
|
|
41
|
-
def
|
|
42
|
-
|
|
56
|
+
def _get_chromadb():
|
|
57
|
+
"""Return the chromadb module, or raise ImportError with root cause."""
|
|
58
|
+
global _chroma_mod, _chroma_import_error
|
|
59
|
+
if _chroma_mod is not None:
|
|
60
|
+
return _chroma_mod
|
|
61
|
+
try:
|
|
62
|
+
import chromadb
|
|
63
|
+
|
|
64
|
+
_chroma_mod = chromadb
|
|
65
|
+
return _chroma_mod
|
|
66
|
+
except Exception as e:
|
|
67
|
+
_chroma_import_error = e
|
|
43
68
|
raise ImportError(
|
|
44
|
-
"
|
|
69
|
+
f"chromadb is required for this functionality. Root cause: {e}"
|
|
45
70
|
)
|
|
46
71
|
|
|
47
72
|
|
|
@@ -50,22 +75,16 @@ class DocumentChunker:
|
|
|
50
75
|
self,
|
|
51
76
|
embedder: DeepEvalBaseEmbeddingModel,
|
|
52
77
|
):
|
|
53
|
-
_check_chromadb_available()
|
|
54
|
-
_check_langchain_available()
|
|
55
78
|
self.text_token_count: Optional[int] = None # set later
|
|
56
79
|
|
|
57
80
|
self.source_file: Optional[str] = None
|
|
58
81
|
self.chunks: Optional["Collection"] = None
|
|
59
|
-
self.sections: Optional[List[LCDocument]] = None
|
|
82
|
+
self.sections: Optional[List["LCDocument"]] = None
|
|
60
83
|
self.embedder: DeepEvalBaseEmbeddingModel = embedder
|
|
61
84
|
self.mean_embedding: Optional[float] = None
|
|
62
85
|
|
|
63
86
|
# Mapping of file extensions to their respective loader classes
|
|
64
|
-
self.loader_mapping: Dict[str, Type[BaseLoader]] = {
|
|
65
|
-
".pdf": PyPDFLoader,
|
|
66
|
-
".txt": TextLoader,
|
|
67
|
-
".docx": Docx2txtLoader,
|
|
68
|
-
}
|
|
87
|
+
self.loader_mapping: Dict[str, "Type[BaseLoader]"] = {}
|
|
69
88
|
|
|
70
89
|
#########################################################
|
|
71
90
|
### Chunking Docs #######################################
|
|
@@ -74,7 +93,8 @@ class DocumentChunker:
|
|
|
74
93
|
async def a_chunk_doc(
|
|
75
94
|
self, chunk_size: int = 1024, chunk_overlap: int = 0
|
|
76
95
|
) -> "Collection":
|
|
77
|
-
|
|
96
|
+
lc = _get_langchain()
|
|
97
|
+
chroma = _get_chromadb()
|
|
78
98
|
|
|
79
99
|
# Raise error if chunk_doc is called before load_doc
|
|
80
100
|
if self.sections is None or self.source_file is None:
|
|
@@ -85,13 +105,13 @@ class DocumentChunker:
|
|
|
85
105
|
# Create ChromaDB client
|
|
86
106
|
full_document_path, _ = os.path.splitext(self.source_file)
|
|
87
107
|
document_name = os.path.basename(full_document_path)
|
|
88
|
-
client =
|
|
108
|
+
client = chroma.PersistentClient(path=f".vector_db/{document_name}")
|
|
89
109
|
|
|
90
110
|
collection_name = f"processed_chunks_{chunk_size}_{chunk_overlap}"
|
|
91
111
|
try:
|
|
92
112
|
collection = client.get_collection(name=collection_name)
|
|
93
113
|
except Exception:
|
|
94
|
-
text_splitter: TextSplitter = TokenTextSplitter(
|
|
114
|
+
text_splitter: "TextSplitter" = lc.TokenTextSplitter(
|
|
95
115
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
96
116
|
)
|
|
97
117
|
# Collection doesn't exist, so create it and then add documents
|
|
@@ -108,7 +128,7 @@ class DocumentChunker:
|
|
|
108
128
|
batch_contents = contents[i:batch_end]
|
|
109
129
|
batch_embeddings = embeddings[i:batch_end]
|
|
110
130
|
batch_ids = ids[i:batch_end]
|
|
111
|
-
batch_metadatas: List[
|
|
131
|
+
batch_metadatas: List[dict] = [
|
|
112
132
|
{"source_file": self.source_file} for _ in batch_contents
|
|
113
133
|
]
|
|
114
134
|
|
|
@@ -121,7 +141,8 @@ class DocumentChunker:
|
|
|
121
141
|
return collection
|
|
122
142
|
|
|
123
143
|
def chunk_doc(self, chunk_size: int = 1024, chunk_overlap: int = 0):
|
|
124
|
-
|
|
144
|
+
lc = _get_langchain()
|
|
145
|
+
chroma = _get_chromadb()
|
|
125
146
|
|
|
126
147
|
# Raise error if chunk_doc is called before load_doc
|
|
127
148
|
if self.sections is None or self.source_file is None:
|
|
@@ -132,13 +153,13 @@ class DocumentChunker:
|
|
|
132
153
|
# Create ChromaDB client
|
|
133
154
|
full_document_path, _ = os.path.splitext(self.source_file)
|
|
134
155
|
document_name = os.path.basename(full_document_path)
|
|
135
|
-
client =
|
|
156
|
+
client = chroma.PersistentClient(path=f".vector_db/{document_name}")
|
|
136
157
|
|
|
137
158
|
collection_name = f"processed_chunks_{chunk_size}_{chunk_overlap}"
|
|
138
159
|
try:
|
|
139
160
|
collection = client.get_collection(name=collection_name)
|
|
140
161
|
except Exception:
|
|
141
|
-
text_splitter: TextSplitter = TokenTextSplitter(
|
|
162
|
+
text_splitter: "TextSplitter" = lc.TokenTextSplitter(
|
|
142
163
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
143
164
|
)
|
|
144
165
|
# Collection doesn't exist, so create it and then add documents
|
|
@@ -155,7 +176,7 @@ class DocumentChunker:
|
|
|
155
176
|
batch_contents = contents[i:batch_end]
|
|
156
177
|
batch_embeddings = embeddings[i:batch_end]
|
|
157
178
|
batch_ids = ids[i:batch_end]
|
|
158
|
-
batch_metadatas: List[
|
|
179
|
+
batch_metadatas: List[dict] = [
|
|
159
180
|
{"source_file": self.source_file} for _ in batch_contents
|
|
160
181
|
]
|
|
161
182
|
|
|
@@ -172,17 +193,31 @@ class DocumentChunker:
|
|
|
172
193
|
#########################################################
|
|
173
194
|
|
|
174
195
|
def get_loader(self, path: str, encoding: Optional[str]) -> "BaseLoader":
|
|
196
|
+
lc = _get_langchain()
|
|
197
|
+
# set mapping lazily now that langchain classes exist
|
|
198
|
+
if not self.loader_mapping:
|
|
199
|
+
self.loader_mapping = {
|
|
200
|
+
".pdf": lc.PyPDFLoader,
|
|
201
|
+
".txt": lc.TextLoader,
|
|
202
|
+
".docx": lc.Docx2txtLoader,
|
|
203
|
+
".md": lc.TextLoader,
|
|
204
|
+
".markdown": lc.TextLoader,
|
|
205
|
+
".mdx": lc.TextLoader,
|
|
206
|
+
}
|
|
207
|
+
|
|
175
208
|
# Find appropriate doc loader
|
|
176
209
|
_, extension = os.path.splitext(path)
|
|
177
210
|
extension = extension.lower()
|
|
178
|
-
loader: Optional[
|
|
211
|
+
loader: Optional["Type[BaseLoader]"] = self.loader_mapping.get(
|
|
212
|
+
extension
|
|
213
|
+
)
|
|
179
214
|
if loader is None:
|
|
180
215
|
raise ValueError(f"Unsupported file format: {extension}")
|
|
181
216
|
|
|
182
|
-
# Load doc into sections and calculate total
|
|
183
|
-
if loader is TextLoader:
|
|
217
|
+
# Load doc into sections and calculate total token count
|
|
218
|
+
if loader is lc.TextLoader:
|
|
184
219
|
return loader(path, encoding=encoding, autodetect_encoding=True)
|
|
185
|
-
elif loader
|
|
220
|
+
elif loader in (lc.PyPDFLoader, lc.Docx2txtLoader):
|
|
186
221
|
return loader(path)
|
|
187
222
|
else:
|
|
188
223
|
raise ValueError(f"Unsupported file format: {extension}")
|
|
@@ -200,5 +235,6 @@ class DocumentChunker:
|
|
|
200
235
|
self.source_file = path
|
|
201
236
|
|
|
202
237
|
def count_tokens(self, chunks: List["LCDocument"]):
|
|
203
|
-
|
|
238
|
+
lc = _get_langchain()
|
|
239
|
+
counter = lc.TokenTextSplitter(chunk_size=1, chunk_overlap=0)
|
|
204
240
|
return len(counter.split_documents(chunks))
|