not-again-ai 0.16.1__tar.gz → 0.17.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/PKG-INFO +1 -1
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/pyproject.toml +4 -1
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/__init__.py +2 -2
- not_again_ai-0.17.0/src/not_again_ai/llm/chat_completion/interface.py +61 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/providers/ollama_api.py +80 -12
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/providers/openai_api.py +180 -38
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/types.py +44 -0
- not_again_ai-0.16.1/src/not_again_ai/llm/chat_completion/interface.py +0 -32
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/LICENSE +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/README.md +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/base/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/base/file_system.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/base/parallel.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/data/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/data/web.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/providers/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/interface.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/ollama_api.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/openai_api.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/types.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/compile_prompt.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/interface.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/providers/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/providers/openai_tiktoken.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/types.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/py.typed +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/statistics/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/statistics/dependence.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/__init__.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/barplots.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/distributions.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/scatterplot.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/time_series.py +0 -0
- {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/utils.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
[project]
|
2
2
|
name = "not-again-ai"
|
3
|
-
version = "0.
|
3
|
+
version = "0.17.0"
|
4
4
|
description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place."
|
5
5
|
authors = [
|
6
6
|
{ name = "DaveCoDev", email = "dave.co.dev@gmail.com" }
|
@@ -70,6 +70,7 @@ nox-poetry = "*"
|
|
70
70
|
|
71
71
|
[tool.poetry.group.test.dependencies]
|
72
72
|
pytest = "*"
|
73
|
+
pytest-asyncio = "*"
|
73
74
|
pytest-cov = "*"
|
74
75
|
pytest-randomly = "*"
|
75
76
|
|
@@ -153,6 +154,8 @@ filterwarnings = [
|
|
153
154
|
# "ignore::DeprecationWarning:typer",
|
154
155
|
"ignore::pytest.PytestUnraisableExceptionWarning"
|
155
156
|
]
|
157
|
+
asyncio_mode = "auto"
|
158
|
+
asyncio_default_fixture_loop_scope = "function"
|
156
159
|
|
157
160
|
[tool.coverage.run]
|
158
161
|
branch = true
|
{not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/__init__.py
RENAMED
@@ -1,4 +1,4 @@
|
|
1
|
-
from not_again_ai.llm.chat_completion.interface import chat_completion
|
1
|
+
from not_again_ai.llm.chat_completion.interface import chat_completion, chat_completion_stream
|
2
2
|
from not_again_ai.llm.chat_completion.types import ChatCompletionRequest
|
3
3
|
|
4
|
-
__all__ = ["ChatCompletionRequest", "chat_completion"]
|
4
|
+
__all__ = ["ChatCompletionRequest", "chat_completion", "chat_completion_stream"]
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from collections.abc import AsyncGenerator, Callable
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream
|
5
|
+
from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream
|
6
|
+
from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse
|
7
|
+
|
8
|
+
|
9
|
+
def chat_completion(
|
10
|
+
request: ChatCompletionRequest,
|
11
|
+
provider: str,
|
12
|
+
client: Callable[..., Any],
|
13
|
+
) -> ChatCompletionResponse:
|
14
|
+
"""Get a chat completion response from the given provider. Currently supported providers:
|
15
|
+
- `openai` - OpenAI
|
16
|
+
- `azure_openai` - Azure OpenAI
|
17
|
+
- `ollama` - Ollama
|
18
|
+
|
19
|
+
Args:
|
20
|
+
request: Request parameter object
|
21
|
+
provider: The supported provider name
|
22
|
+
client: Client information, see the provider's implementation for what can be provided
|
23
|
+
|
24
|
+
Returns:
|
25
|
+
ChatCompletionResponse: The chat completion response.
|
26
|
+
"""
|
27
|
+
if provider == "openai" or provider == "azure_openai":
|
28
|
+
return openai_chat_completion(request, client)
|
29
|
+
elif provider == "ollama":
|
30
|
+
return ollama_chat_completion(request, client)
|
31
|
+
else:
|
32
|
+
raise ValueError(f"Provider {provider} not supported")
|
33
|
+
|
34
|
+
|
35
|
+
async def chat_completion_stream(
|
36
|
+
request: ChatCompletionRequest,
|
37
|
+
provider: str,
|
38
|
+
client: Callable[..., Any],
|
39
|
+
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
40
|
+
"""Stream a chat completion response from the given provider. Currently supported providers:
|
41
|
+
- `openai` - OpenAI
|
42
|
+
- `azure_openai` - Azure OpenAI
|
43
|
+
- `ollama` - Ollama
|
44
|
+
|
45
|
+
Args:
|
46
|
+
request: Request parameter object
|
47
|
+
provider: The supported provider name
|
48
|
+
client: Client information, see the provider's implementation for what can be provided
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
AsyncGenerator[ChatCompletionChunk, None]
|
52
|
+
"""
|
53
|
+
request.stream = True
|
54
|
+
if provider == "openai" or provider == "azure_openai":
|
55
|
+
async for chunk in openai_chat_completion_stream(request, client):
|
56
|
+
yield chunk
|
57
|
+
elif provider == "ollama":
|
58
|
+
async for chunk in ollama_chat_completion_stream(request, client):
|
59
|
+
yield chunk
|
60
|
+
else:
|
61
|
+
raise ValueError(f"Provider {provider} not supported")
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from collections.abc import Callable
|
1
|
+
from collections.abc import AsyncGenerator, Callable
|
2
2
|
import json
|
3
3
|
import os
|
4
4
|
import re
|
@@ -6,14 +6,20 @@ import time
|
|
6
6
|
from typing import Any, Literal, cast
|
7
7
|
|
8
8
|
from loguru import logger
|
9
|
-
from ollama import ChatResponse, Client, ResponseError
|
9
|
+
from ollama import AsyncClient, ChatResponse, Client, ResponseError
|
10
10
|
|
11
11
|
from not_again_ai.llm.chat_completion.types import (
|
12
12
|
AssistantMessage,
|
13
13
|
ChatCompletionChoice,
|
14
|
+
ChatCompletionChoiceStream,
|
15
|
+
ChatCompletionChunk,
|
16
|
+
ChatCompletionDelta,
|
14
17
|
ChatCompletionRequest,
|
15
18
|
ChatCompletionResponse,
|
16
19
|
Function,
|
20
|
+
PartialFunction,
|
21
|
+
PartialToolCall,
|
22
|
+
Role,
|
17
23
|
ToolCall,
|
18
24
|
)
|
19
25
|
|
@@ -51,14 +57,8 @@ def validate(request: ChatCompletionRequest) -> None:
|
|
51
57
|
raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.")
|
52
58
|
|
53
59
|
|
54
|
-
def
|
55
|
-
request: ChatCompletionRequest,
|
56
|
-
client: Callable[..., Any],
|
57
|
-
) -> ChatCompletionResponse:
|
58
|
-
validate(request)
|
59
|
-
|
60
|
+
def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]:
|
60
61
|
kwargs = request.model_dump(mode="json", exclude_none=True)
|
61
|
-
|
62
62
|
# For each key in OLLAMA_PARAMETER_MAP
|
63
63
|
# If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP
|
64
64
|
# If it is None, remove that key from kwargs
|
@@ -141,6 +141,16 @@ def ollama_chat_completion(
|
|
141
141
|
logger.warning("Ollama model only supports a single image per message. Using only the first images.")
|
142
142
|
message["images"] = images
|
143
143
|
|
144
|
+
return kwargs
|
145
|
+
|
146
|
+
|
147
|
+
def ollama_chat_completion(
|
148
|
+
request: ChatCompletionRequest,
|
149
|
+
client: Callable[..., Any],
|
150
|
+
) -> ChatCompletionResponse:
|
151
|
+
validate(request)
|
152
|
+
kwargs = format_kwargs(request)
|
153
|
+
|
144
154
|
try:
|
145
155
|
start_time = time.time()
|
146
156
|
response: ChatResponse = client(**kwargs)
|
@@ -164,7 +174,7 @@ def ollama_chat_completion(
|
|
164
174
|
tool_name = tool_call.function.name
|
165
175
|
if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
|
166
176
|
errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
|
167
|
-
tool_args = tool_call.function.arguments
|
177
|
+
tool_args = dict(tool_call.function.arguments)
|
168
178
|
parsed_tool_calls.append(
|
169
179
|
ToolCall(
|
170
180
|
id="",
|
@@ -206,7 +216,65 @@ def ollama_chat_completion(
|
|
206
216
|
)
|
207
217
|
|
208
218
|
|
209
|
-
def
|
219
|
+
async def ollama_chat_completion_stream(
|
220
|
+
request: ChatCompletionRequest,
|
221
|
+
client: Callable[..., Any],
|
222
|
+
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
223
|
+
validate(request)
|
224
|
+
kwargs = format_kwargs(request)
|
225
|
+
|
226
|
+
start_time = time.time()
|
227
|
+
stream = await client(**kwargs)
|
228
|
+
|
229
|
+
async for chunk in stream:
|
230
|
+
errors = ""
|
231
|
+
# Handle tool calls
|
232
|
+
tool_calls: list[PartialToolCall] | None = None
|
233
|
+
if chunk.message.tool_calls:
|
234
|
+
parsed_tool_calls: list[PartialToolCall] = []
|
235
|
+
for tool_call in chunk.message.tool_calls:
|
236
|
+
tool_name = tool_call.function.name
|
237
|
+
if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
|
238
|
+
errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
|
239
|
+
tool_args = tool_call.function.arguments
|
240
|
+
|
241
|
+
parsed_tool_calls.append(
|
242
|
+
PartialToolCall(
|
243
|
+
id="",
|
244
|
+
function=PartialFunction(
|
245
|
+
name=tool_name,
|
246
|
+
arguments=tool_args,
|
247
|
+
),
|
248
|
+
)
|
249
|
+
)
|
250
|
+
tool_calls = parsed_tool_calls
|
251
|
+
|
252
|
+
current_time = time.time()
|
253
|
+
response_duration = round(current_time - start_time, 4)
|
254
|
+
|
255
|
+
delta = ChatCompletionDelta(
|
256
|
+
content=chunk.message.content or "",
|
257
|
+
role=Role.ASSISTANT,
|
258
|
+
tool_calls=tool_calls,
|
259
|
+
)
|
260
|
+
choice_obj = ChatCompletionChoiceStream(
|
261
|
+
delta=delta,
|
262
|
+
finish_reason=chunk.done_reason,
|
263
|
+
index=0,
|
264
|
+
)
|
265
|
+
chunk_obj = ChatCompletionChunk(
|
266
|
+
choices=[choice_obj],
|
267
|
+
errors=errors.strip(),
|
268
|
+
completion_tokens=chunk.get("eval_count", None),
|
269
|
+
prompt_tokens=chunk.get("prompt_eval_count", None),
|
270
|
+
response_duration=response_duration,
|
271
|
+
)
|
272
|
+
yield chunk_obj
|
273
|
+
|
274
|
+
|
275
|
+
def ollama_client(
|
276
|
+
host: str | None = None, timeout: float | None = None, async_client: bool = False
|
277
|
+
) -> Callable[..., Any]:
|
210
278
|
"""Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable.
|
211
279
|
|
212
280
|
Args:
|
@@ -226,7 +294,7 @@ def ollama_client(host: str | None = None, timeout: float | None = None) -> Call
|
|
226
294
|
host = "http://localhost:11434"
|
227
295
|
|
228
296
|
def client_callable(**kwargs: Any) -> Any:
|
229
|
-
client = Client(host=host, timeout=timeout)
|
297
|
+
client = AsyncClient(host=host, timeout=timeout) if async_client else Client(host=host, timeout=timeout)
|
230
298
|
return client.chat(**kwargs)
|
231
299
|
|
232
300
|
return client_callable
|
@@ -1,17 +1,23 @@
|
|
1
|
-
from collections.abc import Callable
|
1
|
+
from collections.abc import AsyncGenerator, Callable, Coroutine
|
2
2
|
import json
|
3
3
|
import time
|
4
4
|
from typing import Any, Literal
|
5
5
|
|
6
6
|
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
|
7
|
-
from openai import AzureOpenAI, OpenAI
|
7
|
+
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
8
8
|
|
9
9
|
from not_again_ai.llm.chat_completion.types import (
|
10
10
|
AssistantMessage,
|
11
11
|
ChatCompletionChoice,
|
12
|
+
ChatCompletionChoiceStream,
|
13
|
+
ChatCompletionChunk,
|
14
|
+
ChatCompletionDelta,
|
12
15
|
ChatCompletionRequest,
|
13
16
|
ChatCompletionResponse,
|
14
17
|
Function,
|
18
|
+
PartialFunction,
|
19
|
+
PartialToolCall,
|
20
|
+
Role,
|
15
21
|
ToolCall,
|
16
22
|
)
|
17
23
|
|
@@ -36,12 +42,7 @@ def validate(request: ChatCompletionRequest) -> None:
|
|
36
42
|
raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.")
|
37
43
|
|
38
44
|
|
39
|
-
def
|
40
|
-
request: ChatCompletionRequest,
|
41
|
-
client: Callable[..., Any],
|
42
|
-
) -> ChatCompletionResponse:
|
43
|
-
validate(request)
|
44
|
-
|
45
|
+
def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]:
|
45
46
|
# Format the response format parameters to be compatible with OpenAI API
|
46
47
|
if request.json_mode:
|
47
48
|
response_format: dict[str, Any] = {"type": "json_object"}
|
@@ -61,7 +62,6 @@ def openai_chat_completion(
|
|
61
62
|
elif value is None and key in kwargs:
|
62
63
|
del kwargs[key]
|
63
64
|
|
64
|
-
# Iterate over each message and
|
65
65
|
for message in kwargs["messages"]:
|
66
66
|
role = message.get("role", None)
|
67
67
|
# For each ToolMessage, change the "name" field to be named "tool_call_id" instead
|
@@ -84,6 +84,49 @@ def openai_chat_completion(
|
|
84
84
|
if request.tool_choice is not None and request.tool_choice not in ["none", "auto", "required"]:
|
85
85
|
kwargs["tool_choice"] = {"type": "function", "function": {"name": request.tool_choice}}
|
86
86
|
|
87
|
+
return kwargs
|
88
|
+
|
89
|
+
|
90
|
+
def process_logprobs(logprobs_content: list[dict[str, Any]]) -> list[dict[str, Any] | list[dict[str, Any]]]:
|
91
|
+
"""Process logprobs content from OpenAI API response.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
logprobs_content: List of logprob entries from the API response
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
Processed logprobs list containing either single token info or lists of top token infos
|
98
|
+
"""
|
99
|
+
logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = []
|
100
|
+
for logprob in logprobs_content:
|
101
|
+
if logprob.get("top_logprobs", None):
|
102
|
+
curr_logprob_infos: list[dict[str, Any]] = []
|
103
|
+
for top_logprob in logprob.get("top_logprobs", []):
|
104
|
+
curr_logprob_infos.append(
|
105
|
+
{
|
106
|
+
"token": top_logprob.get("token", ""),
|
107
|
+
"logprob": top_logprob.get("logprob", 0),
|
108
|
+
"bytes": top_logprob.get("bytes", 0),
|
109
|
+
}
|
110
|
+
)
|
111
|
+
logprobs_list.append(curr_logprob_infos)
|
112
|
+
else:
|
113
|
+
logprobs_list.append(
|
114
|
+
{
|
115
|
+
"token": logprob.get("token", ""),
|
116
|
+
"logprob": logprob.get("logprob", 0),
|
117
|
+
"bytes": logprob.get("bytes", 0),
|
118
|
+
}
|
119
|
+
)
|
120
|
+
return logprobs_list
|
121
|
+
|
122
|
+
|
123
|
+
def openai_chat_completion(
|
124
|
+
request: ChatCompletionRequest,
|
125
|
+
client: Callable[..., Any],
|
126
|
+
) -> ChatCompletionResponse:
|
127
|
+
validate(request)
|
128
|
+
kwargs = format_kwargs(request)
|
129
|
+
|
87
130
|
start_time = time.time()
|
88
131
|
response = client(**kwargs)
|
89
132
|
end_time = time.time()
|
@@ -133,28 +176,7 @@ def openai_chat_completion(
|
|
133
176
|
# Handle logprobs
|
134
177
|
logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None
|
135
178
|
if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None:
|
136
|
-
|
137
|
-
for logprob in choice["logprobs"]["content"]:
|
138
|
-
if logprob.get("top_logprobs", None):
|
139
|
-
curr_logprob_infos: list[dict[str, Any]] = []
|
140
|
-
for top_logprob in logprob.get("top_logprobs", []):
|
141
|
-
curr_logprob_infos.append(
|
142
|
-
{
|
143
|
-
"token": top_logprob.get("token", ""),
|
144
|
-
"logprob": top_logprob.get("logprob", 0),
|
145
|
-
"bytes": top_logprob.get("bytes", 0),
|
146
|
-
}
|
147
|
-
)
|
148
|
-
logprobs_list.append(curr_logprob_infos)
|
149
|
-
else:
|
150
|
-
logprobs_list.append(
|
151
|
-
{
|
152
|
-
"token": logprob.get("token", ""),
|
153
|
-
"logprob": logprob.get("logprob", 0),
|
154
|
-
"bytes": logprob.get("bytes", 0),
|
155
|
-
}
|
156
|
-
)
|
157
|
-
logprobs = logprobs_list
|
179
|
+
logprobs = process_logprobs(choice["logprobs"]["content"])
|
158
180
|
|
159
181
|
# Handle extras that OpenAI or Azure OpenAI return
|
160
182
|
if choice.get("content_filter_results", None):
|
@@ -195,6 +217,107 @@ def openai_chat_completion(
|
|
195
217
|
)
|
196
218
|
|
197
219
|
|
220
|
+
async def openai_chat_completion_stream(
|
221
|
+
request: ChatCompletionRequest,
|
222
|
+
client: Callable[..., Any],
|
223
|
+
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
224
|
+
validate(request)
|
225
|
+
kwargs = format_kwargs(request)
|
226
|
+
|
227
|
+
start_time = time.time()
|
228
|
+
stream = await client(**kwargs)
|
229
|
+
|
230
|
+
async for chunk in stream:
|
231
|
+
errors = ""
|
232
|
+
# This kind of a hack. To make this processing generic for clients that do not return the correct
|
233
|
+
# data structure, we convert the chunk to a dict
|
234
|
+
if not isinstance(chunk, dict):
|
235
|
+
chunk = chunk.to_dict()
|
236
|
+
|
237
|
+
choices: list[ChatCompletionChoiceStream] = []
|
238
|
+
for choice in chunk["choices"]:
|
239
|
+
content = choice.get("delta", {}).get("content", "")
|
240
|
+
if not content:
|
241
|
+
content = ""
|
242
|
+
|
243
|
+
role = Role.ASSISTANT
|
244
|
+
if choice.get("delta", {}).get("role", None):
|
245
|
+
role = Role(choice["delta"]["role"])
|
246
|
+
|
247
|
+
# Handle tool calls
|
248
|
+
tool_calls: list[PartialToolCall] | None = None
|
249
|
+
if choice["delta"].get("tool_calls", None):
|
250
|
+
parsed_tool_calls: list[PartialToolCall] = []
|
251
|
+
for tool_call in choice["delta"]["tool_calls"]:
|
252
|
+
tool_name = tool_call.get("function", {}).get("name", None)
|
253
|
+
if not tool_name:
|
254
|
+
tool_name = ""
|
255
|
+
tool_args = tool_call.get("function", {}).get("arguments", "")
|
256
|
+
if not tool_args:
|
257
|
+
tool_args = ""
|
258
|
+
|
259
|
+
tool_id = tool_call.get("id", None)
|
260
|
+
parsed_tool_calls.append(
|
261
|
+
PartialToolCall(
|
262
|
+
id=tool_id,
|
263
|
+
function=PartialFunction(
|
264
|
+
name=tool_name,
|
265
|
+
arguments=tool_args,
|
266
|
+
),
|
267
|
+
)
|
268
|
+
)
|
269
|
+
tool_calls = parsed_tool_calls
|
270
|
+
|
271
|
+
refusal = None
|
272
|
+
if choice["delta"].get("refusal", None):
|
273
|
+
refusal = choice["delta"]["refusal"]
|
274
|
+
|
275
|
+
delta = ChatCompletionDelta(
|
276
|
+
content=content,
|
277
|
+
role=role,
|
278
|
+
tool_calls=tool_calls,
|
279
|
+
refusal=refusal,
|
280
|
+
)
|
281
|
+
|
282
|
+
index = choice.get("index", 0)
|
283
|
+
finish_reason = choice.get("finish_reason", None)
|
284
|
+
|
285
|
+
# Handle logprobs
|
286
|
+
logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None
|
287
|
+
if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None:
|
288
|
+
logprobs = process_logprobs(choice["logprobs"]["content"])
|
289
|
+
|
290
|
+
choice_obj = ChatCompletionChoiceStream(
|
291
|
+
delta=delta,
|
292
|
+
finish_reason=finish_reason,
|
293
|
+
logprobs=logprobs,
|
294
|
+
index=index,
|
295
|
+
)
|
296
|
+
choices.append(choice_obj)
|
297
|
+
|
298
|
+
current_time = time.time()
|
299
|
+
response_duration = round(current_time - start_time, 4)
|
300
|
+
|
301
|
+
if "usage" in chunk and chunk["usage"] is not None:
|
302
|
+
completion_tokens = chunk["usage"].get("completion_tokens", None)
|
303
|
+
prompt_tokens = chunk["usage"].get("prompt_tokens", None)
|
304
|
+
system_fingerprint = chunk.get("system_fingerprint", None)
|
305
|
+
else:
|
306
|
+
completion_tokens = None
|
307
|
+
prompt_tokens = None
|
308
|
+
system_fingerprint = None
|
309
|
+
|
310
|
+
chunk_obj = ChatCompletionChunk(
|
311
|
+
choices=choices,
|
312
|
+
errors=errors.strip(),
|
313
|
+
completion_tokens=completion_tokens,
|
314
|
+
prompt_tokens=prompt_tokens,
|
315
|
+
response_duration=response_duration,
|
316
|
+
system_fingerprint=system_fingerprint,
|
317
|
+
)
|
318
|
+
yield chunk_obj
|
319
|
+
|
320
|
+
|
198
321
|
def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]:
|
199
322
|
"""Creates a callable that instantiates and uses an OpenAI client.
|
200
323
|
|
@@ -215,6 +338,20 @@ def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_ar
|
|
215
338
|
return client_callable
|
216
339
|
|
217
340
|
|
341
|
+
def create_client_callable_stream(
|
342
|
+
client_class: type[AsyncOpenAI | AsyncAzureOpenAI], **client_args: Any
|
343
|
+
) -> Callable[..., Any]:
|
344
|
+
filtered_args = {k: v for k, v in client_args.items() if v is not None}
|
345
|
+
|
346
|
+
def client_callable(**kwargs: Any) -> Coroutine[Any, Any, Any]:
|
347
|
+
client = client_class(**filtered_args)
|
348
|
+
kwargs["stream_options"] = {"include_usage": True}
|
349
|
+
stream = client.chat.completions.create(**kwargs)
|
350
|
+
return stream
|
351
|
+
|
352
|
+
return client_callable
|
353
|
+
|
354
|
+
|
218
355
|
class InvalidOAIAPITypeError(Exception):
|
219
356
|
"""Raised when an invalid OAIAPIType string is provided."""
|
220
357
|
|
@@ -227,6 +364,7 @@ def openai_client(
|
|
227
364
|
azure_endpoint: str | None = None,
|
228
365
|
timeout: float | None = None,
|
229
366
|
max_retries: int | None = None,
|
367
|
+
async_client: bool = False,
|
230
368
|
) -> Callable[..., Any]:
|
231
369
|
"""Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters.
|
232
370
|
|
@@ -247,11 +385,11 @@ def openai_client(
|
|
247
385
|
max_retries (int, optional): Certain errors are automatically retried 2 times by default,
|
248
386
|
with a short exponential backoff. Connection errors (for example, due to a network connectivity problem),
|
249
387
|
408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default.
|
388
|
+
async_client (bool, optional): Whether to return an async client. Defaults to False.
|
250
389
|
|
251
390
|
Returns:
|
252
391
|
Callable[..., Any]: A callable that creates a client and returns completion results
|
253
392
|
|
254
|
-
|
255
393
|
Raises:
|
256
394
|
InvalidOAIAPITypeError: If an invalid API type string is provided.
|
257
395
|
NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai').
|
@@ -260,17 +398,21 @@ def openai_client(
|
|
260
398
|
raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.")
|
261
399
|
|
262
400
|
if api_type == "openai":
|
263
|
-
|
264
|
-
|
401
|
+
client_class = AsyncOpenAI if async_client else OpenAI
|
402
|
+
callable_creator = create_client_callable_stream if async_client else create_client_callable
|
403
|
+
return callable_creator(
|
404
|
+
client_class, # type: ignore
|
265
405
|
api_key=api_key,
|
266
406
|
organization=organization,
|
267
407
|
timeout=timeout,
|
268
408
|
max_retries=max_retries,
|
269
409
|
)
|
270
410
|
elif api_type == "azure_openai":
|
411
|
+
azure_client_class = AsyncAzureOpenAI if async_client else AzureOpenAI
|
412
|
+
callable_creator = create_client_callable_stream if async_client else create_client_callable
|
271
413
|
if api_key:
|
272
|
-
return
|
273
|
-
|
414
|
+
return callable_creator(
|
415
|
+
azure_client_class, # type: ignore
|
274
416
|
api_version=aoai_api_version,
|
275
417
|
azure_endpoint=azure_endpoint,
|
276
418
|
api_key=api_key,
|
@@ -282,8 +424,8 @@ def openai_client(
|
|
282
424
|
ad_token_provider = get_bearer_token_provider(
|
283
425
|
azure_credential, "https://cognitiveservices.azure.com/.default"
|
284
426
|
)
|
285
|
-
return
|
286
|
-
|
427
|
+
return callable_creator(
|
428
|
+
azure_client_class, # type: ignore
|
287
429
|
api_version=aoai_api_version,
|
288
430
|
azure_endpoint=azure_endpoint,
|
289
431
|
azure_ad_token_provider=ad_token_provider,
|
@@ -52,12 +52,23 @@ class Function(BaseModel):
|
|
52
52
|
arguments: dict[str, Any]
|
53
53
|
|
54
54
|
|
55
|
+
class PartialFunction(BaseModel):
|
56
|
+
name: str
|
57
|
+
arguments: str | dict[str, Any]
|
58
|
+
|
59
|
+
|
55
60
|
class ToolCall(BaseModel):
|
56
61
|
id: str
|
57
62
|
function: Function
|
58
63
|
type: Literal["function"] = "function"
|
59
64
|
|
60
65
|
|
66
|
+
class PartialToolCall(BaseModel):
|
67
|
+
id: str | None
|
68
|
+
function: PartialFunction
|
69
|
+
type: Literal["function"] = "function"
|
70
|
+
|
71
|
+
|
61
72
|
class DeveloperMessage(BaseMessage[str]):
|
62
73
|
role: Literal[Role.DEVELOPER] = Role.DEVELOPER
|
63
74
|
|
@@ -87,6 +98,7 @@ MessageT = AssistantMessage | DeveloperMessage | SystemMessage | ToolMessage | U
|
|
87
98
|
class ChatCompletionRequest(BaseModel):
|
88
99
|
messages: list[MessageT]
|
89
100
|
model: str
|
101
|
+
stream: bool = Field(default=False)
|
90
102
|
|
91
103
|
max_completion_tokens: int | None = Field(default=None)
|
92
104
|
context_window: int | None = Field(default=None)
|
@@ -148,3 +160,35 @@ class ChatCompletionResponse(BaseModel):
|
|
148
160
|
system_fingerprint: str | None = Field(default=None)
|
149
161
|
|
150
162
|
extras: Any | None = Field(default=None)
|
163
|
+
|
164
|
+
|
165
|
+
class ChatCompletionDelta(BaseModel):
|
166
|
+
content: str
|
167
|
+
role: Role = Field(default=Role.ASSISTANT)
|
168
|
+
|
169
|
+
tool_calls: list[PartialToolCall] | None = Field(default=None)
|
170
|
+
|
171
|
+
refusal: str | None = Field(default=None)
|
172
|
+
|
173
|
+
|
174
|
+
class ChatCompletionChoiceStream(BaseModel):
|
175
|
+
delta: ChatCompletionDelta
|
176
|
+
index: int
|
177
|
+
finish_reason: Literal["stop", "length", "tool_calls", "content_filter"] | None
|
178
|
+
|
179
|
+
logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None)
|
180
|
+
|
181
|
+
extras: Any | None = Field(default=None)
|
182
|
+
|
183
|
+
|
184
|
+
class ChatCompletionChunk(BaseModel):
|
185
|
+
choices: list[ChatCompletionChoiceStream]
|
186
|
+
|
187
|
+
errors: str = Field(default="")
|
188
|
+
|
189
|
+
completion_tokens: int | None = Field(default=None)
|
190
|
+
prompt_tokens: int | None = Field(default=None)
|
191
|
+
response_duration: float | None = Field(default=None)
|
192
|
+
|
193
|
+
system_fingerprint: str | None = Field(default=None)
|
194
|
+
extras: Any | None = Field(default=None)
|
@@ -1,32 +0,0 @@
|
|
1
|
-
from collections.abc import Callable
|
2
|
-
from typing import Any
|
3
|
-
|
4
|
-
from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion
|
5
|
-
from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion
|
6
|
-
from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, ChatCompletionResponse
|
7
|
-
|
8
|
-
|
9
|
-
def chat_completion(
|
10
|
-
request: ChatCompletionRequest,
|
11
|
-
provider: str,
|
12
|
-
client: Callable[..., Any],
|
13
|
-
) -> ChatCompletionResponse:
|
14
|
-
"""Get a chat completion response from the given provider. Currently supported providers:
|
15
|
-
- `openai` - OpenAI
|
16
|
-
- `azure_openai` - Azure OpenAI
|
17
|
-
- `ollama` - Ollama
|
18
|
-
|
19
|
-
Args:
|
20
|
-
request: Request parameter object
|
21
|
-
provider: The supported provider name
|
22
|
-
client: Client information, see the provider's implementation for what can be provided
|
23
|
-
|
24
|
-
Returns:
|
25
|
-
ChatCompletionResponse: The chat completion response.
|
26
|
-
"""
|
27
|
-
if provider == "openai" or provider == "azure_openai":
|
28
|
-
return openai_chat_completion(request, client)
|
29
|
-
elif provider == "ollama":
|
30
|
-
return ollama_chat_completion(request, client)
|
31
|
-
else:
|
32
|
-
raise ValueError(f"Provider {provider} not supported")
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/__init__.py
RENAMED
File without changes
|
{not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/ollama_api.py
RENAMED
File without changes
|
{not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/openai_api.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
{not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/compile_prompt.py
RENAMED
File without changes
|
File without changes
|
{not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/providers/__init__.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|