exa-py 1.14.20__py3-none-any.whl → 1.15.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of exa-py might be problematic. Click here for more details.

@@ -0,0 +1,308 @@
1
+ """Synchronous Research API client."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import (
7
+ Any,
8
+ Dict,
9
+ Generator,
10
+ Generic,
11
+ Literal,
12
+ Optional,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ overload,
17
+ )
18
+
19
+ from pydantic import BaseModel, TypeAdapter
20
+
21
+ from .base import ResearchBaseClient
22
+ from .models import (
23
+ ResearchDto,
24
+ ResearchEvent,
25
+ ListResearchResponseDto,
26
+ )
27
+ from .utils import (
28
+ is_pydantic_model,
29
+ pydantic_to_json_schema,
30
+ stream_sse_events,
31
+ )
32
+
33
+ T = TypeVar("T", bound=BaseModel)
34
+
35
+
36
+ class ResearchTyped(Generic[T]):
37
+ """Wrapper for typed research responses."""
38
+
39
+ def __init__(self, research: ResearchDto, parsed_output: T):
40
+ self.research = research
41
+ self.parsed_output = parsed_output
42
+ # Expose research fields
43
+ self.research_id = research.research_id
44
+ self.status = research.status
45
+ self.created_at = research.created_at
46
+ self.model = research.model
47
+ self.instructions = research.instructions
48
+ if hasattr(research, "events"):
49
+ self.events = research.events
50
+ if hasattr(research, "output"):
51
+ self.output = research.output
52
+ if hasattr(research, "cost_dollars"):
53
+ self.cost_dollars = research.cost_dollars
54
+ if hasattr(research, "error"):
55
+ self.error = research.error
56
+
57
+
58
+ class ResearchClient(ResearchBaseClient):
59
+ """Synchronous client for the Research API."""
60
+
61
+ @overload
62
+ def create(
63
+ self,
64
+ *,
65
+ instructions: str,
66
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
67
+ ) -> ResearchDto: ...
68
+
69
+ @overload
70
+ def create(
71
+ self,
72
+ *,
73
+ instructions: str,
74
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
75
+ output_schema: Dict[str, Any],
76
+ ) -> ResearchDto: ...
77
+
78
+ @overload
79
+ def create(
80
+ self,
81
+ *,
82
+ instructions: str,
83
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
84
+ output_schema: Type[T],
85
+ ) -> ResearchDto: ...
86
+
87
+ def create(
88
+ self,
89
+ *,
90
+ instructions: str,
91
+ model: Literal["exa-research", "exa-research-pro"] = "exa-research",
92
+ output_schema: Optional[Union[Dict[str, Any], Type[BaseModel]]] = None,
93
+ ) -> ResearchDto:
94
+ """Create a new research request.
95
+
96
+ Args:
97
+ instructions: The research instructions.
98
+ model: The model to use for research.
99
+ output_schema: Optional JSON schema or Pydantic model for structured output.
100
+
101
+ Returns:
102
+ The created research object.
103
+ """
104
+ payload = {
105
+ "instructions": instructions,
106
+ "model": model,
107
+ }
108
+
109
+ if output_schema is not None:
110
+ if is_pydantic_model(output_schema):
111
+ payload["outputSchema"] = pydantic_to_json_schema(output_schema)
112
+ else:
113
+ payload["outputSchema"] = output_schema
114
+
115
+ response = self.request("", method="POST", data=payload)
116
+ adapter = TypeAdapter(ResearchDto)
117
+ return adapter.validate_python(response)
118
+
119
+ @overload
120
+ def get(
121
+ self,
122
+ research_id: str,
123
+ ) -> ResearchDto: ...
124
+
125
+ @overload
126
+ def get(
127
+ self,
128
+ research_id: str,
129
+ *,
130
+ stream: Literal[False] = False,
131
+ events: bool = False,
132
+ ) -> ResearchDto: ...
133
+
134
+ @overload
135
+ def get(
136
+ self,
137
+ research_id: str,
138
+ *,
139
+ stream: Literal[True],
140
+ events: Optional[bool] = None,
141
+ ) -> Generator[ResearchEvent, None, None]: ...
142
+
143
+ @overload
144
+ def get(
145
+ self,
146
+ research_id: str,
147
+ *,
148
+ stream: Literal[False] = False,
149
+ events: bool = False,
150
+ output_schema: Type[T],
151
+ ) -> ResearchTyped[T]: ...
152
+
153
+ def get(
154
+ self,
155
+ research_id: str,
156
+ *,
157
+ stream: bool = False,
158
+ events: bool = False,
159
+ output_schema: Optional[Type[BaseModel]] = None,
160
+ ) -> Union[ResearchDto, ResearchTyped, Generator[ResearchEvent, None, None]]:
161
+ """Get a research request by ID.
162
+
163
+ Args:
164
+ research_id: The research ID.
165
+ stream: Whether to stream events.
166
+ events: Whether to include events in non-streaming response.
167
+ output_schema: Optional Pydantic model for typed output validation.
168
+
169
+ Returns:
170
+ Research object, typed research, or event generator.
171
+ """
172
+ params = {}
173
+ if not stream:
174
+ params["stream"] = "false"
175
+ if events:
176
+ params["events"] = "true"
177
+ else:
178
+ params["stream"] = "true"
179
+ if events is not None:
180
+ params["events"] = str(events).lower()
181
+
182
+ if stream:
183
+ response = self.request(
184
+ f"/{research_id}", method="GET", params=params, stream=True
185
+ )
186
+ return stream_sse_events(response)
187
+ else:
188
+ response = self.request(f"/{research_id}", method="GET", params=params)
189
+ adapter = TypeAdapter(ResearchDto)
190
+ research = adapter.validate_python(response)
191
+
192
+ if output_schema and hasattr(research, "output") and research.output:
193
+ try:
194
+ if research.output.parsed:
195
+ parsed = output_schema.model_validate(research.output.parsed)
196
+ else:
197
+ import json
198
+
199
+ parsed_data = json.loads(research.output.content)
200
+ parsed = output_schema.model_validate(parsed_data)
201
+ return ResearchTyped(research, parsed)
202
+ except Exception:
203
+ # If parsing fails, return the regular research object
204
+ return research
205
+
206
+ return research
207
+
208
+ def list(
209
+ self,
210
+ *,
211
+ cursor: Optional[str] = None,
212
+ limit: Optional[int] = None,
213
+ ) -> ListResearchResponseDto:
214
+ """List research requests.
215
+
216
+ Args:
217
+ cursor: Pagination cursor.
218
+ limit: Maximum number of results.
219
+
220
+ Returns:
221
+ List of research objects with pagination info.
222
+ """
223
+ params = self.build_pagination_params(cursor, limit)
224
+ response = self.request("", method="GET", params=params)
225
+ return ListResearchResponseDto.model_validate(response)
226
+
227
+ @overload
228
+ def poll_until_finished(
229
+ self,
230
+ research_id: str,
231
+ *,
232
+ poll_interval: int = 1000,
233
+ timeout_ms: int = 600000,
234
+ events: bool = False,
235
+ ) -> ResearchDto: ...
236
+
237
+ @overload
238
+ def poll_until_finished(
239
+ self,
240
+ research_id: str,
241
+ *,
242
+ poll_interval: int = 1000,
243
+ timeout_ms: int = 600000,
244
+ events: bool = False,
245
+ output_schema: Type[T],
246
+ ) -> ResearchTyped[T]: ...
247
+
248
+ def poll_until_finished(
249
+ self,
250
+ research_id: str,
251
+ *,
252
+ poll_interval: int = 1000,
253
+ timeout_ms: int = 600000,
254
+ events: bool = False,
255
+ output_schema: Optional[Type[BaseModel]] = None,
256
+ ) -> Union[ResearchDto, ResearchTyped]:
257
+ """Poll until research is finished.
258
+
259
+ Args:
260
+ research_id: The research ID.
261
+ poll_interval: Milliseconds between polls (default 1000).
262
+ timeout_ms: Maximum time to wait in milliseconds (default 600000).
263
+ events: Whether to include events in the response.
264
+ output_schema: Optional Pydantic model for typed output validation.
265
+
266
+ Returns:
267
+ Completed research object or typed research.
268
+
269
+ Raises:
270
+ TimeoutError: If research doesn't complete within timeout.
271
+ RuntimeError: If polling fails too many times.
272
+ """
273
+ poll_interval_sec = poll_interval / 1000
274
+ timeout_sec = timeout_ms / 1000
275
+ max_consecutive_failures = 5
276
+ start_time = time.time()
277
+ consecutive_failures = 0
278
+
279
+ while True:
280
+ try:
281
+ if output_schema:
282
+ result = self.get(
283
+ research_id, events=events, output_schema=output_schema
284
+ )
285
+ else:
286
+ result = self.get(research_id, events=events)
287
+
288
+ consecutive_failures = 0
289
+
290
+ # Check if research is finished
291
+ status = result.status if hasattr(result, "status") else None
292
+ if status in ["completed", "failed", "canceled"]:
293
+ return result
294
+
295
+ except Exception as e:
296
+ consecutive_failures += 1
297
+ if consecutive_failures >= max_consecutive_failures:
298
+ raise RuntimeError(
299
+ f"Polling failed {max_consecutive_failures} times in a row "
300
+ f"for research {research_id}: {e}"
301
+ )
302
+
303
+ if time.time() - start_time > timeout_sec:
304
+ raise TimeoutError(
305
+ f"Research {research_id} did not complete within {timeout_ms}ms"
306
+ )
307
+
308
+ time.sleep(poll_interval_sec)
@@ -0,0 +1,222 @@
1
+ """Utilities for the Research API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from typing import (
7
+ Any,
8
+ AsyncGenerator,
9
+ Dict,
10
+ Generator,
11
+ Optional,
12
+ Type,
13
+ )
14
+
15
+ import httpx
16
+ import requests
17
+ from pydantic import BaseModel, ValidationError
18
+
19
+ from .models import (
20
+ ResearchEvent,
21
+ ResearchDefinitionEvent,
22
+ ResearchOutputEvent,
23
+ ResearchPlanDefinitionEvent,
24
+ ResearchPlanOperationEvent,
25
+ ResearchPlanOutputEvent,
26
+ ResearchTaskDefinitionEvent,
27
+ ResearchTaskOperationEvent,
28
+ ResearchTaskOutputEvent,
29
+ )
30
+
31
+
32
+ def is_pydantic_model(schema: Any) -> bool:
33
+ """Check if the given schema is a Pydantic model class.
34
+
35
+ Args:
36
+ schema: The schema to check.
37
+
38
+ Returns:
39
+ True if schema is a Pydantic model class, False otherwise.
40
+ """
41
+ try:
42
+ return isinstance(schema, type) and issubclass(schema, BaseModel)
43
+ except (TypeError, AttributeError):
44
+ return False
45
+
46
+
47
+ def pydantic_to_json_schema(model: Type[BaseModel]) -> Dict[str, Any]:
48
+ """Convert a Pydantic model to JSON Schema.
49
+
50
+ Args:
51
+ model: The Pydantic model class.
52
+
53
+ Returns:
54
+ JSON Schema dictionary with all references inlined.
55
+ """
56
+ # Import here to avoid circular imports
57
+ from exa_py.utils import _convert_schema_input
58
+
59
+ # Use the existing _convert_schema_input which already handles inlining
60
+ return _convert_schema_input(model)
61
+
62
+
63
+ def parse_sse_line(line: str) -> Optional[tuple[str, str]]:
64
+ """Parse a single SSE line.
65
+
66
+ Args:
67
+ line: The SSE line to parse.
68
+
69
+ Returns:
70
+ Tuple of (field, value) or None if not a valid SSE line.
71
+ """
72
+ if not line or not line.strip():
73
+ return None
74
+
75
+ if ":" not in line:
76
+ return None
77
+
78
+ field, _, value = line.partition(":")
79
+ return field.strip(), value.strip()
80
+
81
+
82
+ def parse_sse_event_raw(event_lines: list[str]) -> Optional[Dict[str, Any]]:
83
+ """Parse SSE event lines into a raw event dictionary.
84
+
85
+ Args:
86
+ event_lines: List of lines that make up an SSE event.
87
+
88
+ Returns:
89
+ Parsed event data or None if invalid.
90
+ """
91
+ event_name = None
92
+ event_data = None
93
+
94
+ for line in event_lines:
95
+ parsed = parse_sse_line(line)
96
+ if not parsed:
97
+ continue
98
+
99
+ field, value = parsed
100
+ if field == "event":
101
+ event_name = value
102
+ elif field == "data":
103
+ try:
104
+ event_data = json.loads(value)
105
+ except json.JSONDecodeError:
106
+ # Some events might have non-JSON data
107
+ event_data = value
108
+
109
+ if event_name and event_data:
110
+ # Add eventType to the data for consistency
111
+ if isinstance(event_data, dict):
112
+ event_data["eventType"] = event_name
113
+ return event_data
114
+
115
+ return None
116
+
117
+
118
+ def parse_research_event(raw_event: Dict[str, Any]) -> Optional[ResearchEvent]:
119
+ """Parse a raw event dictionary into a typed ResearchEvent.
120
+
121
+ Args:
122
+ raw_event: Raw event dictionary with eventType field.
123
+
124
+ Returns:
125
+ Typed ResearchEvent or None if parsing fails.
126
+ """
127
+ event_type = raw_event.get("eventType")
128
+ if not event_type:
129
+ return None
130
+
131
+ # Map event types to their corresponding Pydantic models
132
+ event_models = {
133
+ "research-definition": ResearchDefinitionEvent,
134
+ "research-output": ResearchOutputEvent,
135
+ "plan-definition": ResearchPlanDefinitionEvent,
136
+ "plan-operation": ResearchPlanOperationEvent,
137
+ "plan-output": ResearchPlanOutputEvent,
138
+ "task-definition": ResearchTaskDefinitionEvent,
139
+ "task-operation": ResearchTaskOperationEvent,
140
+ "task-output": ResearchTaskOutputEvent,
141
+ }
142
+
143
+ model_class = event_models.get(event_type)
144
+ if not model_class:
145
+ return None
146
+
147
+ try:
148
+ return model_class.model_validate(raw_event)
149
+ except ValidationError:
150
+ # Log or handle validation error if needed
151
+ return None
152
+
153
+
154
+ def stream_sse_events(
155
+ response: requests.Response,
156
+ ) -> Generator[ResearchEvent, None, None]:
157
+ """Stream SSE events from a requests Response.
158
+
159
+ Args:
160
+ response: The streaming response object.
161
+
162
+ Yields:
163
+ Parsed ResearchEvent objects.
164
+ """
165
+ event_lines = []
166
+
167
+ for line in response.iter_lines():
168
+ if not line:
169
+ # Empty line signals end of event
170
+ if event_lines:
171
+ raw_event = parse_sse_event_raw(event_lines)
172
+ if raw_event:
173
+ event = parse_research_event(raw_event)
174
+ if event:
175
+ yield event
176
+ event_lines = []
177
+ else:
178
+ decoded_line = line.decode("utf-8") if isinstance(line, bytes) else line
179
+ event_lines.append(decoded_line)
180
+
181
+ # Handle any remaining lines
182
+ if event_lines:
183
+ raw_event = parse_sse_event_raw(event_lines)
184
+ if raw_event:
185
+ event = parse_research_event(raw_event)
186
+ if event:
187
+ yield event
188
+
189
+
190
+ async def async_stream_sse_events(
191
+ response: httpx.Response,
192
+ ) -> AsyncGenerator[ResearchEvent, None]:
193
+ """Stream SSE events from an httpx Response.
194
+
195
+ Args:
196
+ response: The async streaming response object.
197
+
198
+ Yields:
199
+ Parsed ResearchEvent objects.
200
+ """
201
+ event_lines = []
202
+
203
+ async for line in response.aiter_lines():
204
+ if not line:
205
+ # Empty line signals end of event
206
+ if event_lines:
207
+ raw_event = parse_sse_event_raw(event_lines)
208
+ if raw_event:
209
+ event = parse_research_event(raw_event)
210
+ if event:
211
+ yield event
212
+ event_lines = []
213
+ else:
214
+ event_lines.append(line)
215
+
216
+ # Handle any remaining lines
217
+ if event_lines:
218
+ raw_event = parse_sse_event_raw(event_lines)
219
+ if raw_event:
220
+ event = parse_research_event(raw_event)
221
+ if event:
222
+ yield event
exa_py/utils.py CHANGED
@@ -1,10 +1,7 @@
1
1
  import json
2
- import os
3
- from typing import Any, Optional, Union
2
+ from typing import TYPE_CHECKING, Any, Optional, Union
4
3
  from openai.types.chat import ChatCompletion
5
4
 
6
- from typing import TYPE_CHECKING
7
-
8
5
  from pydantic import BaseModel
9
6
  from pydantic.json_schema import GenerateJsonSchema
10
7
 
@@ -4,6 +4,7 @@ from typing import Dict, Any, Union
4
4
 
5
5
  from ..types import (
6
6
  CreateEnrichmentParameters,
7
+ UpdateEnrichmentParameters,
7
8
  WebsetEnrichment,
8
9
  )
9
10
  from ..core.base import WebsetsBaseClient
@@ -40,6 +41,20 @@ class WebsetEnrichmentsClient(WebsetsBaseClient):
40
41
  response = self.request(f"/v0/websets/{webset_id}/enrichments/{id}", method="GET")
41
42
  return WebsetEnrichment.model_validate(response)
42
43
 
44
+ def update(self, webset_id: str, id: str, params: Union[Dict[str, Any], UpdateEnrichmentParameters]) -> WebsetEnrichment:
45
+ """Update an Enrichment.
46
+
47
+ Args:
48
+ webset_id (str): The id of the Webset.
49
+ id (str): The id of the Enrichment.
50
+ params (UpdateEnrichmentParameters): The parameters for updating an enrichment.
51
+
52
+ Returns:
53
+ WebsetEnrichment: The updated enrichment.
54
+ """
55
+ response = self.request(f"/v0/websets/{webset_id}/enrichments/{id}", data=params, method="PATCH")
56
+ return WebsetEnrichment.model_validate(response)
57
+
43
58
  def delete(self, webset_id: str, id: str) -> WebsetEnrichment:
44
59
  """Delete an Enrichment.
45
60
 
@@ -15,6 +15,13 @@ from ..types import (
15
15
  WebsetSearchUpdatedEvent,
16
16
  WebsetSearchCanceledEvent,
17
17
  WebsetSearchCompletedEvent,
18
+ ImportCreatedEvent,
19
+ ImportCompletedEvent,
20
+ MonitorCreatedEvent,
21
+ MonitorUpdatedEvent,
22
+ MonitorDeletedEvent,
23
+ MonitorRunCreatedEvent,
24
+ MonitorRunCompletedEvent,
18
25
  )
19
26
  from ..core.base import WebsetsBaseClient
20
27
 
@@ -30,6 +37,13 @@ Event = Union[
30
37
  WebsetSearchUpdatedEvent,
31
38
  WebsetSearchCanceledEvent,
32
39
  WebsetSearchCompletedEvent,
40
+ ImportCreatedEvent,
41
+ ImportCompletedEvent,
42
+ MonitorCreatedEvent,
43
+ MonitorUpdatedEvent,
44
+ MonitorDeletedEvent,
45
+ MonitorRunCreatedEvent,
46
+ MonitorRunCompletedEvent,
33
47
  ]
34
48
 
35
49
  class EventsClient(WebsetsBaseClient):
@@ -89,6 +103,13 @@ class EventsClient(WebsetsBaseClient):
89
103
  'webset.search.updated': WebsetSearchUpdatedEvent,
90
104
  'webset.search.canceled': WebsetSearchCanceledEvent,
91
105
  'webset.search.completed': WebsetSearchCompletedEvent,
106
+ 'import.created': ImportCreatedEvent,
107
+ 'import.completed': ImportCompletedEvent,
108
+ 'monitor.created': MonitorCreatedEvent,
109
+ 'monitor.updated': MonitorUpdatedEvent,
110
+ 'monitor.deleted': MonitorDeletedEvent,
111
+ 'monitor.run.created': MonitorRunCreatedEvent,
112
+ 'monitor.run.completed': MonitorRunCompletedEvent,
92
113
  }
93
114
 
94
115
  event_class = event_type_map.get(event_type)