aiagents4pharma 1.8.0__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiagents4pharma/__init__.py +9 -5
- aiagents4pharma/configs/__init__.py +5 -0
- aiagents4pharma/configs/config.yaml +4 -0
- aiagents4pharma/configs/talk2biomodels/__init__.py +6 -0
- aiagents4pharma/configs/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/__init__.py +3 -0
- aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +14 -0
- aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
- aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/agents/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/agents/t2b_agent.py +96 -0
- aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
- aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
- aiagents4pharma/talk2biomodels/api/ols.py +72 -0
- aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
- aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
- aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
- aiagents4pharma/talk2biomodels/states/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +41 -0
- aiagents4pharma/talk2biomodels/tests/__init__.py +3 -0
- aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
- aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
- aiagents4pharma/talk2biomodels/tests/test_basico_model.py +54 -0
- aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
- aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
- aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
- aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
- aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
- aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
- aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
- aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
- aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +63 -0
- aiagents4pharma/talk2biomodels/tools/__init__.py +5 -0
- aiagents4pharma/talk2biomodels/tools/ask_question.py +61 -18
- aiagents4pharma/talk2biomodels/tools/custom_plotter.py +20 -14
- aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
- aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +11 -9
- aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
- aiagents4pharma/talk2biomodels/tools/load_biomodel.py +0 -1
- aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
- aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
- aiagents4pharma/talk2biomodels/tools/simulate_model.py +35 -90
- aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
- aiagents4pharma/talk2cells/tests/scp_agent/test_scp_agent.py +23 -0
- aiagents4pharma/talk2cells/tools/scp_agent/__init__.py +6 -0
- aiagents4pharma/talk2cells/tools/scp_agent/display_studies.py +25 -0
- aiagents4pharma/talk2cells/tools/scp_agent/search_studies.py +79 -0
- aiagents4pharma/talk2competitors/__init__.py +5 -0
- aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
- aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
- aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
- aiagents4pharma/talk2competitors/config/__init__.py +5 -0
- aiagents4pharma/talk2competitors/config/config.py +110 -0
- aiagents4pharma/talk2competitors/state/__init__.py +5 -0
- aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
- aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
- aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
- aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
- aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
- aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
- aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
- aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
- aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
- aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
- aiagents4pharma/talk2knowledgegraphs/tests/__init__.py +0 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_biobridge_primekg.py +242 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_dataset.py +29 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_primekg.py +73 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_datasets_starkqa_primekg.py +116 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_embeddings.py +47 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +45 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_sentencetransformer.py +40 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
- aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
- aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
- aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +44 -25
- aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
- aiagents4pharma-1.8.0.dist-info/RECORD +0 -35
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
- {aiagents4pharma-1.8.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,167 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
Tool for parameter scan.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Type, Annotated
|
9
|
+
import basico
|
10
|
+
from pydantic import BaseModel, Field
|
11
|
+
from langgraph.types import Command
|
12
|
+
from langgraph.prebuilt import InjectedState
|
13
|
+
from langchain_core.tools import BaseTool
|
14
|
+
from langchain_core.messages import ToolMessage
|
15
|
+
from langchain_core.tools.base import InjectedToolCallId
|
16
|
+
from .load_biomodel import ModelData, load_biomodel
|
17
|
+
from .load_arguments import ArgumentData, add_rec_events
|
18
|
+
|
19
|
+
# Initialize logger
|
20
|
+
logging.basicConfig(level=logging.INFO)
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
def run_steady_state(model_object,
|
24
|
+
dic_species_to_be_analyzed_before_experiment):
|
25
|
+
"""
|
26
|
+
Run the steady state analysis.
|
27
|
+
|
28
|
+
Args:
|
29
|
+
model_object: The model object.
|
30
|
+
dic_species_to_be_analyzed_before_experiment: Dictionary of species data.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
DataFrame: The results of the steady state analysis.
|
34
|
+
"""
|
35
|
+
# Update the fixed model species and parameters
|
36
|
+
# These are the initial conditions of the model
|
37
|
+
# set by the user
|
38
|
+
model_object.update_parameters(dic_species_to_be_analyzed_before_experiment)
|
39
|
+
logger.log(logging.INFO, "Running steady state analysis")
|
40
|
+
# Run the steady state analysis
|
41
|
+
output = basico.task_steadystate.run_steadystate(model=model_object.copasi_model)
|
42
|
+
if output == 0:
|
43
|
+
logger.error("Steady state analysis failed")
|
44
|
+
raise ValueError("A steady state was not found")
|
45
|
+
logger.log(logging.INFO, "Steady state analysis successful")
|
46
|
+
# Store the steady state results in a DataFrame
|
47
|
+
df_steady_state = basico.model_info.get_species(model=model_object.copasi_model).reset_index()
|
48
|
+
# print (df_steady_state)
|
49
|
+
# Rename the column name to species_name
|
50
|
+
df_steady_state.rename(columns={'name': 'species_name'},
|
51
|
+
inplace=True)
|
52
|
+
# Rename the column concentration to steady_state_concentration
|
53
|
+
df_steady_state.rename(columns={'concentration': 'steady_state_concentration'},
|
54
|
+
inplace=True)
|
55
|
+
# Rename the column transition_time to steady_state_transition_time
|
56
|
+
df_steady_state.rename(columns={'transition_time': 'steady_state_transition_time'},
|
57
|
+
inplace=True)
|
58
|
+
# Drop some columns
|
59
|
+
df_steady_state.drop(columns=
|
60
|
+
[
|
61
|
+
'initial_particle_number',
|
62
|
+
'initial_expression',
|
63
|
+
'expression',
|
64
|
+
'particle_number',
|
65
|
+
'type',
|
66
|
+
'particle_number_rate',
|
67
|
+
'key',
|
68
|
+
'sbml_id',
|
69
|
+
'display_name'],
|
70
|
+
inplace=True)
|
71
|
+
logger.log(logging.INFO, "Steady state results with shape %s", df_steady_state.shape)
|
72
|
+
return df_steady_state
|
73
|
+
|
74
|
+
class SteadyStateInput(BaseModel):
|
75
|
+
"""
|
76
|
+
Input schema for the steady state tool.
|
77
|
+
"""
|
78
|
+
sys_bio_model: ModelData = Field(description="model data",
|
79
|
+
default=None)
|
80
|
+
arg_data: ArgumentData = Field(
|
81
|
+
description="time, species, and reocurring data"
|
82
|
+
" that must be set before the steady state analysis"
|
83
|
+
" as well as the experiment name", default=None)
|
84
|
+
tool_call_id: Annotated[str, InjectedToolCallId]
|
85
|
+
state: Annotated[dict, InjectedState]
|
86
|
+
|
87
|
+
# Note: It's important that every field has type hints. BaseTool is a
|
88
|
+
# Pydantic class and not having type hints can lead to unexpected behavior.
|
89
|
+
class SteadyStateTool(BaseTool):
|
90
|
+
"""
|
91
|
+
Tool to bring a model to steady state.
|
92
|
+
"""
|
93
|
+
name: str = "steady_state"
|
94
|
+
description: str = "A tool to bring a model to steady state."
|
95
|
+
args_schema: Type[BaseModel] = SteadyStateInput
|
96
|
+
|
97
|
+
def _run(self,
|
98
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
99
|
+
state: Annotated[dict, InjectedState],
|
100
|
+
sys_bio_model: ModelData = None,
|
101
|
+
arg_data: ArgumentData = None
|
102
|
+
) -> Command:
|
103
|
+
"""
|
104
|
+
Run the tool.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
tool_call_id (str): The tool call ID. This is injected by the system.
|
108
|
+
state (dict): The state of the tool.
|
109
|
+
sys_bio_model (ModelData): The model data.
|
110
|
+
arg_data (ArgumentData): The argument data.
|
111
|
+
|
112
|
+
Returns:
|
113
|
+
Command: The updated state of the tool.
|
114
|
+
"""
|
115
|
+
logger.log(logging.INFO, "Calling the steady_state tool %s, %s",
|
116
|
+
sys_bio_model, arg_data)
|
117
|
+
# print (f'Calling steady_state tool {sys_bio_model}, {arg_data}, {tool_call_id}')
|
118
|
+
sbml_file_path = state['sbml_file_path'][-1] if len(state['sbml_file_path']) > 0 else None
|
119
|
+
model_object = load_biomodel(sys_bio_model,
|
120
|
+
sbml_file_path=sbml_file_path)
|
121
|
+
# Prepare the dictionary of species data
|
122
|
+
# that will be passed to the simulate method
|
123
|
+
# of the BasicoModel class
|
124
|
+
dic_species_to_be_analyzed_before_experiment = {}
|
125
|
+
if arg_data:
|
126
|
+
# Prepare the dictionary of species data
|
127
|
+
if arg_data.species_to_be_analyzed_before_experiment is not None:
|
128
|
+
dic_species_to_be_analyzed_before_experiment = dict(
|
129
|
+
zip(arg_data.species_to_be_analyzed_before_experiment.species_name,
|
130
|
+
arg_data.species_to_be_analyzed_before_experiment.species_concentration))
|
131
|
+
# Add reocurring events (if any) to the model
|
132
|
+
if arg_data.reocurring_data is not None:
|
133
|
+
add_rec_events(model_object, arg_data.reocurring_data)
|
134
|
+
# Run the parameter scan
|
135
|
+
df_steady_state = run_steady_state(model_object,
|
136
|
+
dic_species_to_be_analyzed_before_experiment)
|
137
|
+
# Prepare the dictionary of scanned data
|
138
|
+
# that will be passed to the state of the graph
|
139
|
+
dic_steady_state_data = {
|
140
|
+
'name': arg_data.experiment_name,
|
141
|
+
'source': sys_bio_model.biomodel_id if sys_bio_model.biomodel_id else 'upload',
|
142
|
+
'tool_call_id': tool_call_id,
|
143
|
+
'data': df_steady_state.to_dict(orient='records')
|
144
|
+
}
|
145
|
+
# Prepare the dictionary of updated state for the model
|
146
|
+
dic_updated_state_for_model = {}
|
147
|
+
for key, value in {
|
148
|
+
"model_id": [sys_bio_model.biomodel_id],
|
149
|
+
"sbml_file_path": [sbml_file_path],
|
150
|
+
"dic_steady_state_data": [dic_steady_state_data]
|
151
|
+
}.items():
|
152
|
+
if value:
|
153
|
+
dic_updated_state_for_model[key] = value
|
154
|
+
# Return the updated state
|
155
|
+
return Command(
|
156
|
+
update=dic_updated_state_for_model|{
|
157
|
+
# Update the message history
|
158
|
+
"messages": [
|
159
|
+
ToolMessage(
|
160
|
+
content=f"Steady state analysis of"
|
161
|
+
f" {arg_data.experiment_name}"
|
162
|
+
" was successful.",
|
163
|
+
tool_call_id=tool_call_id
|
164
|
+
)
|
165
|
+
],
|
166
|
+
}
|
167
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
'''
|
2
|
+
Test cases for the search_studies
|
3
|
+
'''
|
4
|
+
|
5
|
+
# from ..tools.search_studies import search_studies
|
6
|
+
from aiagents4pharma.talk2cells.agents.scp_agent import get_app
|
7
|
+
from langchain_core.messages import HumanMessage
|
8
|
+
|
9
|
+
def test_agent_scp():
|
10
|
+
'''
|
11
|
+
Test the agent_scp.
|
12
|
+
'''
|
13
|
+
unique_id = 12345
|
14
|
+
app = get_app(unique_id)
|
15
|
+
config = {"configurable": {"thread_id": unique_id}}
|
16
|
+
prompt = "Search for studies on Crohns Disease."
|
17
|
+
response = app.invoke(
|
18
|
+
{"messages": [HumanMessage(content=prompt)]},
|
19
|
+
config=config
|
20
|
+
)
|
21
|
+
assistant_msg = response["messages"][-1].content
|
22
|
+
# Check if the assistant message is a string
|
23
|
+
assert isinstance(assistant_msg, str)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
'''
|
4
|
+
This tool is used to display the table of studies.
|
5
|
+
'''
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated
|
9
|
+
from langchain_core.tools import tool
|
10
|
+
from langgraph.prebuilt import InjectedState
|
11
|
+
|
12
|
+
# Initialize logger
|
13
|
+
logging.basicConfig(level=logging.INFO)
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
@tool('display_studies')
|
17
|
+
def display_studies(state: Annotated[dict, InjectedState]):
|
18
|
+
"""
|
19
|
+
Display the table of studies.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
state (dict): The state of the agent.
|
23
|
+
"""
|
24
|
+
logger.log(logging.INFO, "Calling the tool display_studies")
|
25
|
+
return state["search_table"]
|
@@ -0,0 +1,79 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
'''
|
4
|
+
A tool to fetch studies from the Single Cell Portal.
|
5
|
+
'''
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Annotated
|
9
|
+
import requests
|
10
|
+
from langchain_core.tools import tool
|
11
|
+
from langchain_core.tools.base import InjectedToolCallId
|
12
|
+
from langchain_core.messages import ToolMessage
|
13
|
+
from langgraph.types import Command
|
14
|
+
import pandas as pd
|
15
|
+
|
16
|
+
# Initialize logger
|
17
|
+
logging.basicConfig(level=logging.INFO)
|
18
|
+
logger = logging.getLogger(__name__)
|
19
|
+
|
20
|
+
@tool('search_studies')
|
21
|
+
def search_studies(search_term: str,
|
22
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
23
|
+
limit: int = 5):
|
24
|
+
"""
|
25
|
+
Fetch studies from single cell portal
|
26
|
+
|
27
|
+
Args:
|
28
|
+
search_term (str): The search term to use. Example: "COVID-19", "cancer", etc.
|
29
|
+
limit (int): The number of papers to return. Default is 5.
|
30
|
+
|
31
|
+
"""
|
32
|
+
logger.log(logging.INFO, "Calling the tool search_studies")
|
33
|
+
scp_endpoint = 'https://singlecell.broadinstitute.org/single_cell/api/v1/search?type=study'
|
34
|
+
# params = {'terms': search_term, 'facets': 'MONDO_0005011'}
|
35
|
+
params = {'terms': search_term}
|
36
|
+
status_code = 0
|
37
|
+
while status_code != 200:
|
38
|
+
# Make a GET request to the single cell portal
|
39
|
+
search_response = requests.get(scp_endpoint,
|
40
|
+
params=params,
|
41
|
+
timeout=10,
|
42
|
+
verify=False)
|
43
|
+
status_code = search_response.status_code
|
44
|
+
logger.log(logging.INFO, "Status code %s received from SCP")
|
45
|
+
|
46
|
+
# Select the columns to display in the table
|
47
|
+
selected_columns = ["study_source", "name", "study_url", "gene_count", "cell_count"]
|
48
|
+
|
49
|
+
# Extract the data from the response
|
50
|
+
# with the selected columns
|
51
|
+
df = pd.DataFrame(search_response.json()['studies'])[selected_columns]
|
52
|
+
|
53
|
+
# Convert column 'Study Name' into clickable
|
54
|
+
# hyperlinks from the column 'Study URL'
|
55
|
+
scp_api_url = 'https://singlecell.broadinstitute.org'
|
56
|
+
df['name'] = df.apply(
|
57
|
+
lambda x: f"<a href=\"{scp_api_url}/{x['study_url']}\">{x['name']}</a>",
|
58
|
+
axis=1)
|
59
|
+
|
60
|
+
# Excldue the column 'Study URL' from the dataframe
|
61
|
+
df = df.drop(columns=['study_url'])
|
62
|
+
|
63
|
+
# Add a new column a the beginning of the dataframe with row numbers
|
64
|
+
df.insert(0, 'S/N', range(1, 1 + len(df)))
|
65
|
+
|
66
|
+
# Update the state key 'search_table' with the dataframe in markdown format
|
67
|
+
return Command(
|
68
|
+
update={
|
69
|
+
# update the state keys
|
70
|
+
"search_table": df.to_markdown(tablefmt="grid"),
|
71
|
+
# update the message history
|
72
|
+
"messages": [
|
73
|
+
ToolMessage(
|
74
|
+
f"Successfully fetched {limit} studies on {search_term}.",
|
75
|
+
tool_call_id=tool_call_id
|
76
|
+
)
|
77
|
+
],
|
78
|
+
}
|
79
|
+
)
|
@@ -0,0 +1,130 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
|
3
|
+
"""
|
4
|
+
Main agent for the talk2competitors app.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from typing import Literal
|
9
|
+
from dotenv import load_dotenv
|
10
|
+
from langchain_core.language_models.chat_models import BaseChatModel
|
11
|
+
from langchain_core.messages import AIMessage
|
12
|
+
from langchain_openai import ChatOpenAI
|
13
|
+
from langgraph.checkpoint.memory import MemorySaver
|
14
|
+
from langgraph.graph import END, START, StateGraph
|
15
|
+
from langgraph.types import Command
|
16
|
+
from ..agents import s2_agent
|
17
|
+
from ..config.config import config
|
18
|
+
from ..state.state_talk2competitors import Talk2Competitors
|
19
|
+
|
20
|
+
logging.basicConfig(level=logging.INFO)
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
load_dotenv()
|
24
|
+
|
25
|
+
def make_supervisor_node(llm: BaseChatModel) -> str:
|
26
|
+
"""
|
27
|
+
Creates a supervisor node following LangGraph patterns.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
llm (BaseChatModel): The language model to use for generating responses.
|
31
|
+
|
32
|
+
Returns:
|
33
|
+
str: The supervisor node function.
|
34
|
+
"""
|
35
|
+
# options = ["FINISH", "s2_agent"]
|
36
|
+
|
37
|
+
def supervisor_node(state: Talk2Competitors) -> Command[Literal["s2_agent", "__end__"]]:
|
38
|
+
"""
|
39
|
+
Supervisor node that routes to appropriate sub-agents.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
state (Talk2Competitors): The current state of the conversation.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
Command[Literal["s2_agent", "__end__"]]: The command to execute next.
|
46
|
+
"""
|
47
|
+
logger.info("Supervisor node called")
|
48
|
+
|
49
|
+
messages = [{"role": "system", "content": config.MAIN_AGENT_PROMPT}] + state[
|
50
|
+
"messages"
|
51
|
+
]
|
52
|
+
response = llm.invoke(messages)
|
53
|
+
goto = (
|
54
|
+
"FINISH"
|
55
|
+
if not any(
|
56
|
+
kw in state["messages"][-1].content.lower()
|
57
|
+
for kw in ["search", "paper", "find"]
|
58
|
+
)
|
59
|
+
else "s2_agent"
|
60
|
+
)
|
61
|
+
|
62
|
+
if goto == "FINISH":
|
63
|
+
return Command(
|
64
|
+
goto=END,
|
65
|
+
update={
|
66
|
+
"messages": state["messages"]
|
67
|
+
+ [AIMessage(content=response.content)],
|
68
|
+
"is_last_step": True,
|
69
|
+
"current_agent": None,
|
70
|
+
},
|
71
|
+
)
|
72
|
+
|
73
|
+
return Command(
|
74
|
+
goto="s2_agent",
|
75
|
+
update={
|
76
|
+
"messages": state["messages"],
|
77
|
+
"is_last_step": False,
|
78
|
+
"current_agent": "s2_agent",
|
79
|
+
},
|
80
|
+
)
|
81
|
+
|
82
|
+
return supervisor_node
|
83
|
+
|
84
|
+
def get_app(thread_id: str, llm_model ='gpt-4o-mini') -> StateGraph:
|
85
|
+
"""
|
86
|
+
Returns the langraph app with hierarchical structure.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
thread_id (str): The thread ID for the conversation.
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
The compiled langraph app.
|
93
|
+
"""
|
94
|
+
def call_s2_agent(state: Talk2Competitors) -> Command[Literal["__end__"]]:
|
95
|
+
"""
|
96
|
+
Node for calling the S2 agent.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
state (Talk2Competitors): The current state of the conversation.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Command[Literal["__end__"]]: The command to execute next.
|
103
|
+
"""
|
104
|
+
logger.info("Calling S2 agent")
|
105
|
+
app = s2_agent.get_app(thread_id, llm_model)
|
106
|
+
response = app.invoke(state)
|
107
|
+
logger.info("S2 agent completed")
|
108
|
+
return Command(
|
109
|
+
goto=END,
|
110
|
+
update={
|
111
|
+
"messages": response["messages"],
|
112
|
+
"papers": response.get("papers", []),
|
113
|
+
"is_last_step": True,
|
114
|
+
"current_agent": "s2_agent",
|
115
|
+
},
|
116
|
+
)
|
117
|
+
llm = ChatOpenAI(model=llm_model, temperature=0)
|
118
|
+
workflow = StateGraph(Talk2Competitors)
|
119
|
+
|
120
|
+
supervisor = make_supervisor_node(llm)
|
121
|
+
workflow.add_node("supervisor", supervisor)
|
122
|
+
workflow.add_node("s2_agent", call_s2_agent)
|
123
|
+
|
124
|
+
# Define edges
|
125
|
+
workflow.add_edge(START, "supervisor")
|
126
|
+
workflow.add_edge("s2_agent", END)
|
127
|
+
|
128
|
+
app = workflow.compile(checkpointer=MemorySaver())
|
129
|
+
logger.info("Main agent workflow compiled")
|
130
|
+
return app
|
@@ -0,0 +1,75 @@
|
|
1
|
+
#/usr/bin/env python3
|
2
|
+
|
3
|
+
'''
|
4
|
+
Agent for interacting with Semantic Scholar
|
5
|
+
'''
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from dotenv import load_dotenv
|
9
|
+
from langchain_openai import ChatOpenAI
|
10
|
+
from langgraph.graph import START, StateGraph
|
11
|
+
from langgraph.prebuilt import create_react_agent
|
12
|
+
from langgraph.checkpoint.memory import MemorySaver
|
13
|
+
from ..config.config import config
|
14
|
+
from ..state.state_talk2competitors import Talk2Competitors
|
15
|
+
# from ..tools.s2 import s2_tools
|
16
|
+
from ..tools.s2.search import search_tool
|
17
|
+
from ..tools.s2.display_results import display_results
|
18
|
+
from ..tools.s2.single_paper_rec import get_single_paper_recommendations
|
19
|
+
from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
|
20
|
+
|
21
|
+
load_dotenv()
|
22
|
+
|
23
|
+
# Initialize logger
|
24
|
+
logging.basicConfig(level=logging.INFO)
|
25
|
+
logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
def get_app(uniq_id, llm_model='gpt-4o-mini'):
|
28
|
+
'''
|
29
|
+
This function returns the langraph app.
|
30
|
+
'''
|
31
|
+
def agent_s2_node(state: Talk2Competitors):
|
32
|
+
'''
|
33
|
+
This function calls the model.
|
34
|
+
'''
|
35
|
+
logger.log(logging.INFO, "Creating Agent_S2 node with thread_id %s", uniq_id)
|
36
|
+
response = model.invoke(state, {"configurable": {"thread_id": uniq_id}})
|
37
|
+
return response
|
38
|
+
|
39
|
+
# Define the tools
|
40
|
+
tools = [search_tool,
|
41
|
+
display_results,
|
42
|
+
get_single_paper_recommendations,
|
43
|
+
get_multi_paper_recommendations]
|
44
|
+
|
45
|
+
# Create the LLM
|
46
|
+
llm = ChatOpenAI(model=llm_model, temperature=0)
|
47
|
+
model = create_react_agent(
|
48
|
+
llm,
|
49
|
+
tools=tools,
|
50
|
+
state_schema=Talk2Competitors,
|
51
|
+
state_modifier=config.S2_AGENT_PROMPT,
|
52
|
+
checkpointer=MemorySaver()
|
53
|
+
)
|
54
|
+
|
55
|
+
# Define a new graph
|
56
|
+
workflow = StateGraph(Talk2Competitors)
|
57
|
+
|
58
|
+
# Define the two nodes we will cycle between
|
59
|
+
workflow.add_node("agent_s2", agent_s2_node)
|
60
|
+
|
61
|
+
# Set the entrypoint as `agent`
|
62
|
+
# This means that this node is the first one called
|
63
|
+
workflow.add_edge(START, "agent_s2")
|
64
|
+
|
65
|
+
# Initialize memory to persist state between graph runs
|
66
|
+
checkpointer = MemorySaver()
|
67
|
+
|
68
|
+
# Finally, we compile it!
|
69
|
+
# This compiles it into a LangChain Runnable,
|
70
|
+
# meaning you can use it as you would any other runnable.
|
71
|
+
# Note that we're (optionally) passing the memory when compiling the graph
|
72
|
+
app = workflow.compile(checkpointer=checkpointer)
|
73
|
+
logger.log(logging.INFO, "Compiled the graph")
|
74
|
+
|
75
|
+
return app
|
@@ -0,0 +1,110 @@
|
|
1
|
+
"""Configuration module for AI agents handling paper searches and recommendations."""
|
2
|
+
|
3
|
+
|
4
|
+
# pylint: disable=R0903
|
5
|
+
class Config:
|
6
|
+
"""Configuration class containing prompts for AI agents.
|
7
|
+
|
8
|
+
This class stores prompt templates used by various AI agents in the system,
|
9
|
+
particularly for academic paper searches and recommendations.
|
10
|
+
"""
|
11
|
+
|
12
|
+
MAIN_AGENT_PROMPT = (
|
13
|
+
"You are a supervisory AI agent that routes user queries to specialized tools.\n"
|
14
|
+
"Your task is to select the most appropriate tool based on the user's request.\n\n"
|
15
|
+
"Available tools and their capabilities:\n\n"
|
16
|
+
"1. semantic_scholar_agent:\n"
|
17
|
+
" - Search for academic papers and research\n"
|
18
|
+
" - Get paper recommendations\n"
|
19
|
+
" - Find similar papers\n"
|
20
|
+
" USE FOR: Any queries about finding papers, academic research, "
|
21
|
+
"or getting paper recommendations\n\n"
|
22
|
+
"ROUTING GUIDELINES:\n\n"
|
23
|
+
"ALWAYS route to semantic_scholar_agent for:\n"
|
24
|
+
"- Finding academic papers\n"
|
25
|
+
"- Searching research topics\n"
|
26
|
+
"- Getting paper recommendations\n"
|
27
|
+
"- Finding similar papers\n"
|
28
|
+
"- Any query about academic literature\n\n"
|
29
|
+
"Approach:\n"
|
30
|
+
"1. Identify the core need in the user's query\n"
|
31
|
+
"2. Select the most appropriate tool based on the guidelines above\n"
|
32
|
+
"3. If unclear, ask for clarification\n"
|
33
|
+
"4. For multi-step tasks, focus on the immediate next step\n\n"
|
34
|
+
"Remember:\n"
|
35
|
+
"- Be decisive in your tool selection\n"
|
36
|
+
"- Focus on the immediate task\n"
|
37
|
+
"- Default to semantic_scholar_agent for any paper-finding tasks\n"
|
38
|
+
"- Ask for clarification if the request is ambiguous\n\n"
|
39
|
+
"When presenting paper search results, always use this exact format:\n\n"
|
40
|
+
"Remember to:\n"
|
41
|
+
"- Always remember to add the url\n"
|
42
|
+
"- Put URLs on the title line itself as markdown\n"
|
43
|
+
"- Maintain consistent spacing and formatting"
|
44
|
+
)
|
45
|
+
|
46
|
+
S2_AGENT_PROMPT = (
|
47
|
+
"You are a specialized academic research assistant with access to the following tools:\n\n"
|
48
|
+
"1. search_papers:\n"
|
49
|
+
" USE FOR: General paper searches\n"
|
50
|
+
" - Enhances search terms automatically\n"
|
51
|
+
" - Adds relevant academic keywords\n"
|
52
|
+
" - Focuses on recent research when appropriate\n\n"
|
53
|
+
"2. get_single_paper_recommendations:\n"
|
54
|
+
" USE FOR: Finding papers similar to a specific paper\n"
|
55
|
+
" - Takes a single paper ID\n"
|
56
|
+
" - Returns related papers\n\n"
|
57
|
+
"3. get_multi_paper_recommendations:\n"
|
58
|
+
" USE FOR: Finding papers similar to multiple papers\n"
|
59
|
+
" - Takes multiple paper IDs\n"
|
60
|
+
" - Finds papers related to all inputs\n\n"
|
61
|
+
"GUIDELINES:\n\n"
|
62
|
+
"For paper searches:\n"
|
63
|
+
"- Enhance search terms with academic language\n"
|
64
|
+
"- Include field-specific terminology\n"
|
65
|
+
'- Add "recent" or "latest" when appropriate\n'
|
66
|
+
"- Keep queries focused and relevant\n\n"
|
67
|
+
"For paper recommendations:\n"
|
68
|
+
"- Identify paper IDs (40-character hexadecimal strings)\n"
|
69
|
+
"- Use single_paper_recommendations for one ID\n"
|
70
|
+
"- Use multi_paper_recommendations for multiple IDs\n\n"
|
71
|
+
"Best practices:\n"
|
72
|
+
"1. Start with a broad search if no paper IDs are provided\n"
|
73
|
+
"2. Look for paper IDs in user input\n"
|
74
|
+
"3. Enhance search terms for better results\n"
|
75
|
+
"4. Consider the academic context\n"
|
76
|
+
"5. Be prepared to refine searches based on feedback\n\n"
|
77
|
+
"Remember:\n"
|
78
|
+
"- Always select the most appropriate tool\n"
|
79
|
+
"- Enhance search queries naturally\n"
|
80
|
+
"- Consider academic context\n"
|
81
|
+
"- Focus on delivering relevant results\n\n"
|
82
|
+
"IMPORTANT GUIDELINES FOR PAPER RECOMMENDATIONS:\n\n"
|
83
|
+
"For Multiple Papers:\n"
|
84
|
+
"- When getting recommendations for multiple papers, always use "
|
85
|
+
"get_multi_paper_recommendations tool\n"
|
86
|
+
"- DO NOT call get_single_paper_recommendations multiple times\n"
|
87
|
+
"- Always pass all paper IDs in a single call to get_multi_paper_recommendations\n"
|
88
|
+
'- Use for queries like "find papers related to both/all papers" or '
|
89
|
+
'"find similar papers to these papers"\n\n'
|
90
|
+
"For Single Paper:\n"
|
91
|
+
"- Use get_single_paper_recommendations when focusing on one specific paper\n"
|
92
|
+
"- Pass only one paper ID at a time\n"
|
93
|
+
'- Use for queries like "find papers similar to this paper" or '
|
94
|
+
'"get recommendations for paper X"\n'
|
95
|
+
"- Do not use for multiple papers\n\n"
|
96
|
+
"Examples:\n"
|
97
|
+
'- For "find related papers for both papers":\n'
|
98
|
+
" ✓ Use get_multi_paper_recommendations with both paper IDs\n"
|
99
|
+
" × Don't make multiple calls to get_single_paper_recommendations\n\n"
|
100
|
+
'- For "find papers related to the first paper":\n'
|
101
|
+
" ✓ Use get_single_paper_recommendations with just that paper's ID\n"
|
102
|
+
" × Don't use get_multi_paper_recommendations\n\n"
|
103
|
+
"Remember:\n"
|
104
|
+
"- Be precise in identifying which paper ID to use for single recommendations\n"
|
105
|
+
"- Don't reuse previous paper IDs unless specifically requested\n"
|
106
|
+
"- For fresh paper recommendations, always use the original paper ID"
|
107
|
+
)
|
108
|
+
|
109
|
+
|
110
|
+
config = Config()
|
@@ -0,0 +1,32 @@
|
|
1
|
+
"""
|
2
|
+
This is the state file for the talk2comp agent.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Annotated, Any, Dict, Optional
|
7
|
+
|
8
|
+
from langgraph.prebuilt.chat_agent_executor import AgentState
|
9
|
+
from typing_extensions import NotRequired, Required
|
10
|
+
|
11
|
+
# Configure logging
|
12
|
+
logging.basicConfig(level=logging.INFO)
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
def replace_dict(existing: Dict[str, Any], new: Dict[str, Any]) -> Dict[str, Any]:
|
17
|
+
"""Replace the existing dict with the new one."""
|
18
|
+
logger.info("Updating existing state %s with the state dict: %s", existing, new)
|
19
|
+
return new
|
20
|
+
|
21
|
+
|
22
|
+
class Talk2Competitors(AgentState):
|
23
|
+
"""
|
24
|
+
The state for the talk2comp agent, inheriting from AgentState.
|
25
|
+
"""
|
26
|
+
|
27
|
+
papers: Annotated[Dict[str, Any], replace_dict] # Changed from List to Dict
|
28
|
+
search_table: NotRequired[str]
|
29
|
+
next: str # Required for routing in LangGraph
|
30
|
+
current_agent: NotRequired[Optional[str]]
|
31
|
+
is_last_step: Required[bool] # Required field for LangGraph
|
32
|
+
llm_model: str
|