aiverify-moonshot 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
- aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
- aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
- moonshot/__init__.py +0 -0
- moonshot/__main__.py +198 -0
- moonshot/api.py +155 -0
- moonshot/integrations/__init__.py +0 -0
- moonshot/integrations/cli/__init__.py +0 -0
- moonshot/integrations/cli/__main__.py +25 -0
- moonshot/integrations/cli/active_session_cfg.py +1 -0
- moonshot/integrations/cli/benchmark/__init__.py +0 -0
- moonshot/integrations/cli/benchmark/benchmark.py +186 -0
- moonshot/integrations/cli/benchmark/cookbook.py +545 -0
- moonshot/integrations/cli/benchmark/datasets.py +164 -0
- moonshot/integrations/cli/benchmark/metrics.py +141 -0
- moonshot/integrations/cli/benchmark/recipe.py +598 -0
- moonshot/integrations/cli/benchmark/result.py +216 -0
- moonshot/integrations/cli/benchmark/run.py +140 -0
- moonshot/integrations/cli/benchmark/runner.py +174 -0
- moonshot/integrations/cli/cli.py +64 -0
- moonshot/integrations/cli/common/__init__.py +0 -0
- moonshot/integrations/cli/common/common.py +72 -0
- moonshot/integrations/cli/common/connectors.py +325 -0
- moonshot/integrations/cli/common/display_helper.py +42 -0
- moonshot/integrations/cli/common/prompt_template.py +94 -0
- moonshot/integrations/cli/initialisation/__init__.py +0 -0
- moonshot/integrations/cli/initialisation/initialisation.py +14 -0
- moonshot/integrations/cli/redteam/__init__.py +0 -0
- moonshot/integrations/cli/redteam/attack_module.py +70 -0
- moonshot/integrations/cli/redteam/context_strategy.py +147 -0
- moonshot/integrations/cli/redteam/prompt_template.py +67 -0
- moonshot/integrations/cli/redteam/redteam.py +90 -0
- moonshot/integrations/cli/redteam/session.py +467 -0
- moonshot/integrations/web_api/.env.dev +7 -0
- moonshot/integrations/web_api/__init__.py +0 -0
- moonshot/integrations/web_api/__main__.py +56 -0
- moonshot/integrations/web_api/app.py +125 -0
- moonshot/integrations/web_api/container.py +146 -0
- moonshot/integrations/web_api/log/.gitkeep +0 -0
- moonshot/integrations/web_api/logging_conf.py +114 -0
- moonshot/integrations/web_api/routes/__init__.py +0 -0
- moonshot/integrations/web_api/routes/attack_modules.py +66 -0
- moonshot/integrations/web_api/routes/benchmark.py +116 -0
- moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
- moonshot/integrations/web_api/routes/context_strategy.py +129 -0
- moonshot/integrations/web_api/routes/cookbook.py +225 -0
- moonshot/integrations/web_api/routes/dataset.py +120 -0
- moonshot/integrations/web_api/routes/endpoint.py +282 -0
- moonshot/integrations/web_api/routes/metric.py +78 -0
- moonshot/integrations/web_api/routes/prompt_template.py +128 -0
- moonshot/integrations/web_api/routes/recipe.py +219 -0
- moonshot/integrations/web_api/routes/redteam.py +609 -0
- moonshot/integrations/web_api/routes/runner.py +239 -0
- moonshot/integrations/web_api/schemas/__init__.py +0 -0
- moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
- moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
- moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
- moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
- moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
- moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
- moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
- moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
- moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
- moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
- moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
- moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
- moonshot/integrations/web_api/services/__init__.py +0 -0
- moonshot/integrations/web_api/services/attack_module_service.py +34 -0
- moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
- moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
- moonshot/integrations/web_api/services/base_service.py +8 -0
- moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
- moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
- moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
- moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
- moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
- moonshot/integrations/web_api/services/cookbook_service.py +194 -0
- moonshot/integrations/web_api/services/dataset_service.py +20 -0
- moonshot/integrations/web_api/services/endpoint_service.py +65 -0
- moonshot/integrations/web_api/services/metric_service.py +14 -0
- moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
- moonshot/integrations/web_api/services/recipe_service.py +155 -0
- moonshot/integrations/web_api/services/runner_service.py +147 -0
- moonshot/integrations/web_api/services/session_service.py +350 -0
- moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
- moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
- moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
- moonshot/integrations/web_api/types/types.py +99 -0
- moonshot/src/__init__.py +0 -0
- moonshot/src/api/__init__.py +0 -0
- moonshot/src/api/api_connector.py +58 -0
- moonshot/src/api/api_connector_endpoint.py +162 -0
- moonshot/src/api/api_context_strategy.py +57 -0
- moonshot/src/api/api_cookbook.py +160 -0
- moonshot/src/api/api_dataset.py +46 -0
- moonshot/src/api/api_environment_variables.py +17 -0
- moonshot/src/api/api_metrics.py +51 -0
- moonshot/src/api/api_prompt_template.py +43 -0
- moonshot/src/api/api_recipe.py +182 -0
- moonshot/src/api/api_red_teaming.py +59 -0
- moonshot/src/api/api_result.py +84 -0
- moonshot/src/api/api_run.py +74 -0
- moonshot/src/api/api_runner.py +132 -0
- moonshot/src/api/api_session.py +290 -0
- moonshot/src/configs/__init__.py +0 -0
- moonshot/src/configs/env_variables.py +187 -0
- moonshot/src/connectors/__init__.py +0 -0
- moonshot/src/connectors/connector.py +327 -0
- moonshot/src/connectors/connector_prompt_arguments.py +17 -0
- moonshot/src/connectors_endpoints/__init__.py +0 -0
- moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
- moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
- moonshot/src/cookbooks/__init__.py +0 -0
- moonshot/src/cookbooks/cookbook.py +225 -0
- moonshot/src/cookbooks/cookbook_arguments.py +34 -0
- moonshot/src/datasets/__init__.py +0 -0
- moonshot/src/datasets/dataset.py +255 -0
- moonshot/src/datasets/dataset_arguments.py +50 -0
- moonshot/src/metrics/__init__.py +0 -0
- moonshot/src/metrics/metric.py +192 -0
- moonshot/src/metrics/metric_interface.py +95 -0
- moonshot/src/prompt_templates/__init__.py +0 -0
- moonshot/src/prompt_templates/prompt_template.py +103 -0
- moonshot/src/recipes/__init__.py +0 -0
- moonshot/src/recipes/recipe.py +340 -0
- moonshot/src/recipes/recipe_arguments.py +111 -0
- moonshot/src/redteaming/__init__.py +0 -0
- moonshot/src/redteaming/attack/__init__.py +0 -0
- moonshot/src/redteaming/attack/attack_module.py +618 -0
- moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
- moonshot/src/redteaming/attack/context_strategy.py +131 -0
- moonshot/src/redteaming/context_strategy/__init__.py +0 -0
- moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
- moonshot/src/redteaming/session/__init__.py +0 -0
- moonshot/src/redteaming/session/chat.py +209 -0
- moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
- moonshot/src/redteaming/session/red_teaming_type.py +6 -0
- moonshot/src/redteaming/session/session.py +775 -0
- moonshot/src/results/__init__.py +0 -0
- moonshot/src/results/result.py +119 -0
- moonshot/src/results/result_arguments.py +44 -0
- moonshot/src/runners/__init__.py +0 -0
- moonshot/src/runners/runner.py +476 -0
- moonshot/src/runners/runner_arguments.py +46 -0
- moonshot/src/runners/runner_type.py +6 -0
- moonshot/src/runs/__init__.py +0 -0
- moonshot/src/runs/run.py +344 -0
- moonshot/src/runs/run_arguments.py +162 -0
- moonshot/src/runs/run_progress.py +145 -0
- moonshot/src/runs/run_status.py +10 -0
- moonshot/src/storage/__init__.py +0 -0
- moonshot/src/storage/db_interface.py +128 -0
- moonshot/src/storage/io_interface.py +31 -0
- moonshot/src/storage/storage.py +525 -0
- moonshot/src/utils/__init__.py +0 -0
- moonshot/src/utils/import_modules.py +96 -0
- moonshot/src/utils/timeit.py +25 -0
|
@@ -0,0 +1,775 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from ast import literal_eval
|
|
6
|
+
from datetime import datetime
|
|
7
|
+
from typing import Any, Callable
|
|
8
|
+
|
|
9
|
+
from slugify import slugify
|
|
10
|
+
|
|
11
|
+
from moonshot.src.configs.env_variables import EnvVariables
|
|
12
|
+
from moonshot.src.redteaming.session.chat import Chat
|
|
13
|
+
from moonshot.src.redteaming.session.red_teaming_progress import RedTeamingProgress
|
|
14
|
+
from moonshot.src.redteaming.session.red_teaming_type import RedTeamingType
|
|
15
|
+
from moonshot.src.runners.runner_type import RunnerType
|
|
16
|
+
from moonshot.src.runs.run_status import RunStatus
|
|
17
|
+
from moonshot.src.storage.db_interface import DBInterface
|
|
18
|
+
from moonshot.src.storage.storage import Storage
|
|
19
|
+
from moonshot.src.utils.import_modules import get_instance
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SessionMetadata:
|
|
23
|
+
# TODO: convert this into a pydantic model
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
session_id: str,
|
|
27
|
+
endpoints: list[str],
|
|
28
|
+
created_epoch: float,
|
|
29
|
+
created_datetime: str,
|
|
30
|
+
prompt_template: str,
|
|
31
|
+
context_strategy: str,
|
|
32
|
+
cs_num_of_prev_prompts: int,
|
|
33
|
+
attack_module: str,
|
|
34
|
+
metric: str,
|
|
35
|
+
system_prompt: str,
|
|
36
|
+
):
|
|
37
|
+
self.session_id = self.check_type(session_id, str)
|
|
38
|
+
self.endpoints = self.check_type(endpoints, list)
|
|
39
|
+
self.created_epoch = self.check_type(created_epoch, float)
|
|
40
|
+
self.created_datetime = self.check_type(created_datetime, str)
|
|
41
|
+
self.prompt_template = self.check_type(prompt_template, str)
|
|
42
|
+
self.context_strategy = self.check_type(context_strategy, str)
|
|
43
|
+
self.cs_num_of_prev_prompts = self.check_type(cs_num_of_prev_prompts, int)
|
|
44
|
+
self.attack_module = self.check_type(attack_module, str)
|
|
45
|
+
self.metric = self.check_type(metric, str)
|
|
46
|
+
self.system_prompt = self.check_type(system_prompt, str)
|
|
47
|
+
|
|
48
|
+
def to_dict(self) -> dict:
|
|
49
|
+
"""
|
|
50
|
+
Converts the SessionMetadata instance into a dictionary.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
dict: A dictionary representation of the SessionMetadata instance.
|
|
54
|
+
"""
|
|
55
|
+
return {
|
|
56
|
+
"session_id": self.session_id,
|
|
57
|
+
"endpoints": self.endpoints,
|
|
58
|
+
"created_epoch": str(self.created_epoch),
|
|
59
|
+
"created_datetime": self.created_datetime,
|
|
60
|
+
"prompt_template": self.prompt_template,
|
|
61
|
+
"context_strategy": self.context_strategy,
|
|
62
|
+
"cs_num_of_prev_prompts": self.cs_num_of_prev_prompts,
|
|
63
|
+
"attack_module": self.attack_module,
|
|
64
|
+
"metric": self.metric,
|
|
65
|
+
"system_prompt": self.system_prompt,
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
def to_tuple(self) -> tuple:
|
|
69
|
+
"""
|
|
70
|
+
Converts the SessionMetadata instance into a tuple.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
tuple: A tuple representation of the SessionMetadata instance.
|
|
74
|
+
"""
|
|
75
|
+
return (
|
|
76
|
+
self.session_id,
|
|
77
|
+
str(self.endpoints),
|
|
78
|
+
self.created_epoch,
|
|
79
|
+
self.created_datetime,
|
|
80
|
+
self.prompt_template,
|
|
81
|
+
self.context_strategy,
|
|
82
|
+
self.cs_num_of_prev_prompts,
|
|
83
|
+
self.attack_module,
|
|
84
|
+
self.metric,
|
|
85
|
+
self.system_prompt,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
@classmethod
|
|
89
|
+
def from_tuple(cls, data_tuple: tuple) -> SessionMetadata:
|
|
90
|
+
"""
|
|
91
|
+
Creates a SessionMetadata instance from a tuple using the class method.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
data_tuple (tuple): A tuple containing session_id, endpoints, created_epoch, and created_datetime.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
SessionMetadata: An instance of SessionMetadata.
|
|
98
|
+
"""
|
|
99
|
+
(
|
|
100
|
+
runner_id,
|
|
101
|
+
endpoints,
|
|
102
|
+
created_epoch,
|
|
103
|
+
created_datetime,
|
|
104
|
+
prompt_template,
|
|
105
|
+
context_strategy,
|
|
106
|
+
cs_num_of_prev_prompts,
|
|
107
|
+
attack_module,
|
|
108
|
+
metric,
|
|
109
|
+
system_prompt,
|
|
110
|
+
) = data_tuple
|
|
111
|
+
|
|
112
|
+
return cls(
|
|
113
|
+
runner_id,
|
|
114
|
+
literal_eval(endpoints),
|
|
115
|
+
created_epoch,
|
|
116
|
+
created_datetime,
|
|
117
|
+
prompt_template,
|
|
118
|
+
context_strategy,
|
|
119
|
+
cs_num_of_prev_prompts,
|
|
120
|
+
attack_module,
|
|
121
|
+
metric,
|
|
122
|
+
system_prompt,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def check_type(self, checked_attribute: Any, expected_type: type) -> Any:
|
|
126
|
+
"""
|
|
127
|
+
Checks if the type of the given attribute matches the expected type.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
checked_attribute (Any): The attribute to be checked.
|
|
131
|
+
expected_type (type): The expected type of the attribute.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Any: The checked attribute if its type matches the expected type.
|
|
135
|
+
|
|
136
|
+
Raises:
|
|
137
|
+
TypeError: If the type of the checked attribute does not match the expected type.
|
|
138
|
+
"""
|
|
139
|
+
if not isinstance(checked_attribute, expected_type):
|
|
140
|
+
raise TypeError(
|
|
141
|
+
f"Expected type for {checked_attribute} is {expected_type}, but got {type(checked_attribute)}"
|
|
142
|
+
)
|
|
143
|
+
return checked_attribute
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
class Session:
|
|
147
|
+
DEFAULT_CONTEXT_STRATEGY_PROMPT = 5
|
|
148
|
+
sql_create_session_metadata_table = """
|
|
149
|
+
CREATE TABLE IF NOT EXISTS session_metadata_table (
|
|
150
|
+
session_id text PRIMARY KEY NOT NULL,
|
|
151
|
+
endpoints text NOT NULL,
|
|
152
|
+
created_epoch INTEGER NOT NULL,
|
|
153
|
+
created_datetime text NOT NULL,
|
|
154
|
+
prompt_template text,
|
|
155
|
+
context_strategy text,
|
|
156
|
+
cs_num_of_prev_prompts int,
|
|
157
|
+
attack_module text,
|
|
158
|
+
metric text,
|
|
159
|
+
system_prompt text
|
|
160
|
+
);
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
sql_create_chat_history_table = """
|
|
164
|
+
CREATE TABLE IF NOT EXISTS {} (
|
|
165
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
166
|
+
connection_id text NOT NULL,
|
|
167
|
+
context_strategy text,
|
|
168
|
+
prompt_template text,
|
|
169
|
+
attack_module text,
|
|
170
|
+
metric text,
|
|
171
|
+
prompt text NOT NULL,
|
|
172
|
+
prepared_prompt text NOT NULL,
|
|
173
|
+
system_prompt text,
|
|
174
|
+
predicted_result text NOT NULL,
|
|
175
|
+
duration text NOT NULL,
|
|
176
|
+
prompt_time text NOT NULL
|
|
177
|
+
);
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
sql_create_session_metadata_record = """
|
|
181
|
+
INSERT INTO session_metadata_table (
|
|
182
|
+
session_id,endpoints,created_epoch,created_datetime,prompt_template,context_strategy,cs_num_of_prev_prompts,
|
|
183
|
+
attack_module, metric, system_prompt) VALUES(?,?,?,?,?,?,?,?,?,?)
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
sql_read_session_metadata = """
|
|
187
|
+
SELECT * from session_metadata_table
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
sql_update_session_metadata_field = """
|
|
191
|
+
UPDATE session_metadata_table SET {}=? WHERE session_id=?
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
sql_drop_table = """
|
|
195
|
+
DROP TABLE IF EXISTS {}
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(
|
|
199
|
+
self,
|
|
200
|
+
runner_id: str,
|
|
201
|
+
runner_type: RunnerType,
|
|
202
|
+
runner_args: dict,
|
|
203
|
+
database_instance: Any | None,
|
|
204
|
+
endpoints: list[str],
|
|
205
|
+
results_file_path: str,
|
|
206
|
+
progress_callback_func: Callable | None = None,
|
|
207
|
+
):
|
|
208
|
+
"""
|
|
209
|
+
Initializes a new session with the given parameters, creates session metadata,
|
|
210
|
+
and sets up the database tables for session metadata and chat history.
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
runner_id (str): The unique identifier for the runner.
|
|
214
|
+
runner_type (RunnerType): The type of runner being used.
|
|
215
|
+
runner_args (dict): A dictionary of arguments specific to the runner.
|
|
216
|
+
database_instance (Any | None): The database instance to connect to, or None if not available.
|
|
217
|
+
endpoints (list[str]): A list of endpoint identifiers.
|
|
218
|
+
results_file_path (str): The file path where results should be stored.
|
|
219
|
+
progress_callback_func (Callable | None): An optional callback function for progress updates.
|
|
220
|
+
"""
|
|
221
|
+
created_epoch = time.time()
|
|
222
|
+
created_datetime = datetime.fromtimestamp(created_epoch).strftime(
|
|
223
|
+
"%Y%m%d-%H%M%S"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
self.runner_id = slugify(runner_id, lowercase=True)
|
|
227
|
+
if self.runner_id != runner_id:
|
|
228
|
+
raise RuntimeError(
|
|
229
|
+
"[Session] Failed to initialise Session. Invalid Runner ID."
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
self.runner_args = runner_args
|
|
233
|
+
self.runner_type = runner_type
|
|
234
|
+
self.results_file_path = results_file_path
|
|
235
|
+
self.progress_callback_func = progress_callback_func
|
|
236
|
+
self.database_instance = database_instance
|
|
237
|
+
|
|
238
|
+
self.red_teaming_progress = RedTeamingProgress(
|
|
239
|
+
self.runner_id, self.runner_args, self.progress_callback_func
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
self.cancel_event = asyncio.Event()
|
|
243
|
+
|
|
244
|
+
prompt_template = self.runner_args.get("prompt_template", "")
|
|
245
|
+
context_strategy = self.runner_args.get("context_strategy", "")
|
|
246
|
+
cs_num_of_prev_prompts = self.runner_args.get(
|
|
247
|
+
"cs_num_of_prev_prompts", Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
|
|
248
|
+
)
|
|
249
|
+
attack_module = self.runner_args.get("attack_module", "")
|
|
250
|
+
system_prompt = self.runner_args.get("system_prompt", "")
|
|
251
|
+
metric_id = self.runner_args.get("metric_id", "")
|
|
252
|
+
|
|
253
|
+
self.check_file_exists(
|
|
254
|
+
EnvVariables.PROMPT_TEMPLATES.name,
|
|
255
|
+
prompt_template,
|
|
256
|
+
"Prompt Template",
|
|
257
|
+
"json",
|
|
258
|
+
)
|
|
259
|
+
self.check_file_exists(
|
|
260
|
+
EnvVariables.CONTEXT_STRATEGY.name,
|
|
261
|
+
context_strategy,
|
|
262
|
+
"Context Strategy",
|
|
263
|
+
"py",
|
|
264
|
+
)
|
|
265
|
+
self.check_file_exists(
|
|
266
|
+
EnvVariables.ATTACK_MODULES.name, attack_module, "Attack Module", "py"
|
|
267
|
+
)
|
|
268
|
+
self.check_file_exists(EnvVariables.METRICS.name, metric_id, "Metric", "py")
|
|
269
|
+
|
|
270
|
+
if self.database_instance:
|
|
271
|
+
# create session metadata table if it does not exist
|
|
272
|
+
if not Storage.check_database_table_exists(
|
|
273
|
+
self.database_instance, "session_metadata_table"
|
|
274
|
+
):
|
|
275
|
+
Storage.create_database_table(
|
|
276
|
+
self.database_instance, Session.sql_create_session_metadata_table
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# get session metadata record
|
|
280
|
+
session_metadata_records = Storage.read_database_records(
|
|
281
|
+
self.database_instance, Session.sql_read_session_metadata
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# check if the session metadata record already exists
|
|
285
|
+
if session_metadata_records:
|
|
286
|
+
print("[Session] Session already exists.")
|
|
287
|
+
self.session_metadata = SessionMetadata.from_tuple(
|
|
288
|
+
session_metadata_records[0]
|
|
289
|
+
)
|
|
290
|
+
# create a new record if session metadata does not exist
|
|
291
|
+
else:
|
|
292
|
+
print("[Session] Creating new session.")
|
|
293
|
+
|
|
294
|
+
# create chat history table for each endpoint
|
|
295
|
+
|
|
296
|
+
self.session_metadata = SessionMetadata(
|
|
297
|
+
runner_id,
|
|
298
|
+
endpoints,
|
|
299
|
+
created_epoch,
|
|
300
|
+
created_datetime,
|
|
301
|
+
prompt_template,
|
|
302
|
+
context_strategy,
|
|
303
|
+
cs_num_of_prev_prompts,
|
|
304
|
+
attack_module,
|
|
305
|
+
metric_id,
|
|
306
|
+
system_prompt,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
for endpoint in endpoints:
|
|
310
|
+
endpoint_id = endpoint.replace("-", "_")
|
|
311
|
+
Storage.create_database_table(
|
|
312
|
+
self.database_instance,
|
|
313
|
+
Session.sql_create_chat_history_table.format(endpoint_id),
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
Storage.create_database_record(
|
|
317
|
+
self.database_instance,
|
|
318
|
+
self.session_metadata.to_tuple(),
|
|
319
|
+
Session.sql_create_session_metadata_record,
|
|
320
|
+
)
|
|
321
|
+
else:
|
|
322
|
+
raise RuntimeError(
|
|
323
|
+
"[Session] Failed to initialise Session. No database instance provided."
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
@staticmethod
|
|
327
|
+
def load(database_instance: DBInterface | None) -> dict | None:
|
|
328
|
+
"""
|
|
329
|
+
Loads run data for a given session_id from the database, or the latest run if run_id is None.
|
|
330
|
+
|
|
331
|
+
This method retrieves run data for the specified run_id from the database, or if run_id is None,
|
|
332
|
+
it retrieves the latest run. If the database instance is not provided, it raises a RuntimeError.
|
|
333
|
+
If the database instance is provided, it invokes the read_record method of the database instance
|
|
334
|
+
with the given run_id or the latest run and returns a RunArguments object created from the retrieved record.
|
|
335
|
+
|
|
336
|
+
Parameters:
|
|
337
|
+
database_instance (DBInterface | None): The database accessor instance.
|
|
338
|
+
run_id (int | None): The ID of the run to retrieve, or None to retrieve the latest run.
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
RunArguments: An object containing the details of the run with the given run_id or the latest run.
|
|
342
|
+
"""
|
|
343
|
+
if not database_instance:
|
|
344
|
+
raise RuntimeError("[Session] Runner instance database not provided.")
|
|
345
|
+
|
|
346
|
+
# runner does not have session, return None
|
|
347
|
+
if not Storage.check_database_table_exists(
|
|
348
|
+
database_instance, "session_metadata_table"
|
|
349
|
+
):
|
|
350
|
+
return None
|
|
351
|
+
|
|
352
|
+
# retrieve session metadata
|
|
353
|
+
session_metadata_info = Storage.read_database_records(
|
|
354
|
+
database_instance,
|
|
355
|
+
Session.sql_read_session_metadata,
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if not session_metadata_info:
|
|
359
|
+
raise RuntimeError("[Session] Failed to get Session metadata.")
|
|
360
|
+
|
|
361
|
+
# convert session metadata from tuple to dict
|
|
362
|
+
session_metadata_obj = SessionMetadata.from_tuple(session_metadata_info[0])
|
|
363
|
+
session_metadata_dict = session_metadata_obj.to_dict()
|
|
364
|
+
|
|
365
|
+
return session_metadata_dict
|
|
366
|
+
|
|
367
|
+
async def run(self) -> list | None:
|
|
368
|
+
"""
|
|
369
|
+
Asynchronously executes the session run process.
|
|
370
|
+
|
|
371
|
+
This method orchestrates the entire session run process asynchronously. It initializes the session,
|
|
372
|
+
sets up the necessary environment, executes the session's main logic, handles any errors, and finally,
|
|
373
|
+
compiles and returns the results in a dictionary format. Throughout the process, it updates
|
|
374
|
+
the session's status and logs progress.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
dict: A dictionary containing the results of the session run, including any errors encountered.
|
|
378
|
+
"""
|
|
379
|
+
# ------------------------------------------------------------------------------
|
|
380
|
+
# Part 1: Get asyncio running loop
|
|
381
|
+
# ------------------------------------------------------------------------------
|
|
382
|
+
print("[Session] Part 1: Loading asyncio running loop...")
|
|
383
|
+
loop = asyncio.get_running_loop()
|
|
384
|
+
|
|
385
|
+
# ------------------------------------------------------------------------------
|
|
386
|
+
# Part 2: Load runner processing module
|
|
387
|
+
# ------------------------------------------------------------------------------
|
|
388
|
+
print("[Session] Part 2: Loading runner processing module...")
|
|
389
|
+
start_time = time.perf_counter()
|
|
390
|
+
runner_module_instance = None
|
|
391
|
+
try:
|
|
392
|
+
runner_processing_module_name = self.runner_args.get(
|
|
393
|
+
"runner_processing_module", None
|
|
394
|
+
)
|
|
395
|
+
if runner_processing_module_name:
|
|
396
|
+
# Intialize the runner instance
|
|
397
|
+
runner_module_instance = get_instance(
|
|
398
|
+
runner_processing_module_name,
|
|
399
|
+
Storage.get_filepath(
|
|
400
|
+
EnvVariables.RUNNERS_MODULES.name,
|
|
401
|
+
runner_processing_module_name,
|
|
402
|
+
"py",
|
|
403
|
+
),
|
|
404
|
+
)
|
|
405
|
+
if runner_module_instance:
|
|
406
|
+
runner_module_instance = runner_module_instance()
|
|
407
|
+
else:
|
|
408
|
+
raise RuntimeError(
|
|
409
|
+
f"Unable to get defined runner module instance - {runner_module_instance}"
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
raise RuntimeError(
|
|
413
|
+
f"Failed to get runner processing module name: {runner_processing_module_name}"
|
|
414
|
+
)
|
|
415
|
+
|
|
416
|
+
except Exception as e:
|
|
417
|
+
print(
|
|
418
|
+
f"[Session] Failed to load runner processing module in Part 2 due to error: {str(e)}"
|
|
419
|
+
)
|
|
420
|
+
raise e
|
|
421
|
+
|
|
422
|
+
finally:
|
|
423
|
+
print(
|
|
424
|
+
f"[Session] Loading runner processing module took {(time.perf_counter() - start_time):.4f}s"
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# ------------------------------------------------------------------------------
|
|
428
|
+
# Part 3: Run runner processing module
|
|
429
|
+
# ------------------------------------------------------------------------------
|
|
430
|
+
print("[Session] Part 3: Running runner processing module...")
|
|
431
|
+
start_time = time.perf_counter()
|
|
432
|
+
runner_results = {}
|
|
433
|
+
|
|
434
|
+
try:
|
|
435
|
+
if runner_module_instance:
|
|
436
|
+
runner_results = await runner_module_instance.generate( # type: ignore ; ducktyping
|
|
437
|
+
loop,
|
|
438
|
+
self.runner_args,
|
|
439
|
+
self.database_instance,
|
|
440
|
+
self.session_metadata,
|
|
441
|
+
self.check_redteaming_type(),
|
|
442
|
+
self.red_teaming_progress,
|
|
443
|
+
self.cancel_event,
|
|
444
|
+
)
|
|
445
|
+
else:
|
|
446
|
+
raise RuntimeError("Failed to initialise runner module instance.")
|
|
447
|
+
|
|
448
|
+
except Exception as e:
|
|
449
|
+
print(
|
|
450
|
+
f"[Session] Failed to run runner processing module in Part 3 due to error: {str(e)}"
|
|
451
|
+
)
|
|
452
|
+
raise e
|
|
453
|
+
|
|
454
|
+
finally:
|
|
455
|
+
self.red_teaming_progress.status = RunStatus.COMPLETED
|
|
456
|
+
if self.check_redteaming_type() == RedTeamingType.AUTOMATED:
|
|
457
|
+
self.red_teaming_progress.notify_progress()
|
|
458
|
+
print(
|
|
459
|
+
f"[Session] Running runner processing module took {(time.perf_counter() - start_time):.4f}s"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
# ------------------------------------------------------------------------------
|
|
463
|
+
# Part 4: Wrap up run
|
|
464
|
+
# ------------------------------------------------------------------------------
|
|
465
|
+
print("[Session] Part 4: Wrap up run...")
|
|
466
|
+
return runner_results
|
|
467
|
+
|
|
468
|
+
def cancel(self) -> None:
|
|
469
|
+
"""
|
|
470
|
+
Sets the cancel event to stop the automated red teaming process.
|
|
471
|
+
|
|
472
|
+
This method is used to signal that the automated red teaming process should be cancelled. It sets the
|
|
473
|
+
cancel_event which can be checked in various points of the asynchronous red teaming process to gracefully stop
|
|
474
|
+
the execution.
|
|
475
|
+
|
|
476
|
+
Returns:
|
|
477
|
+
None
|
|
478
|
+
"""
|
|
479
|
+
print("[Session] Cancelling automated red teaming...")
|
|
480
|
+
self.cancel_event.set()
|
|
481
|
+
|
|
482
|
+
def check_redteaming_type(self) -> RedTeamingType:
|
|
483
|
+
"""
|
|
484
|
+
Checks the type of red teaming strategy based on the runner arguments.
|
|
485
|
+
|
|
486
|
+
Returns:
|
|
487
|
+
RedTeamingType: The type of red teaming strategy.
|
|
488
|
+
|
|
489
|
+
Raises:
|
|
490
|
+
RuntimeError: If the red teaming arguments are missing.
|
|
491
|
+
"""
|
|
492
|
+
if (
|
|
493
|
+
"attack_strategies" in self.runner_args
|
|
494
|
+
and self.runner_args.get("attack_strategies") is not None
|
|
495
|
+
):
|
|
496
|
+
return RedTeamingType.AUTOMATED
|
|
497
|
+
elif (
|
|
498
|
+
"manual_rt_args" in self.runner_args
|
|
499
|
+
and self.runner_args.get("manual_rt_args") is not None
|
|
500
|
+
):
|
|
501
|
+
return RedTeamingType.MANUAL
|
|
502
|
+
else:
|
|
503
|
+
raise RuntimeError("Missing red teaming arguments.")
|
|
504
|
+
|
|
505
|
+
@staticmethod
|
|
506
|
+
def update_context_strategy(
|
|
507
|
+
db_instance: DBInterface | None, runner_id: str, context_strategy: str
|
|
508
|
+
) -> bool:
|
|
509
|
+
"""
|
|
510
|
+
Updates the context strategy for a specific runner in the database.
|
|
511
|
+
|
|
512
|
+
Args:
|
|
513
|
+
db_instance (DBInterface | None): The database instance to update the context strategy in.
|
|
514
|
+
runner_id (str): The ID of the runner.
|
|
515
|
+
context_strategy (str): The name of the context strategy to be used.
|
|
516
|
+
|
|
517
|
+
Returns:
|
|
518
|
+
bool: The status on whether the context strategy is updated successfully.
|
|
519
|
+
|
|
520
|
+
Raises:
|
|
521
|
+
RuntimeError: If the database instance is not provided or if the context strategy does not exist.
|
|
522
|
+
"""
|
|
523
|
+
if not db_instance:
|
|
524
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
525
|
+
if context_strategy and not Storage.is_object_exists(
|
|
526
|
+
EnvVariables.CONTEXT_STRATEGY.name, context_strategy, "py"
|
|
527
|
+
):
|
|
528
|
+
raise RuntimeError(
|
|
529
|
+
f"[Session] Context Strategy {context_strategy} does not exist."
|
|
530
|
+
)
|
|
531
|
+
else:
|
|
532
|
+
Storage.update_database_record(
|
|
533
|
+
db_instance,
|
|
534
|
+
(context_strategy, runner_id),
|
|
535
|
+
Session.sql_update_session_metadata_field.format("context_strategy"),
|
|
536
|
+
)
|
|
537
|
+
return True
|
|
538
|
+
|
|
539
|
+
@staticmethod
|
|
540
|
+
def update_cs_num_of_prev_prompts(
|
|
541
|
+
db_instance: DBInterface | None, runner_id: str, cs_num_of_prev_prompts: int
|
|
542
|
+
) -> bool:
|
|
543
|
+
"""
|
|
544
|
+
Updates the number of previous prompts for a specific runner in the database.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
db_instance (DBInterface | None): The database instance to update the number of previous prompts in.
|
|
548
|
+
runner_id (str): The ID of the runner.
|
|
549
|
+
cs_num_of_prev_prompts (int): The new number of previous prompts to be used.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
bool: The status on whether the number of prompts for context strategy is updated successfully.
|
|
553
|
+
|
|
554
|
+
Raises:
|
|
555
|
+
RuntimeError: If the database instance is not provided.
|
|
556
|
+
"""
|
|
557
|
+
if not db_instance:
|
|
558
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
559
|
+
Storage.update_database_record(
|
|
560
|
+
db_instance,
|
|
561
|
+
(cs_num_of_prev_prompts, runner_id),
|
|
562
|
+
Session.sql_update_session_metadata_field.format("cs_num_of_prev_prompts"),
|
|
563
|
+
)
|
|
564
|
+
return True
|
|
565
|
+
|
|
566
|
+
@staticmethod
|
|
567
|
+
def update_prompt_template(
|
|
568
|
+
db_instance: DBInterface | None, runner_id: str, prompt_template: str
|
|
569
|
+
) -> bool:
|
|
570
|
+
"""
|
|
571
|
+
Updates the prompt template in the database for the specified runner.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
db_instance (DBInterface | None): The database instance to update the prompt template in.
|
|
575
|
+
runner_id (str): The ID of the runner.
|
|
576
|
+
prompt_template (str): The new prompt template to be used.
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
bool: The status on whether the prompt template is updated successfully.
|
|
580
|
+
|
|
581
|
+
Raises:
|
|
582
|
+
RuntimeError: If the database instance is not provided or if the prompt template does not exist.
|
|
583
|
+
"""
|
|
584
|
+
if not db_instance:
|
|
585
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
586
|
+
if prompt_template and not Storage.is_object_exists(
|
|
587
|
+
EnvVariables.PROMPT_TEMPLATES.name, prompt_template, "json"
|
|
588
|
+
):
|
|
589
|
+
raise RuntimeError(
|
|
590
|
+
f"[Session] Prompt Template {prompt_template} does not exist."
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
Storage.update_database_record(
|
|
594
|
+
db_instance,
|
|
595
|
+
(prompt_template, runner_id),
|
|
596
|
+
Session.sql_update_session_metadata_field.format("prompt_template"),
|
|
597
|
+
)
|
|
598
|
+
return True
|
|
599
|
+
|
|
600
|
+
@staticmethod
|
|
601
|
+
def update_metric(
|
|
602
|
+
db_instance: DBInterface | None, runner_id: str, metric_id: str
|
|
603
|
+
) -> bool:
|
|
604
|
+
"""
|
|
605
|
+
Updates the metric in the database for the specified runner.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
db_instance (DBInterface | None): The database instance to update the metric in.
|
|
609
|
+
runner_id (str): The ID of the runner.
|
|
610
|
+
metric_id (str): The new metric to be used.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
bool: The status on whether the metric is updated successfully.
|
|
614
|
+
|
|
615
|
+
Raises:
|
|
616
|
+
RuntimeError: If the database instance is not provided or if the metric does not exist.
|
|
617
|
+
"""
|
|
618
|
+
if not db_instance:
|
|
619
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
620
|
+
if metric_id and not Storage.is_object_exists(
|
|
621
|
+
EnvVariables.METRICS.name, metric_id, "py"
|
|
622
|
+
):
|
|
623
|
+
raise RuntimeError(f"[Session] Metric {metric_id} does not exist.")
|
|
624
|
+
else:
|
|
625
|
+
Storage.update_database_record(
|
|
626
|
+
db_instance,
|
|
627
|
+
(metric_id, runner_id),
|
|
628
|
+
Session.sql_update_session_metadata_field.format("metric"),
|
|
629
|
+
)
|
|
630
|
+
return True
|
|
631
|
+
|
|
632
|
+
@staticmethod
|
|
633
|
+
def update_system_prompt(
|
|
634
|
+
db_instance: DBInterface | None, runner_id: str, system_prompt: str
|
|
635
|
+
) -> bool:
|
|
636
|
+
"""
|
|
637
|
+
Updates the system prompt in the database for the specified runner.
|
|
638
|
+
|
|
639
|
+
Args:
|
|
640
|
+
db_instance (DBInterface | None): The database instance to update the system prompt in.
|
|
641
|
+
runner_id (str): The ID of the runner.
|
|
642
|
+
system_prompt (str): The new system prompt to be used.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
bool: The status on whether the system prompt is updated successfully.
|
|
646
|
+
|
|
647
|
+
Raises:
|
|
648
|
+
RuntimeError: If the database instance is not provided.
|
|
649
|
+
"""
|
|
650
|
+
if not db_instance:
|
|
651
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
652
|
+
Storage.update_database_record(
|
|
653
|
+
db_instance,
|
|
654
|
+
(system_prompt, runner_id),
|
|
655
|
+
Session.sql_update_session_metadata_field.format("system_prompt"),
|
|
656
|
+
)
|
|
657
|
+
return True
|
|
658
|
+
|
|
659
|
+
@staticmethod
|
|
660
|
+
def update_attack_module(
|
|
661
|
+
db_instance: DBInterface | None, runner_id: str, attack_module_id: str
|
|
662
|
+
) -> bool:
|
|
663
|
+
"""
|
|
664
|
+
Updates the attack module in the database for the specified runner.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
db_instance (DBInterface | None): The database instance to update the attack module in.
|
|
668
|
+
runner_id (str): The ID of the runner.
|
|
669
|
+
attack_module_id (str): The new attack module to be used.
|
|
670
|
+
|
|
671
|
+
Returns:
|
|
672
|
+
bool: The status on whether the attack module is updated successfully.
|
|
673
|
+
|
|
674
|
+
Raises:
|
|
675
|
+
RuntimeError: If the database instance is not provided or if the attack module does not exist.
|
|
676
|
+
"""
|
|
677
|
+
if not db_instance:
|
|
678
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
679
|
+
if attack_module_id and not Storage.is_object_exists(
|
|
680
|
+
EnvVariables.ATTACK_MODULES.name, attack_module_id, "py"
|
|
681
|
+
):
|
|
682
|
+
raise RuntimeError(
|
|
683
|
+
f"[Session] Attack Module {attack_module_id} does not exist."
|
|
684
|
+
)
|
|
685
|
+
else:
|
|
686
|
+
Storage.update_database_record(
|
|
687
|
+
db_instance,
|
|
688
|
+
(attack_module_id, runner_id),
|
|
689
|
+
Session.sql_update_session_metadata_field.format("attack_module"),
|
|
690
|
+
)
|
|
691
|
+
return True
|
|
692
|
+
|
|
693
|
+
@staticmethod
|
|
694
|
+
def delete(database_instance: DBInterface | None) -> bool:
|
|
695
|
+
"""
|
|
696
|
+
Deletes the session metadata and associated endpoint tables from the database.
|
|
697
|
+
|
|
698
|
+
Args:
|
|
699
|
+
database_instance (DBInterface | None): The database instance to delete the session from.
|
|
700
|
+
|
|
701
|
+
Returns:
|
|
702
|
+
bool: The status on whether the session is deleted successfully.
|
|
703
|
+
|
|
704
|
+
Raises:
|
|
705
|
+
RuntimeError: If the database instance is not provided or if failed to get session metadata.
|
|
706
|
+
"""
|
|
707
|
+
if not database_instance:
|
|
708
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
709
|
+
|
|
710
|
+
session_metadata_info = Storage.read_database_records(
|
|
711
|
+
database_instance,
|
|
712
|
+
Session.sql_read_session_metadata,
|
|
713
|
+
)
|
|
714
|
+
if not session_metadata_info:
|
|
715
|
+
raise RuntimeError("[Session] Failed to get Session metadata.")
|
|
716
|
+
|
|
717
|
+
session_metadata_obj = SessionMetadata.from_tuple(session_metadata_info[0])
|
|
718
|
+
Storage.delete_database_table(
|
|
719
|
+
database_instance, Session.sql_drop_table.format("session_metadata_table")
|
|
720
|
+
)
|
|
721
|
+
for endpoint in session_metadata_obj.endpoints:
|
|
722
|
+
endpoint = endpoint.replace("-", "_")
|
|
723
|
+
Storage.delete_database_table(
|
|
724
|
+
database_instance, Session.sql_drop_table.format(endpoint)
|
|
725
|
+
)
|
|
726
|
+
return True
|
|
727
|
+
|
|
728
|
+
@staticmethod
|
|
729
|
+
def get_session_chats(database_instance: DBInterface | None) -> dict:
|
|
730
|
+
"""
|
|
731
|
+
Retrieves the chat history for all endpoints in a session.
|
|
732
|
+
|
|
733
|
+
Args:
|
|
734
|
+
database_instance (DBInterface | None): The database instance to retrieve the chat history from.
|
|
735
|
+
|
|
736
|
+
Raises:
|
|
737
|
+
RuntimeError: If the database instance is not provided.
|
|
738
|
+
|
|
739
|
+
Returns:
|
|
740
|
+
dict: A dictionary where the keys are endpoint IDs and the values are lists of chat history
|
|
741
|
+
for each endpoint.
|
|
742
|
+
"""
|
|
743
|
+
if not database_instance:
|
|
744
|
+
raise RuntimeError("[Session] Database instance not provided.")
|
|
745
|
+
|
|
746
|
+
session_metadata = Session.load(database_instance)
|
|
747
|
+
chats = {}
|
|
748
|
+
if session_metadata is not None and "endpoints" in session_metadata:
|
|
749
|
+
endpoint_list = session_metadata.get("endpoints", [])
|
|
750
|
+
for endpoint_id in endpoint_list:
|
|
751
|
+
list_of_chats_from_one_ep = Chat.load_chat_history(
|
|
752
|
+
database_instance, endpoint_id.replace("-", "_")
|
|
753
|
+
)
|
|
754
|
+
chats.update({endpoint_id: list_of_chats_from_one_ep})
|
|
755
|
+
return chats
|
|
756
|
+
|
|
757
|
+
def check_file_exists(
|
|
758
|
+
self, env_var_name: str, file_name: str, file_type: str, extension: str
|
|
759
|
+
) -> None:
|
|
760
|
+
"""
|
|
761
|
+
Checks if a specified file exists in the storage.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
env_var_name (str): The environment variable name where the file is stored.
|
|
765
|
+
file_name (str): The name of the file to check.
|
|
766
|
+
file_type (str): The type of the file.
|
|
767
|
+
extension (str): The extension of the file.
|
|
768
|
+
|
|
769
|
+
Raises:
|
|
770
|
+
RuntimeError: If the file does not exist in the storage.
|
|
771
|
+
"""
|
|
772
|
+
if file_name and not Storage.is_object_exists(
|
|
773
|
+
env_var_name, file_name, extension
|
|
774
|
+
):
|
|
775
|
+
raise RuntimeError(f"[Session] {file_type} {file_name} does not exist.")
|