not-again-ai 0.14.0__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- not_again_ai/llm/chat_completion/__init__.py +4 -0
- not_again_ai/llm/chat_completion/interface.py +32 -0
- not_again_ai/llm/chat_completion/providers/ollama_api.py +227 -0
- not_again_ai/llm/chat_completion/providers/openai_api.py +290 -0
- not_again_ai/llm/chat_completion/types.py +145 -0
- not_again_ai/llm/embedding/__init__.py +4 -0
- not_again_ai/llm/embedding/interface.py +28 -0
- not_again_ai/llm/embedding/providers/ollama_api.py +87 -0
- not_again_ai/llm/embedding/providers/openai_api.py +126 -0
- not_again_ai/llm/embedding/types.py +23 -0
- not_again_ai/llm/prompting/__init__.py +3 -0
- not_again_ai/llm/prompting/compile_prompt.py +125 -0
- not_again_ai/llm/prompting/interface.py +46 -0
- not_again_ai/llm/prompting/providers/openai_tiktoken.py +122 -0
- not_again_ai/llm/prompting/types.py +43 -0
- {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/METADATA +24 -40
- not_again_ai-0.16.0.dist-info/RECORD +38 -0
- {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/WHEEL +1 -1
- not_again_ai/llm/gh_models/azure_ai_client.py +0 -20
- not_again_ai/llm/gh_models/chat_completion.py +0 -81
- not_again_ai/llm/openai_api/chat_completion.py +0 -339
- not_again_ai/llm/openai_api/context_management.py +0 -70
- not_again_ai/llm/openai_api/embeddings.py +0 -62
- not_again_ai/llm/openai_api/openai_client.py +0 -78
- not_again_ai/llm/openai_api/prompts.py +0 -191
- not_again_ai/llm/openai_api/tokens.py +0 -184
- not_again_ai/local_llm/__init__.py +0 -27
- not_again_ai/local_llm/chat_completion.py +0 -105
- not_again_ai/local_llm/huggingface/chat_completion.py +0 -59
- not_again_ai/local_llm/huggingface/helpers.py +0 -23
- not_again_ai/local_llm/ollama/__init__.py +0 -0
- not_again_ai/local_llm/ollama/chat_completion.py +0 -111
- not_again_ai/local_llm/ollama/model_mapping.py +0 -17
- not_again_ai/local_llm/ollama/ollama_client.py +0 -24
- not_again_ai/local_llm/ollama/service.py +0 -81
- not_again_ai/local_llm/ollama/tokens.py +0 -104
- not_again_ai/local_llm/prompts.py +0 -38
- not_again_ai/local_llm/tokens.py +0 -90
- not_again_ai-0.14.0.dist-info/RECORD +0 -44
- not_again_ai-0.14.0.dist-info/entry_points.txt +0 -3
- /not_again_ai/llm/{gh_models → chat_completion/providers}/__init__.py +0 -0
- /not_again_ai/llm/{openai_api → embedding/providers}/__init__.py +0 -0
- /not_again_ai/{local_llm/huggingface → llm/prompting/providers}/__init__.py +0 -0
- {not_again_ai-0.14.0.dist-info → not_again_ai-0.16.0.dist-info}/LICENSE +0 -0
@@ -0,0 +1,32 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion
|
5
|
+
from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion
|
6
|
+
from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, ChatCompletionResponse
|
7
|
+
|
8
|
+
|
9
|
+
def chat_completion(
|
10
|
+
request: ChatCompletionRequest,
|
11
|
+
provider: str,
|
12
|
+
client: Callable[..., Any],
|
13
|
+
) -> ChatCompletionResponse:
|
14
|
+
"""Get a chat completion response from the given provider. Currently supported providers:
|
15
|
+
- `openai` - OpenAI
|
16
|
+
- `azure_openai` - Azure OpenAI
|
17
|
+
- `ollama` - Ollama
|
18
|
+
|
19
|
+
Args:
|
20
|
+
request: Request parameter object
|
21
|
+
provider: The supported provider name
|
22
|
+
client: Client information, see the provider's implementation for what can be provided
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
ChatCompletionResponse: The chat completion response.
|
26
|
+
"""
|
27
|
+
if provider == "openai" or provider == "azure_openai":
|
28
|
+
return openai_chat_completion(request, client)
|
29
|
+
elif provider == "ollama":
|
30
|
+
return ollama_chat_completion(request, client)
|
31
|
+
else:
|
32
|
+
raise ValueError(f"Provider {provider} not supported")
|
@@ -0,0 +1,227 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
import time
|
6
|
+
from typing import Any, Literal, cast
|
7
|
+
|
8
|
+
from loguru import logger
|
9
|
+
from ollama import ChatResponse, Client, ResponseError
|
10
|
+
|
11
|
+
from not_again_ai.llm.chat_completion.types import (
|
12
|
+
AssistantMessage,
|
13
|
+
ChatCompletionChoice,
|
14
|
+
ChatCompletionRequest,
|
15
|
+
ChatCompletionResponse,
|
16
|
+
Function,
|
17
|
+
ToolCall,
|
18
|
+
)
|
19
|
+
|
20
|
+
OLLAMA_PARAMETER_MAP = {
|
21
|
+
"frequency_penalty": "repeat_penalty",
|
22
|
+
"max_completion_tokens": "num_predict",
|
23
|
+
"context_window": "num_ctx",
|
24
|
+
"n": None,
|
25
|
+
"tool_choice": None,
|
26
|
+
"reasoning_effort": None,
|
27
|
+
"parallel_tool_calls": None,
|
28
|
+
"logit_bias": None,
|
29
|
+
"top_logprobs": None,
|
30
|
+
"presence_penalty": None,
|
31
|
+
}
|
32
|
+
|
33
|
+
|
34
|
+
def validate(request: ChatCompletionRequest) -> None:
|
35
|
+
if request.json_mode and request.structured_outputs is not None:
|
36
|
+
raise ValueError("json_schema and json_mode cannot be used together.")
|
37
|
+
|
38
|
+
# Check if any of the parameters set to OLLAMA_PARAMETER_MAP are not None
|
39
|
+
for key, value in OLLAMA_PARAMETER_MAP.items():
|
40
|
+
if value is None and getattr(request, key) is not None:
|
41
|
+
logger.warning(f"Parameter {key} is not supported by Ollama and will be ignored.")
|
42
|
+
|
43
|
+
# If "stop" is not None, check if it is just a string
|
44
|
+
if isinstance(request.stop, list):
|
45
|
+
logger.warning("Parameter 'stop' needs to be a string and not a list. It will be ignored.")
|
46
|
+
request.stop = None
|
47
|
+
|
48
|
+
|
49
|
+
def ollama_chat_completion(
|
50
|
+
request: ChatCompletionRequest,
|
51
|
+
client: Callable[..., Any],
|
52
|
+
) -> ChatCompletionResponse:
|
53
|
+
validate(request)
|
54
|
+
|
55
|
+
kwargs = request.model_dump(mode="json", exclude_none=True)
|
56
|
+
|
57
|
+
# For each key in OLLAMA_PARAMETER_MAP
|
58
|
+
# If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP
|
59
|
+
# If it is None, remove that key from kwargs
|
60
|
+
for key, value in OLLAMA_PARAMETER_MAP.items():
|
61
|
+
if value is not None and key in kwargs:
|
62
|
+
kwargs[value] = kwargs.pop(key)
|
63
|
+
elif value is None and key in kwargs:
|
64
|
+
del kwargs[key]
|
65
|
+
|
66
|
+
# If json_mode is True, set the format to json
|
67
|
+
json_mode = kwargs.get("json_mode", None)
|
68
|
+
if json_mode:
|
69
|
+
kwargs["format"] = "json"
|
70
|
+
kwargs.pop("json_mode")
|
71
|
+
elif json_mode is not None and not json_mode:
|
72
|
+
kwargs.pop("json_mode")
|
73
|
+
|
74
|
+
# If structured_outputs is not None, set the format to structured_outputs
|
75
|
+
if kwargs.get("structured_outputs", None):
|
76
|
+
# Check if the schema is in the OpenAI and pull out the schema
|
77
|
+
if "schema" in kwargs["structured_outputs"]:
|
78
|
+
kwargs["format"] = kwargs["structured_outputs"]["schema"]
|
79
|
+
kwargs.pop("structured_outputs")
|
80
|
+
else:
|
81
|
+
kwargs["format"] = kwargs.pop("structured_outputs")
|
82
|
+
|
83
|
+
option_fields = [
|
84
|
+
"mirostat",
|
85
|
+
"mirostat_eta",
|
86
|
+
"mirostat_tau",
|
87
|
+
"num_ctx",
|
88
|
+
"repeat_last_n",
|
89
|
+
"repeat_penalty",
|
90
|
+
"temperature",
|
91
|
+
"seed",
|
92
|
+
"stop",
|
93
|
+
"tfs_z",
|
94
|
+
"num_predict",
|
95
|
+
"top_k",
|
96
|
+
"top_p",
|
97
|
+
"min_p",
|
98
|
+
]
|
99
|
+
# For each field in option_fields, if it is in kwargs, make it under an options dictionary
|
100
|
+
options = {}
|
101
|
+
for field in option_fields:
|
102
|
+
if field in kwargs:
|
103
|
+
options[field] = kwargs.pop(field)
|
104
|
+
kwargs["options"] = options
|
105
|
+
|
106
|
+
for message in kwargs["messages"]:
|
107
|
+
role = message.get("role", None)
|
108
|
+
# For each ToolMessage, remove the name field
|
109
|
+
if role is not None and role == "tool":
|
110
|
+
message.pop("name")
|
111
|
+
|
112
|
+
# For each AssistantMessage with tool calls, remove the id field
|
113
|
+
if role is not None and role == "assistant" and message.get("tool_calls", None):
|
114
|
+
for tool_call in message["tool_calls"]:
|
115
|
+
tool_call.pop("id")
|
116
|
+
|
117
|
+
# Content and images need to be separated
|
118
|
+
images = []
|
119
|
+
content = ""
|
120
|
+
if isinstance(message["content"], list):
|
121
|
+
for item in message["content"]:
|
122
|
+
if item["type"] == "image_url":
|
123
|
+
image_url = item["image_url"]["url"]
|
124
|
+
# Remove the data URL prefix if present
|
125
|
+
if image_url.startswith("data:"):
|
126
|
+
image_url = image_url.split("base64,", 1)[1]
|
127
|
+
images.append(image_url)
|
128
|
+
else:
|
129
|
+
content += item["text"]
|
130
|
+
else:
|
131
|
+
content = message["content"]
|
132
|
+
|
133
|
+
message["content"] = content
|
134
|
+
if len(images) > 1:
|
135
|
+
images = images[:1]
|
136
|
+
logger.warning("Ollama model only supports a single image per message. Using only the first images.")
|
137
|
+
message["images"] = images
|
138
|
+
|
139
|
+
try:
|
140
|
+
start_time = time.time()
|
141
|
+
response: ChatResponse = client(**kwargs)
|
142
|
+
end_time = time.time()
|
143
|
+
response_duration = round(end_time - start_time, 4)
|
144
|
+
except ResponseError as e:
|
145
|
+
# If the error says "model 'model' not found" use regex then raise a more specific error
|
146
|
+
expected_pattern = f"model '{request.model}' not found"
|
147
|
+
if re.search(expected_pattern, e.error):
|
148
|
+
raise ResponseError(f"Model '{request.model}' not found.") from e
|
149
|
+
else:
|
150
|
+
raise ResponseError(e.error) from e
|
151
|
+
|
152
|
+
errors = ""
|
153
|
+
|
154
|
+
# Handle tool calls
|
155
|
+
tool_calls: list[ToolCall] | None = None
|
156
|
+
if response.message.tool_calls:
|
157
|
+
parsed_tool_calls: list[ToolCall] = []
|
158
|
+
for tool_call in response.message.tool_calls:
|
159
|
+
tool_name = tool_call.function.name
|
160
|
+
if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
|
161
|
+
errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
|
162
|
+
tool_args = tool_call.function.arguments
|
163
|
+
parsed_tool_calls.append(
|
164
|
+
ToolCall(
|
165
|
+
id="",
|
166
|
+
function=Function(
|
167
|
+
name=tool_name,
|
168
|
+
arguments=tool_args,
|
169
|
+
),
|
170
|
+
)
|
171
|
+
)
|
172
|
+
tool_calls = parsed_tool_calls
|
173
|
+
|
174
|
+
json_message = None
|
175
|
+
if (request.json_mode or (request.structured_outputs is not None)) and response.message.content:
|
176
|
+
try:
|
177
|
+
json_message = json.loads(response.message.content)
|
178
|
+
except json.JSONDecodeError:
|
179
|
+
errors += "Message failed to parse into JSON\n"
|
180
|
+
|
181
|
+
finish_reason = cast(
|
182
|
+
Literal["stop", "length", "tool_calls", "content_filter"],
|
183
|
+
"stop" if response.done_reason is None else response.done_reason or "stop",
|
184
|
+
)
|
185
|
+
|
186
|
+
choice = ChatCompletionChoice(
|
187
|
+
message=AssistantMessage(
|
188
|
+
content=response.message.content or "",
|
189
|
+
tool_calls=tool_calls,
|
190
|
+
),
|
191
|
+
finish_reason=finish_reason,
|
192
|
+
json_message=json_message,
|
193
|
+
)
|
194
|
+
|
195
|
+
return ChatCompletionResponse(
|
196
|
+
choices=[choice],
|
197
|
+
errors=errors.strip(),
|
198
|
+
completion_tokens=response.get("eval_count", -1),
|
199
|
+
prompt_tokens=response.get("prompt_eval_count", -1),
|
200
|
+
response_duration=response_duration,
|
201
|
+
)
|
202
|
+
|
203
|
+
|
204
|
+
def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]:
|
205
|
+
"""Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
host (str, optional): The host URL of the Ollama server.
|
209
|
+
timeout (float, optional): The timeout for requests
|
210
|
+
|
211
|
+
Returns:
|
212
|
+
Client: An instance of the Ollama client.
|
213
|
+
|
214
|
+
Examples:
|
215
|
+
>>> client = client(host="http://localhost:11434")
|
216
|
+
"""
|
217
|
+
if host is None:
|
218
|
+
host = os.getenv("OLLAMA_HOST")
|
219
|
+
if host is None:
|
220
|
+
logger.warning("OLLAMA_HOST environment variable not set, using default host: http://localhost:11434")
|
221
|
+
host = "http://localhost:11434"
|
222
|
+
|
223
|
+
def client_callable(**kwargs: Any) -> Any:
|
224
|
+
client = Client(host=host, timeout=timeout)
|
225
|
+
return client.chat(**kwargs)
|
226
|
+
|
227
|
+
return client_callable
|
@@ -0,0 +1,290 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
import json
|
3
|
+
import time
|
4
|
+
from typing import Any, Literal
|
5
|
+
|
6
|
+
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
7
|
+
from openai import AzureOpenAI, OpenAI
|
8
|
+
|
9
|
+
from not_again_ai.llm.chat_completion.types import (
|
10
|
+
AssistantMessage,
|
11
|
+
ChatCompletionChoice,
|
12
|
+
ChatCompletionRequest,
|
13
|
+
ChatCompletionResponse,
|
14
|
+
Function,
|
15
|
+
ToolCall,
|
16
|
+
)
|
17
|
+
|
18
|
+
OPENAI_PARAMETER_MAP = {
|
19
|
+
"context_window": None,
|
20
|
+
"mirostat": None,
|
21
|
+
"mirostat_eta": None,
|
22
|
+
"mirostat_tau": None,
|
23
|
+
"repeat_last_n": None,
|
24
|
+
"tfs_z": None,
|
25
|
+
"top_k": None,
|
26
|
+
"min_p": None,
|
27
|
+
}
|
28
|
+
|
29
|
+
|
30
|
+
def validate(request: ChatCompletionRequest) -> None:
|
31
|
+
if request.json_mode and request.structured_outputs is not None:
|
32
|
+
raise ValueError("json_schema and json_mode cannot be used together.")
|
33
|
+
|
34
|
+
|
35
|
+
def openai_chat_completion(
|
36
|
+
request: ChatCompletionRequest,
|
37
|
+
client: Callable[..., Any],
|
38
|
+
) -> ChatCompletionResponse:
|
39
|
+
validate(request)
|
40
|
+
|
41
|
+
# Format the response format parameters to be compatible with OpenAI API
|
42
|
+
if request.json_mode:
|
43
|
+
response_format: dict[str, Any] = {"type": "json_object"}
|
44
|
+
elif request.structured_outputs is not None:
|
45
|
+
response_format = {"type": "json_schema", "json_schema": request.structured_outputs}
|
46
|
+
else:
|
47
|
+
response_format = {"type": "text"}
|
48
|
+
|
49
|
+
kwargs = request.model_dump(mode="json", exclude_none=True)
|
50
|
+
|
51
|
+
# For each key in OPENAI_PARAMETER_MAP
|
52
|
+
# If it is not None, set the key in kwargs to the value of the corresponding value in OPENAI_PARAMETER_MAP
|
53
|
+
# If it is None, remove that key from kwargs
|
54
|
+
for key, value in OPENAI_PARAMETER_MAP.items():
|
55
|
+
if value is not None and key in kwargs:
|
56
|
+
kwargs[value] = kwargs.pop(key)
|
57
|
+
elif value is None and key in kwargs:
|
58
|
+
del kwargs[key]
|
59
|
+
|
60
|
+
# Iterate over each message and
|
61
|
+
for message in kwargs["messages"]:
|
62
|
+
role = message.get("role", None)
|
63
|
+
# For each ToolMessage, change the "name" field to be named "tool_call_id" instead
|
64
|
+
if role is not None and role == "tool":
|
65
|
+
message["tool_call_id"] = message.pop("name")
|
66
|
+
|
67
|
+
# For each AssistantMessage with tool calls, make the function arguments a string
|
68
|
+
if role is not None and role == "assistant" and message.get("tool_calls", None):
|
69
|
+
for tool_call in message["tool_calls"]:
|
70
|
+
tool_call["function"]["arguments"] = str(tool_call["function"]["arguments"])
|
71
|
+
|
72
|
+
# Delete the json_mode and structured_outputs from kwargs
|
73
|
+
kwargs.pop("json_mode", None)
|
74
|
+
kwargs.pop("structured_outputs", None)
|
75
|
+
|
76
|
+
# Add the response_format to kwargs
|
77
|
+
kwargs["response_format"] = response_format
|
78
|
+
|
79
|
+
# Handle tool_choice when the provided tool_choice the name of the required tool.
|
80
|
+
if request.tool_choice is not None and request.tool_choice not in ["none", "auto", "required"]:
|
81
|
+
kwargs["tool_choice"] = {"type": "function", "function": {"name": request.tool_choice}}
|
82
|
+
|
83
|
+
start_time = time.time()
|
84
|
+
response = client(**kwargs)
|
85
|
+
end_time = time.time()
|
86
|
+
response_duration = round(end_time - start_time, 4)
|
87
|
+
|
88
|
+
errors = ""
|
89
|
+
extras: dict[str, Any] = {}
|
90
|
+
choices: list[ChatCompletionChoice] = []
|
91
|
+
for index, choice in enumerate(response["choices"]):
|
92
|
+
choice_extras: dict[str, Any] = {}
|
93
|
+
finish_reason = choice["finish_reason"]
|
94
|
+
|
95
|
+
message = choice["message"]
|
96
|
+
tool_calls: list[ToolCall] | None = None
|
97
|
+
if message.get("tool_calls", None):
|
98
|
+
parsed_tool_calls: list[ToolCall] = []
|
99
|
+
for tool_call in message["tool_calls"]:
|
100
|
+
tool_name = tool_call.get("function", {}).get("name", None)
|
101
|
+
# Check if the tool name is valid (one of the tool names in the request)
|
102
|
+
if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
|
103
|
+
errors += f"Choice {index}: Tool call {tool_call} has an invalid tool name: {tool_name}\n"
|
104
|
+
|
105
|
+
tool_args = tool_call.get("function", {}).get("arguments", None)
|
106
|
+
try:
|
107
|
+
tool_args = json.loads(tool_args)
|
108
|
+
except json.JSONDecodeError:
|
109
|
+
errors += f"Choice {index}: Tool call {tool_call} failed to parse arguments into JSON\n"
|
110
|
+
|
111
|
+
parsed_tool_calls.append(
|
112
|
+
ToolCall(
|
113
|
+
id=tool_call["id"],
|
114
|
+
function=Function(
|
115
|
+
name=tool_name,
|
116
|
+
arguments=tool_args,
|
117
|
+
),
|
118
|
+
)
|
119
|
+
)
|
120
|
+
tool_calls = parsed_tool_calls
|
121
|
+
|
122
|
+
json_message = None
|
123
|
+
if request.json_mode or (request.structured_outputs is not None):
|
124
|
+
try:
|
125
|
+
json_message = json.loads(message.get("content", "{}"))
|
126
|
+
except json.JSONDecodeError:
|
127
|
+
errors += f"Choice {index}: Message failed to parse into JSON\n"
|
128
|
+
|
129
|
+
# Handle logprobs
|
130
|
+
logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None
|
131
|
+
if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None:
|
132
|
+
logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = []
|
133
|
+
for logprob in choice["logprobs"]["content"]:
|
134
|
+
if logprob.get("top_logprobs", None):
|
135
|
+
curr_logprob_infos: list[dict[str, Any]] = []
|
136
|
+
for top_logprob in logprob.get("top_logprobs", []):
|
137
|
+
curr_logprob_infos.append(
|
138
|
+
{
|
139
|
+
"token": top_logprob.get("token", ""),
|
140
|
+
"logprob": top_logprob.get("logprob", 0),
|
141
|
+
"bytes": top_logprob.get("bytes", 0),
|
142
|
+
}
|
143
|
+
)
|
144
|
+
logprobs_list.append(curr_logprob_infos)
|
145
|
+
else:
|
146
|
+
logprobs_list.append(
|
147
|
+
{
|
148
|
+
"token": logprob.get("token", ""),
|
149
|
+
"logprob": logprob.get("logprob", 0),
|
150
|
+
"bytes": logprob.get("bytes", 0),
|
151
|
+
}
|
152
|
+
)
|
153
|
+
logprobs = logprobs_list
|
154
|
+
|
155
|
+
# Handle extras that OpenAI or Azure OpenAI return
|
156
|
+
if choice.get("content_filter_results", None):
|
157
|
+
choice_extras["content_filter_results"] = choice["content_filter_results"]
|
158
|
+
|
159
|
+
choices.append(
|
160
|
+
ChatCompletionChoice(
|
161
|
+
message=AssistantMessage(
|
162
|
+
content=message.get("content") or "",
|
163
|
+
refusal=message.get("refusal", None),
|
164
|
+
tool_calls=tool_calls,
|
165
|
+
),
|
166
|
+
finish_reason=finish_reason,
|
167
|
+
json_message=json_message,
|
168
|
+
logprobs=logprobs,
|
169
|
+
extras=choice_extras,
|
170
|
+
)
|
171
|
+
)
|
172
|
+
|
173
|
+
completion_tokens = response["usage"].get("completion_tokens", -1)
|
174
|
+
prompt_tokens = response["usage"].get("prompt_tokens", -1)
|
175
|
+
completion_detailed_tokens = response["usage"].get("completion_detailed_tokens", None)
|
176
|
+
prompt_detailed_tokens = response["usage"].get("prompt_detailed_tokens", None)
|
177
|
+
system_fingerprint = response.get("system_fingerprint", None)
|
178
|
+
|
179
|
+
extras["prompt_filter_results"] = response.get("prompt_filter_results", None)
|
180
|
+
|
181
|
+
return ChatCompletionResponse(
|
182
|
+
choices=choices,
|
183
|
+
errors=errors.strip(),
|
184
|
+
extras=extras,
|
185
|
+
completion_tokens=completion_tokens,
|
186
|
+
prompt_tokens=prompt_tokens,
|
187
|
+
completion_detailed_tokens=completion_detailed_tokens,
|
188
|
+
prompt_detailed_tokens=prompt_detailed_tokens,
|
189
|
+
system_fingerprint=system_fingerprint,
|
190
|
+
response_duration=response_duration,
|
191
|
+
)
|
192
|
+
|
193
|
+
|
194
|
+
def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]:
|
195
|
+
"""Creates a callable that instantiates and uses an OpenAI client.
|
196
|
+
|
197
|
+
Args:
|
198
|
+
client_class: The OpenAI client class to instantiate (OpenAI or AzureOpenAI)
|
199
|
+
**client_args: Arguments to pass to the client constructor
|
200
|
+
|
201
|
+
Returns:
|
202
|
+
A callable that creates a client and returns completion results
|
203
|
+
"""
|
204
|
+
filtered_args = {k: v for k, v in client_args.items() if v is not None}
|
205
|
+
|
206
|
+
def client_callable(**kwargs: Any) -> Any:
|
207
|
+
client = client_class(**filtered_args)
|
208
|
+
completion = client.chat.completions.create(**kwargs)
|
209
|
+
return completion.to_dict()
|
210
|
+
|
211
|
+
return client_callable
|
212
|
+
|
213
|
+
|
214
|
+
class InvalidOAIAPITypeError(Exception):
|
215
|
+
"""Raised when an invalid OAIAPIType string is provided."""
|
216
|
+
|
217
|
+
|
218
|
+
def openai_client(
|
219
|
+
api_type: Literal["openai", "azure_openai"] = "openai",
|
220
|
+
api_key: str | None = None,
|
221
|
+
organization: str | None = None,
|
222
|
+
aoai_api_version: str = "2024-06-01",
|
223
|
+
azure_endpoint: str | None = None,
|
224
|
+
timeout: float | None = None,
|
225
|
+
max_retries: int | None = None,
|
226
|
+
) -> Callable[..., Any]:
|
227
|
+
"""Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters.
|
228
|
+
|
229
|
+
It is preferred to use RBAC authentication for Azure OpenAI. You must be signed in with the Azure CLI and have correct role assigned.
|
230
|
+
See https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521
|
231
|
+
|
232
|
+
Args:
|
233
|
+
api_type (str, optional): Type of the API to be used. Accepted values are 'openai' or 'azure_openai'.
|
234
|
+
Defaults to 'openai'.
|
235
|
+
api_key (str, optional): The API key to authenticate the client. If not provided,
|
236
|
+
OpenAI automatically uses `OPENAI_API_KEY` from the environment.
|
237
|
+
If provided for Azure OpenAI, it will be used for authentication instead of the Azure AD token provider.
|
238
|
+
organization (str, optional): The ID of the organization. If not provided,
|
239
|
+
OpenAI automotically uses `OPENAI_ORG_ID` from the environment.
|
240
|
+
aoai_api_version (str, optional): Only applicable if using Azure OpenAI https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning
|
241
|
+
azure_endpoint (str, optional): The endpoint to use for Azure OpenAI.
|
242
|
+
timeout (float, optional): By default requests time out after 10 minutes.
|
243
|
+
max_retries (int, optional): Certain errors are automatically retried 2 times by default,
|
244
|
+
with a short exponential backoff. Connection errors (for example, due to a network connectivity problem),
|
245
|
+
408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default.
|
246
|
+
|
247
|
+
Returns:
|
248
|
+
Callable[..., Any]: A callable that creates a client and returns completion results
|
249
|
+
|
250
|
+
|
251
|
+
Raises:
|
252
|
+
InvalidOAIAPITypeError: If an invalid API type string is provided.
|
253
|
+
NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai').
|
254
|
+
"""
|
255
|
+
if api_type not in ["openai", "azure_openai"]:
|
256
|
+
raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.")
|
257
|
+
|
258
|
+
if api_type == "openai":
|
259
|
+
return create_client_callable(
|
260
|
+
OpenAI,
|
261
|
+
api_key=api_key,
|
262
|
+
organization=organization,
|
263
|
+
timeout=timeout,
|
264
|
+
max_retries=max_retries,
|
265
|
+
)
|
266
|
+
elif api_type == "azure_openai":
|
267
|
+
if api_key:
|
268
|
+
return create_client_callable(
|
269
|
+
AzureOpenAI,
|
270
|
+
api_version=aoai_api_version,
|
271
|
+
azure_endpoint=azure_endpoint,
|
272
|
+
api_key=api_key,
|
273
|
+
timeout=timeout,
|
274
|
+
max_retries=max_retries,
|
275
|
+
)
|
276
|
+
else:
|
277
|
+
azure_credential = DefaultAzureCredential()
|
278
|
+
ad_token_provider = get_bearer_token_provider(
|
279
|
+
azure_credential, "https://cognitiveservices.azure.com/.default"
|
280
|
+
)
|
281
|
+
return create_client_callable(
|
282
|
+
AzureOpenAI,
|
283
|
+
api_version=aoai_api_version,
|
284
|
+
azure_endpoint=azure_endpoint,
|
285
|
+
azure_ad_token_provider=ad_token_provider,
|
286
|
+
timeout=timeout,
|
287
|
+
max_retries=max_retries,
|
288
|
+
)
|
289
|
+
else:
|
290
|
+
raise NotImplementedError(f"API type '{api_type}' is invalid.")
|