hyperforge 1.0.0.post19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. hyperforge/__init__.py +16 -0
  2. hyperforge/agent.py +81 -0
  3. hyperforge/api/__init__.py +20 -0
  4. hyperforge/api/app.py +155 -0
  5. hyperforge/api/authentication.py +271 -0
  6. hyperforge/api/commands.py +33 -0
  7. hyperforge/api/internal/__init__.py +4 -0
  8. hyperforge/api/internal/inspect.py +30 -0
  9. hyperforge/api/internal/router.py +3 -0
  10. hyperforge/api/logging.py +18 -0
  11. hyperforge/api/models.py +129 -0
  12. hyperforge/api/session.py +197 -0
  13. hyperforge/api/settings.py +38 -0
  14. hyperforge/api/utils.py +354 -0
  15. hyperforge/api/v1/__init__.py +23 -0
  16. hyperforge/api/v1/agents.py +531 -0
  17. hyperforge/api/v1/interaction.py +430 -0
  18. hyperforge/api/v1/mcp_content.py +311 -0
  19. hyperforge/api/v1/mcp_interaction.py +322 -0
  20. hyperforge/api/v1/oauth.py +60 -0
  21. hyperforge/api/v1/prompt.py +129 -0
  22. hyperforge/api/v1/router.py +3 -0
  23. hyperforge/api/v1/schema.py +56 -0
  24. hyperforge/api/v1/session.py +182 -0
  25. hyperforge/api/v1/utils.py +12 -0
  26. hyperforge/api/v1/workflows.py +643 -0
  27. hyperforge/arag.py +28 -0
  28. hyperforge/broker/__init__.py +52 -0
  29. hyperforge/broker/local.py +116 -0
  30. hyperforge/broker/redis.py +161 -0
  31. hyperforge/configure.py +571 -0
  32. hyperforge/context/__init__.py +0 -0
  33. hyperforge/context/agent.py +377 -0
  34. hyperforge/context/config.py +103 -0
  35. hyperforge/database.py +3 -0
  36. hyperforge/db/__init__.py +6 -0
  37. hyperforge/db/agents.py +1521 -0
  38. hyperforge/db/encryption.py +91 -0
  39. hyperforge/db/exceptions.py +26 -0
  40. hyperforge/db/settings.py +16 -0
  41. hyperforge/db/workflow_cleanup.py +69 -0
  42. hyperforge/definition.py +13 -0
  43. hyperforge/driver.py +31 -0
  44. hyperforge/dummy.py +28 -0
  45. hyperforge/engine.py +189 -0
  46. hyperforge/exceptions.py +14 -0
  47. hyperforge/feature_flag.py +105 -0
  48. hyperforge/fixtures.py +602 -0
  49. hyperforge/interaction.py +116 -0
  50. hyperforge/llm.py +75 -0
  51. hyperforge/manager.py +432 -0
  52. hyperforge/memory/__init__.py +5 -0
  53. hyperforge/memory/memory.py +974 -0
  54. hyperforge/minimal_fixtures.py +75 -0
  55. hyperforge/models.py +336 -0
  56. hyperforge/nua.py +336 -0
  57. hyperforge/openapi.py +63 -0
  58. hyperforge/prompts.py +188 -0
  59. hyperforge/pubsub.py +90 -0
  60. hyperforge/py.typed +0 -0
  61. hyperforge/redis_utils.py +82 -0
  62. hyperforge/retrieval/__init__.py +0 -0
  63. hyperforge/retrieval/agent.py +169 -0
  64. hyperforge/retrieval/config.py +94 -0
  65. hyperforge/server/__init__.py +5 -0
  66. hyperforge/server/cache.py +131 -0
  67. hyperforge/server/run.py +109 -0
  68. hyperforge/server/sandbox.py +60 -0
  69. hyperforge/server/session.py +421 -0
  70. hyperforge/server/settings.py +47 -0
  71. hyperforge/server/utils.py +57 -0
  72. hyperforge/server/web.py +31 -0
  73. hyperforge/settings.py +18 -0
  74. hyperforge/standalone/__init__.py +5 -0
  75. hyperforge/standalone/agent.py +189 -0
  76. hyperforge/standalone/app.py +264 -0
  77. hyperforge/standalone/config.py +137 -0
  78. hyperforge/standalone/const.py +1 -0
  79. hyperforge/standalone/run.py +60 -0
  80. hyperforge/standalone/settings.py +133 -0
  81. hyperforge/standalone/ui_router.py +241 -0
  82. hyperforge/trace.py +42 -0
  83. hyperforge/utils/__init__.py +112 -0
  84. hyperforge/utils/http.py +48 -0
  85. hyperforge/workflows.py +44 -0
  86. hyperforge-1.0.0.post19.dist-info/METADATA +95 -0
  87. hyperforge-1.0.0.post19.dist-info/RECORD +90 -0
  88. hyperforge-1.0.0.post19.dist-info/WHEEL +5 -0
  89. hyperforge-1.0.0.post19.dist-info/entry_points.txt +8 -0
  90. hyperforge-1.0.0.post19.dist-info/top_level.txt +1 -0
@@ -0,0 +1,91 @@
1
+ import base64
2
+ import os
3
+ from functools import cache
4
+ from typing import Any
5
+
6
+ from cryptography.fernet import Fernet, InvalidToken
7
+ from cryptography.hazmat.primitives import hashes
8
+ from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
9
+
10
+ from hyperforge.db.settings import EncryptionSettings
11
+ from hyperforge.driver import DriverConfig, EncryptedPayload
12
+
13
+
14
+ @cache
15
+ def get_fernet() -> Fernet:
16
+ settings = EncryptionSettings()
17
+ return Fernet(settings.encryption_secret_key)
18
+
19
+
20
+ def encrypt_data(data: str) -> str:
21
+ f = get_fernet()
22
+ return f.encrypt(data.encode()).decode()
23
+
24
+
25
+ def decrypt_data(data: str) -> str:
26
+ f = get_fernet()
27
+ try:
28
+ return f.decrypt(token=data, ttl=None).decode()
29
+ except InvalidToken:
30
+ raise ValueError("Invalid encryption token.")
31
+
32
+
33
+ def fernet_key_from_passphrase(
34
+ passphrase: str, salt: bytes | None
35
+ ) -> tuple[bytes, bytes]:
36
+ """Generate a Fernet key from a passphrase and salt (if provided).
37
+ From https://cryptography.io/en/latest/fernet/#using-passwords-with-fernet
38
+ """
39
+ if salt is None:
40
+ salt = os.urandom(16)
41
+ kdf = PBKDF2HMAC(
42
+ algorithm=hashes.SHA256(),
43
+ length=32,
44
+ salt=salt,
45
+ iterations=1_200_000,
46
+ )
47
+ key = base64.urlsafe_b64encode(kdf.derive(passphrase.encode()))
48
+ return key, salt
49
+
50
+
51
+ def dump_without_encrypted_fields(
52
+ model: DriverConfig[EncryptedPayload],
53
+ ) -> dict[str, Any]:
54
+ data = model.model_dump() # type: ignore
55
+ for field in model.config.encrypted_fields: # type: ignore
56
+ if field in data["config"]:
57
+ del data["config"][field]
58
+ return {"config": data}
59
+
60
+
61
+ def encrypt_fields(model: EncryptedPayload) -> dict[str, Any]:
62
+ data = model.model_dump() # type: ignore
63
+ for field in model.encrypted_fields: # type: ignore
64
+ if field in data:
65
+ if isinstance(data[field], str):
66
+ data[field] = encrypt_data(data=data[field])
67
+ elif isinstance(data[field], dict):
68
+ for k, v in data[field].items():
69
+ if isinstance(v, str):
70
+ data[field][k] = encrypt_data(data=v)
71
+ return data
72
+
73
+
74
+ def decrypt_fields(model: EncryptedPayload) -> None:
75
+ for field in model.encrypted_fields:
76
+ if not hasattr(model, field):
77
+ raise AttributeError(f"Field '{field}' not found in {type(model).__name__}")
78
+ value = getattr(model, field)
79
+ if value is not None:
80
+ try:
81
+ if isinstance(value, str):
82
+ value = decrypt_data(data=value)
83
+ elif isinstance(value, dict):
84
+ for k, v in value.items():
85
+ if isinstance(v, str):
86
+ value[k] = decrypt_data(data=v)
87
+ except ValueError:
88
+ # We ignore the error to support current unencrypted data
89
+ continue
90
+
91
+ setattr(model, field, value)
@@ -0,0 +1,26 @@
1
+ class RetrievalAgentError(Exception):
2
+ """Base class for exceptions in Retrieval Agents."""
3
+
4
+
5
+ class NotFoundError(RetrievalAgentError):
6
+ """Exception raised when a requested resource is not found."""
7
+
8
+
9
+ class DriverNotFoundError(NotFoundError):
10
+ """Exception raised when the specified driver is not found."""
11
+
12
+
13
+ class ProtectedWorkflowError(RetrievalAgentError):
14
+ """Exception raised when a protected workflow is modified."""
15
+
16
+
17
+ class ParseExportError(RetrievalAgentError):
18
+ """Exception raised when there is an error parsing the export file."""
19
+
20
+
21
+ class ExportEncryptionError(RetrievalAgentError):
22
+ """Exception raised when there is an error with encryption."""
23
+
24
+
25
+ class InvalidTargetAgentError(RetrievalAgentError):
26
+ """Exception raised when the target agent is invalid for an operation."""
@@ -0,0 +1,16 @@
1
+ from pydantic_settings import BaseSettings
2
+
3
+
4
+ class DataManagerSettings(BaseSettings):
5
+ postgresql_dsn: str
6
+ export_read_chunk_size: int = 1024 * 1024 # 1 MB
7
+ export_read_max_size: int = 10 * 1024 * 1024 # 10 MB
8
+
9
+
10
+ class EncryptionSettings(BaseSettings):
11
+ encryption_secret_key: str
12
+
13
+
14
+ class IDPSettings(BaseSettings):
15
+ dummy_idp: bool = False
16
+ idp_regional_grpc: str = "idp-grpc.idp-regional.svc.cluster.local:9090"
@@ -0,0 +1,69 @@
1
+ import asyncio
2
+ import datetime
3
+ from importlib.metadata import version
4
+
5
+ from nucliadb_telemetry import errors
6
+ from nucliadb_telemetry.errors import setup_error_handling
7
+ from nucliadb_telemetry.logs import setup_logging
8
+ from nucliadb_telemetry.settings import LogLevel, LogSettings
9
+
10
+ from hyperforge.db import logger
11
+ from hyperforge.db.agents import WORKFLOW_PURGE_RETENTION, AgentManager
12
+ from hyperforge.db.settings import DataManagerSettings
13
+
14
+
15
+ async def cleanup_deleted_workflows(
16
+ manager: AgentManager,
17
+ older_than: datetime.timedelta = WORKFLOW_PURGE_RETENTION,
18
+ ):
19
+ logger.info("Cleaning up deleted workflows")
20
+ workflows = await manager.get_expired_deleted_workflows(older_than=older_than)
21
+ logger.info("Found deleted workflows to clean up", extra={"count": len(workflows)})
22
+
23
+ for workflow in workflows:
24
+ try:
25
+ logger.info(
26
+ "Purging deleted workflow",
27
+ extra={
28
+ "account": workflow["account"],
29
+ "agent_id": workflow["agent_id"],
30
+ "workflow_id": workflow["workflow_id"],
31
+ },
32
+ )
33
+ await manager.purge_deleted_workflow(
34
+ account=workflow["account"],
35
+ agent_id=workflow["agent_id"],
36
+ workflow_id=workflow["workflow_id"],
37
+ )
38
+ except Exception as exc:
39
+ errors.capture_exception(exc)
40
+ logger.error(
41
+ "Failed to purge deleted workflow",
42
+ exc_info=exc,
43
+ extra={
44
+ "account": workflow["account"],
45
+ "agent_id": workflow["agent_id"],
46
+ "workflow_id": workflow["workflow_id"],
47
+ },
48
+ )
49
+
50
+
51
+ async def cronjob(manager: AgentManager):
52
+ await cleanup_deleted_workflows(manager)
53
+
54
+
55
+ def run(): # pragma: no cover
56
+ asyncio.run(_main())
57
+
58
+
59
+ async def _main():
60
+ log_settings = LogSettings(logger_levels={"hyperforge.server": LogLevel.INFO})
61
+ setup_logging(settings=log_settings)
62
+ setup_error_handling(version("hyperforge"))
63
+ data_manager_settings = DataManagerSettings()
64
+ manager = await AgentManager.from_settings(settings=data_manager_settings)
65
+ await manager.initialize()
66
+ try:
67
+ await cronjob(manager)
68
+ finally:
69
+ await manager.finalize()
@@ -0,0 +1,13 @@
1
+ from typing import Any, Dict
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class NUAConfig(BaseModel):
7
+ key: str
8
+
9
+
10
+ class FunctionDefinition(BaseModel):
11
+ name: str
12
+ description: str
13
+ parameters: Dict[str, Any]
hyperforge/driver.py ADDED
@@ -0,0 +1,31 @@
1
+ from typing import Any, ClassVar, Generic, Self, TypeVar
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field
4
+
5
+
6
+ class EncryptedPayload(BaseModel):
7
+ encrypted_fields: ClassVar[list[str]] = []
8
+
9
+
10
+ T = TypeVar("T", bound="EncryptedPayload")
11
+
12
+
13
+ class DriverConfig(BaseModel, Generic[T]):
14
+ id: str | None = None
15
+ identifier: str
16
+ name: str
17
+ provider: Any = Field(
18
+ ..., description="The type of driver, e.g., 'google', 'marklogic', etc."
19
+ )
20
+ config: T = Field(..., description="The configuration specific to the driver.")
21
+
22
+
23
+ class Driver(BaseModel):
24
+ name: str
25
+ provider: str
26
+
27
+ model_config = ConfigDict(arbitrary_types_allowed=True)
28
+
29
+ @classmethod
30
+ async def init(cls, driver: Any) -> Self:
31
+ raise NotImplementedError()
hyperforge/dummy.py ADDED
@@ -0,0 +1,28 @@
1
+ from typing import Any
2
+
3
+ from stashify_protos.protos import idp_pb2
4
+
5
+
6
+ class DummyIDPRegionalGRPCUtility:
7
+ requests: list[Any]
8
+
9
+ def __init__(self) -> None:
10
+ self.requests = []
11
+
12
+ async def SendRAOExportEmail(
13
+ self, payload: idp_pb2.SendRAOExportEmailRequest
14
+ ) -> idp_pb2.SendRAOExportEmailResponse:
15
+ self.requests.append(payload)
16
+ return idp_pb2.SendRAOExportEmailResponse(
17
+ status=idp_pb2.SendRAOExportEmailResponse.Status.OK,
18
+ message="Dummy IDP Regional GRPC Utility: Email sent successfully",
19
+ )
20
+
21
+ async def SendEmail(
22
+ self, payload: idp_pb2.SendEmailRequest
23
+ ) -> idp_pb2.SendEmailResponse:
24
+ self.requests.append(payload)
25
+ return idp_pb2.SendEmailResponse(
26
+ status=idp_pb2.SendEmailResponse.Status.OK,
27
+ message="Dummy IDP Regional GRPC Utility: Email sent successfully",
28
+ )
hyperforge/engine.py ADDED
@@ -0,0 +1,189 @@
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, cast
4
+
5
+ from nuclia.lib.nua import AsyncNuaClient
6
+
7
+ from hyperforge.configure import GLOBAL_REGISTRY, load_all_configurations
8
+ from hyperforge.interaction import AragAnswer
9
+ from hyperforge.llm import NoopNuaClient, NuaBaseModel, NUAConnection
10
+ from hyperforge.manager import Manager
11
+ from hyperforge.memory.memory import BaseSessionMemory, QuestionMemory, SessionMemory
12
+ from hyperforge.retrieval.agent import RetrievalAgent
13
+ from hyperforge.retrieval.config import RetrievalAgentConfig
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class State:
20
+ agent: RetrievalAgent
21
+ manager: Manager
22
+
23
+
24
+ async def init(
25
+ config: Optional[Dict[str, Any]] = None,
26
+ agent_id: str = "default",
27
+ internal_nua: bool = False,
28
+ internal_nua_api: str = "http://predict.learning.svc.cluster.local:8080",
29
+ local_openai: Optional[str] = None,
30
+ local_openai_model: Optional[str] = None,
31
+ external_nua_api_key: Optional[str] = None,
32
+ loaded_modules: list[str] = [],
33
+ retrieval_config: Optional[RetrievalAgentConfig] = None,
34
+ session_id: str = "default_session",
35
+ memory_klass: type[BaseSessionMemory] = SessionMemory,
36
+ ) -> Tuple[State, SessionMemory]:
37
+ from hyperforge.configure import scan
38
+
39
+ for load_module in loaded_modules:
40
+ try:
41
+ scan(load_module)
42
+ load_all_configurations(load_module)
43
+ except ImportError:
44
+ logger.error(f"Module {load_module} could not be loaded")
45
+
46
+ if retrieval_config is None:
47
+ if config is None:
48
+ raise ValueError("Either config or retrieval_config must be provided")
49
+ retrieval_config = RetrievalAgentConfig.model_validate(config)
50
+
51
+ state = await get_state(
52
+ agent_id=agent_id,
53
+ config=retrieval_config,
54
+ internal_nua=internal_nua,
55
+ internal_nua_api=internal_nua_api,
56
+ local_openai=local_openai,
57
+ local_openai_model=local_openai_model,
58
+ external_nua_api_key=external_nua_api_key,
59
+ )
60
+ session_memory = memory_klass.from_config(
61
+ retrieval_config.memory,
62
+ agent_id=agent_id,
63
+ workflow_id="default",
64
+ rules=retrieval_config.rules,
65
+ )
66
+ session_memory.init(session_id)
67
+ return state, session_memory
68
+
69
+
70
+ async def main(
71
+ config: Optional[Dict[str, Any]] = None,
72
+ agent_id: str = "default",
73
+ internal_nua: bool = False,
74
+ internal_nua_api: str = "http://predict.learning.svc.cluster.local:8080",
75
+ local_openai: Optional[str] = None,
76
+ local_openai_model: Optional[str] = None,
77
+ external_nua_api_key: Optional[str] = None,
78
+ question: str = "",
79
+ loaded_modules: list[str] = [],
80
+ retrieval_config: Optional[RetrievalAgentConfig] = None,
81
+ callback: Optional[Callable[[AragAnswer], Awaitable[None]]] = None,
82
+ session_id: str = "default_session",
83
+ user_metadata: Optional[Dict[str, str]] = None,
84
+ headers: Optional[Dict[str, str]] = None,
85
+ memory_klass: type[BaseSessionMemory] = SessionMemory,
86
+ streaming: bool = False,
87
+ ) -> QuestionMemory:
88
+ try:
89
+ state, session_memory = await init(
90
+ config=config,
91
+ agent_id=agent_id,
92
+ internal_nua=internal_nua,
93
+ internal_nua_api=internal_nua_api,
94
+ local_openai=local_openai,
95
+ local_openai_model=local_openai_model,
96
+ external_nua_api_key=external_nua_api_key,
97
+ loaded_modules=loaded_modules,
98
+ retrieval_config=retrieval_config,
99
+ session_id=session_id,
100
+ memory_klass=memory_klass,
101
+ )
102
+ question_memory = session_memory.start_question(question, streaming=streaming)
103
+ if callback is not None:
104
+ question_memory.set_callback_fn(callback)
105
+
106
+ if user_metadata:
107
+ question_memory.session.user_info.update(user_metadata)
108
+
109
+ if headers:
110
+ question_memory.headers.update(headers)
111
+
112
+ if state.agent is None:
113
+ raise ValueError("Agent could not be initialized")
114
+
115
+ await state.agent(
116
+ question_memory,
117
+ state.manager,
118
+ )
119
+ except Exception as e:
120
+ raise e
121
+ finally:
122
+ GLOBAL_REGISTRY.clear()
123
+ return question_memory
124
+
125
+
126
+ async def engine(
127
+ manager: Manager,
128
+ agent: RetrievalAgent,
129
+ question_memory: QuestionMemory,
130
+ user_metadata: Optional[Dict[str, str]] = None,
131
+ ) -> None:
132
+ if user_metadata is not None:
133
+ question_memory.session.user_info.update(user_metadata)
134
+
135
+ await agent(
136
+ question_memory,
137
+ manager,
138
+ )
139
+
140
+
141
+ async def get_state(
142
+ agent_id: str,
143
+ config: RetrievalAgentConfig,
144
+ internal_nua_api: str = "http://predict.learning.svc.cluster.local:8080",
145
+ internal_nua: bool = False,
146
+ local_openai: Optional[str] = None,
147
+ local_openai_model: Optional[str] = None,
148
+ external_nua_api_key: Optional[str] = None,
149
+ account: Optional[str] = None,
150
+ kbid: Optional[str] = None,
151
+ local_openai_model_klass: Optional[type[NuaBaseModel]] = None,
152
+ ) -> State:
153
+ nua: AsyncNuaClient
154
+ if internal_nua:
155
+ nua = cast(
156
+ AsyncNuaClient,
157
+ await NUAConnection.connect_internal(
158
+ kbid=kbid, account=account, url=internal_nua_api
159
+ ),
160
+ )
161
+
162
+ elif local_openai is not None and local_openai_model_klass is not None:
163
+ nua = await local_openai_model_klass.model_validate(
164
+ {
165
+ "key": external_nua_api_key,
166
+ "local_openai": local_openai,
167
+ "local_openai_model": local_openai_model,
168
+ }
169
+ ).connect()
170
+
171
+ elif external_nua_api_key is not None:
172
+ nua = await NUAConnection.model_validate(
173
+ {
174
+ "key": external_nua_api_key,
175
+ }
176
+ ).connect()
177
+
178
+ else:
179
+ logger.warning(
180
+ "No LLM backend configured — use a no-op client. Agents that don't"
181
+ " require LLM calls (e.g. the built-in ``static`` context agent) will"
182
+ " work fine; any agent that actually calls NUA will raise a clear error."
183
+ )
184
+ nua = NoopNuaClient()
185
+
186
+ manager = await Manager.from_config(drivers=config.drivers, nua=nua)
187
+ agent = await RetrievalAgent.from_config_class(config)
188
+
189
+ return State(manager=manager, agent=agent)
@@ -0,0 +1,14 @@
1
+ class AutheticationException(Exception):
2
+ pass
3
+
4
+
5
+ class NoAvailableAgents(Exception):
6
+ pass
7
+
8
+
9
+ class MaxRetries(Exception):
10
+ pass
11
+
12
+
13
+ class ModelRetry(Exception):
14
+ pass
@@ -0,0 +1,105 @@
1
+ import hashlib
2
+ import json
3
+ import threading
4
+ from functools import cached_property
5
+ from typing import Any, Optional
6
+
7
+ import mrflagly # type: ignore
8
+ import pydantic
9
+ from pydantic_settings import BaseSettings, SettingsConfigDict
10
+
11
+ _flag_service = threading.local()
12
+
13
+
14
+ class FlagContext(pydantic.BaseModel):
15
+ """Class to hold information that can then be passed to the feature flag service to determine if a flag is enabled or not.
16
+
17
+ Having it as a class allows to simply pass this object down to any function that needs to check for a feature flag and have all the relevant information in one place, instead of having to pass multiple parameters or a dictionary.
18
+ """
19
+
20
+ account_id: Optional[str] = None
21
+ kbid: str
22
+
23
+ @pydantic.computed_field # type: ignore[prop-decorator]
24
+ @cached_property
25
+ def account_id_md5(self) -> Optional[str]:
26
+ if self.account_id is not None:
27
+ return hashlib.md5(self.account_id.encode()).hexdigest()
28
+ return None
29
+
30
+
31
+ class Features:
32
+ # RAO
33
+ FILTERED_AGENTS_FEATURE_FLAG = "rao_filtered_agents_{environment}"
34
+ FILTERED_DRIVERS_FEATURE_FLAG = "rao_filtered_drivers_{environment}"
35
+ AUDIT_RAO_ASK_ENDPOINT = "learning_audit_rao_ask"
36
+ RAO_ACCOUNT_ENABLED_FEATURE_FLAG = "rao_account_enabled_{account_md5}"
37
+
38
+
39
+ class FlagService:
40
+ def __init__(self):
41
+ self.settings = Settings()
42
+ if self.settings.flag_settings_url is None:
43
+ self.flag_service = mrflagly.FlagService(data=json.dumps(DEFAULT_FLAG_DATA))
44
+ else:
45
+ self.flag_service = mrflagly.FlagService(
46
+ url=self.settings.flag_settings_url
47
+ )
48
+
49
+ def enabled(
50
+ self, flag_key: str, default: bool = False, context: Optional[dict] = None
51
+ ) -> bool:
52
+ if context is None:
53
+ context = {}
54
+ context["environment"] = self.settings.running_environment
55
+
56
+ return self.flag_service.enabled(flag_key, default=default, context=context)
57
+
58
+
59
+ class Settings(BaseSettings):
60
+ model_config = SettingsConfigDict(env_parse_none_str="null")
61
+ running_environment: str = pydantic.Field(
62
+ default="local",
63
+ validation_alias=pydantic.AliasChoices("environment", "running_environment"),
64
+ )
65
+ flag_settings_url: str | None = (
66
+ "https://cdn.rag.progress.cloud/features/features-v2.json"
67
+ )
68
+
69
+
70
+ DEFAULT_FLAG_DATA: dict[str, Any] = {
71
+ # These are just defaults to use for local dev and tests
72
+ Features.FILTERED_AGENTS_FEATURE_FLAG.format(environment="local"): {
73
+ "rollout": 0,
74
+ "variants": {"agents": []},
75
+ },
76
+ Features.FILTERED_DRIVERS_FEATURE_FLAG.format(environment="local"): {
77
+ "rollout": 0,
78
+ "variants": {"drivers": []},
79
+ },
80
+ Features.AUDIT_RAO_ASK_ENDPOINT: {
81
+ "rollout": 0,
82
+ "variants": {"environment": ["local"]},
83
+ },
84
+ }
85
+
86
+
87
+ def get_flag_service() -> FlagService:
88
+ if getattr(_flag_service, "service", None) is None:
89
+ _flag_service.service = FlagService()
90
+ return _flag_service.service
91
+
92
+
93
+ def has_feature(
94
+ flag_key: str,
95
+ default: bool = False,
96
+ context: dict[str, str] | None | FlagContext = None,
97
+ ) -> bool:
98
+ fs = get_flag_service()
99
+
100
+ if isinstance(context, FlagContext):
101
+ context_dict = context.model_dump()
102
+ else:
103
+ context_dict = context or {}
104
+
105
+ return fs.enabled(flag_key, default=default, context=context_dict)