mb-rag 1.0.117__tar.gz → 1.0.124__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: mb_rag
3
- Version: 1.0.117
3
+ Version: 1.0.124
4
4
  Summary: RAG function file
5
5
  Author: ['Malav Bateriwala']
6
6
  Requires-Python: >=3.8
@@ -0,0 +1,395 @@
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
4
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
5
+ from IPython.display import display, HTML
6
+ from typing import Optional, List, Dict, Any, Union
7
+ from mb_rag.utils.extra import check_package
8
+ import base64
9
+
10
+ __all__ = [
11
+ 'ChatbotBase',
12
+ 'ModelFactory',
13
+ 'ConversationModel',
14
+ 'IPythonStreamHandler'
15
+ ]
16
+
17
+ class ChatbotBase:
18
+ """Base class for chatbot functionality"""
19
+
20
+ @staticmethod
21
+ def load_env(file_path: str) -> None:
22
+ """
23
+ Load environment variables from a file
24
+ Args:
25
+ file_path (str): Path to the environment file
26
+ """
27
+ load_dotenv(file_path)
28
+
29
+ @staticmethod
30
+ def add_os_key(name: str, key: str) -> None:
31
+ """
32
+ Add an API key to the environment
33
+ Args:
34
+ name (str): Name of the API key
35
+ key (str): API key
36
+ """
37
+ os.environ[name] = key
38
+
39
+ @staticmethod
40
+ def get_client():
41
+ """
42
+ Returns a boto3 client for S3
43
+ """
44
+ if not check_package("boto3"):
45
+ raise ImportError("Boto3 package not found. Please install it using: pip install boto3")
46
+
47
+ import boto3
48
+ return boto3.client('s3')
49
+
50
+ class ModelFactory:
51
+ """Factory class for creating different types of chatbot models"""
52
+
53
+ def __init__(self, model_type: str = 'openai', model_name: str = "gpt-4o", **kwargs) -> Any:
54
+ """
55
+ Factory method to create any type of model
56
+ Args:
57
+ model_type (str): Type of model to create
58
+ model_name (str): Name of the model
59
+ **kwargs: Additional arguments
60
+ Returns:
61
+ Any: Chatbot model
62
+ """
63
+ creators = {
64
+ 'openai': self.create_openai,
65
+ 'anthropic': self.create_anthropic,
66
+ 'google': self.create_google,
67
+ 'ollama': self.create_ollama
68
+ }
69
+
70
+ model_data = creators.get(model_type)
71
+ if not model_data:
72
+ raise ValueError(f"Unsupported model type: {model_type}")
73
+
74
+ try:
75
+ self.model = model_data(model_name, **kwargs)
76
+ except Exception as e:
77
+ raise ValueError(f"Error creating {model_type} model: {str(e)}")
78
+
79
+ @classmethod
80
+ def create_openai(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
81
+ """
82
+ Create OpenAI chatbot model
83
+ Args:
84
+ model_name (str): Name of the model
85
+ **kwargs: Additional arguments
86
+ Returns:
87
+ ChatOpenAI: Chatbot model
88
+ """
89
+ if not check_package("openai"):
90
+ raise ImportError("OpenAI package not found. Please install it using: pip install openai langchain-openai")
91
+
92
+ from langchain_openai import ChatOpenAI
93
+ kwargs["model_name"] = model_name
94
+ return ChatOpenAI(**kwargs)
95
+
96
+ @classmethod
97
+ def create_anthropic(cls, model_name: str = "claude-3-opus-20240229", **kwargs) -> Any:
98
+ """
99
+ Create Anthropic chatbot model
100
+ Args:
101
+ model_name (str): Name of the model
102
+ **kwargs: Additional arguments
103
+ Returns:
104
+ ChatAnthropic: Chatbot model
105
+ """
106
+ if not check_package("anthropic"):
107
+ raise ImportError("Anthropic package not found. Please install it using: pip install anthropic langchain-anthropic")
108
+
109
+ from langchain_anthropic import ChatAnthropic
110
+ kwargs["model_name"] = model_name
111
+ return ChatAnthropic(**kwargs)
112
+
113
+ @classmethod
114
+ def create_google(cls, model_name: str = "gemini-1.5-flash", **kwargs) -> Any:
115
+ """
116
+ Create Google chatbot model
117
+ Args:
118
+ model_name (str): Name of the model
119
+ **kwargs: Additional arguments
120
+ Returns:
121
+ ChatGoogleGenerativeAI: Chatbot model
122
+ """
123
+ if not check_package("google.generativeai"):
124
+ raise ImportError("Google Generative AI package not found. Please install it using: pip install google-generativeai langchain-google-genai")
125
+
126
+ from langchain_google_genai import ChatGoogleGenerativeAI
127
+ kwargs["model"] = model_name
128
+ return ChatGoogleGenerativeAI(**kwargs)
129
+
130
+ @classmethod
131
+ def create_ollama(cls, model_name: str = "llama3", **kwargs) -> Any:
132
+ """
133
+ Create Ollama chatbot model
134
+ Args:
135
+ model_name (str): Name of the model
136
+ **kwargs: Additional arguments
137
+ Returns:
138
+ Ollama: Chatbot model
139
+ """
140
+ if not check_package("langchain_community"):
141
+ raise ImportError("Langchain Community package not found. Please install it using: pip install langchain-community")
142
+
143
+ from langchain_community.llms import Ollama
144
+ kwargs["model"] = model_name
145
+ return Ollama(**kwargs)
146
+
147
+ def invoke_query(self,query: str,get_content_only: bool = True,images: list = None,pydantic_model = None) -> str:
148
+ """
149
+ Invoke the model
150
+ Args:
151
+ query (str): Query to send to the model
152
+ get_content_only (bool): Whether to return only content
153
+ images (list): List of images to send to the model
154
+ pydantic_model: Pydantic model for structured output
155
+ Returns:
156
+ str: Response from the model
157
+ """
158
+
159
+ if pydantic_model is not None:
160
+ try:
161
+ self.model = self.model.with_structured_output(pydantic_model)
162
+ except Exception as e:
163
+ raise ValueError(f"Error with pydantic_model: {e}")
164
+ if images:
165
+ res = self._model_invoke_images(images=images,prompt=query,pydantic_model=pydantic_model)
166
+ else:
167
+ res = self.model.invoke(query)
168
+ if get_content_only:
169
+ try:
170
+ return res.content
171
+ except Exception:
172
+ return res
173
+ return res
174
+
175
+ def _image_to_base64(self,image):
176
+ with open(image, "rb") as f:
177
+ return base64.b64encode(f.read()).decode('utf-8')
178
+
179
+ def _model_invoke_images(self,images: list, prompt: str,pydantic_model = None):
180
+ """
181
+ Function to invoke the model with images
182
+ Args:
183
+ model (ChatOpenAI): Chatbot model
184
+ images (list): List of images
185
+ prompt (str): Prompt
186
+ pydantic_model (PydanticModel): Pydantic model
187
+ Returns:
188
+ str: Output from the model
189
+ """
190
+ base64_images = [self._image_to_base64(image) for image in images]
191
+ image_prompt_create = [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_images[i]}"}} for i in range(len(images))]
192
+ prompt_new = [{"type": "text", "text": prompt},
193
+ *image_prompt_create,]
194
+ if pydantic_model is not None:
195
+ try:
196
+ self.model = self.model.with_structured_output(pydantic_model)
197
+ except Exception as e:
198
+ print(f"Error with pydantic_model: {e}")
199
+ print("Continuing without structured output")
200
+ message= HumanMessage(content=prompt_new,)
201
+ response = self.model.invoke([message])
202
+ return response.content
203
+
204
+
205
+ class ConversationModel:
206
+ """
207
+ A class to handle conversation with AI models
208
+
209
+ Attributes:
210
+ chatbot: The AI model for conversation
211
+ message_list (List): List of conversation messages
212
+ file_path (str): Path to save/load conversations. Can be local or S3
213
+ """
214
+
215
+ def __init__(self,
216
+ model_name: str = "gpt-4o",
217
+ model_type: str = 'openai',
218
+ **kwargs) -> None:
219
+ """Initialize conversation model"""
220
+ self.chatbot = ModelFactory(model_type, model_name, **kwargs)
221
+
222
+ def initialize_conversation(self,
223
+ question: Optional[str],
224
+ context: Optional[str] = None,
225
+ file_path: Optional[str]=None) -> None:
226
+ """Initialize conversation state"""
227
+ if file_path:
228
+ self.file_path = file_path
229
+ self.load_conversation(file_path)
230
+
231
+ else:
232
+ if not question:
233
+ raise ValueError("Question is required.")
234
+
235
+ if context:
236
+ self.context = context
237
+ else:
238
+ self.context = "Answer question to the point and don't hallucinate."
239
+ self.message_list = [
240
+ SystemMessage(content=context),
241
+ HumanMessage(content=question)
242
+ ]
243
+
244
+ res = self._ask_question(self.message_list)
245
+ print(res)
246
+ self.message_list.append(AIMessage(content=res))
247
+
248
+ def _ask_question(self,messages: List[Union[SystemMessage, HumanMessage, AIMessage]],
249
+ get_content_only: bool = True) -> str:
250
+ """
251
+ Ask a question and get response
252
+ Args:
253
+ messages: List of messages
254
+ get_content_only: Whether to return only content
255
+ Returns:
256
+ str: Response from the model
257
+ """
258
+ res = self.chatbot.invoke_query(messages)
259
+ if get_content_only:
260
+ try:
261
+ return res.content
262
+ except Exception:
263
+ return res
264
+ return res
265
+
266
+ def add_message(self, message: str) -> str:
267
+ """
268
+ Add a message to the conversation
269
+ Args:
270
+ message (str): Message to add
271
+ Returns:
272
+ str: Response from the chatbot
273
+ """
274
+ self.message_list.append(HumanMessage(content=message))
275
+ res = self._ask_question(self.message_list)
276
+ self.message_list.append(AIMessage(content=res))
277
+ return res
278
+
279
+ @property
280
+ def all_messages(self) -> List[Union[SystemMessage, HumanMessage, AIMessage]]:
281
+ """Get all messages"""
282
+ return self.message_list
283
+
284
+ @property
285
+ def last_message(self) -> str:
286
+ """Get the last message"""
287
+ return self.message_list[-1].content
288
+
289
+ @property
290
+ def all_messages_content(self) -> List[str]:
291
+ """Get content of all messages"""
292
+ return [message.content for message in self.message_list]
293
+
294
+ def _is_s3_path(self, path: str) -> bool:
295
+ """
296
+ Check if path is an S3 path
297
+ Args:
298
+ path (str): Path to check
299
+ Returns:
300
+ bool: True if S3 path
301
+ """
302
+ return path.startswith("s3://")
303
+
304
+ def save_conversation(self, file_path: Optional[str] = None, **kwargs) -> bool:
305
+ """
306
+ Save the conversation
307
+ Args:
308
+ file_path: Path to save the conversation
309
+ **kwargs: Additional arguments for S3
310
+ Returns:
311
+ bool: Success status
312
+ """
313
+ if self._is_s3_path(file_path or self.file_path):
314
+ print("Saving conversation to S3.")
315
+ self.save_file_path = file_path
316
+ return self._save_to_s3(self.file_path,**kwargs)
317
+ return self._save_to_file(file_path or self.file_path)
318
+
319
+ def _save_to_s3(self,**kwargs) -> bool:
320
+ """Save conversation to S3"""
321
+ try:
322
+ client = kwargs.get('client', self.client)
323
+ bucket = kwargs.get('bucket', self.bucket)
324
+ client.put_object(
325
+ Body=str(self.message_list),
326
+ Bucket=bucket,
327
+ Key=self.save_file_path
328
+ )
329
+ print(f"Conversation saved to s3_path: {self.s3_path}")
330
+ return True
331
+ except Exception as e:
332
+ raise ValueError(f"Error saving conversation to s3: {e}")
333
+
334
+ def _save_to_file(self, file_path: str) -> bool:
335
+ """Save conversation to file"""
336
+ try:
337
+ with open(file_path, 'w') as f:
338
+ for message in self.message_list:
339
+ f.write(f"{message.content}\n")
340
+ print(f"Conversation saved to file: {file_path}")
341
+ return True
342
+ except Exception as e:
343
+ raise ValueError(f"Error saving conversation to file: {e}")
344
+
345
+ def load_conversation(self, file_path: Optional[str] = None, **kwargs) -> List[Any]:
346
+ """
347
+ Load a conversation
348
+ Args:
349
+ file_path: Path to load from
350
+ **kwargs: Additional arguments for S3
351
+ Returns:
352
+ List: Loaded messages
353
+ """
354
+ self.message_list = []
355
+ if self._is_s3_path(file_path or self.file_path):
356
+ print("Loading conversation from S3.")
357
+ self.file_path = file_path
358
+ return self._load_from_s3(**kwargs)
359
+ return self._load_from_file(file_path or self.file_path)
360
+
361
+ def _load_from_s3(self, **kwargs) -> List[Any]:
362
+ """Load conversation from S3"""
363
+ try:
364
+ client = kwargs.get('client', self.client)
365
+ bucket = kwargs.get('bucket', self.bucket)
366
+ res = client.get_response(client, bucket, self.s3_path)
367
+ res_str = eval(res['Body'].read().decode('utf-8'))
368
+ self.message_list = [SystemMessage(content=res_str)]
369
+ print(f"Conversation loaded from s3_path: {self.file_path}")
370
+ return self.message_list
371
+ except Exception as e:
372
+ raise ValueError(f"Error loading conversation from s3: {e}")
373
+
374
+ def _load_from_file(self, file_path: str) -> List[Any]:
375
+ """Load conversation from file"""
376
+ try:
377
+ with open(file_path, 'r') as f:
378
+ lines = f.readlines()
379
+ for line in lines:
380
+ self.message_list.append(SystemMessage(content=line))
381
+ print(f"Conversation loaded from file: {file_path}")
382
+ return self.message_list
383
+ except Exception as e:
384
+ raise ValueError(f"Error loading conversation from file: {e}")
385
+
386
+ class IPythonStreamHandler(StreamingStdOutCallbackHandler):
387
+ """Handler for IPython display"""
388
+
389
+ def __init__(self):
390
+ self.output = ""
391
+
392
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
393
+ """Handle new token"""
394
+ self.output += token
395
+ display(HTML(self.output), clear=True)
@@ -0,0 +1,206 @@
1
+ ## file for chaining functions in chatbot
2
+
3
+ from typing import Optional, List, Dict, Any, Union
4
+ from dataclasses import dataclass
5
+ from langchain.schema.output_parser import StrOutputParser
6
+ from mb_rag.chatbot.prompts import invoke_prompt
7
+ from langchain.schema.runnable import RunnableLambda, RunnableSequence
8
+ from mb_rag.utils.extra import check_package
9
+
10
+ __all__ = ['Chain', 'ChainConfig']
11
+
12
+ def check_langchain_dependencies() -> None:
13
+ """
14
+ Check if required LangChain packages are installed
15
+ Raises:
16
+ ImportError: If any required package is missing
17
+ """
18
+ if not check_package("langchain"):
19
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
20
+ if not check_package("langchain_core"):
21
+ raise ImportError("LangChain Core package not found. Please install it using: pip install langchain-core")
22
+
23
+ # Check dependencies before importing
24
+ check_langchain_dependencies()
25
+
26
+ @dataclass
27
+ class ChainConfig:
28
+ """Configuration for chain operations"""
29
+ prompt: Optional[str] = None
30
+ prompt_template: Optional[str] = None
31
+ input_dict: Optional[Dict[str, Any]] = None
32
+
33
+ class Chain:
34
+ """
35
+ Class to chain functions in chatbot with improved OOP design
36
+ """
37
+ def __init__(self, model: Any, config: Optional[ChainConfig] = None, **kwargs):
38
+ """
39
+ Initialize chain
40
+ Args:
41
+ model: The language model to use
42
+ config: Chain configuration
43
+ **kwargs: Additional arguments
44
+ """
45
+ self.model = model
46
+ self._output_parser = StrOutputParser()
47
+ self._initialize_config(config, **kwargs)
48
+
49
+ @classmethod
50
+ def from_template(cls, model: Any, template: str, input_dict: Dict[str, Any], **kwargs) -> 'Chain':
51
+ """
52
+ Create chain from template
53
+ Args:
54
+ model: The language model
55
+ template: Prompt template
56
+ input_dict: Input dictionary for template
57
+ **kwargs: Additional arguments
58
+ Returns:
59
+ Chain: New chain instance
60
+ """
61
+ config = ChainConfig(
62
+ prompt_template=template,
63
+ input_dict=input_dict
64
+ )
65
+ return cls(model, config, **kwargs)
66
+
67
+ @classmethod
68
+ def from_prompt(cls, model: Any, prompt: str, **kwargs) -> 'Chain':
69
+ """
70
+ Create chain from direct prompt
71
+ Args:
72
+ model: The language model
73
+ prompt: Direct prompt
74
+ **kwargs: Additional arguments
75
+ Returns:
76
+ Chain: New chain instance
77
+ """
78
+ config = ChainConfig(prompt=prompt)
79
+ return cls(model, config, **kwargs)
80
+
81
+ def _initialize_config(self, config: Optional[ChainConfig], **kwargs) -> None:
82
+ """Initialize chain configuration"""
83
+ if config:
84
+ self.input_dict = config.input_dict
85
+ if config.prompt_template:
86
+ self.prompt = invoke_prompt(config.prompt_template, self.input_dict)
87
+ else:
88
+ self.prompt = config.prompt
89
+ else:
90
+ self.input_dict = kwargs.get('input_dict')
91
+ if prompt_template := kwargs.get('prompt_template'):
92
+ self.prompt = invoke_prompt(prompt_template, self.input_dict)
93
+ else:
94
+ self.prompt = kwargs.get('prompt')
95
+
96
+ @property
97
+ def output_parser(self) -> StrOutputParser:
98
+ """Get the output parser"""
99
+ return self._output_parser
100
+
101
+ @staticmethod
102
+ def _validate_chain_components(prompt: Any, middle_chain: Optional[List] = None) -> None:
103
+ """
104
+ Validate chain components
105
+ Args:
106
+ prompt: The prompt to validate
107
+ middle_chain: Optional middle chain to validate
108
+ Raises:
109
+ ValueError: If validation fails
110
+ """
111
+ if prompt is None:
112
+ raise ValueError("Prompt is not provided")
113
+ if middle_chain is not None and not isinstance(middle_chain, list):
114
+ raise ValueError("middle_chain should be a list")
115
+
116
+ def invoke(self) -> Any:
117
+ """
118
+ Invoke the chain
119
+ Returns:
120
+ Any: Output from the chain
121
+ Raises:
122
+ Exception: If prompt is not provided
123
+ """
124
+ self._validate_chain_components(self.prompt)
125
+ chain_output = self.prompt | self.model | self.output_parser
126
+ return chain_output
127
+
128
+ def chain_sequence_invoke(self,
129
+ middle_chain: Optional[List] = None,
130
+ final_chain: Optional[RunnableLambda] = None) -> Any:
131
+ """
132
+ Chain invoke the sequence
133
+ Args:
134
+ middle_chain: List of functions/Prompts/RunnableLambda to chain
135
+ final_chain: Final chain to run
136
+ Returns:
137
+ Any: Output from the chain
138
+ """
139
+ self._validate_chain_components(self.prompt, middle_chain)
140
+
141
+ final = final_chain if final_chain is not None else self.output_parser
142
+
143
+ if middle_chain:
144
+ func_chain = RunnableSequence(self.prompt, middle_chain, final)
145
+ return func_chain.invoke()
146
+ return None
147
+
148
+ def chain_parallel_invoke(self, parallel_chain: List) -> Any:
149
+ """
150
+ Chain invoke in parallel
151
+ Args:
152
+ parallel_chain: List of chains to run in parallel
153
+ Returns:
154
+ Any: Output from the parallel chains
155
+ Raises:
156
+ ImportError: If LangChain is not installed
157
+ """
158
+ if not check_package("langchain"):
159
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
160
+ return parallel_chain.invoke()
161
+
162
+ def chain_branch_invoke(self, branch_chain: Dict) -> Any:
163
+ """
164
+ Chain invoke with branching
165
+ Args:
166
+ branch_chain: Dictionary of branch chains
167
+ Returns:
168
+ Any: Output from the branch chain
169
+ Raises:
170
+ ImportError: If LangChain is not installed
171
+ """
172
+ if not check_package("langchain"):
173
+ raise ImportError("LangChain package not found. Please install it using: pip install langchain")
174
+ return branch_chain.invoke()
175
+
176
+ @staticmethod
177
+ def create_parallel_chain(prompt_template: str, model: Any, branches: Dict[str, Any]) -> Any:
178
+ """
179
+ Create a parallel chain
180
+ Args:
181
+ prompt_template: Template for the prompt
182
+ model: The language model
183
+ branches: Dictionary of branch configurations
184
+ Returns:
185
+ Any: Configured parallel chain
186
+ """
187
+ from langchain.schema.runnable import RunnableParallel
188
+ return (
189
+ prompt_template
190
+ | model
191
+ | StrOutputParser()
192
+ | RunnableParallel(branches=branches)
193
+ )
194
+
195
+ @staticmethod
196
+ def create_branch_chain(conditions: List[tuple], default_chain: Any) -> Any:
197
+ """
198
+ Create a branch chain
199
+ Args:
200
+ conditions: List of condition-chain tuples
201
+ default_chain: Default chain to use
202
+ Returns:
203
+ Any: Configured branch chain
204
+ """
205
+ from langchain.schema.runnable import RunnableBranch
206
+ return RunnableBranch(*conditions, default_chain)