aiagents4pharma 1.8.3__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. aiagents4pharma/__init__.py +9 -6
  2. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +3 -1
  3. aiagents4pharma/talk2biomodels/__init__.py +1 -1
  4. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +1 -1
  5. aiagents4pharma/talk2biomodels/tests/test_langgraph.py +71 -20
  6. aiagents4pharma/talk2biomodels/tools/ask_question.py +16 -7
  7. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
  8. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +6 -6
  9. aiagents4pharma/talk2biomodels/tools/simulate_model.py +26 -12
  10. aiagents4pharma/talk2competitors/__init__.py +5 -0
  11. aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
  12. aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
  13. aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
  14. aiagents4pharma/talk2competitors/config/__init__.py +5 -0
  15. aiagents4pharma/talk2competitors/config/config.py +110 -0
  16. aiagents4pharma/talk2competitors/state/__init__.py +5 -0
  17. aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
  18. aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
  19. aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
  20. aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
  21. aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
  22. aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
  23. aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
  24. aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
  25. aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
  26. {aiagents4pharma-1.8.3.dist-info → aiagents4pharma-1.10.0.dist-info}/METADATA +37 -18
  27. {aiagents4pharma-1.8.3.dist-info → aiagents4pharma-1.10.0.dist-info}/RECORD +30 -15
  28. {aiagents4pharma-1.8.3.dist-info → aiagents4pharma-1.10.0.dist-info}/LICENSE +0 -0
  29. {aiagents4pharma-1.8.3.dist-info → aiagents4pharma-1.10.0.dist-info}/WHEEL +0 -0
  30. {aiagents4pharma-1.8.3.dist-info → aiagents4pharma-1.10.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,11 @@
1
- '''
1
+ """
2
2
  This file is used to import aiagents4pharma modules.
3
- '''
3
+ """
4
4
 
5
- from . import talk2biomodels
6
- from . import talk2cells
7
- from . import talk2knowledgegraphs
8
- from . import configs
5
+ from . import (
6
+ configs,
7
+ talk2biomodels,
8
+ talk2cells,
9
+ talk2competitors,
10
+ talk2knowledgegraphs,
11
+ )
@@ -3,4 +3,6 @@ state_modifier: >
3
3
  You are Talk2BioModels agent.
4
4
  If the user asks for the uploaded model,
5
5
  then pass the use_uploaded_model argument
6
- as True.
6
+ as True. If the user asks for simulation,
7
+ then suggest a value for the `simulation_name`
8
+ argument.
@@ -4,4 +4,4 @@ This file is used to import the models and tools.
4
4
  from . import models
5
5
  from . import tools
6
6
  from . import agents
7
- from . import states
7
+ from . import states
@@ -20,5 +20,5 @@ class Talk2Biomodels(AgentState):
20
20
  # the operator for the sbml_file_path field.
21
21
  # https://langchain-ai.github.io/langgraph/troubleshooting/errors/INVALID_CONCURRENT_GRAPH_UPDATE/
22
22
  sbml_file_path: Annotated[list, operator.add]
23
- dic_simulated_data: dict
23
+ dic_simulated_data: Annotated[list[dict], operator.add]
24
24
  llm_model: str
@@ -1,7 +1,8 @@
1
1
  '''
2
- Test cases
2
+ Test cases for Talk2Biomodels.
3
3
  '''
4
4
 
5
+ import pandas as pd
5
6
  from langchain_core.messages import HumanMessage, ToolMessage
6
7
  from ..agents.t2b_agent import get_app
7
8
 
@@ -13,7 +14,8 @@ def test_get_modelinfo_tool():
13
14
  app = get_app(unique_id)
14
15
  config = {"configurable": {"thread_id": unique_id}}
15
16
  # Update state
16
- app.update_state(config,{"sbml_file_path": ["BIOMD0000000537.xml"]})
17
+ app.update_state(config,
18
+ {"sbml_file_path": ["aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml"]})
17
19
  prompt = "Extract all relevant information from the uploaded model."
18
20
  # Test the tool get_modelinfo
19
21
  response = app.invoke(
@@ -56,26 +58,70 @@ def test_ask_question_tool():
56
58
 
57
59
  ##########################################
58
60
  # Test ask_question tool when simulation
59
- # results are not available
61
+ # results are not available i.e. the
62
+ # simulation has not been run. In this
63
+ # case, the tool should return an error
60
64
  ##########################################
61
65
  # Update state
62
66
  app.update_state(config, {"llm_model": "gpt-4o-mini"})
67
+ # Define the prompt
63
68
  prompt = "Call the ask_question tool to answer the "
64
69
  prompt += "question: What is the concentration of CRP "
65
- prompt += "in serum at 1000 hours?"
66
-
67
- # Test the tool get_modelinfo
68
- response = app.invoke(
69
- {"messages": [HumanMessage(content=prompt)]},
70
- config=config
71
- )
72
- assistant_msg = response["messages"][-1].content
73
- # Check if the assistant message is a string
74
- assert isinstance(assistant_msg, str)
70
+ prompt += "in serum at 1000 hours? The simulation name "
71
+ prompt += "is `simulation_name`."
72
+ # Invoke the tool
73
+ app.invoke(
74
+ {"messages": [HumanMessage(content=prompt)]},
75
+ config=config
76
+ )
77
+ # Get the messages from the current state
78
+ # and reverse the order
79
+ current_state = app.get_state(config)
80
+ reversed_messages = current_state.values["messages"][::-1]
81
+ # Loop through the reversed messages until a
82
+ # ToolMessage is found.
83
+ for msg in reversed_messages:
84
+ # Assert that the message is a ToolMessage
85
+ # and its status is "error"
86
+ if isinstance(msg, ToolMessage):
87
+ assert msg.status == "error"
75
88
 
76
89
  def test_simulate_model_tool():
77
90
  '''
78
- Test the simulate_model tool.
91
+ Test the simulate_model tool when simulating
92
+ multiple models.
93
+ '''
94
+ unique_id = 123
95
+ app = get_app(unique_id)
96
+ config = {"configurable": {"thread_id": unique_id}}
97
+ app.update_state(config, {"llm_model": "gpt-4o-mini"})
98
+ # Upload a model to the state
99
+ app.update_state(config,
100
+ {"sbml_file_path": ["aiagents4pharma/talk2biomodels/tests/BIOMD0000000449_url.xml"]})
101
+ prompt = "Simulate models 64 and the uploaded model"
102
+ # Invoke the agent
103
+ app.invoke(
104
+ {"messages": [HumanMessage(content=prompt)]},
105
+ config=config
106
+ )
107
+ current_state = app.get_state(config)
108
+ dic_simulated_data = current_state.values["dic_simulated_data"]
109
+ # Check if the dic_simulated_data is a list
110
+ assert isinstance(dic_simulated_data, list)
111
+ # Check if the length of the dic_simulated_data is 2
112
+ assert len(dic_simulated_data) == 2
113
+ # Check if the source of the first model is 64
114
+ assert dic_simulated_data[0]['source'] == 64
115
+ # Check if the source of the second model is upload
116
+ assert dic_simulated_data[1]['source'] == "upload"
117
+ # Check if the data of the first model contains
118
+ assert '1,3-bisphosphoglycerate' in dic_simulated_data[0]['data']
119
+ # Check if the data of the second model contains
120
+ assert 'mTORC2' in dic_simulated_data[1]['data']
121
+
122
+ def test_integration():
123
+ '''
124
+ Test the integration of the tools.
79
125
  '''
80
126
  unique_id = 123
81
127
  app = get_app(unique_id)
@@ -138,9 +184,9 @@ def test_simulate_model_tool():
138
184
  reversed_messages = current_state.values["messages"][::-1]
139
185
  # Loop through the reversed messages
140
186
  # until a ToolMessage is found.
141
- expected_artifact = ['CRP[serum]', 'CRPExtracellular']
142
- expected_artifact += ['CRP Suppression (%)', 'CRP (% of baseline)']
143
- expected_artifact += ['CRP[liver]']
187
+ expected_header = ['Time', 'CRP[serum]', 'CRPExtracellular']
188
+ expected_header += ['CRP Suppression (%)', 'CRP (% of baseline)']
189
+ expected_header += ['CRP[liver]']
144
190
  predicted_artifact = []
145
191
  for msg in reversed_messages:
146
192
  if isinstance(msg, ToolMessage):
@@ -150,9 +196,14 @@ def test_simulate_model_tool():
150
196
  if msg.name == "custom_plotter":
151
197
  predicted_artifact = msg.artifact
152
198
  break
153
- # Check if the two artifacts are equal
154
- # assert expected_artifact in predicted_artifact
155
- assert set(expected_artifact).issubset(set(predicted_artifact))
199
+ # Convert the artifact into a pandas dataframe
200
+ # for easy comparison
201
+ df = pd.DataFrame(predicted_artifact)
202
+ # Extract the headers from the dataframe
203
+ predicted_header = df.columns.tolist()
204
+ # Check if the header is in the expected_header
205
+ # assert expected_header in predicted_artifact
206
+ assert set(expected_header).issubset(set(predicted_header))
156
207
  ##########################################
157
208
  # Test custom_plotter tool when the
158
209
  # simulation results are available but
@@ -23,6 +23,8 @@ class AskQuestionInput(BaseModel):
23
23
  Input schema for the AskQuestion tool.
24
24
  """
25
25
  question: str = Field(description="question about the simulation results")
26
+ simulation_name: str = Field(description="""Name assigned to the simulation
27
+ when the tool simulate_model was invoked.""")
26
28
  state: Annotated[dict, InjectedState]
27
29
 
28
30
  # Note: It's important that every field has type hints.
@@ -39,6 +41,7 @@ class AskQuestionTool(BaseTool):
39
41
 
40
42
  def _run(self,
41
43
  question: str,
44
+ simulation_name: str,
42
45
  state: Annotated[dict, InjectedState]) -> str:
43
46
  """
44
47
  Run the tool.
@@ -46,18 +49,24 @@ class AskQuestionTool(BaseTool):
46
49
  Args:
47
50
  question (str): The question to ask about the simulation results.
48
51
  state (dict): The state of the graph.
49
- run_manager (Optional[CallbackManagerForToolRun]): The CallbackManagerForToolRun object.
52
+ simulation_name (str): The name assigned to the simulation.
50
53
 
51
54
  Returns:
52
55
  str: The answer to the question.
53
56
  """
54
57
  logger.log(logging.INFO,
55
- "Calling ask_question tool %s", question)
56
- # Check if the simulation results are available
57
- if 'dic_simulated_data' not in state:
58
- return "Please run the simulation first before \
59
- asking a question about the simulation results."
60
- df = pd.DataFrame.from_dict(state['dic_simulated_data'])
58
+ "Calling ask_question tool %s, %s", question, simulation_name)
59
+ dic_simulated_data = {}
60
+ for data in state["dic_simulated_data"]:
61
+ for key in data:
62
+ if key not in dic_simulated_data:
63
+ dic_simulated_data[key] = []
64
+ dic_simulated_data[key] += [data[key]]
65
+ # print (dic_simulated_data)
66
+ df_simulated_data = pd.DataFrame.from_dict(dic_simulated_data)
67
+ df = pd.DataFrame(
68
+ df_simulated_data[df_simulated_data['name'] == simulation_name]['data'].iloc[0]
69
+ )
61
70
  prompt_content = None
62
71
  # if run_manager and 'prompt' in run_manager.metadata:
63
72
  # prompt_content = run_manager.metadata['prompt']
@@ -6,14 +6,11 @@ Tool for plotting a custom figure.
6
6
 
7
7
  import logging
8
8
  from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
9
- from typing import Type, List, TypedDict, Annotated, Tuple, Union, Literal
10
9
  from pydantic import BaseModel, Field
11
10
  import pandas as pd
12
- import pandas as pd
13
11
  from langchain_openai import ChatOpenAI
14
12
  from langchain_core.tools import BaseTool
15
13
  from langgraph.prebuilt import InjectedState
16
- from langgraph.prebuilt import InjectedState
17
14
 
18
15
  # Initialize logger
19
16
  logging.basicConfig(level=logging.INFO)
@@ -24,7 +21,7 @@ class CustomPlotterInput(BaseModel):
24
21
  Input schema for the PlotImage tool.
25
22
  """
26
23
  question: str = Field(description="Description of the plot")
27
- state: Annotated[dict, InjectedState]
24
+ simulation_name: str = Field(description="Name assigned to the simulation")
28
25
  state: Annotated[dict, InjectedState]
29
26
 
30
27
  # Note: It's important that every field has type hints.
@@ -41,10 +38,10 @@ class CustomPlotterTool(BaseTool):
41
38
  description: str = "A tool to make custom plots of the simulation results"
42
39
  args_schema: Type[BaseModel] = CustomPlotterInput
43
40
  response_format: str = "content_and_artifact"
44
- response_format: str = "content_and_artifact"
45
41
 
46
42
  def _run(self,
47
43
  question: str,
44
+ simulation_name: str,
48
45
  state: Annotated[dict, InjectedState]
49
46
  ) -> Tuple[str, Union[None, List[str]]]:
50
47
  """
@@ -53,17 +50,24 @@ class CustomPlotterTool(BaseTool):
53
50
  Args:
54
51
  question (str): The question about the custom plot.
55
52
  state (dict): The state of the graph.
56
- question (str): The question about the custom plot.
57
- state (dict): The state of the graph.
58
53
 
59
54
  Returns:
60
55
  str: The answer to the question
61
56
  """
62
57
  logger.log(logging.INFO, "Calling custom_plotter tool %s", question)
63
- # Check if the simulation results are available
64
- # if 'dic_simulated_data' not in state:
65
- # return "Please run the simulation first before plotting the figure.", None
66
- df = pd.DataFrame.from_dict(state['dic_simulated_data'])
58
+ dic_simulated_data = {}
59
+ for data in state["dic_simulated_data"]:
60
+ for key in data:
61
+ if key not in dic_simulated_data:
62
+ dic_simulated_data[key] = []
63
+ dic_simulated_data[key] += [data[key]]
64
+ # Create a pandas dataframe from the dictionary
65
+ df = pd.DataFrame.from_dict(dic_simulated_data)
66
+ # Get the simulated data for the current tool call
67
+ df = pd.DataFrame(
68
+ df[df['name'] == simulation_name]['data'].iloc[0]
69
+ )
70
+ # df = pd.DataFrame.from_dict(state['dic_simulated_data'])
67
71
  species_names = df.columns.tolist()
68
72
  # Exclude the time column
69
73
  species_names.remove('Time')
@@ -76,7 +80,8 @@ class CustomPlotterTool(BaseTool):
76
80
  A list of species based on user question.
77
81
  """
78
82
  relevant_species: Union[None, List[Literal[*species_names]]] = Field(
79
- description="List of species based on user question. If no relevant species are found, it will be None.")
83
+ description="""List of species based on user question.
84
+ If no relevant species are found, it will be None.""")
80
85
  # Create an instance of the LLM model
81
86
  llm = ChatOpenAI(model=state['llm_model'], temperature=0)
82
87
  llm_with_structured_output = llm.with_structured_output(CustomHeader)
@@ -90,5 +95,6 @@ class CustomPlotterTool(BaseTool):
90
95
  logger.info("Extracted species: %s", extracted_species)
91
96
  if len(extracted_species) == 0:
92
97
  return "No species found in the simulation results that matches the user prompt.", None
93
- content = f"Plotted custom figure with species: {', '.join(extracted_species)}"
94
- return content, extracted_species
98
+ # Include the time column
99
+ extracted_species.insert(0, 'Time')
100
+ return f"Custom plot {simulation_name}", df[extracted_species].to_dict(orient='records')
@@ -25,12 +25,12 @@ class RequestedModelInfo:
25
25
  """
26
26
  Dataclass for storing the requested model information.
27
27
  """
28
- species: bool = Field(description="Get species from the model.")
29
- parameters: bool = Field(description="Get parameters from the model.")
30
- compartments: bool = Field(description="Get compartments from the model.")
31
- units: bool = Field(description="Get units from the model.")
32
- description: bool = Field(description="Get description from the model.")
33
- name: bool = Field(description="Get name from the model.")
28
+ species: bool = Field(description="Get species from the model.", default=False)
29
+ parameters: bool = Field(description="Get parameters from the model.", default=False)
30
+ compartments: bool = Field(description="Get compartments from the model.", default=False)
31
+ units: bool = Field(description="Get units from the model.", default=False)
32
+ description: bool = Field(description="Get description from the model.", default=False)
33
+ name: bool = Field(description="Get name from the model.", default=False)
34
34
 
35
35
  class GetModelInfoInput(BaseModel):
36
36
  """
@@ -52,10 +52,10 @@ class TimeSpeciesNameConcentration:
52
52
  class RecurringData:
53
53
  """
54
54
  Dataclass for storing the species and time data
55
- on recurring basis.
55
+ on reocurring basis.
56
56
  """
57
57
  data: List[TimeSpeciesNameConcentration] = Field(
58
- description="species and time data on recurring basis",
58
+ description="species and time data on reocurring basis",
59
59
  default=None)
60
60
 
61
61
  @dataclass
@@ -68,12 +68,15 @@ class ArgumentData:
68
68
  description="species name and initial concentration data",
69
69
  default=None)
70
70
  recurring_data: RecurringData = Field(
71
- description="species and time data on recurring basis",
71
+ description="species and time data on reocurring basis",
72
72
  default=None)
73
+ simulation_name: str = Field(
74
+ description="""An AI assigned `_` separated name of
75
+ the simulation based on human query""")
73
76
 
74
77
  def add_rec_events(model_object, recurring_data):
75
78
  """
76
- Add recurring events to the model.
79
+ Add reocurring events to the model.
77
80
  """
78
81
  for row in recurring_data.data:
79
82
  tp, sn, sc = row.time, row.species_name, row.species_concentration
@@ -86,9 +89,12 @@ class SimulateModelInput(BaseModel):
86
89
  """
87
90
  Input schema for the SimulateModel tool.
88
91
  """
89
- sys_bio_model: ModelData = Field(description="model data", default=None)
90
- arg_data: ArgumentData = Field(description="time, species, and recurring data",
91
- default=None)
92
+ sys_bio_model: ModelData = Field(description="model data",
93
+ default=None)
94
+ arg_data: ArgumentData = Field(description=
95
+ """time, species, and reocurring data
96
+ as well as the simulation name""",
97
+ default=None)
92
98
  tool_call_id: Annotated[str, InjectedToolCallId]
93
99
  state: Annotated[dict, InjectedState]
94
100
 
@@ -153,12 +159,20 @@ class SimulateModelTool(BaseTool):
153
159
  interval=interval
154
160
  )
155
161
 
162
+ dic_simulated_data = {
163
+ 'name': arg_data.simulation_name,
164
+ 'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
165
+ 'tool_call_id': tool_call_id,
166
+ 'data': df.to_dict()
167
+ }
168
+
156
169
  # Prepare the dictionary of updated state for the model
157
170
  dic_updated_state_for_model = {}
158
171
  for key, value in {
159
- "model_id": [sys_bio_model.biomodel_id],
160
- "sbml_file_path": [sbml_file_path],
161
- }.items():
172
+ "model_id": [sys_bio_model.biomodel_id],
173
+ "sbml_file_path": [sbml_file_path],
174
+ "dic_simulated_data": [dic_simulated_data],
175
+ }.items():
162
176
  if value:
163
177
  dic_updated_state_for_model[key] = value
164
178
 
@@ -166,11 +180,11 @@ class SimulateModelTool(BaseTool):
166
180
  return Command(
167
181
  update=dic_updated_state_for_model|{
168
182
  # update the state keys
169
- "dic_simulated_data": df.to_dict(),
183
+ # "dic_simulated_data": df.to_dict(),
170
184
  # update the message history
171
185
  "messages": [
172
186
  ToolMessage(
173
- content="Simulation results are ready.",
187
+ content=f"Simulation results of {arg_data.simulation_name}",
174
188
  tool_call_id=tool_call_id
175
189
  )
176
190
  ],
@@ -0,0 +1,5 @@
1
+ """
2
+ This file is used to import all the modules in the package.
3
+ """
4
+
5
+ from . import agents, config, state, tests, tools
@@ -0,0 +1,6 @@
1
+ '''
2
+ This file is used to import all the modules in the package.
3
+ '''
4
+
5
+ from . import main_agent
6
+ from . import s2_agent
@@ -0,0 +1,130 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Main agent for the talk2competitors app.
5
+ """
6
+
7
+ import logging
8
+ from typing import Literal
9
+ from dotenv import load_dotenv
10
+ from langchain_core.language_models.chat_models import BaseChatModel
11
+ from langchain_core.messages import AIMessage
12
+ from langchain_openai import ChatOpenAI
13
+ from langgraph.checkpoint.memory import MemorySaver
14
+ from langgraph.graph import END, START, StateGraph
15
+ from langgraph.types import Command
16
+ from ..agents import s2_agent
17
+ from ..config.config import config
18
+ from ..state.state_talk2competitors import Talk2Competitors
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ load_dotenv()
24
+
25
+ def make_supervisor_node(llm: BaseChatModel) -> str:
26
+ """
27
+ Creates a supervisor node following LangGraph patterns.
28
+
29
+ Args:
30
+ llm (BaseChatModel): The language model to use for generating responses.
31
+
32
+ Returns:
33
+ str: The supervisor node function.
34
+ """
35
+ # options = ["FINISH", "s2_agent"]
36
+
37
+ def supervisor_node(state: Talk2Competitors) -> Command[Literal["s2_agent", "__end__"]]:
38
+ """
39
+ Supervisor node that routes to appropriate sub-agents.
40
+
41
+ Args:
42
+ state (Talk2Competitors): The current state of the conversation.
43
+
44
+ Returns:
45
+ Command[Literal["s2_agent", "__end__"]]: The command to execute next.
46
+ """
47
+ logger.info("Supervisor node called")
48
+
49
+ messages = [{"role": "system", "content": config.MAIN_AGENT_PROMPT}] + state[
50
+ "messages"
51
+ ]
52
+ response = llm.invoke(messages)
53
+ goto = (
54
+ "FINISH"
55
+ if not any(
56
+ kw in state["messages"][-1].content.lower()
57
+ for kw in ["search", "paper", "find"]
58
+ )
59
+ else "s2_agent"
60
+ )
61
+
62
+ if goto == "FINISH":
63
+ return Command(
64
+ goto=END,
65
+ update={
66
+ "messages": state["messages"]
67
+ + [AIMessage(content=response.content)],
68
+ "is_last_step": True,
69
+ "current_agent": None,
70
+ },
71
+ )
72
+
73
+ return Command(
74
+ goto="s2_agent",
75
+ update={
76
+ "messages": state["messages"],
77
+ "is_last_step": False,
78
+ "current_agent": "s2_agent",
79
+ },
80
+ )
81
+
82
+ return supervisor_node
83
+
84
+ def get_app(thread_id: str, llm_model ='gpt-4o-mini') -> StateGraph:
85
+ """
86
+ Returns the langraph app with hierarchical structure.
87
+
88
+ Args:
89
+ thread_id (str): The thread ID for the conversation.
90
+
91
+ Returns:
92
+ The compiled langraph app.
93
+ """
94
+ def call_s2_agent(state: Talk2Competitors) -> Command[Literal["__end__"]]:
95
+ """
96
+ Node for calling the S2 agent.
97
+
98
+ Args:
99
+ state (Talk2Competitors): The current state of the conversation.
100
+
101
+ Returns:
102
+ Command[Literal["__end__"]]: The command to execute next.
103
+ """
104
+ logger.info("Calling S2 agent")
105
+ app = s2_agent.get_app(thread_id, llm_model)
106
+ response = app.invoke(state)
107
+ logger.info("S2 agent completed")
108
+ return Command(
109
+ goto=END,
110
+ update={
111
+ "messages": response["messages"],
112
+ "papers": response.get("papers", []),
113
+ "is_last_step": True,
114
+ "current_agent": "s2_agent",
115
+ },
116
+ )
117
+ llm = ChatOpenAI(model=llm_model, temperature=0)
118
+ workflow = StateGraph(Talk2Competitors)
119
+
120
+ supervisor = make_supervisor_node(llm)
121
+ workflow.add_node("supervisor", supervisor)
122
+ workflow.add_node("s2_agent", call_s2_agent)
123
+
124
+ # Define edges
125
+ workflow.add_edge(START, "supervisor")
126
+ workflow.add_edge("s2_agent", END)
127
+
128
+ app = workflow.compile(checkpointer=MemorySaver())
129
+ logger.info("Main agent workflow compiled")
130
+ return app
@@ -0,0 +1,75 @@
1
+ #/usr/bin/env python3
2
+
3
+ '''
4
+ Agent for interacting with Semantic Scholar
5
+ '''
6
+
7
+ import logging
8
+ from dotenv import load_dotenv
9
+ from langchain_openai import ChatOpenAI
10
+ from langgraph.graph import START, StateGraph
11
+ from langgraph.prebuilt import create_react_agent
12
+ from langgraph.checkpoint.memory import MemorySaver
13
+ from ..config.config import config
14
+ from ..state.state_talk2competitors import Talk2Competitors
15
+ # from ..tools.s2 import s2_tools
16
+ from ..tools.s2.search import search_tool
17
+ from ..tools.s2.display_results import display_results
18
+ from ..tools.s2.single_paper_rec import get_single_paper_recommendations
19
+ from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
20
+
21
+ load_dotenv()
22
+
23
+ # Initialize logger
24
+ logging.basicConfig(level=logging.INFO)
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def get_app(uniq_id, llm_model='gpt-4o-mini'):
28
+ '''
29
+ This function returns the langraph app.
30
+ '''
31
+ def agent_s2_node(state: Talk2Competitors):
32
+ '''
33
+ This function calls the model.
34
+ '''
35
+ logger.log(logging.INFO, "Creating Agent_S2 node with thread_id %s", uniq_id)
36
+ response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
37
+ return response
38
+
39
+ # Define the tools
40
+ tools = [search_tool,
41
+ display_results,
42
+ get_single_paper_recommendations,
43
+ get_multi_paper_recommendations]
44
+
45
+ # Create the LLM
46
+ llm = ChatOpenAI(model=llm_model, temperature=0)
47
+ model = create_react_agent(
48
+ llm,
49
+ tools=tools,
50
+ state_schema=Talk2Competitors,
51
+ state_modifier=config.S2_AGENT_PROMPT,
52
+ checkpointer=MemorySaver()
53
+ )
54
+
55
+ # Define a new graph
56
+ workflow = StateGraph(Talk2Competitors)
57
+
58
+ # Define the two nodes we will cycle between
59
+ workflow.add_node("agent_s2", agent_s2_node)
60
+
61
+ # Set the entrypoint as `agent`
62
+ # This means that this node is the first one called
63
+ workflow.add_edge(START, "agent_s2")
64
+
65
+ # Initialize memory to persist state between graph runs
66
+ checkpointer = MemorySaver()
67
+
68
+ # Finally, we compile it!
69
+ # This compiles it into a LangChain Runnable,
70
+ # meaning you can use it as you would any other runnable.
71
+ # Note that we're (optionally) passing the memory when compiling the graph
72
+ app = workflow.compile(checkpointer=checkpointer)
73
+ logger.log(logging.INFO, "Compiled the graph")
74
+
75
+ return app
@@ -0,0 +1,5 @@
1
+ """
2
+ This package contains configuration settings and prompts used by various AI agents
3
+ """
4
+
5
+ from . import config