mb-rag 1.0.117__py3-none-any.whl → 1.0.124__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

mb_rag/chatbot/chains.py CHANGED
@@ -1,21 +1,15 @@
1
1
  ## file for chaining functions in chatbot
2
2
 
3
- import importlib.util
3
+ from typing import Optional, List, Dict, Any, Union
4
+ from dataclasses import dataclass
4
5
  from langchain.schema.output_parser import StrOutputParser
5
6
  from mb_rag.chatbot.prompts import invoke_prompt
6
- from langchain.schema.runnable import RunnableLambda, RunnableSequence
7
+ from langchain.schema.runnable import RunnableLambda, RunnableSequence
8
+ from mb_rag.utils.extra import check_package
7
9
 
8
- def check_package(package_name):
9
- """
10
- Check if a package is installed
11
- Args:
12
- package_name (str): Name of the package
13
- Returns:
14
- bool: True if package is installed, False otherwise
15
- """
16
- return importlib.util.find_spec(package_name) is not None
10
+ __all__ = ['Chain', 'ChainConfig']
17
11
 
18
- def check_langchain_dependencies():
12
+ def check_langchain_dependencies() -> None:
19
13
  """
20
14
  Check if required LangChain packages are installed
21
15
  Raises:
@@ -29,217 +23,184 @@ def check_langchain_dependencies():
29
23
  # Check dependencies before importing
30
24
  check_langchain_dependencies()
31
25
 
26
+ @dataclass
27
+ class ChainConfig:
28
+ """Configuration for chain operations"""
29
+ prompt: Optional[str] = None
30
+ prompt_template: Optional[str] = None
31
+ input_dict: Optional[Dict[str, Any]] = None
32
32
 
33
- class chain:
33
+ class Chain:
34
34
  """
35
- Class to chain functions in chatbot
35
+ Class to chain functions in chatbot with improved OOP design
36
36
  """
37
- def __init__(self, model, prompt: str = None, prompt_template: str = None, input_dict: dict = None, **kwargs):
37
+ def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
38
+ """
39
+ Initialize chain
40
+ Args:
41
+ model: The language model to use
42
+ config: Chain configuration
43
+ **kwargs: Additional arguments
44
+ """
38
45
  self.model = model
39
- self.output_parser = StrOutputParser() ## self.output_parser = RunnableLambda(lambda x: x.content) - can use this also
40
- if input_dict is not None:
41
- self.input_dict = input_dict
42
- if prompt_template is not None:
43
- self.prompt = invoke_prompt(prompt_template, self.input_dict)
46
+ self._output_parser = StrOutputParser()
47
+ self._initialize_config(config, **kwargs)
48
+
49
+ @classmethod
50
+ def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
51
+ """
52
+ Create chain from template
53
+ Args:
54
+ model: The language model
55
+ template: Prompt template
56
+ input_dict: Input dictionary for template
57
+ **kwargs: Additional arguments
58
+ Returns:
59
+ Chain: New chain instance
60
+ """
61
+ config = ChainConfig(
62
+ prompt_template=template,
63
+ input_dict=input_dict
64
+ )
65
+ return cls(model, config, **kwargs)
66
+
67
+ @classmethod
68
+ def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
69
+ """
70
+ Create chain from direct prompt
71
+ Args:
72
+ model: The language model
73
+ prompt: Direct prompt
74
+ **kwargs: Additional arguments
75
+ Returns:
76
+ Chain: New chain instance
77
+ """
78
+ config = ChainConfig(prompt=prompt)
79
+ return cls(model, config, **kwargs)
80
+
81
+ def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
82
+ """Initialize chain configuration"""
83
+ if config:
84
+ self.input_dict = config.input_dict
85
+ if config.prompt_template:
86
+ self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
87
+ else:
88
+ self.prompt = config.prompt
44
89
  else:
45
- self.prompt = prompt
90
+ self.input_dict = kwargs.get('input_dict')
91
+ if prompt_template := kwargs.get('prompt_template'):
92
+ self.prompt = invoke_prompt(prompt_template, self.input_dict)
93
+ else:
94
+ self.prompt = kwargs.get('prompt')
95
+
96
+ @property
97
+ def output_parser(self) -> StrOutputParser:
98
+ """Get the output parser"""
99
+ return self._output_parser
100
+
101
+ @staticmethod
102
+ def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
103
+ """
104
+ Validate chain components
105
+ Args:
106
+ prompt: The prompt to validate
107
+ middle_chain: Optional middle chain to validate
108
+ Raises:
109
+ ValueError: If validation fails
110
+ """
111
+ if prompt is None:
112
+ raise ValueError("Prompt is not provided")
113
+ if middle_chain is not None and not isinstance(middle_chain, list):
114
+ raise ValueError("middle_chain should be a list")
46
115
 
47
- def invoke(self):
116
+ def invoke(self) -> Any:
48
117
  """
49
118
  Invoke the chain
50
119
  Returns:
51
- str: Output from the chain
120
+ Any: Output from the chain
121
+ Raises:
122
+ Exception: If prompt is not provided
52
123
  """
53
- if self.prompt is not None:
54
- chain_output = self.prompt | self.model | self.output_parser
55
- return chain_output
56
- else:
57
- return Exception("Prompt is not provided")
58
-
59
- def chain_seqeunce_invoke(self, middle_chain: list, final_chain: RunnableLambda = None):
124
+ self._validate_chain_components(self.prompt)
125
+ chain_output = self.prompt | self.model | self.output_parser
126
+ return chain_output
127
+
128
+ def chain_sequence_invoke(self,
129
+ middle_chain: Optional[List] = None,
130
+ final_chain: Optional[RunnableLambda] = None) -> Any:
60
131
  """
61
- Chain invoke the chain
132
+ Chain invoke the sequence
62
133
  Args:
63
- middle_chain (list): List of functions/Prompts/RunnableLambda to chain
64
- final_chain (RunnableLambda): Final chain to run. Default is self.output_parser
134
+ middle_chain: List of functions/Prompts/RunnableLambda to chain
135
+ final_chain: Final chain to run
65
136
  Returns:
66
- str: Output from the chain
137
+ Any: Output from the chain
67
138
  """
68
- if final_chain is not None:
69
- self.final_chain = final_chain
70
- else:
71
- self.final_chain = self.output_parser
72
- if self.prompt is not None:
73
- if middle_chain is not None:
74
- assert isinstance(middle_chain, list), "middle_chain should be a list"
75
- func_chain = RunnableSequence(self.prompt, middle_chain, self.final_chain)
76
- return func_chain.invoke()
77
- else:
78
- return Exception("Prompt is not provided")
79
-
80
- def chain_parrellel_invoke(self, parrellel_chain: list):
139
+ self._validate_chain_components(self.prompt, middle_chain)
140
+
141
+ final = final_chain if final_chain is not None else self.output_parser
142
+
143
+ if middle_chain:
144
+ func_chain = RunnableSequence(self.prompt, middle_chain, final)
145
+ return func_chain.invoke()
146
+ return None
147
+
148
+ def chain_parallel_invoke(self, parallel_chain: List) -> Any:
81
149
  """
82
- Chain invoke the chain #### better to use RunnableParallel outside the class
150
+ Chain invoke in parallel
83
151
  Args:
84
- parrellel_chain (list): List of functions/Prompts/RunnableLambda to chain
152
+ parallel_chain: List of chains to run in parallel
85
153
  Returns:
86
- str: Output from the chain
154
+ Any: Output from the parallel chains
155
+ Raises:
156
+ ImportError: If LangChain is not installed
87
157
  """
88
- return parrellel_chain.invoke()
89
-
90
- def chain_branch_invoke(self, branch_chain: dict):
158
+ if not check_package("langchain"):
159
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
160
+ return parallel_chain.invoke()
161
+
162
+ def chain_branch_invoke(self, branch_chain: Dict) -> Any:
91
163
  """
92
- Chain invoke the chain #### better to use RunnableBranch outside the class
164
+ Chain invoke with branching
93
165
  Args:
94
- branch_chain (dict): Dictionary of functions/Prompts/RunnableLambda to chain
166
+ branch_chain: Dictionary of branch chains
95
167
  Returns:
96
- str: Output from the chain
168
+ Any: Output from the branch chain
169
+ Raises:
170
+ ImportError: If LangChain is not installed
97
171
  """
172
+ if not check_package("langchain"):
173
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
98
174
  return branch_chain.invoke()
99
175
 
100
- # Example code is kept as comments for reference
101
- """
102
- ### Example of parrellel chain
103
- from langchain.schema.runnable import RunnableParallel
104
- from langchain.schema.output_parser import StrOutputParser
105
- from langchain_core.prompts import ChatPromptTemplate
106
-
107
- # Define prompt template
108
- prompt_template = ChatPromptTemplate.from_messages(
109
- [
110
- ("system", "You are an expert product reviewer."),
111
- ("human", "List the main features of the product {product_name}."),
112
- ]
113
- )
114
-
115
- # Define pros analysis step
116
- def analyze_pros(features):
117
- pros_template = ChatPromptTemplate.from_messages(
118
- [
119
- ("system", "You are an expert product reviewer."),
120
- (
121
- "human",
122
- "Given these features: {features}, list the pros of these features.",
123
- ),
124
- ]
125
- )
126
- return pros_template.format_prompt(features=features)
127
-
128
- # Define cons analysis step
129
- def analyze_cons(features):
130
- cons_template = ChatPromptTemplate.from_messages(
131
- [
132
- ("system", "You are an expert product reviewer."),
133
- (
134
- "human",
135
- "Given these features: {features}, list the cons of these features.",
136
- ),
137
- ]
138
- )
139
- return cons_template.format_prompt(features=features)
140
-
141
- # Combine pros and cons into a final review
142
- def combine_pros_cons(pros, cons):
143
- return f"Pros:\n{pros}\n\nCons:\n{cons}"
144
-
145
- # Simplify branches with LCEL
146
- pros_branch_chain = (
147
- RunnableLambda(lambda x: analyze_pros(x)) | model | StrOutputParser()
148
- )
149
-
150
- cons_branch_chain = (
151
- RunnableLambda(lambda x: analyze_cons(x)) | model | StrOutputParser()
152
- )
153
-
154
- # Create the combined chain using LangChain Expression Language (LCEL)
155
- chain = (
156
- prompt_template
157
- | model
158
- | StrOutputParser()
159
- | RunnableParallel(branches={"pros": pros_branch_chain, "cons": cons_branch_chain})
160
- | RunnableLambda(lambda x: combine_pros_cons(x["branches"]["pros"], x["branches"]["cons"]))
161
- )
162
-
163
- # Run the chain
164
- result = chain.invoke({"product_name": "MacBook Pro"})
165
-
166
- ### Example of branch chain
167
- from langchain.schema.runnable import RunnableBranch
168
-
169
- positive_feedback_template = ChatPromptTemplate.from_messages(
170
- [
171
- ("system", "You are a helpful assistant."),
172
- ("human",
173
- "Generate a thank you note for this positive feedback: {feedback}."),
174
- ]
175
- )
176
-
177
- negative_feedback_template = ChatPromptTemplate.from_messages(
178
- [
179
- ("system", "You are a helpful assistant."),
180
- ("human",
181
- "Generate a response addressing this negative feedback: {feedback}."),
182
- ]
183
- )
184
-
185
- neutral_feedback_template = ChatPromptTemplate.from_messages(
186
- [
187
- ("system", "You are a helpful assistant."),
188
- (
189
- "human",
190
- "Generate a request for more details for this neutral feedback: {feedback}.",
191
- ),
192
- ]
193
- )
194
-
195
- escalate_feedback_template = ChatPromptTemplate.from_messages(
196
- [
197
- ("system", "You are a helpful assistant."),
198
- (
199
- "human",
200
- "Generate a message to escalate this feedback to a human agent: {feedback}.",
201
- ),
202
- ]
203
- )
204
-
205
- # Define the feedback classification template
206
- classification_template = ChatPromptTemplate.from_messages(
207
- [
208
- ("system", "You are a helpful assistant."),
209
- ("human",
210
- "Classify the sentiment of this feedback as positive, negative, neutral, or escalate: {feedback}."),
211
- ]
212
- )
213
-
214
- # Define the runnable branches for handling feedback
215
- branches = RunnableBranch(
216
- (
217
- lambda x: "positive" in x,
218
- positive_feedback_template | model | StrOutputParser() # Positive feedback chain
219
- ),
220
- (
221
- lambda x: "negative" in x,
222
- negative_feedback_template | model | StrOutputParser() # Negative feedback chain
223
- ),
224
- (
225
- lambda x: "neutral" in x,
226
- neutral_feedback_template | model | StrOutputParser() # Neutral feedback chain
227
- ),
228
- escalate_feedback_template | model | StrOutputParser()
229
- )
230
-
231
- # Create the classification chain
232
- classification_chain = classification_template | model | StrOutputParser()
233
-
234
- # Combine classification and response generation into one chain
235
- chain = classification_chain | branches
236
-
237
- # Example usage:
238
- # Good review - "The product is excellent. I really enjoyed using it and found it very helpful."
239
- # Bad review - "The product is terrible. It broke after just one use and the quality is very poor."
240
- # Neutral review - "The product is okay. It works as expected but nothing exceptional."
241
- # Default - "I'm not sure about the product yet. Can you tell me more about its features and benefits?"
242
-
243
- review = "The product is terrible. It broke after just one use and the quality is very poor."
244
- result = chain.invoke({"feedback": review})
245
- """
176
+ @staticmethod
177
+ def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
178
+ """
179
+ Create a parallel chain
180
+ Args:
181
+ prompt_template: Template for the prompt
182
+ model: The language model
183
+ branches: Dictionary of branch configurations
184
+ Returns:
185
+ Any: Configured parallel chain
186
+ """
187
+ from langchain.schema.runnable import RunnableParallel
188
+ return (
189
+ prompt_template
190
+ | model
191
+ | StrOutputParser()
192
+ | RunnableParallel(branches=branches)
193
+ )
194
+
195
+ @staticmethod
196
+ def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
197
+ """
198
+ Create a branch chain
199
+ Args:
200
+ conditions: List of condition-chain tuples
201
+ default_chain: Default chain to use
202
+ Returns:
203
+ Any: Configured branch chain
204
+ """
205
+ from langchain.schema.runnable import RunnableBranch
206
+ return RunnableBranch(*conditions, default_chain)