pidevkit-agent 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,15 @@
1
+ __pycache__/
2
+ .pytest_cache/
3
+ .ruff_cache/
4
+ .venv/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+ .coverage
9
+ .coverage.*
10
+ htmlcov/
11
+ dist/
12
+ build/
13
+ *.egg-info/
14
+ .mypy_cache/
15
+ .DS_Store
@@ -0,0 +1,22 @@
1
+ Metadata-Version: 2.4
2
+ Name: pidevkit-agent
3
+ Version: 0.1.0
4
+ Summary: Agent runtime primitives for pidevkit
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+
8
+ # pidevkit-agent
9
+
10
+ Agent orchestration/runtime package for `pidevkit`.
11
+
12
+ ## Install
13
+
14
+ ```bash
15
+ pip install pidevkit-agent
16
+ ```
17
+
18
+ ## Import
19
+
20
+ ```python
21
+ from pidevkit.agent import Agent
22
+ ```
@@ -0,0 +1,15 @@
1
+ # pidevkit-agent
2
+
3
+ Agent orchestration/runtime package for `pidevkit`.
4
+
5
+ ## Install
6
+
7
+ ```bash
8
+ pip install pidevkit-agent
9
+ ```
10
+
11
+ ## Import
12
+
13
+ ```python
14
+ from pidevkit.agent import Agent
15
+ ```
@@ -0,0 +1,14 @@
1
+ [build-system]
2
+ requires = ["hatchling>=1.24.0"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "pidevkit-agent"
7
+ version = "0.1.0"
8
+ description = "Agent runtime primitives for pidevkit"
9
+ readme = "README.md"
10
+ requires-python = ">=3.11"
11
+ dependencies = []
12
+
13
+ [tool.hatch.build.targets.wheel]
14
+ packages = ["src/pidevkit/agent"]
@@ -0,0 +1,19 @@
1
+ from .agent import Agent
2
+ from .agent_loop import (
3
+ agent_loop,
4
+ agent_loop_continue,
5
+ agentLoop,
6
+ agentLoopContinue,
7
+ validate_tool_arguments,
8
+ )
9
+ from .types import AgentState
10
+
11
+ __all__ = [
12
+ "Agent",
13
+ "AgentState",
14
+ "agent_loop",
15
+ "agent_loop_continue",
16
+ "agentLoop",
17
+ "agentLoopContinue",
18
+ "validate_tool_arguments",
19
+ ]
@@ -0,0 +1,418 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import copy
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Mapping
7
+
8
+ from .agent_loop import agent_loop, agent_loop_continue
9
+ from .types import (
10
+ AgentMessage,
11
+ AgentOptions,
12
+ AgentState,
13
+ DEFAULT_MODEL,
14
+ Message,
15
+ Model,
16
+ ThinkingBudgets,
17
+ ThinkingLevel,
18
+ Transport,
19
+ clone_message,
20
+ clone_messages,
21
+ get_value,
22
+ message_content,
23
+ message_role,
24
+ now_ms,
25
+ )
26
+
27
+
28
+ @dataclass(slots=True)
29
+ class AbortSignal:
30
+ aborted: bool = False
31
+
32
+
33
+ class AbortController:
34
+ def __init__(self) -> None:
35
+ self.signal = AbortSignal()
36
+
37
+ def abort(self) -> None:
38
+ self.signal.aborted = True
39
+
40
+
41
+ class Agent:
42
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
43
+ options: dict[str, Any] = {}
44
+ if args:
45
+ if len(args) > 1:
46
+ raise TypeError("Agent accepts at most one positional options argument")
47
+ if not isinstance(args[0], Mapping):
48
+ raise TypeError("Agent positional argument must be a mapping")
49
+ options.update(dict(args[0]))
50
+ options.update(kwargs)
51
+
52
+ self._listeners: list[Callable[[dict[str, Any]], None]] = []
53
+ self._steering_queue: list[AgentMessage] = []
54
+ self._follow_up_queue: list[AgentMessage] = []
55
+ self._abort_controller: AbortController | None = None
56
+
57
+ self._stream_fn = get_value(options, "stream_fn", "streamFn", default=None)
58
+ self._convert_to_llm = get_value(options, "convert_to_llm", "convertToLlm", default=self._default_convert_to_llm)
59
+ self._transform_context = get_value(options, "transform_context", "transformContext", default=None)
60
+ self._get_api_key = get_value(options, "get_api_key", "getApiKey", default=None)
61
+
62
+ self._steering_mode = get_value(options, "steering_mode", "steeringMode", default="one-at-a-time")
63
+ self._follow_up_mode = get_value(options, "follow_up_mode", "followUpMode", default="one-at-a-time")
64
+
65
+ self.thinking_budgets: ThinkingBudgets = get_value(
66
+ options,
67
+ "thinking_budgets",
68
+ "thinkingBudgets",
69
+ default={"minimal": 0, "low": 0, "medium": 0, "high": 0},
70
+ )
71
+ self.transport: Transport = get_value(options, "transport", default="sse")
72
+ self.max_retry_delay_ms: int = int(get_value(options, "max_retry_delay_ms", "maxRetryDelayMs", default=1000) or 1000)
73
+
74
+ self._session_id: str | None = get_value(options, "session_id", "sessionId", default=None)
75
+
76
+ initial_state = get_value(options, "initial_state", "initialState", default=None)
77
+ self._state = self._build_initial_state(initial_state)
78
+
79
+ @property
80
+ def state(self) -> AgentState:
81
+ return self._state
82
+
83
+ def subscribe(self, listener: Callable[[dict[str, Any]], None]) -> Callable[[], None]:
84
+ self._listeners.append(listener)
85
+
86
+ def unsubscribe() -> None:
87
+ if listener in self._listeners:
88
+ self._listeners.remove(listener)
89
+
90
+ return unsubscribe
91
+
92
+ def _emit(self, event: dict[str, Any]) -> None:
93
+ for listener in list(self._listeners):
94
+ try:
95
+ listener(event)
96
+ except Exception:
97
+ continue
98
+
99
+ def set_system_prompt(self, system_prompt: str) -> None:
100
+ self._state.system_prompt = system_prompt
101
+
102
+ def setSystemPrompt(self, system_prompt: str) -> None:
103
+ self.set_system_prompt(system_prompt)
104
+
105
+ def set_model(self, model: Model) -> None:
106
+ self._state.model = copy.deepcopy(model)
107
+
108
+ def setModel(self, model: Model) -> None:
109
+ self.set_model(model)
110
+
111
+ def set_thinking_level(self, level: ThinkingLevel) -> None:
112
+ self._state.thinking_level = level
113
+
114
+ def setThinkingLevel(self, level: ThinkingLevel) -> None:
115
+ self.set_thinking_level(level)
116
+
117
+ def set_tools(self, tools: list[Any]) -> None:
118
+ self._state.tools = list(tools)
119
+
120
+ def setTools(self, tools: list[Any]) -> None:
121
+ self.set_tools(tools)
122
+
123
+ def replace_messages(self, messages: list[AgentMessage]) -> None:
124
+ self._state.messages = clone_messages(messages)
125
+
126
+ def replaceMessages(self, messages: list[AgentMessage]) -> None:
127
+ self.replace_messages(messages)
128
+
129
+ def append_message(self, message: AgentMessage) -> None:
130
+ self._state.messages.append(clone_message(message))
131
+
132
+ def appendMessage(self, message: AgentMessage) -> None:
133
+ self.append_message(message)
134
+
135
+ def clear_messages(self) -> None:
136
+ self._state.messages = []
137
+
138
+ def clearMessages(self) -> None:
139
+ self.clear_messages()
140
+
141
+ def steer(self, message: AgentMessage) -> None:
142
+ self._steering_queue.append(clone_message(message))
143
+
144
+ def follow_up(self, message: AgentMessage) -> None:
145
+ self._follow_up_queue.append(clone_message(message))
146
+
147
+ def followUp(self, message: AgentMessage) -> None:
148
+ self.follow_up(message)
149
+
150
+ def abort(self) -> None:
151
+ if self._abort_controller is not None:
152
+ self._abort_controller.abort()
153
+
154
+ @property
155
+ def session_id(self) -> str | None:
156
+ return self._session_id
157
+
158
+ @session_id.setter
159
+ def session_id(self, value: str | None) -> None:
160
+ self._session_id = value
161
+
162
+ @property
163
+ def sessionId(self) -> str | None:
164
+ return self._session_id
165
+
166
+ @sessionId.setter
167
+ def sessionId(self, value: str | None) -> None:
168
+ self._session_id = value
169
+
170
+ async def prompt(self, input_value: str | AgentMessage, images: list[dict[str, Any]] | None = None) -> list[AgentMessage]:
171
+ self._ensure_not_streaming()
172
+ user_message = self._make_user_message(input_value, images=images)
173
+ return await self._run_loop(prompts=[user_message], is_continue=False)
174
+
175
+ async def continue_(self) -> list[AgentMessage]:
176
+ self._ensure_not_streaming()
177
+
178
+ queued_steering = self._dequeue(self._steering_queue, self._steering_mode)
179
+ if queued_steering:
180
+ return await self._run_loop(prompts=queued_steering, is_continue=False)
181
+
182
+ queued_follow_up = self._dequeue(self._follow_up_queue, self._follow_up_mode)
183
+ if queued_follow_up:
184
+ return await self._run_loop(prompts=queued_follow_up, is_continue=False)
185
+
186
+ return await self._run_loop(prompts=[], is_continue=True)
187
+
188
+ async def resume(self) -> list[AgentMessage]:
189
+ return await self.continue_()
190
+
191
+ async def run_continue(self) -> list[AgentMessage]:
192
+ return await self.continue_()
193
+
194
+ async def wait_for_idle(self) -> None:
195
+ while self._state.is_streaming:
196
+ await asyncio.sleep(0.005)
197
+
198
+ async def _run_loop(self, prompts: list[AgentMessage], *, is_continue: bool) -> list[AgentMessage]:
199
+ self._state.is_streaming = True
200
+ self._state.error = None
201
+ self._state.stream_message = None
202
+ self._state.pending_tool_calls.clear()
203
+
204
+ self._abort_controller = AbortController()
205
+ skip_initial_steering_poll = len(prompts) > 0
206
+ steering_polled = False
207
+
208
+ async def get_steering_messages() -> list[AgentMessage]:
209
+ nonlocal steering_polled
210
+ if skip_initial_steering_poll and not steering_polled:
211
+ steering_polled = True
212
+ return []
213
+ steering_polled = True
214
+ return self._dequeue(self._steering_queue, self._steering_mode)
215
+
216
+ async def get_follow_up_messages() -> list[AgentMessage]:
217
+ return self._dequeue(self._follow_up_queue, self._follow_up_mode)
218
+
219
+ context: dict[str, Any] = {
220
+ "systemPrompt": self._state.system_prompt,
221
+ "messages": clone_messages(self._state.messages),
222
+ "tools": list(self._state.tools),
223
+ }
224
+ config = self._build_loop_config(
225
+ get_steering_messages=get_steering_messages,
226
+ get_follow_up_messages=get_follow_up_messages,
227
+ )
228
+
229
+ try:
230
+ if is_continue:
231
+ stream = agent_loop_continue(
232
+ context,
233
+ config,
234
+ signal=self._abort_controller.signal,
235
+ stream_fn=self._stream_fn,
236
+ )
237
+ else:
238
+ stream = agent_loop(
239
+ prompts,
240
+ context,
241
+ config,
242
+ signal=self._abort_controller.signal,
243
+ stream_fn=self._stream_fn,
244
+ )
245
+
246
+ async for event in stream:
247
+ self._handle_stream_event(event)
248
+ self._emit(event)
249
+
250
+ new_messages = await stream.result()
251
+ if not isinstance(new_messages, list):
252
+ new_messages = list(new_messages)
253
+
254
+ self._state.messages.extend(clone_messages(new_messages))
255
+ return new_messages
256
+ except Exception as exc:
257
+ self._state.error = str(exc)
258
+ raise
259
+ finally:
260
+ self._state.is_streaming = False
261
+ self._state.stream_message = None
262
+ self._state.pending_tool_calls.clear()
263
+ self._abort_controller = None
264
+
265
+ def _handle_stream_event(self, event: Mapping[str, Any]) -> None:
266
+ event_type = str(get_value(event, "type", default=""))
267
+
268
+ if event_type in {"message_start", "message_update"}:
269
+ message = get_value(event, "message", default=None)
270
+ if message_role(message) == "assistant":
271
+ self._state.stream_message = clone_message(message)
272
+ self._sync_pending_tool_calls(message)
273
+
274
+ if event_type == "message_end":
275
+ message = get_value(event, "message", default=None)
276
+ if message_role(message) == "assistant":
277
+ self._state.stream_message = None
278
+
279
+ if event_type == "tool_execution_start":
280
+ tool_call_id = get_value(event, "toolCallId", "tool_call_id", default=None)
281
+ if isinstance(tool_call_id, str):
282
+ self._state.pending_tool_calls.add(tool_call_id)
283
+
284
+ if event_type == "tool_execution_end":
285
+ tool_call_id = get_value(event, "toolCallId", "tool_call_id", default=None)
286
+ if isinstance(tool_call_id, str):
287
+ self._state.pending_tool_calls.discard(tool_call_id)
288
+
289
+ def _sync_pending_tool_calls(self, message: AgentMessage) -> None:
290
+ content = message_content(message)
291
+ if not isinstance(content, list):
292
+ return
293
+
294
+ for block in content:
295
+ if isinstance(block, Mapping) and block.get("type") == "toolCall":
296
+ tool_call_id = block.get("id")
297
+ if isinstance(tool_call_id, str):
298
+ self._state.pending_tool_calls.add(tool_call_id)
299
+
300
+ def _build_loop_config(
301
+ self,
302
+ *,
303
+ get_steering_messages: Callable[[], Any] | None = None,
304
+ get_follow_up_messages: Callable[[], Any] | None = None,
305
+ ) -> dict[str, Any]:
306
+ config: dict[str, Any] = {
307
+ "model": self._state.model,
308
+ "convertToLlm": self._convert_to_llm,
309
+ "getSteeringMessages": get_steering_messages or self._get_steering_messages,
310
+ "getFollowUpMessages": get_follow_up_messages or self._get_follow_up_messages,
311
+ "reasoning": self._state.thinking_level,
312
+ "transport": self.transport,
313
+ "thinkingBudgets": self.thinking_budgets,
314
+ "maxRetryDelayMs": self.max_retry_delay_ms,
315
+ }
316
+
317
+ if callable(self._transform_context):
318
+ config["transformContext"] = self._transform_context
319
+
320
+ if callable(self._get_api_key):
321
+ config["getApiKey"] = self._get_api_key
322
+
323
+ if self._session_id is not None:
324
+ config["sessionId"] = self._session_id
325
+ config["session_id"] = self._session_id
326
+
327
+ return config
328
+
329
+ async def _get_steering_messages(self) -> list[AgentMessage]:
330
+ return self._dequeue(self._steering_queue, self._steering_mode)
331
+
332
+ async def _get_follow_up_messages(self) -> list[AgentMessage]:
333
+ return self._dequeue(self._follow_up_queue, self._follow_up_mode)
334
+
335
+ @staticmethod
336
+ def _dequeue(queue: list[AgentMessage], mode: str) -> list[AgentMessage]:
337
+ if not queue:
338
+ return []
339
+
340
+ if mode == "all":
341
+ values = queue[:]
342
+ queue.clear()
343
+ return values
344
+
345
+ return [queue.pop(0)]
346
+
347
+ @staticmethod
348
+ def _default_convert_to_llm(messages: list[AgentMessage]) -> list[Message]:
349
+ converted: list[Message] = []
350
+ for message in messages:
351
+ role = message_role(message)
352
+ if role in {"user", "assistant", "toolResult"}:
353
+ converted.append(copy.deepcopy(message))
354
+ return converted
355
+
356
+ def _build_initial_state(self, initial_state: Mapping[str, Any] | AgentState | None) -> AgentState:
357
+ if isinstance(initial_state, AgentState):
358
+ return copy.deepcopy(initial_state)
359
+
360
+ state = AgentState()
361
+ if initial_state is None:
362
+ return state
363
+
364
+ state.system_prompt = str(get_value(initial_state, "system_prompt", "systemPrompt", default="") or "")
365
+
366
+ model_value = get_value(initial_state, "model", default=None)
367
+ if isinstance(model_value, Mapping):
368
+ state.model = copy.deepcopy(dict(model_value))
369
+ else:
370
+ state.model = copy.deepcopy(DEFAULT_MODEL)
371
+
372
+ state.thinking_level = get_value(initial_state, "thinking_level", "thinkingLevel", default="off")
373
+
374
+ tools_value = get_value(initial_state, "tools", default=[])
375
+ state.tools = list(tools_value) if isinstance(tools_value, list) else []
376
+
377
+ messages_value = get_value(initial_state, "messages", default=[])
378
+ state.messages = clone_messages(messages_value if isinstance(messages_value, list) else [])
379
+
380
+ state.is_streaming = bool(get_value(initial_state, "is_streaming", "isStreaming", default=False))
381
+ state.stream_message = clone_message(get_value(initial_state, "stream_message", "streamMessage", default=None))
382
+ pending_value = get_value(initial_state, "pending_tool_calls", "pendingToolCalls", default=set())
383
+ if isinstance(pending_value, set):
384
+ state.pending_tool_calls = set(str(item) for item in pending_value)
385
+ elif isinstance(pending_value, list):
386
+ state.pending_tool_calls = set(str(item) for item in pending_value)
387
+ else:
388
+ state.pending_tool_calls = set()
389
+ state.error = get_value(initial_state, "error", default=None)
390
+
391
+ return state
392
+
393
+ def _ensure_not_streaming(self) -> None:
394
+ if self._state.is_streaming:
395
+ raise RuntimeError("Agent is already processing a request")
396
+
397
+ @staticmethod
398
+ def _make_user_message(input_value: str | AgentMessage, images: list[dict[str, Any]] | None = None) -> AgentMessage:
399
+ if isinstance(input_value, Mapping) and message_role(input_value):
400
+ message = clone_message(input_value)
401
+ if get_value(message, "timestamp", default=None) is None:
402
+ message["timestamp"] = now_ms()
403
+ return message
404
+
405
+ if isinstance(input_value, str):
406
+ if images:
407
+ content: list[dict[str, Any]] = [{"type": "text", "text": input_value}, *images]
408
+ else:
409
+ content = input_value
410
+ return {"role": "user", "content": content, "timestamp": now_ms()}
411
+
412
+ raise TypeError("prompt() expects a string or user message mapping")
413
+
414
+
415
+ def _resolve_options_map(options: AgentOptions | Mapping[str, Any]) -> Mapping[str, Any]:
416
+ if isinstance(options, Mapping):
417
+ return options
418
+ return dict(options)