exa-py 1.14.19__py3-none-any.whl → 1.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of exa-py might be problematic. Click here for more details.
- exa_py/api.py +78 -31
- exa_py/research/__init__.py +34 -5
- exa_py/research/async_client.py +310 -0
- exa_py/research/base.py +165 -0
- exa_py/research/models.py +314 -113
- exa_py/research/sync_client.py +308 -0
- exa_py/research/utils.py +222 -0
- exa_py/utils.py +1 -4
- {exa_py-1.14.19.dist-info → exa_py-1.15.0.dist-info}/METADATA +1 -1
- {exa_py-1.14.19.dist-info → exa_py-1.15.0.dist-info}/RECORD +11 -8
- exa_py/research/client.py +0 -358
- {exa_py-1.14.19.dist-info → exa_py-1.15.0.dist-info}/WHEEL +0 -0
exa_py/api.py
CHANGED
|
@@ -41,7 +41,7 @@ from exa_py.utils import (
|
|
|
41
41
|
)
|
|
42
42
|
from .websets import WebsetsClient
|
|
43
43
|
from .websets.core.base import ExaJSONEncoder
|
|
44
|
-
from .research
|
|
44
|
+
from .research import ResearchClient, AsyncResearchClient
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
is_beta = os.getenv("IS_BETA") == "True"
|
|
@@ -133,6 +133,7 @@ SEARCH_OPTIONS_TYPES = {
|
|
|
133
133
|
"end_published_date": [
|
|
134
134
|
str
|
|
135
135
|
], # Results before this publish date; excludes links with no date. ISO 8601 format.
|
|
136
|
+
"user_location": [str], # Two-letter ISO country code of the user (e.g. US).
|
|
136
137
|
"include_text": [
|
|
137
138
|
list
|
|
138
139
|
], # Must be present in webpage text. (One string, up to 5 words)
|
|
@@ -1186,23 +1187,37 @@ class Exa:
|
|
|
1186
1187
|
# Otherwise, serialize the dictionary to JSON if it exists
|
|
1187
1188
|
json_data = json.dumps(data, cls=ExaJSONEncoder) if data else None
|
|
1188
1189
|
|
|
1189
|
-
if data
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
headers=self.headers,
|
|
1194
|
-
stream=True,
|
|
1195
|
-
)
|
|
1196
|
-
return res
|
|
1190
|
+
# Check if we need streaming (either from data for POST or params for GET)
|
|
1191
|
+
needs_streaming = (data and isinstance(data, dict) and data.get("stream")) or (
|
|
1192
|
+
params and params.get("stream") == "true"
|
|
1193
|
+
)
|
|
1197
1194
|
|
|
1198
1195
|
if method.upper() == "GET":
|
|
1199
|
-
|
|
1200
|
-
|
|
1201
|
-
|
|
1196
|
+
if needs_streaming:
|
|
1197
|
+
res = requests.get(
|
|
1198
|
+
self.base_url + endpoint,
|
|
1199
|
+
headers=self.headers,
|
|
1200
|
+
params=params,
|
|
1201
|
+
stream=True,
|
|
1202
|
+
)
|
|
1203
|
+
return res
|
|
1204
|
+
else:
|
|
1205
|
+
res = requests.get(
|
|
1206
|
+
self.base_url + endpoint, headers=self.headers, params=params
|
|
1207
|
+
)
|
|
1202
1208
|
elif method.upper() == "POST":
|
|
1203
|
-
|
|
1204
|
-
|
|
1205
|
-
|
|
1209
|
+
if needs_streaming:
|
|
1210
|
+
res = requests.post(
|
|
1211
|
+
self.base_url + endpoint,
|
|
1212
|
+
data=json_data,
|
|
1213
|
+
headers=self.headers,
|
|
1214
|
+
stream=True,
|
|
1215
|
+
)
|
|
1216
|
+
return res
|
|
1217
|
+
else:
|
|
1218
|
+
res = requests.post(
|
|
1219
|
+
self.base_url + endpoint, data=json_data, headers=self.headers
|
|
1220
|
+
)
|
|
1206
1221
|
elif method.upper() == "PATCH":
|
|
1207
1222
|
res = requests.patch(
|
|
1208
1223
|
self.base_url + endpoint, data=json_data, headers=self.headers
|
|
@@ -1236,6 +1251,7 @@ class Exa:
|
|
|
1236
1251
|
category: Optional[str] = None,
|
|
1237
1252
|
flags: Optional[List[str]] = None,
|
|
1238
1253
|
moderation: Optional[bool] = None,
|
|
1254
|
+
user_location: Optional[str] = None,
|
|
1239
1255
|
) -> SearchResponse[_Result]:
|
|
1240
1256
|
"""Perform a search with a prompt-engineered query to retrieve relevant results.
|
|
1241
1257
|
|
|
@@ -1251,10 +1267,11 @@ class Exa:
|
|
|
1251
1267
|
include_text (List[str], optional): Strings that must appear in the page text.
|
|
1252
1268
|
exclude_text (List[str], optional): Strings that must not appear in the page text.
|
|
1253
1269
|
use_autoprompt (bool, optional): Convert query to Exa (default False).
|
|
1254
|
-
type (str, optional): 'keyword', 'neural', 'hybrid', or '
|
|
1270
|
+
type (str, optional): 'keyword', 'neural', 'hybrid', 'fast', or 'auto' (default 'auto').
|
|
1255
1271
|
category (str, optional): e.g. 'company'
|
|
1256
1272
|
flags (List[str], optional): Experimental flags for Exa usage.
|
|
1257
1273
|
moderation (bool, optional): If True, the search results will be moderated for safety.
|
|
1274
|
+
user_location (str, optional): Two-letter ISO country code of the user (e.g. US).
|
|
1258
1275
|
|
|
1259
1276
|
Returns:
|
|
1260
1277
|
SearchResponse: The response containing search results, etc.
|
|
@@ -1312,6 +1329,7 @@ class Exa:
|
|
|
1312
1329
|
category: Optional[str] = None,
|
|
1313
1330
|
flags: Optional[List[str]] = None,
|
|
1314
1331
|
moderation: Optional[bool] = None,
|
|
1332
|
+
user_location: Optional[str] = None,
|
|
1315
1333
|
livecrawl_timeout: Optional[int] = None,
|
|
1316
1334
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1317
1335
|
filter_empty_results: Optional[bool] = None,
|
|
@@ -1370,6 +1388,7 @@ class Exa:
|
|
|
1370
1388
|
subpage_target: Optional[Union[str, List[str]]] = None,
|
|
1371
1389
|
flags: Optional[List[str]] = None,
|
|
1372
1390
|
moderation: Optional[bool] = None,
|
|
1391
|
+
user_location: Optional[str] = None,
|
|
1373
1392
|
livecrawl_timeout: Optional[int] = None,
|
|
1374
1393
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1375
1394
|
filter_empty_results: Optional[bool] = None,
|
|
@@ -1399,6 +1418,7 @@ class Exa:
|
|
|
1399
1418
|
subpage_target: Optional[Union[str, List[str]]] = None,
|
|
1400
1419
|
flags: Optional[List[str]] = None,
|
|
1401
1420
|
moderation: Optional[bool] = None,
|
|
1421
|
+
user_location: Optional[str] = None,
|
|
1402
1422
|
livecrawl_timeout: Optional[int] = None,
|
|
1403
1423
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1404
1424
|
filter_empty_results: Optional[bool] = None,
|
|
@@ -1427,6 +1447,7 @@ class Exa:
|
|
|
1427
1447
|
subpage_target: Optional[Union[str, List[str]]] = None,
|
|
1428
1448
|
flags: Optional[List[str]] = None,
|
|
1429
1449
|
moderation: Optional[bool] = None,
|
|
1450
|
+
user_location: Optional[str] = None,
|
|
1430
1451
|
livecrawl_timeout: Optional[int] = None,
|
|
1431
1452
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1432
1453
|
filter_empty_results: Optional[bool] = None,
|
|
@@ -1456,6 +1477,7 @@ class Exa:
|
|
|
1456
1477
|
subpage_target: Optional[Union[str, List[str]]] = None,
|
|
1457
1478
|
flags: Optional[List[str]] = None,
|
|
1458
1479
|
moderation: Optional[bool] = None,
|
|
1480
|
+
user_location: Optional[str] = None,
|
|
1459
1481
|
livecrawl_timeout: Optional[int] = None,
|
|
1460
1482
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1461
1483
|
filter_empty_results: Optional[bool] = None,
|
|
@@ -1485,6 +1507,7 @@ class Exa:
|
|
|
1485
1507
|
subpage_target: Optional[Union[str, List[str]]] = None,
|
|
1486
1508
|
flags: Optional[List[str]] = None,
|
|
1487
1509
|
moderation: Optional[bool] = None,
|
|
1510
|
+
user_location: Optional[str] = None,
|
|
1488
1511
|
livecrawl_timeout: Optional[int] = None,
|
|
1489
1512
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1490
1513
|
filter_empty_results: Optional[bool] = None,
|
|
@@ -1513,6 +1536,7 @@ class Exa:
|
|
|
1513
1536
|
category: Optional[str] = None,
|
|
1514
1537
|
flags: Optional[List[str]] = None,
|
|
1515
1538
|
moderation: Optional[bool] = None,
|
|
1539
|
+
user_location: Optional[str] = None,
|
|
1516
1540
|
livecrawl_timeout: Optional[int] = None,
|
|
1517
1541
|
livecrawl: Optional[LIVECRAWL_OPTIONS] = None,
|
|
1518
1542
|
subpages: Optional[int] = None,
|
|
@@ -2396,34 +2420,55 @@ class AsyncExa(Exa):
|
|
|
2396
2420
|
# this may only be a
|
|
2397
2421
|
if self._client is None:
|
|
2398
2422
|
self._client = httpx.AsyncClient(
|
|
2399
|
-
base_url=self.base_url, headers=self.headers, timeout=
|
|
2423
|
+
base_url=self.base_url, headers=self.headers, timeout=600
|
|
2400
2424
|
)
|
|
2401
2425
|
return self._client
|
|
2402
2426
|
|
|
2403
|
-
async def async_request(
|
|
2404
|
-
|
|
2427
|
+
async def async_request(
|
|
2428
|
+
self, endpoint: str, data=None, method: str = "POST", params=None
|
|
2429
|
+
):
|
|
2430
|
+
"""Send a request to the Exa API, optionally streaming if data['stream'] is True.
|
|
2405
2431
|
|
|
2406
2432
|
Args:
|
|
2407
2433
|
endpoint (str): The API endpoint (path).
|
|
2408
|
-
data (dict): The JSON payload to send.
|
|
2434
|
+
data (dict, optional): The JSON payload to send.
|
|
2435
|
+
method (str, optional): The HTTP method to use. Defaults to "POST".
|
|
2436
|
+
params (dict, optional): Query parameters.
|
|
2409
2437
|
|
|
2410
2438
|
Returns:
|
|
2411
|
-
Union[dict,
|
|
2439
|
+
Union[dict, httpx.Response]: If streaming, returns the Response object.
|
|
2412
2440
|
Otherwise, returns the JSON-decoded response as a dict.
|
|
2413
2441
|
|
|
2414
2442
|
Raises:
|
|
2415
2443
|
ValueError: If the request fails (non-200 status code).
|
|
2416
2444
|
"""
|
|
2417
|
-
if data
|
|
2418
|
-
|
|
2419
|
-
|
|
2420
|
-
)
|
|
2421
|
-
res = await self.client.send(request, stream=True)
|
|
2422
|
-
return res
|
|
2423
|
-
|
|
2424
|
-
res = await self.client.post(
|
|
2425
|
-
self.base_url + endpoint, json=data, headers=self.headers
|
|
2445
|
+
# Check if we need streaming (either from data for POST or params for GET)
|
|
2446
|
+
needs_streaming = (data and isinstance(data, dict) and data.get("stream")) or (
|
|
2447
|
+
params and params.get("stream") == "true"
|
|
2426
2448
|
)
|
|
2449
|
+
|
|
2450
|
+
if method.upper() == "GET":
|
|
2451
|
+
if needs_streaming:
|
|
2452
|
+
request = httpx.Request(
|
|
2453
|
+
"GET", self.base_url + endpoint, params=params, headers=self.headers
|
|
2454
|
+
)
|
|
2455
|
+
res = await self.client.send(request, stream=True)
|
|
2456
|
+
return res
|
|
2457
|
+
else:
|
|
2458
|
+
res = await self.client.get(
|
|
2459
|
+
self.base_url + endpoint, params=params, headers=self.headers
|
|
2460
|
+
)
|
|
2461
|
+
elif method.upper() == "POST":
|
|
2462
|
+
if needs_streaming:
|
|
2463
|
+
request = httpx.Request(
|
|
2464
|
+
"POST", self.base_url + endpoint, json=data, headers=self.headers
|
|
2465
|
+
)
|
|
2466
|
+
res = await self.client.send(request, stream=True)
|
|
2467
|
+
return res
|
|
2468
|
+
else:
|
|
2469
|
+
res = await self.client.post(
|
|
2470
|
+
self.base_url + endpoint, json=data, headers=self.headers
|
|
2471
|
+
)
|
|
2427
2472
|
if res.status_code != 200 and res.status_code != 201:
|
|
2428
2473
|
raise ValueError(
|
|
2429
2474
|
f"Request failed with status code {res.status_code}: {res.text}"
|
|
@@ -2448,6 +2493,7 @@ class AsyncExa(Exa):
|
|
|
2448
2493
|
category: Optional[str] = None,
|
|
2449
2494
|
flags: Optional[List[str]] = None,
|
|
2450
2495
|
moderation: Optional[bool] = None,
|
|
2496
|
+
user_location: Optional[str] = None,
|
|
2451
2497
|
) -> SearchResponse[_Result]:
|
|
2452
2498
|
"""Perform a search with a prompt-engineered query to retrieve relevant results.
|
|
2453
2499
|
|
|
@@ -2463,10 +2509,11 @@ class AsyncExa(Exa):
|
|
|
2463
2509
|
include_text (List[str], optional): Strings that must appear in the page text.
|
|
2464
2510
|
exclude_text (List[str], optional): Strings that must not appear in the page text.
|
|
2465
2511
|
use_autoprompt (bool, optional): Convert query to Exa (default False).
|
|
2466
|
-
type (str, optional): 'keyword', 'neural', 'hybrid', or '
|
|
2512
|
+
type (str, optional): 'keyword', 'neural', 'hybrid', 'fast', or 'auto' (default 'auto').
|
|
2467
2513
|
category (str, optional): e.g. 'company'
|
|
2468
2514
|
flags (List[str], optional): Experimental flags for Exa usage.
|
|
2469
2515
|
moderation (bool, optional): If True, the search results will be moderated for safety.
|
|
2516
|
+
user_location (str, optional): Two-letter ISO country code of the user (e.g. US).
|
|
2470
2517
|
|
|
2471
2518
|
Returns:
|
|
2472
2519
|
SearchResponse: The response containing search results, etc.
|
exa_py/research/__init__.py
CHANGED
|
@@ -1,10 +1,39 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""Research API client modules for Exa."""
|
|
2
|
+
|
|
3
|
+
from .sync_client import ResearchClient, ResearchTyped
|
|
4
|
+
from .async_client import AsyncResearchClient, AsyncResearchTyped
|
|
5
|
+
from .models import (
|
|
6
|
+
ResearchDto,
|
|
7
|
+
ResearchEvent,
|
|
8
|
+
ResearchDefinitionEvent,
|
|
9
|
+
ResearchOutputEvent,
|
|
10
|
+
ResearchPlanDefinitionEvent,
|
|
11
|
+
ResearchPlanOperationEvent,
|
|
12
|
+
ResearchPlanOutputEvent,
|
|
13
|
+
ResearchTaskDefinitionEvent,
|
|
14
|
+
ResearchTaskOperationEvent,
|
|
15
|
+
ResearchTaskOutputEvent,
|
|
16
|
+
ListResearchResponseDto,
|
|
17
|
+
CostDollars,
|
|
18
|
+
ResearchOutput,
|
|
19
|
+
)
|
|
3
20
|
|
|
4
21
|
__all__ = [
|
|
5
22
|
"ResearchClient",
|
|
6
23
|
"AsyncResearchClient",
|
|
7
|
-
"
|
|
8
|
-
"
|
|
9
|
-
"
|
|
24
|
+
"ResearchTyped",
|
|
25
|
+
"AsyncResearchTyped",
|
|
26
|
+
"ResearchDto",
|
|
27
|
+
"ResearchEvent",
|
|
28
|
+
"ResearchDefinitionEvent",
|
|
29
|
+
"ResearchOutputEvent",
|
|
30
|
+
"ResearchPlanDefinitionEvent",
|
|
31
|
+
"ResearchPlanOperationEvent",
|
|
32
|
+
"ResearchPlanOutputEvent",
|
|
33
|
+
"ResearchTaskDefinitionEvent",
|
|
34
|
+
"ResearchTaskOperationEvent",
|
|
35
|
+
"ResearchTaskOutputEvent",
|
|
36
|
+
"ListResearchResponseDto",
|
|
37
|
+
"CostDollars",
|
|
38
|
+
"ResearchOutput",
|
|
10
39
|
]
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
"""Asynchronous Research API client."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
AsyncGenerator,
|
|
9
|
+
Dict,
|
|
10
|
+
Generic,
|
|
11
|
+
Literal,
|
|
12
|
+
Optional,
|
|
13
|
+
Type,
|
|
14
|
+
TypeVar,
|
|
15
|
+
Union,
|
|
16
|
+
overload,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel, TypeAdapter
|
|
20
|
+
|
|
21
|
+
from .base import AsyncResearchBaseClient
|
|
22
|
+
from .models import (
|
|
23
|
+
ResearchDto,
|
|
24
|
+
ResearchEvent,
|
|
25
|
+
ListResearchResponseDto,
|
|
26
|
+
)
|
|
27
|
+
from .utils import (
|
|
28
|
+
async_stream_sse_events,
|
|
29
|
+
is_pydantic_model,
|
|
30
|
+
pydantic_to_json_schema,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
T = TypeVar("T", bound=BaseModel)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AsyncResearchTyped(Generic[T]):
|
|
37
|
+
"""Wrapper for typed research responses in async context."""
|
|
38
|
+
|
|
39
|
+
def __init__(self, research: ResearchDto, parsed_output: T):
|
|
40
|
+
self.research = research
|
|
41
|
+
self.parsed_output = parsed_output
|
|
42
|
+
# Expose research fields
|
|
43
|
+
self.research_id = research.research_id
|
|
44
|
+
self.status = research.status
|
|
45
|
+
self.created_at = research.created_at
|
|
46
|
+
self.model = research.model
|
|
47
|
+
self.instructions = research.instructions
|
|
48
|
+
if hasattr(research, "events"):
|
|
49
|
+
self.events = research.events
|
|
50
|
+
if hasattr(research, "output"):
|
|
51
|
+
self.output = research.output
|
|
52
|
+
if hasattr(research, "cost_dollars"):
|
|
53
|
+
self.cost_dollars = research.cost_dollars
|
|
54
|
+
if hasattr(research, "error"):
|
|
55
|
+
self.error = research.error
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class AsyncResearchClient(AsyncResearchBaseClient):
|
|
59
|
+
"""Asynchronous client for the Research API."""
|
|
60
|
+
|
|
61
|
+
@overload
|
|
62
|
+
async def create(
|
|
63
|
+
self,
|
|
64
|
+
*,
|
|
65
|
+
instructions: str,
|
|
66
|
+
model: Literal["exa-research", "exa-research-pro"] = "exa-research",
|
|
67
|
+
) -> ResearchDto: ...
|
|
68
|
+
|
|
69
|
+
@overload
|
|
70
|
+
async def create(
|
|
71
|
+
self,
|
|
72
|
+
*,
|
|
73
|
+
instructions: str,
|
|
74
|
+
model: Literal["exa-research", "exa-research-pro"] = "exa-research",
|
|
75
|
+
output_schema: Dict[str, Any],
|
|
76
|
+
) -> ResearchDto: ...
|
|
77
|
+
|
|
78
|
+
@overload
|
|
79
|
+
async def create(
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
instructions: str,
|
|
83
|
+
model: Literal["exa-research", "exa-research-pro"] = "exa-research",
|
|
84
|
+
output_schema: Type[T],
|
|
85
|
+
) -> ResearchDto: ...
|
|
86
|
+
|
|
87
|
+
async def create(
|
|
88
|
+
self,
|
|
89
|
+
*,
|
|
90
|
+
instructions: str,
|
|
91
|
+
model: Literal["exa-research", "exa-research-pro"] = "exa-research",
|
|
92
|
+
output_schema: Optional[Union[Dict[str, Any], Type[BaseModel]]] = None,
|
|
93
|
+
) -> ResearchDto:
|
|
94
|
+
"""Create a new research request.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
instructions: The research instructions.
|
|
98
|
+
model: The model to use for research.
|
|
99
|
+
output_schema: Optional JSON schema or Pydantic model for structured output.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
The created research object.
|
|
103
|
+
"""
|
|
104
|
+
payload = {
|
|
105
|
+
"instructions": instructions,
|
|
106
|
+
"model": model,
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
if output_schema is not None:
|
|
110
|
+
if is_pydantic_model(output_schema):
|
|
111
|
+
payload["outputSchema"] = pydantic_to_json_schema(output_schema)
|
|
112
|
+
else:
|
|
113
|
+
payload["outputSchema"] = output_schema
|
|
114
|
+
|
|
115
|
+
response = await self.request("", method="POST", data=payload)
|
|
116
|
+
adapter = TypeAdapter(ResearchDto)
|
|
117
|
+
return adapter.validate_python(response)
|
|
118
|
+
|
|
119
|
+
@overload
|
|
120
|
+
async def get(
|
|
121
|
+
self,
|
|
122
|
+
research_id: str,
|
|
123
|
+
) -> ResearchDto: ...
|
|
124
|
+
|
|
125
|
+
@overload
|
|
126
|
+
async def get(
|
|
127
|
+
self,
|
|
128
|
+
research_id: str,
|
|
129
|
+
*,
|
|
130
|
+
stream: Literal[False] = False,
|
|
131
|
+
events: bool = False,
|
|
132
|
+
) -> ResearchDto: ...
|
|
133
|
+
|
|
134
|
+
@overload
|
|
135
|
+
async def get(
|
|
136
|
+
self,
|
|
137
|
+
research_id: str,
|
|
138
|
+
*,
|
|
139
|
+
stream: Literal[True],
|
|
140
|
+
events: Optional[bool] = None,
|
|
141
|
+
) -> AsyncGenerator[ResearchEvent, None]: ...
|
|
142
|
+
|
|
143
|
+
@overload
|
|
144
|
+
async def get(
|
|
145
|
+
self,
|
|
146
|
+
research_id: str,
|
|
147
|
+
*,
|
|
148
|
+
stream: Literal[False] = False,
|
|
149
|
+
events: bool = False,
|
|
150
|
+
output_schema: Type[T],
|
|
151
|
+
) -> AsyncResearchTyped[T]: ...
|
|
152
|
+
|
|
153
|
+
async def get(
|
|
154
|
+
self,
|
|
155
|
+
research_id: str,
|
|
156
|
+
*,
|
|
157
|
+
stream: bool = False,
|
|
158
|
+
events: bool = False,
|
|
159
|
+
output_schema: Optional[Type[BaseModel]] = None,
|
|
160
|
+
) -> Union[ResearchDto, AsyncResearchTyped, AsyncGenerator[ResearchEvent, None]]:
|
|
161
|
+
"""Get a research request by ID.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
research_id: The research ID.
|
|
165
|
+
stream: Whether to stream events.
|
|
166
|
+
events: Whether to include events in non-streaming response.
|
|
167
|
+
output_schema: Optional Pydantic model for typed output validation.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
Research object, typed research, or async event generator.
|
|
171
|
+
"""
|
|
172
|
+
params = {}
|
|
173
|
+
if not stream:
|
|
174
|
+
params["stream"] = "false"
|
|
175
|
+
if events:
|
|
176
|
+
params["events"] = "true"
|
|
177
|
+
else:
|
|
178
|
+
params["stream"] = "true"
|
|
179
|
+
if events is not None:
|
|
180
|
+
params["events"] = str(events).lower()
|
|
181
|
+
|
|
182
|
+
if stream:
|
|
183
|
+
response = await self.request(
|
|
184
|
+
f"/{research_id}", method="GET", params=params, stream=True
|
|
185
|
+
)
|
|
186
|
+
return async_stream_sse_events(response)
|
|
187
|
+
else:
|
|
188
|
+
response = await self.request(
|
|
189
|
+
f"/{research_id}", method="GET", params=params
|
|
190
|
+
)
|
|
191
|
+
adapter = TypeAdapter(ResearchDto)
|
|
192
|
+
research = adapter.validate_python(response)
|
|
193
|
+
|
|
194
|
+
if output_schema and hasattr(research, "output") and research.output:
|
|
195
|
+
try:
|
|
196
|
+
if research.output.parsed:
|
|
197
|
+
parsed = output_schema.model_validate(research.output.parsed)
|
|
198
|
+
else:
|
|
199
|
+
import json
|
|
200
|
+
|
|
201
|
+
parsed_data = json.loads(research.output.content)
|
|
202
|
+
parsed = output_schema.model_validate(parsed_data)
|
|
203
|
+
return AsyncResearchTyped(research, parsed)
|
|
204
|
+
except Exception:
|
|
205
|
+
# If parsing fails, return the regular research object
|
|
206
|
+
return research
|
|
207
|
+
|
|
208
|
+
return research
|
|
209
|
+
|
|
210
|
+
async def list(
|
|
211
|
+
self,
|
|
212
|
+
*,
|
|
213
|
+
cursor: Optional[str] = None,
|
|
214
|
+
limit: Optional[int] = None,
|
|
215
|
+
) -> ListResearchResponseDto:
|
|
216
|
+
"""List research requests.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
cursor: Pagination cursor.
|
|
220
|
+
limit: Maximum number of results.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
List of research objects with pagination info.
|
|
224
|
+
"""
|
|
225
|
+
params = self.build_pagination_params(cursor, limit)
|
|
226
|
+
response = await self.request("", method="GET", params=params)
|
|
227
|
+
return ListResearchResponseDto.model_validate(response)
|
|
228
|
+
|
|
229
|
+
@overload
|
|
230
|
+
async def poll_until_finished(
|
|
231
|
+
self,
|
|
232
|
+
research_id: str,
|
|
233
|
+
*,
|
|
234
|
+
poll_interval: int = 1000,
|
|
235
|
+
timeout_ms: int = 600000,
|
|
236
|
+
events: bool = False,
|
|
237
|
+
) -> ResearchDto: ...
|
|
238
|
+
|
|
239
|
+
@overload
|
|
240
|
+
async def poll_until_finished(
|
|
241
|
+
self,
|
|
242
|
+
research_id: str,
|
|
243
|
+
*,
|
|
244
|
+
poll_interval: int = 1000,
|
|
245
|
+
timeout_ms: int = 600000,
|
|
246
|
+
events: bool = False,
|
|
247
|
+
output_schema: Type[T],
|
|
248
|
+
) -> AsyncResearchTyped[T]: ...
|
|
249
|
+
|
|
250
|
+
async def poll_until_finished(
|
|
251
|
+
self,
|
|
252
|
+
research_id: str,
|
|
253
|
+
*,
|
|
254
|
+
poll_interval: int = 1000,
|
|
255
|
+
timeout_ms: int = 600000,
|
|
256
|
+
events: bool = False,
|
|
257
|
+
output_schema: Optional[Type[BaseModel]] = None,
|
|
258
|
+
) -> Union[ResearchDto, AsyncResearchTyped]:
|
|
259
|
+
"""Poll until research is finished.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
research_id: The research ID.
|
|
263
|
+
poll_interval: Milliseconds between polls (default 1000).
|
|
264
|
+
timeout_ms: Maximum time to wait in milliseconds (default 600000).
|
|
265
|
+
events: Whether to include events in the response.
|
|
266
|
+
output_schema: Optional Pydantic model for typed output validation.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Completed research object or typed research.
|
|
270
|
+
|
|
271
|
+
Raises:
|
|
272
|
+
TimeoutError: If research doesn't complete within timeout.
|
|
273
|
+
RuntimeError: If polling fails too many times.
|
|
274
|
+
"""
|
|
275
|
+
poll_interval_sec = poll_interval / 1000
|
|
276
|
+
timeout_sec = timeout_ms / 1000
|
|
277
|
+
max_consecutive_failures = 5
|
|
278
|
+
start_time = asyncio.get_event_loop().time()
|
|
279
|
+
consecutive_failures = 0
|
|
280
|
+
|
|
281
|
+
while True:
|
|
282
|
+
try:
|
|
283
|
+
if output_schema:
|
|
284
|
+
result = await self.get(
|
|
285
|
+
research_id, events=events, output_schema=output_schema
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
result = await self.get(research_id, events=events)
|
|
289
|
+
|
|
290
|
+
consecutive_failures = 0
|
|
291
|
+
|
|
292
|
+
# Check if research is finished
|
|
293
|
+
status = result.status if hasattr(result, "status") else None
|
|
294
|
+
if status in ["completed", "failed", "canceled"]:
|
|
295
|
+
return result
|
|
296
|
+
|
|
297
|
+
except Exception as e:
|
|
298
|
+
consecutive_failures += 1
|
|
299
|
+
if consecutive_failures >= max_consecutive_failures:
|
|
300
|
+
raise RuntimeError(
|
|
301
|
+
f"Polling failed {max_consecutive_failures} times in a row "
|
|
302
|
+
f"for research {research_id}: {e}"
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
if asyncio.get_event_loop().time() - start_time > timeout_sec:
|
|
306
|
+
raise TimeoutError(
|
|
307
|
+
f"Research {research_id} did not complete within {timeout_ms}ms"
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
await asyncio.sleep(poll_interval_sec)
|