tavily-python 0.7.5__tar.gz → 0.7.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. {tavily_python-0.7.5 → tavily_python-0.7.7}/PKG-INFO +1 -1
  2. {tavily_python-0.7.5 → tavily_python-0.7.7}/setup.py +1 -1
  3. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/async_tavily.py +27 -21
  4. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/tavily.py +27 -21
  5. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily_python.egg-info/PKG-INFO +1 -1
  6. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily_python.egg-info/SOURCES.txt +3 -1
  7. tavily_python-0.7.7/tests/test_async_search.py +219 -0
  8. {tavily_python-0.7.5 → tavily_python-0.7.7}/tests/test_crawl.py +2 -0
  9. {tavily_python-0.7.5 → tavily_python-0.7.7}/tests/test_map.py +2 -0
  10. {tavily_python-0.7.5 → tavily_python-0.7.7}/tests/test_search.py +2 -0
  11. tavily_python-0.7.7/tests/test_sync_search.py +219 -0
  12. {tavily_python-0.7.5 → tavily_python-0.7.7}/LICENSE +0 -0
  13. {tavily_python-0.7.5 → tavily_python-0.7.7}/README.md +0 -0
  14. {tavily_python-0.7.5 → tavily_python-0.7.7}/setup.cfg +0 -0
  15. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/__init__.py +0 -0
  16. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/config.py +0 -0
  17. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/errors.py +0 -0
  18. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/hybrid_rag/__init__.py +0 -0
  19. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/hybrid_rag/hybrid_rag.py +0 -0
  20. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily/utils.py +0 -0
  21. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily_python.egg-info/dependency_links.txt +0 -0
  22. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily_python.egg-info/requires.txt +0 -0
  23. {tavily_python-0.7.5 → tavily_python-0.7.7}/tavily_python.egg-info/top_level.txt +0 -0
  24. {tavily_python-0.7.5 → tavily_python-0.7.7}/tests/test_errors.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tavily-python
3
- Version: 0.7.5
3
+ Version: 0.7.7
4
4
  Summary: Python wrapper for the Tavily API
5
5
  Home-page: https://github.com/tavily-ai/tavily-python
6
6
  Author: Tavily AI
@@ -5,7 +5,7 @@ with open('README.md', 'r', encoding='utf-8') as f:
5
5
 
6
6
  setup(
7
7
  name='tavily-python',
8
- version='0.7.5',
8
+ version='0.7.7',
9
9
  url='https://github.com/tavily-ai/tavily-python',
10
10
  author='Tavily AI',
11
11
  author_email='support@tavily.com',
@@ -42,7 +42,8 @@ class AsyncTavilyClient:
42
42
  self._client_creator = lambda: httpx.AsyncClient(
43
43
  headers={
44
44
  "Content-Type": "application/json",
45
- "Authorization": f"Bearer {api_key}"
45
+ "Authorization": f"Bearer {api_key}",
46
+ "X-Client-Source": "tavily-python"
46
47
  },
47
48
  base_url="https://api.tavily.com",
48
49
  mounts=proxy_mounts
@@ -52,16 +53,16 @@ class AsyncTavilyClient:
52
53
  async def _search(
53
54
  self,
54
55
  query: str,
55
- search_depth: Literal["basic", "advanced"] = "basic",
56
- topic: Literal["general", "news", "finance"] = "general",
56
+ search_depth: Literal["basic", "advanced"] = None,
57
+ topic: Literal["general", "news", "finance"] = None,
57
58
  time_range: Literal["day", "week", "month", "year"] = None,
58
- days: int = 7,
59
- max_results: int = 5,
59
+ days: int = None,
60
+ max_results: int = None,
60
61
  include_domains: Sequence[str] = None,
61
62
  exclude_domains: Sequence[str] = None,
62
- include_answer: Union[bool, Literal["basic", "advanced"]] = False,
63
- include_raw_content: Union[bool, Literal["markdown", "text"]] = False,
64
- include_images: bool = False,
63
+ include_answer: Union[bool, Literal["basic", "advanced"]] = None,
64
+ include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
65
+ include_images: bool = None,
65
66
  timeout: int = 60,
66
67
  country: str = None,
67
68
  **kwargs,
@@ -84,6 +85,8 @@ class AsyncTavilyClient:
84
85
  "country": country,
85
86
  }
86
87
 
88
+ data = {k: v for k, v in data.items() if v is not None}
89
+
87
90
  if kwargs:
88
91
  data.update(kwargs)
89
92
 
@@ -117,16 +120,16 @@ class AsyncTavilyClient:
117
120
 
118
121
  async def search(self,
119
122
  query: str,
120
- search_depth: Literal["basic", "advanced"] = "basic",
121
- topic: Literal["general", "news", "finance"] = "general",
123
+ search_depth: Literal["basic", "advanced"] = None,
124
+ topic: Literal["general", "news", "finance"] = None,
122
125
  time_range: Literal["day", "week", "month", "year"] = None,
123
- days: int = 7,
124
- max_results: int = 5,
126
+ days: int = None,
127
+ max_results: int = None,
125
128
  include_domains: Sequence[str] = None,
126
129
  exclude_domains: Sequence[str] = None,
127
- include_answer: Union[bool, Literal["basic", "advanced"]] = False,
128
- include_raw_content: Union[bool, Literal["markdown", "text"]] = False,
129
- include_images: bool = False,
130
+ include_answer: Union[bool, Literal["basic", "advanced"]] = None,
131
+ include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
132
+ include_images: bool = None,
130
133
  timeout: int = 60,
131
134
  country: str = None,
132
135
  **kwargs, # Accept custom arguments
@@ -160,9 +163,9 @@ class AsyncTavilyClient:
160
163
  async def _extract(
161
164
  self,
162
165
  urls: Union[List[str], str],
163
- include_images: bool = False,
164
- extract_depth: Literal["basic", "advanced"] = "basic",
165
- format: Literal["markdown", "text"] = "markdown",
166
+ include_images: bool = None,
167
+ extract_depth: Literal["basic", "advanced"] = None,
168
+ format: Literal["markdown", "text"] = None,
166
169
  timeout: int = 60,
167
170
  **kwargs
168
171
  ) -> dict:
@@ -175,6 +178,9 @@ class AsyncTavilyClient:
175
178
  "extract_depth": extract_depth,
176
179
  "format": format,
177
180
  }
181
+
182
+ data = {k: v for k, v in data.items() if v is not None}
183
+
178
184
  if kwargs:
179
185
  data.update(kwargs)
180
186
 
@@ -209,9 +215,9 @@ class AsyncTavilyClient:
209
215
 
210
216
  async def extract(self,
211
217
  urls: Union[List[str], str], # Accept a list of URLs or a single URL
212
- include_images: bool = False,
213
- extract_depth: Literal["basic", "advanced"] = "basic",
214
- format: Literal["markdown", "text"] = "markdown",
218
+ include_images: bool = None,
219
+ extract_depth: Literal["basic", "advanced"] = None,
220
+ format: Literal["markdown", "text"] = None,
215
221
  timeout: int = 60,
216
222
  **kwargs, # Accept custom arguments
217
223
  ) -> dict:
@@ -32,21 +32,22 @@ class TavilyClient:
32
32
  self.proxies = resolved_proxies
33
33
  self.headers = {
34
34
  "Content-Type": "application/json",
35
- "Authorization": f"Bearer {self.api_key}"
35
+ "Authorization": f"Bearer {self.api_key}",
36
+ "X-Client-Source": "tavily-python"
36
37
  }
37
38
 
38
39
  def _search(self,
39
40
  query: str,
40
- search_depth: Literal["basic", "advanced"] = "basic",
41
- topic: Literal["general", "news", "finance"] = "general",
41
+ search_depth: Literal["basic", "advanced"] = None,
42
+ topic: Literal["general", "news", "finance"] = None,
42
43
  time_range: Literal["day", "week", "month", "year"] = None,
43
- days: int = 7,
44
- max_results: int = 5,
44
+ days: int = None,
45
+ max_results: int = None,
45
46
  include_domains: Sequence[str] = None,
46
47
  exclude_domains: Sequence[str] = None,
47
- include_answer: Union[bool, Literal["basic", "advanced"]] = False,
48
- include_raw_content: Union[bool, Literal["markdown", "text"]] = False,
49
- include_images: bool = False,
48
+ include_answer: Union[bool, Literal["basic", "advanced"]] = None,
49
+ include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
50
+ include_images: bool = None,
50
51
  timeout: int = 60,
51
52
  country: str = None,
52
53
  **kwargs
@@ -70,6 +71,8 @@ class TavilyClient:
70
71
  "country": country,
71
72
  }
72
73
 
74
+ data = {k: v for k, v in data.items() if v is not None}
75
+
73
76
  if kwargs:
74
77
  data.update(kwargs)
75
78
 
@@ -104,16 +107,16 @@ class TavilyClient:
104
107
 
105
108
  def search(self,
106
109
  query: str,
107
- search_depth: Literal["basic", "advanced"] = "basic",
108
- topic: Literal["general", "news", "finance" ] = "general",
110
+ search_depth: Literal["basic", "advanced"] = None,
111
+ topic: Literal["general", "news", "finance" ] = None,
109
112
  time_range: Literal["day", "week", "month", "year"] = None,
110
- days: int = 7,
111
- max_results: int = 5,
113
+ days: int = None,
114
+ max_results: int = None,
112
115
  include_domains: Sequence[str] = None,
113
116
  exclude_domains: Sequence[str] = None,
114
- include_answer: Union[bool, Literal["basic", "advanced"]] = False,
115
- include_raw_content: Union[bool, Literal["markdown", "text"]] = False,
116
- include_images: bool = False,
117
+ include_answer: Union[bool, Literal["basic", "advanced"]] = None,
118
+ include_raw_content: Union[bool, Literal["markdown", "text"]] = None,
119
+ include_images: bool = None,
117
120
  timeout: int = 60,
118
121
  country: str = None,
119
122
  **kwargs, # Accept custom arguments
@@ -146,9 +149,9 @@ class TavilyClient:
146
149
 
147
150
  def _extract(self,
148
151
  urls: Union[List[str], str],
149
- include_images: bool = False,
150
- extract_depth: Literal["basic", "advanced"] = "basic",
151
- format: Literal["markdown", "text"] = "markdown",
152
+ include_images: bool = None,
153
+ extract_depth: Literal["basic", "advanced"] = None,
154
+ format: Literal["markdown", "text"] = None,
152
155
  timeout: int = 60,
153
156
  **kwargs
154
157
  ) -> dict:
@@ -161,6 +164,9 @@ class TavilyClient:
161
164
  "extract_depth": extract_depth,
162
165
  "format": format,
163
166
  }
167
+
168
+ data = {k: v for k, v in data.items() if v is not None}
169
+
164
170
  if kwargs:
165
171
  data.update(kwargs)
166
172
 
@@ -193,9 +199,9 @@ class TavilyClient:
193
199
 
194
200
  def extract(self,
195
201
  urls: Union[List[str], str], # Accept a list of URLs or a single URL
196
- include_images: bool = False,
197
- extract_depth: Literal["basic", "advanced"] = "basic",
198
- format: Literal["markdown", "text"] = "markdown",
202
+ include_images: bool = None,
203
+ extract_depth: Literal["basic", "advanced"] = None,
204
+ format: Literal["markdown", "text"] = None,
199
205
  timeout: int = 60,
200
206
  **kwargs, # Accept custom arguments
201
207
  ) -> dict:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tavily-python
3
- Version: 0.7.5
3
+ Version: 0.7.7
4
4
  Summary: Python wrapper for the Tavily API
5
5
  Home-page: https://github.com/tavily-ai/tavily-python
6
6
  Author: Tavily AI
@@ -14,7 +14,9 @@ tavily_python.egg-info/SOURCES.txt
14
14
  tavily_python.egg-info/dependency_links.txt
15
15
  tavily_python.egg-info/requires.txt
16
16
  tavily_python.egg-info/top_level.txt
17
+ tests/test_async_search.py
17
18
  tests/test_crawl.py
18
19
  tests/test_errors.py
19
20
  tests/test_map.py
20
- tests/test_search.py
21
+ tests/test_search.py
22
+ tests/test_sync_search.py
@@ -0,0 +1,219 @@
1
+ import unittest
2
+ import os
3
+ from tavily import AsyncTavilyClient, MissingAPIKeyError, InvalidAPIKeyError
4
+ from urllib.parse import urlparse
5
+ import asyncio
6
+
7
+ from unit_tests import cases
8
+ class SearchTest(unittest.TestCase):
9
+
10
+ def setUp(self) -> None:
11
+ self.tavily_client = AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
12
+ return super().setUp()
13
+
14
+ def tearDown(self) -> None:
15
+ return super().tearDown()
16
+
17
+ # Every single search result should have these properties
18
+ def common_search_result_properties(self, result) -> None:
19
+ self.assertIsInstance(result, dict)
20
+ self.assertIn("title", result)
21
+ self.assertIn("url", result)
22
+ self.assertIn("content", result)
23
+ self.assertIn("score", result)
24
+
25
+ # General search results should have these properties
26
+ def general_search_result_properties(self, result) -> None:
27
+ self.common_search_result_properties(result)
28
+ self.assertIn("raw_content", result)
29
+
30
+ # News search results should have these properties
31
+ def news_search_result_properties(self, result) -> None:
32
+ self.common_search_result_properties(result)
33
+ self.assertIn("published_date", result)
34
+
35
+ # Topic-specific properties
36
+ def topic_specific_properties(self, result, **params) -> None:
37
+ if params.get("topic", "general") == "general":
38
+ self.assertIn("raw_content", result)
39
+ elif params.get("topic", "general") == "news":
40
+ self.assertIn("published_date", result)
41
+
42
+ # Domain inclusion/exclusion-dependent properties
43
+ def domain_dependent_properties(self, response, **params) -> None:
44
+ if params.get("topic", "general") != "general":
45
+ return
46
+
47
+ if params.get("include_domains", False) and len(params["include_domains"]) > 0:
48
+ for result in response["results"]:
49
+ self.assertTrue(any(domain in urlparse(result["url"]).netloc for domain in params["include_domains"]))
50
+
51
+ if params.get("exclude_domains", False) and len(params["exclude_domains"]) > 0:
52
+ for result in response["results"]:
53
+ self.assertFalse(any(domain in urlparse(result["url"]).netloc for domain in params["exclude_domains"]))
54
+
55
+ # Image-dependent properties
56
+ def image_properties(self, response, **params) -> None:
57
+ if params.get("topic", "general") == "general" and params.get("include_images", False):
58
+ self.assertIsNotNone(response["images"])
59
+ self.assertIsInstance(response["images"], list)
60
+ for image in response["images"]:
61
+ self.assertIsInstance(image, str)
62
+
63
+ # Answer-dependent properties
64
+ def answer_properties(self, response, **params) -> None:
65
+ if params.get("include_answer", False):
66
+ self.assertIn("answer", response)
67
+ self.assertIsInstance(response["answer"], str)
68
+
69
+
70
+ # Every single search response should have these properties
71
+ def common_response_properties(self, response) -> None:
72
+ self.assertIsNotNone(response)
73
+ self.assertIsInstance(response, dict)
74
+ self.assertIn("answer", response)
75
+ self.assertIn("query", response)
76
+ self.assertIn("results", response)
77
+ self.assertIn("images", response)
78
+ self.assertIn("response_time", response)
79
+ self.assertIn("follow_up_questions", response)
80
+
81
+ self.assertIsNotNone(response["query"])
82
+ self.assertIsNotNone(response["results"])
83
+ self.assertIsNotNone(response["response_time"])
84
+ self.assertIsNotNone(response["images"])
85
+
86
+ self.assertIsInstance(response["query"], str)
87
+ self.assertIsInstance(response["results"], list)
88
+ self.assertIsInstance(response["response_time"], float)
89
+ self.assertIsInstance(response["images"], list)
90
+
91
+ # Search responses also have properties that depend on the request params
92
+ def custom_response_properties(self, response, **params) -> None:
93
+ self.domain_dependent_properties(response, **params)
94
+ self.image_properties(response, **params)
95
+ self.answer_properties(response, **params)
96
+ for result in response["results"]:
97
+ self.topic_specific_properties(result, **params)
98
+
99
+ def test_internal_search(self) -> None:
100
+ for test_case in cases:
101
+ with self.subTest(msg=test_case["name"]):
102
+ response = asyncio.run(self.tavily_client._search(**test_case["params"]))
103
+ self.common_response_properties(response)
104
+ if test_case["params"].get("topic", "general") == "general":
105
+ for search_result in response["results"]:
106
+ self.general_search_result_properties(search_result)
107
+ elif test_case["params"].get("topic", "general") == "news":
108
+ for search_result in response["results"]:
109
+ self.news_search_result_properties(search_result)
110
+ self.custom_response_properties(response, **test_case["params"])
111
+
112
+ def test_external_search(self) -> None:
113
+ for test_case in cases:
114
+ with self.subTest(msg=test_case["name"]):
115
+ response = asyncio.run(self.tavily_client.search(**test_case["params"]))
116
+ self.common_response_properties(response)
117
+ if test_case["params"].get("topic", "general") == "general":
118
+ for search_result in response["results"]:
119
+ self.general_search_result_properties(search_result)
120
+ elif test_case["params"].get("topic", "general") == "news":
121
+ for search_result in response["results"]:
122
+ self.news_search_result_properties(search_result)
123
+ self.custom_response_properties(response, **test_case["params"])
124
+
125
+ class QNASearchTest(unittest.TestCase):
126
+
127
+ def setUp(self) -> None:
128
+ self.tavily_client = AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
129
+ return super().setUp()
130
+
131
+ def tearDown(self) -> None:
132
+ return super().tearDown()
133
+
134
+ def test_qna_search(self) -> None:
135
+ for test_case in cases:
136
+ if "include_answer" in test_case["params"]:
137
+ del test_case["params"]["include_answer"]
138
+ if "include_raw_content" in test_case["params"]:
139
+ del test_case["params"]["include_raw_content"]
140
+ if "include_images" in test_case["params"]:
141
+ del test_case["params"]["include_images"]
142
+ with self.subTest(msg=test_case["name"]):
143
+ response = asyncio.run(self.tavily_client.qna_search(**test_case["params"]))
144
+ self.assertIsNotNone(response)
145
+ self.assertIsInstance(response, str)
146
+ self.assertTrue(len(response) > 0)
147
+
148
+ class CompanyInfoSearchTest(unittest.TestCase):
149
+
150
+ def setUp(self) -> None:
151
+ self.tavily_client = AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
152
+ return super().setUp()
153
+
154
+ def tearDown(self) -> None:
155
+ return super().tearDown()
156
+
157
+ # Every single search result should have these properties
158
+ def common_search_result_properties(self, result) -> None:
159
+ self.assertIsInstance(result, dict)
160
+ self.assertIn("title", result)
161
+ self.assertIn("url", result)
162
+ self.assertIn("content", result)
163
+ self.assertIn("score", result)
164
+
165
+ def test_company_info_search(self) -> None:
166
+ for test_case in cases:
167
+ if "topic" in test_case["params"]:
168
+ del test_case["params"]["topic"]
169
+ if "include_domains" in test_case["params"]:
170
+ del test_case["params"]["include_domains"]
171
+ if "exclude_domains" in test_case["params"]:
172
+ del test_case["params"]["exclude_domains"]
173
+ if "include_raw_content" in test_case["params"]:
174
+ del test_case["params"]["include_raw_content"]
175
+ if "include_images" in test_case["params"]:
176
+ del test_case["params"]["include_images"]
177
+ if "include_answer" in test_case["params"]:
178
+ del test_case["params"]["include_answer"]
179
+ if "use_cache" in test_case["params"]:
180
+ del test_case["params"]["use_cache"]
181
+ with self.subTest(msg=test_case["name"]):
182
+ response = asyncio.run(self.tavily_client.get_company_info(**test_case["params"]))
183
+ self.assertIsNotNone(response)
184
+ self.assertIsInstance(response, list)
185
+ self.assertTrue(len(response) > 0)
186
+ for search_result in response:
187
+ self.common_search_result_properties(search_result)
188
+
189
+ class ErrorTest(unittest.TestCase):
190
+
191
+ def setUp(self) -> None:
192
+ return super().setUp()
193
+
194
+ def tearDown(self) -> None:
195
+ return super().tearDown()
196
+
197
+ # This test is here to ensure that no MissingAPIKeyError is raised when the API key is in the environment
198
+ def test_load_key_from_env(self) -> None:
199
+ self.assertIn('results', asyncio.run(AsyncTavilyClient().search("Why is Tavily the best search API?")))
200
+
201
+ def test_missing_api_key(self) -> None:
202
+ with self.assertRaises(MissingAPIKeyError):
203
+ AsyncTavilyClient(api_key='')
204
+
205
+ old_key = os.getenv("TAVILY_API_KEY")
206
+ del os.environ["TAVILY_API_KEY"]
207
+ with self.assertRaises(MissingAPIKeyError):
208
+ AsyncTavilyClient()
209
+
210
+ os.environ["TAVILY_API_KEY"] = old_key
211
+
212
+
213
+ def test_invalid_api_key(self) -> None:
214
+ with self.assertRaises(InvalidAPIKeyError):
215
+ asyncio.run(AsyncTavilyClient(api_key="invalid_api_key").search("Why is Tavily the best search API?"))
216
+
217
+ if __name__ == "__main__":
218
+
219
+ unittest.main()
@@ -16,6 +16,7 @@ def validate_default(request, response):
16
16
  assert request.method == "POST"
17
17
  assert request.url == "https://api.tavily.com/crawl"
18
18
  assert request.headers["Authorization"] == "Bearer tvly-test"
19
+ assert request.headers["X-Client-Source"] == "tavily-python"
19
20
  assert request.json().get('url') == "https://tavily.com"
20
21
  assert response == dummy_response
21
22
 
@@ -23,6 +24,7 @@ def validate_specific(request, response):
23
24
  assert request.method == "POST"
24
25
  assert request.url == "https://api.tavily.com/crawl"
25
26
  assert request.headers["Authorization"] == "Bearer tvly-test"
27
+ assert request.headers["X-Client-Source"] == "tavily-python"
26
28
  assert request.timeout == 10
27
29
 
28
30
  request_json = request.json()
@@ -15,6 +15,7 @@ def validate_default(request, response):
15
15
  assert request.method == "POST"
16
16
  assert request.url == "https://api.tavily.com/map"
17
17
  assert request.headers["Authorization"] == "Bearer tvly-test"
18
+ assert request.headers["X-Client-Source"] == "tavily-python"
18
19
  assert request.json().get('url') == "https://tavily.com"
19
20
  assert response == dummy_response
20
21
 
@@ -22,6 +23,7 @@ def validate_specific(request, response):
22
23
  assert request.method == "POST"
23
24
  assert request.url == "https://api.tavily.com/map"
24
25
  assert request.headers["Authorization"] == "Bearer tvly-test"
26
+ assert request.headers["X-Client-Source"] == "tavily-python"
25
27
  assert request.timeout == 10
26
28
 
27
29
  request_json = request.json()
@@ -22,6 +22,7 @@ def validate_default(request, response):
22
22
  assert request.method == "POST"
23
23
  assert request.url == "https://api.tavily.com/search"
24
24
  assert request.headers["Authorization"] == "Bearer tvly-test"
25
+ assert request.headers["X-Client-Source"] == "tavily-python"
25
26
  assert request.json().get('query') == "What is Tavily?"
26
27
  assert response == dummy_response
27
28
 
@@ -29,6 +30,7 @@ def validate_specific(request, response):
29
30
  assert request.method == "POST"
30
31
  assert request.url == "https://api.tavily.com/search"
31
32
  assert request.headers["Authorization"] == "Bearer tvly-test"
33
+ assert request.headers["X-Client-Source"] == "tavily-python"
32
34
  assert request.timeout == 10
33
35
 
34
36
  request_json = request.json()
@@ -0,0 +1,219 @@
1
+ import unittest
2
+ import os
3
+ from tavily import TavilyClient, InvalidAPIKeyError, MissingAPIKeyError
4
+ from urllib.parse import urlparse
5
+
6
+ from unit_tests import cases
7
+
8
+ class SearchTest(unittest.TestCase):
9
+
10
+ def setUp(self) -> None:
11
+ self.tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
12
+ return super().setUp()
13
+
14
+ def tearDown(self) -> None:
15
+ return super().tearDown()
16
+
17
+ # Every single search result should have these properties
18
+ def common_search_result_properties(self, result) -> None:
19
+ self.assertIsInstance(result, dict)
20
+ self.assertIn("title", result)
21
+ self.assertIn("url", result)
22
+ self.assertIn("content", result)
23
+ self.assertIn("score", result)
24
+
25
+ # General search results should have these properties
26
+ def general_search_result_properties(self, result) -> None:
27
+ self.common_search_result_properties(result)
28
+ self.assertIn("raw_content", result)
29
+
30
+ # News search results should have these properties
31
+ def news_search_result_properties(self, result) -> None:
32
+ self.common_search_result_properties(result)
33
+ self.assertIn("published_date", result)
34
+
35
+ # Topic-specific properties
36
+ def topic_specific_properties(self, result, **params) -> None:
37
+ if params.get("topic", "general") == "general":
38
+ self.assertIn("raw_content", result)
39
+ elif params.get("topic", "general") == "news":
40
+ self.assertIn("published_date", result)
41
+
42
+ # Domain inclusion/exclusion-dependent properties
43
+ def domain_dependent_properties(self, response, **params) -> None:
44
+ if params.get("topic", "general") != "general":
45
+ return
46
+
47
+ if params.get("include_domains", False) and len(params["include_domains"]) > 0:
48
+ for result in response["results"]:
49
+ self.assertTrue(any(domain in urlparse(result["url"]).netloc for domain in params["include_domains"]))
50
+
51
+ if params.get("exclude_domains", False) and len(params["exclude_domains"]) > 0:
52
+ for result in response["results"]:
53
+ self.assertFalse(any(domain in urlparse(result["url"]).netloc for domain in params["exclude_domains"]))
54
+
55
+ # Image-dependent properties
56
+ def image_properties(self, response, **params) -> None:
57
+ if params.get("topic", "general") == "general" and params.get("include_images", False):
58
+ self.assertIsNotNone(response["images"])
59
+ self.assertIsInstance(response["images"], list)
60
+ for image in response["images"]:
61
+ self.assertIsInstance(image, str)
62
+
63
+ # Answer-dependent properties
64
+ def answer_properties(self, response, **params) -> None:
65
+ if params.get("include_answer", False):
66
+ self.assertIn("answer", response)
67
+ self.assertIsInstance(response["answer"], str)
68
+
69
+
70
+ # Every single search response should have these properties
71
+ def common_response_properties(self, response) -> None:
72
+ self.assertIsNotNone(response)
73
+ self.assertIsInstance(response, dict)
74
+ self.assertIn("answer", response)
75
+ self.assertIn("query", response)
76
+ self.assertIn("results", response)
77
+ self.assertIn("images", response)
78
+ self.assertIn("response_time", response)
79
+ self.assertIn("follow_up_questions", response)
80
+
81
+ self.assertIsNotNone(response["query"])
82
+ self.assertIsNotNone(response["results"])
83
+ self.assertIsNotNone(response["response_time"])
84
+ self.assertIsNotNone(response["images"])
85
+
86
+ self.assertIsInstance(response["query"], str)
87
+ self.assertIsInstance(response["results"], list)
88
+ self.assertIsInstance(response["response_time"], float)
89
+ self.assertIsInstance(response["images"], list)
90
+
91
+ # Search responses also have properties that depend on the request params
92
+ def custom_response_properties(self, response, **params) -> None:
93
+ self.domain_dependent_properties(response, **params)
94
+ self.image_properties(response, **params)
95
+ self.answer_properties(response, **params)
96
+ for result in response["results"]:
97
+ self.topic_specific_properties(result, **params)
98
+
99
+ def test_internal_search(self) -> None:
100
+ for test_case in cases:
101
+ with self.subTest(msg=test_case["name"]):
102
+ result = self.tavily_client._search(**test_case["params"])
103
+ self.common_response_properties(result)
104
+ if test_case["params"].get("topic", "general") == "general":
105
+ for search_result in result["results"]:
106
+ self.general_search_result_properties(search_result)
107
+ elif test_case["params"].get("topic", "general") == "news":
108
+ for search_result in result["results"]:
109
+ self.news_search_result_properties(search_result)
110
+ self.custom_response_properties(result, **test_case["params"])
111
+
112
+ def test_external_search(self) -> None:
113
+ for test_case in cases:
114
+ with self.subTest(msg=test_case["name"]):
115
+ response = self.tavily_client._search(**test_case["params"])
116
+ self.common_response_properties(response)
117
+ if test_case["params"].get("topic", "general") == "general":
118
+ for search_result in response["results"]:
119
+ self.general_search_result_properties(search_result)
120
+ elif test_case["params"].get("topic", "general") == "news":
121
+ for search_result in response["results"]:
122
+ self.news_search_result_properties(search_result)
123
+ self.custom_response_properties(response, **test_case["params"])
124
+
125
+ class QNASearchTest(unittest.TestCase):
126
+
127
+ def setUp(self) -> None:
128
+ self.tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
129
+ return super().setUp()
130
+
131
+ def tearDown(self) -> None:
132
+ return super().tearDown()
133
+
134
+ def test_qna_search(self) -> None:
135
+ for test_case in cases:
136
+ if "include_answer" in test_case["params"]:
137
+ del test_case["params"]["include_answer"]
138
+ if "include_raw_content" in test_case["params"]:
139
+ del test_case["params"]["include_raw_content"]
140
+ if "include_images" in test_case["params"]:
141
+ del test_case["params"]["include_images"]
142
+ with self.subTest(msg=test_case["name"]):
143
+ response = self.tavily_client.qna_search(**test_case["params"])
144
+ self.assertIsNotNone(response)
145
+ self.assertIsInstance(response, str)
146
+ self.assertTrue(len(response) > 0)
147
+
148
+ class CompanyInfoSearchTest(unittest.TestCase):
149
+
150
+ def setUp(self) -> None:
151
+ self.tavily_client = TavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
152
+ return super().setUp()
153
+
154
+ def tearDown(self) -> None:
155
+ return super().tearDown()
156
+
157
+ # Every single search result should have these properties
158
+ def common_search_result_properties(self, result) -> None:
159
+ self.assertIsInstance(result, dict)
160
+ self.assertIn("title", result)
161
+ self.assertIn("url", result)
162
+ self.assertIn("content", result)
163
+ self.assertIn("score", result)
164
+
165
+ def test_company_info_search(self) -> None:
166
+ for test_case in cases:
167
+ if "topic" in test_case["params"]:
168
+ del test_case["params"]["topic"]
169
+ if "include_domains" in test_case["params"]:
170
+ del test_case["params"]["include_domains"]
171
+ if "exclude_domains" in test_case["params"]:
172
+ del test_case["params"]["exclude_domains"]
173
+ if "include_raw_content" in test_case["params"]:
174
+ del test_case["params"]["include_raw_content"]
175
+ if "include_images" in test_case["params"]:
176
+ del test_case["params"]["include_images"]
177
+ if "include_answer" in test_case["params"]:
178
+ del test_case["params"]["include_answer"]
179
+ if "use_cache" in test_case["params"]:
180
+ del test_case["params"]["use_cache"]
181
+ with self.subTest(msg=test_case["name"]):
182
+ response = self.tavily_client.get_company_info(**test_case["params"])
183
+ self.assertIsNotNone(response)
184
+ self.assertIsInstance(response, list)
185
+ self.assertTrue(len(response) > 0)
186
+ for search_result in response:
187
+ self.common_search_result_properties(search_result)
188
+
189
+ class ErrorTest(unittest.TestCase):
190
+
191
+ def setUp(self) -> None:
192
+ return super().setUp()
193
+
194
+ def tearDown(self) -> None:
195
+ return super().tearDown()
196
+
197
+ # This test is here to ensure that no MissingAPIKeyError is raised when the API key is in the environment
198
+ def test_load_key_from_env(self) -> None:
199
+ self.assertIn('results', TavilyClient().search("Why is Tavily the best search API?"))
200
+
201
+ def test_missing_api_key(self) -> None:
202
+ with self.assertRaises(MissingAPIKeyError):
203
+ TavilyClient(api_key='')
204
+
205
+ old_key = os.getenv("TAVILY_API_KEY")
206
+ del os.environ["TAVILY_API_KEY"]
207
+ with self.assertRaises(MissingAPIKeyError):
208
+ TavilyClient()
209
+
210
+ os.environ["TAVILY_API_KEY"] = old_key
211
+
212
+
213
+ def test_invalid_api_key(self) -> None:
214
+ with self.assertRaises(InvalidAPIKeyError):
215
+ TavilyClient(api_key="invalid_api_key").search("Why is Tavily the best search API?")
216
+
217
+ if __name__ == "__main__":
218
+
219
+ unittest.main()
File without changes
File without changes
File without changes