not-again-ai 0.13.0__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. not_again_ai/data/__init__.py +7 -0
  2. not_again_ai/data/web.py +56 -0
  3. not_again_ai/llm/chat_completion/__init__.py +4 -0
  4. not_again_ai/llm/chat_completion/interface.py +32 -0
  5. not_again_ai/llm/chat_completion/providers/ollama_api.py +227 -0
  6. not_again_ai/llm/chat_completion/providers/openai_api.py +290 -0
  7. not_again_ai/llm/chat_completion/types.py +145 -0
  8. not_again_ai/llm/prompting/__init__.py +3 -0
  9. not_again_ai/llm/prompting/compile_messages.py +98 -0
  10. not_again_ai/llm/prompting/interface.py +46 -0
  11. not_again_ai/llm/prompting/providers/openai_tiktoken.py +122 -0
  12. not_again_ai/llm/prompting/types.py +43 -0
  13. {not_again_ai-0.13.0.dist-info → not_again_ai-0.15.0.dist-info}/METADATA +63 -58
  14. not_again_ai-0.15.0.dist-info/RECORD +32 -0
  15. {not_again_ai-0.13.0.dist-info → not_again_ai-0.15.0.dist-info}/WHEEL +1 -1
  16. not_again_ai/llm/gh_models/azure_ai_client.py +0 -20
  17. not_again_ai/llm/gh_models/chat_completion.py +0 -81
  18. not_again_ai/llm/openai_api/chat_completion.py +0 -200
  19. not_again_ai/llm/openai_api/context_management.py +0 -70
  20. not_again_ai/llm/openai_api/embeddings.py +0 -62
  21. not_again_ai/llm/openai_api/openai_client.py +0 -78
  22. not_again_ai/llm/openai_api/prompts.py +0 -191
  23. not_again_ai/llm/openai_api/tokens.py +0 -184
  24. not_again_ai/local_llm/__init__.py +0 -27
  25. not_again_ai/local_llm/chat_completion.py +0 -105
  26. not_again_ai/local_llm/huggingface/__init__.py +0 -0
  27. not_again_ai/local_llm/huggingface/chat_completion.py +0 -59
  28. not_again_ai/local_llm/huggingface/helpers.py +0 -23
  29. not_again_ai/local_llm/ollama/__init__.py +0 -0
  30. not_again_ai/local_llm/ollama/chat_completion.py +0 -111
  31. not_again_ai/local_llm/ollama/model_mapping.py +0 -17
  32. not_again_ai/local_llm/ollama/ollama_client.py +0 -24
  33. not_again_ai/local_llm/ollama/service.py +0 -81
  34. not_again_ai/local_llm/ollama/tokens.py +0 -104
  35. not_again_ai/local_llm/prompts.py +0 -38
  36. not_again_ai/local_llm/tokens.py +0 -90
  37. not_again_ai-0.13.0.dist-info/RECORD +0 -42
  38. not_again_ai-0.13.0.dist-info/entry_points.txt +0 -3
  39. /not_again_ai/llm/{gh_models → chat_completion/providers}/__init__.py +0 -0
  40. /not_again_ai/llm/{openai_api → prompting/providers}/__init__.py +0 -0
  41. {not_again_ai-0.13.0.dist-info → not_again_ai-0.15.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,7 @@
1
+ import importlib.util
2
+
3
+ if importlib.util.find_spec("playwright") is None:
4
+ raise ImportError(
5
+ "not_again_ai.data requires the 'data' extra to be installed. "
6
+ "You can install it using 'pip install not_again_ai[data]'."
7
+ )
@@ -0,0 +1,56 @@
1
+ from loguru import logger
2
+ from playwright.sync_api import Browser, Playwright, sync_playwright
3
+
4
+
5
+ def create_browser(headless: bool = True) -> tuple[Playwright, Browser]:
6
+ """Creates and returns a new Playwright instance and browser.
7
+
8
+ Args:
9
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to True.
10
+
11
+ Returns:
12
+ tuple[Playwright, Browser]: A tuple containing the Playwright instance and browser.
13
+ """
14
+ pwright = sync_playwright().start()
15
+ browser = pwright.chromium.launch(
16
+ headless=headless,
17
+ chromium_sandbox=False,
18
+ timeout=15000,
19
+ )
20
+ return pwright, browser
21
+
22
+
23
+ def get_raw_web_content(url: str, browser: Browser | None = None, headless: bool = True) -> str:
24
+ """Fetches raw web content from a given URL using Playwright.
25
+
26
+ Args:
27
+ url (str): The URL to fetch content from.
28
+ browser (Browser | None, optional): An existing browser instance to use. Defaults to None.
29
+ headless (bool, optional): Whether to run the browser in headless mode. Defaults to True.
30
+
31
+ Returns:
32
+ str: The raw web content.
33
+ """
34
+ p = None
35
+ try:
36
+ if browser is None:
37
+ p, browser = create_browser(headless)
38
+
39
+ page = browser.new_page(
40
+ accept_downloads=False,
41
+ java_script_enabled=True,
42
+ viewport={"width": 1366, "height": 768},
43
+ user_agent="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.3",
44
+ )
45
+ page.goto(url)
46
+ content = page.content()
47
+ page.close()
48
+ return content
49
+ except Exception as e:
50
+ logger.error(f"Failed to get web content: {e}")
51
+ return ""
52
+ finally:
53
+ if browser:
54
+ browser.close()
55
+ if p:
56
+ p.stop()
@@ -0,0 +1,4 @@
1
+ from not_again_ai.llm.chat_completion.interface import chat_completion
2
+ from not_again_ai.llm.chat_completion.types import ChatCompletionRequest
3
+
4
+ __all__ = ["ChatCompletionRequest", "chat_completion"]
@@ -0,0 +1,32 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion
5
+ from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion
6
+ from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, ChatCompletionResponse
7
+
8
+
9
+ def chat_completion(
10
+ request: ChatCompletionRequest,
11
+ provider: str,
12
+ client: Callable[..., Any],
13
+ ) -> ChatCompletionResponse:
14
+ """Get a chat completion response from the given provider. Currently supported providers:
15
+ - `openai` - OpenAI
16
+ - `azure_openai` - Azure OpenAI
17
+ - `ollama` - Ollama
18
+
19
+ Args:
20
+ request: Request parameter object
21
+ provider: The supported provider name
22
+ client: Client information, see the provider's implementation for what can be provided
23
+
24
+ Returns:
25
+ ChatCompletionResponse: The chat completion response.
26
+ """
27
+ if provider == "openai" or provider == "azure_openai":
28
+ return openai_chat_completion(request, client)
29
+ elif provider == "ollama":
30
+ return ollama_chat_completion(request, client)
31
+ else:
32
+ raise ValueError(f"Provider {provider} not supported")
@@ -0,0 +1,227 @@
1
+ from collections.abc import Callable
2
+ import json
3
+ import os
4
+ import re
5
+ import time
6
+ from typing import Any, Literal, cast
7
+
8
+ from loguru import logger
9
+ from ollama import ChatResponse, Client, ResponseError
10
+
11
+ from not_again_ai.llm.chat_completion.types import (
12
+ AssistantMessage,
13
+ ChatCompletionChoice,
14
+ ChatCompletionRequest,
15
+ ChatCompletionResponse,
16
+ Function,
17
+ ToolCall,
18
+ )
19
+
20
+ OLLAMA_PARAMETER_MAP = {
21
+ "frequency_penalty": "repeat_penalty",
22
+ "max_completion_tokens": "num_predict",
23
+ "context_window": "num_ctx",
24
+ "n": None,
25
+ "tool_choice": None,
26
+ "reasoning_effort": None,
27
+ "parallel_tool_calls": None,
28
+ "logit_bias": None,
29
+ "top_logprobs": None,
30
+ "presence_penalty": None,
31
+ }
32
+
33
+
34
+ def validate(request: ChatCompletionRequest) -> None:
35
+ if request.json_mode and request.structured_outputs is not None:
36
+ raise ValueError("json_schema and json_mode cannot be used together.")
37
+
38
+ # Check if any of the parameters set to OLLAMA_PARAMETER_MAP are not None
39
+ for key, value in OLLAMA_PARAMETER_MAP.items():
40
+ if value is None and getattr(request, key) is not None:
41
+ logger.warning(f"Parameter {key} is not supported by Ollama and will be ignored.")
42
+
43
+ # If "stop" is not None, check if it is just a string
44
+ if isinstance(request.stop, list):
45
+ logger.warning("Parameter 'stop' needs to be a string and not a list. It will be ignored.")
46
+ request.stop = None
47
+
48
+
49
+ def ollama_chat_completion(
50
+ request: ChatCompletionRequest,
51
+ client: Callable[..., Any],
52
+ ) -> ChatCompletionResponse:
53
+ validate(request)
54
+
55
+ kwargs = request.model_dump(mode="json", exclude_none=True)
56
+
57
+ # For each key in OLLAMA_PARAMETER_MAP
58
+ # If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP
59
+ # If it is None, remove that key from kwargs
60
+ for key, value in OLLAMA_PARAMETER_MAP.items():
61
+ if value is not None and key in kwargs:
62
+ kwargs[value] = kwargs.pop(key)
63
+ elif value is None and key in kwargs:
64
+ del kwargs[key]
65
+
66
+ # If json_mode is True, set the format to json
67
+ json_mode = kwargs.get("json_mode", None)
68
+ if json_mode:
69
+ kwargs["format"] = "json"
70
+ kwargs.pop("json_mode")
71
+ elif json_mode is not None and not json_mode:
72
+ kwargs.pop("json_mode")
73
+
74
+ # If structured_outputs is not None, set the format to structured_outputs
75
+ if kwargs.get("structured_outputs", None):
76
+ # Check if the schema is in the OpenAI and pull out the schema
77
+ if "schema" in kwargs["structured_outputs"]:
78
+ kwargs["format"] = kwargs["structured_outputs"]["schema"]
79
+ kwargs.pop("structured_outputs")
80
+ else:
81
+ kwargs["format"] = kwargs.pop("structured_outputs")
82
+
83
+ option_fields = [
84
+ "mirostat",
85
+ "mirostat_eta",
86
+ "mirostat_tau",
87
+ "num_ctx",
88
+ "repeat_last_n",
89
+ "repeat_penalty",
90
+ "temperature",
91
+ "seed",
92
+ "stop",
93
+ "tfs_z",
94
+ "num_predict",
95
+ "top_k",
96
+ "top_p",
97
+ "min_p",
98
+ ]
99
+ # For each field in option_fields, if it is in kwargs, make it under an options dictionary
100
+ options = {}
101
+ for field in option_fields:
102
+ if field in kwargs:
103
+ options[field] = kwargs.pop(field)
104
+ kwargs["options"] = options
105
+
106
+ for message in kwargs["messages"]:
107
+ role = message.get("role", None)
108
+ # For each ToolMessage, remove the name field
109
+ if role is not None and role == "tool":
110
+ message.pop("name")
111
+
112
+ # For each AssistantMessage with tool calls, remove the id field
113
+ if role is not None and role == "assistant" and message.get("tool_calls", None):
114
+ for tool_call in message["tool_calls"]:
115
+ tool_call.pop("id")
116
+
117
+ # Content and images need to be separated
118
+ images = []
119
+ content = ""
120
+ if isinstance(message["content"], list):
121
+ for item in message["content"]:
122
+ if item["type"] == "image_url":
123
+ image_url = item["image_url"]["url"]
124
+ # Remove the data URL prefix if present
125
+ if image_url.startswith("data:"):
126
+ image_url = image_url.split("base64,", 1)[1]
127
+ images.append(image_url)
128
+ else:
129
+ content += item["text"]
130
+ else:
131
+ content = message["content"]
132
+
133
+ message["content"] = content
134
+ if len(images) > 1:
135
+ images = images[:1]
136
+ logger.warning("Ollama model only supports a single image per message. Using only the first images.")
137
+ message["images"] = images
138
+
139
+ try:
140
+ start_time = time.time()
141
+ response: ChatResponse = client(**kwargs)
142
+ end_time = time.time()
143
+ response_duration = round(end_time - start_time, 4)
144
+ except ResponseError as e:
145
+ # If the error says "model 'model' not found" use regex then raise a more specific error
146
+ expected_pattern = f"model '{request.model}' not found"
147
+ if re.search(expected_pattern, e.error):
148
+ raise ResponseError(f"Model '{request.model}' not found.") from e
149
+ else:
150
+ raise ResponseError(e.error) from e
151
+
152
+ errors = ""
153
+
154
+ # Handle tool calls
155
+ tool_calls: list[ToolCall] | None = None
156
+ if response.message.tool_calls:
157
+ parsed_tool_calls: list[ToolCall] = []
158
+ for tool_call in response.message.tool_calls:
159
+ tool_name = tool_call.function.name
160
+ if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
161
+ errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
162
+ tool_args = tool_call.function.arguments
163
+ parsed_tool_calls.append(
164
+ ToolCall(
165
+ id="",
166
+ function=Function(
167
+ name=tool_name,
168
+ arguments=tool_args,
169
+ ),
170
+ )
171
+ )
172
+ tool_calls = parsed_tool_calls
173
+
174
+ json_message = None
175
+ if (request.json_mode or (request.structured_outputs is not None)) and response.message.content:
176
+ try:
177
+ json_message = json.loads(response.message.content)
178
+ except json.JSONDecodeError:
179
+ errors += "Message failed to parse into JSON\n"
180
+
181
+ finish_reason = cast(
182
+ Literal["stop", "length", "tool_calls", "content_filter"],
183
+ "stop" if response.done_reason is None else response.done_reason or "stop",
184
+ )
185
+
186
+ choice = ChatCompletionChoice(
187
+ message=AssistantMessage(
188
+ content=response.message.content or "",
189
+ tool_calls=tool_calls,
190
+ ),
191
+ finish_reason=finish_reason,
192
+ json_message=json_message,
193
+ )
194
+
195
+ return ChatCompletionResponse(
196
+ choices=[choice],
197
+ errors=errors.strip(),
198
+ completion_tokens=response.get("eval_count", -1),
199
+ prompt_tokens=response.get("prompt_eval_count", -1),
200
+ response_duration=response_duration,
201
+ )
202
+
203
+
204
+ def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]:
205
+ """Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable.
206
+
207
+ Args:
208
+ host (str, optional): The host URL of the Ollama server.
209
+ timeout (float, optional): The timeout for requests
210
+
211
+ Returns:
212
+ Client: An instance of the Ollama client.
213
+
214
+ Examples:
215
+ >>> client = client(host="http://localhost:11434")
216
+ """
217
+ if host is None:
218
+ host = os.getenv("OLLAMA_HOST")
219
+ if host is None:
220
+ logger.warning("OLLAMA_HOST environment variable not set, using default host: http://localhost:11434")
221
+ host = "http://localhost:11434"
222
+
223
+ def client_callable(**kwargs: Any) -> Any:
224
+ client = Client(host=host, timeout=timeout)
225
+ return client.chat(**kwargs)
226
+
227
+ return client_callable
@@ -0,0 +1,290 @@
1
+ from collections.abc import Callable
2
+ import json
3
+ import time
4
+ from typing import Any, Literal
5
+
6
+ from azure.identity import DefaultAzureCredential, get_bearer_token_provider
7
+ from openai import AzureOpenAI, OpenAI
8
+
9
+ from not_again_ai.llm.chat_completion.types import (
10
+ AssistantMessage,
11
+ ChatCompletionChoice,
12
+ ChatCompletionRequest,
13
+ ChatCompletionResponse,
14
+ Function,
15
+ ToolCall,
16
+ )
17
+
18
+ OPENAI_PARAMETER_MAP = {
19
+ "context_window": None,
20
+ "mirostat": None,
21
+ "mirostat_eta": None,
22
+ "mirostat_tau": None,
23
+ "repeat_last_n": None,
24
+ "tfs_z": None,
25
+ "top_k": None,
26
+ "min_p": None,
27
+ }
28
+
29
+
30
+ def validate(request: ChatCompletionRequest) -> None:
31
+ if request.json_mode and request.structured_outputs is not None:
32
+ raise ValueError("json_schema and json_mode cannot be used together.")
33
+
34
+
35
+ def openai_chat_completion(
36
+ request: ChatCompletionRequest,
37
+ client: Callable[..., Any],
38
+ ) -> ChatCompletionResponse:
39
+ validate(request)
40
+
41
+ # Format the response format parameters to be compatible with OpenAI API
42
+ if request.json_mode:
43
+ response_format: dict[str, Any] = {"type": "json_object"}
44
+ elif request.structured_outputs is not None:
45
+ response_format = {"type": "json_schema", "json_schema": request.structured_outputs}
46
+ else:
47
+ response_format = {"type": "text"}
48
+
49
+ kwargs = request.model_dump(mode="json", exclude_none=True)
50
+
51
+ # For each key in OPENAI_PARAMETER_MAP
52
+ # If it is not None, set the key in kwargs to the value of the corresponding value in OPENAI_PARAMETER_MAP
53
+ # If it is None, remove that key from kwargs
54
+ for key, value in OPENAI_PARAMETER_MAP.items():
55
+ if value is not None and key in kwargs:
56
+ kwargs[value] = kwargs.pop(key)
57
+ elif value is None and key in kwargs:
58
+ del kwargs[key]
59
+
60
+ # Iterate over each message and
61
+ for message in kwargs["messages"]:
62
+ role = message.get("role", None)
63
+ # For each ToolMessage, change the "name" field to be named "tool_call_id" instead
64
+ if role is not None and role == "tool":
65
+ message["tool_call_id"] = message.pop("name")
66
+
67
+ # For each AssistantMessage with tool calls, make the function arguments a string
68
+ if role is not None and role == "assistant" and message.get("tool_calls", None):
69
+ for tool_call in message["tool_calls"]:
70
+ tool_call["function"]["arguments"] = str(tool_call["function"]["arguments"])
71
+
72
+ # Delete the json_mode and structured_outputs from kwargs
73
+ kwargs.pop("json_mode", None)
74
+ kwargs.pop("structured_outputs", None)
75
+
76
+ # Add the response_format to kwargs
77
+ kwargs["response_format"] = response_format
78
+
79
+ # Handle tool_choice when the provided tool_choice the name of the required tool.
80
+ if request.tool_choice is not None and request.tool_choice not in ["none", "auto", "required"]:
81
+ kwargs["tool_choice"] = {"type": "function", "function": {"name": request.tool_choice}}
82
+
83
+ start_time = time.time()
84
+ response = client(**kwargs)
85
+ end_time = time.time()
86
+ response_duration = round(end_time - start_time, 4)
87
+
88
+ errors = ""
89
+ extras: dict[str, Any] = {}
90
+ choices: list[ChatCompletionChoice] = []
91
+ for index, choice in enumerate(response["choices"]):
92
+ choice_extras: dict[str, Any] = {}
93
+ finish_reason = choice["finish_reason"]
94
+
95
+ message = choice["message"]
96
+ tool_calls: list[ToolCall] | None = None
97
+ if message.get("tool_calls", None):
98
+ parsed_tool_calls: list[ToolCall] = []
99
+ for tool_call in message["tool_calls"]:
100
+ tool_name = tool_call.get("function", {}).get("name", None)
101
+ # Check if the tool name is valid (one of the tool names in the request)
102
+ if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
103
+ errors += f"Choice {index}: Tool call {tool_call} has an invalid tool name: {tool_name}\n"
104
+
105
+ tool_args = tool_call.get("function", {}).get("arguments", None)
106
+ try:
107
+ tool_args = json.loads(tool_args)
108
+ except json.JSONDecodeError:
109
+ errors += f"Choice {index}: Tool call {tool_call} failed to parse arguments into JSON\n"
110
+
111
+ parsed_tool_calls.append(
112
+ ToolCall(
113
+ id=tool_call["id"],
114
+ function=Function(
115
+ name=tool_name,
116
+ arguments=tool_args,
117
+ ),
118
+ )
119
+ )
120
+ tool_calls = parsed_tool_calls
121
+
122
+ json_message = None
123
+ if request.json_mode or (request.structured_outputs is not None):
124
+ try:
125
+ json_message = json.loads(message.get("content", "{}"))
126
+ except json.JSONDecodeError:
127
+ errors += f"Choice {index}: Message failed to parse into JSON\n"
128
+
129
+ # Handle logprobs
130
+ logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None
131
+ if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None:
132
+ logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = []
133
+ for logprob in choice["logprobs"]["content"]:
134
+ if logprob.get("top_logprobs", None):
135
+ curr_logprob_infos: list[dict[str, Any]] = []
136
+ for top_logprob in logprob.get("top_logprobs", []):
137
+ curr_logprob_infos.append(
138
+ {
139
+ "token": top_logprob.get("token", ""),
140
+ "logprob": top_logprob.get("logprob", 0),
141
+ "bytes": top_logprob.get("bytes", 0),
142
+ }
143
+ )
144
+ logprobs_list.append(curr_logprob_infos)
145
+ else:
146
+ logprobs_list.append(
147
+ {
148
+ "token": logprob.get("token", ""),
149
+ "logprob": logprob.get("logprob", 0),
150
+ "bytes": logprob.get("bytes", 0),
151
+ }
152
+ )
153
+ logprobs = logprobs_list
154
+
155
+ # Handle extras that OpenAI or Azure OpenAI return
156
+ if choice.get("content_filter_results", None):
157
+ choice_extras["content_filter_results"] = choice["content_filter_results"]
158
+
159
+ choices.append(
160
+ ChatCompletionChoice(
161
+ message=AssistantMessage(
162
+ content=message.get("content") or "",
163
+ refusal=message.get("refusal", None),
164
+ tool_calls=tool_calls,
165
+ ),
166
+ finish_reason=finish_reason,
167
+ json_message=json_message,
168
+ logprobs=logprobs,
169
+ extras=choice_extras,
170
+ )
171
+ )
172
+
173
+ completion_tokens = response["usage"].get("completion_tokens", -1)
174
+ prompt_tokens = response["usage"].get("prompt_tokens", -1)
175
+ completion_detailed_tokens = response["usage"].get("completion_detailed_tokens", None)
176
+ prompt_detailed_tokens = response["usage"].get("prompt_detailed_tokens", None)
177
+ system_fingerprint = response.get("system_fingerprint", None)
178
+
179
+ extras["prompt_filter_results"] = response.get("prompt_filter_results", None)
180
+
181
+ return ChatCompletionResponse(
182
+ choices=choices,
183
+ errors=errors.strip(),
184
+ extras=extras,
185
+ completion_tokens=completion_tokens,
186
+ prompt_tokens=prompt_tokens,
187
+ completion_detailed_tokens=completion_detailed_tokens,
188
+ prompt_detailed_tokens=prompt_detailed_tokens,
189
+ system_fingerprint=system_fingerprint,
190
+ response_duration=response_duration,
191
+ )
192
+
193
+
194
+ def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]:
195
+ """Creates a callable that instantiates and uses an OpenAI client.
196
+
197
+ Args:
198
+ client_class: The OpenAI client class to instantiate (OpenAI or AzureOpenAI)
199
+ **client_args: Arguments to pass to the client constructor
200
+
201
+ Returns:
202
+ A callable that creates a client and returns completion results
203
+ """
204
+ filtered_args = {k: v for k, v in client_args.items() if v is not None}
205
+
206
+ def client_callable(**kwargs: Any) -> Any:
207
+ client = client_class(**filtered_args)
208
+ completion = client.chat.completions.create(**kwargs)
209
+ return completion.to_dict()
210
+
211
+ return client_callable
212
+
213
+
214
+ class InvalidOAIAPITypeError(Exception):
215
+ """Raised when an invalid OAIAPIType string is provided."""
216
+
217
+
218
+ def openai_client(
219
+ api_type: Literal["openai", "azure_openai"] = "openai",
220
+ api_key: str | None = None,
221
+ organization: str | None = None,
222
+ aoai_api_version: str = "2024-06-01",
223
+ azure_endpoint: str | None = None,
224
+ timeout: float | None = None,
225
+ max_retries: int | None = None,
226
+ ) -> Callable[..., Any]:
227
+ """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters.
228
+
229
+ It is preferred to use RBAC authentication for Azure OpenAI. You must be signed in with the Azure CLI and have correct role assigned.
230
+ See https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
231
+
232
+ Args:
233
+ api_type (str, optional): Type of the API to be used. Accepted values are 'openai' or 'azure_openai'.
234
+ Defaults to 'openai'.
235
+ api_key (str, optional): The API key to authenticate the client. If not provided,
236
+ OpenAI automatically uses `OPENAI_API_KEY` from the environment.
237
+ If provided for Azure OpenAI, it will be used for authentication instead of the Azure AD token provider.
238
+ organization (str, optional): The ID of the organization. If not provided,
239
+ OpenAI automotically uses `OPENAI_ORG_ID` from the environment.
240
+ aoai_api_version (str, optional): Only applicable if using Azure OpenAI https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning
241
+ azure_endpoint (str, optional): The endpoint to use for Azure OpenAI.
242
+ timeout (float, optional): By default requests time out after 10 minutes.
243
+ max_retries (int, optional): Certain errors are automatically retried 2 times by default,
244
+ with a short exponential backoff. Connection errors (for example, due to a network connectivity problem),
245
+ 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default.
246
+
247
+ Returns:
248
+ Callable[..., Any]: A callable that creates a client and returns completion results
249
+
250
+
251
+ Raises:
252
+ InvalidOAIAPITypeError: If an invalid API type string is provided.
253
+ NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai').
254
+ """
255
+ if api_type not in ["openai", "azure_openai"]:
256
+ raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.")
257
+
258
+ if api_type == "openai":
259
+ return create_client_callable(
260
+ OpenAI,
261
+ api_key=api_key,
262
+ organization=organization,
263
+ timeout=timeout,
264
+ max_retries=max_retries,
265
+ )
266
+ elif api_type == "azure_openai":
267
+ if api_key:
268
+ return create_client_callable(
269
+ AzureOpenAI,
270
+ api_version=aoai_api_version,
271
+ azure_endpoint=azure_endpoint,
272
+ api_key=api_key,
273
+ timeout=timeout,
274
+ max_retries=max_retries,
275
+ )
276
+ else:
277
+ azure_credential = DefaultAzureCredential()
278
+ ad_token_provider = get_bearer_token_provider(
279
+ azure_credential, "https://cognitiveservices.azure.com/.default"
280
+ )
281
+ return create_client_callable(
282
+ AzureOpenAI,
283
+ api_version=aoai_api_version,
284
+ azure_endpoint=azure_endpoint,
285
+ azure_ad_token_provider=ad_token_provider,
286
+ timeout=timeout,
287
+ max_retries=max_retries,
288
+ )
289
+ else:
290
+ raise NotImplementedError(f"API type '{api_type}' is invalid.")