not-again-ai 0.16.1__tar.gz → 0.17.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/PKG-INFO +1 -1
  2. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/pyproject.toml +4 -1
  3. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/__init__.py +2 -2
  4. not_again_ai-0.17.0/src/not_again_ai/llm/chat_completion/interface.py +61 -0
  5. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/providers/ollama_api.py +80 -12
  6. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/providers/openai_api.py +180 -38
  7. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/types.py +44 -0
  8. not_again_ai-0.16.1/src/not_again_ai/llm/chat_completion/interface.py +0 -32
  9. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/LICENSE +0 -0
  10. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/README.md +0 -0
  11. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/__init__.py +0 -0
  12. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/base/__init__.py +0 -0
  13. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/base/file_system.py +0 -0
  14. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/base/parallel.py +0 -0
  15. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/data/__init__.py +0 -0
  16. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/data/web.py +0 -0
  17. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/__init__.py +0 -0
  18. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/chat_completion/providers/__init__.py +0 -0
  19. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/__init__.py +0 -0
  20. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/interface.py +0 -0
  21. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/__init__.py +0 -0
  22. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/ollama_api.py +0 -0
  23. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/providers/openai_api.py +0 -0
  24. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/embedding/types.py +0 -0
  25. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/__init__.py +0 -0
  26. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/compile_prompt.py +0 -0
  27. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/interface.py +0 -0
  28. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/providers/__init__.py +0 -0
  29. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/providers/openai_tiktoken.py +0 -0
  30. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/llm/prompting/types.py +0 -0
  31. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/py.typed +0 -0
  32. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/statistics/__init__.py +0 -0
  33. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/statistics/dependence.py +0 -0
  34. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/__init__.py +0 -0
  35. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/barplots.py +0 -0
  36. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/distributions.py +0 -0
  37. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/scatterplot.py +0 -0
  38. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/time_series.py +0 -0
  39. {not_again_ai-0.16.1 → not_again_ai-0.17.0}/src/not_again_ai/viz/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: not-again-ai
3
- Version: 0.16.1
3
+ Version: 0.17.0
4
4
  Summary: Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place.
5
5
  License: MIT
6
6
  Author: DaveCoDev
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "not-again-ai"
3
- version = "0.16.1"
3
+ version = "0.17.0"
4
4
  description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place."
5
5
  authors = [
6
6
  { name = "DaveCoDev", email = "dave.co.dev@gmail.com" }
@@ -70,6 +70,7 @@ nox-poetry = "*"
70
70
 
71
71
  [tool.poetry.group.test.dependencies]
72
72
  pytest = "*"
73
+ pytest-asyncio = "*"
73
74
  pytest-cov = "*"
74
75
  pytest-randomly = "*"
75
76
 
@@ -153,6 +154,8 @@ filterwarnings = [
153
154
  # "ignore::DeprecationWarning:typer",
154
155
  "ignore::pytest.PytestUnraisableExceptionWarning"
155
156
  ]
157
+ asyncio_mode = "auto"
158
+ asyncio_default_fixture_loop_scope = "function"
156
159
 
157
160
  [tool.coverage.run]
158
161
  branch = true
@@ -1,4 +1,4 @@
1
- from not_again_ai.llm.chat_completion.interface import chat_completion
1
+ from not_again_ai.llm.chat_completion.interface import chat_completion, chat_completion_stream
2
2
  from not_again_ai.llm.chat_completion.types import ChatCompletionRequest
3
3
 
4
- __all__ = ["ChatCompletionRequest", "chat_completion"]
4
+ __all__ = ["ChatCompletionRequest", "chat_completion", "chat_completion_stream"]
@@ -0,0 +1,61 @@
1
+ from collections.abc import AsyncGenerator, Callable
2
+ from typing import Any
3
+
4
+ from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream
5
+ from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream
6
+ from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse
7
+
8
+
9
+ def chat_completion(
10
+ request: ChatCompletionRequest,
11
+ provider: str,
12
+ client: Callable[..., Any],
13
+ ) -> ChatCompletionResponse:
14
+ """Get a chat completion response from the given provider. Currently supported providers:
15
+ - `openai` - OpenAI
16
+ - `azure_openai` - Azure OpenAI
17
+ - `ollama` - Ollama
18
+
19
+ Args:
20
+ request: Request parameter object
21
+ provider: The supported provider name
22
+ client: Client information, see the provider's implementation for what can be provided
23
+
24
+ Returns:
25
+ ChatCompletionResponse: The chat completion response.
26
+ """
27
+ if provider == "openai" or provider == "azure_openai":
28
+ return openai_chat_completion(request, client)
29
+ elif provider == "ollama":
30
+ return ollama_chat_completion(request, client)
31
+ else:
32
+ raise ValueError(f"Provider {provider} not supported")
33
+
34
+
35
+ async def chat_completion_stream(
36
+ request: ChatCompletionRequest,
37
+ provider: str,
38
+ client: Callable[..., Any],
39
+ ) -> AsyncGenerator[ChatCompletionChunk, None]:
40
+ """Stream a chat completion response from the given provider. Currently supported providers:
41
+ - `openai` - OpenAI
42
+ - `azure_openai` - Azure OpenAI
43
+ - `ollama` - Ollama
44
+
45
+ Args:
46
+ request: Request parameter object
47
+ provider: The supported provider name
48
+ client: Client information, see the provider's implementation for what can be provided
49
+
50
+ Returns:
51
+ AsyncGenerator[ChatCompletionChunk, None]
52
+ """
53
+ request.stream = True
54
+ if provider == "openai" or provider == "azure_openai":
55
+ async for chunk in openai_chat_completion_stream(request, client):
56
+ yield chunk
57
+ elif provider == "ollama":
58
+ async for chunk in ollama_chat_completion_stream(request, client):
59
+ yield chunk
60
+ else:
61
+ raise ValueError(f"Provider {provider} not supported")
@@ -1,4 +1,4 @@
1
- from collections.abc import Callable
1
+ from collections.abc import AsyncGenerator, Callable
2
2
  import json
3
3
  import os
4
4
  import re
@@ -6,14 +6,20 @@ import time
6
6
  from typing import Any, Literal, cast
7
7
 
8
8
  from loguru import logger
9
- from ollama import ChatResponse, Client, ResponseError
9
+ from ollama import AsyncClient, ChatResponse, Client, ResponseError
10
10
 
11
11
  from not_again_ai.llm.chat_completion.types import (
12
12
  AssistantMessage,
13
13
  ChatCompletionChoice,
14
+ ChatCompletionChoiceStream,
15
+ ChatCompletionChunk,
16
+ ChatCompletionDelta,
14
17
  ChatCompletionRequest,
15
18
  ChatCompletionResponse,
16
19
  Function,
20
+ PartialFunction,
21
+ PartialToolCall,
22
+ Role,
17
23
  ToolCall,
18
24
  )
19
25
 
@@ -51,14 +57,8 @@ def validate(request: ChatCompletionRequest) -> None:
51
57
  raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.")
52
58
 
53
59
 
54
- def ollama_chat_completion(
55
- request: ChatCompletionRequest,
56
- client: Callable[..., Any],
57
- ) -> ChatCompletionResponse:
58
- validate(request)
59
-
60
+ def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]:
60
61
  kwargs = request.model_dump(mode="json", exclude_none=True)
61
-
62
62
  # For each key in OLLAMA_PARAMETER_MAP
63
63
  # If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP
64
64
  # If it is None, remove that key from kwargs
@@ -141,6 +141,16 @@ def ollama_chat_completion(
141
141
  logger.warning("Ollama model only supports a single image per message. Using only the first images.")
142
142
  message["images"] = images
143
143
 
144
+ return kwargs
145
+
146
+
147
+ def ollama_chat_completion(
148
+ request: ChatCompletionRequest,
149
+ client: Callable[..., Any],
150
+ ) -> ChatCompletionResponse:
151
+ validate(request)
152
+ kwargs = format_kwargs(request)
153
+
144
154
  try:
145
155
  start_time = time.time()
146
156
  response: ChatResponse = client(**kwargs)
@@ -164,7 +174,7 @@ def ollama_chat_completion(
164
174
  tool_name = tool_call.function.name
165
175
  if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
166
176
  errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
167
- tool_args = tool_call.function.arguments
177
+ tool_args = dict(tool_call.function.arguments)
168
178
  parsed_tool_calls.append(
169
179
  ToolCall(
170
180
  id="",
@@ -206,7 +216,65 @@ def ollama_chat_completion(
206
216
  )
207
217
 
208
218
 
209
- def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]:
219
+ async def ollama_chat_completion_stream(
220
+ request: ChatCompletionRequest,
221
+ client: Callable[..., Any],
222
+ ) -> AsyncGenerator[ChatCompletionChunk, None]:
223
+ validate(request)
224
+ kwargs = format_kwargs(request)
225
+
226
+ start_time = time.time()
227
+ stream = await client(**kwargs)
228
+
229
+ async for chunk in stream:
230
+ errors = ""
231
+ # Handle tool calls
232
+ tool_calls: list[PartialToolCall] | None = None
233
+ if chunk.message.tool_calls:
234
+ parsed_tool_calls: list[PartialToolCall] = []
235
+ for tool_call in chunk.message.tool_calls:
236
+ tool_name = tool_call.function.name
237
+ if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]:
238
+ errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n"
239
+ tool_args = tool_call.function.arguments
240
+
241
+ parsed_tool_calls.append(
242
+ PartialToolCall(
243
+ id="",
244
+ function=PartialFunction(
245
+ name=tool_name,
246
+ arguments=tool_args,
247
+ ),
248
+ )
249
+ )
250
+ tool_calls = parsed_tool_calls
251
+
252
+ current_time = time.time()
253
+ response_duration = round(current_time - start_time, 4)
254
+
255
+ delta = ChatCompletionDelta(
256
+ content=chunk.message.content or "",
257
+ role=Role.ASSISTANT,
258
+ tool_calls=tool_calls,
259
+ )
260
+ choice_obj = ChatCompletionChoiceStream(
261
+ delta=delta,
262
+ finish_reason=chunk.done_reason,
263
+ index=0,
264
+ )
265
+ chunk_obj = ChatCompletionChunk(
266
+ choices=[choice_obj],
267
+ errors=errors.strip(),
268
+ completion_tokens=chunk.get("eval_count", None),
269
+ prompt_tokens=chunk.get("prompt_eval_count", None),
270
+ response_duration=response_duration,
271
+ )
272
+ yield chunk_obj
273
+
274
+
275
+ def ollama_client(
276
+ host: str | None = None, timeout: float | None = None, async_client: bool = False
277
+ ) -> Callable[..., Any]:
210
278
  """Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable.
211
279
 
212
280
  Args:
@@ -226,7 +294,7 @@ def ollama_client(host: str | None = None, timeout: float | None = None) -> Call
226
294
  host = "http://localhost:11434"
227
295
 
228
296
  def client_callable(**kwargs: Any) -> Any:
229
- client = Client(host=host, timeout=timeout)
297
+ client = AsyncClient(host=host, timeout=timeout) if async_client else Client(host=host, timeout=timeout)
230
298
  return client.chat(**kwargs)
231
299
 
232
300
  return client_callable
@@ -1,17 +1,23 @@
1
- from collections.abc import Callable
1
+ from collections.abc import AsyncGenerator, Callable, Coroutine
2
2
  import json
3
3
  import time
4
4
  from typing import Any, Literal
5
5
 
6
6
  from azure.identity import DefaultAzureCredential, get_bearer_token_provider
7
- from openai import AzureOpenAI, OpenAI
7
+ from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
8
8
 
9
9
  from not_again_ai.llm.chat_completion.types import (
10
10
  AssistantMessage,
11
11
  ChatCompletionChoice,
12
+ ChatCompletionChoiceStream,
13
+ ChatCompletionChunk,
14
+ ChatCompletionDelta,
12
15
  ChatCompletionRequest,
13
16
  ChatCompletionResponse,
14
17
  Function,
18
+ PartialFunction,
19
+ PartialToolCall,
20
+ Role,
15
21
  ToolCall,
16
22
  )
17
23
 
@@ -36,12 +42,7 @@ def validate(request: ChatCompletionRequest) -> None:
36
42
  raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.")
37
43
 
38
44
 
39
- def openai_chat_completion(
40
- request: ChatCompletionRequest,
41
- client: Callable[..., Any],
42
- ) -> ChatCompletionResponse:
43
- validate(request)
44
-
45
+ def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]:
45
46
  # Format the response format parameters to be compatible with OpenAI API
46
47
  if request.json_mode:
47
48
  response_format: dict[str, Any] = {"type": "json_object"}
@@ -61,7 +62,6 @@ def openai_chat_completion(
61
62
  elif value is None and key in kwargs:
62
63
  del kwargs[key]
63
64
 
64
- # Iterate over each message and
65
65
  for message in kwargs["messages"]:
66
66
  role = message.get("role", None)
67
67
  # For each ToolMessage, change the "name" field to be named "tool_call_id" instead
@@ -84,6 +84,49 @@ def openai_chat_completion(
84
84
  if request.tool_choice is not None and request.tool_choice not in ["none", "auto", "required"]:
85
85
  kwargs["tool_choice"] = {"type": "function", "function": {"name": request.tool_choice}}
86
86
 
87
+ return kwargs
88
+
89
+
90
+ def process_logprobs(logprobs_content: list[dict[str, Any]]) -> list[dict[str, Any] | list[dict[str, Any]]]:
91
+ """Process logprobs content from OpenAI API response.
92
+
93
+ Args:
94
+ logprobs_content: List of logprob entries from the API response
95
+
96
+ Returns:
97
+ Processed logprobs list containing either single token info or lists of top token infos
98
+ """
99
+ logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = []
100
+ for logprob in logprobs_content:
101
+ if logprob.get("top_logprobs", None):
102
+ curr_logprob_infos: list[dict[str, Any]] = []
103
+ for top_logprob in logprob.get("top_logprobs", []):
104
+ curr_logprob_infos.append(
105
+ {
106
+ "token": top_logprob.get("token", ""),
107
+ "logprob": top_logprob.get("logprob", 0),
108
+ "bytes": top_logprob.get("bytes", 0),
109
+ }
110
+ )
111
+ logprobs_list.append(curr_logprob_infos)
112
+ else:
113
+ logprobs_list.append(
114
+ {
115
+ "token": logprob.get("token", ""),
116
+ "logprob": logprob.get("logprob", 0),
117
+ "bytes": logprob.get("bytes", 0),
118
+ }
119
+ )
120
+ return logprobs_list
121
+
122
+
123
+ def openai_chat_completion(
124
+ request: ChatCompletionRequest,
125
+ client: Callable[..., Any],
126
+ ) -> ChatCompletionResponse:
127
+ validate(request)
128
+ kwargs = format_kwargs(request)
129
+
87
130
  start_time = time.time()
88
131
  response = client(**kwargs)
89
132
  end_time = time.time()
@@ -133,28 +176,7 @@ def openai_chat_completion(
133
176
  # Handle logprobs
134
177
  logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None
135
178
  if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None:
136
- logprobs_list: list[dict[str, Any] | list[dict[str, Any]]] = []
137
- for logprob in choice["logprobs"]["content"]:
138
- if logprob.get("top_logprobs", None):
139
- curr_logprob_infos: list[dict[str, Any]] = []
140
- for top_logprob in logprob.get("top_logprobs", []):
141
- curr_logprob_infos.append(
142
- {
143
- "token": top_logprob.get("token", ""),
144
- "logprob": top_logprob.get("logprob", 0),
145
- "bytes": top_logprob.get("bytes", 0),
146
- }
147
- )
148
- logprobs_list.append(curr_logprob_infos)
149
- else:
150
- logprobs_list.append(
151
- {
152
- "token": logprob.get("token", ""),
153
- "logprob": logprob.get("logprob", 0),
154
- "bytes": logprob.get("bytes", 0),
155
- }
156
- )
157
- logprobs = logprobs_list
179
+ logprobs = process_logprobs(choice["logprobs"]["content"])
158
180
 
159
181
  # Handle extras that OpenAI or Azure OpenAI return
160
182
  if choice.get("content_filter_results", None):
@@ -195,6 +217,107 @@ def openai_chat_completion(
195
217
  )
196
218
 
197
219
 
220
+ async def openai_chat_completion_stream(
221
+ request: ChatCompletionRequest,
222
+ client: Callable[..., Any],
223
+ ) -> AsyncGenerator[ChatCompletionChunk, None]:
224
+ validate(request)
225
+ kwargs = format_kwargs(request)
226
+
227
+ start_time = time.time()
228
+ stream = await client(**kwargs)
229
+
230
+ async for chunk in stream:
231
+ errors = ""
232
+ # This kind of a hack. To make this processing generic for clients that do not return the correct
233
+ # data structure, we convert the chunk to a dict
234
+ if not isinstance(chunk, dict):
235
+ chunk = chunk.to_dict()
236
+
237
+ choices: list[ChatCompletionChoiceStream] = []
238
+ for choice in chunk["choices"]:
239
+ content = choice.get("delta", {}).get("content", "")
240
+ if not content:
241
+ content = ""
242
+
243
+ role = Role.ASSISTANT
244
+ if choice.get("delta", {}).get("role", None):
245
+ role = Role(choice["delta"]["role"])
246
+
247
+ # Handle tool calls
248
+ tool_calls: list[PartialToolCall] | None = None
249
+ if choice["delta"].get("tool_calls", None):
250
+ parsed_tool_calls: list[PartialToolCall] = []
251
+ for tool_call in choice["delta"]["tool_calls"]:
252
+ tool_name = tool_call.get("function", {}).get("name", None)
253
+ if not tool_name:
254
+ tool_name = ""
255
+ tool_args = tool_call.get("function", {}).get("arguments", "")
256
+ if not tool_args:
257
+ tool_args = ""
258
+
259
+ tool_id = tool_call.get("id", None)
260
+ parsed_tool_calls.append(
261
+ PartialToolCall(
262
+ id=tool_id,
263
+ function=PartialFunction(
264
+ name=tool_name,
265
+ arguments=tool_args,
266
+ ),
267
+ )
268
+ )
269
+ tool_calls = parsed_tool_calls
270
+
271
+ refusal = None
272
+ if choice["delta"].get("refusal", None):
273
+ refusal = choice["delta"]["refusal"]
274
+
275
+ delta = ChatCompletionDelta(
276
+ content=content,
277
+ role=role,
278
+ tool_calls=tool_calls,
279
+ refusal=refusal,
280
+ )
281
+
282
+ index = choice.get("index", 0)
283
+ finish_reason = choice.get("finish_reason", None)
284
+
285
+ # Handle logprobs
286
+ logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = None
287
+ if choice.get("logprobs", None) and choice["logprobs"].get("content", None) is not None:
288
+ logprobs = process_logprobs(choice["logprobs"]["content"])
289
+
290
+ choice_obj = ChatCompletionChoiceStream(
291
+ delta=delta,
292
+ finish_reason=finish_reason,
293
+ logprobs=logprobs,
294
+ index=index,
295
+ )
296
+ choices.append(choice_obj)
297
+
298
+ current_time = time.time()
299
+ response_duration = round(current_time - start_time, 4)
300
+
301
+ if "usage" in chunk and chunk["usage"] is not None:
302
+ completion_tokens = chunk["usage"].get("completion_tokens", None)
303
+ prompt_tokens = chunk["usage"].get("prompt_tokens", None)
304
+ system_fingerprint = chunk.get("system_fingerprint", None)
305
+ else:
306
+ completion_tokens = None
307
+ prompt_tokens = None
308
+ system_fingerprint = None
309
+
310
+ chunk_obj = ChatCompletionChunk(
311
+ choices=choices,
312
+ errors=errors.strip(),
313
+ completion_tokens=completion_tokens,
314
+ prompt_tokens=prompt_tokens,
315
+ response_duration=response_duration,
316
+ system_fingerprint=system_fingerprint,
317
+ )
318
+ yield chunk_obj
319
+
320
+
198
321
  def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]:
199
322
  """Creates a callable that instantiates and uses an OpenAI client.
200
323
 
@@ -215,6 +338,20 @@ def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_ar
215
338
  return client_callable
216
339
 
217
340
 
341
+ def create_client_callable_stream(
342
+ client_class: type[AsyncOpenAI | AsyncAzureOpenAI], **client_args: Any
343
+ ) -> Callable[..., Any]:
344
+ filtered_args = {k: v for k, v in client_args.items() if v is not None}
345
+
346
+ def client_callable(**kwargs: Any) -> Coroutine[Any, Any, Any]:
347
+ client = client_class(**filtered_args)
348
+ kwargs["stream_options"] = {"include_usage": True}
349
+ stream = client.chat.completions.create(**kwargs)
350
+ return stream
351
+
352
+ return client_callable
353
+
354
+
218
355
  class InvalidOAIAPITypeError(Exception):
219
356
  """Raised when an invalid OAIAPIType string is provided."""
220
357
 
@@ -227,6 +364,7 @@ def openai_client(
227
364
  azure_endpoint: str | None = None,
228
365
  timeout: float | None = None,
229
366
  max_retries: int | None = None,
367
+ async_client: bool = False,
230
368
  ) -> Callable[..., Any]:
231
369
  """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters.
232
370
 
@@ -247,11 +385,11 @@ def openai_client(
247
385
  max_retries (int, optional): Certain errors are automatically retried 2 times by default,
248
386
  with a short exponential backoff. Connection errors (for example, due to a network connectivity problem),
249
387
  408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default.
388
+ async_client (bool, optional): Whether to return an async client. Defaults to False.
250
389
 
251
390
  Returns:
252
391
  Callable[..., Any]: A callable that creates a client and returns completion results
253
392
 
254
-
255
393
  Raises:
256
394
  InvalidOAIAPITypeError: If an invalid API type string is provided.
257
395
  NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai').
@@ -260,17 +398,21 @@ def openai_client(
260
398
  raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.")
261
399
 
262
400
  if api_type == "openai":
263
- return create_client_callable(
264
- OpenAI,
401
+ client_class = AsyncOpenAI if async_client else OpenAI
402
+ callable_creator = create_client_callable_stream if async_client else create_client_callable
403
+ return callable_creator(
404
+ client_class, # type: ignore
265
405
  api_key=api_key,
266
406
  organization=organization,
267
407
  timeout=timeout,
268
408
  max_retries=max_retries,
269
409
  )
270
410
  elif api_type == "azure_openai":
411
+ azure_client_class = AsyncAzureOpenAI if async_client else AzureOpenAI
412
+ callable_creator = create_client_callable_stream if async_client else create_client_callable
271
413
  if api_key:
272
- return create_client_callable(
273
- AzureOpenAI,
414
+ return callable_creator(
415
+ azure_client_class, # type: ignore
274
416
  api_version=aoai_api_version,
275
417
  azure_endpoint=azure_endpoint,
276
418
  api_key=api_key,
@@ -282,8 +424,8 @@ def openai_client(
282
424
  ad_token_provider = get_bearer_token_provider(
283
425
  azure_credential, "https://cognitiveservices.azure.com/.default"
284
426
  )
285
- return create_client_callable(
286
- AzureOpenAI,
427
+ return callable_creator(
428
+ azure_client_class, # type: ignore
287
429
  api_version=aoai_api_version,
288
430
  azure_endpoint=azure_endpoint,
289
431
  azure_ad_token_provider=ad_token_provider,
@@ -52,12 +52,23 @@ class Function(BaseModel):
52
52
  arguments: dict[str, Any]
53
53
 
54
54
 
55
+ class PartialFunction(BaseModel):
56
+ name: str
57
+ arguments: str | dict[str, Any]
58
+
59
+
55
60
  class ToolCall(BaseModel):
56
61
  id: str
57
62
  function: Function
58
63
  type: Literal["function"] = "function"
59
64
 
60
65
 
66
+ class PartialToolCall(BaseModel):
67
+ id: str | None
68
+ function: PartialFunction
69
+ type: Literal["function"] = "function"
70
+
71
+
61
72
  class DeveloperMessage(BaseMessage[str]):
62
73
  role: Literal[Role.DEVELOPER] = Role.DEVELOPER
63
74
 
@@ -87,6 +98,7 @@ MessageT = AssistantMessage | DeveloperMessage | SystemMessage | ToolMessage | U
87
98
  class ChatCompletionRequest(BaseModel):
88
99
  messages: list[MessageT]
89
100
  model: str
101
+ stream: bool = Field(default=False)
90
102
 
91
103
  max_completion_tokens: int | None = Field(default=None)
92
104
  context_window: int | None = Field(default=None)
@@ -148,3 +160,35 @@ class ChatCompletionResponse(BaseModel):
148
160
  system_fingerprint: str | None = Field(default=None)
149
161
 
150
162
  extras: Any | None = Field(default=None)
163
+
164
+
165
+ class ChatCompletionDelta(BaseModel):
166
+ content: str
167
+ role: Role = Field(default=Role.ASSISTANT)
168
+
169
+ tool_calls: list[PartialToolCall] | None = Field(default=None)
170
+
171
+ refusal: str | None = Field(default=None)
172
+
173
+
174
+ class ChatCompletionChoiceStream(BaseModel):
175
+ delta: ChatCompletionDelta
176
+ index: int
177
+ finish_reason: Literal["stop", "length", "tool_calls", "content_filter"] | None
178
+
179
+ logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None)
180
+
181
+ extras: Any | None = Field(default=None)
182
+
183
+
184
+ class ChatCompletionChunk(BaseModel):
185
+ choices: list[ChatCompletionChoiceStream]
186
+
187
+ errors: str = Field(default="")
188
+
189
+ completion_tokens: int | None = Field(default=None)
190
+ prompt_tokens: int | None = Field(default=None)
191
+ response_duration: float | None = Field(default=None)
192
+
193
+ system_fingerprint: str | None = Field(default=None)
194
+ extras: Any | None = Field(default=None)
@@ -1,32 +0,0 @@
1
- from collections.abc import Callable
2
- from typing import Any
3
-
4
- from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion
5
- from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion
6
- from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, ChatCompletionResponse
7
-
8
-
9
- def chat_completion(
10
- request: ChatCompletionRequest,
11
- provider: str,
12
- client: Callable[..., Any],
13
- ) -> ChatCompletionResponse:
14
- """Get a chat completion response from the given provider. Currently supported providers:
15
- - `openai` - OpenAI
16
- - `azure_openai` - Azure OpenAI
17
- - `ollama` - Ollama
18
-
19
- Args:
20
- request: Request parameter object
21
- provider: The supported provider name
22
- client: Client information, see the provider's implementation for what can be provided
23
-
24
- Returns:
25
- ChatCompletionResponse: The chat completion response.
26
- """
27
- if provider == "openai" or provider == "azure_openai":
28
- return openai_chat_completion(request, client)
29
- elif provider == "ollama":
30
- return ollama_chat_completion(request, client)
31
- else:
32
- raise ValueError(f"Provider {provider} not supported")
File without changes
File without changes