aiagents4pharma 1.19.1__py3-none-any.whl → 1.20.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,59 +1,21 @@
1
1
  """
2
- Unit and integration tests for Talk2Scholars system.
3
- Each test focuses on a single, specific functionality.
4
- Tests are deterministic and independent of each other.
2
+ Unit tests for S2 tools functionality.
5
3
  """
6
4
 
7
- from unittest.mock import Mock, patch
5
+ # pylint: disable=redefined-outer-name
6
+ from unittest.mock import patch
7
+ from langchain_core.messages import ToolMessage
8
8
  import pytest
9
- from langchain_core.messages import AIMessage, HumanMessage
10
- import hydra
11
- from hydra.core.global_hydra import GlobalHydra
12
- from omegaconf import DictConfig, OmegaConf
13
-
14
- from ..agents.main_agent import get_app, make_supervisor_node
15
- from ..state.state_talk2scholars import replace_dict, Talk2Scholars
16
- from ..tools.s2.display_results import display_results
9
+ from ..tools.s2.display_results import display_results, NoPapersFoundError
17
10
  from ..tools.s2.multi_paper_rec import get_multi_paper_recommendations
18
11
  from ..tools.s2.search import search_tool
19
12
  from ..tools.s2.single_paper_rec import get_single_paper_recommendations
20
13
 
21
- # pylint: disable=redefined-outer-name
22
-
23
-
24
- @pytest.fixture(autouse=True)
25
- def hydra_setup():
26
- """Setup and cleanup Hydra for tests."""
27
- GlobalHydra.instance().clear()
28
- with hydra.initialize(version_base=None, config_path="../configs"):
29
- yield
30
-
31
14
 
32
15
  @pytest.fixture
33
- def mock_cfg() -> DictConfig:
34
- """Create a mock configuration for testing."""
35
- config = {
36
- "agents": {
37
- "talk2scholars": {
38
- "main_agent": {
39
- "state_modifier": "Test prompt for main agent",
40
- "temperature": 0,
41
- },
42
- "s2_agent": {
43
- "temperature": 0,
44
- "s2_agent": "Test prompt for s2 agent",
45
- },
46
- }
47
- },
48
- "tools": {
49
- "search": {
50
- "api_endpoint": "https://api.semanticscholar.org/graph/v1/paper/search",
51
- "default_limit": 2,
52
- "api_fields": ["paperId", "title", "abstract", "year", "authors"],
53
- }
54
- },
55
- }
56
- return OmegaConf.create(config)
16
+ def initial_state():
17
+ """Provides an empty initial state for tests."""
18
+ return {"papers": {}, "multi_papers": {}}
57
19
 
58
20
 
59
21
  # Fixed test data for deterministic results
@@ -82,70 +44,26 @@ MOCK_STATE_PAPER = {
82
44
  }
83
45
 
84
46
 
85
- @pytest.fixture
86
- def initial_state() -> Talk2Scholars:
87
- """Create a base state for tests"""
88
- return Talk2Scholars(
89
- messages=[],
90
- papers={},
91
- is_last_step=False,
92
- current_agent=None,
93
- llm_model="gpt-4o-mini",
94
- next="",
95
- )
96
-
97
-
98
- class TestMainAgent:
99
- """Unit tests for main agent functionality"""
100
-
101
- def test_supervisor_routes_search_to_s2(
102
- self, initial_state: Talk2Scholars, mock_cfg
103
- ):
104
- """Verifies that search-related queries are routed to S2 agent"""
105
- llm_mock = Mock()
106
- llm_mock.invoke.return_value = AIMessage(content="Search initiated")
107
-
108
- # Extract the main_agent config
109
- supervisor = make_supervisor_node(
110
- llm_mock, mock_cfg.agents.talk2scholars.main_agent
111
- )
112
- state = initial_state
113
- state["messages"] = [HumanMessage(content="search for papers")]
114
-
115
- result = supervisor(state)
116
- assert result.goto == "s2_agent"
117
- assert not result.update["is_last_step"]
118
- assert result.update["current_agent"] == "s2_agent"
119
-
120
- def test_supervisor_routes_general_to_end(
121
- self, initial_state: Talk2Scholars, mock_cfg
122
- ):
123
- """Verifies that non-search queries end the conversation"""
124
- llm_mock = Mock()
125
- llm_mock.invoke.return_value = AIMessage(content="General response")
126
-
127
- # Extract the main_agent config
128
- supervisor = make_supervisor_node(
129
- llm_mock, mock_cfg.agents.talk2scholars.main_agent
130
- )
131
- state = initial_state
132
- state["messages"] = [HumanMessage(content="What is ML?")]
133
-
134
- result = supervisor(state)
135
- assert result.goto == "__end__"
136
- assert result.update["is_last_step"]
137
-
138
-
139
47
  class TestS2Tools:
140
48
  """Unit tests for individual S2 tools"""
141
49
 
142
- def test_display_results_shows_papers(self, initial_state: Talk2Scholars):
50
+ def test_display_results_empty_state(self, initial_state):
51
+ """Verifies display_results tool behavior when state is empty and raises an exception"""
52
+ with pytest.raises(
53
+ NoPapersFoundError,
54
+ match="No papers found. A search needs to be performed first.",
55
+ ):
56
+ display_results.invoke({"state": initial_state})
57
+
58
+ def test_display_results_shows_papers(self, initial_state):
143
59
  """Verifies display_results tool correctly returns papers from state"""
144
- state = initial_state
60
+ state = initial_state.copy()
145
61
  state["papers"] = MOCK_STATE_PAPER
62
+ state["multi_papers"] = {}
146
63
  result = display_results.invoke(input={"state": state})
147
- assert result == MOCK_STATE_PAPER
148
64
  assert isinstance(result, dict)
65
+ assert result["papers"] == MOCK_STATE_PAPER
66
+ assert result["multi_papers"] == {}
149
67
 
150
68
  @patch("requests.get")
151
69
  def test_search_finds_papers(self, mock_get):
@@ -171,6 +89,60 @@ class TestS2Tools:
171
89
  assert paper["Title"] == "Machine Learning Basics"
172
90
  assert paper["Year"] == 2023
173
91
 
92
+ @patch("requests.get")
93
+ def test_search_finds_papers_with_year(self, mock_get):
94
+ """Verifies search tool works with year parameter"""
95
+ mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
96
+ mock_get.return_value.status_code = 200
97
+
98
+ result = search_tool.invoke(
99
+ input={
100
+ "query": "machine learning",
101
+ "limit": 1,
102
+ "year": "2023-",
103
+ "tool_call_id": "test123",
104
+ "id": "test123",
105
+ }
106
+ )
107
+
108
+ assert "papers" in result.update
109
+ assert "messages" in result.update
110
+ papers = result.update["papers"]
111
+ assert isinstance(papers, dict)
112
+ assert len(papers) > 0
113
+
114
+ @patch("requests.get")
115
+ def test_search_filters_invalid_papers(self, mock_get):
116
+ """Verifies search tool properly filters papers without title or authors"""
117
+ mock_response = {
118
+ "data": [
119
+ {
120
+ "paperId": "123",
121
+ "abstract": "An introduction to ML",
122
+ "year": 2023,
123
+ "citationCount": 100,
124
+ "url": "https://example.com/paper1",
125
+ # Missing title and authors
126
+ },
127
+ MOCK_SEARCH_RESPONSE["data"][0], # This one is valid
128
+ ]
129
+ }
130
+ mock_get.return_value.json.return_value = mock_response
131
+ mock_get.return_value.status_code = 200
132
+
133
+ result = search_tool.invoke(
134
+ input={
135
+ "query": "machine learning",
136
+ "limit": 2,
137
+ "tool_call_id": "test123",
138
+ "id": "test123",
139
+ }
140
+ )
141
+
142
+ assert "papers" in result.update
143
+ papers = result.update["papers"]
144
+ assert len(papers) == 1 # Only the valid paper should be included
145
+
174
146
  @patch("requests.get")
175
147
  def test_single_paper_rec_basic(self, mock_get):
176
148
  """Tests basic single paper recommendation functionality"""
@@ -184,11 +156,10 @@ class TestS2Tools:
184
156
  "paper_id": "123",
185
157
  "limit": 1,
186
158
  "tool_call_id": "test123",
187
- "id": "test123",
188
159
  }
189
160
  )
190
161
  assert "papers" in result.update
191
- assert len(result.update["messages"]) == 1
162
+ assert isinstance(result.update["messages"][0], ToolMessage)
192
163
 
193
164
  @patch("requests.get")
194
165
  def test_single_paper_rec_with_optional_params(self, mock_get):
@@ -222,11 +193,10 @@ class TestS2Tools:
222
193
  "paper_ids": ["123", "456"],
223
194
  "limit": 1,
224
195
  "tool_call_id": "test123",
225
- "id": "test123",
226
196
  }
227
197
  )
228
- assert "papers" in result.update
229
- assert len(result.update["messages"]) == 1
198
+ assert "multi_papers" in result.update
199
+ assert isinstance(result.update["messages"][0], ToolMessage)
230
200
 
231
201
  @patch("requests.post")
232
202
  def test_multi_paper_rec_with_optional_params(self, mock_post):
@@ -245,47 +215,5 @@ class TestS2Tools:
245
215
  "id": "test123",
246
216
  }
247
217
  )
248
- assert "papers" in result.update
218
+ assert "multi_papers" in result.update
249
219
  assert len(result.update["messages"]) == 1
250
-
251
-
252
- def test_state_replace_dict():
253
- """Verifies state dictionary replacement works correctly"""
254
- existing = {"key1": "value1", "key2": "value2"}
255
- new = {"key3": "value3"}
256
- result = replace_dict(existing, new)
257
- assert result == new
258
- assert isinstance(result, dict)
259
-
260
-
261
- @pytest.mark.integration
262
- def test_end_to_end_search_workflow(initial_state: Talk2Scholars, mock_cfg):
263
- """Integration test: Complete search workflow"""
264
- with (
265
- patch("requests.get") as mock_get,
266
- patch("langchain_openai.ChatOpenAI") as mock_llm,
267
- patch("hydra.compose", return_value=mock_cfg),
268
- patch("hydra.initialize"),
269
- ):
270
- mock_get.return_value.json.return_value = MOCK_SEARCH_RESPONSE
271
- mock_get.return_value.status_code = 200
272
-
273
- llm_instance = Mock()
274
- llm_instance.invoke.return_value = AIMessage(content="Search completed")
275
- mock_llm.return_value = llm_instance
276
-
277
- app = get_app("test_integration")
278
- test_state = initial_state
279
- test_state["messages"] = [HumanMessage(content="search for ML papers")]
280
-
281
- config = {
282
- "configurable": {
283
- "thread_id": "test_integration",
284
- "checkpoint_ns": "test",
285
- "checkpoint_id": "test123",
286
- }
287
- }
288
-
289
- response = app.invoke(test_state, config)
290
- assert "papers" in response
291
- assert len(response["messages"]) > 0
@@ -0,0 +1,14 @@
1
+ """
2
+ Tests for state management functionality.
3
+ """
4
+
5
+ from ..state.state_talk2scholars import replace_dict
6
+
7
+
8
+ def test_state_replace_dict():
9
+ """Verifies state dictionary replacement works correctly"""
10
+ existing = {"key1": "value1", "key2": "value2"}
11
+ new = {"key3": "value3"}
12
+ result = replace_dict(existing, new)
13
+ assert result == new
14
+ assert isinstance(result, dict)
@@ -1,11 +1,11 @@
1
1
  #!/usr/bin/env python3
2
2
 
3
- '''
3
+ """
4
4
  This tool is used to display the table of studies.
5
- '''
5
+ """
6
6
 
7
7
  import logging
8
- from typing import Annotated
8
+ from typing import Annotated, Dict, Any
9
9
  from langchain_core.tools import tool
10
10
  from langgraph.prebuilt import InjectedState
11
11
 
@@ -13,13 +13,38 @@ from langgraph.prebuilt import InjectedState
13
13
  logging.basicConfig(level=logging.INFO)
14
14
  logger = logging.getLogger(__name__)
15
15
 
16
- @tool('display_results')
17
- def display_results(state: Annotated[dict, InjectedState]):
16
+
17
+ class NoPapersFoundError(Exception):
18
+ """Exception raised when no papers are found in the state."""
19
+
20
+
21
+ @tool("display_results")
22
+ def display_results(state: Annotated[dict, InjectedState]) -> Dict[str, Any]:
18
23
  """
19
- Display the papers in the state.
24
+ Display the papers in the state. If no papers are found, raises an exception
25
+ indicating that a search is needed.
20
26
 
21
27
  Args:
22
- state (dict): The state of the agent.
28
+ state (dict): The state of the agent containing the papers.
29
+
30
+ Returns:
31
+ dict: A dictionary containing the papers and multi_papers from the state.
32
+
33
+ Raises:
34
+ NoPapersFoundError: If no papers are found in the state.
35
+
36
+ Note:
37
+ The exception allows the LLM to make a more informed decision about initiating a search.
23
38
  """
24
39
  logger.info("Displaying papers from the state")
25
- return state["papers"]
40
+
41
+ if not state.get("papers") and not state.get("multi_papers"):
42
+ logger.info("No papers found in state, raising NoPapersFoundError")
43
+ raise NoPapersFoundError(
44
+ "No papers found. A search needs to be performed first."
45
+ )
46
+
47
+ return {
48
+ "papers": state.get("papers"),
49
+ "multi_papers": state.get("multi_papers"),
50
+ }
@@ -9,7 +9,6 @@ import json
9
9
  import logging
10
10
  from typing import Annotated, Any, Dict, List, Optional
11
11
  import hydra
12
- import pandas as pd
13
12
  import requests
14
13
  from langchain_core.messages import ToolMessage
15
14
  from langchain_core.tools import tool
@@ -18,6 +17,11 @@ from langgraph.types import Command
18
17
  from pydantic import BaseModel, Field
19
18
 
20
19
 
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+
21
25
  class MultiPaperRecInput(BaseModel):
22
26
  """Input schema for multiple paper recommendations tool."""
23
27
 
@@ -110,31 +114,14 @@ def get_multi_paper_recommendations(
110
114
  if paper.get("title") and paper.get("paperId")
111
115
  }
112
116
 
113
- # Create a DataFrame from the dictionary
114
- df = pd.DataFrame.from_dict(filtered_papers, orient="index")
115
- # print("Created DataFrame with results:")
116
- logging.info("Created DataFrame with results: %s", df)
117
-
118
- # Format papers for state update
119
- papers = [
120
- f"Paper ID: {paper_id}\n"
121
- f"Title: {paper_data['Title']}\n"
122
- f"Abstract: {paper_data['Abstract']}\n"
123
- f"Year: {paper_data['Year']}\n"
124
- f"Citations: {paper_data['Citation Count']}\n"
125
- f"URL: {paper_data['URL']}\n"
126
- for paper_id, paper_data in filtered_papers.items()
127
- ]
128
-
129
- # Convert DataFrame to markdown table
130
- markdown_table = df.to_markdown(tablefmt="grid")
131
- logging.info("Search results: %s", papers)
132
-
133
117
  return Command(
134
118
  update={
135
- "papers": filtered_papers, # Now sending the dictionary directly
119
+ "multi_papers": filtered_papers, # Now sending the dictionary directly
136
120
  "messages": [
137
- ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
121
+ ToolMessage(
122
+ content=f"Search Successful: {filtered_papers}",
123
+ tool_call_id=tool_call_id
124
+ )
138
125
  ],
139
126
  }
140
127
  )
@@ -7,7 +7,6 @@ This tool is used to search for academic papers on Semantic Scholar.
7
7
  import logging
8
8
  from typing import Annotated, Any, Dict, Optional
9
9
  import hydra
10
- import pandas as pd
11
10
  import requests
12
11
  from langchain_core.messages import ToolMessage
13
12
  from langchain_core.tools import tool
@@ -16,6 +15,11 @@ from langgraph.types import Command
16
15
  from pydantic import BaseModel, Field
17
16
 
18
17
 
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
19
23
  class SearchInput(BaseModel):
20
24
  """Input schema for the search papers tool."""
21
25
 
@@ -65,8 +69,6 @@ def search_tool(
65
69
  params = {
66
70
  "query": query,
67
71
  "limit": min(limit, 100),
68
- # "fields": "paperId,title,abstract,year,authors,
69
- # citationCount,url,publicationTypes,openAccessPdf",
70
72
  "fields": ",".join(cfg.api_fields),
71
73
  }
72
74
 
@@ -77,7 +79,7 @@ def search_tool(
77
79
  response = requests.get(endpoint, params=params, timeout=10)
78
80
  data = response.json()
79
81
  papers = data.get("data", [])
80
-
82
+ logger.info("Received %d papers", len(papers))
81
83
  # Create a dictionary to store the papers
82
84
  filtered_papers = {
83
85
  paper["paperId"]: {
@@ -86,40 +88,19 @@ def search_tool(
86
88
  "Year": paper.get("year", "N/A"),
87
89
  "Citation Count": paper.get("citationCount", "N/A"),
88
90
  "URL": paper.get("url", "N/A"),
89
- # "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
90
- # if paper.get("publicationTypes")
91
- # else "N/A",
92
- # "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
93
- # if paper.get("openAccessPdf") is not None
94
- # else "N/A",
95
91
  }
96
92
  for paper in papers
97
93
  if paper.get("title") and paper.get("authors")
98
94
  }
99
95
 
100
- df = pd.DataFrame(filtered_papers)
101
-
102
- # Format papers for state update
103
- papers = [
104
- f"Paper ID: {paper_id}\n"
105
- f"Title: {paper_data['Title']}\n"
106
- f"Abstract: {paper_data['Abstract']}\n"
107
- f"Year: {paper_data['Year']}\n"
108
- f"Citations: {paper_data['Citation Count']}\n"
109
- f"URL: {paper_data['URL']}\n"
110
- # f"Publication Type: {paper_data['Publication Type']}\n"
111
- # f"Open Access PDF: {paper_data['Open Access PDF']}"
112
- for paper_id, paper_data in filtered_papers.items()
113
- ]
114
-
115
- markdown_table = df.to_markdown(tablefmt="grid")
116
- logging.info("Search results: %s", papers)
117
-
118
96
  return Command(
119
97
  update={
120
98
  "papers": filtered_papers, # Now sending the dictionary directly
121
99
  "messages": [
122
- ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
100
+ ToolMessage(
101
+ content=f"Search Successful: {filtered_papers}",
102
+ tool_call_id=tool_call_id
103
+ )
123
104
  ],
124
105
  }
125
106
  )
@@ -7,7 +7,6 @@ This tool is used to return recommendations for a single paper.
7
7
  import logging
8
8
  from typing import Annotated, Any, Dict, Optional
9
9
  import hydra
10
- import pandas as pd
11
10
  import requests
12
11
  from langchain_core.messages import ToolMessage
13
12
  from langchain_core.tools import tool
@@ -84,7 +83,6 @@ def get_single_paper_recommendations(
84
83
 
85
84
  response = requests.get(endpoint, params=params, timeout=cfg.request_timeout)
86
85
  data = response.json()
87
- papers = data.get("data", [])
88
86
  response = requests.get(endpoint, params=params, timeout=10)
89
87
  # print(f"API Response Status: {response.status_code}")
90
88
  logging.info(
@@ -106,42 +104,19 @@ def get_single_paper_recommendations(
106
104
  "Year": paper.get("year", "N/A"),
107
105
  "Citation Count": paper.get("citationCount", "N/A"),
108
106
  "URL": paper.get("url", "N/A"),
109
- # "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
110
- # if paper.get("publicationTypes")
111
- # else "N/A",
112
- # "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
113
- # if paper.get("openAccessPdf") is not None
114
- # else "N/A",
115
107
  }
116
108
  for paper in recommendations
117
109
  if paper.get("title") and paper.get("authors")
118
110
  }
119
111
 
120
- # Create a DataFrame for pretty printing
121
- df = pd.DataFrame(filtered_papers)
122
-
123
- # Format papers for state update
124
- papers = [
125
- f"Paper ID: {paper_id}\n"
126
- f"Title: {paper_data['Title']}\n"
127
- f"Abstract: {paper_data['Abstract']}\n"
128
- f"Year: {paper_data['Year']}\n"
129
- f"Citations: {paper_data['Citation Count']}\n"
130
- f"URL: {paper_data['URL']}\n"
131
- # f"Publication Type: {paper_data['Publication Type']}\n"
132
- # f"Open Access PDF: {paper_data['Open Access PDF']}"
133
- for paper_id, paper_data in filtered_papers.items()
134
- ]
135
-
136
- # Convert DataFrame to markdown table
137
- markdown_table = df.to_markdown(tablefmt="grid")
138
- logging.info("Search results: %s", papers)
139
-
140
112
  return Command(
141
113
  update={
142
114
  "papers": filtered_papers, # Now sending the dictionary directly
143
115
  "messages": [
144
- ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
116
+ ToolMessage(
117
+ content=f"Search Successful: {filtered_papers}",
118
+ tool_call_id=tool_call_id
119
+ )
145
120
  ],
146
121
  }
147
122
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: aiagents4pharma
3
- Version: 1.19.1
3
+ Version: 1.20.1
4
4
  Summary: AI Agents for drug discovery, drug development, and other pharmaceutical R&D.
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -56,6 +56,7 @@ Requires-Dist: streamlit-feedback
56
56
  [![TESTS Talk2Scholars](https://github.com/VirtualPatientEngine/AIAgents4Pharma/actions/workflows/tests_talk2scholars.yml/badge.svg)](https://github.com/VirtualPatientEngine/AIAgents4Pharma/actions/workflows/tests_talk2scholars.yml)
57
57
  ![GitHub Release](https://img.shields.io/github/v/release/VirtualPatientEngine/AIAgents4Pharma)
58
58
  ![Python Version from PEP 621 TOML](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2FVirtualPatientEngine%2FAIAgents4Pharma%2Frefs%2Fheads%2Fmain%2Fpyproject.toml)
59
+ ![Docker Pulls](https://img.shields.io/docker/pulls/virtualpatientengine/talk2biomodels?link=https%3A%2F%2Fhub.docker.com%2Frepository%2Fdocker%2Fvirtualpatientengine%2Ftalk2biomodels%2Fgeneral)
59
60
 
60
61
 
61
62
  ## Introduction
@@ -85,7 +86,22 @@ pip install aiagents4pharma
85
86
 
86
87
  Check out the tutorials on each agent for detailed instrcutions.
87
88
 
88
- #### Option 2: git
89
+ #### Option 2: docker hub
90
+
91
+ _Please note that this option is currently available only for Talk2Biomodels._
92
+
93
+ 1. **Pull the image**
94
+ ```
95
+ docker pull virtualpatientengine/talk2biomodels
96
+ ```
97
+ 2. **Run a container**
98
+ ```
99
+ docker run -e OPENAI_API_KEY=<openai_api_key> -e NVIDIA_API_KEY=<nvidia_api_key> -p 8501:8501 virtualpatientengine/talk2biomodels
100
+ ```
101
+ _You can create a free account at NVIDIA and apply for their
102
+ free credits [here](https://build.nvidia.com/explore/discover)._
103
+
104
+ #### Option 3: git
89
105
 
90
106
  1. **Clone the repository:**
91
107
  ```bash
@@ -94,7 +110,7 @@ Check out the tutorials on each agent for detailed instrcutions.
94
110
  ```
95
111
  2. **Install dependencies:**
96
112
  ```bash
97
- pip install .
113
+ pip install -r requirements.txt
98
114
  ```
99
115
  3. **Initialize OPENAI_API_KEY and NVIDIA_API_KEY**
100
116
  ```bash