hyperforge 1.0.0.post19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperforge/__init__.py +16 -0
- hyperforge/agent.py +81 -0
- hyperforge/api/__init__.py +20 -0
- hyperforge/api/app.py +155 -0
- hyperforge/api/authentication.py +271 -0
- hyperforge/api/commands.py +33 -0
- hyperforge/api/internal/__init__.py +4 -0
- hyperforge/api/internal/inspect.py +30 -0
- hyperforge/api/internal/router.py +3 -0
- hyperforge/api/logging.py +18 -0
- hyperforge/api/models.py +129 -0
- hyperforge/api/session.py +197 -0
- hyperforge/api/settings.py +38 -0
- hyperforge/api/utils.py +354 -0
- hyperforge/api/v1/__init__.py +23 -0
- hyperforge/api/v1/agents.py +531 -0
- hyperforge/api/v1/interaction.py +430 -0
- hyperforge/api/v1/mcp_content.py +311 -0
- hyperforge/api/v1/mcp_interaction.py +322 -0
- hyperforge/api/v1/oauth.py +60 -0
- hyperforge/api/v1/prompt.py +129 -0
- hyperforge/api/v1/router.py +3 -0
- hyperforge/api/v1/schema.py +56 -0
- hyperforge/api/v1/session.py +182 -0
- hyperforge/api/v1/utils.py +12 -0
- hyperforge/api/v1/workflows.py +643 -0
- hyperforge/arag.py +28 -0
- hyperforge/broker/__init__.py +52 -0
- hyperforge/broker/local.py +116 -0
- hyperforge/broker/redis.py +161 -0
- hyperforge/configure.py +571 -0
- hyperforge/context/__init__.py +0 -0
- hyperforge/context/agent.py +377 -0
- hyperforge/context/config.py +103 -0
- hyperforge/database.py +3 -0
- hyperforge/db/__init__.py +6 -0
- hyperforge/db/agents.py +1521 -0
- hyperforge/db/encryption.py +91 -0
- hyperforge/db/exceptions.py +26 -0
- hyperforge/db/settings.py +16 -0
- hyperforge/db/workflow_cleanup.py +69 -0
- hyperforge/definition.py +13 -0
- hyperforge/driver.py +31 -0
- hyperforge/dummy.py +28 -0
- hyperforge/engine.py +189 -0
- hyperforge/exceptions.py +14 -0
- hyperforge/feature_flag.py +105 -0
- hyperforge/fixtures.py +602 -0
- hyperforge/interaction.py +116 -0
- hyperforge/llm.py +75 -0
- hyperforge/manager.py +432 -0
- hyperforge/memory/__init__.py +5 -0
- hyperforge/memory/memory.py +974 -0
- hyperforge/minimal_fixtures.py +75 -0
- hyperforge/models.py +336 -0
- hyperforge/nua.py +336 -0
- hyperforge/openapi.py +63 -0
- hyperforge/prompts.py +188 -0
- hyperforge/pubsub.py +90 -0
- hyperforge/py.typed +0 -0
- hyperforge/redis_utils.py +82 -0
- hyperforge/retrieval/__init__.py +0 -0
- hyperforge/retrieval/agent.py +169 -0
- hyperforge/retrieval/config.py +94 -0
- hyperforge/server/__init__.py +5 -0
- hyperforge/server/cache.py +131 -0
- hyperforge/server/run.py +109 -0
- hyperforge/server/sandbox.py +60 -0
- hyperforge/server/session.py +421 -0
- hyperforge/server/settings.py +47 -0
- hyperforge/server/utils.py +57 -0
- hyperforge/server/web.py +31 -0
- hyperforge/settings.py +18 -0
- hyperforge/standalone/__init__.py +5 -0
- hyperforge/standalone/agent.py +189 -0
- hyperforge/standalone/app.py +264 -0
- hyperforge/standalone/config.py +137 -0
- hyperforge/standalone/const.py +1 -0
- hyperforge/standalone/run.py +60 -0
- hyperforge/standalone/settings.py +133 -0
- hyperforge/standalone/ui_router.py +241 -0
- hyperforge/trace.py +42 -0
- hyperforge/utils/__init__.py +112 -0
- hyperforge/utils/http.py +48 -0
- hyperforge/workflows.py +44 -0
- hyperforge-1.0.0.post19.dist-info/METADATA +95 -0
- hyperforge-1.0.0.post19.dist-info/RECORD +90 -0
- hyperforge-1.0.0.post19.dist-info/WHEEL +5 -0
- hyperforge-1.0.0.post19.dist-info/entry_points.txt +8 -0
- hyperforge-1.0.0.post19.dist-info/top_level.txt +1 -0
hyperforge/configure.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import logging
|
|
3
|
+
import sys
|
|
4
|
+
import types
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import (
|
|
7
|
+
TYPE_CHECKING,
|
|
8
|
+
Any,
|
|
9
|
+
Dict,
|
|
10
|
+
Generic,
|
|
11
|
+
List,
|
|
12
|
+
Optional,
|
|
13
|
+
Tuple,
|
|
14
|
+
Type,
|
|
15
|
+
TypeVar,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
from hyperforge.feature_flag import Features, has_feature
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from hyperforge.agent import Agent, AgentConfig
|
|
24
|
+
from hyperforge.driver import Driver, DriverConfig
|
|
25
|
+
ResolvableType = TypeVar("ResolvableType", types.ModuleType, types.FunctionType, type)
|
|
26
|
+
T = TypeVar("T", bound="BaseRegistry")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class BaseRegistry:
|
|
31
|
+
id: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class AgentRegistry(BaseRegistry):
|
|
36
|
+
agent_type: str
|
|
37
|
+
title: str
|
|
38
|
+
description: str
|
|
39
|
+
config_schema: Type["AgentConfig"]
|
|
40
|
+
klass: Optional[Type["Agent"]] = None
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class DriverRegistry(BaseRegistry):
|
|
45
|
+
title: str
|
|
46
|
+
description: str
|
|
47
|
+
config_schema: Type["DriverConfig"]
|
|
48
|
+
klass: Optional[Type["Driver"]] = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass
|
|
52
|
+
class Registration:
|
|
53
|
+
klass: Any
|
|
54
|
+
config: BaseRegistry
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
ConfigurationType = List[Tuple[str, Registration]]
|
|
58
|
+
|
|
59
|
+
_registered_configurations: ConfigurationType = []
|
|
60
|
+
# stored as tuple of (type, configuration) so we get keep it in the order
|
|
61
|
+
# it is registered even if you mix types of registrations
|
|
62
|
+
|
|
63
|
+
_registered_configuration_handlers = {}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class Registry:
|
|
68
|
+
agents: Dict[str, AgentRegistry] = field(default_factory=dict)
|
|
69
|
+
drivers: Dict[str, DriverRegistry] = field(default_factory=dict)
|
|
70
|
+
preprocess_agents: Dict[str, AgentRegistry] = field(default_factory=dict)
|
|
71
|
+
postprocess_agents: Dict[str, AgentRegistry] = field(default_factory=dict)
|
|
72
|
+
generation_agents: Dict[str, AgentRegistry] = field(default_factory=dict)
|
|
73
|
+
context_agents: Dict[str, AgentRegistry] = field(default_factory=dict)
|
|
74
|
+
|
|
75
|
+
def clear(self):
|
|
76
|
+
self.agents.clear()
|
|
77
|
+
self.drivers.clear()
|
|
78
|
+
self.preprocess_agents.clear()
|
|
79
|
+
self.postprocess_agents.clear()
|
|
80
|
+
self.generation_agents.clear()
|
|
81
|
+
self.context_agents.clear()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
GLOBAL_REGISTRY = Registry()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def resolve_dotted_name(name: Any) -> Any:
|
|
88
|
+
"""
|
|
89
|
+
import the provided dotted name
|
|
90
|
+
|
|
91
|
+
:param name: dotted name
|
|
92
|
+
"""
|
|
93
|
+
if not isinstance(name, str):
|
|
94
|
+
return name # already an object
|
|
95
|
+
names = name.split(".")
|
|
96
|
+
used = names.pop(0)
|
|
97
|
+
found = __import__(used)
|
|
98
|
+
for n in names:
|
|
99
|
+
used += "." + n
|
|
100
|
+
try:
|
|
101
|
+
found = getattr(found, n)
|
|
102
|
+
except AttributeError:
|
|
103
|
+
__import__(used)
|
|
104
|
+
found = getattr(found, n)
|
|
105
|
+
|
|
106
|
+
return found
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def register_configuration_handler(type_, handler):
|
|
110
|
+
_registered_configuration_handlers[type_] = handler
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def register_configuration(klass: ResolvableType, config: BaseRegistry, type_: str):
|
|
114
|
+
value = (type_, Registration(klass=klass, config=config))
|
|
115
|
+
if value not in _registered_configurations:
|
|
116
|
+
# do not register twice
|
|
117
|
+
_registered_configurations.append(value)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def get_caller_module(
|
|
121
|
+
level: int = 2, sys: types.ModuleType = sys
|
|
122
|
+
) -> Optional[types.ModuleType]: # pylint: disable=W0621
|
|
123
|
+
"""
|
|
124
|
+
Pulled out of pyramid
|
|
125
|
+
"""
|
|
126
|
+
module_globals = sys._getframe(level).f_globals # type: ignore
|
|
127
|
+
module_name = module_globals.get("__name__") or "__main__"
|
|
128
|
+
module = sys.modules[module_name] # type: ignore
|
|
129
|
+
return module
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def resolve_module_path(path: str) -> str:
|
|
133
|
+
if len(path) > 0 and path[0] == ".":
|
|
134
|
+
caller_mod = get_caller_module()
|
|
135
|
+
caller_path = get_module_dotted_name(caller_mod)
|
|
136
|
+
caller_path = ".".join(caller_path.split(".")[: -path.count("..")]) # type: ignore
|
|
137
|
+
path = caller_path + "." + path.split("..")[-1].strip(".")
|
|
138
|
+
return path
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_module_dotted_name(ob) -> Optional[str]:
|
|
142
|
+
return getattr(ob, "__module__", None) or getattr(ob, "__name__", None)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class _base_decorator(Generic[T]):
|
|
146
|
+
configuration_type: str = "base"
|
|
147
|
+
config_klass: Type[T] = BaseRegistry # type: ignore
|
|
148
|
+
|
|
149
|
+
def __init__(self, **config):
|
|
150
|
+
self.config = config
|
|
151
|
+
|
|
152
|
+
def __call__(self, klass):
|
|
153
|
+
config_klass_instance = self.config_klass(**self.config)
|
|
154
|
+
register_configuration(klass, config_klass_instance, self.configuration_type)
|
|
155
|
+
return klass
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class agent(_base_decorator[AgentRegistry]):
|
|
159
|
+
configuration_type = "agent"
|
|
160
|
+
config_klass = AgentRegistry
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class driver(_base_decorator[DriverRegistry]):
|
|
164
|
+
configuration_type = "driver"
|
|
165
|
+
config_klass = DriverRegistry
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def scan(path: str):
|
|
169
|
+
"""
|
|
170
|
+
Load a module dotted name.
|
|
171
|
+
|
|
172
|
+
:param path: dotted name
|
|
173
|
+
"""
|
|
174
|
+
path = resolve_module_path(path)
|
|
175
|
+
__import__(path)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def clear():
|
|
179
|
+
_registered_configurations[:] = []
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_configurations(module_name, type_=None, excluded=None):
|
|
183
|
+
results = []
|
|
184
|
+
for reg_type, registration in _registered_configurations:
|
|
185
|
+
if type_ is not None and reg_type != type_:
|
|
186
|
+
continue
|
|
187
|
+
module = registration.klass
|
|
188
|
+
normalized_name = get_module_dotted_name(resolve_dotted_name(module))
|
|
189
|
+
|
|
190
|
+
if normalized_name is not None and (normalized_name + ".").startswith(
|
|
191
|
+
module_name + "."
|
|
192
|
+
):
|
|
193
|
+
valid = True
|
|
194
|
+
for excluded_module in excluded or []:
|
|
195
|
+
if (normalized_name + ".").startswith(excluded_module + "."):
|
|
196
|
+
valid = False
|
|
197
|
+
break
|
|
198
|
+
if valid:
|
|
199
|
+
results.append((reg_type, registration))
|
|
200
|
+
return results
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def load_all_configurations(
|
|
204
|
+
module_name, _context: Registry = GLOBAL_REGISTRY, excluded=None
|
|
205
|
+
):
|
|
206
|
+
configurations = get_configurations(module_name, excluded=excluded)
|
|
207
|
+
registration: Registration
|
|
208
|
+
for type_, registration in configurations:
|
|
209
|
+
try:
|
|
210
|
+
_registered_configuration_handlers[type_](
|
|
211
|
+
registration.config, registration.klass, _context
|
|
212
|
+
)
|
|
213
|
+
except TypeError:
|
|
214
|
+
logger.error("Can not find %s module" % registration.klass)
|
|
215
|
+
raise
|
|
216
|
+
return configurations
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def load_agent(
|
|
220
|
+
registration: AgentRegistry,
|
|
221
|
+
klass: ResolvableType,
|
|
222
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
223
|
+
) -> Any:
|
|
224
|
+
agent_id = registration.id
|
|
225
|
+
registration.klass = resolve_dotted_name(klass)
|
|
226
|
+
if agent_id is None:
|
|
227
|
+
raise Exception("Agent configuration must have an 'id' field")
|
|
228
|
+
if agent_id in _context.agents:
|
|
229
|
+
# Already registered
|
|
230
|
+
return
|
|
231
|
+
|
|
232
|
+
_context.agents[agent_id] = registration
|
|
233
|
+
if registration.agent_type == "preprocess":
|
|
234
|
+
_context.preprocess_agents[agent_id] = registration
|
|
235
|
+
elif registration.agent_type == "postprocess":
|
|
236
|
+
_context.postprocess_agents[agent_id] = registration
|
|
237
|
+
elif registration.agent_type == "generation":
|
|
238
|
+
_context.generation_agents[agent_id] = registration
|
|
239
|
+
elif registration.agent_type == "context":
|
|
240
|
+
_context.context_agents[agent_id] = registration
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
register_configuration_handler("agent", load_agent)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def load_driver(
|
|
247
|
+
registration: DriverRegistry,
|
|
248
|
+
klass: ResolvableType,
|
|
249
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
250
|
+
) -> Any:
|
|
251
|
+
driver_id = registration.id
|
|
252
|
+
registration.klass = resolve_dotted_name(klass)
|
|
253
|
+
if driver_id is None:
|
|
254
|
+
raise Exception("Driver configuration must have an 'id' field")
|
|
255
|
+
if driver_id in _context.drivers:
|
|
256
|
+
# Already registered
|
|
257
|
+
return
|
|
258
|
+
|
|
259
|
+
_context.drivers[driver_id] = registration
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
register_configuration_handler("driver", load_driver)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def get_agent_config_instance(
|
|
266
|
+
agent_config: Dict[str, Any], agent_type: str, _context: Registry = GLOBAL_REGISTRY
|
|
267
|
+
) -> "AgentConfig":
|
|
268
|
+
module = agent_config.get("module")
|
|
269
|
+
if module is None:
|
|
270
|
+
raise Exception("Agent configuration must have a 'module' field")
|
|
271
|
+
agent_config_klass = get_agent_config_klass(
|
|
272
|
+
module=module, agent_type=agent_type, _context=_context
|
|
273
|
+
)
|
|
274
|
+
return agent_config_klass.model_validate(agent_config)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
async def create_agent_instance(
|
|
278
|
+
agent_config: Dict[str, Any], agent_type: str, _config: Registry = GLOBAL_REGISTRY
|
|
279
|
+
) -> "Agent":
|
|
280
|
+
module = agent_config.get("module")
|
|
281
|
+
handler = None
|
|
282
|
+
if module is None:
|
|
283
|
+
raise Exception("Agent configuration must have a 'module' field")
|
|
284
|
+
if agent_type == "generation":
|
|
285
|
+
generation_agent = _config.generation_agents.get(module)
|
|
286
|
+
if generation_agent is None:
|
|
287
|
+
raise Exception(
|
|
288
|
+
f"Generation agent module '{module}' is not registered in the generation agents registry"
|
|
289
|
+
)
|
|
290
|
+
config_object = generation_agent.config_schema.model_validate(agent_config)
|
|
291
|
+
handler = generation_agent.klass
|
|
292
|
+
elif agent_type == "preprocess":
|
|
293
|
+
preprocess_agent = _config.preprocess_agents.get(module)
|
|
294
|
+
if preprocess_agent is None:
|
|
295
|
+
raise Exception(
|
|
296
|
+
f"Preprocess agent module '{module}' is not registered in the preprocess agents registry"
|
|
297
|
+
)
|
|
298
|
+
config_object = preprocess_agent.config_schema.model_validate(agent_config)
|
|
299
|
+
handler = preprocess_agent.klass
|
|
300
|
+
elif agent_type == "postprocess":
|
|
301
|
+
postprocess_agent = _config.postprocess_agents.get(module)
|
|
302
|
+
if postprocess_agent is None:
|
|
303
|
+
raise Exception(
|
|
304
|
+
f"Postprocess agent module '{module}' is not registered in the postprocess agents registry"
|
|
305
|
+
)
|
|
306
|
+
config_object = postprocess_agent.config_schema.model_validate(agent_config)
|
|
307
|
+
handler = postprocess_agent.klass
|
|
308
|
+
elif agent_type == "context":
|
|
309
|
+
context_agent = _config.context_agents.get(module)
|
|
310
|
+
if context_agent is None:
|
|
311
|
+
raise Exception(
|
|
312
|
+
f"Context agent module '{module}' is not registered in the context agents registry"
|
|
313
|
+
)
|
|
314
|
+
config_object = context_agent.config_schema.model_validate(agent_config)
|
|
315
|
+
handler = context_agent.klass
|
|
316
|
+
else:
|
|
317
|
+
raise
|
|
318
|
+
if handler is None:
|
|
319
|
+
raise Exception(f"Klass not found for agent module '{module}'")
|
|
320
|
+
return await handler.from_config(config_object)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def get_agent_config_klass(
|
|
324
|
+
module: str, agent_type: Optional[str] = None, _context: Registry = GLOBAL_REGISTRY
|
|
325
|
+
) -> Type["AgentConfig"]:
|
|
326
|
+
if agent_type is not None and agent_type != _context.agents[module].agent_type:
|
|
327
|
+
raise Exception(
|
|
328
|
+
f"Agent module '{module}' is registered as type '{_context.agents[module].agent_type}', not '{agent_type}'"
|
|
329
|
+
)
|
|
330
|
+
return _context.agents[module].config_schema
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
def get_agent_klass(module: str, _context: Registry = GLOBAL_REGISTRY) -> Type["Agent"]:
|
|
334
|
+
agent = _context.agents[module]
|
|
335
|
+
if agent.klass is None:
|
|
336
|
+
raise Exception(f"Klass not found for agent module '{module}'")
|
|
337
|
+
|
|
338
|
+
return agent.klass
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def validate_driver(item: Any, _context: Registry = GLOBAL_REGISTRY):
|
|
342
|
+
if not isinstance(item, dict):
|
|
343
|
+
raise ValueError("Driver configuration must be a dictionary")
|
|
344
|
+
|
|
345
|
+
provider = item.get("provider")
|
|
346
|
+
if provider is None:
|
|
347
|
+
raise ValueError("Driver configuration must have a 'provider' field")
|
|
348
|
+
|
|
349
|
+
if provider not in _context.drivers:
|
|
350
|
+
raise ValueError(
|
|
351
|
+
f"Driver module '{provider}' is not registered in the drivers registry"
|
|
352
|
+
)
|
|
353
|
+
return _context.drivers[provider].config_schema.model_validate(item)
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def get_driver_config_klass(
|
|
357
|
+
provider: str, _context: Registry = GLOBAL_REGISTRY
|
|
358
|
+
) -> Type["DriverConfig"]:
|
|
359
|
+
driver = _context.drivers.get(provider)
|
|
360
|
+
if driver is None:
|
|
361
|
+
raise Exception(f"Driver module '{provider}' is not registered")
|
|
362
|
+
|
|
363
|
+
return driver.config_schema
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def get_driver_klass(
|
|
367
|
+
provider: str, _context: Registry = GLOBAL_REGISTRY
|
|
368
|
+
) -> Type["Driver"]:
|
|
369
|
+
driver = _context.drivers[provider]
|
|
370
|
+
if driver.klass is None:
|
|
371
|
+
raise Exception(f"Klass not found for driver module '{provider}'")
|
|
372
|
+
|
|
373
|
+
return driver.klass
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def get_driver_config_instance(
|
|
377
|
+
driver_config: Dict[str, Any], _context: Registry = GLOBAL_REGISTRY
|
|
378
|
+
) -> "DriverConfig":
|
|
379
|
+
provider = driver_config.get("provider")
|
|
380
|
+
if provider is None:
|
|
381
|
+
raise Exception("Driver configuration must have a 'provider' field")
|
|
382
|
+
driver_config_klass = get_driver_config_klass(provider=provider, _context=_context)
|
|
383
|
+
return driver_config_klass.model_validate(driver_config)
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def validate_agent_generation(item: Any, _context: Registry = GLOBAL_REGISTRY):
|
|
387
|
+
if not isinstance(item, dict):
|
|
388
|
+
raise ValueError("Generation agent configuration must be a dictionary")
|
|
389
|
+
|
|
390
|
+
module = item.get("module")
|
|
391
|
+
if module is None:
|
|
392
|
+
raise ValueError("Generation agent configuration must have a 'module' field")
|
|
393
|
+
|
|
394
|
+
if module not in _context.generation_agents:
|
|
395
|
+
raise ValueError(
|
|
396
|
+
f"Generation agent module '{module}' is not registered in the generation agents registry"
|
|
397
|
+
)
|
|
398
|
+
return _context.generation_agents[module].config_schema.model_validate(item)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def validate_agent_preprocess(item: Any, _context: Registry = GLOBAL_REGISTRY):
|
|
402
|
+
if not isinstance(item, dict):
|
|
403
|
+
raise ValueError("Preprocess agent configuration must be a dictionary")
|
|
404
|
+
|
|
405
|
+
module = item.get("module")
|
|
406
|
+
if module is None:
|
|
407
|
+
raise ValueError("Preprocess agent configuration must have a 'module' field")
|
|
408
|
+
|
|
409
|
+
if module not in _context.preprocess_agents:
|
|
410
|
+
raise ValueError(
|
|
411
|
+
f"Preprocess agent module '{module}' is not registered in the preprocess agents registry"
|
|
412
|
+
)
|
|
413
|
+
return _context.preprocess_agents[module].config_schema.model_validate(item)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def validate_agent_postprocess(item: Any, _context: Registry = GLOBAL_REGISTRY):
|
|
417
|
+
if not isinstance(item, dict):
|
|
418
|
+
raise ValueError("Postprocess agent configuration must be a dictionary")
|
|
419
|
+
|
|
420
|
+
module = item.get("module")
|
|
421
|
+
if module is None:
|
|
422
|
+
raise ValueError("Postprocess agent configuration must have a 'module' field")
|
|
423
|
+
|
|
424
|
+
if module not in _context.postprocess_agents:
|
|
425
|
+
raise ValueError(
|
|
426
|
+
f"Postprocess agent module '{module}' is not registered in the postprocess agents registry"
|
|
427
|
+
)
|
|
428
|
+
return _context.postprocess_agents[module].config_schema.model_validate(item)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def validate_agent_context(item: Any, _context: Registry = GLOBAL_REGISTRY):
|
|
432
|
+
if not isinstance(item, dict):
|
|
433
|
+
raise ValueError("Context agent configuration must be a dictionary")
|
|
434
|
+
|
|
435
|
+
module = item.get("module")
|
|
436
|
+
if module is None:
|
|
437
|
+
raise ValueError("Context agent configuration must have a 'module' field")
|
|
438
|
+
|
|
439
|
+
if module not in _context.context_agents:
|
|
440
|
+
raise ValueError(
|
|
441
|
+
f"Context agent module '{module}' is not registered in the context agents registry"
|
|
442
|
+
)
|
|
443
|
+
return _context.context_agents[module].config_schema.model_validate(item)
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def enabled_driver(
|
|
447
|
+
driver: DriverRegistry, running_environment: str, account_md5: str
|
|
448
|
+
) -> bool:
|
|
449
|
+
return not has_feature(
|
|
450
|
+
flag_key=Features.FILTERED_DRIVERS_FEATURE_FLAG.format(
|
|
451
|
+
environment=running_environment
|
|
452
|
+
),
|
|
453
|
+
context={"drivers": driver.id},
|
|
454
|
+
) or has_feature(
|
|
455
|
+
flag_key=Features.RAO_ACCOUNT_ENABLED_FEATURE_FLAG.format(
|
|
456
|
+
account_md5=account_md5
|
|
457
|
+
),
|
|
458
|
+
context={"drivers": driver.id},
|
|
459
|
+
default=False,
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def create_driver_schema(driver: DriverRegistry) -> Dict[str, Any]:
|
|
464
|
+
return {
|
|
465
|
+
"id": driver.id,
|
|
466
|
+
"title": driver.title,
|
|
467
|
+
"description": driver.description,
|
|
468
|
+
"config_schema": driver.config_schema.model_json_schema(),
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def get_driver_agent_schemas(
|
|
473
|
+
running_environment: str = "production",
|
|
474
|
+
account_id: str = "account_id",
|
|
475
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
476
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
477
|
+
return {
|
|
478
|
+
id: create_driver_schema(driver)
|
|
479
|
+
for id, driver in _context.drivers.items()
|
|
480
|
+
if enabled_driver(driver, running_environment, account_id)
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def enabled_agent(
|
|
485
|
+
agent: AgentRegistry, running_environment: str, account_md5: str
|
|
486
|
+
) -> bool:
|
|
487
|
+
"""Filters out agents and drivers from the schema based on feature flags.
|
|
488
|
+
Returns lists of filtered agents and drivers.
|
|
489
|
+
|
|
490
|
+
If the driver or agent is enabled at account-level, we don't filter it out even if the environment-level flag to disable it is enabled.
|
|
491
|
+
This allows us to have a global kill switch for agents and drivers, while still allowing specific accounts to have access to them if needed.
|
|
492
|
+
|
|
493
|
+
"""
|
|
494
|
+
|
|
495
|
+
return not has_feature(
|
|
496
|
+
flag_key=Features.FILTERED_AGENTS_FEATURE_FLAG.format(
|
|
497
|
+
environment=running_environment
|
|
498
|
+
),
|
|
499
|
+
context={"agents": agent.id},
|
|
500
|
+
) or has_feature(
|
|
501
|
+
flag_key=Features.RAO_ACCOUNT_ENABLED_FEATURE_FLAG.format(
|
|
502
|
+
account_md5=account_md5
|
|
503
|
+
),
|
|
504
|
+
context={"agents": agent.id},
|
|
505
|
+
default=False,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def create_agent_schema(agent: AgentRegistry) -> Dict[str, Any]:
|
|
510
|
+
return {
|
|
511
|
+
"id": agent.id,
|
|
512
|
+
"agent_type": agent.agent_type,
|
|
513
|
+
"title": agent.title,
|
|
514
|
+
"description": agent.description,
|
|
515
|
+
"config_schema": agent.config_schema.model_json_schema(),
|
|
516
|
+
}
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def get_context_agent_schemas(
|
|
520
|
+
running_environment: str = "production",
|
|
521
|
+
account_id: str = "account_id",
|
|
522
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
523
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
524
|
+
account_md5 = hashlib.md5(account_id.encode()).hexdigest()
|
|
525
|
+
|
|
526
|
+
return {
|
|
527
|
+
id: create_agent_schema(agent)
|
|
528
|
+
for id, agent in _context.context_agents.items()
|
|
529
|
+
if enabled_agent(agent, running_environment, account_md5)
|
|
530
|
+
}
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def get_preprocess_agent_schemas(
|
|
534
|
+
running_environment: str = "production",
|
|
535
|
+
account_id: str = "account_id",
|
|
536
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
537
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
538
|
+
account_md5 = hashlib.md5(account_id.encode()).hexdigest()
|
|
539
|
+
return {
|
|
540
|
+
id: create_agent_schema(agent)
|
|
541
|
+
for id, agent in _context.preprocess_agents.items()
|
|
542
|
+
if enabled_agent(agent, running_environment, account_md5)
|
|
543
|
+
}
|
|
544
|
+
|
|
545
|
+
|
|
546
|
+
def get_postprocess_agent_schemas(
|
|
547
|
+
running_environment: str = "production",
|
|
548
|
+
account_id: str = "account_id",
|
|
549
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
550
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
551
|
+
account_md5 = hashlib.md5(account_id.encode()).hexdigest()
|
|
552
|
+
|
|
553
|
+
return {
|
|
554
|
+
id: create_agent_schema(agent)
|
|
555
|
+
for id, agent in _context.postprocess_agents.items()
|
|
556
|
+
if enabled_agent(agent, running_environment, account_md5)
|
|
557
|
+
}
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def get_generation_agent_schemas(
|
|
561
|
+
running_environment: str = "production",
|
|
562
|
+
account_id: str = "account_id",
|
|
563
|
+
_context: Registry = GLOBAL_REGISTRY,
|
|
564
|
+
) -> Dict[str, Dict[str, Any]]:
|
|
565
|
+
account_md5 = hashlib.md5(account_id.encode()).hexdigest()
|
|
566
|
+
|
|
567
|
+
return {
|
|
568
|
+
id: create_agent_schema(agent)
|
|
569
|
+
for id, agent in _context.generation_agents.items()
|
|
570
|
+
if enabled_agent(agent, running_environment, account_md5)
|
|
571
|
+
}
|
|
File without changes
|