mb-rag 1.1.57.post1__py3-none-any.whl → 1.1.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

mb_rag/chatbot/chains.py CHANGED
@@ -1,206 +1,206 @@
1
- ## file for chaining functions in chatbot
2
-
3
- from typing import Optional, List, Dict, Any, Union
4
- from dataclasses import dataclass
5
- from langchain.schema.output_parser import StrOutputParser
6
- from mb_rag.chatbot.prompts import invoke_prompt
7
- from langchain.schema.runnable import RunnableLambda, RunnableSequence
8
- from mb_rag.utils.extra import check_package
9
-
10
- __all__ = ['Chain', 'ChainConfig']
11
-
12
- def check_langchain_dependencies() -> None:
13
- """
14
- Check if required LangChain packages are installed
15
- Raises:
16
- ImportError: If any required package is missing
17
- """
18
- if not check_package("langchain"):
19
- raise ImportError("LangChain package not found. Please install it using: pip install langchain")
20
- if not check_package("langchain_core"):
21
- raise ImportError("LangChain Core package not found. Please install it using: pip install langchain-core")
22
-
23
- # Check dependencies before importing
24
- check_langchain_dependencies()
25
-
26
- @dataclass
27
- class ChainConfig:
28
- """Configuration for chain operations"""
29
- prompt: Optional[str] = None
30
- prompt_template: Optional[str] = None
31
- input_dict: Optional[Dict[str, Any]] = None
32
-
33
- class Chain:
34
- """
35
- Class to chain functions in chatbot with improved OOP design
36
- """
37
- def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
38
- """
39
- Initialize chain
40
- Args:
41
- model: The language model to use
42
- config: Chain configuration
43
- **kwargs: Additional arguments
44
- """
45
- self.model = model
46
- self._output_parser = StrOutputParser()
47
- self._initialize_config(config, **kwargs)
48
-
49
- @classmethod
50
- def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
51
- """
52
- Create chain from template
53
- Args:
54
- model: The language model
55
- template: Prompt template
56
- input_dict: Input dictionary for template
57
- **kwargs: Additional arguments
58
- Returns:
59
- Chain: New chain instance
60
- """
61
- config = ChainConfig(
62
- prompt_template=template,
63
- input_dict=input_dict
64
- )
65
- return cls(model, config, **kwargs)
66
-
67
- @classmethod
68
- def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
69
- """
70
- Create chain from direct prompt
71
- Args:
72
- model: The language model
73
- prompt: Direct prompt
74
- **kwargs: Additional arguments
75
- Returns:
76
- Chain: New chain instance
77
- """
78
- config = ChainConfig(prompt=prompt)
79
- return cls(model, config, **kwargs)
80
-
81
- def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
82
- """Initialize chain configuration"""
83
- if config:
84
- self.input_dict = config.input_dict
85
- if config.prompt_template:
86
- self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
87
- else:
88
- self.prompt = config.prompt
89
- else:
90
- self.input_dict = kwargs.get('input_dict')
91
- if prompt_template := kwargs.get('prompt_template'):
92
- self.prompt = invoke_prompt(prompt_template, self.input_dict)
93
- else:
94
- self.prompt = kwargs.get('prompt')
95
-
96
- @property
97
- def output_parser(self) -> StrOutputParser:
98
- """Get the output parser"""
99
- return self._output_parser
100
-
101
- @staticmethod
102
- def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
103
- """
104
- Validate chain components
105
- Args:
106
- prompt: The prompt to validate
107
- middle_chain: Optional middle chain to validate
108
- Raises:
109
- ValueError: If validation fails
110
- """
111
- if prompt is None:
112
- raise ValueError("Prompt is not provided")
113
- if middle_chain is not None and not isinstance(middle_chain, list):
114
- raise ValueError("middle_chain should be a list")
115
-
116
- def invoke(self) -> Any:
117
- """
118
- Invoke the chain
119
- Returns:
120
- Any: Output from the chain
121
- Raises:
122
- Exception: If prompt is not provided
123
- """
124
- self._validate_chain_components(self.prompt)
125
- chain_output = self.prompt | self.model | self.output_parser
126
- return chain_output
127
-
128
- def chain_sequence_invoke(self,
129
- middle_chain: Optional[List] = None,
130
- final_chain: Optional[RunnableLambda] = None) -> Any:
131
- """
132
- Chain invoke the sequence
133
- Args:
134
- middle_chain: List of functions/Prompts/RunnableLambda to chain
135
- final_chain: Final chain to run
136
- Returns:
137
- Any: Output from the chain
138
- """
139
- self._validate_chain_components(self.prompt, middle_chain)
140
-
141
- final = final_chain if final_chain is not None else self.output_parser
142
-
143
- if middle_chain:
144
- func_chain = RunnableSequence(self.prompt, middle_chain, final)
145
- return func_chain.invoke()
146
- return None
147
-
148
- def chain_parallel_invoke(self, parallel_chain: List) -> Any:
149
- """
150
- Chain invoke in parallel
151
- Args:
152
- parallel_chain: List of chains to run in parallel
153
- Returns:
154
- Any: Output from the parallel chains
155
- Raises:
156
- ImportError: If LangChain is not installed
157
- """
158
- if not check_package("langchain"):
159
- raise ImportError("LangChain package not found. Please install it using: pip install langchain")
160
- return parallel_chain.invoke()
161
-
162
- def chain_branch_invoke(self, branch_chain: Dict) -> Any:
163
- """
164
- Chain invoke with branching
165
- Args:
166
- branch_chain: Dictionary of branch chains
167
- Returns:
168
- Any: Output from the branch chain
169
- Raises:
170
- ImportError: If LangChain is not installed
171
- """
172
- if not check_package("langchain"):
173
- raise ImportError("LangChain package not found. Please install it using: pip install langchain")
174
- return branch_chain.invoke()
175
-
176
- @staticmethod
177
- def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
178
- """
179
- Create a parallel chain
180
- Args:
181
- prompt_template: Template for the prompt
182
- model: The language model
183
- branches: Dictionary of branch configurations
184
- Returns:
185
- Any: Configured parallel chain
186
- """
187
- from langchain.schema.runnable import RunnableParallel
188
- return (
189
- prompt_template
190
- | model
191
- | StrOutputParser()
192
- | RunnableParallel(branches=branches)
193
- )
194
-
195
- @staticmethod
196
- def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
197
- """
198
- Create a branch chain
199
- Args:
200
- conditions: List of condition-chain tuples
201
- default_chain: Default chain to use
202
- Returns:
203
- Any: Configured branch chain
204
- """
205
- from langchain.schema.runnable import RunnableBranch
206
- return RunnableBranch(*conditions, default_chain)
1
+ ## file for chaining functions in chatbot
2
+
3
+ from typing import Optional, List, Dict, Any, Union
4
+ from dataclasses import dataclass
5
+ from langchain.schema.output_parser import StrOutputParser
6
+ from mb_rag.chatbot.prompts import invoke_prompt
7
+ from langchain.schema.runnable import RunnableLambda, RunnableSequence
8
+ from mb_rag.utils.extra import check_package
9
+
10
+ __all__ = ['Chain', 'ChainConfig']
11
+
12
+ def check_langchain_dependencies() -> None:
13
+ """
14
+ Check if required LangChain packages are installed
15
+ Raises:
16
+ ImportError: If any required package is missing
17
+ """
18
+ if not check_package("langchain"):
19
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
20
+ if not check_package("langchain_core"):
21
+ raise ImportError("LangChain Core package not found. Please install it using: pip install langchain-core")
22
+
23
+ # Check dependencies before importing
24
+ check_langchain_dependencies()
25
+
26
+ @dataclass
27
+ class ChainConfig:
28
+ """Configuration for chain operations"""
29
+ prompt: Optional[str] = None
30
+ prompt_template: Optional[str] = None
31
+ input_dict: Optional[Dict[str, Any]] = None
32
+
33
+ class Chain:
34
+ """
35
+ Class to chain functions in chatbot with improved OOP design
36
+ """
37
+ def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
38
+ """
39
+ Initialize chain
40
+ Args:
41
+ model: The language model to use
42
+ config: Chain configuration
43
+ **kwargs: Additional arguments
44
+ """
45
+ self.model = model
46
+ self._output_parser = StrOutputParser()
47
+ self._initialize_config(config, **kwargs)
48
+
49
+ @classmethod
50
+ def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
51
+ """
52
+ Create chain from template
53
+ Args:
54
+ model: The language model
55
+ template: Prompt template
56
+ input_dict: Input dictionary for template
57
+ **kwargs: Additional arguments
58
+ Returns:
59
+ Chain: New chain instance
60
+ """
61
+ config = ChainConfig(
62
+ prompt_template=template,
63
+ input_dict=input_dict
64
+ )
65
+ return cls(model, config, **kwargs)
66
+
67
+ @classmethod
68
+ def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
69
+ """
70
+ Create chain from direct prompt
71
+ Args:
72
+ model: The language model
73
+ prompt: Direct prompt
74
+ **kwargs: Additional arguments
75
+ Returns:
76
+ Chain: New chain instance
77
+ """
78
+ config = ChainConfig(prompt=prompt)
79
+ return cls(model, config, **kwargs)
80
+
81
+ def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
82
+ """Initialize chain configuration"""
83
+ if config:
84
+ self.input_dict = config.input_dict
85
+ if config.prompt_template:
86
+ self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
87
+ else:
88
+ self.prompt = config.prompt
89
+ else:
90
+ self.input_dict = kwargs.get('input_dict')
91
+ if prompt_template := kwargs.get('prompt_template'):
92
+ self.prompt = invoke_prompt(prompt_template, self.input_dict)
93
+ else:
94
+ self.prompt = kwargs.get('prompt')
95
+
96
+ @property
97
+ def output_parser(self) -> StrOutputParser:
98
+ """Get the output parser"""
99
+ return self._output_parser
100
+
101
+ @staticmethod
102
+ def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
103
+ """
104
+ Validate chain components
105
+ Args:
106
+ prompt: The prompt to validate
107
+ middle_chain: Optional middle chain to validate
108
+ Raises:
109
+ ValueError: If validation fails
110
+ """
111
+ if prompt is None:
112
+ raise ValueError("Prompt is not provided")
113
+ if middle_chain is not None and not isinstance(middle_chain, list):
114
+ raise ValueError("middle_chain should be a list")
115
+
116
+ def invoke(self) -> Any:
117
+ """
118
+ Invoke the chain
119
+ Returns:
120
+ Any: Output from the chain
121
+ Raises:
122
+ Exception: If prompt is not provided
123
+ """
124
+ self._validate_chain_components(self.prompt)
125
+ chain_output = self.prompt | self.model | self.output_parser
126
+ return chain_output
127
+
128
+ def chain_sequence_invoke(self,
129
+ middle_chain: Optional[List] = None,
130
+ final_chain: Optional[RunnableLambda] = None) -> Any:
131
+ """
132
+ Chain invoke the sequence
133
+ Args:
134
+ middle_chain: List of functions/Prompts/RunnableLambda to chain
135
+ final_chain: Final chain to run
136
+ Returns:
137
+ Any: Output from the chain
138
+ """
139
+ self._validate_chain_components(self.prompt, middle_chain)
140
+
141
+ final = final_chain if final_chain is not None else self.output_parser
142
+
143
+ if middle_chain:
144
+ func_chain = RunnableSequence(self.prompt, middle_chain, final)
145
+ return func_chain.invoke()
146
+ return None
147
+
148
+ def chain_parallel_invoke(self, parallel_chain: List) -> Any:
149
+ """
150
+ Chain invoke in parallel
151
+ Args:
152
+ parallel_chain: List of chains to run in parallel
153
+ Returns:
154
+ Any: Output from the parallel chains
155
+ Raises:
156
+ ImportError: If LangChain is not installed
157
+ """
158
+ if not check_package("langchain"):
159
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
160
+ return parallel_chain.invoke()
161
+
162
+ def chain_branch_invoke(self, branch_chain: Dict) -> Any:
163
+ """
164
+ Chain invoke with branching
165
+ Args:
166
+ branch_chain: Dictionary of branch chains
167
+ Returns:
168
+ Any: Output from the branch chain
169
+ Raises:
170
+ ImportError: If LangChain is not installed
171
+ """
172
+ if not check_package("langchain"):
173
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
174
+ return branch_chain.invoke()
175
+
176
+ @staticmethod
177
+ def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
178
+ """
179
+ Create a parallel chain
180
+ Args:
181
+ prompt_template: Template for the prompt
182
+ model: The language model
183
+ branches: Dictionary of branch configurations
184
+ Returns:
185
+ Any: Configured parallel chain
186
+ """
187
+ from langchain.schema.runnable import RunnableParallel
188
+ return (
189
+ prompt_template
190
+ | model
191
+ | StrOutputParser()
192
+ | RunnableParallel(branches=branches)
193
+ )
194
+
195
+ @staticmethod
196
+ def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
197
+ """
198
+ Create a branch chain
199
+ Args:
200
+ conditions: List of condition-chain tuples
201
+ default_chain: Default chain to use
202
+ Returns:
203
+ Any: Configured branch chain
204
+ """
205
+ from langchain.schema.runnable import RunnableBranch
206
+ return RunnableBranch(*conditions, default_chain)