aiagents4pharma 1.20.1__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +1 -0
  2. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +1 -0
  3. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_nim_molmim.py +64 -0
  4. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_pubchem.py +33 -0
  5. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_pubchem_utils.py +16 -0
  6. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +1 -0
  7. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  8. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/nim_molmim.py +54 -0
  9. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +1 -0
  10. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/pubchem_strings.py +49 -0
  11. aiagents4pharma/talk2knowledgegraphs/utils/pubchem_utils.py +42 -0
  12. aiagents4pharma/talk2scholars/agents/main_agent.py +90 -91
  13. aiagents4pharma/talk2scholars/agents/s2_agent.py +61 -17
  14. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/main_agent/default.yaml +31 -10
  15. aiagents4pharma/talk2scholars/configs/agents/talk2scholars/s2_agent/default.yaml +8 -16
  16. aiagents4pharma/talk2scholars/configs/app/frontend/default.yaml +11 -9
  17. aiagents4pharma/talk2scholars/configs/config.yaml +1 -0
  18. aiagents4pharma/talk2scholars/configs/tools/multi_paper_recommendation/default.yaml +2 -0
  19. aiagents4pharma/talk2scholars/configs/tools/retrieve_semantic_scholar_paper_id/__init__.py +3 -0
  20. aiagents4pharma/talk2scholars/configs/tools/search/default.yaml +1 -0
  21. aiagents4pharma/talk2scholars/configs/tools/single_paper_recommendation/default.yaml +1 -0
  22. aiagents4pharma/talk2scholars/state/state_talk2scholars.py +36 -7
  23. aiagents4pharma/talk2scholars/tests/test_llm_main_integration.py +58 -0
  24. aiagents4pharma/talk2scholars/tests/test_main_agent.py +98 -122
  25. aiagents4pharma/talk2scholars/tests/test_s2_agent.py +95 -29
  26. aiagents4pharma/talk2scholars/tests/test_s2_tools.py +158 -22
  27. aiagents4pharma/talk2scholars/tools/s2/__init__.py +4 -2
  28. aiagents4pharma/talk2scholars/tools/s2/display_results.py +60 -21
  29. aiagents4pharma/talk2scholars/tools/s2/multi_paper_rec.py +35 -8
  30. aiagents4pharma/talk2scholars/tools/s2/query_results.py +61 -0
  31. aiagents4pharma/talk2scholars/tools/s2/retrieve_semantic_scholar_paper_id.py +79 -0
  32. aiagents4pharma/talk2scholars/tools/s2/search.py +34 -10
  33. aiagents4pharma/talk2scholars/tools/s2/single_paper_rec.py +39 -9
  34. {aiagents4pharma-1.20.1.dist-info → aiagents4pharma-1.22.0.dist-info}/METADATA +2 -1
  35. {aiagents4pharma-1.20.1.dist-info → aiagents4pharma-1.22.0.dist-info}/RECORD +38 -29
  36. aiagents4pharma/talk2scholars/tests/test_integration.py +0 -237
  37. {aiagents4pharma-1.20.1.dist-info → aiagents4pharma-1.22.0.dist-info}/LICENSE +0 -0
  38. {aiagents4pharma-1.20.1.dist-info → aiagents4pharma-1.22.0.dist-info}/WHEEL +0 -0
  39. {aiagents4pharma-1.20.1.dist-info → aiagents4pharma-1.22.0.dist-info}/top_level.txt +0 -0
@@ -5,3 +5,4 @@ Import all the modules in the package
5
5
  from . import agents
6
6
  from . import tools
7
7
  from . import app
8
+ from . import utils
@@ -4,4 +4,5 @@ defaults:
4
4
  - tools/subgraph_extraction: default
5
5
  - tools/subgraph_summarization: default
6
6
  - tools/graphrag_reasoning: default
7
+ - utils/pubchem_utils: default
7
8
  - app/frontend: default
@@ -0,0 +1,64 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Test cases for utils/embeddings/nim_molmim.py
5
+ """
6
+
7
+ import unittest
8
+ from unittest.mock import patch, MagicMock
9
+ from ..utils.embeddings.nim_molmim import EmbeddingWithMOLMIM
10
+
11
+ class TestEmbeddingWithMOLMIM(unittest.TestCase):
12
+ """
13
+ Test cases for EmbeddingWithMOLMIM class.
14
+ """
15
+ def setUp(self):
16
+ self.base_url = "https://fake-nim-api.com/embeddings"
17
+ self.embeddings_model = EmbeddingWithMOLMIM(self.base_url)
18
+ self.test_texts = ["CCO", "CCC", "C=O"]
19
+ self.test_query = "CCO"
20
+ self.mock_response = {
21
+ "embeddings": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]
22
+ }
23
+
24
+ @patch("requests.post")
25
+ def test_embed_documents(self, mock_post):
26
+ '''
27
+ Test the embed_documents method.
28
+ '''
29
+ # Mock the response from requests.post
30
+ mock_post.return_value = MagicMock()
31
+ mock_post.return_value.json.return_value = self.mock_response
32
+ embeddings = self.embeddings_model.embed_documents(self.test_texts)
33
+ # Assertions
34
+ self.assertEqual(embeddings, self.mock_response["embeddings"])
35
+ mock_post.assert_called_once_with(
36
+ self.base_url,
37
+ headers={
38
+ 'accept': 'application/json',
39
+ 'Content-Type': 'application/json'
40
+ },
41
+ data='{"sequences": ["CCO", "CCC", "C=O"]}',
42
+ timeout=60
43
+ )
44
+
45
+ @patch("requests.post")
46
+ def test_embed_query(self, mock_post):
47
+ '''
48
+ Test the embed_query method.
49
+ '''
50
+ # Mock the response from requests.post
51
+ mock_post.return_value = MagicMock()
52
+ mock_post.return_value.json.return_value = {"embeddings": [[0.1, 0.2, 0.3]]}
53
+ embedding = self.embeddings_model.embed_query(self.test_query)
54
+ # Assertions
55
+ self.assertEqual(embedding, [[0.1, 0.2, 0.3]])
56
+ mock_post.assert_called_once_with(
57
+ self.base_url,
58
+ headers={
59
+ 'accept': 'application/json',
60
+ 'Content-Type': 'application/json'
61
+ },
62
+ data='{"sequences": ["CCO"]}',
63
+ timeout=60
64
+ )
@@ -0,0 +1,33 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Test cases for utils/enrichments/pubchem_strings.py
5
+ """
6
+
7
+ import pytest
8
+ from ..utils.enrichments.pubchem_strings import EnrichmentWithPubChem
9
+
10
+ # In this test, we will consider 2 examples:
11
+ # 1. PubChem ID: 5311000 (Alclometasone)
12
+ # 2. PubChem ID: 1X (Fake ID)
13
+ # The expected SMILES representation for the first PubChem ID is:
14
+ SMILES_FIRST = 'C[C@@H]1C[C@H]2[C@@H]3[C@@H](CC4=CC(=O)C=C[C@@]'
15
+ SMILES_FIRST += '4([C@H]3[C@H](C[C@@]2([C@]1(C(=O)CO)O)C)O)C)Cl'
16
+ # The expected SMILES representation for the second PubChem ID is None.
17
+
18
+ @pytest.fixture(name="enrich_obj")
19
+ def fixture_pubchem_config():
20
+ """Return a dictionary with the configuration for the PubChem enrichment."""
21
+ return EnrichmentWithPubChem()
22
+
23
+ def test_enrich_documents(enrich_obj):
24
+ """Test the enrich_documents method."""
25
+ pubchem_ids = ["5311000", "1X"]
26
+ enriched_strings = enrich_obj.enrich_documents(pubchem_ids)
27
+ assert enriched_strings == [SMILES_FIRST, None]
28
+
29
+ def test_enrich_documents_with_rag(enrich_obj):
30
+ """Test the enrich_documents_with_rag method."""
31
+ pubchem_ids = ["5311000", "1X"]
32
+ enriched_strings = enrich_obj.enrich_documents_with_rag(pubchem_ids, None)
33
+ assert enriched_strings == [SMILES_FIRST, None]
@@ -0,0 +1,16 @@
1
+ """
2
+ Test cases for utils/pubchem_utils.py
3
+ """
4
+
5
+ from ..utils import pubchem_utils
6
+
7
+ def test_drugbank_id2pubchem_cid():
8
+ """
9
+ Test the drugbank_id2pubchem_cid method.
10
+
11
+ The DrugBank ID for Alclometasone is DB00240.
12
+ The PubChem CID for Alclometasone is 5311000.
13
+ """
14
+ drugbank_id = "DB00240"
15
+ pubchem_cid = pubchem_utils.drugbank_id2pubchem_cid(drugbank_id)
16
+ assert pubchem_cid == 5311000
@@ -5,3 +5,4 @@ from . import embeddings
5
5
  from . import enrichments
6
6
  from . import extractions
7
7
  from . import kg_utils
8
+ from . import pubchem_utils
@@ -5,3 +5,4 @@ from . import embeddings
5
5
  from . import sentence_transformer
6
6
  from . import huggingface
7
7
  from . import ollama
8
+ from . import nim_molmim
@@ -0,0 +1,54 @@
1
+ """
2
+ Embedding class using MOLMIM model from NVIDIA NIM.
3
+ """
4
+
5
+ import json
6
+ from typing import List
7
+ import requests
8
+ from .embeddings import Embeddings
9
+
10
+ class EmbeddingWithMOLMIM(Embeddings):
11
+ """
12
+ Embedding class using MOLMIM model from NVIDIA NIM
13
+ """
14
+ def __init__(self, base_url: str):
15
+ """
16
+ Initialize the EmbeddingWithMOLMIM class.
17
+
18
+ Args:
19
+ base_url: The base URL for the NIM/MOLMIM model.
20
+ """
21
+ # Set base URL
22
+ self.base_url = base_url
23
+
24
+ def embed_documents(self, texts: List[str]) -> List[float]:
25
+ """
26
+ Generate embedding for a list of SMILES strings using MOLMIM model.
27
+
28
+ Args:
29
+ texts: The list of SMILES strings to be embedded.
30
+
31
+ Returns:
32
+ The list of embeddings for the given SMILES strings.
33
+ """
34
+ headers = {
35
+ 'accept': 'application/json',
36
+ 'Content-Type': 'application/json'
37
+ }
38
+ data = json.dumps({"sequences": texts})
39
+ response = requests.post(self.base_url, headers=headers, data=data, timeout=60)
40
+ embeddings = response.json()["embeddings"]
41
+ return embeddings
42
+
43
+ def embed_query(self, text: str) -> List[float]:
44
+ """
45
+ Generate embeddings for an input query using MOLMIM model.
46
+
47
+ Args:
48
+ text: A query to be embedded.
49
+ Returns:
50
+ The embeddings for the given query.
51
+ """
52
+ # Generate the embedding
53
+ embeddings = self.embed_documents([text])
54
+ return embeddings
@@ -3,3 +3,4 @@ This package contains modules to use the enrichment model
3
3
  """
4
4
  from . import enrichments
5
5
  from . import ollama
6
+ from . import pubchem_strings
@@ -0,0 +1,49 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Enrichment class for enriching PubChem IDs with their STRINGS representation.
5
+ """
6
+
7
+ from typing import List
8
+ import pubchempy as pcp
9
+ from .enrichments import Enrichments
10
+
11
+ class EnrichmentWithPubChem(Enrichments):
12
+ """
13
+ Enrichment class using PubChem
14
+ """
15
+ def enrich_documents(self, texts: List[str]) -> List[str]:
16
+ """
17
+ Enrich a list of input PubChem IDs with their STRINGS representation.
18
+
19
+ Args:
20
+ texts: The list of pubchem IDs to be enriched.
21
+
22
+ Returns:
23
+ The list of enriched STRINGS
24
+ """
25
+
26
+ enriched_pubchem_ids = []
27
+ pubchem_cids = texts
28
+ for pubchem_cid in pubchem_cids:
29
+ try:
30
+ c = pcp.Compound.from_cid(pubchem_cid)
31
+ except pcp.BadRequestError:
32
+ enriched_pubchem_ids.append(None)
33
+ continue
34
+ enriched_pubchem_ids.append(c.isomeric_smiles)
35
+
36
+ return enriched_pubchem_ids
37
+
38
+ def enrich_documents_with_rag(self, texts, docs):
39
+ """
40
+ Enrich a list of input PubChem IDs with their STRINGS representation.
41
+
42
+ Args:
43
+ texts: The list of pubchem IDs to be enriched.
44
+ docs: None
45
+
46
+ Returns:
47
+ The list of enriched STRINGS
48
+ """
49
+ return self.enrich_documents(texts)
@@ -0,0 +1,42 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Enrichment class for enriching PubChem IDs with their STRINGS representation.
5
+ """
6
+
7
+ import logging
8
+ import requests
9
+ import hydra
10
+
11
+ # Initialize logger
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ def drugbank_id2pubchem_cid(drugbank_id):
16
+ """
17
+ Convert DrugBank ID to PubChem CID.
18
+
19
+ Args:
20
+ drugbank_id: The DrugBank ID of the drug.
21
+
22
+ Returns:
23
+ The PubChem CID of the drug.
24
+ """
25
+ logger.log(logging.INFO, "Load Hydra configuration for PubChem ID conversion.")
26
+ with hydra.initialize(version_base=None, config_path="../configs"):
27
+ cfg = hydra.compose(config_name='config',
28
+ overrides=['utils/pubchem_utils=default'])
29
+ cfg = cfg.utils.pubchem_utils
30
+ # Prepare the URL
31
+ pubchem_url_for_drug = cfg.drugbank_id_to_pubchem_cid_url + drugbank_id + '/JSON'
32
+ # Get the data
33
+ response = requests.get(pubchem_url_for_drug, timeout=60)
34
+ data = response.json()
35
+ # Extract the PubChem CID
36
+ cid = None
37
+ for substance in data.get("PC_Substances", []):
38
+ for compound in substance.get("compound", []):
39
+ if "id" in compound and "type" in compound["id"] and compound["id"]["type"] == 1:
40
+ cid = compound["id"].get("id", {}).get("cid")
41
+ break
42
+ return cid
@@ -6,28 +6,17 @@ Main agent for the talk2scholars app using ReAct pattern.
6
6
  This module implements a hierarchical agent system where a supervisor agent
7
7
  routes queries to specialized sub-agents. It follows the LangGraph patterns
8
8
  for multi-agent systems and implements proper state management.
9
-
10
- The main components are:
11
- 1. Supervisor node with ReAct pattern for intelligent routing.
12
- 2. S2 agent node for handling academic paper queries.
13
- 3. Shared state management via Talk2Scholars.
14
- 4. Hydra-based configuration system.
15
-
16
- Example:
17
- app = get_app("thread_123", "gpt-4o-mini")
18
- result = app.invoke({
19
- "messages": [("human", "Find papers about AI agents")]
20
- })
21
9
  """
22
10
 
23
11
  import logging
24
12
  from typing import Literal, Callable
13
+ from pydantic import BaseModel
25
14
  import hydra
26
15
  from langchain_core.language_models.chat_models import BaseChatModel
16
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
27
17
  from langchain_openai import ChatOpenAI
28
18
  from langgraph.checkpoint.memory import MemorySaver
29
19
  from langgraph.graph import END, START, StateGraph
30
- from langgraph.prebuilt import create_react_agent
31
20
  from langgraph.types import Command
32
21
  from ..agents import s2_agent
33
22
  from ..state.state_talk2scholars import Talk2Scholars
@@ -39,13 +28,13 @@ logger = logging.getLogger(__name__)
39
28
 
40
29
  def get_hydra_config():
41
30
  """
42
- Loads and returns the Hydra configuration for the main agent.
31
+ Loads the Hydra configuration for the main agent.
43
32
 
44
- This function fetches the configuration settings for the Talk2Scholars
45
- agent, ensuring that all required parameters are properly initialized.
33
+ This function initializes the Hydra configuration system and retrieves the settings
34
+ for the `Talk2Scholars` agent, ensuring that all required parameters are loaded.
46
35
 
47
36
  Returns:
48
- Any: The configuration object for the main agent.
37
+ DictConfig: The configuration object containing parameters for the main agent.
49
38
  """
50
39
  with hydra.initialize(version_base=None, config_path="../configs"):
51
40
  cfg = hydra.compose(
@@ -54,116 +43,127 @@ def get_hydra_config():
54
43
  return cfg.agents.talk2scholars.main_agent
55
44
 
56
45
 
57
- def make_supervisor_node(llm: BaseChatModel, thread_id: str) -> Callable:
46
+ def make_supervisor_node(llm_model: BaseChatModel, thread_id: str) -> Callable:
58
47
  """
59
- Creates and returns a supervisor node for intelligent routing using the ReAct pattern.
48
+ Creates the supervisor node responsible for routing user queries to the appropriate sub-agents.
60
49
 
61
- This function initializes a supervisor agent that processes user queries and
62
- determines the appropriate sub-agent for further processing. It applies structured
63
- reasoning to manage conversations and direct queries based on context.
50
+ This function initializes the routing logic by leveraging the system and router prompts defined
51
+ in the Hydra configuration. The supervisor determines whether to
52
+ call a sub-agent (like `s2_agent`)
53
+ or directly generate a response using the language model.
64
54
 
65
55
  Args:
66
- llm (BaseChatModel): The language model used by the supervisor agent.
67
- thread_id (str): Unique identifier for the conversation session.
56
+ llm_model (BaseChatModel): The language model used for decision-making.
57
+ thread_id (str): Unique identifier for the current conversation session.
68
58
 
69
59
  Returns:
70
- Callable: A function that acts as the supervisor node in the LangGraph workflow.
71
-
72
- Example:
73
- supervisor = make_supervisor_node(llm, "thread_123")
74
- workflow.add_node("supervisor", supervisor)
60
+ Callable: The supervisor node function that processes user queries and
61
+ decides the next step.
75
62
  """
76
- logger.info("Loading Hydra configuration for Talk2Scholars main agent.")
77
63
  cfg = get_hydra_config()
78
- logger.info("Hydra configuration loaded with values: %s", cfg)
64
+ logger.info("Hydra configuration for Talk2Scholars main agent loaded: %s", cfg)
65
+ members = ["s2_agent"]
66
+ options = ["FINISH"] + members
67
+ # Define system prompt for general interactions
68
+ system_prompt = cfg.system_prompt
69
+ # Define router prompt for routing to sub-agents
70
+ router_prompt = cfg.router_prompt
71
+
72
+ class Router(BaseModel):
73
+ """Worker to route to next. If no workers needed, route to FINISH."""
79
74
 
80
- # Create the supervisor agent using the main agent's configuration
81
- supervisor_agent = create_react_agent(
82
- llm,
83
- tools=[], # Will add sub-agents later
84
- state_modifier=cfg.main_agent,
85
- state_schema=Talk2Scholars,
86
- checkpointer=MemorySaver(),
87
- )
75
+ next: Literal[*options]
88
76
 
89
77
  def supervisor_node(
90
78
  state: Talk2Scholars,
91
- ) -> Command[Literal["s2_agent", "__end__"]]:
79
+ ) -> Command:
92
80
  """
93
- Processes user queries and determines the next step in the conversation flow.
81
+ Handles the routing logic for the supervisor agent.
94
82
 
95
- This function examines the conversation state and decides whether to forward
96
- the query to a specialized sub-agent (e.g., S2 agent) or conclude the interaction.
83
+ This function determines the next agent to invoke based on the router prompt response.
84
+ If no further processing is required, it generates an AI response using the system prompt.
97
85
 
98
86
  Args:
99
- state (Talk2Scholars): The current state of the conversation, containing
100
- messages, papers, and metadata.
87
+ state (Talk2Scholars): The current conversation state, including messages
88
+ exchanged so far.
101
89
 
102
90
  Returns:
103
- Command: The next action to be executed, along with updated state data.
104
-
105
- Example:
106
- result = supervisor_node(current_state)
107
- next_step = result.goto
91
+ Command: A command dictating whether to invoke a sub-agent or generate a final response.
108
92
  """
109
- logger.info(
110
- "Supervisor node called - Messages count: %d",
111
- len(state["messages"]),
112
- )
113
-
114
- # Invoke the supervisor agent with configurable thread_id
115
- result = supervisor_agent.invoke(
116
- state, {"configurable": {"thread_id": thread_id}}
117
- )
118
- goto = "s2_agent"
119
- logger.info("Supervisor agent completed with result: %s", result)
120
-
93
+ messages = [SystemMessage(content=router_prompt)] + state["messages"]
94
+ structured_llm = llm_model.with_structured_output(Router)
95
+ response = structured_llm.invoke(messages)
96
+ goto = response.next
97
+ logger.info("Routing to: %s, Thread ID: %s", goto, thread_id)
98
+ if goto == "FINISH":
99
+ goto = END # Using END from langgraph.graph
100
+ # If no agents were called, and the last message was
101
+ # from the user, call the LLM to respond to the user
102
+ # with a slightly different system prompt.
103
+ if isinstance(messages[-1], HumanMessage):
104
+ response = llm_model.invoke(
105
+ [
106
+ SystemMessage(content=system_prompt),
107
+ ]
108
+ + messages[1:]
109
+ )
110
+ return Command(
111
+ goto=goto, update={"messages": AIMessage(content=response.content)}
112
+ )
113
+ # Go to the requested agent
121
114
  return Command(goto=goto)
122
115
 
123
116
  return supervisor_node
124
117
 
125
118
 
126
- def get_app(thread_id: str, llm_model: str = "gpt-4o-mini") -> StateGraph:
119
+ def get_app(
120
+ thread_id: str,
121
+ llm_model: BaseChatModel = ChatOpenAI(model="gpt-4o-mini", temperature=0),
122
+ ):
127
123
  """
128
- Initializes and returns the LangGraph application with a hierarchical agent system.
124
+ Initializes and returns the LangGraph-based hierarchical agent system.
129
125
 
130
- This function sets up the full agent architecture, including the supervisor
131
- and sub-agents, and compiles the LangGraph workflow for handling user queries.
126
+ This function constructs the agent workflow by defining nodes for the supervisor
127
+ and sub-agents. It compiles the graph using `StateGraph` to enable structured
128
+ conversational workflows.
132
129
 
133
130
  Args:
134
- thread_id (str): Unique identifier for the conversation session.
135
- llm_model (str, optional): The language model to be used. Defaults to "gpt-4o-mini".
131
+ thread_id (str): A unique session identifier for tracking conversation state.
132
+ llm_model (BaseChatModel, optional): The language model used for query processing.
133
+ Defaults to `ChatOpenAI(model="gpt-4o-mini", temperature=0)`.
136
134
 
137
135
  Returns:
138
- StateGraph: A compiled LangGraph application ready for query invocation.
136
+ StateGraph: A compiled LangGraph application that can process user queries.
139
137
 
140
138
  Example:
141
- app = get_app("thread_123")
142
- result = app.invoke(initial_state)
139
+ >>> app = get_app("thread_123")
140
+ >>> result = app.invoke(initial_state)
143
141
  """
144
142
  cfg = get_hydra_config()
145
143
 
146
144
  def call_s2_agent(
147
145
  state: Talk2Scholars,
148
- ) -> Command[Literal["supervisor", "__end__"]]:
146
+ ) -> Command[Literal["supervisor"]]:
149
147
  """
150
- Calls the Semantic Scholar (S2) agent to process academic paper queries.
148
+ Invokes the Semantic Scholar (S2) agent to retrieve relevant research papers.
151
149
 
152
- This function invokes the S2 agent, retrieves relevant research papers,
153
- and updates the conversation state accordingly.
150
+ This function calls the `s2_agent` and updates the conversation state with retrieved
151
+ academic papers. The agent uses Semantic Scholar's API to find papers based on
152
+ user queries.
154
153
 
155
154
  Args:
156
- state (Talk2Scholars): The current conversation state, including user queries
157
- and any previously retrieved papers.
155
+ state (Talk2Scholars): The current state of the conversation, containing messages
156
+ and any previous search results.
158
157
 
159
158
  Returns:
160
- Command: The next action to execute, along with updated messages and papers.
159
+ Command: A command to update the conversation state with the retrieved papers
160
+ and return control to the supervisor node.
161
161
 
162
162
  Example:
163
- result = call_s2_agent(current_state)
164
- next_step = result.goto
163
+ >>> result = call_s2_agent(current_state)
164
+ >>> next_step = result.goto
165
165
  """
166
- logger.info("Calling S2 agent with state: %s", state)
166
+ logger.info("Calling S2 agent")
167
167
  app = s2_agent.get_app(thread_id, llm_model)
168
168
 
169
169
  # Invoke the S2 agent, passing state,
@@ -177,31 +177,30 @@ def get_app(thread_id: str, llm_model: str = "gpt-4o-mini") -> StateGraph:
177
177
  }
178
178
  },
179
179
  )
180
- logger.info("S2 agent completed with response: %s", response)
181
-
180
+ logger.info("S2 agent completed with response")
182
181
  return Command(
183
- goto=END,
184
182
  update={
185
183
  "messages": response["messages"],
186
184
  "papers": response.get("papers", {}),
187
185
  "multi_papers": response.get("multi_papers", {}),
186
+ "last_displayed_papers": response.get("last_displayed_papers", {}),
188
187
  },
188
+ # Always return to supervisor
189
+ goto="supervisor",
189
190
  )
190
191
 
191
192
  # Initialize LLM
192
- logger.info("Using OpenAI model %s with temperature %s", llm_model, cfg.temperature)
193
- llm = ChatOpenAI(model=llm_model, temperature=cfg.temperature)
193
+ logger.info("Using model %s with temperature %s", llm_model, cfg.temperature)
194
194
 
195
195
  # Build the graph
196
196
  workflow = StateGraph(Talk2Scholars)
197
- supervisor = make_supervisor_node(llm, thread_id)
198
-
197
+ supervisor = make_supervisor_node(llm_model, thread_id)
198
+ # Add nodes
199
199
  workflow.add_node("supervisor", supervisor)
200
200
  workflow.add_node("s2_agent", call_s2_agent)
201
+ # Add edges
201
202
  workflow.add_edge(START, "supervisor")
202
- workflow.add_edge("s2_agent", END)
203
-
204
- # Compile the graph without initial state
203
+ # Compile the workflow
205
204
  app = workflow.compile(checkpointer=MemorySaver())
206
205
  logger.info("Main agent workflow compiled")
207
206
  return app