kosong 0.13.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kosong-0.13.0/PKG-INFO +14 -0
- kosong-0.13.0/README.md +3 -0
- kosong-0.13.0/pyproject.toml +40 -0
- kosong-0.13.0/src/kosong/__init__.py +100 -0
- kosong-0.13.0/src/kosong/base/__init__.py +66 -0
- kosong-0.13.0/src/kosong/base/chat_provider.py +56 -0
- kosong-0.13.0/src/kosong/base/message.py +212 -0
- kosong-0.13.0/src/kosong/base/tool.py +22 -0
- kosong-0.13.0/src/kosong/chat_provider/__init__.py +53 -0
- kosong-0.13.0/src/kosong/chat_provider/chaos.py +106 -0
- kosong-0.13.0/src/kosong/chat_provider/kimi.py +153 -0
- kosong-0.13.0/src/kosong/chat_provider/mock.py +52 -0
- kosong-0.13.0/src/kosong/chat_provider/openai_legacy.py +233 -0
- kosong-0.13.0/src/kosong/context/__init__.py +6 -0
- kosong-0.13.0/src/kosong/context/linear.py +145 -0
- kosong-0.13.0/src/kosong/py.typed +0 -0
- kosong-0.13.0/src/kosong/tooling/__init__.py +193 -0
- kosong-0.13.0/src/kosong/tooling/empty.py +13 -0
- kosong-0.13.0/src/kosong/tooling/error.py +41 -0
- kosong-0.13.0/src/kosong/tooling/simple.py +72 -0
- kosong-0.13.0/src/kosong/utils/__init__.py +0 -0
- kosong-0.13.0/src/kosong/utils/aio.py +13 -0
- kosong-0.13.0/src/kosong/utils/typing.py +3 -0
kosong-0.13.0/PKG-INFO
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: kosong
|
|
3
|
+
Version: 0.13.0
|
|
4
|
+
Summary: The building blocks of AI agent.
|
|
5
|
+
Requires-Dist: dotenv>=0.9.9
|
|
6
|
+
Requires-Dist: jsonschema>=4.25.1
|
|
7
|
+
Requires-Dist: openai>=1.109.1
|
|
8
|
+
Requires-Dist: pydantic>=2.11.9
|
|
9
|
+
Requires-Python: >=3.13
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# Kosong
|
|
13
|
+
|
|
14
|
+
> Kosong is the emptiness.
|
kosong-0.13.0/README.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "kosong"
|
|
3
|
+
version = "0.13.0"
|
|
4
|
+
description = "The building blocks of AI agent."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.13"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"dotenv>=0.9.9",
|
|
9
|
+
"jsonschema>=4.25.1",
|
|
10
|
+
"openai>=1.109.1",
|
|
11
|
+
"pydantic>=2.11.9",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[dependency-groups]
|
|
15
|
+
dev = [
|
|
16
|
+
"pyright>=1.1.405",
|
|
17
|
+
"pytest>=8.4.2",
|
|
18
|
+
"pytest-asyncio>=1.2.0",
|
|
19
|
+
"ruff>=0.13.1",
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
[build-system]
|
|
23
|
+
requires = ["uv_build>=0.8.5,<0.9.0"]
|
|
24
|
+
build-backend = "uv_build"
|
|
25
|
+
|
|
26
|
+
[tool.uv.build-backend]
|
|
27
|
+
module-name = ["kosong"]
|
|
28
|
+
|
|
29
|
+
[tool.ruff]
|
|
30
|
+
line-length = 100
|
|
31
|
+
|
|
32
|
+
[tool.ruff.lint]
|
|
33
|
+
select = [
|
|
34
|
+
"E", # pycodestyle
|
|
35
|
+
"F", # Pyflakes
|
|
36
|
+
"UP", # pyupgrade
|
|
37
|
+
"B", # flake8-bugbear
|
|
38
|
+
"SIM", # flake8-simplify
|
|
39
|
+
"I", # isort
|
|
40
|
+
]
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import Callable, Sequence
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
|
|
5
|
+
from kosong.base import generate
|
|
6
|
+
from kosong.base.chat_provider import ChatProvider, StreamedMessagePart, TokenUsage
|
|
7
|
+
from kosong.base.message import Message, ToolCall
|
|
8
|
+
from kosong.chat_provider import ChatProviderError
|
|
9
|
+
from kosong.tooling import ToolResult, ToolResultFuture, Toolset
|
|
10
|
+
from kosong.utils.aio import Callback
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
async def step(
|
|
14
|
+
chat_provider: ChatProvider,
|
|
15
|
+
system_prompt: str,
|
|
16
|
+
toolset: Toolset,
|
|
17
|
+
history: Sequence[Message],
|
|
18
|
+
*,
|
|
19
|
+
on_message_part: Callback[[StreamedMessagePart], None] | None = None,
|
|
20
|
+
on_tool_result: Callable[[ToolResult], None] | None = None,
|
|
21
|
+
) -> "StepResult":
|
|
22
|
+
"""
|
|
23
|
+
Run one "step". In one step, the function generates LLM response based on the given context for
|
|
24
|
+
exactly one time. All new message parts will be streamed to `on_message_part` in real-time if
|
|
25
|
+
provided. Tool calls will be handled by `context.toolset`. The combined message will be returned
|
|
26
|
+
in a `StepResult`. Depending on the toolset implementation, the tool calls may be handled
|
|
27
|
+
asynchronously and the results need to be fetched by `await step_result.tool_results()`.
|
|
28
|
+
|
|
29
|
+
The context will NOT be modified in this function.
|
|
30
|
+
|
|
31
|
+
The token usage will be returned in the `StepResult` if available.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
tool_calls: list[ToolCall] = []
|
|
35
|
+
tool_result_futures: dict[str, ToolResultFuture] = {}
|
|
36
|
+
|
|
37
|
+
def future_done_callback(future: ToolResultFuture):
|
|
38
|
+
if on_tool_result:
|
|
39
|
+
try:
|
|
40
|
+
result = future.result()
|
|
41
|
+
on_tool_result(result)
|
|
42
|
+
except asyncio.CancelledError:
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
async def on_tool_call(tool_call: ToolCall):
|
|
46
|
+
tool_calls.append(tool_call)
|
|
47
|
+
result = toolset.handle(tool_call)
|
|
48
|
+
|
|
49
|
+
if isinstance(result, ToolResult):
|
|
50
|
+
future = ToolResultFuture()
|
|
51
|
+
future.add_done_callback(future_done_callback)
|
|
52
|
+
future.set_result(result)
|
|
53
|
+
tool_result_futures[tool_call.id] = future
|
|
54
|
+
else:
|
|
55
|
+
result.add_done_callback(future_done_callback)
|
|
56
|
+
tool_result_futures[tool_call.id] = result
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
message, usage = await generate(
|
|
60
|
+
chat_provider,
|
|
61
|
+
system_prompt,
|
|
62
|
+
toolset.tools,
|
|
63
|
+
history,
|
|
64
|
+
on_message_part=on_message_part,
|
|
65
|
+
on_tool_call=on_tool_call,
|
|
66
|
+
)
|
|
67
|
+
except (ChatProviderError, asyncio.CancelledError):
|
|
68
|
+
# cancel all the futures to avoid hanging tasks
|
|
69
|
+
for future in tool_result_futures.values():
|
|
70
|
+
future.remove_done_callback(future_done_callback)
|
|
71
|
+
future.cancel()
|
|
72
|
+
raise
|
|
73
|
+
|
|
74
|
+
return StepResult(message, usage, tool_calls, tool_result_futures)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass(frozen=True)
|
|
78
|
+
class StepResult:
|
|
79
|
+
message: Message
|
|
80
|
+
"""The combined message generated in this step."""
|
|
81
|
+
|
|
82
|
+
usage: TokenUsage | None
|
|
83
|
+
"""The token usage of the generated message."""
|
|
84
|
+
|
|
85
|
+
tool_calls: list[ToolCall]
|
|
86
|
+
"""All the tool calls generated in this step."""
|
|
87
|
+
|
|
88
|
+
tool_result_futures: dict[str, ToolResultFuture]
|
|
89
|
+
"""The futures of the results of the spawned tool calls."""
|
|
90
|
+
|
|
91
|
+
async def tool_results(self) -> list[ToolResult]:
|
|
92
|
+
"""All the tool results returned by corresponding tool calls."""
|
|
93
|
+
if not self.tool_result_futures:
|
|
94
|
+
return []
|
|
95
|
+
|
|
96
|
+
results: list[ToolResult] = []
|
|
97
|
+
for tool_call in self.tool_calls:
|
|
98
|
+
result = await self.tool_result_futures[tool_call.id]
|
|
99
|
+
results.append(result)
|
|
100
|
+
return results
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from kosong.base.chat_provider import ChatProvider, StreamedMessagePart, TokenUsage
|
|
4
|
+
from kosong.base.message import ContentPart, Message, TextPart, ToolCall
|
|
5
|
+
from kosong.base.tool import Tool
|
|
6
|
+
from kosong.utils.aio import Callback, callback
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
async def generate(
|
|
10
|
+
chat_provider: ChatProvider,
|
|
11
|
+
system_prompt: str,
|
|
12
|
+
tools: Sequence[Tool],
|
|
13
|
+
history: Sequence[Message],
|
|
14
|
+
*,
|
|
15
|
+
on_message_part: Callback[[StreamedMessagePart], None] | None = None,
|
|
16
|
+
on_tool_call: Callback[[ToolCall], None] | None = None,
|
|
17
|
+
) -> tuple[Message, TokenUsage | None]:
|
|
18
|
+
"""
|
|
19
|
+
Generate one message based on the given context. The given context will remain untouched.
|
|
20
|
+
|
|
21
|
+
Parts of the message will be streamed to the given handlers:
|
|
22
|
+
- `on_message_part` will be called for each raw part which may be incomplete.
|
|
23
|
+
- `on_tool_call` will be called for each complete tool call.
|
|
24
|
+
|
|
25
|
+
The generated message and the token usage will be returned. All parts in the message are
|
|
26
|
+
guaranteed to be complete and merged as much as possible.
|
|
27
|
+
"""
|
|
28
|
+
message = Message(role="assistant", content=[])
|
|
29
|
+
pending_part: StreamedMessagePart | None = None # message part that is currently incomplete
|
|
30
|
+
|
|
31
|
+
stream = await chat_provider.generate(system_prompt, tools, history)
|
|
32
|
+
async for part in stream:
|
|
33
|
+
if on_message_part:
|
|
34
|
+
await callback(on_message_part, part.model_copy(deep=True))
|
|
35
|
+
|
|
36
|
+
if pending_part is None:
|
|
37
|
+
pending_part = part
|
|
38
|
+
elif not pending_part.merge_in_place(part): # try merge into the pending part
|
|
39
|
+
# unmergeable part must push the pending part to the buffer
|
|
40
|
+
_message_append(message, pending_part)
|
|
41
|
+
if isinstance(pending_part, ToolCall) and on_tool_call:
|
|
42
|
+
await callback(on_tool_call, pending_part)
|
|
43
|
+
pending_part = part
|
|
44
|
+
|
|
45
|
+
# end of message
|
|
46
|
+
if pending_part is not None:
|
|
47
|
+
_message_append(message, pending_part)
|
|
48
|
+
if isinstance(pending_part, ToolCall) and on_tool_call:
|
|
49
|
+
await callback(on_tool_call, pending_part)
|
|
50
|
+
|
|
51
|
+
return message, stream.usage
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _message_append(message: Message, part: StreamedMessagePart) -> None:
|
|
55
|
+
match part:
|
|
56
|
+
case ContentPart():
|
|
57
|
+
if isinstance(message.content, str):
|
|
58
|
+
message.content = [TextPart(text=message.content)]
|
|
59
|
+
message.content.append(part)
|
|
60
|
+
case ToolCall():
|
|
61
|
+
if message.tool_calls is None:
|
|
62
|
+
message.tool_calls = []
|
|
63
|
+
message.tool_calls.append(part)
|
|
64
|
+
case _:
|
|
65
|
+
# may be an orphaned `ToolCallPart`
|
|
66
|
+
return
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from collections.abc import AsyncIterator, Sequence
|
|
2
|
+
from typing import NamedTuple, Protocol, runtime_checkable
|
|
3
|
+
|
|
4
|
+
from kosong.base.message import ContentPart, Message, ToolCall, ToolCallPart
|
|
5
|
+
from kosong.base.tool import Tool
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@runtime_checkable
|
|
9
|
+
class ChatProvider(Protocol):
|
|
10
|
+
name: str
|
|
11
|
+
"""
|
|
12
|
+
The name of the chat provider.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def model_name(self) -> str:
|
|
17
|
+
"""
|
|
18
|
+
The model name to use for the chat provider.
|
|
19
|
+
"""
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
async def generate(
|
|
23
|
+
self,
|
|
24
|
+
system_prompt: str,
|
|
25
|
+
tools: Sequence[Tool],
|
|
26
|
+
history: Sequence[Message],
|
|
27
|
+
) -> "StreamedMessage":
|
|
28
|
+
"""
|
|
29
|
+
Generate a new message based on the given system prompt, tools, and history.
|
|
30
|
+
"""
|
|
31
|
+
...
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
type StreamedMessagePart = ContentPart | ToolCall | ToolCallPart
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@runtime_checkable
|
|
38
|
+
class StreamedMessage(Protocol):
|
|
39
|
+
def __aiter__(self) -> AsyncIterator[StreamedMessagePart]:
|
|
40
|
+
"""Create an async iterator from the stream."""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def usage(self) -> "TokenUsage | None":
|
|
45
|
+
"""The usage of the streamed message."""
|
|
46
|
+
...
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class TokenUsage(NamedTuple):
|
|
50
|
+
input: int
|
|
51
|
+
output: int
|
|
52
|
+
# TODO: support `cached`
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def total(self) -> int:
|
|
56
|
+
return self.input + self.output
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import Any, ClassVar, Literal, override
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, GetCoreSchemaHandler, field_serializer
|
|
5
|
+
from pydantic_core import core_schema
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MergableMixin:
|
|
9
|
+
def merge_in_place(self, other: Any) -> bool:
|
|
10
|
+
"""Merge the other part into the current part. Return True if the merge is successful."""
|
|
11
|
+
return False
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ContentPart(BaseModel, ABC, MergableMixin):
|
|
15
|
+
"""A part of a message content."""
|
|
16
|
+
|
|
17
|
+
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
|
18
|
+
|
|
19
|
+
type: str
|
|
20
|
+
... # to be added by subclasses
|
|
21
|
+
|
|
22
|
+
def __init_subclass__(cls, **kwargs):
|
|
23
|
+
super().__init_subclass__(**kwargs)
|
|
24
|
+
|
|
25
|
+
invalid_subclass_error_msg = (
|
|
26
|
+
f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if not hasattr(cls, "type"):
|
|
30
|
+
raise ValueError(invalid_subclass_error_msg)
|
|
31
|
+
|
|
32
|
+
type_value = cls.type
|
|
33
|
+
if not isinstance(type_value, str):
|
|
34
|
+
raise ValueError(invalid_subclass_error_msg)
|
|
35
|
+
|
|
36
|
+
cls.__content_part_registry[type_value] = cls
|
|
37
|
+
|
|
38
|
+
@classmethod
|
|
39
|
+
def __get_pydantic_core_schema__(
|
|
40
|
+
cls, source_type: Any, handler: GetCoreSchemaHandler
|
|
41
|
+
) -> core_schema.CoreSchema:
|
|
42
|
+
# If we're dealing with the base ContentPart class, use custom validation
|
|
43
|
+
if cls.__name__ == "ContentPart":
|
|
44
|
+
|
|
45
|
+
def validate_content_part(value: Any) -> Any:
|
|
46
|
+
# if it's already an instance of a ContentPart subclass, return it
|
|
47
|
+
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
# if it's a dict with a type field, dispatch to the appropriate subclass
|
|
51
|
+
if isinstance(value, dict) and "type" in value:
|
|
52
|
+
type_value = value["type"]
|
|
53
|
+
if not isinstance(type_value, str):
|
|
54
|
+
raise ValueError(f"Cannot validate {value} as ContentPart")
|
|
55
|
+
target_class = cls.__content_part_registry[type_value]
|
|
56
|
+
return target_class.model_validate(value)
|
|
57
|
+
|
|
58
|
+
raise ValueError(f"Cannot validate {value} as ContentPart")
|
|
59
|
+
|
|
60
|
+
return core_schema.no_info_plain_validator_function(validate_content_part)
|
|
61
|
+
|
|
62
|
+
# for subclasses, use the default schema
|
|
63
|
+
return handler(source_type)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class TextPart(ContentPart):
|
|
67
|
+
"""
|
|
68
|
+
>>> TextPart(text="Hello, world!").model_dump()
|
|
69
|
+
{'type': 'text', 'text': 'Hello, world!'}
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
type: str = "text"
|
|
73
|
+
text: str
|
|
74
|
+
|
|
75
|
+
@override
|
|
76
|
+
def merge_in_place(self, other) -> bool:
|
|
77
|
+
if not isinstance(other, TextPart):
|
|
78
|
+
return False
|
|
79
|
+
self.text += other.text
|
|
80
|
+
return True
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class ThinkPart(ContentPart):
|
|
84
|
+
"""
|
|
85
|
+
>>> ThinkPart(think="I think I need to think about this.").model_dump()
|
|
86
|
+
{'type': 'think', 'think': 'I think I need to think about this.'}
|
|
87
|
+
"""
|
|
88
|
+
|
|
89
|
+
type: str = "think"
|
|
90
|
+
think: str
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
def merge_in_place(self, other) -> bool:
|
|
94
|
+
if not isinstance(other, ThinkPart):
|
|
95
|
+
return False
|
|
96
|
+
self.think += other.think
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class ImageURLPart(ContentPart):
|
|
101
|
+
"""
|
|
102
|
+
>>> ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png")).model_dump()
|
|
103
|
+
{'type': 'image_url', 'image_url': {'url': 'https://example.com/image.png', 'id': None}}
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
class ImageURL(BaseModel):
|
|
107
|
+
url: str
|
|
108
|
+
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
|
|
109
|
+
id: str | None = None
|
|
110
|
+
"""The ID of the image, to allow LLMs to distinguish different images."""
|
|
111
|
+
|
|
112
|
+
type: str = "image_url"
|
|
113
|
+
image_url: ImageURL
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class AudioURLPart(ContentPart):
|
|
117
|
+
"""
|
|
118
|
+
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
|
|
119
|
+
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
class AudioURL(BaseModel):
|
|
123
|
+
url: str
|
|
124
|
+
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
|
|
125
|
+
id: str | None = None
|
|
126
|
+
"""The ID of the audio, to allow LLMs to distinguish different audios."""
|
|
127
|
+
|
|
128
|
+
type: str = "audio_url"
|
|
129
|
+
audio_url: AudioURL
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ToolCall(BaseModel, MergableMixin):
|
|
133
|
+
"""
|
|
134
|
+
A tool call requested by the assistant.
|
|
135
|
+
|
|
136
|
+
>>> ToolCall(
|
|
137
|
+
... id="123",
|
|
138
|
+
... function=ToolCall.FunctionBody(
|
|
139
|
+
... name="function",
|
|
140
|
+
... arguments="{}"
|
|
141
|
+
... ),
|
|
142
|
+
... ).model_dump()
|
|
143
|
+
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
class FunctionBody(BaseModel):
|
|
147
|
+
name: str
|
|
148
|
+
arguments: str | None
|
|
149
|
+
|
|
150
|
+
type: Literal["function"] = "function"
|
|
151
|
+
|
|
152
|
+
id: str
|
|
153
|
+
"""The ID of the tool call."""
|
|
154
|
+
function: FunctionBody
|
|
155
|
+
"""The function body of the tool call."""
|
|
156
|
+
|
|
157
|
+
@override
|
|
158
|
+
def merge_in_place(self, other) -> bool:
|
|
159
|
+
if not isinstance(other, ToolCallPart):
|
|
160
|
+
return False
|
|
161
|
+
if self.function.arguments is None:
|
|
162
|
+
self.function.arguments = other.arguments_part
|
|
163
|
+
else:
|
|
164
|
+
self.function.arguments += other.arguments_part or ""
|
|
165
|
+
return True
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class ToolCallPart(BaseModel, MergableMixin):
|
|
169
|
+
"""A part of the tool call."""
|
|
170
|
+
|
|
171
|
+
arguments_part: str | None = None
|
|
172
|
+
"""A part of the arguments of the tool call."""
|
|
173
|
+
|
|
174
|
+
@override
|
|
175
|
+
def merge_in_place(self, other) -> bool:
|
|
176
|
+
if not isinstance(other, ToolCallPart):
|
|
177
|
+
return False
|
|
178
|
+
if self.arguments_part is None:
|
|
179
|
+
self.arguments_part = other.arguments_part
|
|
180
|
+
else:
|
|
181
|
+
self.arguments_part += other.arguments_part or ""
|
|
182
|
+
return True
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class Message(BaseModel):
|
|
186
|
+
"""A message in a conversation."""
|
|
187
|
+
|
|
188
|
+
role: Literal[
|
|
189
|
+
"system",
|
|
190
|
+
"developer",
|
|
191
|
+
"user",
|
|
192
|
+
"assistant",
|
|
193
|
+
"tool",
|
|
194
|
+
]
|
|
195
|
+
name: str | None = None
|
|
196
|
+
|
|
197
|
+
content: str | list[ContentPart]
|
|
198
|
+
"""The content of the message."""
|
|
199
|
+
|
|
200
|
+
tool_calls: list[ToolCall] | None = None
|
|
201
|
+
"""In assistant messages, there can be tool calls."""
|
|
202
|
+
|
|
203
|
+
tool_call_id: str | None = None
|
|
204
|
+
"""In tool messages, there can be a tool call ID."""
|
|
205
|
+
|
|
206
|
+
partial: bool | None = None
|
|
207
|
+
|
|
208
|
+
@field_serializer("content")
|
|
209
|
+
def serialize_content(self, content: str | list[ContentPart]) -> str | list[dict]:
|
|
210
|
+
if isinstance(content, str):
|
|
211
|
+
return content
|
|
212
|
+
return [part.model_dump() for part in content]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Any, Self
|
|
2
|
+
|
|
3
|
+
import jsonschema
|
|
4
|
+
from pydantic import BaseModel, model_validator
|
|
5
|
+
|
|
6
|
+
type ParametersType = dict[str, Any]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Tool(BaseModel):
|
|
10
|
+
name: str
|
|
11
|
+
"""The name of the tool."""
|
|
12
|
+
|
|
13
|
+
description: str
|
|
14
|
+
"""The description of the tool."""
|
|
15
|
+
|
|
16
|
+
parameters: ParametersType
|
|
17
|
+
"""The parameters of the tool, in JSON Schema format."""
|
|
18
|
+
|
|
19
|
+
@model_validator(mode="after")
|
|
20
|
+
def validate_parameters(self) -> Self:
|
|
21
|
+
jsonschema.validate(self.parameters, jsonschema.Draft202012Validator.META_SCHEMA)
|
|
22
|
+
return self
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
from kosong.base.chat_provider import ChatProvider
|
|
2
|
+
|
|
3
|
+
__all__ = [
|
|
4
|
+
"OpenAILegacy",
|
|
5
|
+
"Kimi",
|
|
6
|
+
# for testing
|
|
7
|
+
"MockChatProvider",
|
|
8
|
+
"ChaosChatProvider",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def __static_check_types(
|
|
13
|
+
openai: "OpenAILegacy",
|
|
14
|
+
kimi: "Kimi",
|
|
15
|
+
mock: "MockChatProvider",
|
|
16
|
+
chaos: "ChaosChatProvider",
|
|
17
|
+
):
|
|
18
|
+
"""Use type checking to ensure the types are correct implemented."""
|
|
19
|
+
_: ChatProvider = openai
|
|
20
|
+
_: ChatProvider = mock
|
|
21
|
+
_: ChatProvider = kimi
|
|
22
|
+
_: ChatProvider = chaos
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ChatProviderError(Exception):
|
|
26
|
+
"""The error raised by a chat provider."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, message: str):
|
|
29
|
+
super().__init__(message)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class APIConnectionError(ChatProviderError):
|
|
33
|
+
"""The error raised when the API connection fails."""
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class APITimeoutError(ChatProviderError):
|
|
37
|
+
"""The error raised when the API request times out."""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class APIStatusError(ChatProviderError):
|
|
41
|
+
"""The error raised when the API returns a status code of 4xx or 5xx."""
|
|
42
|
+
|
|
43
|
+
status_code: int
|
|
44
|
+
|
|
45
|
+
def __init__(self, status_code: int, message: str):
|
|
46
|
+
super().__init__(message)
|
|
47
|
+
self.status_code = status_code
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
from .chaos import ChaosChatProvider # noqa: E402
|
|
51
|
+
from .kimi import Kimi # noqa: E402
|
|
52
|
+
from .mock import MockChatProvider # noqa: E402
|
|
53
|
+
from .openai_legacy import OpenAILegacy # noqa: E402
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
|
|
8
|
+
from kosong.chat_provider.openai_legacy import OpenAILegacy
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ChaosConfig(BaseModel):
|
|
12
|
+
"""Configuration for chaos provider."""
|
|
13
|
+
|
|
14
|
+
error_probability: float = 0.3
|
|
15
|
+
error_types: list[int] = [429, 500, 502, 503]
|
|
16
|
+
retry_after: int = 2
|
|
17
|
+
seed: int | None = None
|
|
18
|
+
|
|
19
|
+
@classmethod
|
|
20
|
+
def from_env(cls) -> "ChaosConfig":
|
|
21
|
+
"""Create config from environment variables."""
|
|
22
|
+
seed_str = os.getenv("CHAOS_SEED")
|
|
23
|
+
return cls(
|
|
24
|
+
error_probability=float(os.getenv("CHAOS_ERROR_PROBABILITY", "0.3")),
|
|
25
|
+
error_types=[
|
|
26
|
+
int(x.strip()) for x in os.getenv("CHAOS_ERROR_TYPES", "429,500,502,503").split(",")
|
|
27
|
+
],
|
|
28
|
+
retry_after=int(os.getenv("CHAOS_RETRY_AFTER", "2")),
|
|
29
|
+
seed=int(seed_str) if seed_str else None,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ChaosTransport(httpx.AsyncBaseTransport):
|
|
34
|
+
"""HTTP transport that randomly injects errors."""
|
|
35
|
+
|
|
36
|
+
def __init__(self, wrapped_transport: httpx.AsyncBaseTransport, config: ChaosConfig):
|
|
37
|
+
self._wrapped = wrapped_transport
|
|
38
|
+
self._config = config
|
|
39
|
+
if config.seed is not None:
|
|
40
|
+
random.seed(config.seed)
|
|
41
|
+
|
|
42
|
+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
|
43
|
+
if self._should_inject_error():
|
|
44
|
+
error_code = random.choice(self._config.error_types)
|
|
45
|
+
return self._create_error_response(request, error_code)
|
|
46
|
+
|
|
47
|
+
return await self._wrapped.handle_async_request(request)
|
|
48
|
+
|
|
49
|
+
def _should_inject_error(self) -> bool:
|
|
50
|
+
return random.random() < self._config.error_probability
|
|
51
|
+
|
|
52
|
+
def _create_error_response(self, request: httpx.Request, status_code: int) -> httpx.Response:
|
|
53
|
+
error_messages = {
|
|
54
|
+
429: {"error": {"code": "rate_limit_exceeded", "message": "Rate limit exceeded"}},
|
|
55
|
+
500: {"error": {"code": "internal_error", "message": "Internal server error"}},
|
|
56
|
+
502: {"error": {"code": "bad_gateway", "message": "Bad gateway"}},
|
|
57
|
+
503: {
|
|
58
|
+
"error": {
|
|
59
|
+
"code": "service_unavailable",
|
|
60
|
+
"message": "Service temporarily unavailable",
|
|
61
|
+
}
|
|
62
|
+
},
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
content = json.dumps(
|
|
66
|
+
error_messages.get(status_code, {"error": {"message": "Unknown error"}})
|
|
67
|
+
)
|
|
68
|
+
headers = {"content-type": "application/json"}
|
|
69
|
+
|
|
70
|
+
if status_code == 429:
|
|
71
|
+
headers["retry-after"] = str(self._config.retry_after)
|
|
72
|
+
|
|
73
|
+
return httpx.Response(
|
|
74
|
+
status_code=status_code,
|
|
75
|
+
headers=headers,
|
|
76
|
+
content=content.encode(),
|
|
77
|
+
request=request,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class ChaosChatProvider(OpenAILegacy):
|
|
82
|
+
"""OpenAI Legacy provider with chaos error injection."""
|
|
83
|
+
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
model: str,
|
|
87
|
+
api_key: str | None = None,
|
|
88
|
+
base_url: str | None = None,
|
|
89
|
+
chaos_config: ChaosConfig | None = None,
|
|
90
|
+
**client_kwargs,
|
|
91
|
+
):
|
|
92
|
+
super().__init__(model=model, api_key=api_key, base_url=base_url, **client_kwargs)
|
|
93
|
+
self._chaos_config = chaos_config or ChaosConfig.from_env()
|
|
94
|
+
self._monkey_patch_client()
|
|
95
|
+
|
|
96
|
+
def _monkey_patch_client(self):
|
|
97
|
+
"""Inject chaos transport into the client."""
|
|
98
|
+
original_transport = self._client._client._transport
|
|
99
|
+
chaos_transport = ChaosTransport(original_transport, self._chaos_config)
|
|
100
|
+
self._client._client._transport = chaos_transport
|
|
101
|
+
|
|
102
|
+
@property
|
|
103
|
+
def model_name(self) -> str:
|
|
104
|
+
if self._chaos_config.error_probability > 0:
|
|
105
|
+
return f"chaos({super().model_name})"
|
|
106
|
+
return super().model_name
|