aiverify-moonshot 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
  2. aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
  3. aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
  4. aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
  5. aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
  6. aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
  7. moonshot/__init__.py +0 -0
  8. moonshot/__main__.py +198 -0
  9. moonshot/api.py +155 -0
  10. moonshot/integrations/__init__.py +0 -0
  11. moonshot/integrations/cli/__init__.py +0 -0
  12. moonshot/integrations/cli/__main__.py +25 -0
  13. moonshot/integrations/cli/active_session_cfg.py +1 -0
  14. moonshot/integrations/cli/benchmark/__init__.py +0 -0
  15. moonshot/integrations/cli/benchmark/benchmark.py +186 -0
  16. moonshot/integrations/cli/benchmark/cookbook.py +545 -0
  17. moonshot/integrations/cli/benchmark/datasets.py +164 -0
  18. moonshot/integrations/cli/benchmark/metrics.py +141 -0
  19. moonshot/integrations/cli/benchmark/recipe.py +598 -0
  20. moonshot/integrations/cli/benchmark/result.py +216 -0
  21. moonshot/integrations/cli/benchmark/run.py +140 -0
  22. moonshot/integrations/cli/benchmark/runner.py +174 -0
  23. moonshot/integrations/cli/cli.py +64 -0
  24. moonshot/integrations/cli/common/__init__.py +0 -0
  25. moonshot/integrations/cli/common/common.py +72 -0
  26. moonshot/integrations/cli/common/connectors.py +325 -0
  27. moonshot/integrations/cli/common/display_helper.py +42 -0
  28. moonshot/integrations/cli/common/prompt_template.py +94 -0
  29. moonshot/integrations/cli/initialisation/__init__.py +0 -0
  30. moonshot/integrations/cli/initialisation/initialisation.py +14 -0
  31. moonshot/integrations/cli/redteam/__init__.py +0 -0
  32. moonshot/integrations/cli/redteam/attack_module.py +70 -0
  33. moonshot/integrations/cli/redteam/context_strategy.py +147 -0
  34. moonshot/integrations/cli/redteam/prompt_template.py +67 -0
  35. moonshot/integrations/cli/redteam/redteam.py +90 -0
  36. moonshot/integrations/cli/redteam/session.py +467 -0
  37. moonshot/integrations/web_api/.env.dev +7 -0
  38. moonshot/integrations/web_api/__init__.py +0 -0
  39. moonshot/integrations/web_api/__main__.py +56 -0
  40. moonshot/integrations/web_api/app.py +125 -0
  41. moonshot/integrations/web_api/container.py +146 -0
  42. moonshot/integrations/web_api/log/.gitkeep +0 -0
  43. moonshot/integrations/web_api/logging_conf.py +114 -0
  44. moonshot/integrations/web_api/routes/__init__.py +0 -0
  45. moonshot/integrations/web_api/routes/attack_modules.py +66 -0
  46. moonshot/integrations/web_api/routes/benchmark.py +116 -0
  47. moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
  48. moonshot/integrations/web_api/routes/context_strategy.py +129 -0
  49. moonshot/integrations/web_api/routes/cookbook.py +225 -0
  50. moonshot/integrations/web_api/routes/dataset.py +120 -0
  51. moonshot/integrations/web_api/routes/endpoint.py +282 -0
  52. moonshot/integrations/web_api/routes/metric.py +78 -0
  53. moonshot/integrations/web_api/routes/prompt_template.py +128 -0
  54. moonshot/integrations/web_api/routes/recipe.py +219 -0
  55. moonshot/integrations/web_api/routes/redteam.py +609 -0
  56. moonshot/integrations/web_api/routes/runner.py +239 -0
  57. moonshot/integrations/web_api/schemas/__init__.py +0 -0
  58. moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
  59. moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
  60. moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
  61. moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
  62. moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
  63. moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
  64. moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
  65. moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
  66. moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
  67. moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
  68. moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
  69. moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
  70. moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
  71. moonshot/integrations/web_api/services/__init__.py +0 -0
  72. moonshot/integrations/web_api/services/attack_module_service.py +34 -0
  73. moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
  74. moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
  75. moonshot/integrations/web_api/services/base_service.py +8 -0
  76. moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
  77. moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
  78. moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
  79. moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
  80. moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
  81. moonshot/integrations/web_api/services/cookbook_service.py +194 -0
  82. moonshot/integrations/web_api/services/dataset_service.py +20 -0
  83. moonshot/integrations/web_api/services/endpoint_service.py +65 -0
  84. moonshot/integrations/web_api/services/metric_service.py +14 -0
  85. moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
  86. moonshot/integrations/web_api/services/recipe_service.py +155 -0
  87. moonshot/integrations/web_api/services/runner_service.py +147 -0
  88. moonshot/integrations/web_api/services/session_service.py +350 -0
  89. moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
  90. moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
  91. moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
  92. moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
  93. moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
  94. moonshot/integrations/web_api/types/types.py +99 -0
  95. moonshot/src/__init__.py +0 -0
  96. moonshot/src/api/__init__.py +0 -0
  97. moonshot/src/api/api_connector.py +58 -0
  98. moonshot/src/api/api_connector_endpoint.py +162 -0
  99. moonshot/src/api/api_context_strategy.py +57 -0
  100. moonshot/src/api/api_cookbook.py +160 -0
  101. moonshot/src/api/api_dataset.py +46 -0
  102. moonshot/src/api/api_environment_variables.py +17 -0
  103. moonshot/src/api/api_metrics.py +51 -0
  104. moonshot/src/api/api_prompt_template.py +43 -0
  105. moonshot/src/api/api_recipe.py +182 -0
  106. moonshot/src/api/api_red_teaming.py +59 -0
  107. moonshot/src/api/api_result.py +84 -0
  108. moonshot/src/api/api_run.py +74 -0
  109. moonshot/src/api/api_runner.py +132 -0
  110. moonshot/src/api/api_session.py +290 -0
  111. moonshot/src/configs/__init__.py +0 -0
  112. moonshot/src/configs/env_variables.py +187 -0
  113. moonshot/src/connectors/__init__.py +0 -0
  114. moonshot/src/connectors/connector.py +327 -0
  115. moonshot/src/connectors/connector_prompt_arguments.py +17 -0
  116. moonshot/src/connectors_endpoints/__init__.py +0 -0
  117. moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
  118. moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
  119. moonshot/src/cookbooks/__init__.py +0 -0
  120. moonshot/src/cookbooks/cookbook.py +225 -0
  121. moonshot/src/cookbooks/cookbook_arguments.py +34 -0
  122. moonshot/src/datasets/__init__.py +0 -0
  123. moonshot/src/datasets/dataset.py +255 -0
  124. moonshot/src/datasets/dataset_arguments.py +50 -0
  125. moonshot/src/metrics/__init__.py +0 -0
  126. moonshot/src/metrics/metric.py +192 -0
  127. moonshot/src/metrics/metric_interface.py +95 -0
  128. moonshot/src/prompt_templates/__init__.py +0 -0
  129. moonshot/src/prompt_templates/prompt_template.py +103 -0
  130. moonshot/src/recipes/__init__.py +0 -0
  131. moonshot/src/recipes/recipe.py +340 -0
  132. moonshot/src/recipes/recipe_arguments.py +111 -0
  133. moonshot/src/redteaming/__init__.py +0 -0
  134. moonshot/src/redteaming/attack/__init__.py +0 -0
  135. moonshot/src/redteaming/attack/attack_module.py +618 -0
  136. moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
  137. moonshot/src/redteaming/attack/context_strategy.py +131 -0
  138. moonshot/src/redteaming/context_strategy/__init__.py +0 -0
  139. moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
  140. moonshot/src/redteaming/session/__init__.py +0 -0
  141. moonshot/src/redteaming/session/chat.py +209 -0
  142. moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
  143. moonshot/src/redteaming/session/red_teaming_type.py +6 -0
  144. moonshot/src/redteaming/session/session.py +775 -0
  145. moonshot/src/results/__init__.py +0 -0
  146. moonshot/src/results/result.py +119 -0
  147. moonshot/src/results/result_arguments.py +44 -0
  148. moonshot/src/runners/__init__.py +0 -0
  149. moonshot/src/runners/runner.py +476 -0
  150. moonshot/src/runners/runner_arguments.py +46 -0
  151. moonshot/src/runners/runner_type.py +6 -0
  152. moonshot/src/runs/__init__.py +0 -0
  153. moonshot/src/runs/run.py +344 -0
  154. moonshot/src/runs/run_arguments.py +162 -0
  155. moonshot/src/runs/run_progress.py +145 -0
  156. moonshot/src/runs/run_status.py +10 -0
  157. moonshot/src/storage/__init__.py +0 -0
  158. moonshot/src/storage/db_interface.py +128 -0
  159. moonshot/src/storage/io_interface.py +31 -0
  160. moonshot/src/storage/storage.py +525 -0
  161. moonshot/src/utils/__init__.py +0 -0
  162. moonshot/src/utils/import_modules.py +96 -0
  163. moonshot/src/utils/timeit.py +25 -0
@@ -0,0 +1,775 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import time
5
+ from ast import literal_eval
6
+ from datetime import datetime
7
+ from typing import Any, Callable
8
+
9
+ from slugify import slugify
10
+
11
+ from moonshot.src.configs.env_variables import EnvVariables
12
+ from moonshot.src.redteaming.session.chat import Chat
13
+ from moonshot.src.redteaming.session.red_teaming_progress import RedTeamingProgress
14
+ from moonshot.src.redteaming.session.red_teaming_type import RedTeamingType
15
+ from moonshot.src.runners.runner_type import RunnerType
16
+ from moonshot.src.runs.run_status import RunStatus
17
+ from moonshot.src.storage.db_interface import DBInterface
18
+ from moonshot.src.storage.storage import Storage
19
+ from moonshot.src.utils.import_modules import get_instance
20
+
21
+
22
+ class SessionMetadata:
23
+ # TODO: convert this into a pydantic model
24
+ def __init__(
25
+ self,
26
+ session_id: str,
27
+ endpoints: list[str],
28
+ created_epoch: float,
29
+ created_datetime: str,
30
+ prompt_template: str,
31
+ context_strategy: str,
32
+ cs_num_of_prev_prompts: int,
33
+ attack_module: str,
34
+ metric: str,
35
+ system_prompt: str,
36
+ ):
37
+ self.session_id = self.check_type(session_id, str)
38
+ self.endpoints = self.check_type(endpoints, list)
39
+ self.created_epoch = self.check_type(created_epoch, float)
40
+ self.created_datetime = self.check_type(created_datetime, str)
41
+ self.prompt_template = self.check_type(prompt_template, str)
42
+ self.context_strategy = self.check_type(context_strategy, str)
43
+ self.cs_num_of_prev_prompts = self.check_type(cs_num_of_prev_prompts, int)
44
+ self.attack_module = self.check_type(attack_module, str)
45
+ self.metric = self.check_type(metric, str)
46
+ self.system_prompt = self.check_type(system_prompt, str)
47
+
48
+ def to_dict(self) -> dict:
49
+ """
50
+ Converts the SessionMetadata instance into a dictionary.
51
+
52
+ Returns:
53
+ dict: A dictionary representation of the SessionMetadata instance.
54
+ """
55
+ return {
56
+ "session_id": self.session_id,
57
+ "endpoints": self.endpoints,
58
+ "created_epoch": str(self.created_epoch),
59
+ "created_datetime": self.created_datetime,
60
+ "prompt_template": self.prompt_template,
61
+ "context_strategy": self.context_strategy,
62
+ "cs_num_of_prev_prompts": self.cs_num_of_prev_prompts,
63
+ "attack_module": self.attack_module,
64
+ "metric": self.metric,
65
+ "system_prompt": self.system_prompt,
66
+ }
67
+
68
+ def to_tuple(self) -> tuple:
69
+ """
70
+ Converts the SessionMetadata instance into a tuple.
71
+
72
+ Returns:
73
+ tuple: A tuple representation of the SessionMetadata instance.
74
+ """
75
+ return (
76
+ self.session_id,
77
+ str(self.endpoints),
78
+ self.created_epoch,
79
+ self.created_datetime,
80
+ self.prompt_template,
81
+ self.context_strategy,
82
+ self.cs_num_of_prev_prompts,
83
+ self.attack_module,
84
+ self.metric,
85
+ self.system_prompt,
86
+ )
87
+
88
+ @classmethod
89
+ def from_tuple(cls, data_tuple: tuple) -> SessionMetadata:
90
+ """
91
+ Creates a SessionMetadata instance from a tuple using the class method.
92
+
93
+ Args:
94
+ data_tuple (tuple): A tuple containing session_id, endpoints, created_epoch, and created_datetime.
95
+
96
+ Returns:
97
+ SessionMetadata: An instance of SessionMetadata.
98
+ """
99
+ (
100
+ runner_id,
101
+ endpoints,
102
+ created_epoch,
103
+ created_datetime,
104
+ prompt_template,
105
+ context_strategy,
106
+ cs_num_of_prev_prompts,
107
+ attack_module,
108
+ metric,
109
+ system_prompt,
110
+ ) = data_tuple
111
+
112
+ return cls(
113
+ runner_id,
114
+ literal_eval(endpoints),
115
+ created_epoch,
116
+ created_datetime,
117
+ prompt_template,
118
+ context_strategy,
119
+ cs_num_of_prev_prompts,
120
+ attack_module,
121
+ metric,
122
+ system_prompt,
123
+ )
124
+
125
+ def check_type(self, checked_attribute: Any, expected_type: type) -> Any:
126
+ """
127
+ Checks if the type of the given attribute matches the expected type.
128
+
129
+ Args:
130
+ checked_attribute (Any): The attribute to be checked.
131
+ expected_type (type): The expected type of the attribute.
132
+
133
+ Returns:
134
+ Any: The checked attribute if its type matches the expected type.
135
+
136
+ Raises:
137
+ TypeError: If the type of the checked attribute does not match the expected type.
138
+ """
139
+ if not isinstance(checked_attribute, expected_type):
140
+ raise TypeError(
141
+ f"Expected type for {checked_attribute} is {expected_type}, but got {type(checked_attribute)}"
142
+ )
143
+ return checked_attribute
144
+
145
+
146
+ class Session:
147
+ DEFAULT_CONTEXT_STRATEGY_PROMPT = 5
148
+ sql_create_session_metadata_table = """
149
+ CREATE TABLE IF NOT EXISTS session_metadata_table (
150
+ session_id text PRIMARY KEY NOT NULL,
151
+ endpoints text NOT NULL,
152
+ created_epoch INTEGER NOT NULL,
153
+ created_datetime text NOT NULL,
154
+ prompt_template text,
155
+ context_strategy text,
156
+ cs_num_of_prev_prompts int,
157
+ attack_module text,
158
+ metric text,
159
+ system_prompt text
160
+ );
161
+ """
162
+
163
+ sql_create_chat_history_table = """
164
+ CREATE TABLE IF NOT EXISTS {} (
165
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
166
+ connection_id text NOT NULL,
167
+ context_strategy text,
168
+ prompt_template text,
169
+ attack_module text,
170
+ metric text,
171
+ prompt text NOT NULL,
172
+ prepared_prompt text NOT NULL,
173
+ system_prompt text,
174
+ predicted_result text NOT NULL,
175
+ duration text NOT NULL,
176
+ prompt_time text NOT NULL
177
+ );
178
+ """
179
+
180
+ sql_create_session_metadata_record = """
181
+ INSERT INTO session_metadata_table (
182
+ session_id,endpoints,created_epoch,created_datetime,prompt_template,context_strategy,cs_num_of_prev_prompts,
183
+ attack_module, metric, system_prompt) VALUES(?,?,?,?,?,?,?,?,?,?)
184
+ """
185
+
186
+ sql_read_session_metadata = """
187
+ SELECT * from session_metadata_table
188
+ """
189
+
190
+ sql_update_session_metadata_field = """
191
+ UPDATE session_metadata_table SET {}=? WHERE session_id=?
192
+ """
193
+
194
+ sql_drop_table = """
195
+ DROP TABLE IF EXISTS {}
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ runner_id: str,
201
+ runner_type: RunnerType,
202
+ runner_args: dict,
203
+ database_instance: Any | None,
204
+ endpoints: list[str],
205
+ results_file_path: str,
206
+ progress_callback_func: Callable | None = None,
207
+ ):
208
+ """
209
+ Initializes a new session with the given parameters, creates session metadata,
210
+ and sets up the database tables for session metadata and chat history.
211
+
212
+ Args:
213
+ runner_id (str): The unique identifier for the runner.
214
+ runner_type (RunnerType): The type of runner being used.
215
+ runner_args (dict): A dictionary of arguments specific to the runner.
216
+ database_instance (Any | None): The database instance to connect to, or None if not available.
217
+ endpoints (list[str]): A list of endpoint identifiers.
218
+ results_file_path (str): The file path where results should be stored.
219
+ progress_callback_func (Callable | None): An optional callback function for progress updates.
220
+ """
221
+ created_epoch = time.time()
222
+ created_datetime = datetime.fromtimestamp(created_epoch).strftime(
223
+ "%Y%m%d-%H%M%S"
224
+ )
225
+
226
+ self.runner_id = slugify(runner_id, lowercase=True)
227
+ if self.runner_id != runner_id:
228
+ raise RuntimeError(
229
+ "[Session] Failed to initialise Session. Invalid Runner ID."
230
+ )
231
+
232
+ self.runner_args = runner_args
233
+ self.runner_type = runner_type
234
+ self.results_file_path = results_file_path
235
+ self.progress_callback_func = progress_callback_func
236
+ self.database_instance = database_instance
237
+
238
+ self.red_teaming_progress = RedTeamingProgress(
239
+ self.runner_id, self.runner_args, self.progress_callback_func
240
+ )
241
+
242
+ self.cancel_event = asyncio.Event()
243
+
244
+ prompt_template = self.runner_args.get("prompt_template", "")
245
+ context_strategy = self.runner_args.get("context_strategy", "")
246
+ cs_num_of_prev_prompts = self.runner_args.get(
247
+ "cs_num_of_prev_prompts", Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
248
+ )
249
+ attack_module = self.runner_args.get("attack_module", "")
250
+ system_prompt = self.runner_args.get("system_prompt", "")
251
+ metric_id = self.runner_args.get("metric_id", "")
252
+
253
+ self.check_file_exists(
254
+ EnvVariables.PROMPT_TEMPLATES.name,
255
+ prompt_template,
256
+ "Prompt Template",
257
+ "json",
258
+ )
259
+ self.check_file_exists(
260
+ EnvVariables.CONTEXT_STRATEGY.name,
261
+ context_strategy,
262
+ "Context Strategy",
263
+ "py",
264
+ )
265
+ self.check_file_exists(
266
+ EnvVariables.ATTACK_MODULES.name, attack_module, "Attack Module", "py"
267
+ )
268
+ self.check_file_exists(EnvVariables.METRICS.name, metric_id, "Metric", "py")
269
+
270
+ if self.database_instance:
271
+ # create session metadata table if it does not exist
272
+ if not Storage.check_database_table_exists(
273
+ self.database_instance, "session_metadata_table"
274
+ ):
275
+ Storage.create_database_table(
276
+ self.database_instance, Session.sql_create_session_metadata_table
277
+ )
278
+
279
+ # get session metadata record
280
+ session_metadata_records = Storage.read_database_records(
281
+ self.database_instance, Session.sql_read_session_metadata
282
+ )
283
+
284
+ # check if the session metadata record already exists
285
+ if session_metadata_records:
286
+ print("[Session] Session already exists.")
287
+ self.session_metadata = SessionMetadata.from_tuple(
288
+ session_metadata_records[0]
289
+ )
290
+ # create a new record if session metadata does not exist
291
+ else:
292
+ print("[Session] Creating new session.")
293
+
294
+ # create chat history table for each endpoint
295
+
296
+ self.session_metadata = SessionMetadata(
297
+ runner_id,
298
+ endpoints,
299
+ created_epoch,
300
+ created_datetime,
301
+ prompt_template,
302
+ context_strategy,
303
+ cs_num_of_prev_prompts,
304
+ attack_module,
305
+ metric_id,
306
+ system_prompt,
307
+ )
308
+
309
+ for endpoint in endpoints:
310
+ endpoint_id = endpoint.replace("-", "_")
311
+ Storage.create_database_table(
312
+ self.database_instance,
313
+ Session.sql_create_chat_history_table.format(endpoint_id),
314
+ )
315
+
316
+ Storage.create_database_record(
317
+ self.database_instance,
318
+ self.session_metadata.to_tuple(),
319
+ Session.sql_create_session_metadata_record,
320
+ )
321
+ else:
322
+ raise RuntimeError(
323
+ "[Session] Failed to initialise Session. No database instance provided."
324
+ )
325
+
326
+ @staticmethod
327
+ def load(database_instance: DBInterface | None) -> dict | None:
328
+ """
329
+ Loads run data for a given session_id from the database, or the latest run if run_id is None.
330
+
331
+ This method retrieves run data for the specified run_id from the database, or if run_id is None,
332
+ it retrieves the latest run. If the database instance is not provided, it raises a RuntimeError.
333
+ If the database instance is provided, it invokes the read_record method of the database instance
334
+ with the given run_id or the latest run and returns a RunArguments object created from the retrieved record.
335
+
336
+ Parameters:
337
+ database_instance (DBInterface | None): The database accessor instance.
338
+ run_id (int | None): The ID of the run to retrieve, or None to retrieve the latest run.
339
+
340
+ Returns:
341
+ RunArguments: An object containing the details of the run with the given run_id or the latest run.
342
+ """
343
+ if not database_instance:
344
+ raise RuntimeError("[Session] Runner instance database not provided.")
345
+
346
+ # runner does not have session, return None
347
+ if not Storage.check_database_table_exists(
348
+ database_instance, "session_metadata_table"
349
+ ):
350
+ return None
351
+
352
+ # retrieve session metadata
353
+ session_metadata_info = Storage.read_database_records(
354
+ database_instance,
355
+ Session.sql_read_session_metadata,
356
+ )
357
+
358
+ if not session_metadata_info:
359
+ raise RuntimeError("[Session] Failed to get Session metadata.")
360
+
361
+ # convert session metadata from tuple to dict
362
+ session_metadata_obj = SessionMetadata.from_tuple(session_metadata_info[0])
363
+ session_metadata_dict = session_metadata_obj.to_dict()
364
+
365
+ return session_metadata_dict
366
+
367
+ async def run(self) -> list | None:
368
+ """
369
+ Asynchronously executes the session run process.
370
+
371
+ This method orchestrates the entire session run process asynchronously. It initializes the session,
372
+ sets up the necessary environment, executes the session's main logic, handles any errors, and finally,
373
+ compiles and returns the results in a dictionary format. Throughout the process, it updates
374
+ the session's status and logs progress.
375
+
376
+ Returns:
377
+ dict: A dictionary containing the results of the session run, including any errors encountered.
378
+ """
379
+ # ------------------------------------------------------------------------------
380
+ # Part 1: Get asyncio running loop
381
+ # ------------------------------------------------------------------------------
382
+ print("[Session] Part 1: Loading asyncio running loop...")
383
+ loop = asyncio.get_running_loop()
384
+
385
+ # ------------------------------------------------------------------------------
386
+ # Part 2: Load runner processing module
387
+ # ------------------------------------------------------------------------------
388
+ print("[Session] Part 2: Loading runner processing module...")
389
+ start_time = time.perf_counter()
390
+ runner_module_instance = None
391
+ try:
392
+ runner_processing_module_name = self.runner_args.get(
393
+ "runner_processing_module", None
394
+ )
395
+ if runner_processing_module_name:
396
+ # Intialize the runner instance
397
+ runner_module_instance = get_instance(
398
+ runner_processing_module_name,
399
+ Storage.get_filepath(
400
+ EnvVariables.RUNNERS_MODULES.name,
401
+ runner_processing_module_name,
402
+ "py",
403
+ ),
404
+ )
405
+ if runner_module_instance:
406
+ runner_module_instance = runner_module_instance()
407
+ else:
408
+ raise RuntimeError(
409
+ f"Unable to get defined runner module instance - {runner_module_instance}"
410
+ )
411
+ else:
412
+ raise RuntimeError(
413
+ f"Failed to get runner processing module name: {runner_processing_module_name}"
414
+ )
415
+
416
+ except Exception as e:
417
+ print(
418
+ f"[Session] Failed to load runner processing module in Part 2 due to error: {str(e)}"
419
+ )
420
+ raise e
421
+
422
+ finally:
423
+ print(
424
+ f"[Session] Loading runner processing module took {(time.perf_counter() - start_time):.4f}s"
425
+ )
426
+
427
+ # ------------------------------------------------------------------------------
428
+ # Part 3: Run runner processing module
429
+ # ------------------------------------------------------------------------------
430
+ print("[Session] Part 3: Running runner processing module...")
431
+ start_time = time.perf_counter()
432
+ runner_results = {}
433
+
434
+ try:
435
+ if runner_module_instance:
436
+ runner_results = await runner_module_instance.generate( # type: ignore ; ducktyping
437
+ loop,
438
+ self.runner_args,
439
+ self.database_instance,
440
+ self.session_metadata,
441
+ self.check_redteaming_type(),
442
+ self.red_teaming_progress,
443
+ self.cancel_event,
444
+ )
445
+ else:
446
+ raise RuntimeError("Failed to initialise runner module instance.")
447
+
448
+ except Exception as e:
449
+ print(
450
+ f"[Session] Failed to run runner processing module in Part 3 due to error: {str(e)}"
451
+ )
452
+ raise e
453
+
454
+ finally:
455
+ self.red_teaming_progress.status = RunStatus.COMPLETED
456
+ if self.check_redteaming_type() == RedTeamingType.AUTOMATED:
457
+ self.red_teaming_progress.notify_progress()
458
+ print(
459
+ f"[Session] Running runner processing module took {(time.perf_counter() - start_time):.4f}s"
460
+ )
461
+
462
+ # ------------------------------------------------------------------------------
463
+ # Part 4: Wrap up run
464
+ # ------------------------------------------------------------------------------
465
+ print("[Session] Part 4: Wrap up run...")
466
+ return runner_results
467
+
468
+ def cancel(self) -> None:
469
+ """
470
+ Sets the cancel event to stop the automated red teaming process.
471
+
472
+ This method is used to signal that the automated red teaming process should be cancelled. It sets the
473
+ cancel_event which can be checked in various points of the asynchronous red teaming process to gracefully stop
474
+ the execution.
475
+
476
+ Returns:
477
+ None
478
+ """
479
+ print("[Session] Cancelling automated red teaming...")
480
+ self.cancel_event.set()
481
+
482
+ def check_redteaming_type(self) -> RedTeamingType:
483
+ """
484
+ Checks the type of red teaming strategy based on the runner arguments.
485
+
486
+ Returns:
487
+ RedTeamingType: The type of red teaming strategy.
488
+
489
+ Raises:
490
+ RuntimeError: If the red teaming arguments are missing.
491
+ """
492
+ if (
493
+ "attack_strategies" in self.runner_args
494
+ and self.runner_args.get("attack_strategies") is not None
495
+ ):
496
+ return RedTeamingType.AUTOMATED
497
+ elif (
498
+ "manual_rt_args" in self.runner_args
499
+ and self.runner_args.get("manual_rt_args") is not None
500
+ ):
501
+ return RedTeamingType.MANUAL
502
+ else:
503
+ raise RuntimeError("Missing red teaming arguments.")
504
+
505
+ @staticmethod
506
+ def update_context_strategy(
507
+ db_instance: DBInterface | None, runner_id: str, context_strategy: str
508
+ ) -> bool:
509
+ """
510
+ Updates the context strategy for a specific runner in the database.
511
+
512
+ Args:
513
+ db_instance (DBInterface | None): The database instance to update the context strategy in.
514
+ runner_id (str): The ID of the runner.
515
+ context_strategy (str): The name of the context strategy to be used.
516
+
517
+ Returns:
518
+ bool: The status on whether the context strategy is updated successfully.
519
+
520
+ Raises:
521
+ RuntimeError: If the database instance is not provided or if the context strategy does not exist.
522
+ """
523
+ if not db_instance:
524
+ raise RuntimeError("[Session] Database instance not provided.")
525
+ if context_strategy and not Storage.is_object_exists(
526
+ EnvVariables.CONTEXT_STRATEGY.name, context_strategy, "py"
527
+ ):
528
+ raise RuntimeError(
529
+ f"[Session] Context Strategy {context_strategy} does not exist."
530
+ )
531
+ else:
532
+ Storage.update_database_record(
533
+ db_instance,
534
+ (context_strategy, runner_id),
535
+ Session.sql_update_session_metadata_field.format("context_strategy"),
536
+ )
537
+ return True
538
+
539
+ @staticmethod
540
+ def update_cs_num_of_prev_prompts(
541
+ db_instance: DBInterface | None, runner_id: str, cs_num_of_prev_prompts: int
542
+ ) -> bool:
543
+ """
544
+ Updates the number of previous prompts for a specific runner in the database.
545
+
546
+ Args:
547
+ db_instance (DBInterface | None): The database instance to update the number of previous prompts in.
548
+ runner_id (str): The ID of the runner.
549
+ cs_num_of_prev_prompts (int): The new number of previous prompts to be used.
550
+
551
+ Returns:
552
+ bool: The status on whether the number of prompts for context strategy is updated successfully.
553
+
554
+ Raises:
555
+ RuntimeError: If the database instance is not provided.
556
+ """
557
+ if not db_instance:
558
+ raise RuntimeError("[Session] Database instance not provided.")
559
+ Storage.update_database_record(
560
+ db_instance,
561
+ (cs_num_of_prev_prompts, runner_id),
562
+ Session.sql_update_session_metadata_field.format("cs_num_of_prev_prompts"),
563
+ )
564
+ return True
565
+
566
+ @staticmethod
567
+ def update_prompt_template(
568
+ db_instance: DBInterface | None, runner_id: str, prompt_template: str
569
+ ) -> bool:
570
+ """
571
+ Updates the prompt template in the database for the specified runner.
572
+
573
+ Args:
574
+ db_instance (DBInterface | None): The database instance to update the prompt template in.
575
+ runner_id (str): The ID of the runner.
576
+ prompt_template (str): The new prompt template to be used.
577
+
578
+ Returns:
579
+ bool: The status on whether the prompt template is updated successfully.
580
+
581
+ Raises:
582
+ RuntimeError: If the database instance is not provided or if the prompt template does not exist.
583
+ """
584
+ if not db_instance:
585
+ raise RuntimeError("[Session] Database instance not provided.")
586
+ if prompt_template and not Storage.is_object_exists(
587
+ EnvVariables.PROMPT_TEMPLATES.name, prompt_template, "json"
588
+ ):
589
+ raise RuntimeError(
590
+ f"[Session] Prompt Template {prompt_template} does not exist."
591
+ )
592
+
593
+ Storage.update_database_record(
594
+ db_instance,
595
+ (prompt_template, runner_id),
596
+ Session.sql_update_session_metadata_field.format("prompt_template"),
597
+ )
598
+ return True
599
+
600
+ @staticmethod
601
+ def update_metric(
602
+ db_instance: DBInterface | None, runner_id: str, metric_id: str
603
+ ) -> bool:
604
+ """
605
+ Updates the metric in the database for the specified runner.
606
+
607
+ Args:
608
+ db_instance (DBInterface | None): The database instance to update the metric in.
609
+ runner_id (str): The ID of the runner.
610
+ metric_id (str): The new metric to be used.
611
+
612
+ Returns:
613
+ bool: The status on whether the metric is updated successfully.
614
+
615
+ Raises:
616
+ RuntimeError: If the database instance is not provided or if the metric does not exist.
617
+ """
618
+ if not db_instance:
619
+ raise RuntimeError("[Session] Database instance not provided.")
620
+ if metric_id and not Storage.is_object_exists(
621
+ EnvVariables.METRICS.name, metric_id, "py"
622
+ ):
623
+ raise RuntimeError(f"[Session] Metric {metric_id} does not exist.")
624
+ else:
625
+ Storage.update_database_record(
626
+ db_instance,
627
+ (metric_id, runner_id),
628
+ Session.sql_update_session_metadata_field.format("metric"),
629
+ )
630
+ return True
631
+
632
+ @staticmethod
633
+ def update_system_prompt(
634
+ db_instance: DBInterface | None, runner_id: str, system_prompt: str
635
+ ) -> bool:
636
+ """
637
+ Updates the system prompt in the database for the specified runner.
638
+
639
+ Args:
640
+ db_instance (DBInterface | None): The database instance to update the system prompt in.
641
+ runner_id (str): The ID of the runner.
642
+ system_prompt (str): The new system prompt to be used.
643
+
644
+ Returns:
645
+ bool: The status on whether the system prompt is updated successfully.
646
+
647
+ Raises:
648
+ RuntimeError: If the database instance is not provided.
649
+ """
650
+ if not db_instance:
651
+ raise RuntimeError("[Session] Database instance not provided.")
652
+ Storage.update_database_record(
653
+ db_instance,
654
+ (system_prompt, runner_id),
655
+ Session.sql_update_session_metadata_field.format("system_prompt"),
656
+ )
657
+ return True
658
+
659
+ @staticmethod
660
+ def update_attack_module(
661
+ db_instance: DBInterface | None, runner_id: str, attack_module_id: str
662
+ ) -> bool:
663
+ """
664
+ Updates the attack module in the database for the specified runner.
665
+
666
+ Args:
667
+ db_instance (DBInterface | None): The database instance to update the attack module in.
668
+ runner_id (str): The ID of the runner.
669
+ attack_module_id (str): The new attack module to be used.
670
+
671
+ Returns:
672
+ bool: The status on whether the attack module is updated successfully.
673
+
674
+ Raises:
675
+ RuntimeError: If the database instance is not provided or if the attack module does not exist.
676
+ """
677
+ if not db_instance:
678
+ raise RuntimeError("[Session] Database instance not provided.")
679
+ if attack_module_id and not Storage.is_object_exists(
680
+ EnvVariables.ATTACK_MODULES.name, attack_module_id, "py"
681
+ ):
682
+ raise RuntimeError(
683
+ f"[Session] Attack Module {attack_module_id} does not exist."
684
+ )
685
+ else:
686
+ Storage.update_database_record(
687
+ db_instance,
688
+ (attack_module_id, runner_id),
689
+ Session.sql_update_session_metadata_field.format("attack_module"),
690
+ )
691
+ return True
692
+
693
+ @staticmethod
694
+ def delete(database_instance: DBInterface | None) -> bool:
695
+ """
696
+ Deletes the session metadata and associated endpoint tables from the database.
697
+
698
+ Args:
699
+ database_instance (DBInterface | None): The database instance to delete the session from.
700
+
701
+ Returns:
702
+ bool: The status on whether the session is deleted successfully.
703
+
704
+ Raises:
705
+ RuntimeError: If the database instance is not provided or if failed to get session metadata.
706
+ """
707
+ if not database_instance:
708
+ raise RuntimeError("[Session] Database instance not provided.")
709
+
710
+ session_metadata_info = Storage.read_database_records(
711
+ database_instance,
712
+ Session.sql_read_session_metadata,
713
+ )
714
+ if not session_metadata_info:
715
+ raise RuntimeError("[Session] Failed to get Session metadata.")
716
+
717
+ session_metadata_obj = SessionMetadata.from_tuple(session_metadata_info[0])
718
+ Storage.delete_database_table(
719
+ database_instance, Session.sql_drop_table.format("session_metadata_table")
720
+ )
721
+ for endpoint in session_metadata_obj.endpoints:
722
+ endpoint = endpoint.replace("-", "_")
723
+ Storage.delete_database_table(
724
+ database_instance, Session.sql_drop_table.format(endpoint)
725
+ )
726
+ return True
727
+
728
+ @staticmethod
729
+ def get_session_chats(database_instance: DBInterface | None) -> dict:
730
+ """
731
+ Retrieves the chat history for all endpoints in a session.
732
+
733
+ Args:
734
+ database_instance (DBInterface | None): The database instance to retrieve the chat history from.
735
+
736
+ Raises:
737
+ RuntimeError: If the database instance is not provided.
738
+
739
+ Returns:
740
+ dict: A dictionary where the keys are endpoint IDs and the values are lists of chat history
741
+ for each endpoint.
742
+ """
743
+ if not database_instance:
744
+ raise RuntimeError("[Session] Database instance not provided.")
745
+
746
+ session_metadata = Session.load(database_instance)
747
+ chats = {}
748
+ if session_metadata is not None and "endpoints" in session_metadata:
749
+ endpoint_list = session_metadata.get("endpoints", [])
750
+ for endpoint_id in endpoint_list:
751
+ list_of_chats_from_one_ep = Chat.load_chat_history(
752
+ database_instance, endpoint_id.replace("-", "_")
753
+ )
754
+ chats.update({endpoint_id: list_of_chats_from_one_ep})
755
+ return chats
756
+
757
+ def check_file_exists(
758
+ self, env_var_name: str, file_name: str, file_type: str, extension: str
759
+ ) -> None:
760
+ """
761
+ Checks if a specified file exists in the storage.
762
+
763
+ Args:
764
+ env_var_name (str): The environment variable name where the file is stored.
765
+ file_name (str): The name of the file to check.
766
+ file_type (str): The type of the file.
767
+ extension (str): The extension of the file.
768
+
769
+ Raises:
770
+ RuntimeError: If the file does not exist in the storage.
771
+ """
772
+ if file_name and not Storage.is_object_exists(
773
+ env_var_name, file_name, extension
774
+ ):
775
+ raise RuntimeError(f"[Session] {file_type} {file_name} does not exist.")