kosong 0.13.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kosong-0.13.0/PKG-INFO ADDED
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.3
2
+ Name: kosong
3
+ Version: 0.13.0
4
+ Summary: The building blocks of AI agent.
5
+ Requires-Dist: dotenv>=0.9.9
6
+ Requires-Dist: jsonschema>=4.25.1
7
+ Requires-Dist: openai>=1.109.1
8
+ Requires-Dist: pydantic>=2.11.9
9
+ Requires-Python: >=3.13
10
+ Description-Content-Type: text/markdown
11
+
12
+ # Kosong
13
+
14
+ > Kosong is the emptiness.
@@ -0,0 +1,3 @@
1
+ # Kosong
2
+
3
+ > Kosong is the emptiness.
@@ -0,0 +1,40 @@
1
+ [project]
2
+ name = "kosong"
3
+ version = "0.13.0"
4
+ description = "The building blocks of AI agent."
5
+ readme = "README.md"
6
+ requires-python = ">=3.13"
7
+ dependencies = [
8
+ "dotenv>=0.9.9",
9
+ "jsonschema>=4.25.1",
10
+ "openai>=1.109.1",
11
+ "pydantic>=2.11.9",
12
+ ]
13
+
14
+ [dependency-groups]
15
+ dev = [
16
+ "pyright>=1.1.405",
17
+ "pytest>=8.4.2",
18
+ "pytest-asyncio>=1.2.0",
19
+ "ruff>=0.13.1",
20
+ ]
21
+
22
+ [build-system]
23
+ requires = ["uv_build>=0.8.5,<0.9.0"]
24
+ build-backend = "uv_build"
25
+
26
+ [tool.uv.build-backend]
27
+ module-name = ["kosong"]
28
+
29
+ [tool.ruff]
30
+ line-length = 100
31
+
32
+ [tool.ruff.lint]
33
+ select = [
34
+ "E", # pycodestyle
35
+ "F", # Pyflakes
36
+ "UP", # pyupgrade
37
+ "B", # flake8-bugbear
38
+ "SIM", # flake8-simplify
39
+ "I", # isort
40
+ ]
@@ -0,0 +1,100 @@
1
+ import asyncio
2
+ from collections.abc import Callable, Sequence
3
+ from dataclasses import dataclass
4
+
5
+ from kosong.base import generate
6
+ from kosong.base.chat_provider import ChatProvider, StreamedMessagePart, TokenUsage
7
+ from kosong.base.message import Message, ToolCall
8
+ from kosong.chat_provider import ChatProviderError
9
+ from kosong.tooling import ToolResult, ToolResultFuture, Toolset
10
+ from kosong.utils.aio import Callback
11
+
12
+
13
+ async def step(
14
+ chat_provider: ChatProvider,
15
+ system_prompt: str,
16
+ toolset: Toolset,
17
+ history: Sequence[Message],
18
+ *,
19
+ on_message_part: Callback[[StreamedMessagePart], None] | None = None,
20
+ on_tool_result: Callable[[ToolResult], None] | None = None,
21
+ ) -> "StepResult":
22
+ """
23
+ Run one "step". In one step, the function generates LLM response based on the given context for
24
+ exactly one time. All new message parts will be streamed to `on_message_part` in real-time if
25
+ provided. Tool calls will be handled by `context.toolset`. The combined message will be returned
26
+ in a `StepResult`. Depending on the toolset implementation, the tool calls may be handled
27
+ asynchronously and the results need to be fetched by `await step_result.tool_results()`.
28
+
29
+ The context will NOT be modified in this function.
30
+
31
+ The token usage will be returned in the `StepResult` if available.
32
+ """
33
+
34
+ tool_calls: list[ToolCall] = []
35
+ tool_result_futures: dict[str, ToolResultFuture] = {}
36
+
37
+ def future_done_callback(future: ToolResultFuture):
38
+ if on_tool_result:
39
+ try:
40
+ result = future.result()
41
+ on_tool_result(result)
42
+ except asyncio.CancelledError:
43
+ return
44
+
45
+ async def on_tool_call(tool_call: ToolCall):
46
+ tool_calls.append(tool_call)
47
+ result = toolset.handle(tool_call)
48
+
49
+ if isinstance(result, ToolResult):
50
+ future = ToolResultFuture()
51
+ future.add_done_callback(future_done_callback)
52
+ future.set_result(result)
53
+ tool_result_futures[tool_call.id] = future
54
+ else:
55
+ result.add_done_callback(future_done_callback)
56
+ tool_result_futures[tool_call.id] = result
57
+
58
+ try:
59
+ message, usage = await generate(
60
+ chat_provider,
61
+ system_prompt,
62
+ toolset.tools,
63
+ history,
64
+ on_message_part=on_message_part,
65
+ on_tool_call=on_tool_call,
66
+ )
67
+ except (ChatProviderError, asyncio.CancelledError):
68
+ # cancel all the futures to avoid hanging tasks
69
+ for future in tool_result_futures.values():
70
+ future.remove_done_callback(future_done_callback)
71
+ future.cancel()
72
+ raise
73
+
74
+ return StepResult(message, usage, tool_calls, tool_result_futures)
75
+
76
+
77
+ @dataclass(frozen=True)
78
+ class StepResult:
79
+ message: Message
80
+ """The combined message generated in this step."""
81
+
82
+ usage: TokenUsage | None
83
+ """The token usage of the generated message."""
84
+
85
+ tool_calls: list[ToolCall]
86
+ """All the tool calls generated in this step."""
87
+
88
+ tool_result_futures: dict[str, ToolResultFuture]
89
+ """The futures of the results of the spawned tool calls."""
90
+
91
+ async def tool_results(self) -> list[ToolResult]:
92
+ """All the tool results returned by corresponding tool calls."""
93
+ if not self.tool_result_futures:
94
+ return []
95
+
96
+ results: list[ToolResult] = []
97
+ for tool_call in self.tool_calls:
98
+ result = await self.tool_result_futures[tool_call.id]
99
+ results.append(result)
100
+ return results
@@ -0,0 +1,66 @@
1
+ from collections.abc import Sequence
2
+
3
+ from kosong.base.chat_provider import ChatProvider, StreamedMessagePart, TokenUsage
4
+ from kosong.base.message import ContentPart, Message, TextPart, ToolCall
5
+ from kosong.base.tool import Tool
6
+ from kosong.utils.aio import Callback, callback
7
+
8
+
9
+ async def generate(
10
+ chat_provider: ChatProvider,
11
+ system_prompt: str,
12
+ tools: Sequence[Tool],
13
+ history: Sequence[Message],
14
+ *,
15
+ on_message_part: Callback[[StreamedMessagePart], None] | None = None,
16
+ on_tool_call: Callback[[ToolCall], None] | None = None,
17
+ ) -> tuple[Message, TokenUsage | None]:
18
+ """
19
+ Generate one message based on the given context. The given context will remain untouched.
20
+
21
+ Parts of the message will be streamed to the given handlers:
22
+ - `on_message_part` will be called for each raw part which may be incomplete.
23
+ - `on_tool_call` will be called for each complete tool call.
24
+
25
+ The generated message and the token usage will be returned. All parts in the message are
26
+ guaranteed to be complete and merged as much as possible.
27
+ """
28
+ message = Message(role="assistant", content=[])
29
+ pending_part: StreamedMessagePart | None = None # message part that is currently incomplete
30
+
31
+ stream = await chat_provider.generate(system_prompt, tools, history)
32
+ async for part in stream:
33
+ if on_message_part:
34
+ await callback(on_message_part, part.model_copy(deep=True))
35
+
36
+ if pending_part is None:
37
+ pending_part = part
38
+ elif not pending_part.merge_in_place(part): # try merge into the pending part
39
+ # unmergeable part must push the pending part to the buffer
40
+ _message_append(message, pending_part)
41
+ if isinstance(pending_part, ToolCall) and on_tool_call:
42
+ await callback(on_tool_call, pending_part)
43
+ pending_part = part
44
+
45
+ # end of message
46
+ if pending_part is not None:
47
+ _message_append(message, pending_part)
48
+ if isinstance(pending_part, ToolCall) and on_tool_call:
49
+ await callback(on_tool_call, pending_part)
50
+
51
+ return message, stream.usage
52
+
53
+
54
+ def _message_append(message: Message, part: StreamedMessagePart) -> None:
55
+ match part:
56
+ case ContentPart():
57
+ if isinstance(message.content, str):
58
+ message.content = [TextPart(text=message.content)]
59
+ message.content.append(part)
60
+ case ToolCall():
61
+ if message.tool_calls is None:
62
+ message.tool_calls = []
63
+ message.tool_calls.append(part)
64
+ case _:
65
+ # may be an orphaned `ToolCallPart`
66
+ return
@@ -0,0 +1,56 @@
1
+ from collections.abc import AsyncIterator, Sequence
2
+ from typing import NamedTuple, Protocol, runtime_checkable
3
+
4
+ from kosong.base.message import ContentPart, Message, ToolCall, ToolCallPart
5
+ from kosong.base.tool import Tool
6
+
7
+
8
+ @runtime_checkable
9
+ class ChatProvider(Protocol):
10
+ name: str
11
+ """
12
+ The name of the chat provider.
13
+ """
14
+
15
+ @property
16
+ def model_name(self) -> str:
17
+ """
18
+ The model name to use for the chat provider.
19
+ """
20
+ ...
21
+
22
+ async def generate(
23
+ self,
24
+ system_prompt: str,
25
+ tools: Sequence[Tool],
26
+ history: Sequence[Message],
27
+ ) -> "StreamedMessage":
28
+ """
29
+ Generate a new message based on the given system prompt, tools, and history.
30
+ """
31
+ ...
32
+
33
+
34
+ type StreamedMessagePart = ContentPart | ToolCall | ToolCallPart
35
+
36
+
37
+ @runtime_checkable
38
+ class StreamedMessage(Protocol):
39
+ def __aiter__(self) -> AsyncIterator[StreamedMessagePart]:
40
+ """Create an async iterator from the stream."""
41
+ ...
42
+
43
+ @property
44
+ def usage(self) -> "TokenUsage | None":
45
+ """The usage of the streamed message."""
46
+ ...
47
+
48
+
49
+ class TokenUsage(NamedTuple):
50
+ input: int
51
+ output: int
52
+ # TODO: support `cached`
53
+
54
+ @property
55
+ def total(self) -> int:
56
+ return self.input + self.output
@@ -0,0 +1,212 @@
1
+ from abc import ABC
2
+ from typing import Any, ClassVar, Literal, override
3
+
4
+ from pydantic import BaseModel, GetCoreSchemaHandler, field_serializer
5
+ from pydantic_core import core_schema
6
+
7
+
8
+ class MergableMixin:
9
+ def merge_in_place(self, other: Any) -> bool:
10
+ """Merge the other part into the current part. Return True if the merge is successful."""
11
+ return False
12
+
13
+
14
+ class ContentPart(BaseModel, ABC, MergableMixin):
15
+ """A part of a message content."""
16
+
17
+ __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
18
+
19
+ type: str
20
+ ... # to be added by subclasses
21
+
22
+ def __init_subclass__(cls, **kwargs):
23
+ super().__init_subclass__(**kwargs)
24
+
25
+ invalid_subclass_error_msg = (
26
+ f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
27
+ )
28
+
29
+ if not hasattr(cls, "type"):
30
+ raise ValueError(invalid_subclass_error_msg)
31
+
32
+ type_value = cls.type
33
+ if not isinstance(type_value, str):
34
+ raise ValueError(invalid_subclass_error_msg)
35
+
36
+ cls.__content_part_registry[type_value] = cls
37
+
38
+ @classmethod
39
+ def __get_pydantic_core_schema__(
40
+ cls, source_type: Any, handler: GetCoreSchemaHandler
41
+ ) -> core_schema.CoreSchema:
42
+ # If we're dealing with the base ContentPart class, use custom validation
43
+ if cls.__name__ == "ContentPart":
44
+
45
+ def validate_content_part(value: Any) -> Any:
46
+ # if it's already an instance of a ContentPart subclass, return it
47
+ if hasattr(value, "__class__") and issubclass(value.__class__, cls):
48
+ return value
49
+
50
+ # if it's a dict with a type field, dispatch to the appropriate subclass
51
+ if isinstance(value, dict) and "type" in value:
52
+ type_value = value["type"]
53
+ if not isinstance(type_value, str):
54
+ raise ValueError(f"Cannot validate {value} as ContentPart")
55
+ target_class = cls.__content_part_registry[type_value]
56
+ return target_class.model_validate(value)
57
+
58
+ raise ValueError(f"Cannot validate {value} as ContentPart")
59
+
60
+ return core_schema.no_info_plain_validator_function(validate_content_part)
61
+
62
+ # for subclasses, use the default schema
63
+ return handler(source_type)
64
+
65
+
66
+ class TextPart(ContentPart):
67
+ """
68
+ >>> TextPart(text="Hello, world!").model_dump()
69
+ {'type': 'text', 'text': 'Hello, world!'}
70
+ """
71
+
72
+ type: str = "text"
73
+ text: str
74
+
75
+ @override
76
+ def merge_in_place(self, other) -> bool:
77
+ if not isinstance(other, TextPart):
78
+ return False
79
+ self.text += other.text
80
+ return True
81
+
82
+
83
+ class ThinkPart(ContentPart):
84
+ """
85
+ >>> ThinkPart(think="I think I need to think about this.").model_dump()
86
+ {'type': 'think', 'think': 'I think I need to think about this.'}
87
+ """
88
+
89
+ type: str = "think"
90
+ think: str
91
+
92
+ @override
93
+ def merge_in_place(self, other) -> bool:
94
+ if not isinstance(other, ThinkPart):
95
+ return False
96
+ self.think += other.think
97
+ return True
98
+
99
+
100
+ class ImageURLPart(ContentPart):
101
+ """
102
+ >>> ImageURLPart(image_url=ImageURLPart.ImageURL(url="https://example.com/image.png")).model_dump()
103
+ {'type': 'image_url', 'image_url': {'url': 'https://example.com/image.png', 'id': None}}
104
+ """
105
+
106
+ class ImageURL(BaseModel):
107
+ url: str
108
+ """The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
109
+ id: str | None = None
110
+ """The ID of the image, to allow LLMs to distinguish different images."""
111
+
112
+ type: str = "image_url"
113
+ image_url: ImageURL
114
+
115
+
116
+ class AudioURLPart(ContentPart):
117
+ """
118
+ >>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
119
+ {'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
120
+ """
121
+
122
+ class AudioURL(BaseModel):
123
+ url: str
124
+ """The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
125
+ id: str | None = None
126
+ """The ID of the audio, to allow LLMs to distinguish different audios."""
127
+
128
+ type: str = "audio_url"
129
+ audio_url: AudioURL
130
+
131
+
132
+ class ToolCall(BaseModel, MergableMixin):
133
+ """
134
+ A tool call requested by the assistant.
135
+
136
+ >>> ToolCall(
137
+ ... id="123",
138
+ ... function=ToolCall.FunctionBody(
139
+ ... name="function",
140
+ ... arguments="{}"
141
+ ... ),
142
+ ... ).model_dump()
143
+ {'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
144
+ """
145
+
146
+ class FunctionBody(BaseModel):
147
+ name: str
148
+ arguments: str | None
149
+
150
+ type: Literal["function"] = "function"
151
+
152
+ id: str
153
+ """The ID of the tool call."""
154
+ function: FunctionBody
155
+ """The function body of the tool call."""
156
+
157
+ @override
158
+ def merge_in_place(self, other) -> bool:
159
+ if not isinstance(other, ToolCallPart):
160
+ return False
161
+ if self.function.arguments is None:
162
+ self.function.arguments = other.arguments_part
163
+ else:
164
+ self.function.arguments += other.arguments_part or ""
165
+ return True
166
+
167
+
168
+ class ToolCallPart(BaseModel, MergableMixin):
169
+ """A part of the tool call."""
170
+
171
+ arguments_part: str | None = None
172
+ """A part of the arguments of the tool call."""
173
+
174
+ @override
175
+ def merge_in_place(self, other) -> bool:
176
+ if not isinstance(other, ToolCallPart):
177
+ return False
178
+ if self.arguments_part is None:
179
+ self.arguments_part = other.arguments_part
180
+ else:
181
+ self.arguments_part += other.arguments_part or ""
182
+ return True
183
+
184
+
185
+ class Message(BaseModel):
186
+ """A message in a conversation."""
187
+
188
+ role: Literal[
189
+ "system",
190
+ "developer",
191
+ "user",
192
+ "assistant",
193
+ "tool",
194
+ ]
195
+ name: str | None = None
196
+
197
+ content: str | list[ContentPart]
198
+ """The content of the message."""
199
+
200
+ tool_calls: list[ToolCall] | None = None
201
+ """In assistant messages, there can be tool calls."""
202
+
203
+ tool_call_id: str | None = None
204
+ """In tool messages, there can be a tool call ID."""
205
+
206
+ partial: bool | None = None
207
+
208
+ @field_serializer("content")
209
+ def serialize_content(self, content: str | list[ContentPart]) -> str | list[dict]:
210
+ if isinstance(content, str):
211
+ return content
212
+ return [part.model_dump() for part in content]
@@ -0,0 +1,22 @@
1
+ from typing import Any, Self
2
+
3
+ import jsonschema
4
+ from pydantic import BaseModel, model_validator
5
+
6
+ type ParametersType = dict[str, Any]
7
+
8
+
9
+ class Tool(BaseModel):
10
+ name: str
11
+ """The name of the tool."""
12
+
13
+ description: str
14
+ """The description of the tool."""
15
+
16
+ parameters: ParametersType
17
+ """The parameters of the tool, in JSON Schema format."""
18
+
19
+ @model_validator(mode="after")
20
+ def validate_parameters(self) -> Self:
21
+ jsonschema.validate(self.parameters, jsonschema.Draft202012Validator.META_SCHEMA)
22
+ return self
@@ -0,0 +1,53 @@
1
+ from kosong.base.chat_provider import ChatProvider
2
+
3
+ __all__ = [
4
+ "OpenAILegacy",
5
+ "Kimi",
6
+ # for testing
7
+ "MockChatProvider",
8
+ "ChaosChatProvider",
9
+ ]
10
+
11
+
12
+ def __static_check_types(
13
+ openai: "OpenAILegacy",
14
+ kimi: "Kimi",
15
+ mock: "MockChatProvider",
16
+ chaos: "ChaosChatProvider",
17
+ ):
18
+ """Use type checking to ensure the types are correct implemented."""
19
+ _: ChatProvider = openai
20
+ _: ChatProvider = mock
21
+ _: ChatProvider = kimi
22
+ _: ChatProvider = chaos
23
+
24
+
25
+ class ChatProviderError(Exception):
26
+ """The error raised by a chat provider."""
27
+
28
+ def __init__(self, message: str):
29
+ super().__init__(message)
30
+
31
+
32
+ class APIConnectionError(ChatProviderError):
33
+ """The error raised when the API connection fails."""
34
+
35
+
36
+ class APITimeoutError(ChatProviderError):
37
+ """The error raised when the API request times out."""
38
+
39
+
40
+ class APIStatusError(ChatProviderError):
41
+ """The error raised when the API returns a status code of 4xx or 5xx."""
42
+
43
+ status_code: int
44
+
45
+ def __init__(self, status_code: int, message: str):
46
+ super().__init__(message)
47
+ self.status_code = status_code
48
+
49
+
50
+ from .chaos import ChaosChatProvider # noqa: E402
51
+ from .kimi import Kimi # noqa: E402
52
+ from .mock import MockChatProvider # noqa: E402
53
+ from .openai_legacy import OpenAILegacy # noqa: E402
@@ -0,0 +1,106 @@
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ import httpx
6
+ from pydantic import BaseModel
7
+
8
+ from kosong.chat_provider.openai_legacy import OpenAILegacy
9
+
10
+
11
+ class ChaosConfig(BaseModel):
12
+ """Configuration for chaos provider."""
13
+
14
+ error_probability: float = 0.3
15
+ error_types: list[int] = [429, 500, 502, 503]
16
+ retry_after: int = 2
17
+ seed: int | None = None
18
+
19
+ @classmethod
20
+ def from_env(cls) -> "ChaosConfig":
21
+ """Create config from environment variables."""
22
+ seed_str = os.getenv("CHAOS_SEED")
23
+ return cls(
24
+ error_probability=float(os.getenv("CHAOS_ERROR_PROBABILITY", "0.3")),
25
+ error_types=[
26
+ int(x.strip()) for x in os.getenv("CHAOS_ERROR_TYPES", "429,500,502,503").split(",")
27
+ ],
28
+ retry_after=int(os.getenv("CHAOS_RETRY_AFTER", "2")),
29
+ seed=int(seed_str) if seed_str else None,
30
+ )
31
+
32
+
33
+ class ChaosTransport(httpx.AsyncBaseTransport):
34
+ """HTTP transport that randomly injects errors."""
35
+
36
+ def __init__(self, wrapped_transport: httpx.AsyncBaseTransport, config: ChaosConfig):
37
+ self._wrapped = wrapped_transport
38
+ self._config = config
39
+ if config.seed is not None:
40
+ random.seed(config.seed)
41
+
42
+ async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
43
+ if self._should_inject_error():
44
+ error_code = random.choice(self._config.error_types)
45
+ return self._create_error_response(request, error_code)
46
+
47
+ return await self._wrapped.handle_async_request(request)
48
+
49
+ def _should_inject_error(self) -> bool:
50
+ return random.random() < self._config.error_probability
51
+
52
+ def _create_error_response(self, request: httpx.Request, status_code: int) -> httpx.Response:
53
+ error_messages = {
54
+ 429: {"error": {"code": "rate_limit_exceeded", "message": "Rate limit exceeded"}},
55
+ 500: {"error": {"code": "internal_error", "message": "Internal server error"}},
56
+ 502: {"error": {"code": "bad_gateway", "message": "Bad gateway"}},
57
+ 503: {
58
+ "error": {
59
+ "code": "service_unavailable",
60
+ "message": "Service temporarily unavailable",
61
+ }
62
+ },
63
+ }
64
+
65
+ content = json.dumps(
66
+ error_messages.get(status_code, {"error": {"message": "Unknown error"}})
67
+ )
68
+ headers = {"content-type": "application/json"}
69
+
70
+ if status_code == 429:
71
+ headers["retry-after"] = str(self._config.retry_after)
72
+
73
+ return httpx.Response(
74
+ status_code=status_code,
75
+ headers=headers,
76
+ content=content.encode(),
77
+ request=request,
78
+ )
79
+
80
+
81
+ class ChaosChatProvider(OpenAILegacy):
82
+ """OpenAI Legacy provider with chaos error injection."""
83
+
84
+ def __init__(
85
+ self,
86
+ model: str,
87
+ api_key: str | None = None,
88
+ base_url: str | None = None,
89
+ chaos_config: ChaosConfig | None = None,
90
+ **client_kwargs,
91
+ ):
92
+ super().__init__(model=model, api_key=api_key, base_url=base_url, **client_kwargs)
93
+ self._chaos_config = chaos_config or ChaosConfig.from_env()
94
+ self._monkey_patch_client()
95
+
96
+ def _monkey_patch_client(self):
97
+ """Inject chaos transport into the client."""
98
+ original_transport = self._client._client._transport
99
+ chaos_transport = ChaosTransport(original_transport, self._chaos_config)
100
+ self._client._client._transport = chaos_transport
101
+
102
+ @property
103
+ def model_name(self) -> str:
104
+ if self._chaos_config.error_probability > 0:
105
+ return f"chaos({super().model_name})"
106
+ return super().model_name