mb-rag 1.0.117__py3-none-any.whl → 1.0.124__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mb-rag might be problematic. Click here for more details.
- mb_rag/chatbot/basic.py +329 -295
- mb_rag/chatbot/chains.py +163 -202
- mb_rag/utils/bounding_box.py +203 -76
- mb_rag/utils/extra.py +1 -1
- mb_rag/version.py +1 -1
- {mb_rag-1.0.117.dist-info → mb_rag-1.0.124.dist-info}/METADATA +1 -1
- mb_rag-1.0.124.dist-info/RECORD +15 -0
- mb_rag-1.0.117.dist-info/RECORD +0 -15
- {mb_rag-1.0.117.dist-info → mb_rag-1.0.124.dist-info}/WHEEL +0 -0
- {mb_rag-1.0.117.dist-info → mb_rag-1.0.124.dist-info}/top_level.txt +0 -0
mb_rag/chatbot/chains.py
CHANGED
|
@@ -1,21 +1,15 @@
|
|
|
1
1
|
## file for chaining functions in chatbot
|
|
2
2
|
|
|
3
|
-
import
|
|
3
|
+
from typing import Optional, List, Dict, Any, Union
|
|
4
|
+
from dataclasses import dataclass
|
|
4
5
|
from langchain.schema.output_parser import StrOutputParser
|
|
5
6
|
from mb_rag.chatbot.prompts import invoke_prompt
|
|
6
|
-
from langchain.schema.runnable import RunnableLambda, RunnableSequence
|
|
7
|
+
from langchain.schema.runnable import RunnableLambda, RunnableSequence
|
|
8
|
+
from mb_rag.utils.extra import check_package
|
|
7
9
|
|
|
8
|
-
|
|
9
|
-
"""
|
|
10
|
-
Check if a package is installed
|
|
11
|
-
Args:
|
|
12
|
-
package_name (str): Name of the package
|
|
13
|
-
Returns:
|
|
14
|
-
bool: True if package is installed, False otherwise
|
|
15
|
-
"""
|
|
16
|
-
return importlib.util.find_spec(package_name) is not None
|
|
10
|
+
__all__ = ['Chain', 'ChainConfig']
|
|
17
11
|
|
|
18
|
-
def check_langchain_dependencies():
|
|
12
|
+
def check_langchain_dependencies() -> None:
|
|
19
13
|
"""
|
|
20
14
|
Check if required LangChain packages are installed
|
|
21
15
|
Raises:
|
|
@@ -29,217 +23,184 @@ def check_langchain_dependencies():
|
|
|
29
23
|
# Check dependencies before importing
|
|
30
24
|
check_langchain_dependencies()
|
|
31
25
|
|
|
26
|
+
@dataclass
|
|
27
|
+
class ChainConfig:
|
|
28
|
+
"""Configuration for chain operations"""
|
|
29
|
+
prompt: Optional[str] = None
|
|
30
|
+
prompt_template: Optional[str] = None
|
|
31
|
+
input_dict: Optional[Dict[str, Any]] = None
|
|
32
32
|
|
|
33
|
-
class
|
|
33
|
+
class Chain:
|
|
34
34
|
"""
|
|
35
|
-
Class to chain functions in chatbot
|
|
35
|
+
Class to chain functions in chatbot with improved OOP design
|
|
36
36
|
"""
|
|
37
|
-
def __init__(self, model
|
|
37
|
+
def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
|
|
38
|
+
"""
|
|
39
|
+
Initialize chain
|
|
40
|
+
Args:
|
|
41
|
+
model: The language model to use
|
|
42
|
+
config: Chain configuration
|
|
43
|
+
**kwargs: Additional arguments
|
|
44
|
+
"""
|
|
38
45
|
self.model = model
|
|
39
|
-
self.
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
46
|
+
self._output_parser = StrOutputParser()
|
|
47
|
+
self._initialize_config(config, **kwargs)
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
|
|
51
|
+
"""
|
|
52
|
+
Create chain from template
|
|
53
|
+
Args:
|
|
54
|
+
model: The language model
|
|
55
|
+
template: Prompt template
|
|
56
|
+
input_dict: Input dictionary for template
|
|
57
|
+
**kwargs: Additional arguments
|
|
58
|
+
Returns:
|
|
59
|
+
Chain: New chain instance
|
|
60
|
+
"""
|
|
61
|
+
config = ChainConfig(
|
|
62
|
+
prompt_template=template,
|
|
63
|
+
input_dict=input_dict
|
|
64
|
+
)
|
|
65
|
+
return cls(model, config, **kwargs)
|
|
66
|
+
|
|
67
|
+
@classmethod
|
|
68
|
+
def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
|
|
69
|
+
"""
|
|
70
|
+
Create chain from direct prompt
|
|
71
|
+
Args:
|
|
72
|
+
model: The language model
|
|
73
|
+
prompt: Direct prompt
|
|
74
|
+
**kwargs: Additional arguments
|
|
75
|
+
Returns:
|
|
76
|
+
Chain: New chain instance
|
|
77
|
+
"""
|
|
78
|
+
config = ChainConfig(prompt=prompt)
|
|
79
|
+
return cls(model, config, **kwargs)
|
|
80
|
+
|
|
81
|
+
def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
|
|
82
|
+
"""Initialize chain configuration"""
|
|
83
|
+
if config:
|
|
84
|
+
self.input_dict = config.input_dict
|
|
85
|
+
if config.prompt_template:
|
|
86
|
+
self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
|
|
87
|
+
else:
|
|
88
|
+
self.prompt = config.prompt
|
|
44
89
|
else:
|
|
45
|
-
self.
|
|
90
|
+
self.input_dict = kwargs.get('input_dict')
|
|
91
|
+
if prompt_template := kwargs.get('prompt_template'):
|
|
92
|
+
self.prompt = invoke_prompt(prompt_template, self.input_dict)
|
|
93
|
+
else:
|
|
94
|
+
self.prompt = kwargs.get('prompt')
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def output_parser(self) -> StrOutputParser:
|
|
98
|
+
"""Get the output parser"""
|
|
99
|
+
return self._output_parser
|
|
100
|
+
|
|
101
|
+
@staticmethod
|
|
102
|
+
def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
|
|
103
|
+
"""
|
|
104
|
+
Validate chain components
|
|
105
|
+
Args:
|
|
106
|
+
prompt: The prompt to validate
|
|
107
|
+
middle_chain: Optional middle chain to validate
|
|
108
|
+
Raises:
|
|
109
|
+
ValueError: If validation fails
|
|
110
|
+
"""
|
|
111
|
+
if prompt is None:
|
|
112
|
+
raise ValueError("Prompt is not provided")
|
|
113
|
+
if middle_chain is not None and not isinstance(middle_chain, list):
|
|
114
|
+
raise ValueError("middle_chain should be a list")
|
|
46
115
|
|
|
47
|
-
def invoke(self):
|
|
116
|
+
def invoke(self) -> Any:
|
|
48
117
|
"""
|
|
49
118
|
Invoke the chain
|
|
50
119
|
Returns:
|
|
51
|
-
|
|
120
|
+
Any: Output from the chain
|
|
121
|
+
Raises:
|
|
122
|
+
Exception: If prompt is not provided
|
|
52
123
|
"""
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
124
|
+
self._validate_chain_components(self.prompt)
|
|
125
|
+
chain_output = self.prompt | self.model | self.output_parser
|
|
126
|
+
return chain_output
|
|
127
|
+
|
|
128
|
+
def chain_sequence_invoke(self,
|
|
129
|
+
middle_chain: Optional[List] = None,
|
|
130
|
+
final_chain: Optional[RunnableLambda] = None) -> Any:
|
|
60
131
|
"""
|
|
61
|
-
Chain invoke the
|
|
132
|
+
Chain invoke the sequence
|
|
62
133
|
Args:
|
|
63
|
-
middle_chain
|
|
64
|
-
final_chain
|
|
134
|
+
middle_chain: List of functions/Prompts/RunnableLambda to chain
|
|
135
|
+
final_chain: Final chain to run
|
|
65
136
|
Returns:
|
|
66
|
-
|
|
137
|
+
Any: Output from the chain
|
|
67
138
|
"""
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
else
|
|
71
|
-
|
|
72
|
-
if
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
return Exception("Prompt is not provided")
|
|
79
|
-
|
|
80
|
-
def chain_parrellel_invoke(self, parrellel_chain: list):
|
|
139
|
+
self._validate_chain_components(self.prompt, middle_chain)
|
|
140
|
+
|
|
141
|
+
final = final_chain if final_chain is not None else self.output_parser
|
|
142
|
+
|
|
143
|
+
if middle_chain:
|
|
144
|
+
func_chain = RunnableSequence(self.prompt, middle_chain, final)
|
|
145
|
+
return func_chain.invoke()
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
def chain_parallel_invoke(self, parallel_chain: List) -> Any:
|
|
81
149
|
"""
|
|
82
|
-
Chain invoke
|
|
150
|
+
Chain invoke in parallel
|
|
83
151
|
Args:
|
|
84
|
-
|
|
152
|
+
parallel_chain: List of chains to run in parallel
|
|
85
153
|
Returns:
|
|
86
|
-
|
|
154
|
+
Any: Output from the parallel chains
|
|
155
|
+
Raises:
|
|
156
|
+
ImportError: If LangChain is not installed
|
|
87
157
|
"""
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
158
|
+
if not check_package("langchain"):
|
|
159
|
+
raise ImportError("LangChain package not found. Please install it using: pip install langchain")
|
|
160
|
+
return parallel_chain.invoke()
|
|
161
|
+
|
|
162
|
+
def chain_branch_invoke(self, branch_chain: Dict) -> Any:
|
|
91
163
|
"""
|
|
92
|
-
Chain invoke
|
|
164
|
+
Chain invoke with branching
|
|
93
165
|
Args:
|
|
94
|
-
branch_chain
|
|
166
|
+
branch_chain: Dictionary of branch chains
|
|
95
167
|
Returns:
|
|
96
|
-
|
|
168
|
+
Any: Output from the branch chain
|
|
169
|
+
Raises:
|
|
170
|
+
ImportError: If LangChain is not installed
|
|
97
171
|
"""
|
|
172
|
+
if not check_package("langchain"):
|
|
173
|
+
raise ImportError("LangChain package not found. Please install it using: pip install langchain")
|
|
98
174
|
return branch_chain.invoke()
|
|
99
175
|
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
[
|
|
132
|
-
("system", "You are an expert product reviewer."),
|
|
133
|
-
(
|
|
134
|
-
"human",
|
|
135
|
-
"Given these features: {features}, list the cons of these features.",
|
|
136
|
-
),
|
|
137
|
-
]
|
|
138
|
-
)
|
|
139
|
-
return cons_template.format_prompt(features=features)
|
|
140
|
-
|
|
141
|
-
# Combine pros and cons into a final review
|
|
142
|
-
def combine_pros_cons(pros, cons):
|
|
143
|
-
return f"Pros:\n{pros}\n\nCons:\n{cons}"
|
|
144
|
-
|
|
145
|
-
# Simplify branches with LCEL
|
|
146
|
-
pros_branch_chain = (
|
|
147
|
-
RunnableLambda(lambda x: analyze_pros(x)) | model | StrOutputParser()
|
|
148
|
-
)
|
|
149
|
-
|
|
150
|
-
cons_branch_chain = (
|
|
151
|
-
RunnableLambda(lambda x: analyze_cons(x)) | model | StrOutputParser()
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
# Create the combined chain using LangChain Expression Language (LCEL)
|
|
155
|
-
chain = (
|
|
156
|
-
prompt_template
|
|
157
|
-
| model
|
|
158
|
-
| StrOutputParser()
|
|
159
|
-
| RunnableParallel(branches={"pros": pros_branch_chain, "cons": cons_branch_chain})
|
|
160
|
-
| RunnableLambda(lambda x: combine_pros_cons(x["branches"]["pros"], x["branches"]["cons"]))
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
# Run the chain
|
|
164
|
-
result = chain.invoke({"product_name": "MacBook Pro"})
|
|
165
|
-
|
|
166
|
-
### Example of branch chain
|
|
167
|
-
from langchain.schema.runnable import RunnableBranch
|
|
168
|
-
|
|
169
|
-
positive_feedback_template = ChatPromptTemplate.from_messages(
|
|
170
|
-
[
|
|
171
|
-
("system", "You are a helpful assistant."),
|
|
172
|
-
("human",
|
|
173
|
-
"Generate a thank you note for this positive feedback: {feedback}."),
|
|
174
|
-
]
|
|
175
|
-
)
|
|
176
|
-
|
|
177
|
-
negative_feedback_template = ChatPromptTemplate.from_messages(
|
|
178
|
-
[
|
|
179
|
-
("system", "You are a helpful assistant."),
|
|
180
|
-
("human",
|
|
181
|
-
"Generate a response addressing this negative feedback: {feedback}."),
|
|
182
|
-
]
|
|
183
|
-
)
|
|
184
|
-
|
|
185
|
-
neutral_feedback_template = ChatPromptTemplate.from_messages(
|
|
186
|
-
[
|
|
187
|
-
("system", "You are a helpful assistant."),
|
|
188
|
-
(
|
|
189
|
-
"human",
|
|
190
|
-
"Generate a request for more details for this neutral feedback: {feedback}.",
|
|
191
|
-
),
|
|
192
|
-
]
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
escalate_feedback_template = ChatPromptTemplate.from_messages(
|
|
196
|
-
[
|
|
197
|
-
("system", "You are a helpful assistant."),
|
|
198
|
-
(
|
|
199
|
-
"human",
|
|
200
|
-
"Generate a message to escalate this feedback to a human agent: {feedback}.",
|
|
201
|
-
),
|
|
202
|
-
]
|
|
203
|
-
)
|
|
204
|
-
|
|
205
|
-
# Define the feedback classification template
|
|
206
|
-
classification_template = ChatPromptTemplate.from_messages(
|
|
207
|
-
[
|
|
208
|
-
("system", "You are a helpful assistant."),
|
|
209
|
-
("human",
|
|
210
|
-
"Classify the sentiment of this feedback as positive, negative, neutral, or escalate: {feedback}."),
|
|
211
|
-
]
|
|
212
|
-
)
|
|
213
|
-
|
|
214
|
-
# Define the runnable branches for handling feedback
|
|
215
|
-
branches = RunnableBranch(
|
|
216
|
-
(
|
|
217
|
-
lambda x: "positive" in x,
|
|
218
|
-
positive_feedback_template | model | StrOutputParser() # Positive feedback chain
|
|
219
|
-
),
|
|
220
|
-
(
|
|
221
|
-
lambda x: "negative" in x,
|
|
222
|
-
negative_feedback_template | model | StrOutputParser() # Negative feedback chain
|
|
223
|
-
),
|
|
224
|
-
(
|
|
225
|
-
lambda x: "neutral" in x,
|
|
226
|
-
neutral_feedback_template | model | StrOutputParser() # Neutral feedback chain
|
|
227
|
-
),
|
|
228
|
-
escalate_feedback_template | model | StrOutputParser()
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
# Create the classification chain
|
|
232
|
-
classification_chain = classification_template | model | StrOutputParser()
|
|
233
|
-
|
|
234
|
-
# Combine classification and response generation into one chain
|
|
235
|
-
chain = classification_chain | branches
|
|
236
|
-
|
|
237
|
-
# Example usage:
|
|
238
|
-
# Good review - "The product is excellent. I really enjoyed using it and found it very helpful."
|
|
239
|
-
# Bad review - "The product is terrible. It broke after just one use and the quality is very poor."
|
|
240
|
-
# Neutral review - "The product is okay. It works as expected but nothing exceptional."
|
|
241
|
-
# Default - "I'm not sure about the product yet. Can you tell me more about its features and benefits?"
|
|
242
|
-
|
|
243
|
-
review = "The product is terrible. It broke after just one use and the quality is very poor."
|
|
244
|
-
result = chain.invoke({"feedback": review})
|
|
245
|
-
"""
|
|
176
|
+
@staticmethod
|
|
177
|
+
def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
|
|
178
|
+
"""
|
|
179
|
+
Create a parallel chain
|
|
180
|
+
Args:
|
|
181
|
+
prompt_template: Template for the prompt
|
|
182
|
+
model: The language model
|
|
183
|
+
branches: Dictionary of branch configurations
|
|
184
|
+
Returns:
|
|
185
|
+
Any: Configured parallel chain
|
|
186
|
+
"""
|
|
187
|
+
from langchain.schema.runnable import RunnableParallel
|
|
188
|
+
return (
|
|
189
|
+
prompt_template
|
|
190
|
+
| model
|
|
191
|
+
| StrOutputParser()
|
|
192
|
+
| RunnableParallel(branches=branches)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
|
|
197
|
+
"""
|
|
198
|
+
Create a branch chain
|
|
199
|
+
Args:
|
|
200
|
+
conditions: List of condition-chain tuples
|
|
201
|
+
default_chain: Default chain to use
|
|
202
|
+
Returns:
|
|
203
|
+
Any: Configured branch chain
|
|
204
|
+
"""
|
|
205
|
+
from langchain.schema.runnable import RunnableBranch
|
|
206
|
+
return RunnableBranch(*conditions, default_chain)
|