hyperforge 1.0.0.post19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperforge/__init__.py +16 -0
- hyperforge/agent.py +81 -0
- hyperforge/api/__init__.py +20 -0
- hyperforge/api/app.py +155 -0
- hyperforge/api/authentication.py +271 -0
- hyperforge/api/commands.py +33 -0
- hyperforge/api/internal/__init__.py +4 -0
- hyperforge/api/internal/inspect.py +30 -0
- hyperforge/api/internal/router.py +3 -0
- hyperforge/api/logging.py +18 -0
- hyperforge/api/models.py +129 -0
- hyperforge/api/session.py +197 -0
- hyperforge/api/settings.py +38 -0
- hyperforge/api/utils.py +354 -0
- hyperforge/api/v1/__init__.py +23 -0
- hyperforge/api/v1/agents.py +531 -0
- hyperforge/api/v1/interaction.py +430 -0
- hyperforge/api/v1/mcp_content.py +311 -0
- hyperforge/api/v1/mcp_interaction.py +322 -0
- hyperforge/api/v1/oauth.py +60 -0
- hyperforge/api/v1/prompt.py +129 -0
- hyperforge/api/v1/router.py +3 -0
- hyperforge/api/v1/schema.py +56 -0
- hyperforge/api/v1/session.py +182 -0
- hyperforge/api/v1/utils.py +12 -0
- hyperforge/api/v1/workflows.py +643 -0
- hyperforge/arag.py +28 -0
- hyperforge/broker/__init__.py +52 -0
- hyperforge/broker/local.py +116 -0
- hyperforge/broker/redis.py +161 -0
- hyperforge/configure.py +571 -0
- hyperforge/context/__init__.py +0 -0
- hyperforge/context/agent.py +377 -0
- hyperforge/context/config.py +103 -0
- hyperforge/database.py +3 -0
- hyperforge/db/__init__.py +6 -0
- hyperforge/db/agents.py +1521 -0
- hyperforge/db/encryption.py +91 -0
- hyperforge/db/exceptions.py +26 -0
- hyperforge/db/settings.py +16 -0
- hyperforge/db/workflow_cleanup.py +69 -0
- hyperforge/definition.py +13 -0
- hyperforge/driver.py +31 -0
- hyperforge/dummy.py +28 -0
- hyperforge/engine.py +189 -0
- hyperforge/exceptions.py +14 -0
- hyperforge/feature_flag.py +105 -0
- hyperforge/fixtures.py +602 -0
- hyperforge/interaction.py +116 -0
- hyperforge/llm.py +75 -0
- hyperforge/manager.py +432 -0
- hyperforge/memory/__init__.py +5 -0
- hyperforge/memory/memory.py +974 -0
- hyperforge/minimal_fixtures.py +75 -0
- hyperforge/models.py +336 -0
- hyperforge/nua.py +336 -0
- hyperforge/openapi.py +63 -0
- hyperforge/prompts.py +188 -0
- hyperforge/pubsub.py +90 -0
- hyperforge/py.typed +0 -0
- hyperforge/redis_utils.py +82 -0
- hyperforge/retrieval/__init__.py +0 -0
- hyperforge/retrieval/agent.py +169 -0
- hyperforge/retrieval/config.py +94 -0
- hyperforge/server/__init__.py +5 -0
- hyperforge/server/cache.py +131 -0
- hyperforge/server/run.py +109 -0
- hyperforge/server/sandbox.py +60 -0
- hyperforge/server/session.py +421 -0
- hyperforge/server/settings.py +47 -0
- hyperforge/server/utils.py +57 -0
- hyperforge/server/web.py +31 -0
- hyperforge/settings.py +18 -0
- hyperforge/standalone/__init__.py +5 -0
- hyperforge/standalone/agent.py +189 -0
- hyperforge/standalone/app.py +264 -0
- hyperforge/standalone/config.py +137 -0
- hyperforge/standalone/const.py +1 -0
- hyperforge/standalone/run.py +60 -0
- hyperforge/standalone/settings.py +133 -0
- hyperforge/standalone/ui_router.py +241 -0
- hyperforge/trace.py +42 -0
- hyperforge/utils/__init__.py +112 -0
- hyperforge/utils/http.py +48 -0
- hyperforge/workflows.py +44 -0
- hyperforge-1.0.0.post19.dist-info/METADATA +95 -0
- hyperforge-1.0.0.post19.dist-info/RECORD +90 -0
- hyperforge-1.0.0.post19.dist-info/WHEEL +5 -0
- hyperforge-1.0.0.post19.dist-info/entry_points.txt +8 -0
- hyperforge-1.0.0.post19.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
import os
|
|
3
|
+
from functools import cache
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from cryptography.fernet import Fernet, InvalidToken
|
|
7
|
+
from cryptography.hazmat.primitives import hashes
|
|
8
|
+
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
9
|
+
|
|
10
|
+
from hyperforge.db.settings import EncryptionSettings
|
|
11
|
+
from hyperforge.driver import DriverConfig, EncryptedPayload
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@cache
|
|
15
|
+
def get_fernet() -> Fernet:
|
|
16
|
+
settings = EncryptionSettings()
|
|
17
|
+
return Fernet(settings.encryption_secret_key)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def encrypt_data(data: str) -> str:
|
|
21
|
+
f = get_fernet()
|
|
22
|
+
return f.encrypt(data.encode()).decode()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def decrypt_data(data: str) -> str:
|
|
26
|
+
f = get_fernet()
|
|
27
|
+
try:
|
|
28
|
+
return f.decrypt(token=data, ttl=None).decode()
|
|
29
|
+
except InvalidToken:
|
|
30
|
+
raise ValueError("Invalid encryption token.")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def fernet_key_from_passphrase(
|
|
34
|
+
passphrase: str, salt: bytes | None
|
|
35
|
+
) -> tuple[bytes, bytes]:
|
|
36
|
+
"""Generate a Fernet key from a passphrase and salt (if provided).
|
|
37
|
+
From https://cryptography.io/en/latest/fernet/#using-passwords-with-fernet
|
|
38
|
+
"""
|
|
39
|
+
if salt is None:
|
|
40
|
+
salt = os.urandom(16)
|
|
41
|
+
kdf = PBKDF2HMAC(
|
|
42
|
+
algorithm=hashes.SHA256(),
|
|
43
|
+
length=32,
|
|
44
|
+
salt=salt,
|
|
45
|
+
iterations=1_200_000,
|
|
46
|
+
)
|
|
47
|
+
key = base64.urlsafe_b64encode(kdf.derive(passphrase.encode()))
|
|
48
|
+
return key, salt
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def dump_without_encrypted_fields(
|
|
52
|
+
model: DriverConfig[EncryptedPayload],
|
|
53
|
+
) -> dict[str, Any]:
|
|
54
|
+
data = model.model_dump() # type: ignore
|
|
55
|
+
for field in model.config.encrypted_fields: # type: ignore
|
|
56
|
+
if field in data["config"]:
|
|
57
|
+
del data["config"][field]
|
|
58
|
+
return {"config": data}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def encrypt_fields(model: EncryptedPayload) -> dict[str, Any]:
|
|
62
|
+
data = model.model_dump() # type: ignore
|
|
63
|
+
for field in model.encrypted_fields: # type: ignore
|
|
64
|
+
if field in data:
|
|
65
|
+
if isinstance(data[field], str):
|
|
66
|
+
data[field] = encrypt_data(data=data[field])
|
|
67
|
+
elif isinstance(data[field], dict):
|
|
68
|
+
for k, v in data[field].items():
|
|
69
|
+
if isinstance(v, str):
|
|
70
|
+
data[field][k] = encrypt_data(data=v)
|
|
71
|
+
return data
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def decrypt_fields(model: EncryptedPayload) -> None:
|
|
75
|
+
for field in model.encrypted_fields:
|
|
76
|
+
if not hasattr(model, field):
|
|
77
|
+
raise AttributeError(f"Field '{field}' not found in {type(model).__name__}")
|
|
78
|
+
value = getattr(model, field)
|
|
79
|
+
if value is not None:
|
|
80
|
+
try:
|
|
81
|
+
if isinstance(value, str):
|
|
82
|
+
value = decrypt_data(data=value)
|
|
83
|
+
elif isinstance(value, dict):
|
|
84
|
+
for k, v in value.items():
|
|
85
|
+
if isinstance(v, str):
|
|
86
|
+
value[k] = decrypt_data(data=v)
|
|
87
|
+
except ValueError:
|
|
88
|
+
# We ignore the error to support current unencrypted data
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
setattr(model, field, value)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
class RetrievalAgentError(Exception):
|
|
2
|
+
"""Base class for exceptions in Retrieval Agents."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class NotFoundError(RetrievalAgentError):
|
|
6
|
+
"""Exception raised when a requested resource is not found."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DriverNotFoundError(NotFoundError):
|
|
10
|
+
"""Exception raised when the specified driver is not found."""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProtectedWorkflowError(RetrievalAgentError):
|
|
14
|
+
"""Exception raised when a protected workflow is modified."""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ParseExportError(RetrievalAgentError):
|
|
18
|
+
"""Exception raised when there is an error parsing the export file."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ExportEncryptionError(RetrievalAgentError):
|
|
22
|
+
"""Exception raised when there is an error with encryption."""
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InvalidTargetAgentError(RetrievalAgentError):
|
|
26
|
+
"""Exception raised when the target agent is invalid for an operation."""
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from pydantic_settings import BaseSettings
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DataManagerSettings(BaseSettings):
|
|
5
|
+
postgresql_dsn: str
|
|
6
|
+
export_read_chunk_size: int = 1024 * 1024 # 1 MB
|
|
7
|
+
export_read_max_size: int = 10 * 1024 * 1024 # 10 MB
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class EncryptionSettings(BaseSettings):
|
|
11
|
+
encryption_secret_key: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class IDPSettings(BaseSettings):
|
|
15
|
+
dummy_idp: bool = False
|
|
16
|
+
idp_regional_grpc: str = "idp-grpc.idp-regional.svc.cluster.local:9090"
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import datetime
|
|
3
|
+
from importlib.metadata import version
|
|
4
|
+
|
|
5
|
+
from nucliadb_telemetry import errors
|
|
6
|
+
from nucliadb_telemetry.errors import setup_error_handling
|
|
7
|
+
from nucliadb_telemetry.logs import setup_logging
|
|
8
|
+
from nucliadb_telemetry.settings import LogLevel, LogSettings
|
|
9
|
+
|
|
10
|
+
from hyperforge.db import logger
|
|
11
|
+
from hyperforge.db.agents import WORKFLOW_PURGE_RETENTION, AgentManager
|
|
12
|
+
from hyperforge.db.settings import DataManagerSettings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def cleanup_deleted_workflows(
|
|
16
|
+
manager: AgentManager,
|
|
17
|
+
older_than: datetime.timedelta = WORKFLOW_PURGE_RETENTION,
|
|
18
|
+
):
|
|
19
|
+
logger.info("Cleaning up deleted workflows")
|
|
20
|
+
workflows = await manager.get_expired_deleted_workflows(older_than=older_than)
|
|
21
|
+
logger.info("Found deleted workflows to clean up", extra={"count": len(workflows)})
|
|
22
|
+
|
|
23
|
+
for workflow in workflows:
|
|
24
|
+
try:
|
|
25
|
+
logger.info(
|
|
26
|
+
"Purging deleted workflow",
|
|
27
|
+
extra={
|
|
28
|
+
"account": workflow["account"],
|
|
29
|
+
"agent_id": workflow["agent_id"],
|
|
30
|
+
"workflow_id": workflow["workflow_id"],
|
|
31
|
+
},
|
|
32
|
+
)
|
|
33
|
+
await manager.purge_deleted_workflow(
|
|
34
|
+
account=workflow["account"],
|
|
35
|
+
agent_id=workflow["agent_id"],
|
|
36
|
+
workflow_id=workflow["workflow_id"],
|
|
37
|
+
)
|
|
38
|
+
except Exception as exc:
|
|
39
|
+
errors.capture_exception(exc)
|
|
40
|
+
logger.error(
|
|
41
|
+
"Failed to purge deleted workflow",
|
|
42
|
+
exc_info=exc,
|
|
43
|
+
extra={
|
|
44
|
+
"account": workflow["account"],
|
|
45
|
+
"agent_id": workflow["agent_id"],
|
|
46
|
+
"workflow_id": workflow["workflow_id"],
|
|
47
|
+
},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
async def cronjob(manager: AgentManager):
|
|
52
|
+
await cleanup_deleted_workflows(manager)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def run(): # pragma: no cover
|
|
56
|
+
asyncio.run(_main())
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def _main():
|
|
60
|
+
log_settings = LogSettings(logger_levels={"hyperforge.server": LogLevel.INFO})
|
|
61
|
+
setup_logging(settings=log_settings)
|
|
62
|
+
setup_error_handling(version("hyperforge"))
|
|
63
|
+
data_manager_settings = DataManagerSettings()
|
|
64
|
+
manager = await AgentManager.from_settings(settings=data_manager_settings)
|
|
65
|
+
await manager.initialize()
|
|
66
|
+
try:
|
|
67
|
+
await cronjob(manager)
|
|
68
|
+
finally:
|
|
69
|
+
await manager.finalize()
|
hyperforge/definition.py
ADDED
hyperforge/driver.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from typing import Any, ClassVar, Generic, Self, TypeVar
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EncryptedPayload(BaseModel):
|
|
7
|
+
encrypted_fields: ClassVar[list[str]] = []
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
T = TypeVar("T", bound="EncryptedPayload")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DriverConfig(BaseModel, Generic[T]):
|
|
14
|
+
id: str | None = None
|
|
15
|
+
identifier: str
|
|
16
|
+
name: str
|
|
17
|
+
provider: Any = Field(
|
|
18
|
+
..., description="The type of driver, e.g., 'google', 'marklogic', etc."
|
|
19
|
+
)
|
|
20
|
+
config: T = Field(..., description="The configuration specific to the driver.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class Driver(BaseModel):
|
|
24
|
+
name: str
|
|
25
|
+
provider: str
|
|
26
|
+
|
|
27
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
async def init(cls, driver: Any) -> Self:
|
|
31
|
+
raise NotImplementedError()
|
hyperforge/dummy.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from stashify_protos.protos import idp_pb2
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DummyIDPRegionalGRPCUtility:
|
|
7
|
+
requests: list[Any]
|
|
8
|
+
|
|
9
|
+
def __init__(self) -> None:
|
|
10
|
+
self.requests = []
|
|
11
|
+
|
|
12
|
+
async def SendRAOExportEmail(
|
|
13
|
+
self, payload: idp_pb2.SendRAOExportEmailRequest
|
|
14
|
+
) -> idp_pb2.SendRAOExportEmailResponse:
|
|
15
|
+
self.requests.append(payload)
|
|
16
|
+
return idp_pb2.SendRAOExportEmailResponse(
|
|
17
|
+
status=idp_pb2.SendRAOExportEmailResponse.Status.OK,
|
|
18
|
+
message="Dummy IDP Regional GRPC Utility: Email sent successfully",
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
async def SendEmail(
|
|
22
|
+
self, payload: idp_pb2.SendEmailRequest
|
|
23
|
+
) -> idp_pb2.SendEmailResponse:
|
|
24
|
+
self.requests.append(payload)
|
|
25
|
+
return idp_pb2.SendEmailResponse(
|
|
26
|
+
status=idp_pb2.SendEmailResponse.Status.OK,
|
|
27
|
+
message="Dummy IDP Regional GRPC Utility: Email sent successfully",
|
|
28
|
+
)
|
hyperforge/engine.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, cast
|
|
4
|
+
|
|
5
|
+
from nuclia.lib.nua import AsyncNuaClient
|
|
6
|
+
|
|
7
|
+
from hyperforge.configure import GLOBAL_REGISTRY, load_all_configurations
|
|
8
|
+
from hyperforge.interaction import AragAnswer
|
|
9
|
+
from hyperforge.llm import NoopNuaClient, NuaBaseModel, NUAConnection
|
|
10
|
+
from hyperforge.manager import Manager
|
|
11
|
+
from hyperforge.memory.memory import BaseSessionMemory, QuestionMemory, SessionMemory
|
|
12
|
+
from hyperforge.retrieval.agent import RetrievalAgent
|
|
13
|
+
from hyperforge.retrieval.config import RetrievalAgentConfig
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class State:
|
|
20
|
+
agent: RetrievalAgent
|
|
21
|
+
manager: Manager
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def init(
|
|
25
|
+
config: Optional[Dict[str, Any]] = None,
|
|
26
|
+
agent_id: str = "default",
|
|
27
|
+
internal_nua: bool = False,
|
|
28
|
+
internal_nua_api: str = "http://predict.learning.svc.cluster.local:8080",
|
|
29
|
+
local_openai: Optional[str] = None,
|
|
30
|
+
local_openai_model: Optional[str] = None,
|
|
31
|
+
external_nua_api_key: Optional[str] = None,
|
|
32
|
+
loaded_modules: list[str] = [],
|
|
33
|
+
retrieval_config: Optional[RetrievalAgentConfig] = None,
|
|
34
|
+
session_id: str = "default_session",
|
|
35
|
+
memory_klass: type[BaseSessionMemory] = SessionMemory,
|
|
36
|
+
) -> Tuple[State, SessionMemory]:
|
|
37
|
+
from hyperforge.configure import scan
|
|
38
|
+
|
|
39
|
+
for load_module in loaded_modules:
|
|
40
|
+
try:
|
|
41
|
+
scan(load_module)
|
|
42
|
+
load_all_configurations(load_module)
|
|
43
|
+
except ImportError:
|
|
44
|
+
logger.error(f"Module {load_module} could not be loaded")
|
|
45
|
+
|
|
46
|
+
if retrieval_config is None:
|
|
47
|
+
if config is None:
|
|
48
|
+
raise ValueError("Either config or retrieval_config must be provided")
|
|
49
|
+
retrieval_config = RetrievalAgentConfig.model_validate(config)
|
|
50
|
+
|
|
51
|
+
state = await get_state(
|
|
52
|
+
agent_id=agent_id,
|
|
53
|
+
config=retrieval_config,
|
|
54
|
+
internal_nua=internal_nua,
|
|
55
|
+
internal_nua_api=internal_nua_api,
|
|
56
|
+
local_openai=local_openai,
|
|
57
|
+
local_openai_model=local_openai_model,
|
|
58
|
+
external_nua_api_key=external_nua_api_key,
|
|
59
|
+
)
|
|
60
|
+
session_memory = memory_klass.from_config(
|
|
61
|
+
retrieval_config.memory,
|
|
62
|
+
agent_id=agent_id,
|
|
63
|
+
workflow_id="default",
|
|
64
|
+
rules=retrieval_config.rules,
|
|
65
|
+
)
|
|
66
|
+
session_memory.init(session_id)
|
|
67
|
+
return state, session_memory
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def main(
|
|
71
|
+
config: Optional[Dict[str, Any]] = None,
|
|
72
|
+
agent_id: str = "default",
|
|
73
|
+
internal_nua: bool = False,
|
|
74
|
+
internal_nua_api: str = "http://predict.learning.svc.cluster.local:8080",
|
|
75
|
+
local_openai: Optional[str] = None,
|
|
76
|
+
local_openai_model: Optional[str] = None,
|
|
77
|
+
external_nua_api_key: Optional[str] = None,
|
|
78
|
+
question: str = "",
|
|
79
|
+
loaded_modules: list[str] = [],
|
|
80
|
+
retrieval_config: Optional[RetrievalAgentConfig] = None,
|
|
81
|
+
callback: Optional[Callable[[AragAnswer], Awaitable[None]]] = None,
|
|
82
|
+
session_id: str = "default_session",
|
|
83
|
+
user_metadata: Optional[Dict[str, str]] = None,
|
|
84
|
+
headers: Optional[Dict[str, str]] = None,
|
|
85
|
+
memory_klass: type[BaseSessionMemory] = SessionMemory,
|
|
86
|
+
streaming: bool = False,
|
|
87
|
+
) -> QuestionMemory:
|
|
88
|
+
try:
|
|
89
|
+
state, session_memory = await init(
|
|
90
|
+
config=config,
|
|
91
|
+
agent_id=agent_id,
|
|
92
|
+
internal_nua=internal_nua,
|
|
93
|
+
internal_nua_api=internal_nua_api,
|
|
94
|
+
local_openai=local_openai,
|
|
95
|
+
local_openai_model=local_openai_model,
|
|
96
|
+
external_nua_api_key=external_nua_api_key,
|
|
97
|
+
loaded_modules=loaded_modules,
|
|
98
|
+
retrieval_config=retrieval_config,
|
|
99
|
+
session_id=session_id,
|
|
100
|
+
memory_klass=memory_klass,
|
|
101
|
+
)
|
|
102
|
+
question_memory = session_memory.start_question(question, streaming=streaming)
|
|
103
|
+
if callback is not None:
|
|
104
|
+
question_memory.set_callback_fn(callback)
|
|
105
|
+
|
|
106
|
+
if user_metadata:
|
|
107
|
+
question_memory.session.user_info.update(user_metadata)
|
|
108
|
+
|
|
109
|
+
if headers:
|
|
110
|
+
question_memory.headers.update(headers)
|
|
111
|
+
|
|
112
|
+
if state.agent is None:
|
|
113
|
+
raise ValueError("Agent could not be initialized")
|
|
114
|
+
|
|
115
|
+
await state.agent(
|
|
116
|
+
question_memory,
|
|
117
|
+
state.manager,
|
|
118
|
+
)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise e
|
|
121
|
+
finally:
|
|
122
|
+
GLOBAL_REGISTRY.clear()
|
|
123
|
+
return question_memory
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def engine(
|
|
127
|
+
manager: Manager,
|
|
128
|
+
agent: RetrievalAgent,
|
|
129
|
+
question_memory: QuestionMemory,
|
|
130
|
+
user_metadata: Optional[Dict[str, str]] = None,
|
|
131
|
+
) -> None:
|
|
132
|
+
if user_metadata is not None:
|
|
133
|
+
question_memory.session.user_info.update(user_metadata)
|
|
134
|
+
|
|
135
|
+
await agent(
|
|
136
|
+
question_memory,
|
|
137
|
+
manager,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
async def get_state(
|
|
142
|
+
agent_id: str,
|
|
143
|
+
config: RetrievalAgentConfig,
|
|
144
|
+
internal_nua_api: str = "http://predict.learning.svc.cluster.local:8080",
|
|
145
|
+
internal_nua: bool = False,
|
|
146
|
+
local_openai: Optional[str] = None,
|
|
147
|
+
local_openai_model: Optional[str] = None,
|
|
148
|
+
external_nua_api_key: Optional[str] = None,
|
|
149
|
+
account: Optional[str] = None,
|
|
150
|
+
kbid: Optional[str] = None,
|
|
151
|
+
local_openai_model_klass: Optional[type[NuaBaseModel]] = None,
|
|
152
|
+
) -> State:
|
|
153
|
+
nua: AsyncNuaClient
|
|
154
|
+
if internal_nua:
|
|
155
|
+
nua = cast(
|
|
156
|
+
AsyncNuaClient,
|
|
157
|
+
await NUAConnection.connect_internal(
|
|
158
|
+
kbid=kbid, account=account, url=internal_nua_api
|
|
159
|
+
),
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
elif local_openai is not None and local_openai_model_klass is not None:
|
|
163
|
+
nua = await local_openai_model_klass.model_validate(
|
|
164
|
+
{
|
|
165
|
+
"key": external_nua_api_key,
|
|
166
|
+
"local_openai": local_openai,
|
|
167
|
+
"local_openai_model": local_openai_model,
|
|
168
|
+
}
|
|
169
|
+
).connect()
|
|
170
|
+
|
|
171
|
+
elif external_nua_api_key is not None:
|
|
172
|
+
nua = await NUAConnection.model_validate(
|
|
173
|
+
{
|
|
174
|
+
"key": external_nua_api_key,
|
|
175
|
+
}
|
|
176
|
+
).connect()
|
|
177
|
+
|
|
178
|
+
else:
|
|
179
|
+
logger.warning(
|
|
180
|
+
"No LLM backend configured — use a no-op client. Agents that don't"
|
|
181
|
+
" require LLM calls (e.g. the built-in ``static`` context agent) will"
|
|
182
|
+
" work fine; any agent that actually calls NUA will raise a clear error."
|
|
183
|
+
)
|
|
184
|
+
nua = NoopNuaClient()
|
|
185
|
+
|
|
186
|
+
manager = await Manager.from_config(drivers=config.drivers, nua=nua)
|
|
187
|
+
agent = await RetrievalAgent.from_config_class(config)
|
|
188
|
+
|
|
189
|
+
return State(manager=manager, agent=agent)
|
hyperforge/exceptions.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import json
|
|
3
|
+
import threading
|
|
4
|
+
from functools import cached_property
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
|
|
7
|
+
import mrflagly # type: ignore
|
|
8
|
+
import pydantic
|
|
9
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
10
|
+
|
|
11
|
+
_flag_service = threading.local()
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FlagContext(pydantic.BaseModel):
|
|
15
|
+
"""Class to hold information that can then be passed to the feature flag service to determine if a flag is enabled or not.
|
|
16
|
+
|
|
17
|
+
Having it as a class allows to simply pass this object down to any function that needs to check for a feature flag and have all the relevant information in one place, instead of having to pass multiple parameters or a dictionary.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
account_id: Optional[str] = None
|
|
21
|
+
kbid: str
|
|
22
|
+
|
|
23
|
+
@pydantic.computed_field # type: ignore[prop-decorator]
|
|
24
|
+
@cached_property
|
|
25
|
+
def account_id_md5(self) -> Optional[str]:
|
|
26
|
+
if self.account_id is not None:
|
|
27
|
+
return hashlib.md5(self.account_id.encode()).hexdigest()
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class Features:
|
|
32
|
+
# RAO
|
|
33
|
+
FILTERED_AGENTS_FEATURE_FLAG = "rao_filtered_agents_{environment}"
|
|
34
|
+
FILTERED_DRIVERS_FEATURE_FLAG = "rao_filtered_drivers_{environment}"
|
|
35
|
+
AUDIT_RAO_ASK_ENDPOINT = "learning_audit_rao_ask"
|
|
36
|
+
RAO_ACCOUNT_ENABLED_FEATURE_FLAG = "rao_account_enabled_{account_md5}"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class FlagService:
|
|
40
|
+
def __init__(self):
|
|
41
|
+
self.settings = Settings()
|
|
42
|
+
if self.settings.flag_settings_url is None:
|
|
43
|
+
self.flag_service = mrflagly.FlagService(data=json.dumps(DEFAULT_FLAG_DATA))
|
|
44
|
+
else:
|
|
45
|
+
self.flag_service = mrflagly.FlagService(
|
|
46
|
+
url=self.settings.flag_settings_url
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def enabled(
|
|
50
|
+
self, flag_key: str, default: bool = False, context: Optional[dict] = None
|
|
51
|
+
) -> bool:
|
|
52
|
+
if context is None:
|
|
53
|
+
context = {}
|
|
54
|
+
context["environment"] = self.settings.running_environment
|
|
55
|
+
|
|
56
|
+
return self.flag_service.enabled(flag_key, default=default, context=context)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Settings(BaseSettings):
|
|
60
|
+
model_config = SettingsConfigDict(env_parse_none_str="null")
|
|
61
|
+
running_environment: str = pydantic.Field(
|
|
62
|
+
default="local",
|
|
63
|
+
validation_alias=pydantic.AliasChoices("environment", "running_environment"),
|
|
64
|
+
)
|
|
65
|
+
flag_settings_url: str | None = (
|
|
66
|
+
"https://cdn.rag.progress.cloud/features/features-v2.json"
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
DEFAULT_FLAG_DATA: dict[str, Any] = {
|
|
71
|
+
# These are just defaults to use for local dev and tests
|
|
72
|
+
Features.FILTERED_AGENTS_FEATURE_FLAG.format(environment="local"): {
|
|
73
|
+
"rollout": 0,
|
|
74
|
+
"variants": {"agents": []},
|
|
75
|
+
},
|
|
76
|
+
Features.FILTERED_DRIVERS_FEATURE_FLAG.format(environment="local"): {
|
|
77
|
+
"rollout": 0,
|
|
78
|
+
"variants": {"drivers": []},
|
|
79
|
+
},
|
|
80
|
+
Features.AUDIT_RAO_ASK_ENDPOINT: {
|
|
81
|
+
"rollout": 0,
|
|
82
|
+
"variants": {"environment": ["local"]},
|
|
83
|
+
},
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_flag_service() -> FlagService:
|
|
88
|
+
if getattr(_flag_service, "service", None) is None:
|
|
89
|
+
_flag_service.service = FlagService()
|
|
90
|
+
return _flag_service.service
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def has_feature(
|
|
94
|
+
flag_key: str,
|
|
95
|
+
default: bool = False,
|
|
96
|
+
context: dict[str, str] | None | FlagContext = None,
|
|
97
|
+
) -> bool:
|
|
98
|
+
fs = get_flag_service()
|
|
99
|
+
|
|
100
|
+
if isinstance(context, FlagContext):
|
|
101
|
+
context_dict = context.model_dump()
|
|
102
|
+
else:
|
|
103
|
+
context_dict = context or {}
|
|
104
|
+
|
|
105
|
+
return fs.enabled(flag_key, default=default, context=context_dict)
|