aiagents4pharma 1.9.0__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. aiagents4pharma/__init__.py +9 -6
  2. aiagents4pharma/configs/config.yaml +2 -1
  3. aiagents4pharma/configs/talk2biomodels/__init__.py +1 -0
  4. aiagents4pharma/configs/talk2biomodels/agents/t2b_agent/default.yaml +9 -3
  5. aiagents4pharma/configs/talk2biomodels/tools/__init__.py +4 -0
  6. aiagents4pharma/configs/talk2biomodels/tools/ask_question/__init__.py +3 -0
  7. aiagents4pharma/talk2biomodels/__init__.py +1 -0
  8. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +14 -11
  9. aiagents4pharma/talk2biomodels/api/__init__.py +6 -0
  10. aiagents4pharma/talk2biomodels/api/kegg.py +83 -0
  11. aiagents4pharma/talk2biomodels/api/ols.py +72 -0
  12. aiagents4pharma/talk2biomodels/api/uniprot.py +35 -0
  13. aiagents4pharma/talk2biomodels/models/basico_model.py +29 -32
  14. aiagents4pharma/talk2biomodels/models/sys_bio_model.py +9 -6
  15. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +24 -7
  16. aiagents4pharma/talk2biomodels/tests/test_api.py +57 -0
  17. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +44 -0
  18. aiagents4pharma/talk2biomodels/tests/test_basico_model.py +7 -8
  19. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +171 -0
  20. aiagents4pharma/talk2biomodels/tests/test_getmodelinfo.py +26 -0
  21. aiagents4pharma/talk2biomodels/tests/test_integration.py +126 -0
  22. aiagents4pharma/talk2biomodels/tests/test_param_scan.py +68 -0
  23. aiagents4pharma/talk2biomodels/tests/test_query_article.py +76 -0
  24. aiagents4pharma/talk2biomodels/tests/test_search_models.py +28 -0
  25. aiagents4pharma/talk2biomodels/tests/test_simulate_model.py +39 -0
  26. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +90 -0
  27. aiagents4pharma/talk2biomodels/tests/test_sys_bio_model.py +13 -7
  28. aiagents4pharma/talk2biomodels/tools/__init__.py +4 -0
  29. aiagents4pharma/talk2biomodels/tools/ask_question.py +59 -25
  30. aiagents4pharma/talk2biomodels/tools/get_annotation.py +304 -0
  31. aiagents4pharma/talk2biomodels/tools/get_modelinfo.py +5 -3
  32. aiagents4pharma/talk2biomodels/tools/load_arguments.py +114 -0
  33. aiagents4pharma/talk2biomodels/tools/parameter_scan.py +287 -0
  34. aiagents4pharma/talk2biomodels/tools/query_article.py +59 -0
  35. aiagents4pharma/talk2biomodels/tools/simulate_model.py +20 -89
  36. aiagents4pharma/talk2biomodels/tools/steady_state.py +167 -0
  37. aiagents4pharma/talk2competitors/__init__.py +5 -0
  38. aiagents4pharma/talk2competitors/agents/__init__.py +6 -0
  39. aiagents4pharma/talk2competitors/agents/main_agent.py +130 -0
  40. aiagents4pharma/talk2competitors/agents/s2_agent.py +75 -0
  41. aiagents4pharma/talk2competitors/config/__init__.py +5 -0
  42. aiagents4pharma/talk2competitors/config/config.py +110 -0
  43. aiagents4pharma/talk2competitors/state/__init__.py +5 -0
  44. aiagents4pharma/talk2competitors/state/state_talk2competitors.py +32 -0
  45. aiagents4pharma/talk2competitors/tests/__init__.py +3 -0
  46. aiagents4pharma/talk2competitors/tests/test_langgraph.py +274 -0
  47. aiagents4pharma/talk2competitors/tools/__init__.py +7 -0
  48. aiagents4pharma/talk2competitors/tools/s2/__init__.py +8 -0
  49. aiagents4pharma/talk2competitors/tools/s2/display_results.py +25 -0
  50. aiagents4pharma/talk2competitors/tools/s2/multi_paper_rec.py +132 -0
  51. aiagents4pharma/talk2competitors/tools/s2/search.py +119 -0
  52. aiagents4pharma/talk2competitors/tools/s2/single_paper_rec.py +141 -0
  53. aiagents4pharma/talk2knowledgegraphs/__init__.py +2 -1
  54. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_enrichments.py +39 -0
  55. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +117 -0
  56. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +5 -0
  57. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/__init__.py +5 -0
  58. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/enrichments.py +36 -0
  59. aiagents4pharma/talk2knowledgegraphs/utils/enrichments/ollama.py +123 -0
  60. {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/METADATA +42 -23
  61. aiagents4pharma-1.15.0.dist-info/RECORD +102 -0
  62. aiagents4pharma/talk2biomodels/tests/test_langgraph.py +0 -240
  63. aiagents4pharma-1.9.0.dist-info/RECORD +0 -62
  64. {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/LICENSE +0 -0
  65. {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/WHEEL +0 -0
  66. {aiagents4pharma-1.9.0.dist-info → aiagents4pharma-1.15.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,132 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ multi_paper_rec: Tool for getting recommendations
5
+ based on multiple papers
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ from typing import Annotated, Any, Dict, List, Optional
11
+
12
+ import pandas as pd
13
+ import requests
14
+ from langchain_core.messages import ToolMessage
15
+ from langchain_core.tools import tool
16
+ from langchain_core.tools.base import InjectedToolCallId
17
+ from langgraph.types import Command
18
+ from pydantic import BaseModel, Field
19
+
20
+
21
+ class MultiPaperRecInput(BaseModel):
22
+ """Input schema for multiple paper recommendations tool."""
23
+
24
+ paper_ids: List[str] = Field(
25
+ description=("List of Semantic Scholar Paper IDs to get recommendations for")
26
+ )
27
+ limit: int = Field(
28
+ default=2,
29
+ description="Maximum total number of recommendations to return",
30
+ ge=1,
31
+ le=500,
32
+ )
33
+ year: Optional[str] = Field(
34
+ default=None,
35
+ description="Year range in format: YYYY for specific year, "
36
+ "YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
37
+ )
38
+ tool_call_id: Annotated[str, InjectedToolCallId]
39
+
40
+ model_config = {"arbitrary_types_allowed": True}
41
+
42
+
43
+ @tool(args_schema=MultiPaperRecInput)
44
+ def get_multi_paper_recommendations(
45
+ paper_ids: List[str],
46
+ tool_call_id: Annotated[str, InjectedToolCallId],
47
+ limit: int = 2,
48
+ year: Optional[str] = None,
49
+ ) -> Dict[str, Any]:
50
+ """
51
+ Get paper recommendations based on multiple papers.
52
+
53
+ Args:
54
+ paper_ids (List[str]): The list of paper IDs to base recommendations on.
55
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
56
+ limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
57
+ year (str, optional): Year range for papers.
58
+ Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
59
+
60
+ Returns:
61
+ Dict[str, Any]: The recommendations and related information.
62
+ """
63
+ logging.info("Starting multi-paper recommendations search.")
64
+
65
+ endpoint = "https://api.semanticscholar.org/recommendations/v1/papers"
66
+ headers = {"Content-Type": "application/json"}
67
+ payload = {"positivePaperIds": paper_ids, "negativePaperIds": []}
68
+ params = {
69
+ "limit": min(limit, 500),
70
+ "fields": "paperId,title,abstract,year,authors,citationCount,url",
71
+ }
72
+
73
+ # Add year parameter if provided
74
+ if year:
75
+ params["year"] = year
76
+
77
+ # Getting recommendations
78
+ response = requests.post(
79
+ endpoint,
80
+ headers=headers,
81
+ params=params,
82
+ data=json.dumps(payload),
83
+ timeout=10,
84
+ )
85
+ logging.info(
86
+ "API Response Status for multi-paper recommendations: %s", response.status_code
87
+ )
88
+
89
+ data = response.json()
90
+ recommendations = data.get("recommendedPapers", [])
91
+
92
+ # Create a dictionary to store the papers
93
+ filtered_papers = {
94
+ paper["paperId"]: {
95
+ "Title": paper.get("title", "N/A"),
96
+ "Abstract": paper.get("abstract", "N/A"),
97
+ "Year": paper.get("year", "N/A"),
98
+ "Citation Count": paper.get("citationCount", "N/A"),
99
+ "URL": paper.get("url", "N/A"),
100
+ }
101
+ for paper in recommendations
102
+ if paper.get("title") and paper.get("paperId")
103
+ }
104
+
105
+ # Create a DataFrame from the dictionary
106
+ df = pd.DataFrame.from_dict(filtered_papers, orient="index")
107
+ # print("Created DataFrame with results:")
108
+ logging.info("Created DataFrame with results: %s", df)
109
+
110
+ # Format papers for state update
111
+ papers = [
112
+ f"Paper ID: {paper_id}\n"
113
+ f"Title: {paper_data['Title']}\n"
114
+ f"Abstract: {paper_data['Abstract']}\n"
115
+ f"Year: {paper_data['Year']}\n"
116
+ f"Citations: {paper_data['Citation Count']}\n"
117
+ f"URL: {paper_data['URL']}\n"
118
+ for paper_id, paper_data in filtered_papers.items()
119
+ ]
120
+
121
+ # Convert DataFrame to markdown table
122
+ markdown_table = df.to_markdown(tablefmt="grid")
123
+ logging.info("Search results: %s", papers)
124
+
125
+ return Command(
126
+ update={
127
+ "papers": filtered_papers, # Now sending the dictionary directly
128
+ "messages": [
129
+ ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
130
+ ],
131
+ }
132
+ )
@@ -0,0 +1,119 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This tool is used to search for academic papers on Semantic Scholar.
5
+ """
6
+
7
+ import logging
8
+ from typing import Annotated, Any, Dict, Optional
9
+
10
+ import pandas as pd
11
+ import requests
12
+ from langchain_core.messages import ToolMessage
13
+ from langchain_core.tools import tool
14
+ from langchain_core.tools.base import InjectedToolCallId
15
+ from langgraph.types import Command
16
+ from pydantic import BaseModel, Field
17
+
18
+
19
+ class SearchInput(BaseModel):
20
+ """Input schema for the search papers tool."""
21
+
22
+ query: str = Field(
23
+ description="Search query string to find academic papers."
24
+ "Be specific and include relevant academic terms."
25
+ )
26
+ limit: int = Field(
27
+ default=2, description="Maximum number of results to return", ge=1, le=100
28
+ )
29
+ year: Optional[str] = Field(
30
+ default=None,
31
+ description="Year range in format: YYYY for specific year, "
32
+ "YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
33
+ )
34
+ tool_call_id: Annotated[str, InjectedToolCallId]
35
+
36
+
37
+ @tool(args_schema=SearchInput)
38
+ def search_tool(
39
+ query: str,
40
+ tool_call_id: Annotated[str, InjectedToolCallId],
41
+ limit: int = 2,
42
+ year: Optional[str] = None,
43
+ ) -> Dict[str, Any]:
44
+ """
45
+ Search for academic papers on Semantic Scholar.
46
+
47
+ Args:
48
+ query (str): The search query string to find academic papers.
49
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
50
+ limit (int, optional): The maximum number of results to return. Defaults to 2.
51
+ year (str, optional): Year range for papers.
52
+ Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
53
+
54
+ Returns:
55
+ Dict[str, Any]: The search results and related information.
56
+ """
57
+ print("Starting paper search...")
58
+ endpoint = "https://api.semanticscholar.org/graph/v1/paper/search"
59
+ params = {
60
+ "query": query,
61
+ "limit": min(limit, 100),
62
+ # "fields": "paperId,title,abstract,year,authors,
63
+ # citationCount,url,publicationTypes,openAccessPdf",
64
+ "fields": "paperId,title,abstract,year,authors,citationCount,url",
65
+ }
66
+
67
+ # Add year parameter if provided
68
+ if year:
69
+ params["year"] = year
70
+
71
+ response = requests.get(endpoint, params=params, timeout=10)
72
+ data = response.json()
73
+ papers = data.get("data", [])
74
+
75
+ # Create a dictionary to store the papers
76
+ filtered_papers = {
77
+ paper["paperId"]: {
78
+ "Title": paper.get("title", "N/A"),
79
+ "Abstract": paper.get("abstract", "N/A"),
80
+ "Year": paper.get("year", "N/A"),
81
+ "Citation Count": paper.get("citationCount", "N/A"),
82
+ "URL": paper.get("url", "N/A"),
83
+ # "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
84
+ # if paper.get("publicationTypes")
85
+ # else "N/A",
86
+ # "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
87
+ # if paper.get("openAccessPdf") is not None
88
+ # else "N/A",
89
+ }
90
+ for paper in papers
91
+ if paper.get("title") and paper.get("authors")
92
+ }
93
+
94
+ df = pd.DataFrame(filtered_papers)
95
+
96
+ # Format papers for state update
97
+ papers = [
98
+ f"Paper ID: {paper_id}\n"
99
+ f"Title: {paper_data['Title']}\n"
100
+ f"Abstract: {paper_data['Abstract']}\n"
101
+ f"Year: {paper_data['Year']}\n"
102
+ f"Citations: {paper_data['Citation Count']}\n"
103
+ f"URL: {paper_data['URL']}\n"
104
+ # f"Publication Type: {paper_data['Publication Type']}\n"
105
+ # f"Open Access PDF: {paper_data['Open Access PDF']}"
106
+ for paper_id, paper_data in filtered_papers.items()
107
+ ]
108
+
109
+ markdown_table = df.to_markdown(tablefmt="grid")
110
+ logging.info("Search results: %s", papers)
111
+
112
+ return Command(
113
+ update={
114
+ "papers": filtered_papers, # Now sending the dictionary directly
115
+ "messages": [
116
+ ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
117
+ ],
118
+ }
119
+ )
@@ -0,0 +1,141 @@
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ This tool is used to return recommendations for a single paper.
5
+ """
6
+
7
+ import logging
8
+ from typing import Annotated, Any, Dict, Optional
9
+
10
+ import pandas as pd
11
+ import requests
12
+ from langchain_core.messages import ToolMessage
13
+ from langchain_core.tools import tool
14
+ from langchain_core.tools.base import InjectedToolCallId
15
+ from langgraph.types import Command
16
+ from pydantic import BaseModel, Field
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class SinglePaperRecInput(BaseModel):
24
+ """Input schema for single paper recommendation tool."""
25
+
26
+ paper_id: str = Field(
27
+ description="Semantic Scholar Paper ID to get recommendations for (40-character string)"
28
+ )
29
+ limit: int = Field(
30
+ default=2,
31
+ description="Maximum number of recommendations to return",
32
+ ge=1,
33
+ le=500,
34
+ )
35
+ year: Optional[str] = Field(
36
+ default=None,
37
+ description="Year range in format: YYYY for specific year, "
38
+ "YYYY- for papers after year, -YYYY for papers before year, or YYYY:YYYY for range",
39
+ )
40
+ tool_call_id: Annotated[str, InjectedToolCallId]
41
+ model_config = {"arbitrary_types_allowed": True}
42
+
43
+
44
+ @tool(args_schema=SinglePaperRecInput)
45
+ def get_single_paper_recommendations(
46
+ paper_id: str,
47
+ tool_call_id: Annotated[str, InjectedToolCallId],
48
+ limit: int = 2,
49
+ year: Optional[str] = None,
50
+ ) -> Dict[str, Any]:
51
+ """
52
+ Get paper recommendations based on a single paper.
53
+
54
+ Args:
55
+ paper_id (str): The Semantic Scholar Paper ID to get recommendations for.
56
+ tool_call_id (Annotated[str, InjectedToolCallId]): The tool call ID.
57
+ limit (int, optional): The maximum number of recommendations to return. Defaults to 2.
58
+ year (str, optional): Year range for papers.
59
+ Supports formats like "2024-", "-2024", "2024:2025". Defaults to None.
60
+
61
+ Returns:
62
+ Dict[str, Any]: The recommendations and related information.
63
+ """
64
+ logger.info("Starting single paper recommendations search.")
65
+
66
+ endpoint = (
67
+ f"https://api.semanticscholar.org/recommendations/v1/papers/forpaper/{paper_id}"
68
+ )
69
+ params = {
70
+ "limit": min(limit, 500), # Max 500 per API docs
71
+ "fields": "paperId,title,abstract,year,authors,citationCount,url",
72
+ "from": "all-cs", # Using all-cs pool as specified in docs
73
+ }
74
+
75
+ # Add year parameter if provided
76
+ if year:
77
+ params["year"] = year
78
+
79
+ response = requests.get(endpoint, params=params, timeout=10)
80
+ data = response.json()
81
+ papers = data.get("data", [])
82
+ response = requests.get(endpoint, params=params, timeout=10)
83
+ # print(f"API Response Status: {response.status_code}")
84
+ logging.info(
85
+ "API Response Status for recommendations of paper %s: %s",
86
+ paper_id,
87
+ response.status_code,
88
+ )
89
+ # print(f"Request params: {params}")
90
+ logging.info("Request params: %s", params)
91
+
92
+ data = response.json()
93
+ recommendations = data.get("recommendedPapers", [])
94
+
95
+ # Extract paper ID and title from recommendations
96
+ filtered_papers = {
97
+ paper["paperId"]: {
98
+ "Title": paper.get("title", "N/A"),
99
+ "Abstract": paper.get("abstract", "N/A"),
100
+ "Year": paper.get("year", "N/A"),
101
+ "Citation Count": paper.get("citationCount", "N/A"),
102
+ "URL": paper.get("url", "N/A"),
103
+ # "Publication Type": paper.get("publicationTypes", ["N/A"])[0]
104
+ # if paper.get("publicationTypes")
105
+ # else "N/A",
106
+ # "Open Access PDF": paper.get("openAccessPdf", {}).get("url", "N/A")
107
+ # if paper.get("openAccessPdf") is not None
108
+ # else "N/A",
109
+ }
110
+ for paper in recommendations
111
+ if paper.get("title") and paper.get("authors")
112
+ }
113
+
114
+ # Create a DataFrame for pretty printing
115
+ df = pd.DataFrame(filtered_papers)
116
+
117
+ # Format papers for state update
118
+ papers = [
119
+ f"Paper ID: {paper_id}\n"
120
+ f"Title: {paper_data['Title']}\n"
121
+ f"Abstract: {paper_data['Abstract']}\n"
122
+ f"Year: {paper_data['Year']}\n"
123
+ f"Citations: {paper_data['Citation Count']}\n"
124
+ f"URL: {paper_data['URL']}\n"
125
+ # f"Publication Type: {paper_data['Publication Type']}\n"
126
+ # f"Open Access PDF: {paper_data['Open Access PDF']}"
127
+ for paper_id, paper_data in filtered_papers.items()
128
+ ]
129
+
130
+ # Convert DataFrame to markdown table
131
+ markdown_table = df.to_markdown(tablefmt="grid")
132
+ logging.info("Search results: %s", papers)
133
+
134
+ return Command(
135
+ update={
136
+ "papers": filtered_papers, # Now sending the dictionary directly
137
+ "messages": [
138
+ ToolMessage(content=markdown_table, tool_call_id=tool_call_id)
139
+ ],
140
+ }
141
+ )
@@ -1,4 +1,5 @@
1
1
  '''
2
- This file is used to import the datasets, utils, and tools.
2
+ This file is used to import the datasets and utils.
3
3
  '''
4
4
  from . import datasets
5
+ from . import utils
@@ -0,0 +1,39 @@
1
+ """
2
+ Test cases for utils/enrichments/enrichments.py
3
+ """
4
+
5
+ from ..utils.enrichments.enrichments import Enrichments
6
+
7
+ class TestEnrichments(Enrichments):
8
+ """Test implementation of the Enrichments interface for testing purposes."""
9
+
10
+ def enrich_documents(self, texts: list[str]) -> list[list[float]]:
11
+ return [
12
+ f"Additional text description of {text} as the input." for text in texts
13
+ ]
14
+
15
+ def enrich_documents_with_rag(self, texts, docs):
16
+ # Currently we don't have a RAG model to test this method.
17
+ # Thus, we will just call the enrich_documents method instead.
18
+ return self.enrich_documents(texts)
19
+
20
+ def test_enrich_documents():
21
+ """Test enriching documents using the Enrichments interface."""
22
+ enrichments = TestEnrichments()
23
+ texts = ["text1", "text2"]
24
+ result = enrichments.enrich_documents(texts)
25
+ assert result == [
26
+ "Additional text description of text1 as the input.",
27
+ "Additional text description of text2 as the input.",
28
+ ]
29
+
30
+ def test_enrich_documents_with_rag():
31
+ """Test enriching documents with RAG using the Enrichments interface."""
32
+ enrichments = TestEnrichments()
33
+ texts = ["text1", "text2"]
34
+ docs = ["doc1", "doc2"]
35
+ result = enrichments.enrich_documents_with_rag(texts, docs)
36
+ assert result == [
37
+ "Additional text description of text1 as the input.",
38
+ "Additional text description of text2 as the input.",
39
+ ]
@@ -0,0 +1,117 @@
1
+ """
2
+ Test cases for utils/enrichments/ollama.py
3
+ """
4
+
5
+ import pytest
6
+ import ollama
7
+ from ..utils.enrichments.ollama import EnrichmentWithOllama
8
+
9
+ @pytest.fixture(name="ollama_config")
10
+ def fixture_ollama_config():
11
+ """Return a dictionary with Ollama configuration."""
12
+ return {
13
+ "model_name": "smollm2:360m",
14
+ "prompt_enrichment": """
15
+ Given the input as a list of strings, please return the list of addditional information of
16
+ each input terms using your prior knowledge.
17
+
18
+ Example:
19
+ Input: ['acetaminophen', 'aspirin']
20
+ Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
+ 'aspirin is a medication used to treat pain, fever, and inflammation']
22
+
23
+ Do not include any pretext as the output, only the list of strings enriched.
24
+
25
+ Input: {input}
26
+ """,
27
+ "temperature": 0.0,
28
+ "streaming": False,
29
+ }
30
+
31
+ def test_no_model_ollama(ollama_config):
32
+ """Test the case when the Ollama model is not available."""
33
+ cfg = ollama_config
34
+ cfg_model = "smollm2:135m" # Choose a small model
35
+
36
+ # Delete the Ollama model
37
+ try:
38
+ ollama.delete(cfg_model)
39
+ except ollama.ResponseError:
40
+ pass
41
+
42
+ # Check if the model is available
43
+ with pytest.raises(
44
+ ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
45
+ ):
46
+ EnrichmentWithOllama(
47
+ model_name=cfg_model,
48
+ prompt_enrichment=cfg["prompt_enrichment"],
49
+ temperature=cfg["temperature"],
50
+ streaming=cfg["streaming"],
51
+ )
52
+ ollama.delete(cfg_model)
53
+
54
+ def test_enrich_nodes_ollama(ollama_config):
55
+ """Test the Ollama textual enrichment class for node enrichment."""
56
+ # Prepare enrichment model
57
+ cfg = ollama_config
58
+ enr_model = EnrichmentWithOllama(
59
+ model_name=cfg["model_name"],
60
+ prompt_enrichment=cfg["prompt_enrichment"],
61
+ temperature=cfg["temperature"],
62
+ streaming=cfg["streaming"],
63
+ )
64
+
65
+ # Perform enrichment for nodes
66
+ nodes = ["Adalimumab", "Infliximab"]
67
+ enriched_nodes = enr_model.enrich_documents(nodes)
68
+ # Check the enriched nodes
69
+ assert len(enriched_nodes) == 2
70
+ assert all(
71
+ enriched_nodes[i] != nodes[i] for i in range(len(nodes))
72
+ )
73
+
74
+
75
+ def test_enrich_relations_ollama(ollama_config):
76
+ """Test the Ollama textual enrichment class for relation enrichment."""
77
+ # Prepare enrichment model
78
+ cfg = ollama_config
79
+ enr_model = EnrichmentWithOllama(
80
+ model_name=cfg["model_name"],
81
+ prompt_enrichment=cfg["prompt_enrichment"],
82
+ temperature=cfg["temperature"],
83
+ streaming=cfg["streaming"],
84
+ )
85
+ # Perform enrichment for relations
86
+ relations = [
87
+ "IL23R-gene causation disease-inflammatory bowel diseases",
88
+ "NOD2-gene causation disease-inflammatory bowel diseases",
89
+ ]
90
+ enriched_relations = enr_model.enrich_documents(relations)
91
+ # Check the enriched relations
92
+ assert len(enriched_relations) == 2
93
+ assert all(
94
+ enriched_relations[i] != relations[i]
95
+ for i in range(len(relations))
96
+ )
97
+
98
+
99
+ def test_enrich_ollama_rag(ollama_config):
100
+ """Test the Ollama textual enrichment class for enrichment with RAG (not implemented)."""
101
+ # Prepare enrichment model
102
+ cfg = ollama_config
103
+ enr_model = EnrichmentWithOllama(
104
+ model_name=cfg["model_name"],
105
+ prompt_enrichment=cfg["prompt_enrichment"],
106
+ temperature=cfg["temperature"],
107
+ streaming=cfg["streaming"],
108
+ )
109
+ # Perform enrichment for nodes
110
+ nodes = ["Adalimumab", "Infliximab"]
111
+ docs = [r"\path\to\doc1", r"\path\to\doc2"]
112
+ enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
113
+ # Check the enriched nodes
114
+ assert len(enriched_nodes) == 2
115
+ assert all(
116
+ enriched_nodes[i] != nodes[i] for i in range(len(nodes))
117
+ )
@@ -0,0 +1,5 @@
1
+ '''
2
+ This file is used to import utlities.
3
+ '''
4
+ from . import enrichments
5
+ from . import embeddings
@@ -0,0 +1,5 @@
1
+ """
2
+ This package contains modules to use the enrichment model
3
+ """
4
+ from . import enrichments
5
+ from . import ollama
@@ -0,0 +1,36 @@
1
+ """
2
+ Enrichments interface
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+
7
+ class Enrichments(ABC):
8
+ """Interface for enrichment models.
9
+
10
+ This is an interface meant for implementing text enrichment models.
11
+
12
+ Enrichment models are used to enrich node or relation features in a given knowledge graph.
13
+ """
14
+
15
+ @abstractmethod
16
+ def enrich_documents(self, texts: list[str]) -> list[list[str]]:
17
+ """Enrich documents.
18
+
19
+ Args:
20
+ texts: List of documents to enrich.
21
+
22
+ Returns:
23
+ List of enriched documents.
24
+ """
25
+
26
+ @abstractmethod
27
+ def enrich_documents_with_rag(self, texts: list[str], docs: list[str]) -> list[str]:
28
+ """Enrich documents with RAG.
29
+
30
+ Args:
31
+ texts: List of documents to enrich.
32
+ docs: List of reference documents to enrich the input texts.
33
+
34
+ Returns:
35
+ List of enriched documents with RAG.
36
+ """