exa-py 1.0.9__py3-none-any.whl → 1.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of exa-py might be problematic. Click here for more details.

exa_py/api.py CHANGED
@@ -1,9 +1,12 @@
1
1
  from __future__ import annotations
2
2
  from dataclasses import dataclass
3
3
  import dataclasses
4
+ from functools import wraps
4
5
  import re
5
6
  import requests
6
7
  from typing import (
8
+ Callable,
9
+ Iterable,
7
10
  List,
8
11
  Optional,
9
12
  Dict,
@@ -15,6 +18,28 @@ from typing import (
15
18
  )
16
19
  from typing_extensions import TypedDict
17
20
 
21
+ import httpx
22
+ from openai import NOT_GIVEN, NotGiven, OpenAI
23
+ from openai.types.chat.chat_completion_stream_options_param import (
24
+ ChatCompletionStreamOptionsParam,
25
+ )
26
+ from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
27
+ from openai.types.chat_model import ChatModel
28
+ from openai.types.chat.chat_completion_tool_choice_option_param import (
29
+ ChatCompletionToolChoiceOptionParam,
30
+ )
31
+ from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
32
+ from openai._types import Headers, Query, Body
33
+ from openai.types.chat import completion_create_params
34
+ from exa_py.utils import (
35
+ ExaOpenAICompletion,
36
+ add_message_to_messages,
37
+ format_exa_result,
38
+ maybe_get_query,
39
+ )
40
+
41
+
42
+
18
43
 
19
44
  def snake_to_camel(snake_str: str) -> str:
20
45
  """Convert snake_case string to camelCase.
@@ -319,7 +344,7 @@ class Exa:
319
344
  self,
320
345
  api_key: Optional[str],
321
346
  base_url: str = "https://api.exa.ai",
322
- user_agent: str = "exa-py 1.0.9",
347
+ user_agent: str = "exa-py 1.0.11",
323
348
  ):
324
349
  """Initialize the Exa client with the provided API key and optional base URL and user agent.
325
350
 
@@ -646,3 +671,177 @@ class Exa:
646
671
  [Result(**to_snake_case(result)) for result in data["results"]],
647
672
  data["autopromptString"] if "autopromptString" in data else None,
648
673
  )
674
+ def wrap(self, client: OpenAI):
675
+ """Wrap an OpenAI client with Exa functionality.
676
+
677
+ After wrapping, any call to `client.chat.completions.create` will be intercepted and enhanced with Exa functionality.
678
+
679
+ To disable Exa functionality for a specific call, set `use_exa="none"` in the call to `client.chat.completions.create`.
680
+
681
+ Args:
682
+ client (OpenAI): The OpenAI client to wrap.
683
+
684
+ Returns:
685
+ OpenAI: The wrapped OpenAI client.
686
+ """
687
+
688
+ func = client.chat.completions.create
689
+
690
+ @wraps(func)
691
+ def create_with_rag(
692
+ # Mandatory OpenAI args
693
+ messages: Iterable[ChatCompletionMessageParam],
694
+ model: Union[str, ChatModel],
695
+ # Optional OpenAI args
696
+ frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
697
+ function_call: completion_create_params.FunctionCall | NotGiven = NOT_GIVEN,
698
+ functions: (
699
+ Iterable[completion_create_params.Function] | NotGiven
700
+ ) = NOT_GIVEN,
701
+ logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
702
+ logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
703
+ max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
704
+ n: Optional[int] | NotGiven = NOT_GIVEN,
705
+ presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
706
+ response_format: (
707
+ completion_create_params.ResponseFormat | NotGiven
708
+ ) = NOT_GIVEN,
709
+ seed: Optional[int] | NotGiven = NOT_GIVEN,
710
+ stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
711
+ stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
712
+ stream_options: (
713
+ Optional[ChatCompletionStreamOptionsParam] | NotGiven
714
+ ) = NOT_GIVEN,
715
+ temperature: Optional[float] | NotGiven = NOT_GIVEN,
716
+ tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
717
+ tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
718
+ top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
719
+ top_p: Optional[float] | NotGiven = NOT_GIVEN,
720
+ user: str | NotGiven = NOT_GIVEN,
721
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
722
+ # The extra values given here take precedence over values defined on the client or passed to this method.
723
+ extra_headers: Headers | None = None,
724
+ extra_query: Query | None = None,
725
+ extra_body: Body | None = None,
726
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
727
+ # Exa args
728
+ use_exa: Optional[Literal["required", "none", "auto"]] = "auto",
729
+ highlights: Union[HighlightsContentsOptions, Literal[True], None] = None,
730
+ num_results: Optional[int] = 3,
731
+ include_domains: Optional[List[str]] = None,
732
+ exclude_domains: Optional[List[str]] = None,
733
+ start_crawl_date: Optional[str] = None,
734
+ end_crawl_date: Optional[str] = None,
735
+ start_published_date: Optional[str] = None,
736
+ end_published_date: Optional[str] = None,
737
+ use_autoprompt: Optional[bool] = True,
738
+ type: Optional[str] = None,
739
+ category: Optional[str] = None,
740
+ result_max_len: int = 2048,
741
+ ):
742
+ exa_kwargs = {
743
+ "num_results": num_results,
744
+ "include_domains": include_domains,
745
+ "exclude_domains": exclude_domains,
746
+ "highlights": highlights,
747
+ "start_crawl_date": start_crawl_date,
748
+ "end_crawl_date": end_crawl_date,
749
+ "start_published_date": start_published_date,
750
+ "end_published_date": end_published_date,
751
+ "use_autoprompt": use_autoprompt,
752
+ "type": type,
753
+ "category": category,
754
+ }
755
+
756
+ create_kwargs = {
757
+ "model": model,
758
+ "frequency_penalty": frequency_penalty,
759
+ "function_call": function_call,
760
+ "functions": functions,
761
+ "logit_bias": logit_bias,
762
+ "logprobs": logprobs,
763
+ "max_tokens": max_tokens,
764
+ "n": n,
765
+ "presence_penalty": presence_penalty,
766
+ "response_format": response_format,
767
+ "seed": seed,
768
+ "stop": stop,
769
+ "stream": stream,
770
+ "stream_options": stream_options,
771
+ "temperature": temperature,
772
+ "tool_choice": tool_choice,
773
+ "tools": tools,
774
+ "top_logprobs": top_logprobs,
775
+ "top_p": top_p,
776
+ "user": user,
777
+ "extra_headers": extra_headers,
778
+ "extra_query": extra_query,
779
+ "extra_body": extra_body,
780
+ "timeout": timeout,
781
+ }
782
+
783
+ if use_exa != "none":
784
+ assert tools is NOT_GIVEN, "Tool use is not supported with Exa"
785
+ create_kwargs["tool_choice"] = use_exa
786
+
787
+ return self._create_with_tool(
788
+ create_fn=func,
789
+ messages=list(messages),
790
+ max_len=result_max_len,
791
+ create_kwargs=create_kwargs,
792
+ exa_kwargs=exa_kwargs,
793
+ )
794
+
795
+ print("Wrapping OpenAI client with Exa functionality.", type(create_with_rag))
796
+ client.chat.completions.create = create_with_rag # type: ignore
797
+
798
+ return client
799
+
800
+ def _create_with_tool(
801
+ self,
802
+ create_fn: Callable,
803
+ messages: List[ChatCompletionMessageParam],
804
+ max_len,
805
+ create_kwargs,
806
+ exa_kwargs,
807
+ ) -> ExaOpenAICompletion:
808
+ tools = [
809
+ {
810
+ "type": "function",
811
+ "function": {
812
+ "name": "search",
813
+ "description": "Search the web for relevant information.",
814
+ "parameters": {
815
+ "type": "object",
816
+ "properties": {
817
+ "query": {
818
+ "type": "string",
819
+ "description": "The query to search for.",
820
+ },
821
+ },
822
+ "required": ["query"],
823
+ },
824
+ },
825
+ }
826
+ ]
827
+
828
+ create_kwargs["tools"] = tools
829
+
830
+ completion = create_fn(messages=messages, **create_kwargs)
831
+
832
+ query = maybe_get_query(completion)
833
+
834
+ if not query:
835
+ return ExaOpenAICompletion.from_completion(completion=completion, exa_result=None)
836
+
837
+ exa_result = self.search_and_contents(query, **exa_kwargs)
838
+ exa_str = format_exa_result(exa_result, max_len=max_len)
839
+ new_messages = add_message_to_messages(completion, messages, exa_str)
840
+ # For now, don't allow recursive tool calls
841
+ create_kwargs["tool_choice"] = "none"
842
+ completion = create_fn(messages=new_messages, **create_kwargs)
843
+
844
+ exa_completion = ExaOpenAICompletion.from_completion(
845
+ completion=completion, exa_result=exa_result
846
+ )
847
+ return exa_completion
exa_py/utils.py ADDED
@@ -0,0 +1,78 @@
1
+ import json
2
+ from typing import Optional
3
+ from openai.types.chat import ChatCompletion
4
+
5
+ from typing import TYPE_CHECKING
6
+ if TYPE_CHECKING:
7
+ from exa_py.api import ResultWithText, SearchResponse
8
+
9
+
10
+
11
+ def maybe_get_query(completion) -> str | None:
12
+ """Extract query from completion if it exists."""
13
+ if completion.choices[0].message.tool_calls:
14
+ for tool_call in completion.choices[0].message.tool_calls:
15
+ if tool_call.function.name == "search":
16
+ query = json.loads(tool_call.function.arguments)["query"]
17
+ return query
18
+ return None
19
+
20
+
21
+ def add_message_to_messages(completion, messages, exa_result) -> list[dict]:
22
+ """Add assistant message and exa result to messages list. Also remove previous exa call and results."""
23
+ assistant_message = completion.choices[0].message
24
+ assert assistant_message.tool_calls, "Must use this with a tool call request"
25
+ # Remove previous exa call and results to prevent blowing up history
26
+ messages = [
27
+ message
28
+ for message in messages
29
+ if not (message.get("role") == "function")
30
+ ]
31
+
32
+ messages.extend([
33
+ assistant_message,
34
+ {
35
+ "role": "tool",
36
+ "name": "search",
37
+ "tool_call_id": assistant_message.tool_calls[0].id,
38
+ "content": exa_result,
39
+ }
40
+ ])
41
+
42
+ return messages
43
+
44
+
45
+ def format_exa_result(exa_result, max_len: int=-1):
46
+ """Format exa result for pasting into chat."""
47
+ str = [
48
+ f"Url: {result.url}\nTitle: {result.title}\n{result.text[:max_len]}\n"
49
+ for result in exa_result.results
50
+ ]
51
+
52
+ return "\n".join(str)
53
+
54
+
55
+ class ExaOpenAICompletion(ChatCompletion):
56
+ """Exa wrapper for OpenAI completion."""
57
+ def __init__(self, exa_result: Optional["SearchResponse[ResultWithText]"], **kwargs):
58
+ super().__init__(**kwargs)
59
+ self.exa_result = exa_result
60
+
61
+
62
+ @classmethod
63
+ def from_completion(
64
+ cls,
65
+ exa_result: Optional["SearchResponse[ResultWithText]"],
66
+ completion: ChatCompletion
67
+ ):
68
+
69
+ return cls(
70
+ exa_result=exa_result,
71
+ id=completion.id,
72
+ choices=completion.choices,
73
+ created=completion.created,
74
+ model=completion.model,
75
+ object=completion.object,
76
+ system_fingerprint=completion.system_fingerprint,
77
+ usage=completion.usage,
78
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: exa-py
3
- Version: 1.0.9
3
+ Version: 1.0.11
4
4
  Summary: Python SDK for Exa API.
5
5
  Home-page: https://github.com/exa-labs/exa-py
6
6
  Author: Exa
@@ -18,6 +18,9 @@ Classifier: Programming Language :: Python :: 3.12
18
18
  Description-Content-Type: text/markdown
19
19
  Requires-Dist: requests
20
20
  Requires-Dist: typing-extensions
21
+ Requires-Dist: openai
22
+ Provides-Extra: openai
23
+ Requires-Dist: openai ; extra == 'openai'
21
24
 
22
25
  # Exa
23
26
 
@@ -0,0 +1,8 @@
1
+ exa_py/__init__.py,sha256=aVF1zB_UV3dagJ5Vn2WrdcInzibdIW61M89sjwRCU_g,29
2
+ exa_py/api.py,sha256=l1gg_Sp2c9R9SpUpPI7G6M8mrnYULfsxZBv8lHVoxPU,31032
3
+ exa_py/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ exa_py/utils.py,sha256=pF91TKGGg4SbSVrdWroQze90-D9AlvqLoKWlHIeyLaE,2402
5
+ exa_py-1.0.11.dist-info/METADATA,sha256=jdepNNb0qBPgzPzHZCQ33Ocyfe0dIe4A3DhppxbwQQQ,3137
6
+ exa_py-1.0.11.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
7
+ exa_py-1.0.11.dist-info/top_level.txt,sha256=Mfkmscdw9HWR1PtVhU1gAiVo6DHu_tyiVdb89gfZBVI,7
8
+ exa_py-1.0.11.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- exa_py/__init__.py,sha256=aVF1zB_UV3dagJ5Vn2WrdcInzibdIW61M89sjwRCU_g,29
2
- exa_py/api.py,sha256=wV3JOtEDN6yggi4TS2XaAhcvHU85LIRine_mioZ049E,22932
3
- exa_py/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- exa_py-1.0.9.dist-info/METADATA,sha256=XqcyExdupc9LOfn595tOlb3Y7LLHdhY7yQ08j4UPnz0,3049
5
- exa_py-1.0.9.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
6
- exa_py-1.0.9.dist-info/top_level.txt,sha256=Mfkmscdw9HWR1PtVhU1gAiVo6DHu_tyiVdb89gfZBVI,7
7
- exa_py-1.0.9.dist-info/RECORD,,