mb-rag 1.1.46__py3-none-any.whl → 1.1.56.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

mb_rag/chatbot/chains.py CHANGED
@@ -1,206 +1,206 @@
1
- ## file for chaining functions in chatbot
2
-
3
- from typing import Optional, List, Dict, Any, Union
4
- from dataclasses import dataclass
5
- from langchain.schema.output_parser import StrOutputParser
6
- from mb_rag.chatbot.prompts import invoke_prompt
7
- from langchain.schema.runnable import RunnableLambda, RunnableSequence
8
- from mb_rag.utils.extra import check_package
9
-
10
- __all__ = ['Chain', 'ChainConfig']
11
-
12
- def check_langchain_dependencies() -> None:
13
- """
14
- Check if required LangChain packages are installed
15
- Raises:
16
- ImportError: If any required package is missing
17
- """
18
- if not check_package("langchain"):
19
- raise ImportError("LangChain package not found. Please install it using: pip install langchain")
20
- if not check_package("langchain_core"):
21
- raise ImportError("LangChain Core package not found. Please install it using: pip install langchain-core")
22
-
23
- # Check dependencies before importing
24
- check_langchain_dependencies()
25
-
26
- @dataclass
27
- class ChainConfig:
28
- """Configuration for chain operations"""
29
- prompt: Optional[str] = None
30
- prompt_template: Optional[str] = None
31
- input_dict: Optional[Dict[str, Any]] = None
32
-
33
- class Chain:
34
- """
35
- Class to chain functions in chatbot with improved OOP design
36
- """
37
- def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
38
- """
39
- Initialize chain
40
- Args:
41
- model: The language model to use
42
- config: Chain configuration
43
- **kwargs: Additional arguments
44
- """
45
- self.model = model
46
- self._output_parser = StrOutputParser()
47
- self._initialize_config(config, **kwargs)
48
-
49
- @classmethod
50
- def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
51
- """
52
- Create chain from template
53
- Args:
54
- model: The language model
55
- template: Prompt template
56
- input_dict: Input dictionary for template
57
- **kwargs: Additional arguments
58
- Returns:
59
- Chain: New chain instance
60
- """
61
- config = ChainConfig(
62
- prompt_template=template,
63
- input_dict=input_dict
64
- )
65
- return cls(model, config, **kwargs)
66
-
67
- @classmethod
68
- def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
69
- """
70
- Create chain from direct prompt
71
- Args:
72
- model: The language model
73
- prompt: Direct prompt
74
- **kwargs: Additional arguments
75
- Returns:
76
- Chain: New chain instance
77
- """
78
- config = ChainConfig(prompt=prompt)
79
- return cls(model, config, **kwargs)
80
-
81
- def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
82
- """Initialize chain configuration"""
83
- if config:
84
- self.input_dict = config.input_dict
85
- if config.prompt_template:
86
- self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
87
- else:
88
- self.prompt = config.prompt
89
- else:
90
- self.input_dict = kwargs.get('input_dict')
91
- if prompt_template := kwargs.get('prompt_template'):
92
- self.prompt = invoke_prompt(prompt_template, self.input_dict)
93
- else:
94
- self.prompt = kwargs.get('prompt')
95
-
96
- @property
97
- def output_parser(self) -> StrOutputParser:
98
- """Get the output parser"""
99
- return self._output_parser
100
-
101
- @staticmethod
102
- def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
103
- """
104
- Validate chain components
105
- Args:
106
- prompt: The prompt to validate
107
- middle_chain: Optional middle chain to validate
108
- Raises:
109
- ValueError: If validation fails
110
- """
111
- if prompt is None:
112
- raise ValueError("Prompt is not provided")
113
- if middle_chain is not None and not isinstance(middle_chain, list):
114
- raise ValueError("middle_chain should be a list")
115
-
116
- def invoke(self) -> Any:
117
- """
118
- Invoke the chain
119
- Returns:
120
- Any: Output from the chain
121
- Raises:
122
- Exception: If prompt is not provided
123
- """
124
- self._validate_chain_components(self.prompt)
125
- chain_output = self.prompt | self.model | self.output_parser
126
- return chain_output
127
-
128
- def chain_sequence_invoke(self,
129
- middle_chain: Optional[List] = None,
130
- final_chain: Optional[RunnableLambda] = None) -> Any:
131
- """
132
- Chain invoke the sequence
133
- Args:
134
- middle_chain: List of functions/Prompts/RunnableLambda to chain
135
- final_chain: Final chain to run
136
- Returns:
137
- Any: Output from the chain
138
- """
139
- self._validate_chain_components(self.prompt, middle_chain)
140
-
141
- final = final_chain if final_chain is not None else self.output_parser
142
-
143
- if middle_chain:
144
- func_chain = RunnableSequence(self.prompt, middle_chain, final)
145
- return func_chain.invoke()
146
- return None
147
-
148
- def chain_parallel_invoke(self, parallel_chain: List) -> Any:
149
- """
150
- Chain invoke in parallel
151
- Args:
152
- parallel_chain: List of chains to run in parallel
153
- Returns:
154
- Any: Output from the parallel chains
155
- Raises:
156
- ImportError: If LangChain is not installed
157
- """
158
- if not check_package("langchain"):
159
- raise ImportError("LangChain package not found. Please install it using: pip install langchain")
160
- return parallel_chain.invoke()
161
-
162
- def chain_branch_invoke(self, branch_chain: Dict) -> Any:
163
- """
164
- Chain invoke with branching
165
- Args:
166
- branch_chain: Dictionary of branch chains
167
- Returns:
168
- Any: Output from the branch chain
169
- Raises:
170
- ImportError: If LangChain is not installed
171
- """
172
- if not check_package("langchain"):
173
- raise ImportError("LangChain package not found. Please install it using: pip install langchain")
174
- return branch_chain.invoke()
175
-
176
- @staticmethod
177
- def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
178
- """
179
- Create a parallel chain
180
- Args:
181
- prompt_template: Template for the prompt
182
- model: The language model
183
- branches: Dictionary of branch configurations
184
- Returns:
185
- Any: Configured parallel chain
186
- """
187
- from langchain.schema.runnable import RunnableParallel
188
- return (
189
- prompt_template
190
- | model
191
- | StrOutputParser()
192
- | RunnableParallel(branches=branches)
193
- )
194
-
195
- @staticmethod
196
- def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
197
- """
198
- Create a branch chain
199
- Args:
200
- conditions: List of condition-chain tuples
201
- default_chain: Default chain to use
202
- Returns:
203
- Any: Configured branch chain
204
- """
205
- from langchain.schema.runnable import RunnableBranch
206
- return RunnableBranch(*conditions, default_chain)
1
+ ## file for chaining functions in chatbot
2
+
3
+ from typing import Optional, List, Dict, Any, Union
4
+ from dataclasses import dataclass
5
+ from langchain.schema.output_parser import StrOutputParser
6
+ from mb_rag.chatbot.prompts import invoke_prompt
7
+ from langchain.schema.runnable import RunnableLambda, RunnableSequence
8
+ from mb_rag.utils.extra import check_package
9
+
10
+ __all__ = ['Chain', 'ChainConfig']
11
+
12
+ def check_langchain_dependencies() -> None:
13
+ """
14
+ Check if required LangChain packages are installed
15
+ Raises:
16
+ ImportError: If any required package is missing
17
+ """
18
+ if not check_package("langchain"):
19
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
20
+ if not check_package("langchain_core"):
21
+ raise ImportError("LangChain Core package not found. Please install it using: pip install langchain-core")
22
+
23
+ # Check dependencies before importing
24
+ check_langchain_dependencies()
25
+
26
+ @dataclass
27
+ class ChainConfig:
28
+ """Configuration for chain operations"""
29
+ prompt: Optional[str] = None
30
+ prompt_template: Optional[str] = None
31
+ input_dict: Optional[Dict[str, Any]] = None
32
+
33
+ class Chain:
34
+ """
35
+ Class to chain functions in chatbot with improved OOP design
36
+ """
37
+ def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
38
+ """
39
+ Initialize chain
40
+ Args:
41
+ model: The language model to use
42
+ config: Chain configuration
43
+ **kwargs: Additional arguments
44
+ """
45
+ self.model = model
46
+ self._output_parser = StrOutputParser()
47
+ self._initialize_config(config, **kwargs)
48
+
49
+ @classmethod
50
+ def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
51
+ """
52
+ Create chain from template
53
+ Args:
54
+ model: The language model
55
+ template: Prompt template
56
+ input_dict: Input dictionary for template
57
+ **kwargs: Additional arguments
58
+ Returns:
59
+ Chain: New chain instance
60
+ """
61
+ config = ChainConfig(
62
+ prompt_template=template,
63
+ input_dict=input_dict
64
+ )
65
+ return cls(model, config, **kwargs)
66
+
67
+ @classmethod
68
+ def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
69
+ """
70
+ Create chain from direct prompt
71
+ Args:
72
+ model: The language model
73
+ prompt: Direct prompt
74
+ **kwargs: Additional arguments
75
+ Returns:
76
+ Chain: New chain instance
77
+ """
78
+ config = ChainConfig(prompt=prompt)
79
+ return cls(model, config, **kwargs)
80
+
81
+ def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
82
+ """Initialize chain configuration"""
83
+ if config:
84
+ self.input_dict = config.input_dict
85
+ if config.prompt_template:
86
+ self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
87
+ else:
88
+ self.prompt = config.prompt
89
+ else:
90
+ self.input_dict = kwargs.get('input_dict')
91
+ if prompt_template := kwargs.get('prompt_template'):
92
+ self.prompt = invoke_prompt(prompt_template, self.input_dict)
93
+ else:
94
+ self.prompt = kwargs.get('prompt')
95
+
96
+ @property
97
+ def output_parser(self) -> StrOutputParser:
98
+ """Get the output parser"""
99
+ return self._output_parser
100
+
101
+ @staticmethod
102
+ def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
103
+ """
104
+ Validate chain components
105
+ Args:
106
+ prompt: The prompt to validate
107
+ middle_chain: Optional middle chain to validate
108
+ Raises:
109
+ ValueError: If validation fails
110
+ """
111
+ if prompt is None:
112
+ raise ValueError("Prompt is not provided")
113
+ if middle_chain is not None and not isinstance(middle_chain, list):
114
+ raise ValueError("middle_chain should be a list")
115
+
116
+ def invoke(self) -> Any:
117
+ """
118
+ Invoke the chain
119
+ Returns:
120
+ Any: Output from the chain
121
+ Raises:
122
+ Exception: If prompt is not provided
123
+ """
124
+ self._validate_chain_components(self.prompt)
125
+ chain_output = self.prompt | self.model | self.output_parser
126
+ return chain_output
127
+
128
+ def chain_sequence_invoke(self,
129
+ middle_chain: Optional[List] = None,
130
+ final_chain: Optional[RunnableLambda] = None) -> Any:
131
+ """
132
+ Chain invoke the sequence
133
+ Args:
134
+ middle_chain: List of functions/Prompts/RunnableLambda to chain
135
+ final_chain: Final chain to run
136
+ Returns:
137
+ Any: Output from the chain
138
+ """
139
+ self._validate_chain_components(self.prompt, middle_chain)
140
+
141
+ final = final_chain if final_chain is not None else self.output_parser
142
+
143
+ if middle_chain:
144
+ func_chain = RunnableSequence(self.prompt, middle_chain, final)
145
+ return func_chain.invoke()
146
+ return None
147
+
148
+ def chain_parallel_invoke(self, parallel_chain: List) -> Any:
149
+ """
150
+ Chain invoke in parallel
151
+ Args:
152
+ parallel_chain: List of chains to run in parallel
153
+ Returns:
154
+ Any: Output from the parallel chains
155
+ Raises:
156
+ ImportError: If LangChain is not installed
157
+ """
158
+ if not check_package("langchain"):
159
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
160
+ return parallel_chain.invoke()
161
+
162
+ def chain_branch_invoke(self, branch_chain: Dict) -> Any:
163
+ """
164
+ Chain invoke with branching
165
+ Args:
166
+ branch_chain: Dictionary of branch chains
167
+ Returns:
168
+ Any: Output from the branch chain
169
+ Raises:
170
+ ImportError: If LangChain is not installed
171
+ """
172
+ if not check_package("langchain"):
173
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
174
+ return branch_chain.invoke()
175
+
176
+ @staticmethod
177
+ def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
178
+ """
179
+ Create a parallel chain
180
+ Args:
181
+ prompt_template: Template for the prompt
182
+ model: The language model
183
+ branches: Dictionary of branch configurations
184
+ Returns:
185
+ Any: Configured parallel chain
186
+ """
187
+ from langchain.schema.runnable import RunnableParallel
188
+ return (
189
+ prompt_template
190
+ | model
191
+ | StrOutputParser()
192
+ | RunnableParallel(branches=branches)
193
+ )
194
+
195
+ @staticmethod
196
+ def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
197
+ """
198
+ Create a branch chain
199
+ Args:
200
+ conditions: List of condition-chain tuples
201
+ default_chain: Default chain to use
202
+ Returns:
203
+ Any: Configured branch chain
204
+ """
205
+ from langchain.schema.runnable import RunnableBranch
206
+ return RunnableBranch(*conditions, default_chain)
@@ -0,0 +1,185 @@
1
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
2
+ from typing import Optional, List, Any, Union
3
+
4
+ __all__ = [
5
+ 'ConversationModel'
6
+ ]
7
+
8
+ class ConversationModel:
9
+ """
10
+ A class to handle conversation with AI models
11
+
12
+ Attributes:
13
+ chatbot: The AI model for conversation
14
+ message_list (List): List of conversation messages
15
+ file_path (str): Path to save/load conversations. Can be local or S3
16
+ """
17
+
18
+ def __init__(self,
19
+ llm: Any,
20
+ message_list: Optional[List[Any]] = None,
21
+ file_path: Optional[str] = None,
22
+ **kwargs) -> None:
23
+ """Initialize conversation model"""
24
+ self.chatbot = llm
25
+ if message_list:
26
+ self.message_list = message_list
27
+ else:
28
+ self.message_list = []
29
+ if file_path:
30
+ self.file_path = file_path
31
+ else:
32
+ self.file_path = None
33
+
34
+ def initialize_conversation(self,context_message: str = "") -> None:
35
+ """Initialize conversation state.
36
+ Getting the content from file_path if provided"""
37
+ if self.file_path:
38
+ self.load_conversation()
39
+
40
+ if context_message:
41
+ self.message_list.append(SystemMessage(content=context_message))
42
+ else:
43
+ self.message_list.append(SystemMessage(content="""This is conversation model.
44
+ Look into the conversation history and answer the question if provided.
45
+ Give a brief introduction of the conversation history."""))
46
+ message_list_content = "".join(self.all_messages_content)
47
+ return self.add_message(message_list_content,get_content_only=True)
48
+
49
+ def _ask_question(self,query: str,images: list = None,
50
+ get_content_only: bool = True) -> str:
51
+ """
52
+ Ask a question and get response
53
+ Args:
54
+ query: Question to ask
55
+ get_content_only: Whether to return only content
56
+ Returns:
57
+ str: Response from the model
58
+ """
59
+ if images:
60
+ res = self.chatbot.invoke_query(query,images=images,get_content_only=get_content_only)
61
+ else:
62
+ res = self.chatbot.invoke_query(query,get_content_only=get_content_only)
63
+ return res
64
+
65
+ def add_message(self, query: str,images: list = None,get_content_only: bool = True) -> str:
66
+ """
67
+ Add a message to the conversation
68
+ Args:
69
+ query (str): Question to ask
70
+ images (list): List of images to send to the model
71
+ get_content_only (bool): Whether to return only content
72
+ Returns:
73
+ str: Response from the chatbot
74
+ """
75
+ self.message_list.append(HumanMessage(content=query))
76
+ res = self._ask_question(query,images=images,get_content_only=get_content_only)
77
+ self.message_list.append(AIMessage(content=res))
78
+ return res
79
+
80
+ @property
81
+ def all_messages(self) -> List[Union[SystemMessage, HumanMessage, AIMessage]]:
82
+ """Get all messages"""
83
+ return self.message_list
84
+
85
+ @property
86
+ def last_message(self) -> str:
87
+ """Get the last message"""
88
+ return self.message_list[-1].content
89
+
90
+ @property
91
+ def all_messages_content(self) -> List[str]:
92
+ """Get content of all messages"""
93
+ return [message.content for message in self.message_list]
94
+
95
+ def _is_s3_path(self, path: str) -> bool:
96
+ """
97
+ Check if path is an S3 path
98
+ Args:
99
+ path (str): Path to check
100
+ Returns:
101
+ bool: True if S3 path
102
+ """
103
+ return path.startswith("s3://")
104
+
105
+ def save_conversation(self, file_path: Optional[str] = None, **kwargs) -> bool:
106
+ """
107
+ Save the conversation
108
+ Args:
109
+ file_path: Path to save the conversation
110
+ **kwargs: Additional arguments for S3
111
+ Returns:
112
+ bool: Success status
113
+ """
114
+ if self._is_s3_path(file_path or self.file_path):
115
+ print("Saving conversation to S3.")
116
+ self.save_file_path = file_path
117
+ return self._save_to_s3(self.file_path,**kwargs)
118
+ return self._save_to_file(file_path or self.file_path)
119
+
120
+ def _save_to_s3(self,**kwargs) -> bool:
121
+ """Save conversation to S3"""
122
+ try:
123
+ client = kwargs.get('client', self.client)
124
+ bucket = kwargs.get('bucket', self.bucket)
125
+ client.put_object(
126
+ Body=str(self.message_list),
127
+ Bucket=bucket,
128
+ Key=self.save_file_path
129
+ )
130
+ print(f"Conversation saved to s3_path: {self.s3_path}")
131
+ return True
132
+ except Exception as e:
133
+ raise ValueError(f"Error saving conversation to s3: {e}")
134
+
135
+ def _save_to_file(self, file_path: str) -> bool:
136
+ """Save conversation to file"""
137
+ try:
138
+ with open(file_path, 'w') as f:
139
+ for message in self.message_list:
140
+ f.write(f"{message.content}\n")
141
+ print(f"Conversation saved to file: {file_path}")
142
+ return True
143
+ except Exception as e:
144
+ raise ValueError(f"Error saving conversation to file: {e}")
145
+
146
+ def load_conversation(self, file_path: Optional[str] = None, **kwargs) -> List[Any]:
147
+ """
148
+ Load a conversation
149
+ Args:
150
+ file_path: Path to load from
151
+ **kwargs: Additional arguments for S3
152
+ Returns:
153
+ List: Loaded messages
154
+ """
155
+ self.message_list = []
156
+ if self._is_s3_path(file_path or self.file_path):
157
+ print("Loading conversation from S3.")
158
+ self.file_path = file_path
159
+ return self._load_from_s3(**kwargs)
160
+ return self._load_from_file(file_path or self.file_path)
161
+
162
+ def _load_from_s3(self, **kwargs) -> List[Any]:
163
+ """Load conversation from S3"""
164
+ try:
165
+ client = kwargs.get('client', self.client)
166
+ bucket = kwargs.get('bucket', self.bucket)
167
+ res = client.get_response(client, bucket, self.s3_path)
168
+ res_str = eval(res['Body'].read().decode('utf-8'))
169
+ self.message_list = [SystemMessage(content=res_str)]
170
+ print(f"Conversation loaded from s3_path: {self.file_path}")
171
+ return self.message_list
172
+ except Exception as e:
173
+ raise ValueError(f"Error loading conversation from s3: {e}")
174
+
175
+ def _load_from_file(self, file_path: str) -> List[Any]:
176
+ """Load conversation from file"""
177
+ try:
178
+ with open(file_path, 'r') as f:
179
+ lines = f.readlines()
180
+ for line in lines:
181
+ self.message_list.append(SystemMessage(content=line))
182
+ print(f"Conversation loaded from file: {file_path}")
183
+ return self.message_list
184
+ except Exception as e:
185
+ raise ValueError(f"Error loading conversation from file: {e}")