calfkit 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- calfkit/__init__.py +4 -0
- calfkit/broker/__init__.py +3 -0
- calfkit/broker/broker.py +34 -0
- calfkit/broker/deployable.py +6 -0
- calfkit/broker/middleware.py +15 -0
- calfkit/experimental/rpc_worker.py +51 -0
- calfkit/messages/__init__.py +9 -0
- calfkit/messages/util.py +99 -0
- calfkit/models/event_envelope.py +43 -0
- calfkit/models/types.py +65 -0
- calfkit/nodes/__init__.py +16 -0
- calfkit/nodes/agent_router_node.py +194 -0
- calfkit/nodes/base_node.py +62 -0
- calfkit/nodes/base_tool_node.py +55 -0
- calfkit/nodes/chat_node.py +48 -0
- calfkit/nodes/registrator.py +11 -0
- calfkit/providers/__init__.py +5 -0
- calfkit/providers/pydantic_ai/__init__.py +3 -0
- calfkit/providers/pydantic_ai/openai.py +61 -0
- calfkit/runners/__init__.py +14 -0
- calfkit/runners/node_runner.py +38 -0
- calfkit/stores/__init__.py +38 -0
- calfkit/stores/base.py +60 -0
- calfkit/stores/in_memory.py +29 -0
- calfkit-0.1.1.dist-info/METADATA +129 -0
- calfkit-0.1.1.dist-info/RECORD +28 -0
- calfkit-0.1.1.dist-info/WHEEL +4 -0
- calfkit-0.1.1.dist-info/licenses/LICENSE +202 -0
calfkit/__init__.py
ADDED
calfkit/broker/broker.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from faststream import FastStream
|
|
6
|
+
from faststream.kafka import KafkaBroker
|
|
7
|
+
|
|
8
|
+
from calfkit.broker.deployable import Deployable
|
|
9
|
+
from calfkit.broker.middleware import ContextInjectionMiddleware
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Broker(KafkaBroker, Deployable):
|
|
13
|
+
"""Lightweight client wrapper connecting to Calf brokers"""
|
|
14
|
+
|
|
15
|
+
mode = Literal["kafka_mode"]
|
|
16
|
+
|
|
17
|
+
def __init__(self, bootstrap_servers: str | Iterable[str] | None = None, **broker_kwargs: Any):
|
|
18
|
+
if not bootstrap_servers:
|
|
19
|
+
bootstrap_servers = os.getenv("CALF_HOST_URL")
|
|
20
|
+
super().__init__(
|
|
21
|
+
bootstrap_servers or "localhost",
|
|
22
|
+
middlewares=[ContextInjectionMiddleware],
|
|
23
|
+
**broker_kwargs,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def app(self) -> FastStream:
|
|
28
|
+
return FastStream(self)
|
|
29
|
+
|
|
30
|
+
async def run_app(self) -> None:
|
|
31
|
+
await self.app.run()
|
|
32
|
+
|
|
33
|
+
def __getattr__(self, name: str) -> Any:
|
|
34
|
+
return getattr(self, name)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from faststream import BaseMiddleware
|
|
4
|
+
from faststream.message import StreamMessage
|
|
5
|
+
from faststream.types import AsyncFuncAny
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ContextInjectionMiddleware(BaseMiddleware):
|
|
9
|
+
async def consume_scope(
|
|
10
|
+
self,
|
|
11
|
+
call_next: AsyncFuncAny,
|
|
12
|
+
msg: StreamMessage[Any],
|
|
13
|
+
) -> Any:
|
|
14
|
+
with self.context.scope("correlation_id", msg.correlation_id):
|
|
15
|
+
return await super().consume_scope(call_next, msg)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from asyncio import Future, wait_for
|
|
2
|
+
from uuid import uuid4
|
|
3
|
+
|
|
4
|
+
from faststream.kafka import KafkaBroker, KafkaMessage
|
|
5
|
+
from faststream.types import SendableMessage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RPCWorker:
|
|
9
|
+
responses: dict[str, Future[bytes]]
|
|
10
|
+
|
|
11
|
+
def __init__(self, broker: KafkaBroker, reply_topic: str) -> None:
|
|
12
|
+
self.responses = {}
|
|
13
|
+
self.broker = broker
|
|
14
|
+
self.reply_topic = reply_topic
|
|
15
|
+
|
|
16
|
+
self.subscriber = broker.subscriber(reply_topic)
|
|
17
|
+
self.subscriber(self._handle_responses)
|
|
18
|
+
|
|
19
|
+
async def start(self) -> None:
|
|
20
|
+
await self.subscriber.start()
|
|
21
|
+
|
|
22
|
+
async def stop(self) -> None:
|
|
23
|
+
await self.subscriber.stop()
|
|
24
|
+
|
|
25
|
+
def _handle_responses(self, msg: KafkaMessage) -> None:
|
|
26
|
+
if future := self.responses.pop(msg.correlation_id, None):
|
|
27
|
+
future.set_result(msg.body)
|
|
28
|
+
|
|
29
|
+
async def request(
|
|
30
|
+
self,
|
|
31
|
+
data: SendableMessage,
|
|
32
|
+
topic: str,
|
|
33
|
+
timeout: float = 10.0,
|
|
34
|
+
) -> bytes:
|
|
35
|
+
correlation_id = str(uuid4())
|
|
36
|
+
future = self.responses[correlation_id] = Future[bytes]()
|
|
37
|
+
|
|
38
|
+
await self.broker.publish(
|
|
39
|
+
data,
|
|
40
|
+
topic,
|
|
41
|
+
reply_to=self.reply_topic,
|
|
42
|
+
correlation_id=correlation_id,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
response: bytes = await wait_for(future, timeout=timeout)
|
|
47
|
+
except TimeoutError:
|
|
48
|
+
self.responses.pop(correlation_id, None)
|
|
49
|
+
raise
|
|
50
|
+
else:
|
|
51
|
+
return response
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Message utilities for calf SDK.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for working with pydantic_ai ModelMessage types,
|
|
4
|
+
including message history manipulation and transformation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .util import patch_system_prompts, validate_tool_call_pairs
|
|
8
|
+
|
|
9
|
+
__all__ = ["patch_system_prompts", "validate_tool_call_pairs"]
|
calfkit/messages/util.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Message utility functions for calf SDK.
|
|
2
|
+
|
|
3
|
+
This module provides utilities for working with pydantic_ai ModelMessage types,
|
|
4
|
+
including message history manipulation and transformation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from pydantic_ai.messages import (
|
|
8
|
+
ModelMessage,
|
|
9
|
+
ModelRequest,
|
|
10
|
+
ModelResponse,
|
|
11
|
+
RetryPromptPart,
|
|
12
|
+
SystemPromptPart,
|
|
13
|
+
ToolCallPart,
|
|
14
|
+
ToolReturnPart,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def patch_system_prompts(
|
|
19
|
+
base: list[ModelMessage],
|
|
20
|
+
incoming: list[ModelMessage],
|
|
21
|
+
) -> list[ModelMessage]:
|
|
22
|
+
"""Patch system prompts in message history.
|
|
23
|
+
|
|
24
|
+
If incoming messages contain system prompts, they replace any existing
|
|
25
|
+
system prompts in base. System prompts are consolidated and placed at
|
|
26
|
+
the front of the history.
|
|
27
|
+
|
|
28
|
+
If incoming has no system prompts, returns base unmodified.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
base: The existing message history to patch.
|
|
32
|
+
incoming: The new messages that may contain replacement system prompts.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
A new message history with system prompts patched.
|
|
36
|
+
|
|
37
|
+
Examples:
|
|
38
|
+
>>> from pydantic_ai import ModelRequest, SystemPromptPart
|
|
39
|
+
>>> base = [ModelRequest(parts=[SystemPromptPart("old system")])]
|
|
40
|
+
>>> incoming = [ModelRequest(parts=[SystemPromptPart("new system")])]
|
|
41
|
+
>>> result = patch_system_prompts(base, incoming)
|
|
42
|
+
>>> len(result)
|
|
43
|
+
1
|
|
44
|
+
>>> result[0].parts[0].content
|
|
45
|
+
'new system'
|
|
46
|
+
"""
|
|
47
|
+
incoming_system_parts: list[SystemPromptPart] = []
|
|
48
|
+
for msg in incoming:
|
|
49
|
+
if isinstance(msg, ModelResponse):
|
|
50
|
+
continue
|
|
51
|
+
for part in msg.parts:
|
|
52
|
+
if isinstance(part, SystemPromptPart):
|
|
53
|
+
incoming_system_parts.append(part)
|
|
54
|
+
|
|
55
|
+
if not incoming_system_parts:
|
|
56
|
+
return base
|
|
57
|
+
|
|
58
|
+
system_msg = ModelRequest(parts=incoming_system_parts)
|
|
59
|
+
result: list[ModelMessage] = []
|
|
60
|
+
for msg in base:
|
|
61
|
+
if isinstance(msg, ModelRequest):
|
|
62
|
+
non_system_parts = [p for p in msg.parts if not isinstance(p, SystemPromptPart)]
|
|
63
|
+
if non_system_parts:
|
|
64
|
+
result.append(ModelRequest(parts=non_system_parts))
|
|
65
|
+
else:
|
|
66
|
+
result.append(msg)
|
|
67
|
+
|
|
68
|
+
return [system_msg] + result
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def validate_tool_call_pairs(messages: list[ModelMessage]) -> bool:
|
|
72
|
+
"""Validate that all tool calls have corresponding tool results.
|
|
73
|
+
|
|
74
|
+
Iterates through messages in reverse order to verify that every ToolCallPart
|
|
75
|
+
has a matching ToolReturnPart or RetryPromptPart with the same tool_call_id.
|
|
76
|
+
|
|
77
|
+
The first time a tool call is found without a matching result, the function
|
|
78
|
+
returns False immediately.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
messages: List of ModelMessage to validate.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
True if all tool calls have matching results, False otherwise.
|
|
85
|
+
"""
|
|
86
|
+
seen_result_ids: set[str] = set()
|
|
87
|
+
|
|
88
|
+
for message in reversed(messages):
|
|
89
|
+
if isinstance(message, ModelRequest):
|
|
90
|
+
for req_part in message.parts:
|
|
91
|
+
if isinstance(req_part, (ToolReturnPart, RetryPromptPart)):
|
|
92
|
+
seen_result_ids.add(req_part.tool_call_id)
|
|
93
|
+
elif isinstance(message, ModelResponse):
|
|
94
|
+
for resp_part in message.parts:
|
|
95
|
+
if isinstance(resp_part, ToolCallPart):
|
|
96
|
+
if resp_part.tool_call_id not in seen_result_ids:
|
|
97
|
+
return False
|
|
98
|
+
|
|
99
|
+
return True
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic_ai import ModelMessage, ModelRequest
|
|
5
|
+
from pydantic_ai.models import ModelRequestParameters
|
|
6
|
+
|
|
7
|
+
from calfkit.models.types import CompactBaseModel, SerializableModelSettings, ToolCallRequest
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EventEnvelope(CompactBaseModel):
|
|
11
|
+
kind: Literal["tool_call_request", "user_prompt", "ai_response", "tool_result"]
|
|
12
|
+
text: str | None = None
|
|
13
|
+
trace_id: str | None = None
|
|
14
|
+
|
|
15
|
+
# Used to surface the tool call from latest message so tool call workers do not have to dig
|
|
16
|
+
tool_call_request: ToolCallRequest | None = None
|
|
17
|
+
|
|
18
|
+
# Optional inference-time patch in settings and parameters
|
|
19
|
+
patch_model_request_params: ModelRequestParameters | None = None
|
|
20
|
+
patch_model_settings: SerializableModelSettings | None = None
|
|
21
|
+
|
|
22
|
+
# Running message history
|
|
23
|
+
message_history: list[ModelMessage] = []
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def latest_message_in_history(self) -> ModelMessage | None:
|
|
27
|
+
return self.message_history[-1] if self.message_history else None
|
|
28
|
+
|
|
29
|
+
# The result message from a node
|
|
30
|
+
incoming_node_messages: Sequence[ModelMessage] = []
|
|
31
|
+
|
|
32
|
+
# thread id / conversation identifier
|
|
33
|
+
thread_id: str | None = None
|
|
34
|
+
|
|
35
|
+
# Allow client to dynamically patch system message at runtime
|
|
36
|
+
# Intentionally kept separate from message_history in order to simplify patch logic
|
|
37
|
+
system_message: ModelRequest | None = None
|
|
38
|
+
|
|
39
|
+
# Where the final response from AI should be published to
|
|
40
|
+
final_response_topic: str
|
|
41
|
+
|
|
42
|
+
# Whether the current message is the final response from the AI to the user
|
|
43
|
+
final_response: bool = False
|
calfkit/models/types.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from typing import Any, TypeAlias
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
from pydantic_ai import ToolCallPart
|
|
5
|
+
from typing_extensions import TypedDict
|
|
6
|
+
|
|
7
|
+
ToolCallRequest: TypeAlias = ToolCallPart
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SerializableModelSettings(TypedDict, total=False):
|
|
11
|
+
"""Serializable version of pydantic_ai.ModelSettings.
|
|
12
|
+
|
|
13
|
+
This is a copy of ModelSettings with `timeout` narrowed
|
|
14
|
+
to `float` only to ensure JSON serializability.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
max_tokens: int
|
|
18
|
+
"""The maximum number of tokens to generate before stopping."""
|
|
19
|
+
|
|
20
|
+
temperature: float
|
|
21
|
+
"""Amount of randomness injected into the response."""
|
|
22
|
+
|
|
23
|
+
top_p: float
|
|
24
|
+
"""Nucleus sampling parameter."""
|
|
25
|
+
|
|
26
|
+
timeout: float
|
|
27
|
+
"""Request timeout in seconds (float only, not httpx.Timeout)."""
|
|
28
|
+
|
|
29
|
+
parallel_tool_calls: bool
|
|
30
|
+
"""Whether to allow parallel tool calls."""
|
|
31
|
+
|
|
32
|
+
seed: int
|
|
33
|
+
"""Random seed for deterministic results."""
|
|
34
|
+
|
|
35
|
+
presence_penalty: float
|
|
36
|
+
"""Penalize tokens based on whether they have appeared in the text so far."""
|
|
37
|
+
|
|
38
|
+
frequency_penalty: float
|
|
39
|
+
"""Penalize tokens based on their existing frequency in the text so far."""
|
|
40
|
+
|
|
41
|
+
logit_bias: dict[str, int]
|
|
42
|
+
"""Modify the likelihood of specified tokens appearing in the completion."""
|
|
43
|
+
|
|
44
|
+
stop_sequences: list[str]
|
|
45
|
+
"""Sequences that will cause the model to stop generating."""
|
|
46
|
+
|
|
47
|
+
extra_headers: dict[str, str]
|
|
48
|
+
"""Extra headers to send to the model."""
|
|
49
|
+
|
|
50
|
+
extra_body: dict[str, Any]
|
|
51
|
+
"""Extra body to send to the model."""
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class CompactBaseModel(BaseModel):
|
|
55
|
+
"""Base model that excludes unset and None values during serialization."""
|
|
56
|
+
|
|
57
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
58
|
+
kwargs.setdefault("exclude_unset", True)
|
|
59
|
+
kwargs.setdefault("exclude_none", True)
|
|
60
|
+
return super().model_dump(**kwargs)
|
|
61
|
+
|
|
62
|
+
def model_dump_json(self, **kwargs: Any) -> str:
|
|
63
|
+
kwargs.setdefault("exclude_unset", True)
|
|
64
|
+
kwargs.setdefault("exclude_none", True)
|
|
65
|
+
return super().model_dump_json(**kwargs)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from calfkit.nodes.agent_router_node import AgentRouterNode
|
|
2
|
+
from calfkit.nodes.base_node import BaseNode, publish_to, subscribe_to
|
|
3
|
+
from calfkit.nodes.base_tool_node import BaseToolNode, agent_tool
|
|
4
|
+
from calfkit.nodes.chat_node import ChatNode
|
|
5
|
+
from calfkit.nodes.registrator import Registrator
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AgentRouterNode",
|
|
9
|
+
"BaseNode",
|
|
10
|
+
"BaseToolNode",
|
|
11
|
+
"ChatNode",
|
|
12
|
+
"Registrator",
|
|
13
|
+
"agent_tool",
|
|
14
|
+
"publish_to",
|
|
15
|
+
"subscribe_to",
|
|
16
|
+
]
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
from typing import Annotated, Any
|
|
2
|
+
|
|
3
|
+
import uuid_utils
|
|
4
|
+
from faststream import Context
|
|
5
|
+
from faststream.kafka.annotations import (
|
|
6
|
+
KafkaBroker as BrokerAnnotation,
|
|
7
|
+
)
|
|
8
|
+
from pydantic_ai import ModelRequest, ModelResponse, SystemPromptPart
|
|
9
|
+
from pydantic_ai.models import ModelRequestParameters
|
|
10
|
+
|
|
11
|
+
from calfkit.broker.broker import Broker
|
|
12
|
+
from calfkit.messages import patch_system_prompts, validate_tool_call_pairs
|
|
13
|
+
from calfkit.models.event_envelope import EventEnvelope
|
|
14
|
+
from calfkit.models.types import ToolCallRequest
|
|
15
|
+
from calfkit.nodes.base_node import BaseNode, publish_to, subscribe_to
|
|
16
|
+
from calfkit.nodes.base_tool_node import BaseToolNode
|
|
17
|
+
from calfkit.stores.base import MessageHistoryStore
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class AgentRouterNode(BaseNode):
|
|
21
|
+
"""Logic for the internal routing to operate agents"""
|
|
22
|
+
|
|
23
|
+
_router_sub_topic_name = "agent_router.input"
|
|
24
|
+
_router_pub_topic_name = "agent_router.output"
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
chat_node: BaseNode,
|
|
29
|
+
tool_nodes: list[BaseToolNode],
|
|
30
|
+
system_prompt: str | None = None,
|
|
31
|
+
handoff_nodes: list[type[BaseNode]] = [],
|
|
32
|
+
message_history_store: MessageHistoryStore | None = None,
|
|
33
|
+
*args: Any,
|
|
34
|
+
**kwargs: Any,
|
|
35
|
+
):
|
|
36
|
+
self.chat = chat_node
|
|
37
|
+
self.tools = tool_nodes
|
|
38
|
+
self.handoffs = handoff_nodes
|
|
39
|
+
self.system_prompt = system_prompt
|
|
40
|
+
self.system_message = (
|
|
41
|
+
ModelRequest(parts=[SystemPromptPart(self.system_prompt)])
|
|
42
|
+
if self.system_prompt
|
|
43
|
+
else None
|
|
44
|
+
)
|
|
45
|
+
self.message_history_store = message_history_store
|
|
46
|
+
|
|
47
|
+
self.tools_topic_registry: dict[str, str] = {
|
|
48
|
+
tool.tool_schema().name: tool.subscribed_topic
|
|
49
|
+
for tool in tool_nodes
|
|
50
|
+
if tool.subscribed_topic is not None
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
self.tool_response_topics = [tool.publish_to_topic for tool in self.tools]
|
|
54
|
+
|
|
55
|
+
super().__init__(*args, **kwargs)
|
|
56
|
+
|
|
57
|
+
@subscribe_to(_router_sub_topic_name)
|
|
58
|
+
@publish_to(_router_pub_topic_name)
|
|
59
|
+
async def _router(
|
|
60
|
+
self,
|
|
61
|
+
ctx: EventEnvelope,
|
|
62
|
+
correlation_id: Annotated[str, Context()],
|
|
63
|
+
broker: BrokerAnnotation,
|
|
64
|
+
) -> EventEnvelope:
|
|
65
|
+
if not ctx.incoming_node_messages:
|
|
66
|
+
raise RuntimeError("There is no response message to process")
|
|
67
|
+
|
|
68
|
+
# One central place where message history is updated
|
|
69
|
+
if self.message_history_store is not None and ctx.thread_id is not None:
|
|
70
|
+
await self.message_history_store.append_many(
|
|
71
|
+
thread_id=ctx.thread_id, messages=ctx.incoming_node_messages
|
|
72
|
+
)
|
|
73
|
+
ctx.message_history = await self.message_history_store.get(thread_id=ctx.thread_id)
|
|
74
|
+
else:
|
|
75
|
+
ctx.message_history.extend(ctx.incoming_node_messages)
|
|
76
|
+
|
|
77
|
+
# Apply system prompts with priority: incoming > self.system_message > existing history
|
|
78
|
+
# First, apply self.system_message as fallback (replaces existing history)
|
|
79
|
+
if ctx.system_message is not None:
|
|
80
|
+
ctx.message_history = patch_system_prompts(ctx.message_history, [ctx.system_message])
|
|
81
|
+
elif self.system_message is not None:
|
|
82
|
+
ctx.message_history = patch_system_prompts(
|
|
83
|
+
ctx.message_history,
|
|
84
|
+
[self.system_message],
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if isinstance(ctx.latest_message_in_history, ModelResponse):
|
|
88
|
+
if (
|
|
89
|
+
ctx.latest_message_in_history.finish_reason == "tool_call"
|
|
90
|
+
or ctx.latest_message_in_history.tool_calls
|
|
91
|
+
):
|
|
92
|
+
for tool_call in ctx.latest_message_in_history.tool_calls:
|
|
93
|
+
await self._route_tool(ctx, tool_call, correlation_id, broker)
|
|
94
|
+
else:
|
|
95
|
+
# reply to sender here
|
|
96
|
+
await self._reply_to_sender(ctx, correlation_id, broker)
|
|
97
|
+
elif validate_tool_call_pairs(ctx.message_history):
|
|
98
|
+
await self._call_model(ctx, correlation_id, broker)
|
|
99
|
+
return ctx
|
|
100
|
+
|
|
101
|
+
async def _route_tool(
|
|
102
|
+
self,
|
|
103
|
+
event_envelope: EventEnvelope,
|
|
104
|
+
generated_tool_call: ToolCallRequest,
|
|
105
|
+
correlation_id: str,
|
|
106
|
+
broker: Any,
|
|
107
|
+
) -> None:
|
|
108
|
+
tool_topic = self.tools_topic_registry.get(generated_tool_call.tool_name)
|
|
109
|
+
if tool_topic is None:
|
|
110
|
+
# TODO: implement a short circuit to respond with an
|
|
111
|
+
# error message for when provided tool does not exist.
|
|
112
|
+
return
|
|
113
|
+
event_envelope = event_envelope.model_copy(
|
|
114
|
+
update={"kind": "tool_call_request", "tool_call_request": generated_tool_call}
|
|
115
|
+
)
|
|
116
|
+
await broker.publish(
|
|
117
|
+
event_envelope,
|
|
118
|
+
topic=tool_topic,
|
|
119
|
+
correlation_id=correlation_id,
|
|
120
|
+
reply_to=self.subscribed_topic,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
async def _reply_to_sender(
|
|
124
|
+
self, event_envelope: EventEnvelope, correlation_id: str, broker: Any
|
|
125
|
+
) -> None:
|
|
126
|
+
event_envelope = event_envelope.model_copy(
|
|
127
|
+
update={"kind": "ai_response", "final_response": True}
|
|
128
|
+
)
|
|
129
|
+
await broker.publish(
|
|
130
|
+
event_envelope,
|
|
131
|
+
topic=event_envelope.final_response_topic,
|
|
132
|
+
correlation_id=correlation_id,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
async def _call_model(
|
|
136
|
+
self,
|
|
137
|
+
event_envelope: EventEnvelope,
|
|
138
|
+
correlation_id: str,
|
|
139
|
+
broker: Any,
|
|
140
|
+
) -> None:
|
|
141
|
+
patch_model_request_params = event_envelope.patch_model_request_params
|
|
142
|
+
if patch_model_request_params is None:
|
|
143
|
+
patch_model_request_params = ModelRequestParameters(
|
|
144
|
+
function_tools=[tool.tool_schema() for tool in self.tools]
|
|
145
|
+
)
|
|
146
|
+
event_envelope = event_envelope.model_copy(
|
|
147
|
+
update={"patch_model_request_params": patch_model_request_params}
|
|
148
|
+
)
|
|
149
|
+
await broker.publish(
|
|
150
|
+
event_envelope,
|
|
151
|
+
topic=self.chat.subscribed_topic,
|
|
152
|
+
correlation_id=correlation_id,
|
|
153
|
+
reply_to=self.subscribed_topic,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
async def invoke(
|
|
157
|
+
self,
|
|
158
|
+
user_prompt: str,
|
|
159
|
+
broker: Broker,
|
|
160
|
+
final_response_topic: str,
|
|
161
|
+
thread_id: str | None = None,
|
|
162
|
+
correlation_id: str | None = None,
|
|
163
|
+
) -> str:
|
|
164
|
+
"""Invoke the agent
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
user_prompt (str): User prompt to request the model
|
|
168
|
+
broker (Broker): The broker to connect to
|
|
169
|
+
correlation_id (str | None, optional): Optionally provide a correlation ID
|
|
170
|
+
for this request. Defaults to None.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
str: The correlation ID for this request
|
|
174
|
+
"""
|
|
175
|
+
patch_model_request_params = ModelRequestParameters(
|
|
176
|
+
function_tools=[tool.tool_schema() for tool in self.tools]
|
|
177
|
+
)
|
|
178
|
+
if correlation_id is None:
|
|
179
|
+
correlation_id = uuid_utils.uuid7().hex
|
|
180
|
+
new_node_messages = [ModelRequest.user_text_prompt(user_prompt)]
|
|
181
|
+
await broker.publish(
|
|
182
|
+
EventEnvelope(
|
|
183
|
+
kind="user_prompt",
|
|
184
|
+
trace_id=correlation_id,
|
|
185
|
+
patch_model_request_params=patch_model_request_params,
|
|
186
|
+
thread_id=thread_id,
|
|
187
|
+
incoming_node_messages=new_node_messages,
|
|
188
|
+
system_message=self.system_message,
|
|
189
|
+
final_response_topic=final_response_topic,
|
|
190
|
+
),
|
|
191
|
+
topic=self.subscribed_topic or "",
|
|
192
|
+
correlation_id=correlation_id,
|
|
193
|
+
)
|
|
194
|
+
return correlation_id
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from functools import cached_property
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def subscribe_to(topic_name: str) -> Callable[[Any], Any]:
|
|
8
|
+
def decorator(fn: Any) -> Any:
|
|
9
|
+
fn._subscribe_to_topic_name = topic_name
|
|
10
|
+
return fn
|
|
11
|
+
|
|
12
|
+
return decorator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def publish_to(topic_name: str) -> Callable[[Any], Any]:
|
|
16
|
+
def decorator(fn: Any) -> Any:
|
|
17
|
+
fn._publish_to_topic_name = topic_name
|
|
18
|
+
return fn
|
|
19
|
+
|
|
20
|
+
return decorator
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class BaseNode(ABC):
|
|
24
|
+
"""Effectively a node is the data plane, defining the internal wiring and logic.
|
|
25
|
+
When provided to a NodeRunner, node logic can be deployed."""
|
|
26
|
+
|
|
27
|
+
_handler_registry: dict[Callable[..., Any], dict[str, str]] = {}
|
|
28
|
+
|
|
29
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
30
|
+
self.bound_registry: dict[Callable[..., Any], dict[str, str]] = {
|
|
31
|
+
fn.__get__(self, type(self)): topics_dict
|
|
32
|
+
for fn, topics_dict in self._handler_registry.items()
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def __init_subclass__(cls) -> None:
|
|
36
|
+
super().__init_subclass__()
|
|
37
|
+
|
|
38
|
+
cls._handler_registry = {}
|
|
39
|
+
|
|
40
|
+
for attr in cls.__dict__.values():
|
|
41
|
+
publish_to_topic_name = getattr(attr, "_publish_to_topic_name", None)
|
|
42
|
+
subscribe_to_topic_name = getattr(attr, "_subscribe_to_topic_name", None)
|
|
43
|
+
if publish_to_topic_name:
|
|
44
|
+
cls._handler_registry[attr] = {"publish_topic": publish_to_topic_name}
|
|
45
|
+
if subscribe_to_topic_name:
|
|
46
|
+
cls._handler_registry[attr] = cls._handler_registry.get(attr, {}) | {
|
|
47
|
+
"subscribe_topic": subscribe_to_topic_name
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
@cached_property
|
|
51
|
+
def subscribed_topic(self) -> str | None:
|
|
52
|
+
for topics_dict in self._handler_registry.values():
|
|
53
|
+
if "subscribe_topic" in topics_dict:
|
|
54
|
+
return topics_dict["subscribe_topic"]
|
|
55
|
+
return None
|
|
56
|
+
|
|
57
|
+
@cached_property
|
|
58
|
+
def publish_to_topic(self) -> str | None:
|
|
59
|
+
for topics_dict in self._handler_registry.values():
|
|
60
|
+
if "publish_topic" in topics_dict:
|
|
61
|
+
return topics_dict["publish_topic"]
|
|
62
|
+
return None
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from collections.abc import Awaitable, Callable
|
|
4
|
+
from typing import Any, cast
|
|
5
|
+
|
|
6
|
+
from pydantic_ai import ModelRequest, Tool, ToolDefinition, ToolReturnPart
|
|
7
|
+
|
|
8
|
+
from calfkit.models.event_envelope import EventEnvelope
|
|
9
|
+
from calfkit.nodes.base_node import BaseNode, publish_to, subscribe_to
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseToolNode(BaseNode, ABC):
|
|
13
|
+
@classmethod
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def tool_schema(cls) -> ToolDefinition: ...
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def agent_tool(func: Callable[..., Any] | Callable[..., Awaitable[Any]]) -> BaseToolNode:
|
|
19
|
+
"""Agent tool decorator to turn a function into a deployable node"""
|
|
20
|
+
tool = Tool(func)
|
|
21
|
+
|
|
22
|
+
class ToolNode(BaseToolNode):
|
|
23
|
+
@subscribe_to(f"tool_node.{func.__name__}.request")
|
|
24
|
+
@publish_to(f"tool_node.{func.__name__}.result")
|
|
25
|
+
async def on_enter(self, event_envelope: EventEnvelope) -> EventEnvelope:
|
|
26
|
+
if not event_envelope.tool_call_request:
|
|
27
|
+
raise RuntimeError("No tool call request found")
|
|
28
|
+
tool_cal_req = event_envelope.tool_call_request
|
|
29
|
+
kw_args = tool_cal_req.args_as_dict()
|
|
30
|
+
result = func(**kw_args)
|
|
31
|
+
if inspect.isawaitable(result):
|
|
32
|
+
result = await result
|
|
33
|
+
tool_result = ToolReturnPart(
|
|
34
|
+
tool_name=tool_cal_req.tool_name,
|
|
35
|
+
content=result,
|
|
36
|
+
tool_call_id=tool_cal_req.tool_call_id,
|
|
37
|
+
)
|
|
38
|
+
event_envelope = event_envelope.model_copy(
|
|
39
|
+
update={
|
|
40
|
+
"kind": "tool_result",
|
|
41
|
+
"incoming_node_messages": [ModelRequest(parts=[tool_result])],
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
return event_envelope
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def tool_schema(cls) -> ToolDefinition:
|
|
48
|
+
return cast(ToolDefinition, tool.tool_def)
|
|
49
|
+
|
|
50
|
+
ToolNode.__name__ = func.__name__
|
|
51
|
+
ToolNode.__qualname__ = func.__qualname__
|
|
52
|
+
ToolNode.__doc__ = func.__doc__
|
|
53
|
+
ToolNode.__module__ = func.__module__
|
|
54
|
+
|
|
55
|
+
return ToolNode()
|