aiagents4pharma 1.17.1__py3-none-any.whl → 1.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. aiagents4pharma/talk2biomodels/agents/t2b_agent.py +4 -4
  2. aiagents4pharma/talk2biomodels/configs/tools/ask_question/default.yaml +7 -15
  3. aiagents4pharma/talk2biomodels/states/state_talk2biomodels.py +4 -1
  4. aiagents4pharma/talk2biomodels/tests/test_ask_question.py +4 -2
  5. aiagents4pharma/talk2biomodels/tests/test_get_annotation.py +4 -2
  6. aiagents4pharma/talk2biomodels/tests/test_integration.py +34 -30
  7. aiagents4pharma/talk2biomodels/tests/test_query_article.py +7 -1
  8. aiagents4pharma/talk2biomodels/tests/test_search_models.py +3 -1
  9. aiagents4pharma/talk2biomodels/tests/test_steady_state.py +6 -3
  10. aiagents4pharma/talk2biomodels/tools/ask_question.py +1 -2
  11. aiagents4pharma/talk2biomodels/tools/custom_plotter.py +23 -10
  12. aiagents4pharma/talk2biomodels/tools/get_annotation.py +11 -10
  13. aiagents4pharma/talk2biomodels/tools/query_article.py +6 -2
  14. aiagents4pharma/talk2biomodels/tools/search_models.py +8 -2
  15. aiagents4pharma/talk2knowledgegraphs/__init__.py +3 -0
  16. aiagents4pharma/talk2knowledgegraphs/agents/__init__.py +4 -0
  17. aiagents4pharma/talk2knowledgegraphs/agents/t2kg_agent.py +85 -0
  18. aiagents4pharma/talk2knowledgegraphs/configs/__init__.py +7 -0
  19. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/__init__.py +3 -0
  20. aiagents4pharma/talk2knowledgegraphs/configs/agents/t2kg_agent/default.yaml +62 -0
  21. aiagents4pharma/talk2knowledgegraphs/configs/app/__init__.py +4 -0
  22. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/__init__.py +3 -0
  23. aiagents4pharma/talk2knowledgegraphs/configs/app/frontend/default.yaml +31 -0
  24. aiagents4pharma/talk2knowledgegraphs/configs/config.yaml +7 -0
  25. aiagents4pharma/talk2knowledgegraphs/configs/tools/__init__.py +6 -0
  26. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/__init__.py +3 -0
  27. aiagents4pharma/talk2knowledgegraphs/configs/tools/graphrag_reasoning/default.yaml +24 -0
  28. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/__init__.py +3 -0
  29. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_extraction/default.yaml +43 -0
  30. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/__init__.py +3 -0
  31. aiagents4pharma/talk2knowledgegraphs/configs/tools/subgraph_summarization/default.yaml +9 -0
  32. aiagents4pharma/talk2knowledgegraphs/states/__init__.py +4 -0
  33. aiagents4pharma/talk2knowledgegraphs/states/state_talk2knowledgegraphs.py +38 -0
  34. aiagents4pharma/talk2knowledgegraphs/tests/test_agents_t2kg_agent.py +110 -0
  35. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_graphrag_reasoning.py +210 -0
  36. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_extraction.py +174 -0
  37. aiagents4pharma/talk2knowledgegraphs/tests/test_tools_subgraph_summarization.py +154 -0
  38. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_huggingface.py +0 -1
  39. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_embeddings_ollama.py +56 -0
  40. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_enrichments_ollama.py +18 -42
  41. aiagents4pharma/talk2knowledgegraphs/tests/test_utils_kg_utils.py +79 -0
  42. aiagents4pharma/talk2knowledgegraphs/tools/__init__.py +6 -0
  43. aiagents4pharma/talk2knowledgegraphs/tools/graphrag_reasoning.py +143 -0
  44. aiagents4pharma/talk2knowledgegraphs/tools/load_arguments.py +22 -0
  45. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_extraction.py +305 -0
  46. aiagents4pharma/talk2knowledgegraphs/tools/subgraph_summarization.py +126 -0
  47. aiagents4pharma/talk2knowledgegraphs/utils/__init__.py +4 -2
  48. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/__init__.py +1 -0
  49. aiagents4pharma/talk2knowledgegraphs/utils/embeddings/ollama.py +81 -0
  50. aiagents4pharma/talk2knowledgegraphs/utils/extractions/__init__.py +4 -0
  51. aiagents4pharma/talk2knowledgegraphs/utils/extractions/pcst.py +225 -0
  52. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/METADATA +12 -3
  53. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/RECORD +56 -24
  54. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/LICENSE +0 -0
  55. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/WHEEL +0 -0
  56. {aiagents4pharma-1.17.1.dist-info → aiagents4pharma-1.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,174 @@
1
+ """
2
+ Test cases for tools/subgraph_extraction.py
3
+ """
4
+
5
+ import pytest
6
+ from langchain_core.messages import HumanMessage
7
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
+ from ..agents.t2kg_agent import get_app
9
+
10
+ # Define the data path
11
+ DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
+
13
+
14
+ @pytest.fixture(name="input_dict")
15
+ def input_dict_fixture():
16
+ """
17
+ Input dictionary fixture.
18
+ """
19
+ input_dict = {
20
+ "llm_model": None, # TBA for each test case
21
+ "embedding_model": None, # TBA for each test case
22
+ "uploaded_files": [],
23
+ "topk_nodes": 3,
24
+ "topk_edges": 3,
25
+ "dic_source_graph": [
26
+ {
27
+ "name": "PrimeKG",
28
+ "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
29
+ "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
30
+ }
31
+ ],
32
+ }
33
+
34
+ return input_dict
35
+
36
+
37
+ def test_extract_subgraph_wo_docs(input_dict):
38
+ """
39
+ Test the subgraph extraction tool without any documents using OpenAI model.
40
+
41
+ Args:
42
+ input_dict: Input dictionary.
43
+ """
44
+ # Prepare LLM and embedding model
45
+ input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
46
+ input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
47
+
48
+ # Setup the app
49
+ unique_id = 12345
50
+ app = get_app(unique_id, llm_model=input_dict["llm_model"])
51
+ config = {"configurable": {"thread_id": unique_id}}
52
+ # Update state
53
+ app.update_state(
54
+ config,
55
+ input_dict,
56
+ )
57
+ prompt = """
58
+ Please directly invoke `subgraph_extraction` tool without calling any other tools
59
+ to respond to the following prompt:
60
+
61
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
62
+ (IBD) that existed in the knowledge graph.
63
+ Please set the extraction name for this process as `subkg_12345`.
64
+ """
65
+
66
+ # Test the tool subgraph_extraction
67
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
68
+
69
+ # Check assistant message
70
+ assistant_msg = response["messages"][-1].content
71
+ assert isinstance(assistant_msg, str)
72
+
73
+ # Check tool message
74
+ tool_msg = response["messages"][-2]
75
+ assert tool_msg.name == "subgraph_extraction"
76
+
77
+ # Check extracted subgraph dictionary
78
+ current_state = app.get_state(config)
79
+ dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
80
+ assert isinstance(dic_extracted_graph, dict)
81
+ assert dic_extracted_graph["name"] == "subkg_12345"
82
+ assert dic_extracted_graph["graph_source"] == "PrimeKG"
83
+ assert dic_extracted_graph["topk_nodes"] == 3
84
+ assert dic_extracted_graph["topk_edges"] == 3
85
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
86
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
87
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
88
+ assert isinstance(dic_extracted_graph["graph_text"], str)
89
+ # Check if the nodes are in the graph_text
90
+ assert all(
91
+ n[0] in dic_extracted_graph["graph_text"]
92
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
93
+ )
94
+ # Check if the edges are in the graph_text
95
+ assert all(
96
+ ",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
97
+ in dic_extracted_graph["graph_text"]
98
+ for e in dic_extracted_graph["graph_dict"]["edges"]
99
+ )
100
+
101
+
102
+ def test_extract_subgraph_w_docs(input_dict):
103
+ """
104
+ Test the subgraph extraction tool with a document as reference (i.e., endotype document)
105
+ using OpenAI model.
106
+
107
+ Args:
108
+ input_dict: Input dictionary.
109
+ """
110
+ # Prepare LLM and embedding model
111
+ input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
112
+ input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
113
+
114
+ # Setup the app
115
+ unique_id = 12345
116
+ app = get_app(unique_id, llm_model=input_dict["llm_model"])
117
+ config = {"configurable": {"thread_id": unique_id}}
118
+ # Update state
119
+ input_dict["uploaded_files"] = [
120
+ {
121
+ "file_name": "DGE_human_Colon_UC-vs-Colon_Control.pdf",
122
+ "file_path": f"{DATA_PATH}/DGE_human_Colon_UC-vs-Colon_Control.pdf",
123
+ "file_type": "endotype",
124
+ "uploaded_by": "VPEUser",
125
+ "uploaded_timestamp": "2024-11-05 00:00:00",
126
+ }
127
+ ]
128
+ app.update_state(
129
+ config,
130
+ input_dict,
131
+ )
132
+ prompt = """
133
+ Please ONLY invoke `subgraph_extraction` tool without calling any other tools
134
+ to respond to the following prompt:
135
+
136
+ Extract all relevant information related to nodes of genes related to inflammatory bowel disease
137
+ (IBD) that existed in the knowledge graph.
138
+ Please set the extraction name for this process as `subkg_12345`.
139
+ """
140
+
141
+ # Test the tool subgraph_extraction
142
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
143
+
144
+ # Check assistant message
145
+ assistant_msg = response["messages"][-1].content
146
+ assert isinstance(assistant_msg, str)
147
+
148
+ # Check tool message
149
+ tool_msg = response["messages"][-2]
150
+ assert tool_msg.name == "subgraph_extraction"
151
+
152
+ # Check extracted subgraph dictionary
153
+ current_state = app.get_state(config)
154
+ dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
155
+ assert isinstance(dic_extracted_graph, dict)
156
+ assert dic_extracted_graph["name"] == "subkg_12345"
157
+ assert dic_extracted_graph["graph_source"] == "PrimeKG"
158
+ assert dic_extracted_graph["topk_nodes"] == 3
159
+ assert dic_extracted_graph["topk_edges"] == 3
160
+ assert isinstance(dic_extracted_graph["graph_dict"], dict)
161
+ assert len(dic_extracted_graph["graph_dict"]["nodes"]) > 0
162
+ assert len(dic_extracted_graph["graph_dict"]["edges"]) > 0
163
+ assert isinstance(dic_extracted_graph["graph_text"], str)
164
+ # Check if the nodes are in the graph_text
165
+ assert all(
166
+ n[0] in dic_extracted_graph["graph_text"]
167
+ for n in dic_extracted_graph["graph_dict"]["nodes"]
168
+ )
169
+ # Check if the edges are in the graph_text
170
+ assert all(
171
+ ",".join([e[0], '"' + str(tuple(e[2]["relation"])) + '"', e[1]])
172
+ in dic_extracted_graph["graph_text"]
173
+ for e in dic_extracted_graph["graph_dict"]["edges"]
174
+ )
@@ -0,0 +1,154 @@
1
+ """
2
+ Test cases for tools/subgraph_summarization.py
3
+ """
4
+
5
+ import pytest
6
+ from langchain_core.messages import HumanMessage
7
+ from langchain_openai import ChatOpenAI, OpenAIEmbeddings
8
+ from ..agents.t2kg_agent import get_app
9
+
10
+ # Define the data path
11
+ DATA_PATH = "aiagents4pharma/talk2knowledgegraphs/tests/files"
12
+
13
+
14
+ @pytest.fixture(name="input_dict")
15
+ def input_dict_fixture():
16
+ """
17
+ Input dictionary fixture.
18
+ """
19
+ input_dict = {
20
+ "llm_model": None, # TBA for each test case
21
+ "embedding_model": None, # TBA for each test case
22
+ "uploaded_files": [],
23
+ "topk_nodes": 3,
24
+ "topk_edges": 3,
25
+ "dic_source_graph": [
26
+ {
27
+ "name": "PrimeKG",
28
+ "kg_pyg_path": f"{DATA_PATH}/primekg_ibd_pyg_graph.pkl",
29
+ "kg_text_path": f"{DATA_PATH}/primekg_ibd_text_graph.pkl",
30
+ }
31
+ ],
32
+ "dic_extracted_graph": [
33
+ {
34
+ "name": "subkg_12345",
35
+ "tool_call_id": "tool_12345",
36
+ "graph_source": "PrimeKG",
37
+ "topk_nodes": 3,
38
+ "topk_edges": 3,
39
+ "graph_dict": {
40
+ 'nodes': [('IFNG_(3495)', {}),
41
+ ('IKBKG_(3672)', {}),
42
+ ('ATG16L1_(6661)', {}),
43
+ ('inflammatory bowel disease_(28158)', {}),
44
+ ('Crohn ileitis and jejunitis_(35814)', {}),
45
+ ("Crohn's colitis_(83770)", {})],
46
+ 'edges': [('IFNG_(3495)', 'inflammatory bowel disease_(28158)',
47
+ {'relation': ['gene/protein', 'associated with', 'disease'],
48
+ 'label': ['gene/protein', 'associated with', 'disease']}),
49
+ ('IFNG_(3495)', "Crohn's colitis_(83770)",
50
+ {'relation': ['gene/protein', 'associated with', 'disease'],
51
+ 'label': ['gene/protein', 'associated with', 'disease']}),
52
+ ('IFNG_(3495)', 'Crohn ileitis and jejunitis_(35814)',
53
+ {'relation': ['gene/protein', 'associated with', 'disease'],
54
+ 'label': ['gene/protein', 'associated with', 'disease']}),
55
+ ('ATG16L1_(6661)', 'IKBKG_(3672)',
56
+ {'relation': ['gene/protein', 'ppi', 'gene/protein'],
57
+ 'label': ['gene/protein', 'ppi', 'gene/protein']}),
58
+ ("Crohn's colitis_(83770)", 'ATG16L1_(6661)',
59
+ {'relation': ['disease', 'associated with', 'gene/protein'],
60
+ 'label': ['disease', 'associated with', 'gene/protein']})]},
61
+ "graph_text": """
62
+ node_id,node_attr
63
+ IFNG_(3495),"IFNG belongs to gene/protein category.
64
+ This gene encodes a soluble cytokine that is a member of the type II interferon class.
65
+ The encoded protein is secreted by cells of both the innate and adaptive immune systems.
66
+ The active protein is a homodimer that binds to the interferon gamma receptor
67
+ which triggers a cellular response to viral and microbial infections.
68
+ Mutations in this gene are associated with an increased susceptibility to viral,
69
+ bacterial and parasitic infections and to several autoimmune diseases.
70
+ [provided by RefSeq, Dec 2015]."
71
+ IKBKG_(3672),"IKBKG belongs to gene/protein category. This gene encodes the regulatory
72
+ subunit of the inhibitor of kappaB kinase (IKK) complex, which activates NF-kappaB
73
+ resulting in activation of genes involved in inflammation, immunity, cell survival,
74
+ and other pathways. Mutations in this gene result in incontinentia pigmenti,
75
+ hypohidrotic ectodermal dysplasia, and several other types of immunodeficiencies.
76
+ A pseudogene highly similar to this locus is located in an adjacent region of the
77
+ X chromosome. [provided by RefSeq, Mar 2016]."
78
+ ATG16L1_(6661),"ATG16L1 belongs to gene/protein category. The protein encoded
79
+ by this gene is part of a large protein complex that is necessary for autophagy,
80
+ the major process by which intracellular components are targeted to lysosomes
81
+ for degradation. Defects in this gene are a cause of susceptibility to inflammatory
82
+ bowel disease type 10 (IBD10). Several transcript variants encoding different
83
+ isoforms have been found for this gene.[provided by RefSeq, Jun 2010]."
84
+ inflammatory bowel disease_(28158),inflammatory bowel disease belongs to disease
85
+ category. Any inflammatory bowel disease in which the cause of the disease
86
+ is a mutation in the NOD2 gene.
87
+ Crohn ileitis and jejunitis_(35814),Crohn ileitis and jejunitis belongs to
88
+ disease category. An Crohn disease involving a pathogenic inflammatory
89
+ response in the ileum.
90
+ Crohn's colitis_(83770),Crohn's colitis belongs to disease category.
91
+ Crohn's disease affecting the colon.
92
+
93
+ head_id,edge_type,tail_id
94
+ Crohn's colitis_(83770),"('disease', 'associated with', 'gene/protein')",
95
+ ATG16L1_(6661)
96
+ ATG16L1_(6661),"('gene/protein', 'ppi', 'gene/protein')",IKBKG_(3672)
97
+ IFNG_(3495),"('gene/protein', 'associated with', 'disease')",
98
+ inflammatory bowel disease_(28158)
99
+ IFNG_(3495),"('gene/protein', 'associated with', 'disease')",Crohn's colitis_(83770)
100
+ IFNG_(3495),"('gene/protein', 'associated with', 'disease')",
101
+ Crohn ileitis and jejunitis_(35814)
102
+ """,
103
+ "graph_summary": None,
104
+ }
105
+ ],
106
+ }
107
+
108
+ return input_dict
109
+
110
+
111
+ def test_summarize_subgraph(input_dict):
112
+ """
113
+ Test the subgraph summarization tool without any documents using Ollama model.
114
+
115
+ Args:
116
+ input_dict: Input dictionary fixture.
117
+ """
118
+ # Prepare LLM and embedding model
119
+ input_dict["llm_model"] = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
120
+ input_dict["embedding_model"] = OpenAIEmbeddings(model="text-embedding-3-small")
121
+
122
+ # Setup the app
123
+ unique_id = 12345
124
+ app = get_app(unique_id, llm_model=input_dict["llm_model"])
125
+ config = {"configurable": {"thread_id": unique_id}}
126
+ # Update state
127
+ app.update_state(
128
+ config,
129
+ input_dict,
130
+ )
131
+ prompt = """
132
+ Please directly invoke `subgraph_summarization` tool without calling any other tools
133
+ to respond to the following prompt:
134
+
135
+ You are given a subgraph in the forms of textualized subgraph representing
136
+ nodes and edges (triples) obtained from extraction_name `subkg_12345`.
137
+ Summarize the given subgraph and higlight the importance nodes and edges.
138
+ """
139
+
140
+ # Test the tool subgraph_summarization
141
+ response = app.invoke({"messages": [HumanMessage(content=prompt)]}, config=config)
142
+
143
+ # Check assistant message
144
+ assistant_msg = response["messages"][-1].content
145
+ assert isinstance(assistant_msg, str)
146
+
147
+ # Check tool message
148
+ tool_msg = response["messages"][-2]
149
+ assert tool_msg.name == "subgraph_summarization"
150
+
151
+ # Check summarized subgraph
152
+ current_state = app.get_state(config)
153
+ dic_extracted_graph = current_state.values["dic_extracted_graph"][0]
154
+ assert isinstance(dic_extracted_graph["graph_summary"], str)
@@ -31,7 +31,6 @@ def test_embedding_with_huggingface_embed_query(embedding_model):
31
31
  # Check the result
32
32
  assert len(result) == 768
33
33
 
34
-
35
34
  def test_embedding_with_huggingface_failed():
36
35
  """Test embedding documents using the EmbeddingWithHuggingFace class."""
37
36
  # Check if the model is available on HuggingFace Hub
@@ -0,0 +1,56 @@
1
+ """
2
+ Test cases for utils/embeddings/ollama.py
3
+ """
4
+
5
+ import pytest
6
+ import ollama
7
+ from ..utils.embeddings.ollama import EmbeddingWithOllama
8
+
9
+ @pytest.fixture(name="ollama_config")
10
+ def fixture_ollama_config():
11
+ """Return a dictionary with Ollama configuration."""
12
+ return {
13
+ "model_name": "all-minilm", # Choose a small model
14
+ }
15
+
16
+ def test_no_model_ollama(ollama_config):
17
+ """Test the case when the Ollama model is not available."""
18
+ cfg = ollama_config
19
+
20
+ # Delete the Ollama model
21
+ try:
22
+ ollama.delete(cfg["model_name"])
23
+ except ollama.ResponseError:
24
+ pass
25
+
26
+ # Check if the model is available
27
+ with pytest.raises(
28
+ ValueError, match=f"Error: Pulled {cfg["model_name"]} model and restarted Ollama server."
29
+ ):
30
+ EmbeddingWithOllama(model_name=cfg["model_name"])
31
+
32
+ @pytest.fixture(name="embedding_model")
33
+ def embedding_model_fixture(ollama_config):
34
+ """Return the configuration object for the Ollama embedding model and model object"""
35
+ cfg = ollama_config
36
+ return EmbeddingWithOllama(model_name=cfg["model_name"])
37
+
38
+ def test_embedding_with_ollama_embed_documents(embedding_model):
39
+ """Test embedding documents using the EmbeddingWithOllama class."""
40
+ # Perform embedding
41
+ texts = ["Adalimumab", "Infliximab", "Vedolizumab"]
42
+ result = embedding_model.embed_documents(texts)
43
+ # Check the result
44
+ assert len(result) == 3
45
+ assert len(result[0]) == 384
46
+
47
+ def test_embedding_with_ollama_embed_query(embedding_model):
48
+ """Test embedding a query using the EmbeddingWithOllama class."""
49
+ # Perform embedding
50
+ text = "Adalimumab"
51
+ result = embedding_model.embed_query(text)
52
+ # Check the result
53
+ assert len(result) == 384
54
+
55
+ # Delete the Ollama model so that it will not be cached afterward
56
+ ollama.delete(embedding_model.model_name)
@@ -6,20 +6,21 @@ import pytest
6
6
  import ollama
7
7
  from ..utils.enrichments.ollama import EnrichmentWithOllama
8
8
 
9
+
9
10
  @pytest.fixture(name="ollama_config")
10
11
  def fixture_ollama_config():
11
12
  """Return a dictionary with Ollama configuration."""
12
13
  return {
13
- "model_name": "smollm2:360m",
14
+ "model_name": "llama3.2:1b",
14
15
  "prompt_enrichment": """
15
- Given the input as a list of strings, please return the list of addditional information of
16
- each input terms using your prior knowledge.
16
+ Given the input as a list of strings, please return the list of addditional information
17
+ of each input terms using your prior knowledge.
17
18
 
18
19
  Example:
19
20
  Input: ['acetaminophen', 'aspirin']
20
- Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
+ Ouput: ['acetaminophen is a medication used to treat pain and fever',
21
22
  'aspirin is a medication used to treat pain, fever, and inflammation']
22
-
23
+
23
24
  Do not include any pretext as the output, only the list of strings enriched.
24
25
 
25
26
  Input: {input}
@@ -28,10 +29,11 @@ def fixture_ollama_config():
28
29
  "streaming": False,
29
30
  }
30
31
 
32
+
31
33
  def test_no_model_ollama(ollama_config):
32
34
  """Test the case when the Ollama model is not available."""
33
35
  cfg = ollama_config
34
- cfg_model = "smollm2:135m" # Choose a small model
36
+ cfg_model = "smollm2:135m" # Choose a small model
35
37
 
36
38
  # Delete the Ollama model
37
39
  try:
@@ -41,7 +43,8 @@ def test_no_model_ollama(ollama_config):
41
43
 
42
44
  # Check if the model is available
43
45
  with pytest.raises(
44
- ValueError, match=f"Error: Pulled {cfg_model} model and restarted Ollama server."
46
+ ValueError,
47
+ match=f"Error: Pulled {cfg_model} model and restarted Ollama server.",
45
48
  ):
46
49
  EnrichmentWithOllama(
47
50
  model_name=cfg_model,
@@ -51,7 +54,8 @@ def test_no_model_ollama(ollama_config):
51
54
  )
52
55
  ollama.delete(cfg_model)
53
56
 
54
- def test_enrich_nodes_ollama(ollama_config):
57
+
58
+ def test_enrich_ollama(ollama_config):
55
59
  """Test the Ollama textual enrichment class for node enrichment."""
56
60
  # Prepare enrichment model
57
61
  cfg = ollama_config
@@ -63,37 +67,11 @@ def test_enrich_nodes_ollama(ollama_config):
63
67
  )
64
68
 
65
69
  # Perform enrichment for nodes
66
- nodes = ["Adalimumab", "Infliximab"]
70
+ nodes = ["acetaminophen"]
67
71
  enriched_nodes = enr_model.enrich_documents(nodes)
68
72
  # Check the enriched nodes
69
- assert len(enriched_nodes) == 2
70
- assert all(
71
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
72
- )
73
-
74
-
75
- def test_enrich_relations_ollama(ollama_config):
76
- """Test the Ollama textual enrichment class for relation enrichment."""
77
- # Prepare enrichment model
78
- cfg = ollama_config
79
- enr_model = EnrichmentWithOllama(
80
- model_name=cfg["model_name"],
81
- prompt_enrichment=cfg["prompt_enrichment"],
82
- temperature=cfg["temperature"],
83
- streaming=cfg["streaming"],
84
- )
85
- # Perform enrichment for relations
86
- relations = [
87
- "IL23R-gene causation disease-inflammatory bowel diseases",
88
- "NOD2-gene causation disease-inflammatory bowel diseases",
89
- ]
90
- enriched_relations = enr_model.enrich_documents(relations)
91
- # Check the enriched relations
92
- assert len(enriched_relations) == 2
93
- assert all(
94
- enriched_relations[i] != relations[i]
95
- for i in range(len(relations))
96
- )
73
+ assert len(enriched_nodes) == 1
74
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
97
75
 
98
76
 
99
77
  def test_enrich_ollama_rag(ollama_config):
@@ -107,11 +85,9 @@ def test_enrich_ollama_rag(ollama_config):
107
85
  streaming=cfg["streaming"],
108
86
  )
109
87
  # Perform enrichment for nodes
110
- nodes = ["Adalimumab", "Infliximab"]
88
+ nodes = ["acetaminophen"]
111
89
  docs = [r"\path\to\doc1", r"\path\to\doc2"]
112
90
  enriched_nodes = enr_model.enrich_documents_with_rag(nodes, docs)
113
91
  # Check the enriched nodes
114
- assert len(enriched_nodes) == 2
115
- assert all(
116
- enriched_nodes[i] != nodes[i] for i in range(len(nodes))
117
- )
92
+ assert len(enriched_nodes) == 1
93
+ assert all(enriched_nodes[i] != nodes[i] for i in range(len(nodes)))
@@ -0,0 +1,79 @@
1
+ """
2
+ Test cases for utils/kg_utils.py
3
+ """
4
+
5
+ import pytest
6
+ import networkx as nx
7
+ import pandas as pd
8
+ from ..utils import kg_utils
9
+
10
+
11
+ @pytest.fixture(name="sample_graph")
12
+ def make_sample_graph():
13
+ """Return a sample graph"""
14
+ sg = nx.Graph()
15
+ sg.add_node(1, node_id=1, feature_id="A", feature_value="ValueA")
16
+ sg.add_node(2, node_id=2, feature_id="B", feature_value="ValueB")
17
+ sg.add_edge(1, 2, edge_id=1, feature_id="E", feature_value="EdgeValue")
18
+ return sg
19
+
20
+
21
+ def test_kg_to_df_pandas(sample_graph):
22
+ """Test the kg_to_df_pandas function"""
23
+ df_nodes, df_edges = kg_utils.kg_to_df_pandas(sample_graph)
24
+ print(df_nodes)
25
+ expected_nodes_data = {
26
+ "node_id": [1, 2],
27
+ "feature_id": ["A", "B"],
28
+ "feature_value": ["ValueA", "ValueB"],
29
+ }
30
+ expected_nodes_df = pd.DataFrame(expected_nodes_data, index=[1, 2])
31
+ print(expected_nodes_df)
32
+ expected_edges_data = {
33
+ "node_source": [1],
34
+ "node_target": [2],
35
+ "edge_id": [1],
36
+ "feature_id": ["E"],
37
+ "feature_value": ["EdgeValue"],
38
+ }
39
+ expected_edges_df = pd.DataFrame(expected_edges_data)
40
+
41
+ # Assert that the dataframes are equal but the order of columns may be different
42
+ # Ignore the index of the dataframes
43
+ pd.testing.assert_frame_equal(df_nodes, expected_nodes_df, check_like=True)
44
+ pd.testing.assert_frame_equal(df_edges, expected_edges_df, check_like=True)
45
+
46
+
47
+ def test_df_pandas_to_kg():
48
+ """Test the df_pandas_to_kg function"""
49
+ nodes_data = {
50
+ "node_id": [1, 2],
51
+ "feature_id": ["A", "B"],
52
+ "feature_value": ["ValueA", "ValueB"],
53
+ }
54
+ df_nodes_attrs = pd.DataFrame(nodes_data).set_index("node_id")
55
+
56
+ edges_data = {
57
+ "node_source": [1],
58
+ "node_target": [2],
59
+ "edge_id": [1],
60
+ "feature_id": ["E"],
61
+ "feature_value": ["EdgeValue"],
62
+ }
63
+ df_edges = pd.DataFrame(edges_data)
64
+
65
+ kg = kg_utils.df_pandas_to_kg(
66
+ df_edges, df_nodes_attrs, "node_source", "node_target"
67
+ )
68
+
69
+ assert len(kg.nodes) == 2
70
+ assert len(kg.edges) == 1
71
+
72
+ assert kg.nodes[1]["feature_id"] == "A"
73
+ assert kg.nodes[1]["feature_value"] == "ValueA"
74
+ assert kg.nodes[2]["feature_id"] == "B"
75
+ assert kg.nodes[2]["feature_value"] == "ValueB"
76
+
77
+ assert kg.edges[1, 2]["feature_id"] == "E"
78
+ assert kg.edges[1, 2]["feature_value"] == "EdgeValue"
79
+ assert kg.edges[1, 2]["edge_id"] == 1
@@ -0,0 +1,6 @@
1
+ '''
2
+ This file is used to import all the models in the package.
3
+ '''
4
+ from . import subgraph_extraction
5
+ from . import subgraph_summarization
6
+ from . import graphrag_reasoning