not-again-ai 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. not_again_ai/data/__init__.py +7 -0
  2. not_again_ai/data/web.py +56 -0
  3. not_again_ai/llm/chat_completion/__init__.py +4 -0
  4. not_again_ai/llm/chat_completion/interface.py +32 -0
  5. not_again_ai/llm/chat_completion/providers/ollama_api.py +227 -0
  6. not_again_ai/llm/chat_completion/providers/openai_api.py +290 -0
  7. not_again_ai/llm/chat_completion/types.py +145 -0
  8. not_again_ai/llm/prompting/__init__.py +3 -0
  9. not_again_ai/llm/prompting/compile_messages.py +98 -0
  10. not_again_ai/llm/prompting/interface.py +46 -0
  11. not_again_ai/llm/prompting/providers/openai_tiktoken.py +122 -0
  12. not_again_ai/llm/prompting/types.py +43 -0
  13. {not_again_ai-0.13.0.dist-info → not_again_ai-0.15.0.dist-info}/METADATA +63 -58
  14. not_again_ai-0.15.0.dist-info/RECORD +32 -0
  15. {not_again_ai-0.13.0.dist-info → not_again_ai-0.15.0.dist-info}/WHEEL +1 -1
  16. not_again_ai/llm/gh_models/azure_ai_client.py +0 -20
  17. not_again_ai/llm/gh_models/chat_completion.py +0 -81
  18. not_again_ai/llm/openai_api/chat_completion.py +0 -200
  19. not_again_ai/llm/openai_api/context_management.py +0 -70
  20. not_again_ai/llm/openai_api/embeddings.py +0 -62
  21. not_again_ai/llm/openai_api/openai_client.py +0 -78
  22. not_again_ai/llm/openai_api/prompts.py +0 -191
  23. not_again_ai/llm/openai_api/tokens.py +0 -184
  24. not_again_ai/local_llm/__init__.py +0 -27
  25. not_again_ai/local_llm/chat_completion.py +0 -105
  26. not_again_ai/local_llm/huggingface/__init__.py +0 -0
  27. not_again_ai/local_llm/huggingface/chat_completion.py +0 -59
  28. not_again_ai/local_llm/huggingface/helpers.py +0 -23
  29. not_again_ai/local_llm/ollama/__init__.py +0 -0
  30. not_again_ai/local_llm/ollama/chat_completion.py +0 -111
  31. not_again_ai/local_llm/ollama/model_mapping.py +0 -17
  32. not_again_ai/local_llm/ollama/ollama_client.py +0 -24
  33. not_again_ai/local_llm/ollama/service.py +0 -81
  34. not_again_ai/local_llm/ollama/tokens.py +0 -104
  35. not_again_ai/local_llm/prompts.py +0 -38
  36. not_again_ai/local_llm/tokens.py +0 -90
  37. not_again_ai-0.13.0.dist-info/RECORD +0 -42
  38. not_again_ai-0.13.0.dist-info/entry_points.txt +0 -3
  39. /not_again_ai/llm/{gh_models → chat_completion/providers}/__init__.py +0 -0
  40. /not_again_ai/llm/{openai_api → prompting/providers}/__init__.py +0 -0
  41. {not_again_ai-0.13.0.dist-info → not_again_ai-0.15.0.dist-info}/LICENSE +0 -0
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.0
2
+ Generator: poetry-core 2.0.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,20 +0,0 @@
1
- import os
2
-
3
- from azure.ai.inference import ChatCompletionsClient
4
- from azure.core.credentials import AzureKeyCredential
5
-
6
-
7
- def azure_ai_client(
8
- token: str | None = None,
9
- endpoint: str = "https://models.inference.ai.azure.com",
10
- ) -> ChatCompletionsClient:
11
- if not token:
12
- token = os.getenv("GITHUB_TOKEN")
13
- if not token:
14
- raise ValueError("Token must be provided or GITHUB_TOKEN environment variable must be set")
15
-
16
- client = ChatCompletionsClient(
17
- endpoint=endpoint,
18
- credential=AzureKeyCredential(token),
19
- )
20
- return client
@@ -1,81 +0,0 @@
1
- import contextlib
2
- import json
3
- import time
4
- from typing import Any
5
-
6
- from azure.ai.inference import ChatCompletionsClient
7
- from azure.ai.inference.models import ChatCompletionsToolDefinition, ChatRequestMessage
8
-
9
-
10
- def chat_completion(
11
- messages: list[ChatRequestMessage],
12
- model: str,
13
- client: ChatCompletionsClient,
14
- tools: list[ChatCompletionsToolDefinition] | None = None,
15
- max_tokens: int | None = None,
16
- temperature: float | None = None,
17
- json_mode: bool = False,
18
- seed: int | None = None,
19
- ) -> dict[str, Any]:
20
- """Gets a response from GitHub Models using the Azure AI Inference SDK.
21
- See the available models at https://github.com/marketplace/models
22
- Full documentation of the SDK is at: https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-chat-completions
23
- And samples at: https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/ai/azure-ai-inference/samples
24
-
25
- Returns:
26
- dict[str, Any]: A dictionary with the following keys
27
- message (str | dict): The content of the generated assistant message.
28
- If json_mode is True, this will be a dictionary.
29
- tool_names (list[str], optional): The names of the tools called by the model.
30
- If the model does not support tools, a ResponseError is raised.
31
- tool_args_list (list[dict], optional): The arguments of the tools called by the model.
32
- prompt_tokens (int): The number of tokens in the messages sent to the model.
33
- completion_tokens (int): The number of tokens used by the model to generate the completion.
34
- response_duration (float): The time, in seconds, taken to generate the response by using the model.
35
- system_fingerprint (str, optional): If seed is set, a unique identifier for the model used to generate the response.
36
- """
37
- response_format = {"type": "json_object"} if json_mode else None
38
- start_time = time.time()
39
- response = client.complete( # type: ignore
40
- messages=messages,
41
- model=model,
42
- response_format=response_format, # type: ignore
43
- max_tokens=max_tokens,
44
- temperature=temperature,
45
- tools=tools,
46
- seed=seed,
47
- )
48
- end_time = time.time()
49
- response_duration = end_time - start_time
50
-
51
- response_data = {}
52
- finish_reason = response.choices[0].finish_reason
53
- response_data["finish_reason"] = finish_reason.value # type: ignore
54
-
55
- message = response.choices[0].message.content
56
- if message and json_mode:
57
- with contextlib.suppress(json.JSONDecodeError):
58
- message = json.loads(message)
59
- response_data["message"] = message
60
-
61
- # Check for tool calls because even if the finish_reason is stop, the model may have called a tool
62
- tool_calls = response.choices[0].message.tool_calls
63
- if tool_calls:
64
- tool_names = []
65
- tool_args_list = []
66
- for tool_call in tool_calls:
67
- tool_names.append(tool_call.function.name)
68
- tool_args_list.append(json.loads(tool_call.function.arguments))
69
- response_data["tool_names"] = tool_names
70
- response_data["tool_args_list"] = tool_args_list
71
-
72
- if seed is not None and hasattr(response, "system_fingerprint"):
73
- response_data["system_fingerprint"] = response.system_fingerprint
74
-
75
- usage = response.usage
76
- if usage is not None:
77
- response_data["completion_tokens"] = usage.completion_tokens
78
- response_data["prompt_tokens"] = usage.prompt_tokens
79
- response_data["response_duration"] = round(response_duration, 4)
80
-
81
- return response_data
@@ -1,200 +0,0 @@
1
- import contextlib
2
- import json
3
- import time
4
- from typing import Any
5
-
6
- from openai import AzureOpenAI, OpenAI
7
-
8
-
9
- def chat_completion(
10
- messages: list[dict[str, Any]],
11
- model: str,
12
- client: OpenAI | AzureOpenAI | Any,
13
- tools: list[dict[str, Any]] | None = None,
14
- tool_choice: str = "auto",
15
- max_tokens: int | None = None,
16
- temperature: float = 0.7,
17
- json_mode: bool = False,
18
- json_schema: dict[str, Any] | None = None,
19
- seed: int | None = None,
20
- logprobs: tuple[bool, int | None] | None = None,
21
- n: int = 1,
22
- **kwargs: Any,
23
- ) -> dict[str, Any]:
24
- """Get an OpenAI chat completion response: https://platform.openai.com/docs/api-reference/chat/create
25
-
26
- NOTE: Depending on the model, certain parameters may not be supported,
27
- particularly for older vision-enabled models like gpt-4-1106-vision-preview.
28
- Be sure to check the documentation: https://platform.openai.com/docs/models/gpt-4-turbo-and-gpt-4
29
-
30
- Args:
31
- messages (list): A list of messages comprising the conversation so far.
32
- model (str): ID of the model to use. See the model endpoint compatibility table:
33
- https://platform.openai.com/docs/models/model-endpoint-compatibility
34
- for details on which models work with the Chat API.
35
- client (OpenAI): An instance of the OpenAI or AzureOpenAI client.
36
- If anything else is provided, we assume that it follows the OpenAI spec and call it by passing kwargs directly.
37
- For example you can provide something like:
38
- ```
39
- def custom_client(**kwargs):
40
- client = openai_client()
41
- completion = client.chat.completions.create(**kwargs)
42
- return completion.to_dict()
43
- ```
44
- tools (list[dict[str, Any]], optional):A list of tools the model may call.
45
- Use this to provide a list of functions the model may generate JSON inputs for. Defaults to None.
46
- tool_choice (str, optional): The tool choice to use. Can be "auto", "required", "none", or a specific function name.
47
- Note the function name cannot be any of "auto", "required", or "none". Defaults to "auto".
48
- max_tokens (int, optional): The maximum number of tokens to generate in the chat completion.
49
- Defaults to None, which automatically limits to the model's maximum context length.
50
- temperature (float, optional): What sampling temperature to use, between 0 and 2.
51
- Higher values like 0.8 will make the output more random,
52
- while lower values like 0.2 will make it more focused and deterministic. Defaults to 0.7.
53
- json_mode (bool, optional): When JSON mode is enabled, the model is constrained to only
54
- generate strings that parse into valid JSON object and will return a dictionary.
55
- See https://platform.openai.com/docs/guides/text-generation/json-mode
56
- json_schema (dict, optional): Enables Structured Outputs which ensures the model will
57
- always generate responses that adhere to your supplied JSON Schema.
58
- See https://platform.openai.com/docs/guides/structured-outputs/structured-outputs
59
- seed (int, optional): If specified, OpenAI will make a best effort to sample deterministically,
60
- such that repeated requests with the same `seed` and parameters should return the same result.
61
- Determinism is not guaranteed, and you should refer to the `system_fingerprint` response
62
- parameter to monitor changes in the backend.
63
- logprobs (tuple[bool, int], optional): Whether to return log probabilities of the output tokens or not.
64
- If `logprobs[0]` is true, returns the log probabilities of each output token returned in the content of message.
65
- `logprobs[1]` is an integer between 0 and 5 specifying the number of most likely tokens to return at each token position,
66
- each with an associated log probability. `logprobs[0]` must be set to true if this parameter is used.
67
- n (int, optional): How many chat completion choices to generate for each input message.
68
- Defaults to 1.
69
- **kwargs: Additional keyword arguments to pass to the OpenAI client chat completion.
70
-
71
- Returns:
72
- dict[str, Any]: A dictionary with the following keys:
73
- finish_reason (str): The reason the model stopped generating further tokens.
74
- Can be 'stop', 'length', or 'tool_calls'.
75
- tool_names (list[str], optional): The names of the tools called by the model.
76
- tool_args_list (list[dict], optional): The arguments of the tools called by the model.
77
- message (str | dict): The content of the generated assistant message.
78
- If json_mode is True, this will be a dictionary.
79
- logprobs (list[dict[str, Any] | list[dict[str, Any]]]): If logprobs[1] is between 1 and 5, each element in the list
80
- will be a list of dictionaries containing the token, logprob, and bytes for the top `logprobs[1]` logprobs. Otherwise,
81
- this will be a list of dictionaries containing the token, logprob, and bytes for each token in the message.
82
- choices (list[dict], optional): A list of chat completion choices if n > 1 where each dict contains the above fields.
83
- completion_tokens (int): The number of tokens used by the model to generate the completion.
84
- NOTE: If n > 1 this is the sum of all completions.
85
- prompt_tokens (int): The number of tokens in the messages sent to the model.
86
- system_fingerprint (str, optional): If seed is set, a unique identifier for the model used to generate the response.
87
- response_duration (float): The time, in seconds, taken to generate the response from the API.
88
- """
89
-
90
- if json_mode and json_schema is not None:
91
- raise ValueError("json_schema and json_mode cannot be used together.")
92
-
93
- if json_mode:
94
- response_format: dict[str, Any] = {"type": "json_object"}
95
- elif json_schema is not None:
96
- if isinstance(json_schema, dict):
97
- response_format = {"type": "json_schema", "json_schema": json_schema}
98
- else:
99
- response_format = {"type": "text"}
100
-
101
- kwargs.update(
102
- {
103
- "messages": messages,
104
- "model": model,
105
- "max_tokens": max_tokens,
106
- "temperature": temperature,
107
- "response_format": response_format,
108
- "n": n,
109
- }
110
- )
111
-
112
- if tools is not None:
113
- kwargs["tools"] = tools
114
- if tool_choice not in ["none", "auto", "required"]:
115
- kwargs["tool_choice"] = {"type": "function", "function": {"name": tool_choice}}
116
- else:
117
- kwargs["tool_choice"] = tool_choice
118
-
119
- if seed is not None:
120
- kwargs["seed"] = seed
121
-
122
- if logprobs is not None:
123
- kwargs["logprobs"] = logprobs[0]
124
- if logprobs[0] and logprobs[1] is not None:
125
- kwargs["top_logprobs"] = logprobs[1]
126
-
127
- start_time = time.time()
128
- if isinstance(client, OpenAI | AzureOpenAI):
129
- response = client.chat.completions.create(**kwargs)
130
- response = response.to_dict()
131
- else:
132
- response = client(**kwargs)
133
- end_time = time.time()
134
- response_duration = end_time - start_time
135
-
136
- response_data: dict[str, Any] = {"choices": []}
137
- for response_choice in response["choices"]:
138
- response_data_curr = {}
139
- finish_reason = response_choice["finish_reason"]
140
- response_data_curr["finish_reason"] = finish_reason
141
-
142
- # We first check for tool calls because even if the finish_reason is stop, the model may have called a tool
143
- tool_calls = response_choice["message"].get("tool_calls", None)
144
- if tool_calls:
145
- tool_names = []
146
- tool_args_list = []
147
- for tool_call in tool_calls:
148
- tool_names.append(tool_call["function"]["name"])
149
- tool_args_list.append(json.loads(tool_call["function"]["arguments"]))
150
- response_data_curr["message"] = response_choice["message"]["content"]
151
- response_data_curr["tool_names"] = tool_names
152
- response_data_curr["tool_args_list"] = tool_args_list
153
- elif finish_reason == "stop" or finish_reason == "length":
154
- message = response_choice["message"]["content"]
155
- if json_mode or json_schema is not None:
156
- with contextlib.suppress(json.JSONDecodeError):
157
- message = json.loads(message)
158
- response_data_curr["message"] = message
159
-
160
- if response_choice["logprobs"] and response_choice["logprobs"]["content"] is not None:
161
- logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = []
162
- for logprob in response_choice["logprobs"]["content"]:
163
- if logprob["top_logprobs"]:
164
- curr_logprob_infos = []
165
- for top_logprob in logprob["top_logprobs"]:
166
- curr_logprob_infos.append(
167
- {
168
- "token": top_logprob["token"],
169
- "logprob": top_logprob["logprob"],
170
- "bytes": top_logprob["bytes"],
171
- }
172
- )
173
- logprobs_list.append(curr_logprob_infos)
174
- else:
175
- logprobs_list.append(
176
- {
177
- "token": logprob["token"],
178
- "logprob": logprob["logprob"],
179
- "bytes": logprob["bytes"],
180
- }
181
- )
182
-
183
- response_data_curr["logprobs"] = logprobs_list
184
- response_data["choices"].append(response_data_curr)
185
-
186
- usage = response["usage"]
187
- if usage is not None:
188
- response_data["completion_tokens"] = usage["completion_tokens"]
189
- response_data["prompt_tokens"] = usage["prompt_tokens"]
190
-
191
- if seed is not None and response["system_fingerprint"] is not None:
192
- response_data["system_fingerprint"] = response["system_fingerprint"]
193
-
194
- response_data["response_duration"] = round(response_duration, 4)
195
-
196
- if len(response_data["choices"]) == 1:
197
- response_data.update(response_data["choices"][0])
198
- del response_data["choices"]
199
-
200
- return response_data
@@ -1,70 +0,0 @@
1
- import copy
2
-
3
- from not_again_ai.llm.openai_api.tokens import load_tokenizer, num_tokens_from_messages, truncate_str
4
-
5
-
6
- def _inject_variable(
7
- messages_unformatted: list[dict[str, str]], variable_name: str, variable_text: str
8
- ) -> list[dict[str, str]]:
9
- """Injects variables into the messages using Python string formatting."""
10
- messages_formatted = copy.deepcopy(messages_unformatted)
11
- for message in messages_formatted:
12
- message["content"] = message["content"].replace("{{" + variable_name + "}}", variable_text)
13
- return messages_formatted
14
-
15
-
16
- def priority_truncation(
17
- messages_unformatted: list[dict[str, str]],
18
- variables: dict[str, str],
19
- priority: list[str],
20
- token_limit: int,
21
- model: str = "gpt-3.5-turbo-0125",
22
- ) -> list[dict[str, str]]:
23
- """Formats messages_unformatted and injects variables into the messages in the order of priority, truncating the messages to fit the token limit.
24
-
25
- Algorithm:
26
- 0. Checks if all variables in the priority list are in the variables dict. If not, adds the missing variables into priority in any order.
27
- 1. Iterating over priority:
28
- a. Count the current number of tokens in messages_formatted and compute how many tokens remain.
29
- b. Count the number of times the variable occurs in messages_formatted.
30
- c. Truncate the variable to fit the remaining token budget taking into account the number of times it occurs in the messages.
31
- d. Inject the variable text into messages_formatted.
32
-
33
- Args:
34
- messages_unformatted: A list of dictionaries where each dictionary
35
- represents a message. Each message must have 'role' and 'content'
36
- keys with string values, where content is a string with any number of occurrences of {{variable_name}}.
37
- variables: A dictionary where each key-value pair represents a variable name and its value to inject.
38
- priority: A list of variable names in their order of priority.
39
- token_limit: The maximum number of tokens allowed in the messages.
40
- model: The model to use for tokenization. Defaults to "gpt-3.5-turbo-0125".
41
- """
42
- tokenizer = load_tokenizer(model)
43
-
44
- # Check if all variables in the priority list are in the variables dict.
45
- # If not, add the missing variables into priority in any order.
46
- for var in variables:
47
- if var not in priority:
48
- priority.append(var)
49
-
50
- messages_formatted = copy.deepcopy(messages_unformatted)
51
- for var in priority:
52
- # Count the current number of tokens in messages_formatted and compute a remaining token budget.
53
- tokenizer = load_tokenizer(model)
54
- num_tokens = num_tokens_from_messages(messages_formatted, tokenizer=tokenizer, model=model)
55
- remaining_tokens = token_limit - num_tokens
56
- if remaining_tokens <= 0:
57
- break
58
-
59
- # Count the number of times the variable occurs in messages_formatted (including within the same message).
60
- num_var_occurrences = 0
61
- for message in messages_formatted:
62
- num_var_occurrences += message["content"].count("{{" + var + "}}")
63
-
64
- # Truncate the variable to fit the remaining token budget taking into account the number of times it occurs in the messages.
65
- truncated_var = truncate_str(variables[var], remaining_tokens // num_var_occurrences, tokenizer=tokenizer)
66
-
67
- # Inject the variable text into messages_formatted.
68
- messages_formatted = _inject_variable(messages_formatted, var, truncated_var)
69
-
70
- return messages_formatted
@@ -1,62 +0,0 @@
1
- from typing import Any
2
-
3
- from openai import OpenAI
4
-
5
-
6
- def embed_text(
7
- text: str | list[str],
8
- client: OpenAI,
9
- model: str = "text-embedding-3-large",
10
- dimensions: int | None = None,
11
- encoding_format: str = "float",
12
- **kwargs: Any,
13
- ) -> list[float] | str | list[list[float]] | list[str]:
14
- """Generates an embedding vector for a given text using OpenAI's API.
15
-
16
- Args:
17
- text (str | list[str]): The input text to be embedded. Each text should not exceed 8191 tokens, which is the max for V2 and V3 models
18
- client (OpenAI): The OpenAI client used to interact with the API.
19
- model (str, optional): The ID of the model to use for embedding.
20
- Defaults to "text-embedding-3-large".
21
- Choose from text-embedding-3-small, text-embedding-3-large, text-embedding-ada-002.
22
- See https://platform.openai.com/docs/models/embeddings for more details.
23
- dimensions (int | None, optional): The number of dimensions for the output embeddings.
24
- This is only supported in "text-embedding-3" and later models. Defaults to None.
25
- encoding_format (str, optional): The format for the returned embeddings. Can be either "float" or "base64".
26
- Defaults to "float".
27
-
28
- Returns:
29
- list[float] | str | list[list[float]] | list[str]: The embedding vector represented as a list of floats or base64 encoded string.
30
- If multiple text inputs are provided, a list of embedding vectors is returned.
31
- The length and format of the vector depend on the model, encoding_format, and dimensions.
32
-
33
- Raises:
34
- ValueError: If 'text-embedding-ada-002' model is used and dimensions are specified,
35
- as this model does not support specifying dimensions.
36
-
37
- Example:
38
- client = OpenAI()
39
- embedding = embed_text("Example text", client, model="text-embedding-ada-002")
40
- """
41
- if model == "text-embedding-ada-002" and dimensions:
42
- # text-embedding-ada-002 does not support dimensions
43
- raise ValueError("text-embedding-ada-002 does not support dimensions")
44
-
45
- kwargs = {
46
- "model": model,
47
- "input": text,
48
- "encoding_format": encoding_format,
49
- }
50
- if dimensions:
51
- kwargs["dimensions"] = dimensions
52
-
53
- response = client.embeddings.create(**kwargs)
54
-
55
- responses = []
56
- for embedding in response.data:
57
- responses.append(embedding.embedding)
58
-
59
- if len(responses) == 1:
60
- return responses[0]
61
-
62
- return responses
@@ -1,78 +0,0 @@
1
- from typing import Literal
2
-
3
- from azure.identity import DefaultAzureCredential, get_bearer_token_provider
4
- from openai import AzureOpenAI, OpenAI
5
-
6
-
7
- class InvalidOAIAPITypeError(Exception):
8
- """Raised when an invalid OAIAPIType string is provided."""
9
-
10
- pass
11
-
12
-
13
- def openai_client(
14
- api_type: Literal["openai", "azure_openai"] = "openai",
15
- api_key: str | None = None,
16
- organization: str | None = None,
17
- aoai_api_version: str = "2024-06-01",
18
- azure_endpoint: str | None = None,
19
- timeout: float | None = None,
20
- max_retries: int | None = None,
21
- ) -> OpenAI | AzureOpenAI:
22
- """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters.
23
-
24
- Azure OpenAI requires RBAC authentication. You must be signed in with the Azure CLI and have correct role assigned.
25
- See https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
26
-
27
- Args:
28
- api_type (str, optional): Type of the API to be used. Accepted values are 'openai' or 'azure_openai'.
29
- Defaults to 'openai'.
30
- api_key (str, optional): The API key to authenticate the client. If not provided,
31
- OpenAI automatically uses `OPENAI_API_KEY` from the environment.
32
- organization (str, optional): The ID of the organization. If not provided,
33
- OpenAI automotically uses `OPENAI_ORG_ID` from the environment.
34
- aoai_api_version (str, optional): Only applicable if using Azure OpenAI https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning
35
- azure_endpoint (str, optional): The endpoint to use for Azure OpenAI.
36
- If not provided, will be read from the `AZURE_OPENAI_ENDPOINT` environment variable.
37
- timeout (float, optional): By default requests time out after 10 minutes.
38
- max_retries (int, optional): Certain errors are automatically retried 2 times by default,
39
- with a short exponential backoff. Connection errors (for example, due to a network connectivity problem),
40
- 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default.
41
-
42
- Returns:
43
- OpenAI: An instance of the OpenAI client.
44
-
45
- Raises:
46
- InvalidOAIAPITypeError: If an invalid API type string is provided.
47
- NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai').
48
-
49
- Examples:
50
- >>> client = openai_client(api_type="openai", api_key="YOUR_API_KEY")
51
- """
52
- if api_type not in ["openai", "azure_openai"]:
53
- raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.")
54
-
55
- if api_type == "openai":
56
- args = {
57
- "api_key": api_key,
58
- "organization": organization,
59
- "timeout": timeout,
60
- "max_retries": max_retries,
61
- }
62
- # Remove any None values in order to use the default values.
63
- filtered_args = {k: v for k, v in args.items() if v is not None}
64
- return OpenAI(**filtered_args) # type: ignore
65
- elif api_type == "azure_openai":
66
- azure_credential = DefaultAzureCredential()
67
- ad_token_provider = get_bearer_token_provider(azure_credential, "https://cognitiveservices.azure.com/.default")
68
- args = {
69
- "api_version": aoai_api_version,
70
- "azure_endpoint": azure_endpoint,
71
- "azure_ad_token_provider": ad_token_provider, # type: ignore
72
- "timeout": timeout,
73
- "max_retries": max_retries,
74
- }
75
- filtered_args = {k: v for k, v in args.items() if v is not None}
76
- return AzureOpenAI(**filtered_args) # type: ignore
77
- else:
78
- raise NotImplementedError(f"API type '{api_type}' is invalid.")