mb-rag 1.1.46__py3-none-any.whl → 1.1.56.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mb-rag might be problematic. Click here for more details.

mb_rag/version.py CHANGED
@@ -1,5 +1,5 @@
1
1
  MAJOR_VERSION = 1
2
2
  MINOR_VERSION = 1
3
- PATCH_VERSION = 46
3
+ PATCH_VERSION = 56
4
4
  version = '{}.{}.{}'.format(MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION)
5
5
  __all__ = ['MAJOR_VERSION', 'MINOR_VERSION', 'PATCH_VERSION', 'version']
@@ -1,11 +1,11 @@
1
- Metadata-Version: 2.4
2
- Name: mb_rag
3
- Version: 1.1.46
4
- Summary: RAG function file
5
- Author: ['Malav Bateriwala']
6
- Requires-Python: >=3.8
7
- Requires-Dist: mb_base
8
- Dynamic: author
9
- Dynamic: requires-dist
10
- Dynamic: requires-python
11
- Dynamic: summary
1
+ Metadata-Version: 2.4
2
+ Name: mb_rag
3
+ Version: 1.1.56.post0
4
+ Summary: RAG function file
5
+ Author: ['Malav Bateriwala']
6
+ Requires-Python: >=3.8
7
+ Requires-Dist: mb_base
8
+ Dynamic: author
9
+ Dynamic: requires-dist
10
+ Dynamic: requires-python
11
+ Dynamic: summary
@@ -0,0 +1,19 @@
1
+ mb_rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ mb_rag/basic.py,sha256=v-D27FAJZTpUHsKHTNw2-cZiLHbwlKYRzrOlFNyAudY,12216
3
+ mb_rag/version.py,sha256=pCIU6hqXriGFDx2MivR_HZSvVwlwBMSh-2M7HpwDDyI,207
4
+ mb_rag/chatbot/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ mb_rag/chatbot/chains.py,sha256=0Ir1nUlNdmzs7jPFiES2z3EaGErDTmIMZ1Lx7R4ajLg,7429
6
+ mb_rag/chatbot/conversation.py,sha256=uboW4p_MBXbD3u4b79O0cnuKRrr-6fyeQgnv4C4DrA0,7179
7
+ mb_rag/chatbot/prompts.py,sha256=-DGEV-a9kaPhsRTMrRPeg0kJxOo9um315SHpHoAijbg,1807
8
+ mb_rag/rag/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ mb_rag/rag/embeddings.py,sha256=vqU9Yt-6aXu9XN8gBcXkhYHBTZvERH87HiFTzYtJHKM,31325
10
+ mb_rag/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ mb_rag/utils/all_data_extract.py,sha256=CCxEPwD52qAAbMESW9QbPNtXlQsSTviSlcUOVQqPOC8,2596
12
+ mb_rag/utils/bounding_box.py,sha256=vjy1BzqebEz0dtM5H5S_94UkmZemCnM-_AnBItOWK3o,8597
13
+ mb_rag/utils/document_extract.py,sha256=AHvC2YGo60snXXB9C15QyMRi3ZfkfcRAfV8t_IRD4IM,12912
14
+ mb_rag/utils/extra.py,sha256=SWbZmTttK20GNkt9AM0xoOzhQEYwJiPmR9rgHnt1_RU,2565
15
+ mb_rag/utils/pdf_extract.py,sha256=aHk61kdzSM6uptzQHuGin8Kgf9DfaQjcLX5-q9hNQok,16026
16
+ mb_rag-1.1.56.post0.dist-info/METADATA,sha256=bM5gnocs1dQczGE0UPl3fWvJmPTQwD-WagHQYzW05zM,251
17
+ mb_rag-1.1.56.post0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ mb_rag-1.1.56.post0.dist-info/top_level.txt,sha256=FIK1eAa5uYnurgXZquBG-s3PIy-HDTC5yJBW4lTH_pM,7
19
+ mb_rag-1.1.56.post0.dist-info/RECORD,,
mb_rag/chatbot/basic.py DELETED
@@ -1,644 +0,0 @@
1
- import os
2
- from dotenv import load_dotenv
3
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
4
- from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
5
- from langchain.document_loaders import TextLoader
6
- from IPython.display import display, HTML
7
- from typing import Optional, List, Dict, Any, Union
8
- from mb_rag.utils.extra import check_package
9
- import base64
10
-
11
- __all__ = [
12
- 'ChatbotBase',
13
- 'ModelFactory',
14
- 'ConversationModel',
15
- 'IPythonStreamHandler',
16
- 'AgentFactory'
17
- ]
18
-
19
- class ChatbotBase:
20
- """Base class for chatbot functionality"""
21
-
22
- @staticmethod
23
- def load_env(file_path: str) -> None:
24
- """
25
- Load environment variables from a file
26
- Args:
27
- file_path (str): Path to the environment file
28
- """
29
- load_dotenv(file_path)
30
-
31
- @staticmethod
32
- def add_os_key(name: str, key: str) -> None:
33
- """
34
- Add an API key to the environment
35
- Args:
36
- name (str): Name of the API key
37
- key (str): API key
38
- """
39
- os.environ[name] = key
40
-
41
- @staticmethod
42
- def get_client():
43
- """
44
- Returns a boto3 client for S3
45
- """
46
- if not check_package("boto3"):
47
- raise ImportError("Boto3 package not found. Please install it using: pip install boto3")
48
-
49
- import boto3
50
- return boto3.client('s3')
51
-
52
- class ModelFactory:
53
- """Factory class for creating different types of chatbot models"""
54
-
55
- def __init__(self, model_type: str = 'openai', model_name: str = "gpt-4o", **kwargs) -> Any:
56
- """
57
- Factory method to create any type of model
58
- Args:
59
- model_type (str): Type of model to create. Default is OpenAI. Options are openai, anthropic, google, ollama , groq
60
- model_name (str): Name of the model
61
- **kwargs: Additional arguments
62
- Returns:
63
- Any: Chatbot model
64
- """
65
- creators = {
66
- 'openai': self.create_openai,
67
- 'anthropic': self.create_anthropic,
68
- 'google': self.create_google,
69
- 'ollama': self.create_ollama,
70
- 'groq': self.create_groq,
71
- 'deepseek': self.create_deepseek,
72
- 'qwen' : self.create_qwen,
73
- 'hugging_face': self.create_hugging_face
74
- }
75
-
76
- self.model_type = model_type
77
- self.model_name = model_name
78
- model_data = creators.get(model_type)
79
- if not model_data:
80
- raise ValueError(f"Unsupported model type: {model_type}")
81
-
82
- try:
83
- self.model = model_data(model_name, **kwargs)
84
- except Exception as e:
85
- raise ValueError(f"Error creating {model_type} model: {str(e)}")
86
-
87
- @classmethod
88
- def create_openai(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
89
- """
90
- Create OpenAI chatbot model
91
- Args:
92
- model_name (str): Name of the model
93
- **kwargs: Additional arguments
94
- Returns:
95
- ChatOpenAI: Chatbot model
96
- """
97
- if not check_package("openai"):
98
- raise ImportError("OpenAI package not found. Please install it using: pip install openai langchain-openai")
99
-
100
- from langchain_openai import ChatOpenAI
101
- kwargs["model_name"] = model_name
102
- return ChatOpenAI(**kwargs)
103
-
104
- @classmethod
105
- def create_anthropic(cls, model_name: str = "claude-3-opus-20240229", **kwargs) -> Any:
106
- """
107
- Create Anthropic chatbot model
108
- Args:
109
- model_name (str): Name of the model
110
- **kwargs: Additional arguments
111
- Returns:
112
- ChatAnthropic: Chatbot model
113
- """
114
- if not check_package("anthropic"):
115
- raise ImportError("Anthropic package not found. Please install it using: pip install anthropic langchain-anthropic")
116
-
117
- from langchain_anthropic import ChatAnthropic
118
- kwargs["model_name"] = model_name
119
- return ChatAnthropic(**kwargs)
120
-
121
- @classmethod
122
- def create_google(cls, model_name: str = "gemini-2.0-flash", **kwargs) -> Any:
123
- """
124
- Create Google chatbot model
125
- Args:
126
- model_name (str): Name of the model
127
- **kwargs: Additional arguments
128
- Returns:
129
- ChatGoogleGenerativeAI: Chatbot model
130
- """
131
- if not check_package("langchain_google_genai"):
132
- raise ImportError("langchain_google_genai package not found. Please install it using: pip install google-generativeai langchain-google-genai")
133
-
134
- from langchain_google_genai import ChatGoogleGenerativeAI
135
- kwargs["model"] = model_name
136
- return ChatGoogleGenerativeAI(**kwargs)
137
-
138
- @classmethod
139
- def create_ollama(cls, model_name: str = "llama3", **kwargs) -> Any:
140
- """
141
- Create Ollama chatbot model
142
- Args:
143
- model_name (str): Name of the model
144
- **kwargs: Additional arguments
145
- Returns:
146
- Ollama: Chatbot model
147
- """
148
- if not check_package("langchain_ollama"):
149
- raise ImportError("Langchain Community package not found. Please install it using: pip install langchain_ollama")
150
-
151
- from langchain_ollama import ChatOllama
152
-
153
- print(f"Current Ollama serve model is {os.system('ollama ps')}")
154
- kwargs["model"] = model_name
155
- return ChatOllama(**kwargs)
156
-
157
- @classmethod
158
- def create_groq(cls, model_name: str = "llama-3.3-70b-versatile", **kwargs) -> Any:
159
- """
160
- Create Groq chatbot model
161
- Args:
162
- model_name (str): Name of the model
163
- **kwargs: Additional arguments. Options are: temperature, groq_api_key, model_name
164
- Returns:
165
- ChatGroq: Chatbot model
166
- """
167
- if not check_package("langchain_groq"):
168
- raise ImportError("Langchain Groq package not found. Please install it using: pip install langchain-groq")
169
-
170
- from langchain_groq import ChatGroq
171
- kwargs["model"] = model_name
172
- return ChatGroq(**kwargs)
173
-
174
- @classmethod
175
- def create_deepseek(cls, model_name: str = "deepseek-chat", **kwargs) -> Any:
176
- """
177
- Create Deepseek chatbot model
178
- Args:
179
- model_name (str): Name of the model
180
- **kwargs: Additional arguments
181
- Returns:
182
- ChatDeepseek: Chatbot model
183
- """
184
- if not check_package("langchain_deepseek"):
185
- raise ImportError("Langchain Deepseek package not found. Please install it using: pip install langchain-deepseek")
186
-
187
- from langchain_deepseek import ChatDeepSeek
188
- kwargs["model"] = model_name
189
- return ChatDeepSeek(**kwargs)
190
-
191
- @classmethod
192
- def create_qwen(cls, model_name: str = "qwen", **kwargs) -> Any:
193
- """
194
- Create Qwen chatbot model
195
- Args:
196
- model_name (str): Name of the model
197
- **kwargs: Additional arguments
198
- Returns:
199
- ChatQwen: Chatbot model
200
- """
201
- if not check_package("langchain_community"):
202
- raise ImportError("Langchain Qwen package not found. Please install it using: pip install langchain_community")
203
-
204
- from langchain_community.chat_models.tongyi import ChatTongyi
205
- kwargs["model"] = model_name
206
- return ChatTongyi(streaming=True,**kwargs)
207
-
208
- @classmethod
209
- def create_hugging_face(cls, model_name: str = "Qwen/Qwen2.5-VL-7B-Instruct",model_function: str = "image-text-to-text",
210
- device='cpu',**kwargs) -> Any:
211
- """
212
- Create and load hugging face model.
213
- Args:
214
- model_name (str): Name of the model
215
- model_function (str): model function. Default is image-text-to-text.
216
- device (str): Device to use. Default is cpu
217
- **kwargs: Additional arguments
218
- Returns:
219
- ChatHuggingFace: Chatbot model
220
- """
221
- if not check_package("transformers"):
222
- raise ImportError("Transformers package not found. Please install it using: pip install transformers")
223
- if not check_package("langchain_huggingface"):
224
- raise ImportError("langchain_huggingface package not found. Please install it using: pip install langchain_huggingface")
225
- if not check_package("torch"):
226
- raise ImportError("Torch package not found. Please install it using: pip install torch")
227
-
228
- from langchain_huggingface import HuggingFacePipeline
229
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModelForImageTextToText,AutoProcessor
230
- import torch
231
-
232
- device = torch.device(device) if torch.cuda.is_available() else torch.device("cpu")
233
-
234
- temperature = kwargs.pop("temperature", 0.7)
235
- max_length = kwargs.pop("max_length", 1024)
236
-
237
- if model_function == "image-text-to-text":
238
- tokenizer = AutoProcessor.from_pretrained(model_name,trust_remote_code=True)
239
- model = AutoModelForImageTextToText.from_pretrained(
240
- model_name,
241
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
242
- device_map=device,
243
- trust_remote_code=True,
244
- **kwargs
245
- )
246
- else:
247
- tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True)
248
- model = AutoModelForCausalLM.from_pretrained(
249
- model_name,
250
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
251
- device_map=device,
252
- trust_remote_code=True,
253
- **kwargs)
254
-
255
- # Create pipeline
256
- pipe = pipeline(
257
- model_function,
258
- model=model,
259
- tokenizer=tokenizer,
260
- max_length=max_length,
261
- temperature=temperature
262
- )
263
-
264
- # Create and return LangChain HuggingFacePipeline
265
- return HuggingFacePipeline(pipeline=pipe)
266
-
267
- def _reset_model(self):
268
- """Reset the model"""
269
- self.model = self.model.reset()
270
-
271
- def invoke_query(self,query: str,file_path: str = None,get_content_only: bool = True,images: list = None,pydantic_model = None) -> str:
272
- """
273
- Invoke the model
274
- Args:
275
- query (str): Query to send to the model
276
- file_path (str): Path to text file to load. Default is None
277
- get_content_only (bool): Whether to return only content
278
- images (list): List of images to send to the model
279
- pydantic_model: Pydantic model for structured output
280
- Returns:
281
- str: Response from the model
282
- """
283
- if file_path:
284
- loader = TextLoader(file_path)
285
- document = loader.load()
286
- query = document.content
287
-
288
- if pydantic_model is not None:
289
- if hasattr(self.model, 'with_structured_output'):
290
- try:
291
- self.model = self.model.with_structured_output(pydantic_model)
292
- except Exception as e:
293
- raise ValueError(f"Error with pydantic_model: {e}")
294
- if images:
295
- res = self._model_invoke_images(images=images,prompt=query,pydantic_model=pydantic_model,get_content_only=get_content_only)
296
- else:
297
- res = self.model.invoke(query)
298
- if get_content_only:
299
- try:
300
- return res.content
301
- except Exception:
302
- return res
303
- return res
304
-
305
- def _image_to_base64(self,image):
306
- with open(image, "rb") as f:
307
- return base64.b64encode(f.read()).decode('utf-8')
308
-
309
- def _model_invoke_images(self,images: list, prompt: str,pydantic_model = None,get_content_only: bool = True) -> str:
310
- """
311
- Function to invoke the model with images
312
- Args:
313
- images (list): List of images
314
- prompt (str): Prompt
315
- pydantic_model (PydanticModel): Pydantic model
316
- Returns:
317
- str: Output from the model
318
- """
319
- base64_images = [self._image_to_base64(image) for image in images]
320
- image_prompt_create = [{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_images[i]}"}} for i in range(len(images))]
321
- prompt_new = [{"type": "text", "text": prompt},
322
- *image_prompt_create,]
323
- if pydantic_model is not None:
324
- try:
325
- self.model = self.model.with_structured_output(pydantic_model)
326
- except Exception as e:
327
- print(f"Error with pydantic_model: {e}")
328
- print("Continuing without structured output")
329
- message= HumanMessage(content=prompt_new,)
330
- response = self.model.invoke([message])
331
-
332
- if get_content_only:
333
- try:
334
- return response.content
335
- except Exception:
336
- print("Failed to get content from response. Returning response object")
337
- return response
338
- else:
339
- return response
340
-
341
- def _get_llm_metadata(self):
342
- """
343
- Returns Basic metadata about the LLM
344
- """
345
- print("Model Name: ", self.model)
346
- print("Model Temperature: ", self.model.temperature)
347
- print("Model Max Tokens: ", self.model.max_output_tokens)
348
- print("Model Top P: ", self.model.top_p)
349
- print("Model Top K: ", self.model.top_k)
350
- print("Model Input Schema:",self.model.input_schema)
351
-
352
- class ConversationModel:
353
- """
354
- A class to handle conversation with AI models
355
-
356
- Attributes:
357
- chatbot: The AI model for conversation
358
- message_list (List): List of conversation messages
359
- file_path (str): Path to save/load conversations. Can be local or S3
360
- """
361
-
362
- def __init__(self,
363
- model_name: str = "gpt-4o",
364
- model_type: str = 'openai',
365
- **kwargs) -> None:
366
- """Initialize conversation model"""
367
- self.chatbot = ModelFactory(model_type, model_name, **kwargs)
368
-
369
- def initialize_conversation(self,
370
- question: Optional[str],
371
- context: Optional[str] = None,
372
- file_path: Optional[str]=None) -> None:
373
- """Initialize conversation state"""
374
- if file_path:
375
- self.file_path = file_path
376
- self.load_conversation(file_path)
377
-
378
- else:
379
- if not question:
380
- raise ValueError("Question is required.")
381
-
382
- if context:
383
- self.context = context
384
- else:
385
- self.context = "Answer question to the point and don't hallucinate."
386
- self.message_list = [
387
- SystemMessage(content=context),
388
- HumanMessage(content=question)
389
- ]
390
-
391
- res = self._ask_question(self.message_list)
392
- print(res)
393
- self.message_list.append(AIMessage(content=res))
394
-
395
- def _ask_question(self,messages: List[Union[SystemMessage, HumanMessage, AIMessage]],
396
- get_content_only: bool = True) -> str:
397
- """
398
- Ask a question and get response
399
- Args:
400
- messages: List of messages
401
- get_content_only: Whether to return only content
402
- Returns:
403
- str: Response from the model
404
- """
405
- res = self.chatbot.invoke_query(messages)
406
- if get_content_only:
407
- try:
408
- return res.content
409
- except Exception:
410
- return res
411
- return res
412
-
413
- def add_message(self, message: str) -> str:
414
- """
415
- Add a message to the conversation
416
- Args:
417
- message (str): Message to add
418
- Returns:
419
- str: Response from the chatbot
420
- """
421
- self.message_list.append(HumanMessage(content=message))
422
- res = self._ask_question(self.message_list)
423
- self.message_list.append(AIMessage(content=res))
424
- return res
425
-
426
- @property
427
- def all_messages(self) -> List[Union[SystemMessage, HumanMessage, AIMessage]]:
428
- """Get all messages"""
429
- return self.message_list
430
-
431
- @property
432
- def last_message(self) -> str:
433
- """Get the last message"""
434
- return self.message_list[-1].content
435
-
436
- @property
437
- def all_messages_content(self) -> List[str]:
438
- """Get content of all messages"""
439
- return [message.content for message in self.message_list]
440
-
441
- def _is_s3_path(self, path: str) -> bool:
442
- """
443
- Check if path is an S3 path
444
- Args:
445
- path (str): Path to check
446
- Returns:
447
- bool: True if S3 path
448
- """
449
- return path.startswith("s3://")
450
-
451
- def save_conversation(self, file_path: Optional[str] = None, **kwargs) -> bool:
452
- """
453
- Save the conversation
454
- Args:
455
- file_path: Path to save the conversation
456
- **kwargs: Additional arguments for S3
457
- Returns:
458
- bool: Success status
459
- """
460
- if self._is_s3_path(file_path or self.file_path):
461
- print("Saving conversation to S3.")
462
- self.save_file_path = file_path
463
- return self._save_to_s3(self.file_path,**kwargs)
464
- return self._save_to_file(file_path or self.file_path)
465
-
466
- def _save_to_s3(self,**kwargs) -> bool:
467
- """Save conversation to S3"""
468
- try:
469
- client = kwargs.get('client', self.client)
470
- bucket = kwargs.get('bucket', self.bucket)
471
- client.put_object(
472
- Body=str(self.message_list),
473
- Bucket=bucket,
474
- Key=self.save_file_path
475
- )
476
- print(f"Conversation saved to s3_path: {self.s3_path}")
477
- return True
478
- except Exception as e:
479
- raise ValueError(f"Error saving conversation to s3: {e}")
480
-
481
- def _save_to_file(self, file_path: str) -> bool:
482
- """Save conversation to file"""
483
- try:
484
- with open(file_path, 'w') as f:
485
- for message in self.message_list:
486
- f.write(f"{message.content}\n")
487
- print(f"Conversation saved to file: {file_path}")
488
- return True
489
- except Exception as e:
490
- raise ValueError(f"Error saving conversation to file: {e}")
491
-
492
- def load_conversation(self, file_path: Optional[str] = None, **kwargs) -> List[Any]:
493
- """
494
- Load a conversation
495
- Args:
496
- file_path: Path to load from
497
- **kwargs: Additional arguments for S3
498
- Returns:
499
- List: Loaded messages
500
- """
501
- self.message_list = []
502
- if self._is_s3_path(file_path or self.file_path):
503
- print("Loading conversation from S3.")
504
- self.file_path = file_path
505
- return self._load_from_s3(**kwargs)
506
- return self._load_from_file(file_path or self.file_path)
507
-
508
- def _load_from_s3(self, **kwargs) -> List[Any]:
509
- """Load conversation from S3"""
510
- try:
511
- client = kwargs.get('client', self.client)
512
- bucket = kwargs.get('bucket', self.bucket)
513
- res = client.get_response(client, bucket, self.s3_path)
514
- res_str = eval(res['Body'].read().decode('utf-8'))
515
- self.message_list = [SystemMessage(content=res_str)]
516
- print(f"Conversation loaded from s3_path: {self.file_path}")
517
- return self.message_list
518
- except Exception as e:
519
- raise ValueError(f"Error loading conversation from s3: {e}")
520
-
521
- def _load_from_file(self, file_path: str) -> List[Any]:
522
- """Load conversation from file"""
523
- try:
524
- with open(file_path, 'r') as f:
525
- lines = f.readlines()
526
- for line in lines:
527
- self.message_list.append(SystemMessage(content=line))
528
- print(f"Conversation loaded from file: {file_path}")
529
- return self.message_list
530
- except Exception as e:
531
- raise ValueError(f"Error loading conversation from file: {e}")
532
-
533
- class IPythonStreamHandler(StreamingStdOutCallbackHandler):
534
- """Handler for IPython display"""
535
-
536
- def __init__(self):
537
- self.output = ""
538
-
539
- def on_llm_new_token(self, token: str, **kwargs) -> None:
540
- """Handle new token"""
541
- self.output += token
542
- display(HTML(self.output), clear=True)
543
-
544
-
545
- class AgentFactory:
546
- """Factory class for creating different types of agents"""
547
-
548
- def __init__(self, agent_type: str = 'basic', model_name: str = "gpt-4o", **kwargs) -> Any:
549
- """
550
- Factory method to create any type of agent
551
- Args:
552
- agent_type (str): Type of agent to create. Default is basic.
553
- model_name (str): Name of the model
554
- **kwargs: Additional arguments
555
- Returns:
556
- Any: Agent
557
- """
558
- creators = {
559
- 'basic': self.create_basic_agent,
560
- 'langgraph': self.create_langgraph_agent,
561
- }
562
-
563
- agent_data = creators.get(agent_type)
564
- if not agent_data:
565
- raise ValueError(f"Unsupported agent type: {agent_type}")
566
-
567
- try:
568
- self.agent = agent_data(model_name, **kwargs)
569
- except Exception as e:
570
- raise ValueError(f"Error creating {agent_type} agent: {str(e)}")
571
-
572
- @classmethod
573
- def create_basic_agent(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
574
- """
575
- Create basic agent
576
- Args:
577
- model_name (str): Name of the model
578
- **kwargs: Additional arguments
579
- Returns:
580
- Runnable: Agent
581
- """
582
- # Basic agent creation logic here
583
- llm = ModelFactory(model_name=model_name, **kwargs).model
584
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
585
- prompt = ChatPromptTemplate.from_messages([
586
- ("system", "You are a helpful AI assistant"),
587
- MessagesPlaceholder(variable_name="messages")
588
- ])
589
- from langchain_core.runnables import chain
590
- agent = prompt | llm
591
- return agent
592
-
593
- @classmethod
594
- def create_langgraph_agent(cls, model_name: str = "gpt-4o", **kwargs) -> Any:
595
- """
596
- Create LangGraph agent
597
- Args:
598
- model_name (str): Name of the model
599
- **kwargs: Additional arguments
600
- Returns:
601
- Graph: LangGraph agent
602
- """
603
- if not check_package("langgraph"):
604
- raise ImportError("LangGraph package not found. Please install it using: pip install langgraph")
605
-
606
- from langgraph.graph import StateGraph, MessageGraph
607
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
608
- from langchain_core.runnables import chain
609
- from langchain_core.messages import BaseMessage
610
-
611
- llm = ModelFactory(model_name=model_name, **kwargs).model
612
-
613
- # Define the state of the graph
614
- class GraphState:
615
- messages: List[BaseMessage]
616
- agent_state: Dict[str, Any]
617
-
618
- # Define the nodes
619
- def agent(state: GraphState):
620
- prompt = ChatPromptTemplate.from_messages([
621
- ("system", "You are a helpful AI assistant"),
622
- MessagesPlaceholder(variable_name="messages")
623
- ])
624
- return (prompt | llm).invoke({"messages": state.messages})
625
-
626
- def user(state: GraphState, input: str):
627
- return HumanMessage(content=input)
628
-
629
- # Define the graph
630
- graph = MessageGraph()
631
-
632
- # Add the nodes
633
- graph.add_node("agent", agent)
634
- graph.add_node("user", user)
635
-
636
- # Set the entrypoint
637
- graph.set_entry_point("user")
638
-
639
- # Add the edges
640
- graph.add_edge("user", "agent")
641
- graph.add_edge("agent", "user")
642
-
643
- # Compile the graph
644
- return graph.compile()