@elizaos/python 2.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +26 -0
- package/README.md +239 -0
- package/elizaos/__init__.py +280 -0
- package/elizaos/action_docs.py +149 -0
- package/elizaos/advanced_capabilities/__init__.py +85 -0
- package/elizaos/advanced_capabilities/actions/__init__.py +54 -0
- package/elizaos/advanced_capabilities/actions/add_contact.py +139 -0
- package/elizaos/advanced_capabilities/actions/follow_room.py +151 -0
- package/elizaos/advanced_capabilities/actions/image_generation.py +148 -0
- package/elizaos/advanced_capabilities/actions/mute_room.py +164 -0
- package/elizaos/advanced_capabilities/actions/remove_contact.py +145 -0
- package/elizaos/advanced_capabilities/actions/roles.py +207 -0
- package/elizaos/advanced_capabilities/actions/schedule_follow_up.py +154 -0
- package/elizaos/advanced_capabilities/actions/search_contacts.py +145 -0
- package/elizaos/advanced_capabilities/actions/send_message.py +187 -0
- package/elizaos/advanced_capabilities/actions/settings.py +151 -0
- package/elizaos/advanced_capabilities/actions/unfollow_room.py +164 -0
- package/elizaos/advanced_capabilities/actions/unmute_room.py +164 -0
- package/elizaos/advanced_capabilities/actions/update_contact.py +164 -0
- package/elizaos/advanced_capabilities/actions/update_entity.py +161 -0
- package/elizaos/advanced_capabilities/evaluators/__init__.py +18 -0
- package/elizaos/advanced_capabilities/evaluators/reflection.py +134 -0
- package/elizaos/advanced_capabilities/evaluators/relationship_extraction.py +203 -0
- package/elizaos/advanced_capabilities/providers/__init__.py +36 -0
- package/elizaos/advanced_capabilities/providers/agent_settings.py +60 -0
- package/elizaos/advanced_capabilities/providers/contacts.py +77 -0
- package/elizaos/advanced_capabilities/providers/facts.py +82 -0
- package/elizaos/advanced_capabilities/providers/follow_ups.py +113 -0
- package/elizaos/advanced_capabilities/providers/knowledge.py +83 -0
- package/elizaos/advanced_capabilities/providers/relationships.py +112 -0
- package/elizaos/advanced_capabilities/providers/roles.py +97 -0
- package/elizaos/advanced_capabilities/providers/settings.py +51 -0
- package/elizaos/advanced_capabilities/services/__init__.py +18 -0
- package/elizaos/advanced_capabilities/services/follow_up.py +138 -0
- package/elizaos/advanced_capabilities/services/rolodex.py +244 -0
- package/elizaos/advanced_memory/__init__.py +3 -0
- package/elizaos/advanced_memory/evaluators.py +97 -0
- package/elizaos/advanced_memory/memory_service.py +556 -0
- package/elizaos/advanced_memory/plugin.py +30 -0
- package/elizaos/advanced_memory/prompts.py +12 -0
- package/elizaos/advanced_memory/providers.py +90 -0
- package/elizaos/advanced_memory/types.py +65 -0
- package/elizaos/advanced_planning/__init__.py +10 -0
- package/elizaos/advanced_planning/actions.py +145 -0
- package/elizaos/advanced_planning/message_classifier.py +127 -0
- package/elizaos/advanced_planning/planning_service.py +712 -0
- package/elizaos/advanced_planning/plugin.py +40 -0
- package/elizaos/advanced_planning/prompts.py +4 -0
- package/elizaos/basic_capabilities/__init__.py +66 -0
- package/elizaos/basic_capabilities/actions/__init__.py +24 -0
- package/elizaos/basic_capabilities/actions/choice.py +140 -0
- package/elizaos/basic_capabilities/actions/ignore.py +66 -0
- package/elizaos/basic_capabilities/actions/none.py +56 -0
- package/elizaos/basic_capabilities/actions/reply.py +120 -0
- package/elizaos/basic_capabilities/providers/__init__.py +54 -0
- package/elizaos/basic_capabilities/providers/action_state.py +113 -0
- package/elizaos/basic_capabilities/providers/actions.py +263 -0
- package/elizaos/basic_capabilities/providers/attachments.py +76 -0
- package/elizaos/basic_capabilities/providers/capabilities.py +62 -0
- package/elizaos/basic_capabilities/providers/character.py +113 -0
- package/elizaos/basic_capabilities/providers/choice.py +73 -0
- package/elizaos/basic_capabilities/providers/context_bench.py +44 -0
- package/elizaos/basic_capabilities/providers/current_time.py +58 -0
- package/elizaos/basic_capabilities/providers/entities.py +99 -0
- package/elizaos/basic_capabilities/providers/evaluators.py +54 -0
- package/elizaos/basic_capabilities/providers/providers_list.py +55 -0
- package/elizaos/basic_capabilities/providers/recent_messages.py +85 -0
- package/elizaos/basic_capabilities/providers/time.py +45 -0
- package/elizaos/basic_capabilities/providers/world.py +93 -0
- package/elizaos/basic_capabilities/services/__init__.py +18 -0
- package/elizaos/basic_capabilities/services/embedding.py +122 -0
- package/elizaos/basic_capabilities/services/task.py +178 -0
- package/elizaos/bootstrap/__init__.py +12 -0
- package/elizaos/bootstrap/actions/__init__.py +68 -0
- package/elizaos/bootstrap/actions/add_contact.py +149 -0
- package/elizaos/bootstrap/actions/choice.py +147 -0
- package/elizaos/bootstrap/actions/follow_room.py +151 -0
- package/elizaos/bootstrap/actions/ignore.py +80 -0
- package/elizaos/bootstrap/actions/image_generation.py +135 -0
- package/elizaos/bootstrap/actions/mute_room.py +151 -0
- package/elizaos/bootstrap/actions/none.py +71 -0
- package/elizaos/bootstrap/actions/remove_contact.py +159 -0
- package/elizaos/bootstrap/actions/reply.py +140 -0
- package/elizaos/bootstrap/actions/roles.py +193 -0
- package/elizaos/bootstrap/actions/schedule_follow_up.py +164 -0
- package/elizaos/bootstrap/actions/search_contacts.py +159 -0
- package/elizaos/bootstrap/actions/send_message.py +173 -0
- package/elizaos/bootstrap/actions/settings.py +165 -0
- package/elizaos/bootstrap/actions/unfollow_room.py +151 -0
- package/elizaos/bootstrap/actions/unmute_room.py +151 -0
- package/elizaos/bootstrap/actions/update_contact.py +178 -0
- package/elizaos/bootstrap/actions/update_entity.py +175 -0
- package/elizaos/bootstrap/autonomy/__init__.py +18 -0
- package/elizaos/bootstrap/autonomy/action.py +197 -0
- package/elizaos/bootstrap/autonomy/providers.py +165 -0
- package/elizaos/bootstrap/autonomy/routes.py +171 -0
- package/elizaos/bootstrap/autonomy/service.py +562 -0
- package/elizaos/bootstrap/autonomy/types.py +18 -0
- package/elizaos/bootstrap/evaluators/__init__.py +19 -0
- package/elizaos/bootstrap/evaluators/reflection.py +118 -0
- package/elizaos/bootstrap/evaluators/relationship_extraction.py +192 -0
- package/elizaos/bootstrap/plugin.py +140 -0
- package/elizaos/bootstrap/providers/__init__.py +80 -0
- package/elizaos/bootstrap/providers/action_state.py +71 -0
- package/elizaos/bootstrap/providers/actions.py +256 -0
- package/elizaos/bootstrap/providers/agent_settings.py +63 -0
- package/elizaos/bootstrap/providers/attachments.py +76 -0
- package/elizaos/bootstrap/providers/capabilities.py +66 -0
- package/elizaos/bootstrap/providers/character.py +128 -0
- package/elizaos/bootstrap/providers/choice.py +77 -0
- package/elizaos/bootstrap/providers/contacts.py +78 -0
- package/elizaos/bootstrap/providers/context_bench.py +49 -0
- package/elizaos/bootstrap/providers/current_time.py +56 -0
- package/elizaos/bootstrap/providers/entities.py +99 -0
- package/elizaos/bootstrap/providers/evaluators.py +58 -0
- package/elizaos/bootstrap/providers/facts.py +86 -0
- package/elizaos/bootstrap/providers/follow_ups.py +116 -0
- package/elizaos/bootstrap/providers/knowledge.py +73 -0
- package/elizaos/bootstrap/providers/providers_list.py +59 -0
- package/elizaos/bootstrap/providers/recent_messages.py +85 -0
- package/elizaos/bootstrap/providers/relationships.py +106 -0
- package/elizaos/bootstrap/providers/roles.py +95 -0
- package/elizaos/bootstrap/providers/settings.py +55 -0
- package/elizaos/bootstrap/providers/time.py +45 -0
- package/elizaos/bootstrap/providers/world.py +97 -0
- package/elizaos/bootstrap/services/__init__.py +26 -0
- package/elizaos/bootstrap/services/embedding.py +122 -0
- package/elizaos/bootstrap/services/follow_up.py +138 -0
- package/elizaos/bootstrap/services/rolodex.py +244 -0
- package/elizaos/bootstrap/services/task.py +585 -0
- package/elizaos/bootstrap/types.py +54 -0
- package/elizaos/bootstrap/utils/__init__.py +7 -0
- package/elizaos/bootstrap/utils/xml.py +69 -0
- package/elizaos/character.py +149 -0
- package/elizaos/logger.py +179 -0
- package/elizaos/media/__init__.py +45 -0
- package/elizaos/media/mime.py +315 -0
- package/elizaos/media/search.py +161 -0
- package/elizaos/media/tests/__init__.py +1 -0
- package/elizaos/media/tests/test_mime.py +117 -0
- package/elizaos/media/tests/test_search.py +156 -0
- package/elizaos/plugin.py +191 -0
- package/elizaos/prompts.py +1071 -0
- package/elizaos/py.typed +0 -0
- package/elizaos/runtime.py +2572 -0
- package/elizaos/services/__init__.py +49 -0
- package/elizaos/services/hook_service.py +511 -0
- package/elizaos/services/message_service.py +1248 -0
- package/elizaos/settings.py +182 -0
- package/elizaos/streaming_context.py +159 -0
- package/elizaos/trajectory_context.py +18 -0
- package/elizaos/types/__init__.py +512 -0
- package/elizaos/types/agent.py +31 -0
- package/elizaos/types/components.py +208 -0
- package/elizaos/types/database.py +64 -0
- package/elizaos/types/environment.py +46 -0
- package/elizaos/types/events.py +47 -0
- package/elizaos/types/memory.py +45 -0
- package/elizaos/types/model.py +393 -0
- package/elizaos/types/plugin.py +188 -0
- package/elizaos/types/primitives.py +100 -0
- package/elizaos/types/runtime.py +460 -0
- package/elizaos/types/service.py +113 -0
- package/elizaos/types/service_interfaces.py +244 -0
- package/elizaos/types/state.py +188 -0
- package/elizaos/types/task.py +29 -0
- package/elizaos/utils/__init__.py +108 -0
- package/elizaos/utils/spec_examples.py +48 -0
- package/elizaos/utils/streaming.py +426 -0
- package/elizaos_atropos_shared/__init__.py +1 -0
- package/elizaos_atropos_shared/canonical_eliza.py +282 -0
- package/package.json +19 -0
- package/pyproject.toml +143 -0
- package/requirements-dev.in +11 -0
- package/requirements-dev.lock +134 -0
- package/requirements.in +9 -0
- package/requirements.lock +64 -0
- package/tests/__init__.py +0 -0
- package/tests/test_action_parameters.py +154 -0
- package/tests/test_actions_provider_examples.py +39 -0
- package/tests/test_advanced_memory_behavior.py +96 -0
- package/tests/test_advanced_memory_flag.py +30 -0
- package/tests/test_advanced_planning_behavior.py +225 -0
- package/tests/test_advanced_planning_flag.py +26 -0
- package/tests/test_autonomy.py +445 -0
- package/tests/test_bootstrap_initialize.py +37 -0
- package/tests/test_character.py +163 -0
- package/tests/test_character_provider.py +231 -0
- package/tests/test_dynamic_prompt_exec.py +561 -0
- package/tests/test_logger_redaction.py +43 -0
- package/tests/test_plugin.py +117 -0
- package/tests/test_runtime.py +422 -0
- package/tests/test_salt_production_enforcement.py +22 -0
- package/tests/test_settings_crypto.py +118 -0
- package/tests/test_streaming.py +295 -0
- package/tests/test_types.py +221 -0
- package/tests/test_uuid_parity.py +46 -0
|
@@ -0,0 +1,2572 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import re
|
|
5
|
+
import uuid
|
|
6
|
+
import xml.etree.ElementTree as ET
|
|
7
|
+
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from elizaos.action_docs import with_canonical_action_docs, with_canonical_evaluator_docs
|
|
12
|
+
from elizaos.logger import Logger, create_logger
|
|
13
|
+
from elizaos.settings import decrypt_secret, get_salt
|
|
14
|
+
from elizaos.types.agent import Character, TemplateType
|
|
15
|
+
from elizaos.types.components import (
|
|
16
|
+
Action,
|
|
17
|
+
ActionResult,
|
|
18
|
+
Evaluator,
|
|
19
|
+
HandlerCallback,
|
|
20
|
+
HandlerOptions,
|
|
21
|
+
Provider,
|
|
22
|
+
)
|
|
23
|
+
from elizaos.types.database import AgentRunSummaryResult, IDatabaseAdapter, Log
|
|
24
|
+
from elizaos.types.environment import Entity, Room, World
|
|
25
|
+
from elizaos.types.events import EventType
|
|
26
|
+
from elizaos.types.memory import Memory
|
|
27
|
+
from elizaos.types.model import GenerateTextOptions, GenerateTextResult, LLMMode, ModelType
|
|
28
|
+
from elizaos.types.plugin import Plugin, Route
|
|
29
|
+
from elizaos.types.primitives import UUID, Content, as_uuid, string_to_uuid
|
|
30
|
+
from elizaos.types.runtime import (
|
|
31
|
+
IAgentRuntime,
|
|
32
|
+
RuntimeSettings,
|
|
33
|
+
SendHandlerFunction,
|
|
34
|
+
StreamingModelHandler,
|
|
35
|
+
TargetInfo,
|
|
36
|
+
)
|
|
37
|
+
from elizaos.types.service import Service
|
|
38
|
+
from elizaos.types.state import RetryBackoffConfig, SchemaRow, State, StateData, StreamEvent
|
|
39
|
+
from elizaos.types.task import TaskWorker
|
|
40
|
+
from elizaos.utils import compose_prompt_from_state as _compose_prompt_from_state
|
|
41
|
+
from elizaos.utils import get_current_time_ms as _get_current_time_ms
|
|
42
|
+
from elizaos.utils.streaming import ValidationStreamExtractor, ValidationStreamExtractorConfig
|
|
43
|
+
|
|
44
|
+
_message_service_class: type | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _get_message_service_class() -> type:
|
|
48
|
+
global _message_service_class
|
|
49
|
+
if _message_service_class is None:
|
|
50
|
+
from elizaos.services.message_service import DefaultMessageService
|
|
51
|
+
|
|
52
|
+
_message_service_class = DefaultMessageService
|
|
53
|
+
return _message_service_class
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class ModelHandler:
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
handler: Callable[[IAgentRuntime, dict[str, Any]], Awaitable[Any]],
|
|
60
|
+
provider: str,
|
|
61
|
+
priority: int = 0,
|
|
62
|
+
) -> None:
|
|
63
|
+
self.handler = handler
|
|
64
|
+
self.provider = provider
|
|
65
|
+
self.priority = priority
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class StreamingModelHandlerWrapper:
|
|
69
|
+
"""Wrapper for streaming model handlers."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
handler: StreamingModelHandler,
|
|
74
|
+
provider: str,
|
|
75
|
+
priority: int = 0,
|
|
76
|
+
) -> None:
|
|
77
|
+
self.handler = handler
|
|
78
|
+
self.provider = provider
|
|
79
|
+
self.priority = priority
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
_anonymous_agent_counter = 0
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class AgentRuntime(IAgentRuntime):
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
character: Character | None = None,
|
|
89
|
+
agent_id: UUID | None = None,
|
|
90
|
+
adapter: IDatabaseAdapter | None = None,
|
|
91
|
+
plugins: list[Plugin] | None = None,
|
|
92
|
+
settings: RuntimeSettings | None = None,
|
|
93
|
+
conversation_length: int = 32,
|
|
94
|
+
log_level: str = "ERROR",
|
|
95
|
+
disable_basic_capabilities: bool = False,
|
|
96
|
+
enable_extended_capabilities: bool = False,
|
|
97
|
+
action_planning: bool | None = None,
|
|
98
|
+
llm_mode: LLMMode | None = None,
|
|
99
|
+
check_should_respond: bool | None = None,
|
|
100
|
+
enable_autonomy: bool = False,
|
|
101
|
+
) -> None:
|
|
102
|
+
global _anonymous_agent_counter
|
|
103
|
+
if character is not None:
|
|
104
|
+
resolved_character = character
|
|
105
|
+
is_anonymous = False
|
|
106
|
+
else:
|
|
107
|
+
_anonymous_agent_counter += 1
|
|
108
|
+
resolved_character = Character(
|
|
109
|
+
name=f"Agent-{_anonymous_agent_counter}",
|
|
110
|
+
bio="An anonymous agent",
|
|
111
|
+
)
|
|
112
|
+
is_anonymous = True
|
|
113
|
+
|
|
114
|
+
self._capability_disable_basic = disable_basic_capabilities
|
|
115
|
+
self._capability_enable_extended = enable_extended_capabilities
|
|
116
|
+
self._capability_enable_autonomy = enable_autonomy
|
|
117
|
+
self._is_anonymous_character = is_anonymous
|
|
118
|
+
self._action_planning_option = action_planning
|
|
119
|
+
self._llm_mode_option = llm_mode
|
|
120
|
+
self._check_should_respond_option = check_should_respond
|
|
121
|
+
self._agent_id = (
|
|
122
|
+
agent_id or resolved_character.id or string_to_uuid(resolved_character.name)
|
|
123
|
+
)
|
|
124
|
+
self._character = resolved_character
|
|
125
|
+
self._adapter = adapter
|
|
126
|
+
self._conversation_length = conversation_length
|
|
127
|
+
self._settings: RuntimeSettings = settings or {}
|
|
128
|
+
self._enable_autonomy = enable_autonomy or (
|
|
129
|
+
self._settings.get("ENABLE_AUTONOMY") in (True, "true")
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
self._providers: list[Provider] = []
|
|
133
|
+
self._actions: list[Action] = []
|
|
134
|
+
self._evaluators: list[Evaluator] = []
|
|
135
|
+
self._plugins: list[Plugin] = []
|
|
136
|
+
self._services: dict[str, list[Service]] = {}
|
|
137
|
+
self._routes: list[Route] = []
|
|
138
|
+
self._events: dict[str, list[Callable[[Any], Awaitable[None]]]] = {}
|
|
139
|
+
self._models: dict[str, list[ModelHandler]] = {}
|
|
140
|
+
self._streaming_models: dict[str, list[StreamingModelHandlerWrapper]] = {}
|
|
141
|
+
self._task_workers: dict[str, TaskWorker] = {}
|
|
142
|
+
self._send_handlers: dict[str, SendHandlerFunction] = {}
|
|
143
|
+
self._state_cache: dict[str, State] = {}
|
|
144
|
+
self._STATE_CACHE_MAX = 200
|
|
145
|
+
self._current_run_id: UUID | None = None
|
|
146
|
+
self._current_room_id: UUID | None = None
|
|
147
|
+
self._action_results: dict[str, list[ActionResult]] = {}
|
|
148
|
+
self._ACTION_RESULTS_MAX = 200
|
|
149
|
+
# Cached action lookup dict (name -> Action). Invalidated on action registration.
|
|
150
|
+
self._action_by_name: dict[str, Action] | None = None
|
|
151
|
+
self._logger = create_logger(namespace=resolved_character.name, level=log_level.upper())
|
|
152
|
+
self._initial_plugins = plugins or []
|
|
153
|
+
self._init_complete = False
|
|
154
|
+
self._init_event = asyncio.Event()
|
|
155
|
+
self._message_service: Any = None
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def logger(self) -> Logger:
|
|
159
|
+
return self._logger
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def message_service(self) -> Any:
|
|
163
|
+
if self._message_service is None:
|
|
164
|
+
service_class = _get_message_service_class()
|
|
165
|
+
self._message_service = service_class()
|
|
166
|
+
return self._message_service
|
|
167
|
+
|
|
168
|
+
@property
|
|
169
|
+
def enable_autonomy(self) -> bool:
|
|
170
|
+
return self._enable_autonomy
|
|
171
|
+
|
|
172
|
+
@enable_autonomy.setter
|
|
173
|
+
def enable_autonomy(self, value: bool) -> None:
|
|
174
|
+
self._enable_autonomy = value
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def agent_id(self) -> UUID:
|
|
178
|
+
return self._agent_id
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def character(self) -> Character:
|
|
182
|
+
return self._character
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def providers(self) -> list[Provider]:
|
|
186
|
+
return self._providers
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def actions(self) -> list[Action]:
|
|
190
|
+
return self._actions
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def evaluators(self) -> list[Evaluator]:
|
|
194
|
+
return self._evaluators
|
|
195
|
+
|
|
196
|
+
@property
|
|
197
|
+
def plugins(self) -> list[Plugin]:
|
|
198
|
+
return self._plugins
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def services(self) -> dict[str, list[Service]]:
|
|
202
|
+
return self._services
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def routes(self) -> list[Route]:
|
|
206
|
+
return self._routes
|
|
207
|
+
|
|
208
|
+
@property
|
|
209
|
+
def events(self) -> dict[str, list[Callable[[Any], Awaitable[None]]]]:
|
|
210
|
+
"""Get registered event handlers."""
|
|
211
|
+
return self._events
|
|
212
|
+
|
|
213
|
+
@property
|
|
214
|
+
def state_cache(self) -> dict[str, State]:
|
|
215
|
+
return self._state_cache
|
|
216
|
+
|
|
217
|
+
def register_database_adapter(self, adapter: IDatabaseAdapter) -> None:
|
|
218
|
+
self._adapter = adapter
|
|
219
|
+
|
|
220
|
+
@property
|
|
221
|
+
def db(self) -> Any:
|
|
222
|
+
if not self._adapter:
|
|
223
|
+
raise RuntimeError("Database adapter not set")
|
|
224
|
+
return self._adapter.db
|
|
225
|
+
|
|
226
|
+
async def initialize(self, config: dict[str, str | int | bool | None] | None = None) -> None:
|
|
227
|
+
_ = config
|
|
228
|
+
self.logger.info("Initializing AgentRuntime...")
|
|
229
|
+
|
|
230
|
+
if self._adapter:
|
|
231
|
+
await self._adapter.initialize()
|
|
232
|
+
self.logger.debug("Database adapter initialized")
|
|
233
|
+
|
|
234
|
+
has_bootstrap = any(p.name == "bootstrap" for p in self._initial_plugins)
|
|
235
|
+
if not has_bootstrap:
|
|
236
|
+
from elizaos.bootstrap import bootstrap_plugin
|
|
237
|
+
|
|
238
|
+
self._initial_plugins.insert(0, bootstrap_plugin)
|
|
239
|
+
|
|
240
|
+
# Advanced planning is built into core, but only loaded when enabled on the character.
|
|
241
|
+
if getattr(self._character, "advanced_planning", None) is True:
|
|
242
|
+
has_adv = any(p.name == "advanced-planning" for p in self._initial_plugins)
|
|
243
|
+
if not has_adv:
|
|
244
|
+
from elizaos.advanced_planning import advanced_planning_plugin
|
|
245
|
+
|
|
246
|
+
# Register after bootstrap so core providers/actions are available.
|
|
247
|
+
insert_at = (
|
|
248
|
+
1
|
|
249
|
+
if self._initial_plugins and self._initial_plugins[0].name == "bootstrap"
|
|
250
|
+
else 0
|
|
251
|
+
)
|
|
252
|
+
self._initial_plugins.insert(insert_at, advanced_planning_plugin)
|
|
253
|
+
|
|
254
|
+
# Advanced memory is built into core, but only loaded when enabled on the character.
|
|
255
|
+
if getattr(self._character, "advanced_memory", None) is True:
|
|
256
|
+
has_adv = any(p.name == "memory" for p in self._initial_plugins)
|
|
257
|
+
if not has_adv:
|
|
258
|
+
from elizaos.advanced_memory import advanced_memory_plugin
|
|
259
|
+
|
|
260
|
+
insert_at = (
|
|
261
|
+
1
|
|
262
|
+
if self._initial_plugins and self._initial_plugins[0].name == "bootstrap"
|
|
263
|
+
else 0
|
|
264
|
+
)
|
|
265
|
+
self._initial_plugins.insert(insert_at, advanced_memory_plugin)
|
|
266
|
+
|
|
267
|
+
for plugin in self._initial_plugins:
|
|
268
|
+
await self.register_plugin(plugin)
|
|
269
|
+
|
|
270
|
+
self._init_complete = True
|
|
271
|
+
self._init_event.set()
|
|
272
|
+
self.logger.info("AgentRuntime initialized successfully")
|
|
273
|
+
|
|
274
|
+
async def register_plugin(self, plugin: Plugin) -> None:
|
|
275
|
+
from elizaos.plugin import register_plugin
|
|
276
|
+
|
|
277
|
+
plugin_to_register = plugin
|
|
278
|
+
|
|
279
|
+
if plugin.name == "bootstrap":
|
|
280
|
+
char_settings_obj = self._character.settings
|
|
281
|
+
char_settings: dict[str, object] = {}
|
|
282
|
+
if hasattr(char_settings_obj, "DESCRIPTOR"):
|
|
283
|
+
from google.protobuf.json_format import MessageToDict
|
|
284
|
+
|
|
285
|
+
char_settings = MessageToDict(char_settings_obj, preserving_proto_field_name=True)
|
|
286
|
+
elif isinstance(char_settings_obj, dict):
|
|
287
|
+
char_settings = char_settings_obj
|
|
288
|
+
|
|
289
|
+
disable_basic = self._capability_disable_basic or (
|
|
290
|
+
char_settings.get("DISABLE_BASIC_CAPABILITIES") in (True, "true")
|
|
291
|
+
)
|
|
292
|
+
enable_extended = self._capability_enable_extended or (
|
|
293
|
+
char_settings.get("ENABLE_EXTENDED_CAPABILITIES") in (True, "true")
|
|
294
|
+
)
|
|
295
|
+
skip_character_provider = self._is_anonymous_character
|
|
296
|
+
|
|
297
|
+
enable_autonomy = self._capability_enable_autonomy or (
|
|
298
|
+
char_settings.get("ENABLE_AUTONOMY") in (True, "true")
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if disable_basic or enable_extended or skip_character_provider or enable_autonomy:
|
|
302
|
+
from elizaos.bootstrap import CapabilityConfig, create_bootstrap_plugin
|
|
303
|
+
|
|
304
|
+
config = CapabilityConfig(
|
|
305
|
+
disable_basic=disable_basic,
|
|
306
|
+
enable_extended=enable_extended,
|
|
307
|
+
skip_character_provider=skip_character_provider,
|
|
308
|
+
enable_autonomy=enable_autonomy,
|
|
309
|
+
)
|
|
310
|
+
plugin_to_register = create_bootstrap_plugin(config)
|
|
311
|
+
|
|
312
|
+
await register_plugin(self, plugin_to_register)
|
|
313
|
+
self._plugins.append(plugin_to_register)
|
|
314
|
+
|
|
315
|
+
def get_service(self, service: str) -> Service | None:
|
|
316
|
+
services = self._services.get(service)
|
|
317
|
+
return services[0] if services else None
|
|
318
|
+
|
|
319
|
+
def get_services_by_type(self, service: str) -> list[Service]:
|
|
320
|
+
return self._services.get(service, [])
|
|
321
|
+
|
|
322
|
+
def get_all_services(self) -> dict[str, list[Service]]:
|
|
323
|
+
return self._services
|
|
324
|
+
|
|
325
|
+
async def register_service(self, service_class: type[Service]) -> None:
|
|
326
|
+
service_type = service_class.service_type
|
|
327
|
+
service = await service_class.start(self)
|
|
328
|
+
|
|
329
|
+
if service_type not in self._services:
|
|
330
|
+
self._services[service_type] = []
|
|
331
|
+
self._services[service_type].append(service)
|
|
332
|
+
|
|
333
|
+
self.logger.debug(f"Service registered: {service_type}")
|
|
334
|
+
|
|
335
|
+
async def get_service_load_promise(self, service_type: str) -> Service:
|
|
336
|
+
if not self._init_complete:
|
|
337
|
+
await self._init_event.wait()
|
|
338
|
+
|
|
339
|
+
service = self.get_service(service_type)
|
|
340
|
+
if not service:
|
|
341
|
+
raise RuntimeError(f"Service not found: {service_type}")
|
|
342
|
+
return service
|
|
343
|
+
|
|
344
|
+
def get_registered_service_types(self) -> list[str]:
|
|
345
|
+
return list(self._services.keys())
|
|
346
|
+
|
|
347
|
+
def has_service(self, service_type: str) -> bool:
|
|
348
|
+
return service_type in self._services and len(self._services[service_type]) > 0
|
|
349
|
+
|
|
350
|
+
def set_setting(self, key: str, value: object | None, secret: bool = False) -> None:
|
|
351
|
+
if value is None:
|
|
352
|
+
return
|
|
353
|
+
|
|
354
|
+
if secret:
|
|
355
|
+
if self._character.secrets is None:
|
|
356
|
+
self._character.secrets = {}
|
|
357
|
+
if isinstance(self._character.secrets, dict):
|
|
358
|
+
self._character.secrets[key] = value # type: ignore[assignment]
|
|
359
|
+
else:
|
|
360
|
+
# Fall back to internal settings dict for protobuf objects
|
|
361
|
+
self._settings[key] = value
|
|
362
|
+
return
|
|
363
|
+
|
|
364
|
+
# Try to set on character.settings if it's a dict
|
|
365
|
+
if isinstance(self._character.settings, dict):
|
|
366
|
+
self._character.settings[key] = value # type: ignore[assignment]
|
|
367
|
+
else:
|
|
368
|
+
# Fall back to internal settings dict for protobuf objects
|
|
369
|
+
self._settings[key] = value
|
|
370
|
+
|
|
371
|
+
def get_setting(self, key: str) -> object | None:
|
|
372
|
+
settings = self._character.settings
|
|
373
|
+
secrets = self._character.secrets
|
|
374
|
+
|
|
375
|
+
nested_secrets: dict[str, object] | None = None
|
|
376
|
+
if isinstance(settings, dict):
|
|
377
|
+
nested = settings.get("secrets")
|
|
378
|
+
if isinstance(nested, dict):
|
|
379
|
+
nested_secrets = nested
|
|
380
|
+
|
|
381
|
+
value: object | None
|
|
382
|
+
if isinstance(secrets, dict) and key in secrets:
|
|
383
|
+
value = secrets.get(key)
|
|
384
|
+
elif isinstance(settings, dict) and key in settings:
|
|
385
|
+
value = settings.get(key)
|
|
386
|
+
elif isinstance(nested_secrets, dict) and key in nested_secrets:
|
|
387
|
+
value = nested_secrets.get(key)
|
|
388
|
+
else:
|
|
389
|
+
value = self._settings.get(key)
|
|
390
|
+
|
|
391
|
+
if value is None:
|
|
392
|
+
return None
|
|
393
|
+
|
|
394
|
+
if isinstance(value, bool):
|
|
395
|
+
return value
|
|
396
|
+
if isinstance(value, (int, float)):
|
|
397
|
+
return value
|
|
398
|
+
if isinstance(value, str):
|
|
399
|
+
decrypted = decrypt_secret(value, get_salt())
|
|
400
|
+
if decrypted == "true":
|
|
401
|
+
return True
|
|
402
|
+
if decrypted == "false":
|
|
403
|
+
return False
|
|
404
|
+
# Cast to str since decrypt_secret returns object for type flexibility
|
|
405
|
+
return str(decrypted) if decrypted is not None else None
|
|
406
|
+
|
|
407
|
+
# Allow non-primitive runtime settings (e.g. objects used by providers/actions).
|
|
408
|
+
return value
|
|
409
|
+
|
|
410
|
+
def get_all_settings(self) -> dict[str, object | None]:
|
|
411
|
+
keys: set[str] = set(self._settings.keys())
|
|
412
|
+
if isinstance(self._character.settings, dict):
|
|
413
|
+
keys.update(self._character.settings.keys())
|
|
414
|
+
nested = self._character.settings.get("secrets")
|
|
415
|
+
if isinstance(nested, dict):
|
|
416
|
+
keys.update(nested.keys())
|
|
417
|
+
if isinstance(self._character.secrets, dict):
|
|
418
|
+
keys.update(self._character.secrets.keys())
|
|
419
|
+
|
|
420
|
+
return {k: self.get_setting(k) for k in keys}
|
|
421
|
+
|
|
422
|
+
def compose_prompt(self, *, state: State, template: TemplateType) -> str:
|
|
423
|
+
return _compose_prompt_from_state(state=state, template=template)
|
|
424
|
+
|
|
425
|
+
def compose_prompt_from_state(self, *, state: State, template: TemplateType) -> str:
|
|
426
|
+
return _compose_prompt_from_state(state=state, template=template)
|
|
427
|
+
|
|
428
|
+
def get_current_time_ms(self) -> int:
|
|
429
|
+
return _get_current_time_ms()
|
|
430
|
+
|
|
431
|
+
def get_conversation_length(self) -> int:
|
|
432
|
+
return self._conversation_length
|
|
433
|
+
|
|
434
|
+
def is_action_planning_enabled(self) -> bool:
|
|
435
|
+
if self._action_planning_option is not None:
|
|
436
|
+
return self._action_planning_option
|
|
437
|
+
|
|
438
|
+
setting = self.get_setting("ACTION_PLANNING")
|
|
439
|
+
if setting is not None:
|
|
440
|
+
if isinstance(setting, bool):
|
|
441
|
+
return setting
|
|
442
|
+
if isinstance(setting, str):
|
|
443
|
+
return setting.lower() == "true"
|
|
444
|
+
|
|
445
|
+
return True
|
|
446
|
+
|
|
447
|
+
def get_llm_mode(self) -> LLMMode:
|
|
448
|
+
if self._llm_mode_option is not None:
|
|
449
|
+
return self._llm_mode_option
|
|
450
|
+
|
|
451
|
+
setting = self.get_setting("LLM_MODE")
|
|
452
|
+
if setting is not None and isinstance(setting, str):
|
|
453
|
+
upper = setting.upper()
|
|
454
|
+
if upper == "SMALL":
|
|
455
|
+
return LLMMode.SMALL
|
|
456
|
+
elif upper == "LARGE":
|
|
457
|
+
return LLMMode.LARGE
|
|
458
|
+
elif upper == "DEFAULT":
|
|
459
|
+
return LLMMode.DEFAULT
|
|
460
|
+
|
|
461
|
+
# Default to DEFAULT (no override)
|
|
462
|
+
return LLMMode.DEFAULT
|
|
463
|
+
|
|
464
|
+
def is_check_should_respond_enabled(self) -> bool:
|
|
465
|
+
"""
|
|
466
|
+
Check if the shouldRespond evaluation is enabled.
|
|
467
|
+
|
|
468
|
+
When enabled (default: True), the agent evaluates whether to respond to each message.
|
|
469
|
+
When disabled, the agent always responds (ChatGPT mode) - useful for direct chat interfaces.
|
|
470
|
+
|
|
471
|
+
Priority: constructor option > character setting CHECK_SHOULD_RESPOND > default (True)
|
|
472
|
+
"""
|
|
473
|
+
# Constructor option takes precedence
|
|
474
|
+
if self._check_should_respond_option is not None:
|
|
475
|
+
return self._check_should_respond_option
|
|
476
|
+
|
|
477
|
+
setting = self.get_setting("CHECK_SHOULD_RESPOND")
|
|
478
|
+
if setting is not None:
|
|
479
|
+
if isinstance(setting, bool):
|
|
480
|
+
return setting
|
|
481
|
+
if isinstance(setting, str):
|
|
482
|
+
return setting.lower() != "false"
|
|
483
|
+
|
|
484
|
+
# Default to True (check should respond is enabled)
|
|
485
|
+
return True
|
|
486
|
+
|
|
487
|
+
# Component registration
|
|
488
|
+
def register_provider(self, provider: Provider) -> None:
|
|
489
|
+
self._providers.append(provider)
|
|
490
|
+
|
|
491
|
+
def register_action(self, action: Action) -> None:
|
|
492
|
+
self._actions.append(with_canonical_action_docs(action))
|
|
493
|
+
self._action_by_name = None # Invalidate cached lookup
|
|
494
|
+
|
|
495
|
+
def register_evaluator(self, evaluator: Evaluator) -> None:
|
|
496
|
+
self._evaluators.append(with_canonical_evaluator_docs(evaluator))
|
|
497
|
+
|
|
498
|
+
@staticmethod
|
|
499
|
+
def _parse_param_value(value: str) -> str | int | float | bool | None:
|
|
500
|
+
raw = value.strip()
|
|
501
|
+
if raw == "":
|
|
502
|
+
return None
|
|
503
|
+
lower = raw.lower()
|
|
504
|
+
if lower == "true":
|
|
505
|
+
return True
|
|
506
|
+
if lower == "false":
|
|
507
|
+
return False
|
|
508
|
+
if lower == "null":
|
|
509
|
+
return None
|
|
510
|
+
# Try int first, then float
|
|
511
|
+
try:
|
|
512
|
+
if re.fullmatch(r"-?\d+", raw):
|
|
513
|
+
return int(raw)
|
|
514
|
+
if re.fullmatch(r"-?\d+\.\d+", raw):
|
|
515
|
+
return float(raw)
|
|
516
|
+
except Exception:
|
|
517
|
+
return raw
|
|
518
|
+
return raw
|
|
519
|
+
|
|
520
|
+
def _parse_action_params(self, params_raw: object | None) -> dict[str, list[dict[str, object]]]:
|
|
521
|
+
"""
|
|
522
|
+
Parse action parameters from either:
|
|
523
|
+
- Nested dict structure (e.g. {"MOVE": {"direction": "north"}})
|
|
524
|
+
- XML string (inner content of <params> or full <params>...</params>)
|
|
525
|
+
"""
|
|
526
|
+
if params_raw is None:
|
|
527
|
+
return {}
|
|
528
|
+
|
|
529
|
+
if isinstance(params_raw, str):
|
|
530
|
+
xml_text = params_raw if "<params" in params_raw else f"<params>{params_raw}</params>"
|
|
531
|
+
try:
|
|
532
|
+
root = ET.fromstring(xml_text)
|
|
533
|
+
except ET.ParseError:
|
|
534
|
+
return {}
|
|
535
|
+
|
|
536
|
+
if root.tag.lower() != "params":
|
|
537
|
+
return {}
|
|
538
|
+
|
|
539
|
+
result: dict[str, list[dict[str, object]]] = {}
|
|
540
|
+
for action_elem in list(root):
|
|
541
|
+
action_name = action_elem.tag.upper()
|
|
542
|
+
action_params: dict[str, object] = {}
|
|
543
|
+
for param_elem in list(action_elem):
|
|
544
|
+
action_params[param_elem.tag] = self._parse_param_value(param_elem.text or "")
|
|
545
|
+
if action_params:
|
|
546
|
+
result.setdefault(action_name, []).append(action_params)
|
|
547
|
+
return result
|
|
548
|
+
|
|
549
|
+
if isinstance(params_raw, dict):
|
|
550
|
+
result_dict: dict[str, list[dict[str, object]]] = {}
|
|
551
|
+
for action_name, params_value in params_raw.items():
|
|
552
|
+
action_key = str(action_name).upper()
|
|
553
|
+
|
|
554
|
+
entries: list[dict[str, object]] = []
|
|
555
|
+
if isinstance(params_value, list):
|
|
556
|
+
for item in params_value:
|
|
557
|
+
if not isinstance(item, dict):
|
|
558
|
+
continue
|
|
559
|
+
inner_action_params: dict[str, object] = {}
|
|
560
|
+
for param_name, raw_value in item.items():
|
|
561
|
+
key = str(param_name)
|
|
562
|
+
if isinstance(raw_value, str):
|
|
563
|
+
inner_action_params[key] = self._parse_param_value(raw_value)
|
|
564
|
+
else:
|
|
565
|
+
inner_action_params[key] = raw_value
|
|
566
|
+
if inner_action_params:
|
|
567
|
+
entries.append(inner_action_params)
|
|
568
|
+
elif isinstance(params_value, dict):
|
|
569
|
+
inner_action_params = {}
|
|
570
|
+
for param_name, raw_value in params_value.items():
|
|
571
|
+
key = str(param_name)
|
|
572
|
+
if isinstance(raw_value, str):
|
|
573
|
+
inner_action_params[key] = self._parse_param_value(raw_value)
|
|
574
|
+
else:
|
|
575
|
+
inner_action_params[key] = raw_value
|
|
576
|
+
if inner_action_params:
|
|
577
|
+
entries.append(inner_action_params)
|
|
578
|
+
else:
|
|
579
|
+
continue
|
|
580
|
+
|
|
581
|
+
if entries:
|
|
582
|
+
result_dict[action_key] = entries
|
|
583
|
+
return result_dict
|
|
584
|
+
|
|
585
|
+
return {}
|
|
586
|
+
|
|
587
|
+
def _validate_action_params(
|
|
588
|
+
self, action: Action, extracted: dict[str, object] | None
|
|
589
|
+
) -> tuple[bool, dict[str, object] | None, list[str]]:
|
|
590
|
+
errors: list[str] = []
|
|
591
|
+
validated: dict[str, object] = {}
|
|
592
|
+
|
|
593
|
+
if not action.parameters:
|
|
594
|
+
return True, None, []
|
|
595
|
+
|
|
596
|
+
for param_def in action.parameters:
|
|
597
|
+
extracted_value = extracted.get(param_def.name) if extracted else None
|
|
598
|
+
if extracted_value is None and extracted:
|
|
599
|
+
# Be tolerant to parameter name casing produced by models (e.g. "Expression" vs "expression")
|
|
600
|
+
for k, v in extracted.items():
|
|
601
|
+
if isinstance(k, str) and k.lower() == param_def.name.lower():
|
|
602
|
+
extracted_value = v
|
|
603
|
+
break
|
|
604
|
+
|
|
605
|
+
# Treat explicit None as missing
|
|
606
|
+
if extracted_value is None:
|
|
607
|
+
if param_def.required:
|
|
608
|
+
errors.append(
|
|
609
|
+
f"Required parameter '{param_def.name}' was not provided for action {action.name}"
|
|
610
|
+
)
|
|
611
|
+
elif getattr(param_def.schema, "default_value", None):
|
|
612
|
+
validated[param_def.name] = param_def.schema.default_value
|
|
613
|
+
continue
|
|
614
|
+
|
|
615
|
+
schema_type = param_def.schema.type
|
|
616
|
+
|
|
617
|
+
if schema_type == "string":
|
|
618
|
+
# Parameters often come from XML and may be parsed into scalars
|
|
619
|
+
# (e.g., "200" -> int 200). For string-typed params, coerce
|
|
620
|
+
# scalars back to strings rather than failing validation.
|
|
621
|
+
if isinstance(extracted_value, bool):
|
|
622
|
+
extracted_value = "true" if extracted_value else "false"
|
|
623
|
+
elif isinstance(extracted_value, (int, float)):
|
|
624
|
+
extracted_value = str(extracted_value)
|
|
625
|
+
if not isinstance(extracted_value, str):
|
|
626
|
+
errors.append(
|
|
627
|
+
f"Parameter '{param_def.name}' expected string, got {type(extracted_value).__name__}"
|
|
628
|
+
)
|
|
629
|
+
continue
|
|
630
|
+
if (
|
|
631
|
+
param_def.schema.enum_values
|
|
632
|
+
and extracted_value not in param_def.schema.enum_values
|
|
633
|
+
):
|
|
634
|
+
errors.append(
|
|
635
|
+
f"Parameter '{param_def.name}' value '{extracted_value}' not in allowed values: {', '.join(param_def.schema.enum_values)}"
|
|
636
|
+
)
|
|
637
|
+
continue
|
|
638
|
+
if param_def.schema.pattern and not re.fullmatch(
|
|
639
|
+
param_def.schema.pattern, extracted_value
|
|
640
|
+
):
|
|
641
|
+
errors.append(
|
|
642
|
+
f"Parameter '{param_def.name}' value '{extracted_value}' does not match pattern: {param_def.schema.pattern}"
|
|
643
|
+
)
|
|
644
|
+
continue
|
|
645
|
+
validated[param_def.name] = extracted_value
|
|
646
|
+
continue
|
|
647
|
+
|
|
648
|
+
if schema_type == "number":
|
|
649
|
+
if isinstance(extracted_value, bool) or not isinstance(
|
|
650
|
+
extracted_value, (int, float)
|
|
651
|
+
):
|
|
652
|
+
errors.append(
|
|
653
|
+
f"Parameter '{param_def.name}' expected number, got {type(extracted_value).__name__}"
|
|
654
|
+
)
|
|
655
|
+
continue
|
|
656
|
+
if param_def.schema.minimum is not None and float(extracted_value) < float(
|
|
657
|
+
param_def.schema.minimum
|
|
658
|
+
):
|
|
659
|
+
errors.append(
|
|
660
|
+
f"Parameter '{param_def.name}' value {extracted_value} is below minimum {param_def.schema.minimum}"
|
|
661
|
+
)
|
|
662
|
+
continue
|
|
663
|
+
if param_def.schema.maximum is not None and float(extracted_value) > float(
|
|
664
|
+
param_def.schema.maximum
|
|
665
|
+
):
|
|
666
|
+
errors.append(
|
|
667
|
+
f"Parameter '{param_def.name}' value {extracted_value} is above maximum {param_def.schema.maximum}"
|
|
668
|
+
)
|
|
669
|
+
continue
|
|
670
|
+
validated[param_def.name] = extracted_value
|
|
671
|
+
continue
|
|
672
|
+
|
|
673
|
+
if schema_type == "boolean":
|
|
674
|
+
if not isinstance(extracted_value, bool):
|
|
675
|
+
errors.append(
|
|
676
|
+
f"Parameter '{param_def.name}' expected boolean, got {type(extracted_value).__name__}"
|
|
677
|
+
)
|
|
678
|
+
continue
|
|
679
|
+
validated[param_def.name] = extracted_value
|
|
680
|
+
continue
|
|
681
|
+
|
|
682
|
+
if schema_type == "array":
|
|
683
|
+
if not isinstance(extracted_value, list):
|
|
684
|
+
errors.append(
|
|
685
|
+
f"Parameter '{param_def.name}' expected array, got {type(extracted_value).__name__}"
|
|
686
|
+
)
|
|
687
|
+
continue
|
|
688
|
+
validated[param_def.name] = extracted_value
|
|
689
|
+
continue
|
|
690
|
+
|
|
691
|
+
if schema_type == "object":
|
|
692
|
+
if not isinstance(extracted_value, dict):
|
|
693
|
+
errors.append(
|
|
694
|
+
f"Parameter '{param_def.name}' expected object, got {type(extracted_value).__name__}"
|
|
695
|
+
)
|
|
696
|
+
continue
|
|
697
|
+
validated[param_def.name] = extracted_value
|
|
698
|
+
continue
|
|
699
|
+
|
|
700
|
+
validated[param_def.name] = extracted_value
|
|
701
|
+
|
|
702
|
+
return (len(errors) == 0, validated if validated else None, errors)
|
|
703
|
+
|
|
704
|
+
async def process_actions(
|
|
705
|
+
self,
|
|
706
|
+
message: Memory,
|
|
707
|
+
responses: list[Memory],
|
|
708
|
+
state: State | None = None,
|
|
709
|
+
callback: HandlerCallback | None = None,
|
|
710
|
+
_options: dict[str, Any] | None = None,
|
|
711
|
+
) -> None:
|
|
712
|
+
"""Process actions selected by the model response (supports optional <params>)."""
|
|
713
|
+
if not responses:
|
|
714
|
+
return
|
|
715
|
+
|
|
716
|
+
actions_to_process: list[str] = []
|
|
717
|
+
if self.is_action_planning_enabled():
|
|
718
|
+
for response in responses:
|
|
719
|
+
if response.content.actions:
|
|
720
|
+
actions_to_process.extend(
|
|
721
|
+
[a for a in response.content.actions if isinstance(a, str)]
|
|
722
|
+
)
|
|
723
|
+
else:
|
|
724
|
+
for response in responses:
|
|
725
|
+
if response.content.actions:
|
|
726
|
+
first = response.content.actions[0]
|
|
727
|
+
if isinstance(first, str):
|
|
728
|
+
actions_to_process = [first]
|
|
729
|
+
break
|
|
730
|
+
|
|
731
|
+
if not actions_to_process:
|
|
732
|
+
return
|
|
733
|
+
|
|
734
|
+
for response in responses:
|
|
735
|
+
if not response.content.actions:
|
|
736
|
+
continue
|
|
737
|
+
|
|
738
|
+
# Track Nth occurrence of each action within this response so repeated actions
|
|
739
|
+
# (e.g., multiple WRITE_FILE actions) consume the corresponding Nth params entry.
|
|
740
|
+
param_index: dict[str, int] = {}
|
|
741
|
+
|
|
742
|
+
for response_action in response.content.actions:
|
|
743
|
+
if not isinstance(response_action, str):
|
|
744
|
+
continue
|
|
745
|
+
|
|
746
|
+
# Respect single-action mode: only execute the first collected action
|
|
747
|
+
if not self.is_action_planning_enabled() and actions_to_process:
|
|
748
|
+
if response_action != actions_to_process[0]:
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
action = self._get_action_by_name(response_action)
|
|
752
|
+
if not action:
|
|
753
|
+
self.logger.error(f"Action not found: {response_action}")
|
|
754
|
+
continue
|
|
755
|
+
|
|
756
|
+
options_obj = HandlerOptions()
|
|
757
|
+
valid = True
|
|
758
|
+
validated_params: dict[str, object] | None = None
|
|
759
|
+
errors: list[str] = []
|
|
760
|
+
|
|
761
|
+
if action.parameters:
|
|
762
|
+
params_raw = getattr(response.content, "params", None)
|
|
763
|
+
# Fallback: params may be stored in content.data["params"]
|
|
764
|
+
# when Content is a protobuf without a native params field.
|
|
765
|
+
if params_raw is None and response.content.data:
|
|
766
|
+
try:
|
|
767
|
+
# Protobuf Struct uses [] access, not .get()
|
|
768
|
+
if "params" in response.content.data:
|
|
769
|
+
data_params = response.content.data["params"]
|
|
770
|
+
if data_params is not None:
|
|
771
|
+
# Convert protobuf Struct to dict for _parse_action_params
|
|
772
|
+
from google.protobuf.json_format import MessageToDict
|
|
773
|
+
if hasattr(data_params, "DESCRIPTOR"):
|
|
774
|
+
params_raw = MessageToDict(data_params)
|
|
775
|
+
else:
|
|
776
|
+
params_raw = data_params
|
|
777
|
+
except (AttributeError, TypeError, KeyError):
|
|
778
|
+
pass
|
|
779
|
+
params_by_action = self._parse_action_params(params_raw)
|
|
780
|
+
action_key = response_action.upper()
|
|
781
|
+
extracted_list = params_by_action.get(action_key) or params_by_action.get(
|
|
782
|
+
action.name.upper()
|
|
783
|
+
)
|
|
784
|
+
|
|
785
|
+
idx = param_index.get(action_key, 0)
|
|
786
|
+
extracted: dict[str, object] | None = None
|
|
787
|
+
if isinstance(extracted_list, list):
|
|
788
|
+
if idx < len(extracted_list):
|
|
789
|
+
entry = extracted_list[idx]
|
|
790
|
+
if isinstance(entry, dict):
|
|
791
|
+
extracted = entry
|
|
792
|
+
param_index[action_key] = idx + 1
|
|
793
|
+
elif isinstance(extracted_list, dict):
|
|
794
|
+
extracted = extracted_list
|
|
795
|
+
valid, validated_params, errors = self._validate_action_params(
|
|
796
|
+
action, extracted
|
|
797
|
+
)
|
|
798
|
+
if not valid:
|
|
799
|
+
self.logger.warning(
|
|
800
|
+
"Action parameter validation incomplete",
|
|
801
|
+
src="runtime:actions",
|
|
802
|
+
actionName=action.name,
|
|
803
|
+
errors=errors,
|
|
804
|
+
)
|
|
805
|
+
try:
|
|
806
|
+
options_obj.parameter_errors = errors
|
|
807
|
+
except (AttributeError, ValueError):
|
|
808
|
+
# Protobuf HandlerOptions may not have parameter_errors field
|
|
809
|
+
pass
|
|
810
|
+
|
|
811
|
+
if validated_params:
|
|
812
|
+
from google.protobuf import struct_pb2
|
|
813
|
+
|
|
814
|
+
from elizaos.types.components import ActionParameters
|
|
815
|
+
|
|
816
|
+
struct_values = struct_pb2.Struct()
|
|
817
|
+
for k, v in validated_params.items():
|
|
818
|
+
if v is None:
|
|
819
|
+
struct_values.fields[k].null_value = 0
|
|
820
|
+
elif isinstance(v, bool):
|
|
821
|
+
struct_values.fields[k].bool_value = v
|
|
822
|
+
elif isinstance(v, (int, float)):
|
|
823
|
+
struct_values.fields[k].number_value = float(v)
|
|
824
|
+
elif isinstance(v, str):
|
|
825
|
+
struct_values.fields[k].string_value = v
|
|
826
|
+
else:
|
|
827
|
+
struct_values.fields[k].string_value = str(v)
|
|
828
|
+
options_obj.parameters.CopyFrom(ActionParameters(values=struct_values))
|
|
829
|
+
|
|
830
|
+
# Ensure options.parameters is always a plain dict for action handlers.
|
|
831
|
+
# Proto HandlerOptions.parameters is ActionParameters (not dict-like),
|
|
832
|
+
# but handlers universally call options.parameters.get("key").
|
|
833
|
+
_params_dict = validated_params or {}
|
|
834
|
+
if not _params_dict and hasattr(options_obj, "parameters"):
|
|
835
|
+
try:
|
|
836
|
+
pv = options_obj.parameters
|
|
837
|
+
if isinstance(pv, dict):
|
|
838
|
+
_params_dict = pv
|
|
839
|
+
elif hasattr(pv, "values") and hasattr(pv.values, "items"):
|
|
840
|
+
from google.protobuf.json_format import MessageToDict
|
|
841
|
+
|
|
842
|
+
_params_dict = MessageToDict(pv.values)
|
|
843
|
+
except Exception:
|
|
844
|
+
pass
|
|
845
|
+
options_obj = type(
|
|
846
|
+
"_Opts",
|
|
847
|
+
(),
|
|
848
|
+
{
|
|
849
|
+
"parameters": _params_dict,
|
|
850
|
+
"parameter_errors": errors,
|
|
851
|
+
},
|
|
852
|
+
)()
|
|
853
|
+
|
|
854
|
+
result = await action.handler(
|
|
855
|
+
self,
|
|
856
|
+
message,
|
|
857
|
+
state,
|
|
858
|
+
options_obj,
|
|
859
|
+
callback,
|
|
860
|
+
responses,
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
# Store result
|
|
864
|
+
if message.id:
|
|
865
|
+
message_id = str(message.id)
|
|
866
|
+
if message_id not in self._action_results:
|
|
867
|
+
self._action_results[message_id] = []
|
|
868
|
+
if result:
|
|
869
|
+
self._action_results[message_id].append(result)
|
|
870
|
+
# LRU eviction for action results to prevent unbounded growth
|
|
871
|
+
if len(self._action_results) > self._ACTION_RESULTS_MAX:
|
|
872
|
+
excess = len(self._action_results) - self._ACTION_RESULTS_MAX
|
|
873
|
+
keys_to_remove = list(self._action_results.keys())[:excess]
|
|
874
|
+
for k in keys_to_remove:
|
|
875
|
+
del self._action_results[k]
|
|
876
|
+
|
|
877
|
+
def _get_action_by_name(self, name: str) -> Action | None:
|
|
878
|
+
"""O(1) action lookup using cached name -> Action dict."""
|
|
879
|
+
if self._action_by_name is None:
|
|
880
|
+
self._action_by_name = {a.name: a for a in self._actions}
|
|
881
|
+
return self._action_by_name.get(name)
|
|
882
|
+
|
|
883
|
+
def get_action_results(self, message_id: UUID) -> list[ActionResult]:
|
|
884
|
+
return self._action_results.get(str(message_id), [])
|
|
885
|
+
|
|
886
|
+
def get_available_actions(self) -> list[Action]:
|
|
887
|
+
"""Get all registered actions."""
|
|
888
|
+
return self._actions
|
|
889
|
+
|
|
890
|
+
async def evaluate_pre(
|
|
891
|
+
self,
|
|
892
|
+
message: Memory,
|
|
893
|
+
state: State | None = None,
|
|
894
|
+
) -> "PreEvaluatorResult":
|
|
895
|
+
"""Run phase='pre' evaluators as middleware before memory storage.
|
|
896
|
+
|
|
897
|
+
Pre-evaluators can inspect, rewrite, or block a message before it
|
|
898
|
+
reaches the agent. If any pre-evaluator sets ``blocked=True``, the
|
|
899
|
+
message is dropped. If any sets ``rewritten_text``, the last rewrite
|
|
900
|
+
wins.
|
|
901
|
+
|
|
902
|
+
Returns:
|
|
903
|
+
A merged PreEvaluatorResult.
|
|
904
|
+
"""
|
|
905
|
+
from elizaos.types.components import PreEvaluatorResult
|
|
906
|
+
|
|
907
|
+
pre_evaluators = [e for e in self._evaluators if getattr(e, "phase", "post") == "pre"]
|
|
908
|
+
if not pre_evaluators:
|
|
909
|
+
return PreEvaluatorResult(blocked=False)
|
|
910
|
+
|
|
911
|
+
blocked = False
|
|
912
|
+
rewritten_text: str | None = None
|
|
913
|
+
reason: str | None = None
|
|
914
|
+
|
|
915
|
+
for evaluator in pre_evaluators:
|
|
916
|
+
try:
|
|
917
|
+
is_valid = await evaluator.validate(self, message, state)
|
|
918
|
+
if not is_valid:
|
|
919
|
+
continue
|
|
920
|
+
|
|
921
|
+
result = await evaluator.handler(
|
|
922
|
+
self,
|
|
923
|
+
message,
|
|
924
|
+
state,
|
|
925
|
+
HandlerOptions(),
|
|
926
|
+
None,
|
|
927
|
+
None,
|
|
928
|
+
)
|
|
929
|
+
|
|
930
|
+
# Handler may return a PreEvaluatorResult-like object or ActionResult
|
|
931
|
+
if result and hasattr(result, "success"):
|
|
932
|
+
# ActionResult — interpret success=False as blocked
|
|
933
|
+
if not result.success:
|
|
934
|
+
blocked = True
|
|
935
|
+
reason = result.error or result.text or reason
|
|
936
|
+
self.logger.warning(
|
|
937
|
+
f'Pre-evaluator "{evaluator.name}" blocked message: {reason}'
|
|
938
|
+
)
|
|
939
|
+
elif isinstance(result, dict):
|
|
940
|
+
if result.get("blocked"):
|
|
941
|
+
blocked = True
|
|
942
|
+
reason = result.get("reason", reason)
|
|
943
|
+
self.logger.warning(
|
|
944
|
+
f'Pre-evaluator "{evaluator.name}" blocked message: {reason}'
|
|
945
|
+
)
|
|
946
|
+
if "rewritten_text" in result and result["rewritten_text"] is not None:
|
|
947
|
+
rewritten_text = result["rewritten_text"]
|
|
948
|
+
|
|
949
|
+
except Exception as e:
|
|
950
|
+
self.logger.error(f'Pre-evaluator "{evaluator.name}" failed: {e}')
|
|
951
|
+
|
|
952
|
+
return PreEvaluatorResult(
|
|
953
|
+
blocked=blocked,
|
|
954
|
+
rewritten_text=rewritten_text,
|
|
955
|
+
reason=reason,
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
async def evaluate(
|
|
959
|
+
self,
|
|
960
|
+
message: Memory,
|
|
961
|
+
state: State | None = None,
|
|
962
|
+
did_respond: bool = False,
|
|
963
|
+
callback: HandlerCallback | None = None,
|
|
964
|
+
responses: list[Memory] | None = None,
|
|
965
|
+
) -> list[Evaluator] | None:
|
|
966
|
+
"""Run phase='post' (default) evaluators on a message."""
|
|
967
|
+
ran_evaluators: list[Evaluator] = []
|
|
968
|
+
|
|
969
|
+
for evaluator in self._evaluators:
|
|
970
|
+
# Skip pre-evaluators (they run via evaluate_pre)
|
|
971
|
+
if getattr(evaluator, "phase", "post") == "pre":
|
|
972
|
+
continue
|
|
973
|
+
|
|
974
|
+
should_run = evaluator.always_run or did_respond
|
|
975
|
+
|
|
976
|
+
if should_run:
|
|
977
|
+
try:
|
|
978
|
+
is_valid = await evaluator.validate(self, message, state)
|
|
979
|
+
if is_valid:
|
|
980
|
+
await evaluator.handler(
|
|
981
|
+
self,
|
|
982
|
+
message,
|
|
983
|
+
state,
|
|
984
|
+
HandlerOptions(),
|
|
985
|
+
callback,
|
|
986
|
+
responses,
|
|
987
|
+
)
|
|
988
|
+
ran_evaluators.append(evaluator)
|
|
989
|
+
except Exception as e:
|
|
990
|
+
self.logger.error(f"Evaluator {evaluator.name} failed: {e}")
|
|
991
|
+
|
|
992
|
+
return ran_evaluators if ran_evaluators else None
|
|
993
|
+
|
|
994
|
+
async def ensure_connections(
|
|
995
|
+
self,
|
|
996
|
+
entities: list[Entity],
|
|
997
|
+
rooms: list[Room],
|
|
998
|
+
_source: str,
|
|
999
|
+
world: World,
|
|
1000
|
+
) -> None:
|
|
1001
|
+
"""Ensure connections are set up."""
|
|
1002
|
+
# Ensure world exists
|
|
1003
|
+
await self.ensure_world_exists(world)
|
|
1004
|
+
|
|
1005
|
+
# Ensure rooms exist
|
|
1006
|
+
for room in rooms:
|
|
1007
|
+
await self.ensure_room_exists(room)
|
|
1008
|
+
|
|
1009
|
+
for entity in entities:
|
|
1010
|
+
if entity.id:
|
|
1011
|
+
await self.create_entities([entity])
|
|
1012
|
+
for room in rooms:
|
|
1013
|
+
await self.ensure_participant_in_room(entity.id, room.id)
|
|
1014
|
+
|
|
1015
|
+
async def ensure_connection(
|
|
1016
|
+
self,
|
|
1017
|
+
entity_id: UUID,
|
|
1018
|
+
room_id: UUID,
|
|
1019
|
+
world_id: UUID,
|
|
1020
|
+
user_name: str | None = None,
|
|
1021
|
+
name: str | None = None,
|
|
1022
|
+
world_name: str | None = None,
|
|
1023
|
+
source: str | None = None,
|
|
1024
|
+
channel_id: str | None = None,
|
|
1025
|
+
message_server_id: UUID | None = None,
|
|
1026
|
+
channel_type: str | None = None,
|
|
1027
|
+
user_id: UUID | None = None,
|
|
1028
|
+
metadata: dict[str, Any] | None = None,
|
|
1029
|
+
) -> None:
|
|
1030
|
+
"""Ensure a connection is set up."""
|
|
1031
|
+
# Implementation depends on database adapter
|
|
1032
|
+
pass
|
|
1033
|
+
|
|
1034
|
+
async def ensure_participant_in_room(self, entity_id: UUID, room_id: UUID) -> None:
|
|
1035
|
+
"""Ensure an entity is a participant in a room."""
|
|
1036
|
+
if self._adapter:
|
|
1037
|
+
is_participant = await self._adapter.is_room_participant(room_id, entity_id)
|
|
1038
|
+
if not is_participant:
|
|
1039
|
+
await self._adapter.add_participants_room([entity_id], room_id)
|
|
1040
|
+
|
|
1041
|
+
async def ensure_world_exists(self, world: World) -> None:
|
|
1042
|
+
if self._adapter:
|
|
1043
|
+
existing = await self._adapter.get_world(world.id)
|
|
1044
|
+
if not existing:
|
|
1045
|
+
await self._adapter.create_world(world)
|
|
1046
|
+
|
|
1047
|
+
async def ensure_room_exists(self, room: Room) -> None:
|
|
1048
|
+
"""Ensure a room exists."""
|
|
1049
|
+
if self._adapter:
|
|
1050
|
+
rooms = await self._adapter.get_rooms_by_ids([room.id])
|
|
1051
|
+
if not rooms or len(rooms) == 0:
|
|
1052
|
+
await self._adapter.create_rooms([room])
|
|
1053
|
+
|
|
1054
|
+
async def compose_state(
|
|
1055
|
+
self,
|
|
1056
|
+
message: Memory,
|
|
1057
|
+
include_list: list[str] | None = None,
|
|
1058
|
+
only_include: bool = False,
|
|
1059
|
+
skip_cache: bool = False,
|
|
1060
|
+
) -> State:
|
|
1061
|
+
# If we're running inside a trajectory step, always bypass the state cache
|
|
1062
|
+
# so providers are executed and logged for training/benchmark traces.
|
|
1063
|
+
traj_step_id: str | None = None
|
|
1064
|
+
if message.metadata is not None:
|
|
1065
|
+
maybe_step = getattr(message.metadata, "trajectoryStepId", None)
|
|
1066
|
+
if isinstance(maybe_step, str) and maybe_step:
|
|
1067
|
+
traj_step_id = maybe_step
|
|
1068
|
+
skip_cache = True
|
|
1069
|
+
|
|
1070
|
+
cache_key = str(message.room_id)
|
|
1071
|
+
|
|
1072
|
+
if not skip_cache and cache_key in self._state_cache:
|
|
1073
|
+
return self._state_cache[cache_key]
|
|
1074
|
+
|
|
1075
|
+
# Create new state
|
|
1076
|
+
state = State(
|
|
1077
|
+
values={},
|
|
1078
|
+
data=StateData(),
|
|
1079
|
+
text="",
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
providers_to_run = self._providers
|
|
1083
|
+
if include_list and only_include:
|
|
1084
|
+
# Exclusive mode: run ONLY providers in the include_list
|
|
1085
|
+
providers_to_run = [p for p in self._providers if p.name in include_list]
|
|
1086
|
+
elif include_list:
|
|
1087
|
+
# Additive mode (TypeScript parity): run all non-private/non-dynamic providers
|
|
1088
|
+
# PLUS any explicitly included providers (which may be private/dynamic)
|
|
1089
|
+
include_set = set(include_list)
|
|
1090
|
+
providers_to_run = [
|
|
1091
|
+
p
|
|
1092
|
+
for p in self._providers
|
|
1093
|
+
if (not p.private and not getattr(p, "dynamic", False)) or p.name in include_set
|
|
1094
|
+
]
|
|
1095
|
+
|
|
1096
|
+
# Sort by position
|
|
1097
|
+
providers_to_run.sort(key=lambda p: p.position or 0)
|
|
1098
|
+
|
|
1099
|
+
# Optional trajectory logging (end-to-end capture)
|
|
1100
|
+
|
|
1101
|
+
from typing import Protocol, runtime_checkable
|
|
1102
|
+
|
|
1103
|
+
@runtime_checkable
|
|
1104
|
+
class _TrajectoryLogger(Protocol):
|
|
1105
|
+
def log_provider_access(
|
|
1106
|
+
self,
|
|
1107
|
+
*,
|
|
1108
|
+
step_id: str,
|
|
1109
|
+
provider_name: str,
|
|
1110
|
+
data: dict[str, str | int | float | bool | None],
|
|
1111
|
+
purpose: str,
|
|
1112
|
+
query: dict[str, str | int | float | bool | None] | None = None,
|
|
1113
|
+
) -> None: ...
|
|
1114
|
+
|
|
1115
|
+
traj_svc = self.get_service("trajectory_logger")
|
|
1116
|
+
traj_logger = traj_svc if isinstance(traj_svc, _TrajectoryLogger) else None
|
|
1117
|
+
|
|
1118
|
+
def _as_json_scalar(value: object) -> str | int | float | bool | None:
|
|
1119
|
+
if value is None:
|
|
1120
|
+
return None
|
|
1121
|
+
if isinstance(value, (str, int, float, bool)):
|
|
1122
|
+
if isinstance(value, str):
|
|
1123
|
+
return value[:2000]
|
|
1124
|
+
return value
|
|
1125
|
+
return str(value)[:2000]
|
|
1126
|
+
|
|
1127
|
+
def _as_json_dict(data: object) -> dict[str, str | int | float | bool | None]:
|
|
1128
|
+
if not isinstance(data, dict):
|
|
1129
|
+
return {"value": _as_json_scalar(data)}
|
|
1130
|
+
out: dict[str, str | int | float | bool | None] = {}
|
|
1131
|
+
for k, v in data.items():
|
|
1132
|
+
if isinstance(k, str):
|
|
1133
|
+
out[k] = _as_json_scalar(v)
|
|
1134
|
+
return out
|
|
1135
|
+
|
|
1136
|
+
text_parts: list[str] = []
|
|
1137
|
+
for provider in providers_to_run:
|
|
1138
|
+
if provider.private:
|
|
1139
|
+
continue
|
|
1140
|
+
|
|
1141
|
+
result = await provider.get(self, message, state)
|
|
1142
|
+
if result.text:
|
|
1143
|
+
text_parts.append(result.text)
|
|
1144
|
+
if result.values:
|
|
1145
|
+
for k, v in result.values.items():
|
|
1146
|
+
if hasattr(state.values, k):
|
|
1147
|
+
setattr(state.values, k, v)
|
|
1148
|
+
else:
|
|
1149
|
+
state.values.extra[k] = v
|
|
1150
|
+
if result.data:
|
|
1151
|
+
# Access map entry to create it, then update its data struct
|
|
1152
|
+
entry = state.data.providers[provider.name]
|
|
1153
|
+
for k, v in result.data.items():
|
|
1154
|
+
entry.data[k] = v
|
|
1155
|
+
|
|
1156
|
+
# Log provider access to trajectory service (if available)
|
|
1157
|
+
if traj_step_id and traj_logger is not None:
|
|
1158
|
+
try:
|
|
1159
|
+
user_text = message.content.text or ""
|
|
1160
|
+
traj_logger.log_provider_access(
|
|
1161
|
+
step_id=traj_step_id,
|
|
1162
|
+
provider_name=provider.name,
|
|
1163
|
+
data=_as_json_dict(result.data or {}),
|
|
1164
|
+
purpose="compose_state",
|
|
1165
|
+
query={"message": _as_json_scalar(user_text)},
|
|
1166
|
+
)
|
|
1167
|
+
except Exception:
|
|
1168
|
+
# Trajectory logging must never break core message flow.
|
|
1169
|
+
pass
|
|
1170
|
+
|
|
1171
|
+
state.text = "\n".join(text_parts)
|
|
1172
|
+
# Match TypeScript behavior: expose providers text under {{providers}}.
|
|
1173
|
+
state.values.providers = state.text
|
|
1174
|
+
|
|
1175
|
+
if not skip_cache:
|
|
1176
|
+
self._state_cache[cache_key] = state
|
|
1177
|
+
# LRU eviction: remove oldest entries when cache exceeds limit
|
|
1178
|
+
if len(self._state_cache) > self._STATE_CACHE_MAX:
|
|
1179
|
+
excess = len(self._state_cache) - self._STATE_CACHE_MAX
|
|
1180
|
+
keys_to_remove = list(self._state_cache.keys())[:excess]
|
|
1181
|
+
for k in keys_to_remove:
|
|
1182
|
+
del self._state_cache[k]
|
|
1183
|
+
|
|
1184
|
+
return state
|
|
1185
|
+
|
|
1186
|
+
# Model usage
|
|
1187
|
+
def has_model(self, model_type: str | ModelType) -> bool:
|
|
1188
|
+
"""Check if a model handler is registered for the given model type."""
|
|
1189
|
+
|
|
1190
|
+
key = model_type.value if isinstance(model_type, ModelType) else model_type
|
|
1191
|
+
handlers = self._models.get(key, [])
|
|
1192
|
+
return len(handlers) > 0
|
|
1193
|
+
|
|
1194
|
+
async def use_model(
|
|
1195
|
+
self,
|
|
1196
|
+
model_type: str | ModelType,
|
|
1197
|
+
params: dict[str, Any] | None = None,
|
|
1198
|
+
provider: str | None = None,
|
|
1199
|
+
**kwargs: Any,
|
|
1200
|
+
) -> Any:
|
|
1201
|
+
effective_model_type = model_type.value if isinstance(model_type, ModelType) else model_type
|
|
1202
|
+
if params is None:
|
|
1203
|
+
params = dict(kwargs)
|
|
1204
|
+
elif kwargs:
|
|
1205
|
+
params = {**params, **kwargs}
|
|
1206
|
+
|
|
1207
|
+
# Apply LLM mode override for text generation models
|
|
1208
|
+
llm_mode = self.get_llm_mode()
|
|
1209
|
+
if llm_mode != LLMMode.DEFAULT:
|
|
1210
|
+
# List of text generation model types that can be overridden
|
|
1211
|
+
text_generation_models = [
|
|
1212
|
+
ModelType.TEXT_SMALL,
|
|
1213
|
+
ModelType.TEXT_LARGE,
|
|
1214
|
+
ModelType.TEXT_REASONING_SMALL,
|
|
1215
|
+
ModelType.TEXT_REASONING_LARGE,
|
|
1216
|
+
ModelType.TEXT_COMPLETION,
|
|
1217
|
+
]
|
|
1218
|
+
if effective_model_type in text_generation_models:
|
|
1219
|
+
override_model_type = (
|
|
1220
|
+
ModelType.TEXT_SMALL.value
|
|
1221
|
+
if llm_mode == LLMMode.SMALL
|
|
1222
|
+
else ModelType.TEXT_LARGE.value
|
|
1223
|
+
)
|
|
1224
|
+
if effective_model_type != override_model_type:
|
|
1225
|
+
self.logger.debug(
|
|
1226
|
+
f"LLM mode override applied: {effective_model_type} -> {override_model_type} (mode: {llm_mode})"
|
|
1227
|
+
)
|
|
1228
|
+
effective_model_type = override_model_type
|
|
1229
|
+
|
|
1230
|
+
handlers = self._models.get(effective_model_type, [])
|
|
1231
|
+
|
|
1232
|
+
if not handlers:
|
|
1233
|
+
raise RuntimeError(f"No model handler registered for: {effective_model_type}")
|
|
1234
|
+
|
|
1235
|
+
handlers.sort(key=lambda h: h.priority, reverse=True)
|
|
1236
|
+
|
|
1237
|
+
if provider:
|
|
1238
|
+
handlers = [h for h in handlers if h.provider == provider]
|
|
1239
|
+
if not handlers:
|
|
1240
|
+
raise RuntimeError(f"No model handler for provider: {provider}")
|
|
1241
|
+
|
|
1242
|
+
handler = handlers[0]
|
|
1243
|
+
start_ms = self.get_current_time_ms()
|
|
1244
|
+
result = await handler.handler(self, params)
|
|
1245
|
+
end_ms = self.get_current_time_ms()
|
|
1246
|
+
|
|
1247
|
+
# Optional trajectory logging: associate model calls with the current trajectory step
|
|
1248
|
+
try:
|
|
1249
|
+
from elizaos.trajectory_context import CURRENT_TRAJECTORY_STEP_ID
|
|
1250
|
+
|
|
1251
|
+
step_id = CURRENT_TRAJECTORY_STEP_ID.get()
|
|
1252
|
+
traj_svc = self.get_service("trajectory_logger")
|
|
1253
|
+
if step_id and traj_svc is not None and hasattr(traj_svc, "log_llm_call"):
|
|
1254
|
+
prompt = str(params.get("prompt", "")) if isinstance(params, dict) else ""
|
|
1255
|
+
system_prompt = str(params.get("system", "")) if isinstance(params, dict) else ""
|
|
1256
|
+
temperature_raw = params.get("temperature") if isinstance(params, dict) else None
|
|
1257
|
+
temperature = (
|
|
1258
|
+
float(temperature_raw) if isinstance(temperature_raw, (int, float)) else 0.0
|
|
1259
|
+
)
|
|
1260
|
+
max_tokens_raw = params.get("maxTokens") if isinstance(params, dict) else None
|
|
1261
|
+
max_tokens = int(max_tokens_raw) if isinstance(max_tokens_raw, int) else 0
|
|
1262
|
+
|
|
1263
|
+
traj_svc.log_llm_call( # type: ignore[call-arg]
|
|
1264
|
+
step_id=step_id,
|
|
1265
|
+
model=str(effective_model_type),
|
|
1266
|
+
system_prompt=system_prompt,
|
|
1267
|
+
user_prompt=prompt,
|
|
1268
|
+
response=str(result),
|
|
1269
|
+
temperature=temperature,
|
|
1270
|
+
max_tokens=max_tokens,
|
|
1271
|
+
purpose="action",
|
|
1272
|
+
action_type="runtime.use_model",
|
|
1273
|
+
latency_ms=max(0, end_ms - start_ms),
|
|
1274
|
+
)
|
|
1275
|
+
except Exception:
|
|
1276
|
+
pass
|
|
1277
|
+
|
|
1278
|
+
return result
|
|
1279
|
+
|
|
1280
|
+
async def generate_text(
|
|
1281
|
+
self,
|
|
1282
|
+
input_text: str,
|
|
1283
|
+
options: GenerateTextOptions | None = None,
|
|
1284
|
+
) -> GenerateTextResult:
|
|
1285
|
+
model_type: str | ModelType = ModelType.TEXT_LARGE
|
|
1286
|
+
if options and options.model_type:
|
|
1287
|
+
model_type = options.model_type
|
|
1288
|
+
|
|
1289
|
+
params: dict[str, str | int | float] = {
|
|
1290
|
+
"prompt": input_text,
|
|
1291
|
+
}
|
|
1292
|
+
if options:
|
|
1293
|
+
if options.temperature is not None:
|
|
1294
|
+
params["temperature"] = options.temperature
|
|
1295
|
+
if options.max_tokens is not None:
|
|
1296
|
+
params["maxTokens"] = options.max_tokens
|
|
1297
|
+
|
|
1298
|
+
result = await self.use_model(model_type, params)
|
|
1299
|
+
return GenerateTextResult(text=str(result))
|
|
1300
|
+
|
|
1301
|
+
def register_model(
|
|
1302
|
+
self,
|
|
1303
|
+
model_type: str | ModelType,
|
|
1304
|
+
handler: Callable[[IAgentRuntime, dict[str, Any]], Awaitable[Any]],
|
|
1305
|
+
provider: str,
|
|
1306
|
+
priority: int = 0,
|
|
1307
|
+
) -> None:
|
|
1308
|
+
key = model_type.value if isinstance(model_type, ModelType) else model_type
|
|
1309
|
+
if key not in self._models:
|
|
1310
|
+
self._models[key] = []
|
|
1311
|
+
|
|
1312
|
+
self._models[key].append(
|
|
1313
|
+
ModelHandler(handler=handler, provider=provider, priority=priority)
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
def get_model(
|
|
1317
|
+
self, model_type: str
|
|
1318
|
+
) -> Callable[[IAgentRuntime, dict[str, Any]], Awaitable[Any]] | None:
|
|
1319
|
+
handlers = self._models.get(model_type, [])
|
|
1320
|
+
if handlers:
|
|
1321
|
+
handlers.sort(key=lambda h: h.priority, reverse=True)
|
|
1322
|
+
return handlers[0].handler
|
|
1323
|
+
return None
|
|
1324
|
+
|
|
1325
|
+
def register_streaming_model(
|
|
1326
|
+
self,
|
|
1327
|
+
model_type: str | ModelType,
|
|
1328
|
+
handler: StreamingModelHandler,
|
|
1329
|
+
provider: str,
|
|
1330
|
+
priority: int = 0,
|
|
1331
|
+
) -> None:
|
|
1332
|
+
"""Register a streaming model handler."""
|
|
1333
|
+
key = model_type.value if isinstance(model_type, ModelType) else model_type
|
|
1334
|
+
if key not in self._streaming_models:
|
|
1335
|
+
self._streaming_models[key] = []
|
|
1336
|
+
|
|
1337
|
+
self._streaming_models[key].append(
|
|
1338
|
+
StreamingModelHandlerWrapper(handler=handler, provider=provider, priority=priority)
|
|
1339
|
+
)
|
|
1340
|
+
|
|
1341
|
+
async def _use_model_stream_impl(
|
|
1342
|
+
self,
|
|
1343
|
+
model_type: str | ModelType,
|
|
1344
|
+
params: dict[str, Any] | None = None,
|
|
1345
|
+
provider: str | None = None,
|
|
1346
|
+
**kwargs: Any,
|
|
1347
|
+
) -> AsyncIterator[str]:
|
|
1348
|
+
"""Internal implementation for streaming model calls."""
|
|
1349
|
+
effective_model_type = model_type.value if isinstance(model_type, ModelType) else model_type
|
|
1350
|
+
if params is None:
|
|
1351
|
+
params = dict(kwargs)
|
|
1352
|
+
elif kwargs:
|
|
1353
|
+
params = {**params, **kwargs}
|
|
1354
|
+
|
|
1355
|
+
# Apply LLM mode override for streaming text generation models
|
|
1356
|
+
llm_mode = self.get_llm_mode()
|
|
1357
|
+
if llm_mode != LLMMode.DEFAULT:
|
|
1358
|
+
streaming_text_models = [
|
|
1359
|
+
ModelType.TEXT_SMALL_STREAM.value,
|
|
1360
|
+
ModelType.TEXT_LARGE_STREAM.value,
|
|
1361
|
+
]
|
|
1362
|
+
if effective_model_type in streaming_text_models:
|
|
1363
|
+
override_model_type = (
|
|
1364
|
+
ModelType.TEXT_SMALL_STREAM.value
|
|
1365
|
+
if llm_mode == LLMMode.SMALL
|
|
1366
|
+
else ModelType.TEXT_LARGE_STREAM.value
|
|
1367
|
+
)
|
|
1368
|
+
if effective_model_type != override_model_type:
|
|
1369
|
+
self.logger.debug(
|
|
1370
|
+
f"LLM mode override applied: {effective_model_type} -> {override_model_type} (mode: {llm_mode})"
|
|
1371
|
+
)
|
|
1372
|
+
effective_model_type = override_model_type
|
|
1373
|
+
|
|
1374
|
+
handlers = self._streaming_models.get(effective_model_type, [])
|
|
1375
|
+
|
|
1376
|
+
if not handlers:
|
|
1377
|
+
raise RuntimeError(f"No streaming model handler registered for: {effective_model_type}")
|
|
1378
|
+
|
|
1379
|
+
handlers.sort(key=lambda h: h.priority, reverse=True)
|
|
1380
|
+
|
|
1381
|
+
if provider:
|
|
1382
|
+
handlers = [h for h in handlers if h.provider == provider]
|
|
1383
|
+
if not handlers:
|
|
1384
|
+
raise RuntimeError(f"No streaming model handler for provider: {provider}")
|
|
1385
|
+
|
|
1386
|
+
handler = handlers[0]
|
|
1387
|
+
async for chunk in handler.handler(self, params):
|
|
1388
|
+
yield chunk
|
|
1389
|
+
|
|
1390
|
+
def use_model_stream(
|
|
1391
|
+
self,
|
|
1392
|
+
model_type: str | ModelType,
|
|
1393
|
+
params: dict[str, Any] | None = None,
|
|
1394
|
+
provider: str | None = None,
|
|
1395
|
+
**kwargs: Any,
|
|
1396
|
+
) -> AsyncIterator[str]:
|
|
1397
|
+
"""
|
|
1398
|
+
Use a streaming model handler to generate text token by token.
|
|
1399
|
+
|
|
1400
|
+
Args:
|
|
1401
|
+
model_type: The model type (e.g., "TEXT_LARGE_STREAM")
|
|
1402
|
+
params: Parameters for the model (prompt, system, temperature, etc.)
|
|
1403
|
+
provider: Optional specific provider to use
|
|
1404
|
+
**kwargs: Additional parameters merged into params
|
|
1405
|
+
|
|
1406
|
+
Returns:
|
|
1407
|
+
An async iterator yielding text chunks as they are generated.
|
|
1408
|
+
"""
|
|
1409
|
+
return self._use_model_stream_impl(model_type, params, provider, **kwargs)
|
|
1410
|
+
|
|
1411
|
+
# Event handling
|
|
1412
|
+
def register_event(
|
|
1413
|
+
self,
|
|
1414
|
+
event: str,
|
|
1415
|
+
handler: Callable[[Any], Awaitable[None]],
|
|
1416
|
+
) -> None:
|
|
1417
|
+
if event not in self._events:
|
|
1418
|
+
self._events[event] = []
|
|
1419
|
+
self._events[event].append(handler)
|
|
1420
|
+
|
|
1421
|
+
def get_event(self, event: str) -> list[Callable[[Any], Awaitable[None]]] | None:
|
|
1422
|
+
"""Get event handlers for an event type."""
|
|
1423
|
+
return self._events.get(event)
|
|
1424
|
+
|
|
1425
|
+
async def emit_event(
|
|
1426
|
+
self,
|
|
1427
|
+
event: str | list[str],
|
|
1428
|
+
params: Any,
|
|
1429
|
+
) -> None:
|
|
1430
|
+
events = [event] if isinstance(event, str) else event
|
|
1431
|
+
|
|
1432
|
+
for evt in events:
|
|
1433
|
+
handlers = self._events.get(evt, [])
|
|
1434
|
+
for handler in handlers:
|
|
1435
|
+
await handler(params)
|
|
1436
|
+
|
|
1437
|
+
# Task management
|
|
1438
|
+
def register_task_worker(self, task_handler: TaskWorker) -> None:
|
|
1439
|
+
"""Register a task worker."""
|
|
1440
|
+
self._task_workers[task_handler.name] = task_handler
|
|
1441
|
+
|
|
1442
|
+
def get_task_worker(self, name: str) -> TaskWorker | None:
|
|
1443
|
+
"""Get a task worker by name."""
|
|
1444
|
+
return self._task_workers.get(name)
|
|
1445
|
+
|
|
1446
|
+
# Lifecycle
|
|
1447
|
+
async def stop(self) -> None:
|
|
1448
|
+
"""Stop the runtime."""
|
|
1449
|
+
self.logger.info("Stopping AgentRuntime...")
|
|
1450
|
+
|
|
1451
|
+
# Stop all services
|
|
1452
|
+
for service_type, services in self._services.items():
|
|
1453
|
+
for service in services:
|
|
1454
|
+
try:
|
|
1455
|
+
await service.stop()
|
|
1456
|
+
except Exception as e:
|
|
1457
|
+
self.logger.error(f"Failed to stop service {service_type}: {e}")
|
|
1458
|
+
|
|
1459
|
+
if self._adapter:
|
|
1460
|
+
await self._adapter.close()
|
|
1461
|
+
|
|
1462
|
+
self.logger.info("AgentRuntime stopped")
|
|
1463
|
+
|
|
1464
|
+
async def add_embedding_to_memory(self, memory: Memory) -> Memory:
|
|
1465
|
+
return memory
|
|
1466
|
+
|
|
1467
|
+
async def queue_embedding_generation(self, memory: Memory, priority: str = "normal") -> None:
|
|
1468
|
+
await self.emit_event(
|
|
1469
|
+
EventType.EMBEDDING_GENERATION_REQUESTED.value,
|
|
1470
|
+
{"runtime": self, "memory": memory, "priority": priority, "source": "runtime"},
|
|
1471
|
+
)
|
|
1472
|
+
|
|
1473
|
+
async def get_all_memories(self) -> list[Memory]:
|
|
1474
|
+
if not self._adapter:
|
|
1475
|
+
return []
|
|
1476
|
+
return await self._adapter.get_memories(
|
|
1477
|
+
{"agentId": str(self._agent_id), "tableName": "memories"}
|
|
1478
|
+
)
|
|
1479
|
+
|
|
1480
|
+
async def clear_all_agent_memories(self) -> None:
|
|
1481
|
+
pass
|
|
1482
|
+
|
|
1483
|
+
def create_run_id(self) -> UUID:
|
|
1484
|
+
return as_uuid(str(uuid.uuid4()))
|
|
1485
|
+
|
|
1486
|
+
def start_run(self, room_id: UUID | None = None) -> UUID:
|
|
1487
|
+
self._current_run_id = self.create_run_id()
|
|
1488
|
+
self._current_room_id = room_id
|
|
1489
|
+
return self._current_run_id
|
|
1490
|
+
|
|
1491
|
+
def end_run(self) -> None:
|
|
1492
|
+
self._current_run_id = None
|
|
1493
|
+
self._current_room_id = None
|
|
1494
|
+
|
|
1495
|
+
def get_current_run_id(self) -> UUID:
|
|
1496
|
+
if not self._current_run_id:
|
|
1497
|
+
return self.start_run()
|
|
1498
|
+
return self._current_run_id
|
|
1499
|
+
|
|
1500
|
+
async def get_entity_by_id(self, entity_id: UUID) -> Entity | None:
|
|
1501
|
+
if not self._adapter:
|
|
1502
|
+
return None
|
|
1503
|
+
entities = await self._adapter.get_entities_by_ids([entity_id])
|
|
1504
|
+
return entities[0] if entities else None
|
|
1505
|
+
|
|
1506
|
+
async def get_room(self, room_id: UUID) -> Room | None:
|
|
1507
|
+
if not self._adapter:
|
|
1508
|
+
return None
|
|
1509
|
+
rooms = await self._adapter.get_rooms_by_ids([room_id])
|
|
1510
|
+
return rooms[0] if rooms else None
|
|
1511
|
+
|
|
1512
|
+
async def create_entity(self, entity: Entity) -> bool:
|
|
1513
|
+
if not self._adapter:
|
|
1514
|
+
return False
|
|
1515
|
+
return await self._adapter.create_entities([entity])
|
|
1516
|
+
|
|
1517
|
+
async def create_room(self, room: Room) -> UUID:
|
|
1518
|
+
if not self._adapter:
|
|
1519
|
+
raise RuntimeError("Database adapter not set")
|
|
1520
|
+
ids = await self._adapter.create_rooms([room])
|
|
1521
|
+
return ids[0]
|
|
1522
|
+
|
|
1523
|
+
async def add_participant(self, entity_id: UUID, room_id: UUID) -> bool:
|
|
1524
|
+
if not self._adapter:
|
|
1525
|
+
return False
|
|
1526
|
+
return await self._adapter.add_participants_room([entity_id], room_id)
|
|
1527
|
+
|
|
1528
|
+
async def get_rooms(self, world_id: UUID) -> list[Room]:
|
|
1529
|
+
if not self._adapter:
|
|
1530
|
+
return []
|
|
1531
|
+
return await self._adapter.get_rooms_by_world(world_id)
|
|
1532
|
+
|
|
1533
|
+
def register_send_handler(self, source: str, handler: SendHandlerFunction) -> None:
|
|
1534
|
+
self._send_handlers[source] = handler
|
|
1535
|
+
|
|
1536
|
+
async def send_message_to_target(self, target: TargetInfo, content: Content) -> None:
|
|
1537
|
+
if target.source and target.source in self._send_handlers:
|
|
1538
|
+
await self._send_handlers[target.source](target, content)
|
|
1539
|
+
|
|
1540
|
+
async def init(self) -> None:
|
|
1541
|
+
if self._adapter:
|
|
1542
|
+
await self._adapter.init()
|
|
1543
|
+
|
|
1544
|
+
async def is_ready(self) -> bool:
|
|
1545
|
+
if not self._adapter:
|
|
1546
|
+
return False
|
|
1547
|
+
return await self._adapter.is_ready()
|
|
1548
|
+
|
|
1549
|
+
async def close(self) -> None:
|
|
1550
|
+
if self._adapter:
|
|
1551
|
+
await self._adapter.close()
|
|
1552
|
+
|
|
1553
|
+
async def get_connection(self) -> Any:
|
|
1554
|
+
if not self._adapter:
|
|
1555
|
+
raise RuntimeError("Database adapter not set")
|
|
1556
|
+
return await self._adapter.get_connection()
|
|
1557
|
+
|
|
1558
|
+
async def get_agent(self, agent_id: UUID) -> Any | None:
|
|
1559
|
+
if not self._adapter:
|
|
1560
|
+
return None
|
|
1561
|
+
return await self._adapter.get_agent(agent_id)
|
|
1562
|
+
|
|
1563
|
+
async def get_agents(self) -> list[Any]:
|
|
1564
|
+
if not self._adapter:
|
|
1565
|
+
return []
|
|
1566
|
+
return await self._adapter.get_agents()
|
|
1567
|
+
|
|
1568
|
+
async def create_agent(self, agent: Any) -> bool:
|
|
1569
|
+
if not self._adapter:
|
|
1570
|
+
return False
|
|
1571
|
+
return await self._adapter.create_agent(agent)
|
|
1572
|
+
|
|
1573
|
+
async def update_agent(self, agent_id: UUID, agent: Any) -> bool:
|
|
1574
|
+
if not self._adapter:
|
|
1575
|
+
return False
|
|
1576
|
+
return await self._adapter.update_agent(agent_id, agent)
|
|
1577
|
+
|
|
1578
|
+
async def delete_agent(self, agent_id: UUID) -> bool:
|
|
1579
|
+
if not self._adapter:
|
|
1580
|
+
return False
|
|
1581
|
+
return await self._adapter.delete_agent(agent_id)
|
|
1582
|
+
|
|
1583
|
+
async def ensure_embedding_dimension(self, dimension: int) -> None:
|
|
1584
|
+
if self._adapter:
|
|
1585
|
+
await self._adapter.ensure_embedding_dimension(dimension)
|
|
1586
|
+
|
|
1587
|
+
async def get_entity(self, entity_id: UUID) -> Any | None:
|
|
1588
|
+
"""Get a single entity by ID."""
|
|
1589
|
+
if not self._adapter:
|
|
1590
|
+
return None
|
|
1591
|
+
entities = await self._adapter.get_entities_by_ids([entity_id])
|
|
1592
|
+
return entities[0] if entities else None
|
|
1593
|
+
|
|
1594
|
+
async def get_entities_by_ids(self, entity_ids: list[UUID]) -> list[Any] | None:
|
|
1595
|
+
if not self._adapter:
|
|
1596
|
+
return None
|
|
1597
|
+
return await self._adapter.get_entities_by_ids(entity_ids)
|
|
1598
|
+
|
|
1599
|
+
async def get_entities_for_room(
|
|
1600
|
+
self, room_id: UUID, include_components: bool = False
|
|
1601
|
+
) -> list[Any]:
|
|
1602
|
+
if not self._adapter:
|
|
1603
|
+
return []
|
|
1604
|
+
return await self._adapter.get_entities_for_room(room_id, include_components)
|
|
1605
|
+
|
|
1606
|
+
async def create_entities(self, entities: list[Any]) -> bool:
|
|
1607
|
+
if not self._adapter:
|
|
1608
|
+
return False
|
|
1609
|
+
return await self._adapter.create_entities(entities)
|
|
1610
|
+
|
|
1611
|
+
async def update_entity(self, entity: Any) -> None:
|
|
1612
|
+
if self._adapter:
|
|
1613
|
+
await self._adapter.update_entity(entity)
|
|
1614
|
+
|
|
1615
|
+
async def get_component(
|
|
1616
|
+
self,
|
|
1617
|
+
entity_id: UUID,
|
|
1618
|
+
component_type: str,
|
|
1619
|
+
world_id: UUID | None = None,
|
|
1620
|
+
source_entity_id: UUID | None = None,
|
|
1621
|
+
) -> Any | None:
|
|
1622
|
+
if not self._adapter:
|
|
1623
|
+
return None
|
|
1624
|
+
return await self._adapter.get_component(
|
|
1625
|
+
entity_id, component_type, world_id, source_entity_id
|
|
1626
|
+
)
|
|
1627
|
+
|
|
1628
|
+
async def get_components(
|
|
1629
|
+
self,
|
|
1630
|
+
entity_id: UUID,
|
|
1631
|
+
world_id: UUID | None = None,
|
|
1632
|
+
source_entity_id: UUID | None = None,
|
|
1633
|
+
) -> list[Any]:
|
|
1634
|
+
if not self._adapter:
|
|
1635
|
+
return []
|
|
1636
|
+
return await self._adapter.get_components(entity_id, world_id, source_entity_id)
|
|
1637
|
+
|
|
1638
|
+
async def create_component(self, component: Any) -> bool:
|
|
1639
|
+
if not self._adapter:
|
|
1640
|
+
return False
|
|
1641
|
+
return await self._adapter.create_component(component)
|
|
1642
|
+
|
|
1643
|
+
async def update_component(self, component: Any) -> None:
|
|
1644
|
+
if self._adapter:
|
|
1645
|
+
await self._adapter.update_component(component)
|
|
1646
|
+
|
|
1647
|
+
async def delete_component(self, component_id: UUID) -> None:
|
|
1648
|
+
if self._adapter:
|
|
1649
|
+
await self._adapter.delete_component(component_id)
|
|
1650
|
+
|
|
1651
|
+
async def get_memories(
|
|
1652
|
+
self,
|
|
1653
|
+
params: dict[str, Any] | None = None,
|
|
1654
|
+
*,
|
|
1655
|
+
room_id: UUID | None = None,
|
|
1656
|
+
limit: int | None = None,
|
|
1657
|
+
order_by: str | None = None,
|
|
1658
|
+
order_direction: str | None = None,
|
|
1659
|
+
table_name: str | None = None,
|
|
1660
|
+
**kwargs: Any,
|
|
1661
|
+
) -> list[Any]:
|
|
1662
|
+
"""
|
|
1663
|
+
Get memories, supporting both dict-style and kwargs-style calling.
|
|
1664
|
+
|
|
1665
|
+
Can be called as:
|
|
1666
|
+
get_memories({"roomId": room_id, "limit": 10})
|
|
1667
|
+
or:
|
|
1668
|
+
get_memories(room_id=room_id, limit=10)
|
|
1669
|
+
"""
|
|
1670
|
+
if not self._adapter:
|
|
1671
|
+
return []
|
|
1672
|
+
# Start with provided params or empty dict
|
|
1673
|
+
merged_params = dict(params) if params else {}
|
|
1674
|
+
# Explicit keyword arguments take precedence over params dict
|
|
1675
|
+
if room_id is not None:
|
|
1676
|
+
merged_params["roomId"] = str(room_id)
|
|
1677
|
+
if limit is not None:
|
|
1678
|
+
merged_params["limit"] = limit
|
|
1679
|
+
if order_by is not None:
|
|
1680
|
+
merged_params["orderBy"] = order_by
|
|
1681
|
+
if order_direction is not None:
|
|
1682
|
+
merged_params["orderDirection"] = order_direction
|
|
1683
|
+
if table_name is not None:
|
|
1684
|
+
merged_params["tableName"] = table_name
|
|
1685
|
+
# Additional kwargs also take precedence
|
|
1686
|
+
merged_params.update(kwargs)
|
|
1687
|
+
return await self._adapter.get_memories(merged_params)
|
|
1688
|
+
|
|
1689
|
+
async def get_memory_by_id(self, id: UUID) -> Any | None:
|
|
1690
|
+
if not self._adapter:
|
|
1691
|
+
return None
|
|
1692
|
+
return await self._adapter.get_memory_by_id(id)
|
|
1693
|
+
|
|
1694
|
+
async def get_memories_by_ids(
|
|
1695
|
+
self, ids: list[UUID], table_name: str | None = None
|
|
1696
|
+
) -> list[Any]:
|
|
1697
|
+
if not self._adapter:
|
|
1698
|
+
return []
|
|
1699
|
+
return await self._adapter.get_memories_by_ids(ids, table_name)
|
|
1700
|
+
|
|
1701
|
+
async def get_memories_by_room_ids(self, params: dict[str, Any]) -> list[Any]:
|
|
1702
|
+
if not self._adapter:
|
|
1703
|
+
return []
|
|
1704
|
+
return await self._adapter.get_memories_by_room_ids(params)
|
|
1705
|
+
|
|
1706
|
+
async def get_cached_embeddings(self, params: dict[str, Any]) -> list[dict[str, Any]]:
|
|
1707
|
+
if not self._adapter:
|
|
1708
|
+
return []
|
|
1709
|
+
return await self._adapter.get_cached_embeddings(params)
|
|
1710
|
+
|
|
1711
|
+
async def log(self, params: dict[str, Any]) -> None:
|
|
1712
|
+
if self._adapter:
|
|
1713
|
+
await self._adapter.log(params)
|
|
1714
|
+
|
|
1715
|
+
async def get_logs(self, params: dict[str, Any]) -> list[Log]:
|
|
1716
|
+
if not self._adapter:
|
|
1717
|
+
return []
|
|
1718
|
+
return await self._adapter.get_logs(params)
|
|
1719
|
+
|
|
1720
|
+
async def delete_log(self, log_id: UUID) -> None:
|
|
1721
|
+
if self._adapter:
|
|
1722
|
+
await self._adapter.delete_log(log_id)
|
|
1723
|
+
|
|
1724
|
+
async def get_agent_run_summaries(self, params: dict[str, Any]) -> AgentRunSummaryResult:
|
|
1725
|
+
if not self._adapter:
|
|
1726
|
+
return AgentRunSummaryResult(runs=[], total=0, has_more=False)
|
|
1727
|
+
return await self._adapter.get_agent_run_summaries(params)
|
|
1728
|
+
|
|
1729
|
+
async def search_memories(self, params: dict[str, Any]) -> list[Any]:
|
|
1730
|
+
if not self._adapter:
|
|
1731
|
+
return []
|
|
1732
|
+
return await self._adapter.search_memories(params)
|
|
1733
|
+
|
|
1734
|
+
async def create_memory(
|
|
1735
|
+
self,
|
|
1736
|
+
memory: dict[str, object] | None = None,
|
|
1737
|
+
table_name: str | None = None,
|
|
1738
|
+
unique: bool | None = None,
|
|
1739
|
+
**kwargs: object,
|
|
1740
|
+
) -> UUID:
|
|
1741
|
+
if not self._adapter:
|
|
1742
|
+
raise RuntimeError("Database adapter not set")
|
|
1743
|
+
return await self._adapter.create_memory(
|
|
1744
|
+
memory, table_name, bool(unique) if unique is not None else False
|
|
1745
|
+
)
|
|
1746
|
+
|
|
1747
|
+
async def update_memory(self, memory: Memory | dict[str, Any]) -> bool:
|
|
1748
|
+
if not self._adapter:
|
|
1749
|
+
return False
|
|
1750
|
+
return await self._adapter.update_memory(memory)
|
|
1751
|
+
|
|
1752
|
+
async def delete_memory(self, memory_id: UUID) -> None:
|
|
1753
|
+
if self._adapter:
|
|
1754
|
+
await self._adapter.delete_memory(memory_id)
|
|
1755
|
+
|
|
1756
|
+
async def delete_many_memories(self, memory_ids: list[UUID]) -> None:
|
|
1757
|
+
if self._adapter:
|
|
1758
|
+
await self._adapter.delete_many_memories(memory_ids)
|
|
1759
|
+
|
|
1760
|
+
async def delete_all_memories(self, room_id: UUID, table_name: str) -> None:
|
|
1761
|
+
if self._adapter:
|
|
1762
|
+
await self._adapter.delete_all_memories(room_id, table_name)
|
|
1763
|
+
|
|
1764
|
+
async def count_memories(
|
|
1765
|
+
self, room_id: UUID, unique: bool = False, table_name: str | None = None
|
|
1766
|
+
) -> int:
|
|
1767
|
+
if not self._adapter:
|
|
1768
|
+
return 0
|
|
1769
|
+
return await self._adapter.count_memories(room_id, unique, table_name)
|
|
1770
|
+
|
|
1771
|
+
async def create_world(self, world: Any) -> UUID:
|
|
1772
|
+
if not self._adapter:
|
|
1773
|
+
raise RuntimeError("Database adapter not set")
|
|
1774
|
+
return await self._adapter.create_world(world)
|
|
1775
|
+
|
|
1776
|
+
async def get_world(self, id: UUID) -> Any | None:
|
|
1777
|
+
if not self._adapter:
|
|
1778
|
+
return None
|
|
1779
|
+
return await self._adapter.get_world(id)
|
|
1780
|
+
|
|
1781
|
+
async def remove_world(self, id: UUID) -> None:
|
|
1782
|
+
if self._adapter:
|
|
1783
|
+
await self._adapter.remove_world(id)
|
|
1784
|
+
|
|
1785
|
+
async def get_all_worlds(self) -> list[Any]:
|
|
1786
|
+
if not self._adapter:
|
|
1787
|
+
return []
|
|
1788
|
+
return await self._adapter.get_all_worlds()
|
|
1789
|
+
|
|
1790
|
+
async def update_world(self, world: Any) -> None:
|
|
1791
|
+
if self._adapter:
|
|
1792
|
+
await self._adapter.update_world(world)
|
|
1793
|
+
|
|
1794
|
+
async def get_rooms_by_ids(self, room_ids: list[UUID]) -> list[Any] | None:
|
|
1795
|
+
if not self._adapter:
|
|
1796
|
+
return None
|
|
1797
|
+
return await self._adapter.get_rooms_by_ids(room_ids)
|
|
1798
|
+
|
|
1799
|
+
async def create_rooms(self, rooms: list[Any]) -> list[UUID]:
|
|
1800
|
+
if not self._adapter:
|
|
1801
|
+
raise RuntimeError("Database adapter not set")
|
|
1802
|
+
return await self._adapter.create_rooms(rooms)
|
|
1803
|
+
|
|
1804
|
+
async def delete_room(self, room_id: UUID) -> None:
|
|
1805
|
+
if self._adapter:
|
|
1806
|
+
await self._adapter.delete_room(room_id)
|
|
1807
|
+
|
|
1808
|
+
async def delete_rooms_by_world_id(self, world_id: UUID) -> None:
|
|
1809
|
+
if self._adapter:
|
|
1810
|
+
await self._adapter.delete_rooms_by_world_id(world_id)
|
|
1811
|
+
|
|
1812
|
+
async def update_room(self, room: Any) -> None:
|
|
1813
|
+
if self._adapter:
|
|
1814
|
+
await self._adapter.update_room(room)
|
|
1815
|
+
|
|
1816
|
+
async def get_rooms_for_participant(self, entity_id: UUID) -> list[UUID]:
|
|
1817
|
+
if not self._adapter:
|
|
1818
|
+
return []
|
|
1819
|
+
return await self._adapter.get_rooms_for_participant(entity_id)
|
|
1820
|
+
|
|
1821
|
+
async def get_rooms_for_participants(self, user_ids: list[UUID]) -> list[UUID]:
|
|
1822
|
+
if not self._adapter:
|
|
1823
|
+
return []
|
|
1824
|
+
return await self._adapter.get_rooms_for_participants(user_ids)
|
|
1825
|
+
|
|
1826
|
+
async def get_rooms_by_world(self, world_id: UUID) -> list[Any]:
|
|
1827
|
+
if not self._adapter:
|
|
1828
|
+
return []
|
|
1829
|
+
return await self._adapter.get_rooms_by_world(world_id)
|
|
1830
|
+
|
|
1831
|
+
async def remove_participant(self, entity_id: UUID, room_id: UUID) -> bool:
|
|
1832
|
+
if not self._adapter:
|
|
1833
|
+
return False
|
|
1834
|
+
return await self._adapter.remove_participant(entity_id, room_id)
|
|
1835
|
+
|
|
1836
|
+
async def get_participants_for_entity(self, entity_id: UUID) -> list[Any]:
|
|
1837
|
+
if not self._adapter:
|
|
1838
|
+
return []
|
|
1839
|
+
return await self._adapter.get_participants_for_entity(entity_id)
|
|
1840
|
+
|
|
1841
|
+
async def get_participants_for_room(self, room_id: UUID) -> list[UUID]:
|
|
1842
|
+
if not self._adapter:
|
|
1843
|
+
return []
|
|
1844
|
+
return await self._adapter.get_participants_for_room(room_id)
|
|
1845
|
+
|
|
1846
|
+
async def is_room_participant(self, room_id: UUID, entity_id: UUID) -> bool:
|
|
1847
|
+
if not self._adapter:
|
|
1848
|
+
return False
|
|
1849
|
+
return await self._adapter.is_room_participant(room_id, entity_id)
|
|
1850
|
+
|
|
1851
|
+
async def add_participants_room(self, entity_ids: list[UUID], room_id: UUID) -> bool:
|
|
1852
|
+
if not self._adapter:
|
|
1853
|
+
return False
|
|
1854
|
+
return await self._adapter.add_participants_room(entity_ids, room_id)
|
|
1855
|
+
|
|
1856
|
+
async def get_participant_user_state(self, room_id: UUID, entity_id: UUID) -> str | None:
|
|
1857
|
+
if not self._adapter:
|
|
1858
|
+
return None
|
|
1859
|
+
return await self._adapter.get_participant_user_state(room_id, entity_id)
|
|
1860
|
+
|
|
1861
|
+
async def set_participant_user_state(
|
|
1862
|
+
self, room_id: UUID, entity_id: UUID, state: str | None
|
|
1863
|
+
) -> None:
|
|
1864
|
+
if self._adapter:
|
|
1865
|
+
await self._adapter.set_participant_user_state(room_id, entity_id, state)
|
|
1866
|
+
|
|
1867
|
+
async def create_relationship(self, params: dict[str, Any]) -> bool:
|
|
1868
|
+
if not self._adapter:
|
|
1869
|
+
return False
|
|
1870
|
+
return await self._adapter.create_relationship(params)
|
|
1871
|
+
|
|
1872
|
+
async def update_relationship(self, relationship: Any) -> None:
|
|
1873
|
+
if self._adapter:
|
|
1874
|
+
await self._adapter.update_relationship(relationship)
|
|
1875
|
+
|
|
1876
|
+
async def get_relationship(self, params: dict[str, Any]) -> Any | None:
|
|
1877
|
+
if not self._adapter:
|
|
1878
|
+
return None
|
|
1879
|
+
return await self._adapter.get_relationship(params)
|
|
1880
|
+
|
|
1881
|
+
async def get_relationships(self, params: dict[str, Any]) -> list[Any]:
|
|
1882
|
+
if not self._adapter:
|
|
1883
|
+
return []
|
|
1884
|
+
return await self._adapter.get_relationships(params)
|
|
1885
|
+
|
|
1886
|
+
async def get_cache(self, key: str) -> Any | None:
|
|
1887
|
+
if not self._adapter:
|
|
1888
|
+
return None
|
|
1889
|
+
return await self._adapter.get_cache(key)
|
|
1890
|
+
|
|
1891
|
+
async def set_cache(self, key: str, value: Any) -> bool:
|
|
1892
|
+
if not self._adapter:
|
|
1893
|
+
return False
|
|
1894
|
+
return await self._adapter.set_cache(key, value)
|
|
1895
|
+
|
|
1896
|
+
async def delete_cache(self, key: str) -> None:
|
|
1897
|
+
if not self._adapter:
|
|
1898
|
+
return
|
|
1899
|
+
await self._adapter.delete_cache(key)
|
|
1900
|
+
|
|
1901
|
+
async def create_task(self, task: Any) -> UUID:
|
|
1902
|
+
if not self._adapter:
|
|
1903
|
+
raise RuntimeError("Database adapter not set")
|
|
1904
|
+
return await self._adapter.create_task(task)
|
|
1905
|
+
|
|
1906
|
+
async def get_tasks(self, params: dict[str, Any]) -> list[Any]:
|
|
1907
|
+
if not self._adapter:
|
|
1908
|
+
return []
|
|
1909
|
+
return await self._adapter.get_tasks(params)
|
|
1910
|
+
|
|
1911
|
+
async def get_task(self, id: UUID) -> Any | None:
|
|
1912
|
+
if not self._adapter:
|
|
1913
|
+
return None
|
|
1914
|
+
return await self._adapter.get_task(id)
|
|
1915
|
+
|
|
1916
|
+
async def get_tasks_by_name(self, name: str) -> list[Any]:
|
|
1917
|
+
if not self._adapter:
|
|
1918
|
+
return []
|
|
1919
|
+
return await self._adapter.get_tasks_by_name(name)
|
|
1920
|
+
|
|
1921
|
+
async def update_task(self, id: UUID, task: dict[str, Any]) -> None:
|
|
1922
|
+
if self._adapter:
|
|
1923
|
+
await self._adapter.update_task(id, task)
|
|
1924
|
+
|
|
1925
|
+
async def delete_task(self, id: UUID) -> None:
|
|
1926
|
+
if self._adapter:
|
|
1927
|
+
await self._adapter.delete_task(id)
|
|
1928
|
+
|
|
1929
|
+
async def get_memories_by_world_id(self, params: dict[str, Any]) -> list[Any]:
|
|
1930
|
+
if not self._adapter:
|
|
1931
|
+
return []
|
|
1932
|
+
return await self._adapter.get_memories_by_world_id(params)
|
|
1933
|
+
|
|
1934
|
+
# ============================================================================
|
|
1935
|
+
# Dynamic Prompt Execution with Validation-Aware Streaming
|
|
1936
|
+
# ============================================================================
|
|
1937
|
+
|
|
1938
|
+
async def dynamic_prompt_exec_from_state(
|
|
1939
|
+
self,
|
|
1940
|
+
state: State,
|
|
1941
|
+
prompt: str | Callable[[dict[str, Any]], str],
|
|
1942
|
+
schema: list[SchemaRow],
|
|
1943
|
+
options: DynamicPromptOptions | None = None,
|
|
1944
|
+
) -> dict[str, Any] | None:
|
|
1945
|
+
"""Dynamic prompt execution with state injection, schema-based parsing, and validation.
|
|
1946
|
+
|
|
1947
|
+
WHY THIS EXISTS:
|
|
1948
|
+
LLMs are powerful but unreliable for structured outputs. They can:
|
|
1949
|
+
- Silently truncate output when hitting token limits
|
|
1950
|
+
- Skip fields or produce malformed structures
|
|
1951
|
+
- Hallucinate or ignore parts of the prompt
|
|
1952
|
+
|
|
1953
|
+
This method addresses these issues by:
|
|
1954
|
+
1. Validation codes: Injects UUID codes the LLM must echo back
|
|
1955
|
+
2. Retry with backoff: Automatic retries on validation failure
|
|
1956
|
+
3. Structured parsing: XML/JSON response parsing with nested support
|
|
1957
|
+
4. Streaming support: ValidationStreamExtractor for incremental output with validation
|
|
1958
|
+
|
|
1959
|
+
For streaming, provide `on_stream_chunk` in options. Streaming uses
|
|
1960
|
+
ValidationStreamExtractor which streams validated content in real-time
|
|
1961
|
+
while detecting truncation via validation codes.
|
|
1962
|
+
|
|
1963
|
+
Args:
|
|
1964
|
+
state: State object to inject into the prompt template
|
|
1965
|
+
prompt: Prompt template string or callable that takes state and returns string
|
|
1966
|
+
schema: Array of SchemaRow definitions for structured output
|
|
1967
|
+
options: Configuration for model size, validation level, retries, streaming, etc.
|
|
1968
|
+
|
|
1969
|
+
Returns:
|
|
1970
|
+
Parsed structured response as dict, or None on failure
|
|
1971
|
+
"""
|
|
1972
|
+
if options is None:
|
|
1973
|
+
options = DynamicPromptOptions()
|
|
1974
|
+
|
|
1975
|
+
# Determine model type - check options.model first, then model_size, then default
|
|
1976
|
+
if options.model:
|
|
1977
|
+
model_type_str = options.model
|
|
1978
|
+
elif options.model_size == "small":
|
|
1979
|
+
model_type_str = ModelType.TEXT_SMALL
|
|
1980
|
+
else:
|
|
1981
|
+
model_type_str = ModelType.TEXT_LARGE
|
|
1982
|
+
|
|
1983
|
+
schema_key = ",".join(s.field for s in schema)
|
|
1984
|
+
model_schema_key = f"{model_type_str}:{schema_key}"
|
|
1985
|
+
|
|
1986
|
+
# Get validation level from settings or options (mirrors TypeScript behavior)
|
|
1987
|
+
default_context_level = 2
|
|
1988
|
+
default_retries = 1
|
|
1989
|
+
|
|
1990
|
+
validation_setting = self.get_setting("VALIDATION_LEVEL")
|
|
1991
|
+
if validation_setting:
|
|
1992
|
+
level_str = str(validation_setting).lower()
|
|
1993
|
+
if level_str in ("trusted", "fast"):
|
|
1994
|
+
default_context_level = 0
|
|
1995
|
+
default_retries = 0
|
|
1996
|
+
elif level_str == "progressive":
|
|
1997
|
+
default_context_level = 1
|
|
1998
|
+
default_retries = 2
|
|
1999
|
+
elif level_str in ("strict", "safe"):
|
|
2000
|
+
default_context_level = 3
|
|
2001
|
+
default_retries = 3
|
|
2002
|
+
else:
|
|
2003
|
+
self.logger.warning(
|
|
2004
|
+
f'Unrecognized VALIDATION_LEVEL "{level_str}". '
|
|
2005
|
+
f"Valid values: trusted, fast, progressive, strict, safe. "
|
|
2006
|
+
f"Falling back to default (level 2)."
|
|
2007
|
+
)
|
|
2008
|
+
|
|
2009
|
+
validation_level = (
|
|
2010
|
+
options.context_check_level
|
|
2011
|
+
if options.context_check_level is not None
|
|
2012
|
+
else default_context_level
|
|
2013
|
+
)
|
|
2014
|
+
max_retries = options.max_retries if options.max_retries is not None else default_retries
|
|
2015
|
+
current_retry = 0
|
|
2016
|
+
|
|
2017
|
+
# Generate per-field validation codes for levels 0-1
|
|
2018
|
+
per_field_codes: dict[str, str] = {}
|
|
2019
|
+
if validation_level <= 1:
|
|
2020
|
+
for row in schema:
|
|
2021
|
+
default_validate = validation_level == 1
|
|
2022
|
+
needs_validation = (
|
|
2023
|
+
row.validate_field if row.validate_field is not None else default_validate
|
|
2024
|
+
)
|
|
2025
|
+
if needs_validation:
|
|
2026
|
+
per_field_codes[row.field] = str(uuid.uuid4())[:8]
|
|
2027
|
+
|
|
2028
|
+
# Streaming extractor (created on first iteration if streaming enabled)
|
|
2029
|
+
extractor: ValidationStreamExtractor | None = None
|
|
2030
|
+
|
|
2031
|
+
while current_retry <= max_retries:
|
|
2032
|
+
# Compile template with state values
|
|
2033
|
+
# Callable signature: def my_prompt(ctx: dict) -> str:
|
|
2034
|
+
# return f"Hello {ctx['state'].values.get('name')}"
|
|
2035
|
+
template_str = prompt({"state": state}) if callable(prompt) else prompt
|
|
2036
|
+
|
|
2037
|
+
# Template substitution (Handlebars-like)
|
|
2038
|
+
# Mirrors TypeScript behavior: { ...filteredState, ...state.values }
|
|
2039
|
+
rendered = template_str
|
|
2040
|
+
|
|
2041
|
+
# Helper to extract dict from protobuf message or dict-like object
|
|
2042
|
+
def extract_fields(obj: Any) -> dict[str, Any]:
|
|
2043
|
+
"""Extract fields from protobuf message, dict, or object."""
|
|
2044
|
+
if obj is None:
|
|
2045
|
+
return {}
|
|
2046
|
+
# If it's already a dict, return it
|
|
2047
|
+
if isinstance(obj, dict):
|
|
2048
|
+
return obj
|
|
2049
|
+
# Try MessageToDict for protobuf messages (most reliable)
|
|
2050
|
+
if hasattr(obj, "DESCRIPTOR"):
|
|
2051
|
+
try:
|
|
2052
|
+
from google.protobuf.json_format import MessageToDict
|
|
2053
|
+
|
|
2054
|
+
return MessageToDict(obj, preserving_proto_field_name=True)
|
|
2055
|
+
except Exception:
|
|
2056
|
+
pass
|
|
2057
|
+
# Fallback: try ListFields() for protobuf messages
|
|
2058
|
+
if hasattr(obj, "ListFields"):
|
|
2059
|
+
result = {}
|
|
2060
|
+
for field_desc, value in obj.ListFields():
|
|
2061
|
+
result[field_desc.name] = value
|
|
2062
|
+
return result
|
|
2063
|
+
# Fallback: try __dict__ for regular objects
|
|
2064
|
+
if hasattr(obj, "__dict__"):
|
|
2065
|
+
return {
|
|
2066
|
+
k: v
|
|
2067
|
+
for k, v in obj.__dict__.items()
|
|
2068
|
+
if not k.startswith("_") and v is not None
|
|
2069
|
+
}
|
|
2070
|
+
return {}
|
|
2071
|
+
|
|
2072
|
+
# Build context dict combining state properties and state.values
|
|
2073
|
+
context: dict[str, Any] = {}
|
|
2074
|
+
|
|
2075
|
+
# Add state-level properties (like filteredState in TypeScript)
|
|
2076
|
+
# Exclude 'text', 'values', 'data' like TypeScript does
|
|
2077
|
+
state_fields = extract_fields(state)
|
|
2078
|
+
for key, value in state_fields.items():
|
|
2079
|
+
if key not in ("text", "values", "data"):
|
|
2080
|
+
context[key] = value
|
|
2081
|
+
|
|
2082
|
+
# Add state.data properties
|
|
2083
|
+
if hasattr(state, "data"):
|
|
2084
|
+
data_fields = extract_fields(state.data)
|
|
2085
|
+
context.update(data_fields)
|
|
2086
|
+
|
|
2087
|
+
# Add state.values (these take precedence, like in TypeScript)
|
|
2088
|
+
if hasattr(state, "values"):
|
|
2089
|
+
values_fields = extract_fields(state.values)
|
|
2090
|
+
context.update(values_fields)
|
|
2091
|
+
|
|
2092
|
+
# Add smart retry context if present
|
|
2093
|
+
if "_smartRetryContext" in context:
|
|
2094
|
+
rendered += str(context.pop("_smartRetryContext"))
|
|
2095
|
+
|
|
2096
|
+
# Perform substitution
|
|
2097
|
+
for key, value in context.items():
|
|
2098
|
+
placeholder = f"{{{{{key}}}}}"
|
|
2099
|
+
rendered = rendered.replace(placeholder, str(value))
|
|
2100
|
+
|
|
2101
|
+
# Build format
|
|
2102
|
+
format_type = (options.force_format or "xml").upper()
|
|
2103
|
+
is_xml = format_type == "XML"
|
|
2104
|
+
container_start = "<response>" if is_xml else "{"
|
|
2105
|
+
container_end = "</response>" if is_xml else "}"
|
|
2106
|
+
|
|
2107
|
+
# Build extended schema with validation codes
|
|
2108
|
+
first = validation_level >= 2
|
|
2109
|
+
last = validation_level >= 3
|
|
2110
|
+
|
|
2111
|
+
ext_schema: list[tuple[str, str]] = []
|
|
2112
|
+
|
|
2113
|
+
def codes_schema(prefix: str) -> list[tuple[str, str]]:
|
|
2114
|
+
return [
|
|
2115
|
+
(f"{prefix}initial_code", "echo the initial UUID code from prompt"),
|
|
2116
|
+
(f"{prefix}middle_code", "echo the middle UUID code from prompt"),
|
|
2117
|
+
(f"{prefix}end_code", "echo the end UUID code from prompt"),
|
|
2118
|
+
]
|
|
2119
|
+
|
|
2120
|
+
if first:
|
|
2121
|
+
ext_schema.extend(codes_schema("one_"))
|
|
2122
|
+
|
|
2123
|
+
for row in schema:
|
|
2124
|
+
if row.field in per_field_codes:
|
|
2125
|
+
ext_schema.append(
|
|
2126
|
+
(f"code_{row.field}_start", f"output exactly: {per_field_codes[row.field]}")
|
|
2127
|
+
)
|
|
2128
|
+
ext_schema.append((row.field, row.description))
|
|
2129
|
+
if row.field in per_field_codes:
|
|
2130
|
+
ext_schema.append(
|
|
2131
|
+
(f"code_{row.field}_end", f"output exactly: {per_field_codes[row.field]}")
|
|
2132
|
+
)
|
|
2133
|
+
|
|
2134
|
+
if last:
|
|
2135
|
+
ext_schema.extend(codes_schema("two_"))
|
|
2136
|
+
|
|
2137
|
+
# Build example
|
|
2138
|
+
example_lines = [container_start]
|
|
2139
|
+
for i, (field, desc) in enumerate(ext_schema):
|
|
2140
|
+
is_last = i == len(ext_schema) - 1
|
|
2141
|
+
if is_xml:
|
|
2142
|
+
example_lines.append(f" <{field}>{desc}</{field}>")
|
|
2143
|
+
else:
|
|
2144
|
+
# No trailing comma on last field for valid JSON
|
|
2145
|
+
comma = "" if is_last else ","
|
|
2146
|
+
example_lines.append(f' "{field}": "{desc}"{comma}')
|
|
2147
|
+
example_lines.append(container_end)
|
|
2148
|
+
example = "\n".join(example_lines)
|
|
2149
|
+
|
|
2150
|
+
init_code = str(uuid.uuid4())
|
|
2151
|
+
mid_code = str(uuid.uuid4())
|
|
2152
|
+
final_code = str(uuid.uuid4())
|
|
2153
|
+
|
|
2154
|
+
section_start = "<output>" if is_xml else "# Strict Output instructions"
|
|
2155
|
+
section_end = "</output>" if is_xml else ""
|
|
2156
|
+
|
|
2157
|
+
full_prompt = f"""initial code: {init_code}
|
|
2158
|
+
{rendered}
|
|
2159
|
+
middle code: {mid_code}
|
|
2160
|
+
{section_start}
|
|
2161
|
+
Do NOT include any thinking, reasoning, or <think> sections in your response.
|
|
2162
|
+
Go directly to the {format_type} response format without any preamble or explanation.
|
|
2163
|
+
|
|
2164
|
+
Respond using {format_type} format like this:
|
|
2165
|
+
{example}
|
|
2166
|
+
|
|
2167
|
+
IMPORTANT: Your response must ONLY contain the {container_start}{container_end} {format_type} block above. Do not include any text, thinking, or reasoning before or after this {format_type} block. Start your response immediately with {container_start} and end with {container_end}.
|
|
2168
|
+
{section_end}
|
|
2169
|
+
end code: {final_code}
|
|
2170
|
+
"""
|
|
2171
|
+
|
|
2172
|
+
self.logger.debug(f"dynamic_prompt_exec_from_state: using format {format_type}")
|
|
2173
|
+
|
|
2174
|
+
# ── Prompt trimming safety net ─────────────────────────────────
|
|
2175
|
+
# If the prompt exceeds a character-based budget, trim it to
|
|
2176
|
+
# prevent context-limit errors from the model provider.
|
|
2177
|
+
MAX_PROMPT_CHARS = 256_000 # ~128K tokens at ~2 chars/token
|
|
2178
|
+
if len(full_prompt) > MAX_PROMPT_CHARS:
|
|
2179
|
+
est_tokens = len(full_prompt) // 2
|
|
2180
|
+
self.logger.warning(
|
|
2181
|
+
f"dynamic_prompt_exec_from_state: prompt too large "
|
|
2182
|
+
f"(~{est_tokens:,} est tokens), trimming to ~{MAX_PROMPT_CHARS // 2:,}"
|
|
2183
|
+
)
|
|
2184
|
+
# Keep the end of the prompt (most recent content + output instructions)
|
|
2185
|
+
full_prompt = full_prompt[-MAX_PROMPT_CHARS:]
|
|
2186
|
+
|
|
2187
|
+
# ── Cap maxTokens to fit within model context ──────────────────
|
|
2188
|
+
MODEL_CONTEXT_LIMIT = 200_000
|
|
2189
|
+
est_input = len(full_prompt) // 2 # pessimistic: ~2 chars/token
|
|
2190
|
+
max_tokens = 4096
|
|
2191
|
+
max_available_output = MODEL_CONTEXT_LIMIT - est_input - 1_000
|
|
2192
|
+
if max_tokens > max_available_output > 0:
|
|
2193
|
+
max_tokens = max(1_000, max_available_output)
|
|
2194
|
+
self.logger.warning(
|
|
2195
|
+
f"dynamic_prompt_exec_from_state: capping maxTokens to {max_tokens}"
|
|
2196
|
+
)
|
|
2197
|
+
|
|
2198
|
+
# Call model
|
|
2199
|
+
params = {
|
|
2200
|
+
"prompt": full_prompt,
|
|
2201
|
+
"maxTokens": max_tokens,
|
|
2202
|
+
}
|
|
2203
|
+
|
|
2204
|
+
# Check for cancellation before request
|
|
2205
|
+
if options.abort_signal and options.abort_signal():
|
|
2206
|
+
if extractor:
|
|
2207
|
+
extractor.signal_error("Cancelled by user")
|
|
2208
|
+
return None
|
|
2209
|
+
|
|
2210
|
+
# Create ValidationStreamExtractor on first iteration if streaming enabled (XML only)
|
|
2211
|
+
# JSON streaming bypasses the extractor since it parses XML tags
|
|
2212
|
+
if current_retry == 0 and options.on_stream_chunk and extractor is None and is_xml:
|
|
2213
|
+
has_rich_consumer = options.on_stream_event is not None
|
|
2214
|
+
|
|
2215
|
+
# Determine which fields to stream
|
|
2216
|
+
stream_fields = [
|
|
2217
|
+
row.field
|
|
2218
|
+
for row in schema
|
|
2219
|
+
if (row.stream_field if row.stream_field is not None else row.field == "text")
|
|
2220
|
+
]
|
|
2221
|
+
|
|
2222
|
+
# Default to "text" if no explicit stream fields
|
|
2223
|
+
if not stream_fields and any(row.field == "text" for row in schema):
|
|
2224
|
+
stream_fields = ["text"]
|
|
2225
|
+
|
|
2226
|
+
stream_message_id = f"stream-{uuid.uuid4().hex[:12]}"
|
|
2227
|
+
|
|
2228
|
+
# Capture stream_message_id in default parameter to avoid late binding
|
|
2229
|
+
extractor = ValidationStreamExtractor(
|
|
2230
|
+
ValidationStreamExtractorConfig(
|
|
2231
|
+
level=validation_level,
|
|
2232
|
+
schema=schema,
|
|
2233
|
+
stream_fields=stream_fields,
|
|
2234
|
+
expected_codes=per_field_codes,
|
|
2235
|
+
on_chunk=lambda chunk,
|
|
2236
|
+
_field,
|
|
2237
|
+
msg_id=stream_message_id: options.on_stream_chunk(chunk, msg_id)
|
|
2238
|
+
if options.on_stream_chunk
|
|
2239
|
+
else None,
|
|
2240
|
+
on_event=lambda event, msg_id=stream_message_id: options.on_stream_event(
|
|
2241
|
+
event, msg_id
|
|
2242
|
+
)
|
|
2243
|
+
if options.on_stream_event
|
|
2244
|
+
else None,
|
|
2245
|
+
abort_signal=options.abort_signal,
|
|
2246
|
+
has_rich_consumer=has_rich_consumer,
|
|
2247
|
+
)
|
|
2248
|
+
)
|
|
2249
|
+
|
|
2250
|
+
try:
|
|
2251
|
+
# Use streaming if extractor is active, otherwise use non-streaming
|
|
2252
|
+
if extractor:
|
|
2253
|
+
# Streaming mode: use use_model_stream and feed chunks to extractor
|
|
2254
|
+
response_parts: list[str] = []
|
|
2255
|
+
stream_model_type = (
|
|
2256
|
+
f"{model_type_str}_STREAM"
|
|
2257
|
+
if not model_type_str.endswith("_STREAM")
|
|
2258
|
+
else model_type_str
|
|
2259
|
+
)
|
|
2260
|
+
|
|
2261
|
+
async for chunk in self.use_model_stream(stream_model_type, params):
|
|
2262
|
+
if options.abort_signal and options.abort_signal():
|
|
2263
|
+
extractor.signal_error("Cancelled by user")
|
|
2264
|
+
return None
|
|
2265
|
+
response_parts.append(chunk)
|
|
2266
|
+
extractor.push(chunk)
|
|
2267
|
+
|
|
2268
|
+
# Flush extractor and get final state
|
|
2269
|
+
extractor.flush()
|
|
2270
|
+
response_str = "".join(response_parts)
|
|
2271
|
+
else:
|
|
2272
|
+
# Non-streaming mode: use use_model
|
|
2273
|
+
response = await self.use_model(model_type_str, params)
|
|
2274
|
+
response_str = str(response) if response else ""
|
|
2275
|
+
except Exception as e:
|
|
2276
|
+
self.logger.error(f"Model call failed: {e}")
|
|
2277
|
+
current_retry += 1
|
|
2278
|
+
if current_retry <= max_retries and options.retry_backoff:
|
|
2279
|
+
delay = options.retry_backoff.delay_for_retry(current_retry)
|
|
2280
|
+
self.logger.debug(
|
|
2281
|
+
f"Retry backoff: waiting {delay}ms before retry {current_retry}"
|
|
2282
|
+
)
|
|
2283
|
+
await asyncio.sleep(delay / 1000.0)
|
|
2284
|
+
if extractor:
|
|
2285
|
+
extractor.reset()
|
|
2286
|
+
continue
|
|
2287
|
+
|
|
2288
|
+
# Clean response (remove <think> blocks)
|
|
2289
|
+
clean_response = re.sub(r"<think>[\s\S]*?</think>", "", response_str)
|
|
2290
|
+
|
|
2291
|
+
# Parse response
|
|
2292
|
+
response_content: dict[str, Any] | None = None
|
|
2293
|
+
if is_xml:
|
|
2294
|
+
response_content = self._parse_xml_to_dict(clean_response)
|
|
2295
|
+
else:
|
|
2296
|
+
import contextlib
|
|
2297
|
+
import json
|
|
2298
|
+
|
|
2299
|
+
with contextlib.suppress(json.JSONDecodeError):
|
|
2300
|
+
# JSON parse may fail - response_content remains None if so
|
|
2301
|
+
# This triggers retry logic below via the all_good = False path
|
|
2302
|
+
response_content = json.loads(clean_response)
|
|
2303
|
+
|
|
2304
|
+
all_good = True
|
|
2305
|
+
|
|
2306
|
+
if response_content:
|
|
2307
|
+
# Validate codes based on context level
|
|
2308
|
+
if validation_level <= 1:
|
|
2309
|
+
# Per-field validation
|
|
2310
|
+
for field, expected_code in per_field_codes.items():
|
|
2311
|
+
start_code = response_content.get(f"code_{field}_start")
|
|
2312
|
+
end_code = response_content.get(f"code_{field}_end")
|
|
2313
|
+
if start_code != expected_code or end_code != expected_code:
|
|
2314
|
+
self.logger.warning(
|
|
2315
|
+
f"Per-field validation failed for {field}: expected={expected_code}, start={start_code}, end={end_code}"
|
|
2316
|
+
)
|
|
2317
|
+
all_good = False
|
|
2318
|
+
else:
|
|
2319
|
+
# Checkpoint validation
|
|
2320
|
+
validation_codes = [
|
|
2321
|
+
(first, "one_initial_code", init_code),
|
|
2322
|
+
(first, "one_middle_code", mid_code),
|
|
2323
|
+
(first, "one_end_code", final_code),
|
|
2324
|
+
(last, "two_initial_code", init_code),
|
|
2325
|
+
(last, "two_middle_code", mid_code),
|
|
2326
|
+
(last, "two_end_code", final_code),
|
|
2327
|
+
]
|
|
2328
|
+
for enabled, field, expected in validation_codes:
|
|
2329
|
+
if enabled:
|
|
2330
|
+
actual = response_content.get(field)
|
|
2331
|
+
if actual != expected:
|
|
2332
|
+
self.logger.warning(
|
|
2333
|
+
f"Checkpoint {field} mismatch: expected {expected}"
|
|
2334
|
+
)
|
|
2335
|
+
all_good = False
|
|
2336
|
+
|
|
2337
|
+
# Validate required fields
|
|
2338
|
+
if options.required_fields:
|
|
2339
|
+
for field in options.required_fields:
|
|
2340
|
+
value = response_content.get(field)
|
|
2341
|
+
is_missing = (
|
|
2342
|
+
value is None
|
|
2343
|
+
or (isinstance(value, str) and not value.strip())
|
|
2344
|
+
or (isinstance(value, (list, dict)) and not value)
|
|
2345
|
+
)
|
|
2346
|
+
if is_missing:
|
|
2347
|
+
self.logger.warning(f"Missing required field: {field}")
|
|
2348
|
+
all_good = False
|
|
2349
|
+
|
|
2350
|
+
# Clean up validation code fields from result
|
|
2351
|
+
for field in list(per_field_codes.keys()):
|
|
2352
|
+
response_content.pop(f"code_{field}_start", None)
|
|
2353
|
+
response_content.pop(f"code_{field}_end", None)
|
|
2354
|
+
if first:
|
|
2355
|
+
response_content.pop("one_initial_code", None)
|
|
2356
|
+
response_content.pop("one_middle_code", None)
|
|
2357
|
+
response_content.pop("one_end_code", None)
|
|
2358
|
+
if last:
|
|
2359
|
+
response_content.pop("two_initial_code", None)
|
|
2360
|
+
response_content.pop("two_middle_code", None)
|
|
2361
|
+
response_content.pop("two_end_code", None)
|
|
2362
|
+
else:
|
|
2363
|
+
self.logger.warning(
|
|
2364
|
+
f"dynamic_prompt_exec_from_state parse problem: {clean_response[:500]}"
|
|
2365
|
+
)
|
|
2366
|
+
all_good = False
|
|
2367
|
+
|
|
2368
|
+
if all_good and response_content:
|
|
2369
|
+
self.logger.debug(f"dynamic_prompt_exec_from_state success [{model_schema_key}]")
|
|
2370
|
+
# Clean up smart retry context from state
|
|
2371
|
+
if hasattr(state, "values") and "_smartRetryContext" in getattr(
|
|
2372
|
+
state.values, "__dict__", state.values if isinstance(state.values, dict) else {}
|
|
2373
|
+
):
|
|
2374
|
+
with contextlib.suppress(KeyError, TypeError):
|
|
2375
|
+
del state.values["_smartRetryContext"]
|
|
2376
|
+
return response_content
|
|
2377
|
+
|
|
2378
|
+
current_retry += 1
|
|
2379
|
+
|
|
2380
|
+
# Signal retry to extractor if present
|
|
2381
|
+
if extractor:
|
|
2382
|
+
extractor.signal_retry(current_retry)
|
|
2383
|
+
extractor.reset()
|
|
2384
|
+
|
|
2385
|
+
# Build smart retry context for level 1 (per-field validation)
|
|
2386
|
+
if validation_level == 1 and response_content:
|
|
2387
|
+
# Find validated fields (those with correct codes)
|
|
2388
|
+
validated_fields: list[str] = []
|
|
2389
|
+
for field, expected_code in per_field_codes.items():
|
|
2390
|
+
start_code = response_content.get(f"code_{field}_start")
|
|
2391
|
+
end_code = response_content.get(f"code_{field}_end")
|
|
2392
|
+
if start_code == expected_code and end_code == expected_code:
|
|
2393
|
+
validated_fields.append(field)
|
|
2394
|
+
|
|
2395
|
+
if validated_fields:
|
|
2396
|
+
# Build retry context with validated fields
|
|
2397
|
+
validated_parts: list[str] = []
|
|
2398
|
+
for field in validated_fields:
|
|
2399
|
+
content = response_content.get(field, "")
|
|
2400
|
+
if content:
|
|
2401
|
+
truncated = (
|
|
2402
|
+
content[:500] + "..." if len(str(content)) > 500 else str(content)
|
|
2403
|
+
)
|
|
2404
|
+
validated_parts.append(f"<{field}>{truncated}</{field}>")
|
|
2405
|
+
|
|
2406
|
+
if validated_parts:
|
|
2407
|
+
# Find missing/invalid fields
|
|
2408
|
+
all_fields = {row.field for row in schema}
|
|
2409
|
+
missing = [f for f in all_fields if f not in validated_fields]
|
|
2410
|
+
smart_retry_context = (
|
|
2411
|
+
f"\n\n[RETRY CONTEXT]\n"
|
|
2412
|
+
f"You previously produced these valid fields:\n"
|
|
2413
|
+
f"{chr(10).join(validated_parts)}\n\n"
|
|
2414
|
+
f"Please complete: {', '.join(missing) if missing else 'all fields'}"
|
|
2415
|
+
)
|
|
2416
|
+
# Store in state for next iteration (may fail on protobuf)
|
|
2417
|
+
if hasattr(state, "values"):
|
|
2418
|
+
import contextlib
|
|
2419
|
+
|
|
2420
|
+
with contextlib.suppress(TypeError):
|
|
2421
|
+
# Protobuf messages don't support item assignment
|
|
2422
|
+
state.values["_smartRetryContext"] = smart_retry_context
|
|
2423
|
+
|
|
2424
|
+
self.logger.warn(
|
|
2425
|
+
f"dynamic_prompt_exec_from_state retry {current_retry}/{max_retries} "
|
|
2426
|
+
f"validated={','.join(validated_fields) or 'none'}"
|
|
2427
|
+
)
|
|
2428
|
+
|
|
2429
|
+
if current_retry <= max_retries and options.retry_backoff:
|
|
2430
|
+
delay = options.retry_backoff.delay_for_retry(current_retry)
|
|
2431
|
+
self.logger.debug(f"Retry backoff: waiting {delay}ms before retry {current_retry}")
|
|
2432
|
+
await asyncio.sleep(delay / 1000.0)
|
|
2433
|
+
|
|
2434
|
+
self.logger.error(
|
|
2435
|
+
f"dynamic_prompt_exec_from_state failed after {max_retries} retries [{model_schema_key}]"
|
|
2436
|
+
)
|
|
2437
|
+
|
|
2438
|
+
# Signal error to extractor if present
|
|
2439
|
+
if extractor:
|
|
2440
|
+
diagnosis = extractor.diagnose()
|
|
2441
|
+
missing = diagnosis.missing_fields
|
|
2442
|
+
invalid = diagnosis.invalid_fields
|
|
2443
|
+
incomplete = diagnosis.incomplete_fields
|
|
2444
|
+
extractor.signal_error(
|
|
2445
|
+
f"Failed after {max_retries} retries. "
|
|
2446
|
+
f"Missing: {missing}, Invalid: {invalid}, Incomplete: {incomplete}"
|
|
2447
|
+
)
|
|
2448
|
+
|
|
2449
|
+
# Clean up smart retry context from state
|
|
2450
|
+
if hasattr(state, "values") and "_smartRetryContext" in getattr(
|
|
2451
|
+
state.values, "__dict__", state.values if isinstance(state.values, dict) else {}
|
|
2452
|
+
):
|
|
2453
|
+
with contextlib.suppress(KeyError, TypeError):
|
|
2454
|
+
del state.values["_smartRetryContext"]
|
|
2455
|
+
return None
|
|
2456
|
+
|
|
2457
|
+
def _parse_xml_to_dict(self, xml_text: str) -> dict[str, Any] | None:
|
|
2458
|
+
"""Parse XML-like response to dict using ElementTree for nested XML support."""
|
|
2459
|
+
|
|
2460
|
+
def element_to_dict(element: ET.Element) -> dict[str, Any] | str:
|
|
2461
|
+
"""Recursively convert an XML element to a dict."""
|
|
2462
|
+
children = list(element)
|
|
2463
|
+
if not children:
|
|
2464
|
+
# Leaf node - return trimmed text
|
|
2465
|
+
return (element.text or "").strip()
|
|
2466
|
+
|
|
2467
|
+
# Has children - build nested dict
|
|
2468
|
+
result: dict[str, Any] = {}
|
|
2469
|
+
for child in children:
|
|
2470
|
+
child_value = element_to_dict(child)
|
|
2471
|
+
if child.tag in result:
|
|
2472
|
+
# Handle duplicate tags by converting to list
|
|
2473
|
+
existing = result[child.tag]
|
|
2474
|
+
if isinstance(existing, list):
|
|
2475
|
+
existing.append(child_value)
|
|
2476
|
+
else:
|
|
2477
|
+
result[child.tag] = [existing, child_value]
|
|
2478
|
+
else:
|
|
2479
|
+
result[child.tag] = child_value
|
|
2480
|
+
return result
|
|
2481
|
+
|
|
2482
|
+
try:
|
|
2483
|
+
# Try to find and parse the response element
|
|
2484
|
+
# First try to find <response>...</response>
|
|
2485
|
+
response_match = re.search(r"<response>([\s\S]*?)</response>", xml_text)
|
|
2486
|
+
if response_match:
|
|
2487
|
+
xml_content = f"<response>{response_match.group(1)}</response>"
|
|
2488
|
+
else:
|
|
2489
|
+
# Try to wrap content if it looks like XML tags
|
|
2490
|
+
xml_content = f"<root>{xml_text}</root>"
|
|
2491
|
+
|
|
2492
|
+
root = ET.fromstring(xml_content)
|
|
2493
|
+
result = element_to_dict(root)
|
|
2494
|
+
|
|
2495
|
+
if isinstance(result, dict) and result:
|
|
2496
|
+
return result
|
|
2497
|
+
return None
|
|
2498
|
+
except ET.ParseError:
|
|
2499
|
+
# Fall back to regex for malformed XML with recursive nested tag parsing
|
|
2500
|
+
def parse_nested(text: str) -> dict[str, Any]:
|
|
2501
|
+
"""Recursively parse nested XML tags."""
|
|
2502
|
+
result: dict[str, Any] = {}
|
|
2503
|
+
pattern = r"<([\w-]+)>([\s\S]*?)</\1>"
|
|
2504
|
+
matches = re.findall(pattern, text)
|
|
2505
|
+
for tag_name, content in matches:
|
|
2506
|
+
content_stripped = content.strip()
|
|
2507
|
+
# Check if content has nested tags
|
|
2508
|
+
if re.search(r"<[\w-]+>", content_stripped):
|
|
2509
|
+
# Recursively parse nested content
|
|
2510
|
+
nested = parse_nested(content_stripped)
|
|
2511
|
+
if nested:
|
|
2512
|
+
result[tag_name] = nested
|
|
2513
|
+
else:
|
|
2514
|
+
result[tag_name] = content_stripped
|
|
2515
|
+
else:
|
|
2516
|
+
result[tag_name] = content_stripped
|
|
2517
|
+
return result
|
|
2518
|
+
|
|
2519
|
+
# First try to unwrap <response> wrapper and parse inner content
|
|
2520
|
+
# Use lazy *? to avoid matching too much if multiple response tags exist
|
|
2521
|
+
response_match = re.search(r"<response>([\s\S]*?)</response>", xml_text)
|
|
2522
|
+
if response_match:
|
|
2523
|
+
inner_result = parse_nested(response_match.group(1))
|
|
2524
|
+
if inner_result:
|
|
2525
|
+
return inner_result
|
|
2526
|
+
|
|
2527
|
+
# Otherwise parse the whole text
|
|
2528
|
+
result = parse_nested(xml_text)
|
|
2529
|
+
return result if result else None
|
|
2530
|
+
|
|
2531
|
+
async def search_knowledge(self, query: str, limit: int = 5) -> list[object]:
|
|
2532
|
+
"""Search for knowledge matching the given query."""
|
|
2533
|
+
if not self._adapter:
|
|
2534
|
+
return []
|
|
2535
|
+
return await self._adapter.search_memories({"query": query, "limit": limit})
|
|
2536
|
+
|
|
2537
|
+
|
|
2538
|
+
@dataclass
|
|
2539
|
+
class DynamicPromptOptions:
|
|
2540
|
+
"""Options for dynamic prompt execution."""
|
|
2541
|
+
|
|
2542
|
+
model_size: str | None = None
|
|
2543
|
+
"""Model size to use ('small' or 'large')"""
|
|
2544
|
+
|
|
2545
|
+
model: str | None = None
|
|
2546
|
+
"""Specific model identifier override"""
|
|
2547
|
+
|
|
2548
|
+
force_format: str | None = None
|
|
2549
|
+
"""Force output format ('json' or 'xml')"""
|
|
2550
|
+
|
|
2551
|
+
required_fields: list[str] | None = None
|
|
2552
|
+
"""Required fields that must be present and non-empty"""
|
|
2553
|
+
|
|
2554
|
+
context_check_level: int | None = None
|
|
2555
|
+
"""Validation level (0=trusted, 1=progressive, 2=checkpoint, 3=full)"""
|
|
2556
|
+
|
|
2557
|
+
max_retries: int | None = None
|
|
2558
|
+
"""Maximum retry attempts"""
|
|
2559
|
+
|
|
2560
|
+
retry_backoff: RetryBackoffConfig | None = None
|
|
2561
|
+
"""Retry backoff configuration"""
|
|
2562
|
+
|
|
2563
|
+
on_stream_chunk: Callable[[str, str | None], Any] | None = None
|
|
2564
|
+
"""Callback for streaming chunks (chunk, message_id) -> None.
|
|
2565
|
+
If provided, enables streaming with validation-aware extraction."""
|
|
2566
|
+
|
|
2567
|
+
on_stream_event: Callable[[StreamEvent, str | None], Any] | None = None
|
|
2568
|
+
"""Callback for rich streaming events (event, message_id) -> None.
|
|
2569
|
+
Provides detailed events for advanced UIs (field validation, retries, errors)."""
|
|
2570
|
+
|
|
2571
|
+
abort_signal: Callable[[], bool] | None = None
|
|
2572
|
+
"""Callable returning True if the operation should be aborted."""
|