exa-py 1.14.19__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of exa-py might be problematic. Click here for more details.

exa_py/api.py CHANGED
@@ -41,7 +41,7 @@ from exa_py.utils import (
41
41
  )
42
42
  from .websets import WebsetsClient
43
43
  from .websets.core.base import ExaJSONEncoder
44
- from .research.client import ResearchClient, AsyncResearchClient
44
+ from .research import ResearchClient, AsyncResearchClient
45
45
 
46
46
 
47
47
  is_beta = os.getenv("IS_BETA") == "True"
@@ -133,6 +133,7 @@ SEARCH_OPTIONS_TYPES = {
133
133
  "end_published_date": [
134
134
  str
135
135
  ], # Results before this publish date; excludes links with no date. ISO 8601 format.
136
+ "user_location": [str], # Two-letter ISO country code of the user (e.g. US).
136
137
  "include_text": [
137
138
  list
138
139
  ], # Must be present in webpage text. (One string, up to 5 words)
@@ -1186,23 +1187,37 @@ class Exa:
1186
1187
  # Otherwise, serialize the dictionary to JSON if it exists
1187
1188
  json_data = json.dumps(data, cls=ExaJSONEncoder) if data else None
1188
1189
 
1189
- if data and data.get("stream"):
1190
- res = requests.post(
1191
- self.base_url + endpoint,
1192
- data=json_data,
1193
- headers=self.headers,
1194
- stream=True,
1195
- )
1196
- return res
1190
+ # Check if we need streaming (either from data for POST or params for GET)
1191
+ needs_streaming = (data and isinstance(data, dict) and data.get("stream")) or (
1192
+ params and params.get("stream") == "true"
1193
+ )
1197
1194
 
1198
1195
  if method.upper() == "GET":
1199
- res = requests.get(
1200
- self.base_url + endpoint, headers=self.headers, params=params
1201
- )
1196
+ if needs_streaming:
1197
+ res = requests.get(
1198
+ self.base_url + endpoint,
1199
+ headers=self.headers,
1200
+ params=params,
1201
+ stream=True,
1202
+ )
1203
+ return res
1204
+ else:
1205
+ res = requests.get(
1206
+ self.base_url + endpoint, headers=self.headers, params=params
1207
+ )
1202
1208
  elif method.upper() == "POST":
1203
- res = requests.post(
1204
- self.base_url + endpoint, data=json_data, headers=self.headers
1205
- )
1209
+ if needs_streaming:
1210
+ res = requests.post(
1211
+ self.base_url + endpoint,
1212
+ data=json_data,
1213
+ headers=self.headers,
1214
+ stream=True,
1215
+ )
1216
+ return res
1217
+ else:
1218
+ res = requests.post(
1219
+ self.base_url + endpoint, data=json_data, headers=self.headers
1220
+ )
1206
1221
  elif method.upper() == "PATCH":
1207
1222
  res = requests.patch(
1208
1223
  self.base_url + endpoint, data=json_data, headers=self.headers
@@ -1236,6 +1251,7 @@ class Exa:
1236
1251
  category: Optional[str] = None,
1237
1252
  flags: Optional[List[str]] = None,
1238
1253
  moderation: Optional[bool] = None,
1254
+ user_location: Optional[str] = None,
1239
1255
  ) -> SearchResponse[_Result]:
1240
1256
  """Perform a search with a prompt-engineered query to retrieve relevant results.
1241
1257
 
@@ -1251,10 +1267,11 @@ class Exa:
1251
1267
  include_text (List[str], optional): Strings that must appear in the page text.
1252
1268
  exclude_text (List[str], optional): Strings that must not appear in the page text.
1253
1269
  use_autoprompt (bool, optional): Convert query to Exa (default False).
1254
- type (str, optional): 'keyword', 'neural', 'hybrid', or 'fast' (default 'neural').
1270
+ type (str, optional): 'keyword', 'neural', 'hybrid', 'fast', or 'auto' (default 'auto').
1255
1271
  category (str, optional): e.g. 'company'
1256
1272
  flags (List[str], optional): Experimental flags for Exa usage.
1257
1273
  moderation (bool, optional): If True, the search results will be moderated for safety.
1274
+ user_location (str, optional): Two-letter ISO country code of the user (e.g. US).
1258
1275
 
1259
1276
  Returns:
1260
1277
  SearchResponse: The response containing search results, etc.
@@ -1312,6 +1329,7 @@ class Exa:
1312
1329
  category: Optional[str] = None,
1313
1330
  flags: Optional[List[str]] = None,
1314
1331
  moderation: Optional[bool] = None,
1332
+ user_location: Optional[str] = None,
1315
1333
  livecrawl_timeout: Optional[int] = None,
1316
1334
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1317
1335
  filter_empty_results: Optional[bool] = None,
@@ -1370,6 +1388,7 @@ class Exa:
1370
1388
  subpage_target: Optional[Union[str, List[str]]] = None,
1371
1389
  flags: Optional[List[str]] = None,
1372
1390
  moderation: Optional[bool] = None,
1391
+ user_location: Optional[str] = None,
1373
1392
  livecrawl_timeout: Optional[int] = None,
1374
1393
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1375
1394
  filter_empty_results: Optional[bool] = None,
@@ -1399,6 +1418,7 @@ class Exa:
1399
1418
  subpage_target: Optional[Union[str, List[str]]] = None,
1400
1419
  flags: Optional[List[str]] = None,
1401
1420
  moderation: Optional[bool] = None,
1421
+ user_location: Optional[str] = None,
1402
1422
  livecrawl_timeout: Optional[int] = None,
1403
1423
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1404
1424
  filter_empty_results: Optional[bool] = None,
@@ -1427,6 +1447,7 @@ class Exa:
1427
1447
  subpage_target: Optional[Union[str, List[str]]] = None,
1428
1448
  flags: Optional[List[str]] = None,
1429
1449
  moderation: Optional[bool] = None,
1450
+ user_location: Optional[str] = None,
1430
1451
  livecrawl_timeout: Optional[int] = None,
1431
1452
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1432
1453
  filter_empty_results: Optional[bool] = None,
@@ -1456,6 +1477,7 @@ class Exa:
1456
1477
  subpage_target: Optional[Union[str, List[str]]] = None,
1457
1478
  flags: Optional[List[str]] = None,
1458
1479
  moderation: Optional[bool] = None,
1480
+ user_location: Optional[str] = None,
1459
1481
  livecrawl_timeout: Optional[int] = None,
1460
1482
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1461
1483
  filter_empty_results: Optional[bool] = None,
@@ -1485,6 +1507,7 @@ class Exa:
1485
1507
  subpage_target: Optional[Union[str, List[str]]] = None,
1486
1508
  flags: Optional[List[str]] = None,
1487
1509
  moderation: Optional[bool] = None,
1510
+ user_location: Optional[str] = None,
1488
1511
  livecrawl_timeout: Optional[int] = None,
1489
1512
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1490
1513
  filter_empty_results: Optional[bool] = None,
@@ -1513,6 +1536,7 @@ class Exa:
1513
1536
  category: Optional[str] = None,
1514
1537
  flags: Optional[List[str]] = None,
1515
1538
  moderation: Optional[bool] = None,
1539
+ user_location: Optional[str] = None,
1516
1540
  livecrawl_timeout: Optional[int] = None,
1517
1541
  livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
1518
1542
  subpages: Optional[int] = None,
@@ -2396,34 +2420,55 @@ class AsyncExa(Exa):
2396
2420
  # this may only be a
2397
2421
  if self._client is None:
2398
2422
  self._client = httpx.AsyncClient(
2399
- base_url=self.base_url, headers=self.headers, timeout=60
2423
+ base_url=self.base_url, headers=self.headers, timeout=600
2400
2424
  )
2401
2425
  return self._client
2402
2426
 
2403
- async def async_request(self, endpoint: str, data):
2404
- """Send a POST request to the Exa API, optionally streaming if data['stream'] is True.
2427
+ async def async_request(
2428
+ self, endpoint: str, data=None, method: str = "POST", params=None
2429
+ ):
2430
+ """Send a request to the Exa API, optionally streaming if data['stream'] is True.
2405
2431
 
2406
2432
  Args:
2407
2433
  endpoint (str): The API endpoint (path).
2408
- data (dict): The JSON payload to send.
2434
+ data (dict, optional): The JSON payload to send.
2435
+ method (str, optional): The HTTP method to use. Defaults to "POST".
2436
+ params (dict, optional): Query parameters.
2409
2437
 
2410
2438
  Returns:
2411
- Union[dict, requests.Response]: If streaming, returns the Response object.
2439
+ Union[dict, httpx.Response]: If streaming, returns the Response object.
2412
2440
  Otherwise, returns the JSON-decoded response as a dict.
2413
2441
 
2414
2442
  Raises:
2415
2443
  ValueError: If the request fails (non-200 status code).
2416
2444
  """
2417
- if data.get("stream"):
2418
- request = httpx.Request(
2419
- "POST", self.base_url + endpoint, json=data, headers=self.headers
2420
- )
2421
- res = await self.client.send(request, stream=True)
2422
- return res
2423
-
2424
- res = await self.client.post(
2425
- self.base_url + endpoint, json=data, headers=self.headers
2445
+ # Check if we need streaming (either from data for POST or params for GET)
2446
+ needs_streaming = (data and isinstance(data, dict) and data.get("stream")) or (
2447
+ params and params.get("stream") == "true"
2426
2448
  )
2449
+
2450
+ if method.upper() == "GET":
2451
+ if needs_streaming:
2452
+ request = httpx.Request(
2453
+ "GET", self.base_url + endpoint, params=params, headers=self.headers
2454
+ )
2455
+ res = await self.client.send(request, stream=True)
2456
+ return res
2457
+ else:
2458
+ res = await self.client.get(
2459
+ self.base_url + endpoint, params=params, headers=self.headers
2460
+ )
2461
+ elif method.upper() == "POST":
2462
+ if needs_streaming:
2463
+ request = httpx.Request(
2464
+ "POST", self.base_url + endpoint, json=data, headers=self.headers
2465
+ )
2466
+ res = await self.client.send(request, stream=True)
2467
+ return res
2468
+ else:
2469
+ res = await self.client.post(
2470
+ self.base_url + endpoint, json=data, headers=self.headers
2471
+ )
2427
2472
  if res.status_code != 200 and res.status_code != 201:
2428
2473
  raise ValueError(
2429
2474
  f"Request failed with status code {res.status_code}: {res.text}"
@@ -2448,6 +2493,7 @@ class AsyncExa(Exa):
2448
2493
  category: Optional[str] = None,
2449
2494
  flags: Optional[List[str]] = None,
2450
2495
  moderation: Optional[bool] = None,
2496
+ user_location: Optional[str] = None,
2451
2497
  ) -> SearchResponse[_Result]:
2452
2498
  """Perform a search with a prompt-engineered query to retrieve relevant results.
2453
2499
 
@@ -2463,10 +2509,11 @@ class AsyncExa(Exa):
2463
2509
  include_text (List[str], optional): Strings that must appear in the page text.
2464
2510
  exclude_text (List[str], optional): Strings that must not appear in the page text.
2465
2511
  use_autoprompt (bool, optional): Convert query to Exa (default False).
2466
- type (str, optional): 'keyword', 'neural', 'hybrid', or 'fast' (default 'neural').
2512
+ type (str, optional): 'keyword', 'neural', 'hybrid', 'fast', or 'auto' (default 'auto').
2467
2513
  category (str, optional): e.g. 'company'
2468
2514
  flags (List[str], optional): Experimental flags for Exa usage.
2469
2515
  moderation (bool, optional): If True, the search results will be moderated for safety.
2516
+ user_location (str, optional): Two-letter ISO country code of the user (e.g. US).
2470
2517
 
2471
2518
  Returns:
2472
2519
  SearchResponse: The response containing search results, etc.
@@ -1,10 +1,39 @@
1
- from .client import ResearchClient, AsyncResearchClient
2
- from .models import ResearchTask, ListResearchTasksResponse, ResearchTaskId
1
+ """Research API client modules for Exa."""
2
+
3
+ from .sync_client import ResearchClient, ResearchTyped
4
+ from .async_client import AsyncResearchClient, AsyncResearchTyped
5
+ from .models import (
6
+ ResearchDto,
7
+ ResearchEvent,
8
+ ResearchDefinitionEvent,
9
+ ResearchOutputEvent,
10
+ ResearchPlanDefinitionEvent,
11
+ ResearchPlanOperationEvent,
12
+ ResearchPlanOutputEvent,
13
+ ResearchTaskDefinitionEvent,
14
+ ResearchTaskOperationEvent,
15
+ ResearchTaskOutputEvent,
16
+ ListResearchResponseDto,
17
+ CostDollars,
18
+ ResearchOutput,
19
+ )
3
20
 
4
21
  __all__ = [
5
22
  "ResearchClient",
6
23
  "AsyncResearchClient",
7
- "ResearchTaskId",
8
- "ResearchTask",
9
- "ListResearchTasksResponse",
24
+ "ResearchTyped",
25
+ "AsyncResearchTyped",
26
+ "ResearchDto",
27
+ "ResearchEvent",
28
+ "ResearchDefinitionEvent",
29
+ "ResearchOutputEvent",
30
+ "ResearchPlanDefinitionEvent",
31
+ "ResearchPlanOperationEvent",
32
+ "ResearchPlanOutputEvent",
33
+ "ResearchTaskDefinitionEvent",
34
+ "ResearchTaskOperationEvent",
35
+ "ResearchTaskOutputEvent",
36
+ "ListResearchResponseDto",
37
+ "CostDollars",
38
+ "ResearchOutput",
10
39
  ]
@@ -0,0 +1,310 @@
1
+ """Asynchronous Research API client."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import (
7
+ Any,
8
+ AsyncGenerator,
9
+ Dict,
10
+ Generic,
11
+ Literal,
12
+ Optional,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ overload,
17
+ )
18
+
19
+ from pydantic import BaseModel, TypeAdapter
20
+
21
+ from .base import AsyncResearchBaseClient
22
+ from .models import (
23
+ ResearchDto,
24
+ ResearchEvent,
25
+ ListResearchResponseDto,
26
+ )
27
+ from .utils import (
28
+ async_stream_sse_events,
29
+ is_pydantic_model,
30
+ pydantic_to_json_schema,
31
+ )
32
+
33
+ T = TypeVar("T", bound=BaseModel)
34
+
35
+
36
+ class AsyncResearchTyped(Generic[T]):
37
+ """Wrapper for typed research responses in async context."""
38
+
39
+ def __init__(self, research: ResearchDto, parsed_output: T):
40
+ self.research = research
41
+ self.parsed_output = parsed_output
42
+ # Expose research fields
43
+ self.research_id = research.research_id
44
+ self.status = research.status
45
+ self.created_at = research.created_at
46
+ self.model = research.model
47
+ self.instructions = research.instructions
48
+ if hasattr(research, "events"):
49
+ self.events = research.events
50
+ if hasattr(research, "output"):
51
+ self.output = research.output
52
+ if hasattr(research, "cost_dollars"):
53
+ self.cost_dollars = research.cost_dollars
54
+ if hasattr(research, "error"):
55
+ self.error = research.error
56
+
57
+
58
+ class AsyncResearchClient(AsyncResearchBaseClient):
59
+ """Asynchronous client for the Research API."""
60
+
61
+ @overload
62
+ async def create(
63
+ self,
64
+ *,
65
+ instructions: str,
66
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
67
+ ) -> ResearchDto: ...
68
+
69
+ @overload
70
+ async def create(
71
+ self,
72
+ *,
73
+ instructions: str,
74
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
75
+ output_schema: Dict[str, Any],
76
+ ) -> ResearchDto: ...
77
+
78
+ @overload
79
+ async def create(
80
+ self,
81
+ *,
82
+ instructions: str,
83
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
84
+ output_schema: Type[T],
85
+ ) -> ResearchDto: ...
86
+
87
+ async def create(
88
+ self,
89
+ *,
90
+ instructions: str,
91
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
92
+ output_schema: Optional[Union[Dict[str, Any], Type[BaseModel]]] = None,
93
+ ) -> ResearchDto:
94
+ """Create a new research request.
95
+
96
+ Args:
97
+ instructions: The research instructions.
98
+ model: The model to use for research.
99
+ output_schema: Optional JSON schema or Pydantic model for structured output.
100
+
101
+ Returns:
102
+ The created research object.
103
+ """
104
+ payload = {
105
+ "instructions": instructions,
106
+ "model": model,
107
+ }
108
+
109
+ if output_schema is not None:
110
+ if is_pydantic_model(output_schema):
111
+ payload["outputSchema"] = pydantic_to_json_schema(output_schema)
112
+ else:
113
+ payload["outputSchema"] = output_schema
114
+
115
+ response = await self.request("", method="POST", data=payload)
116
+ adapter = TypeAdapter(ResearchDto)
117
+ return adapter.validate_python(response)
118
+
119
+ @overload
120
+ async def get(
121
+ self,
122
+ research_id: str,
123
+ ) -> ResearchDto: ...
124
+
125
+ @overload
126
+ async def get(
127
+ self,
128
+ research_id: str,
129
+ *,
130
+ stream: Literal[False] = False,
131
+ events: bool = False,
132
+ ) -> ResearchDto: ...
133
+
134
+ @overload
135
+ async def get(
136
+ self,
137
+ research_id: str,
138
+ *,
139
+ stream: Literal[True],
140
+ events: Optional[bool] = None,
141
+ ) -> AsyncGenerator[ResearchEvent, None]: ...
142
+
143
+ @overload
144
+ async def get(
145
+ self,
146
+ research_id: str,
147
+ *,
148
+ stream: Literal[False] = False,
149
+ events: bool = False,
150
+ output_schema: Type[T],
151
+ ) -> AsyncResearchTyped[T]: ...
152
+
153
+ async def get(
154
+ self,
155
+ research_id: str,
156
+ *,
157
+ stream: bool = False,
158
+ events: bool = False,
159
+ output_schema: Optional[Type[BaseModel]] = None,
160
+ ) -> Union[ResearchDto, AsyncResearchTyped, AsyncGenerator[ResearchEvent, None]]:
161
+ """Get a research request by ID.
162
+
163
+ Args:
164
+ research_id: The research ID.
165
+ stream: Whether to stream events.
166
+ events: Whether to include events in non-streaming response.
167
+ output_schema: Optional Pydantic model for typed output validation.
168
+
169
+ Returns:
170
+ Research object, typed research, or async event generator.
171
+ """
172
+ params = {}
173
+ if not stream:
174
+ params["stream"] = "false"
175
+ if events:
176
+ params["events"] = "true"
177
+ else:
178
+ params["stream"] = "true"
179
+ if events is not None:
180
+ params["events"] = str(events).lower()
181
+
182
+ if stream:
183
+ response = await self.request(
184
+ f"/{research_id}", method="GET", params=params, stream=True
185
+ )
186
+ return async_stream_sse_events(response)
187
+ else:
188
+ response = await self.request(
189
+ f"/{research_id}", method="GET", params=params
190
+ )
191
+ adapter = TypeAdapter(ResearchDto)
192
+ research = adapter.validate_python(response)
193
+
194
+ if output_schema and hasattr(research, "output") and research.output:
195
+ try:
196
+ if research.output.parsed:
197
+ parsed = output_schema.model_validate(research.output.parsed)
198
+ else:
199
+ import json
200
+
201
+ parsed_data = json.loads(research.output.content)
202
+ parsed = output_schema.model_validate(parsed_data)
203
+ return AsyncResearchTyped(research, parsed)
204
+ except Exception:
205
+ # If parsing fails, return the regular research object
206
+ return research
207
+
208
+ return research
209
+
210
+ async def list(
211
+ self,
212
+ *,
213
+ cursor: Optional[str] = None,
214
+ limit: Optional[int] = None,
215
+ ) -> ListResearchResponseDto:
216
+ """List research requests.
217
+
218
+ Args:
219
+ cursor: Pagination cursor.
220
+ limit: Maximum number of results.
221
+
222
+ Returns:
223
+ List of research objects with pagination info.
224
+ """
225
+ params = self.build_pagination_params(cursor, limit)
226
+ response = await self.request("", method="GET", params=params)
227
+ return ListResearchResponseDto.model_validate(response)
228
+
229
+ @overload
230
+ async def poll_until_finished(
231
+ self,
232
+ research_id: str,
233
+ *,
234
+ poll_interval: int = 1000,
235
+ timeout_ms: int = 600000,
236
+ events: bool = False,
237
+ ) -> ResearchDto: ...
238
+
239
+ @overload
240
+ async def poll_until_finished(
241
+ self,
242
+ research_id: str,
243
+ *,
244
+ poll_interval: int = 1000,
245
+ timeout_ms: int = 600000,
246
+ events: bool = False,
247
+ output_schema: Type[T],
248
+ ) -> AsyncResearchTyped[T]: ...
249
+
250
+ async def poll_until_finished(
251
+ self,
252
+ research_id: str,
253
+ *,
254
+ poll_interval: int = 1000,
255
+ timeout_ms: int = 600000,
256
+ events: bool = False,
257
+ output_schema: Optional[Type[BaseModel]] = None,
258
+ ) -> Union[ResearchDto, AsyncResearchTyped]:
259
+ """Poll until research is finished.
260
+
261
+ Args:
262
+ research_id: The research ID.
263
+ poll_interval: Milliseconds between polls (default 1000).
264
+ timeout_ms: Maximum time to wait in milliseconds (default 600000).
265
+ events: Whether to include events in the response.
266
+ output_schema: Optional Pydantic model for typed output validation.
267
+
268
+ Returns:
269
+ Completed research object or typed research.
270
+
271
+ Raises:
272
+ TimeoutError: If research doesn't complete within timeout.
273
+ RuntimeError: If polling fails too many times.
274
+ """
275
+ poll_interval_sec = poll_interval / 1000
276
+ timeout_sec = timeout_ms / 1000
277
+ max_consecutive_failures = 5
278
+ start_time = asyncio.get_event_loop().time()
279
+ consecutive_failures = 0
280
+
281
+ while True:
282
+ try:
283
+ if output_schema:
284
+ result = await self.get(
285
+ research_id, events=events, output_schema=output_schema
286
+ )
287
+ else:
288
+ result = await self.get(research_id, events=events)
289
+
290
+ consecutive_failures = 0
291
+
292
+ # Check if research is finished
293
+ status = result.status if hasattr(result, "status") else None
294
+ if status in ["completed", "failed", "canceled"]:
295
+ return result
296
+
297
+ except Exception as e:
298
+ consecutive_failures += 1
299
+ if consecutive_failures >= max_consecutive_failures:
300
+ raise RuntimeError(
301
+ f"Polling failed {max_consecutive_failures} times in a row "
302
+ f"for research {research_id}: {e}"
303
+ )
304
+
305
+ if asyncio.get_event_loop().time() - start_time > timeout_sec:
306
+ raise TimeoutError(
307
+ f"Research {research_id} did not complete within {timeout_ms}ms"
308
+ )
309
+
310
+ await asyncio.sleep(poll_interval_sec)