aiverify-moonshot 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
  2. aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
  3. aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
  4. aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
  5. aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
  6. aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
  7. moonshot/__init__.py +0 -0
  8. moonshot/__main__.py +198 -0
  9. moonshot/api.py +155 -0
  10. moonshot/integrations/__init__.py +0 -0
  11. moonshot/integrations/cli/__init__.py +0 -0
  12. moonshot/integrations/cli/__main__.py +25 -0
  13. moonshot/integrations/cli/active_session_cfg.py +1 -0
  14. moonshot/integrations/cli/benchmark/__init__.py +0 -0
  15. moonshot/integrations/cli/benchmark/benchmark.py +186 -0
  16. moonshot/integrations/cli/benchmark/cookbook.py +545 -0
  17. moonshot/integrations/cli/benchmark/datasets.py +164 -0
  18. moonshot/integrations/cli/benchmark/metrics.py +141 -0
  19. moonshot/integrations/cli/benchmark/recipe.py +598 -0
  20. moonshot/integrations/cli/benchmark/result.py +216 -0
  21. moonshot/integrations/cli/benchmark/run.py +140 -0
  22. moonshot/integrations/cli/benchmark/runner.py +174 -0
  23. moonshot/integrations/cli/cli.py +64 -0
  24. moonshot/integrations/cli/common/__init__.py +0 -0
  25. moonshot/integrations/cli/common/common.py +72 -0
  26. moonshot/integrations/cli/common/connectors.py +325 -0
  27. moonshot/integrations/cli/common/display_helper.py +42 -0
  28. moonshot/integrations/cli/common/prompt_template.py +94 -0
  29. moonshot/integrations/cli/initialisation/__init__.py +0 -0
  30. moonshot/integrations/cli/initialisation/initialisation.py +14 -0
  31. moonshot/integrations/cli/redteam/__init__.py +0 -0
  32. moonshot/integrations/cli/redteam/attack_module.py +70 -0
  33. moonshot/integrations/cli/redteam/context_strategy.py +147 -0
  34. moonshot/integrations/cli/redteam/prompt_template.py +67 -0
  35. moonshot/integrations/cli/redteam/redteam.py +90 -0
  36. moonshot/integrations/cli/redteam/session.py +467 -0
  37. moonshot/integrations/web_api/.env.dev +7 -0
  38. moonshot/integrations/web_api/__init__.py +0 -0
  39. moonshot/integrations/web_api/__main__.py +56 -0
  40. moonshot/integrations/web_api/app.py +125 -0
  41. moonshot/integrations/web_api/container.py +146 -0
  42. moonshot/integrations/web_api/log/.gitkeep +0 -0
  43. moonshot/integrations/web_api/logging_conf.py +114 -0
  44. moonshot/integrations/web_api/routes/__init__.py +0 -0
  45. moonshot/integrations/web_api/routes/attack_modules.py +66 -0
  46. moonshot/integrations/web_api/routes/benchmark.py +116 -0
  47. moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
  48. moonshot/integrations/web_api/routes/context_strategy.py +129 -0
  49. moonshot/integrations/web_api/routes/cookbook.py +225 -0
  50. moonshot/integrations/web_api/routes/dataset.py +120 -0
  51. moonshot/integrations/web_api/routes/endpoint.py +282 -0
  52. moonshot/integrations/web_api/routes/metric.py +78 -0
  53. moonshot/integrations/web_api/routes/prompt_template.py +128 -0
  54. moonshot/integrations/web_api/routes/recipe.py +219 -0
  55. moonshot/integrations/web_api/routes/redteam.py +609 -0
  56. moonshot/integrations/web_api/routes/runner.py +239 -0
  57. moonshot/integrations/web_api/schemas/__init__.py +0 -0
  58. moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
  59. moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
  60. moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
  61. moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
  62. moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
  63. moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
  64. moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
  65. moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
  66. moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
  67. moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
  68. moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
  69. moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
  70. moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
  71. moonshot/integrations/web_api/services/__init__.py +0 -0
  72. moonshot/integrations/web_api/services/attack_module_service.py +34 -0
  73. moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
  74. moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
  75. moonshot/integrations/web_api/services/base_service.py +8 -0
  76. moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
  77. moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
  78. moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
  79. moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
  80. moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
  81. moonshot/integrations/web_api/services/cookbook_service.py +194 -0
  82. moonshot/integrations/web_api/services/dataset_service.py +20 -0
  83. moonshot/integrations/web_api/services/endpoint_service.py +65 -0
  84. moonshot/integrations/web_api/services/metric_service.py +14 -0
  85. moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
  86. moonshot/integrations/web_api/services/recipe_service.py +155 -0
  87. moonshot/integrations/web_api/services/runner_service.py +147 -0
  88. moonshot/integrations/web_api/services/session_service.py +350 -0
  89. moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
  90. moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
  91. moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
  92. moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
  93. moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
  94. moonshot/integrations/web_api/types/types.py +99 -0
  95. moonshot/src/__init__.py +0 -0
  96. moonshot/src/api/__init__.py +0 -0
  97. moonshot/src/api/api_connector.py +58 -0
  98. moonshot/src/api/api_connector_endpoint.py +162 -0
  99. moonshot/src/api/api_context_strategy.py +57 -0
  100. moonshot/src/api/api_cookbook.py +160 -0
  101. moonshot/src/api/api_dataset.py +46 -0
  102. moonshot/src/api/api_environment_variables.py +17 -0
  103. moonshot/src/api/api_metrics.py +51 -0
  104. moonshot/src/api/api_prompt_template.py +43 -0
  105. moonshot/src/api/api_recipe.py +182 -0
  106. moonshot/src/api/api_red_teaming.py +59 -0
  107. moonshot/src/api/api_result.py +84 -0
  108. moonshot/src/api/api_run.py +74 -0
  109. moonshot/src/api/api_runner.py +132 -0
  110. moonshot/src/api/api_session.py +290 -0
  111. moonshot/src/configs/__init__.py +0 -0
  112. moonshot/src/configs/env_variables.py +187 -0
  113. moonshot/src/connectors/__init__.py +0 -0
  114. moonshot/src/connectors/connector.py +327 -0
  115. moonshot/src/connectors/connector_prompt_arguments.py +17 -0
  116. moonshot/src/connectors_endpoints/__init__.py +0 -0
  117. moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
  118. moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
  119. moonshot/src/cookbooks/__init__.py +0 -0
  120. moonshot/src/cookbooks/cookbook.py +225 -0
  121. moonshot/src/cookbooks/cookbook_arguments.py +34 -0
  122. moonshot/src/datasets/__init__.py +0 -0
  123. moonshot/src/datasets/dataset.py +255 -0
  124. moonshot/src/datasets/dataset_arguments.py +50 -0
  125. moonshot/src/metrics/__init__.py +0 -0
  126. moonshot/src/metrics/metric.py +192 -0
  127. moonshot/src/metrics/metric_interface.py +95 -0
  128. moonshot/src/prompt_templates/__init__.py +0 -0
  129. moonshot/src/prompt_templates/prompt_template.py +103 -0
  130. moonshot/src/recipes/__init__.py +0 -0
  131. moonshot/src/recipes/recipe.py +340 -0
  132. moonshot/src/recipes/recipe_arguments.py +111 -0
  133. moonshot/src/redteaming/__init__.py +0 -0
  134. moonshot/src/redteaming/attack/__init__.py +0 -0
  135. moonshot/src/redteaming/attack/attack_module.py +618 -0
  136. moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
  137. moonshot/src/redteaming/attack/context_strategy.py +131 -0
  138. moonshot/src/redteaming/context_strategy/__init__.py +0 -0
  139. moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
  140. moonshot/src/redteaming/session/__init__.py +0 -0
  141. moonshot/src/redteaming/session/chat.py +209 -0
  142. moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
  143. moonshot/src/redteaming/session/red_teaming_type.py +6 -0
  144. moonshot/src/redteaming/session/session.py +775 -0
  145. moonshot/src/results/__init__.py +0 -0
  146. moonshot/src/results/result.py +119 -0
  147. moonshot/src/results/result_arguments.py +44 -0
  148. moonshot/src/runners/__init__.py +0 -0
  149. moonshot/src/runners/runner.py +476 -0
  150. moonshot/src/runners/runner_arguments.py +46 -0
  151. moonshot/src/runners/runner_type.py +6 -0
  152. moonshot/src/runs/__init__.py +0 -0
  153. moonshot/src/runs/run.py +344 -0
  154. moonshot/src/runs/run_arguments.py +162 -0
  155. moonshot/src/runs/run_progress.py +145 -0
  156. moonshot/src/runs/run_status.py +10 -0
  157. moonshot/src/storage/__init__.py +0 -0
  158. moonshot/src/storage/db_interface.py +128 -0
  159. moonshot/src/storage/io_interface.py +31 -0
  160. moonshot/src/storage/storage.py +525 -0
  161. moonshot/src/utils/__init__.py +0 -0
  162. moonshot/src/utils/import_modules.py +96 -0
  163. moonshot/src/utils/timeit.py +25 -0
@@ -0,0 +1,618 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import abstractmethod
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Any, AsyncGenerator
7
+
8
+ from jinja2 import Template
9
+ from pydantic import BaseModel
10
+
11
+ from moonshot.src.configs.env_variables import EnvVariables
12
+ from moonshot.src.connectors.connector import Connector
13
+ from moonshot.src.connectors.connector_prompt_arguments import ConnectorPromptArguments
14
+ from moonshot.src.connectors_endpoints.connector_endpoint import ConnectorEndpoint
15
+ from moonshot.src.metrics.metric import Metric
16
+ from moonshot.src.redteaming.attack.attack_module_arguments import AttackModuleArguments
17
+ from moonshot.src.redteaming.attack.context_strategy import ContextStrategy
18
+ from moonshot.src.runs.run_status import RunStatus
19
+ from moonshot.src.storage.storage import Storage
20
+ from moonshot.src.utils.import_modules import get_instance
21
+
22
+
23
+ class AttackModule:
24
+ cache_name = "cache"
25
+ cache_extension = "json"
26
+ sql_create_chat_record = """
27
+ INSERT INTO {} (connection_id,context_strategy,prompt_template,attack_module,
28
+ metric,prompt,prepared_prompt,system_prompt,predicted_result,duration,prompt_time)VALUES(?,?,?,?,?,?,?,?,?,?,?)
29
+ """
30
+
31
+ def __init__(self, am_id: str, am_arguments: AttackModuleArguments | None = None):
32
+ self.id = am_id
33
+ if am_arguments is not None:
34
+ self.connector_ids = am_arguments.connector_ids
35
+ self.prompt_templates = am_arguments.prompt_templates
36
+ self.prompt = am_arguments.prompt
37
+ self.system_prompt = am_arguments.system_prompt
38
+ self.metric_ids = am_arguments.metric_ids
39
+ self.context_strategy_info = am_arguments.context_strategy_info
40
+ self.db_instance = am_arguments.db_instance
41
+ self.red_teaming_progress = am_arguments.red_teaming_progress
42
+ self.cancel_event = am_arguments.cancel_event
43
+ self.params = am_arguments.params
44
+
45
+ @classmethod
46
+ def load(
47
+ cls, am_id: str, am_arguments: AttackModuleArguments | None = None
48
+ ) -> AttackModule:
49
+ """
50
+ Retrieves an attack module instance by its ID.
51
+
52
+ This method attempts to load an attack module instance using the provided ID. If the attack module instance
53
+ is found, it is returned. If the attack module instance does not exist, a RuntimeError is raised.
54
+
55
+ Args:
56
+ am_id (str): The unique identifier of the attack module to be retrieved.
57
+
58
+ Returns:
59
+ AttackModule: The retrieved attack module instance.
60
+
61
+ Raises:
62
+ RuntimeError: If the attack module instance does not exist.
63
+ """
64
+ attack_module_inst = get_instance(
65
+ am_id,
66
+ Storage.get_filepath(EnvVariables.ATTACK_MODULES.name, am_id, "py"),
67
+ )
68
+
69
+ if attack_module_inst:
70
+ return attack_module_inst(am_id, am_arguments)
71
+ else:
72
+ raise RuntimeError(
73
+ f"Unable to get defined attack module instance - {am_id}"
74
+ )
75
+
76
+ @abstractmethod
77
+ def get_metadata(self) -> dict:
78
+ """
79
+ Get metadata for the attack module.
80
+
81
+ Returns a dictionary of the attack module metadata.
82
+ Returns:
83
+ dict: A dictionary containing the metadata of the attack module.
84
+ """
85
+ pass
86
+
87
+ async def _generate_prompts(
88
+ self, prompt: str, target_llm_connector_id: str
89
+ ) -> AsyncGenerator[RedTeamingPromptArguments, None]:
90
+ """
91
+ Generates prompts for red teaming.
92
+
93
+ This method asynchronously generates prompts for red teaming based on the provided prompt and target LLM
94
+ connector ID. It processes the prompt using context strategy and prompt template if specified, and
95
+ yields RedTeamingPromptArguments for each generated prompt.
96
+
97
+ Args:
98
+ prompt (str): The prompt to be processed and sent to the target LLM.
99
+ target_llm_connector_id (str): The unique identifier of the target LLM connector.
100
+
101
+ Yields:
102
+ RedTeamingPromptArguments: An instance of RedTeamingPromptArguments containing the
103
+ generated prompt details.
104
+
105
+ """
106
+ if self.context_strategy_info:
107
+ context_strategy_instance = self.context_strategy_instances[0]
108
+ num_of_prev_prompts = self.context_strategy_info[0].get(
109
+ "num_of_prev_prompts"
110
+ )
111
+ prompt = ContextStrategy.process_prompt_cs(
112
+ prompt,
113
+ context_strategy_instance.id,
114
+ self.db_instance,
115
+ target_llm_connector_id,
116
+ num_of_prev_prompts,
117
+ )
118
+ if self.prompt_templates:
119
+ # prepare prompt template generator
120
+ pt_id = self.prompt_templates[0]
121
+ pt_info = Storage.read_object_with_iterator(
122
+ EnvVariables.PROMPT_TEMPLATES.name,
123
+ pt_id,
124
+ "json",
125
+ iterator_keys=["template"],
126
+ )
127
+ pt = next(pt_info["template"])
128
+ jinja2_template = Template(pt)
129
+ prompt = jinja2_template.render({"prompt": prompt})
130
+
131
+ yield RedTeamingPromptArguments(
132
+ conn_id=target_llm_connector_id,
133
+ am_id=self.id,
134
+ cs_id=self.context_strategy_instances[0].id
135
+ if self.context_strategy_info
136
+ else "",
137
+ pt_id=self.prompt_templates[0] if self.prompt_templates else "",
138
+ me_id=self.metric_ids[0] if self.metric_ids else "",
139
+ original_prompt=self.prompt,
140
+ system_prompt=self.system_prompt,
141
+ start_time="",
142
+ connector_prompt=ConnectorPromptArguments(
143
+ prompt_index=0,
144
+ prompt=prompt,
145
+ target="",
146
+ ),
147
+ )
148
+
149
+ async def _send_prompt_to_all_llm_default(self) -> list:
150
+ """
151
+ NOTE: this method does not currently handle callbacks
152
+ Asynchronously sends the default prompt to all Language Learning Models (LLMs).
153
+
154
+ This method generates prompts by appending the contents of the prompt template and modifies the prompt with the
155
+ context strategy for each LLM, sends each prompt to the respective LLM, and consolidates the responses into a
156
+ list.
157
+
158
+ Returns:
159
+ list: A list of consolidated responses from all LLMs.
160
+ """
161
+ generator_list = []
162
+ consolidated_result_list = []
163
+ if self.connector_ids:
164
+ for target_llm_connector in self.connector_instances:
165
+ gen_prompts_generator = self._generate_prompts(
166
+ self.prompt, target_llm_connector.id
167
+ )
168
+ gen_results_generator = self._generate_predictions(
169
+ gen_prompts_generator, target_llm_connector
170
+ )
171
+ generator_list.append(gen_results_generator)
172
+
173
+ for generator in generator_list:
174
+ async for result in generator:
175
+ if self.cancel_event.is_set():
176
+ print(
177
+ "[Red Teaming] Cancellation flag is set. Cancelling task..."
178
+ )
179
+ break
180
+ consolidated_result_list.append(result)
181
+ return consolidated_result_list
182
+
183
+ async def _send_prompt_to_all_llm(self, list_of_prompts: list) -> list:
184
+ """
185
+ Asynchronously sends prompts to all Language Learning Models (LLMs).
186
+
187
+ This method takes a list of prompts, sends each prompt to all LLM connectors, records the responses,
188
+ and returns a list of consolidated responses.
189
+
190
+ Args:
191
+ list_of_prompts (list): A list of prompts to be sent to the LLM connectors.
192
+
193
+ Returns:
194
+ list: A list of consolidated responses from all LLM connectors.
195
+ """
196
+ consolidated_responses = []
197
+ for prepared_prompt in list_of_prompts:
198
+ for target_llm_connector in self.connector_instances:
199
+ if self.cancel_event.is_set():
200
+ print("[Red Teaming] Cancellation flag is set. Cancelling task...")
201
+ break
202
+
203
+ if self.red_teaming_progress:
204
+ self.red_teaming_progress.update_red_teaming_progress()
205
+
206
+ new_prompt_info = ConnectorPromptArguments(
207
+ prompt_index=1, prompt=prepared_prompt, target=""
208
+ )
209
+ start_time = datetime.now()
210
+ response = await Connector.get_prediction(
211
+ new_prompt_info, target_llm_connector
212
+ )
213
+ consolidated_responses.append(response)
214
+
215
+ red_teaming_prompt_arguments = RedTeamingPromptArguments(
216
+ conn_id=target_llm_connector.id,
217
+ am_id=self.id,
218
+ cs_id=self.context_strategy_instances[0].id
219
+ if self.context_strategy_info
220
+ else "",
221
+ me_id=self.metric_ids[0] if self.metric_ids else "",
222
+ pt_id=self.prompt_templates[0] if self.prompt_templates else "",
223
+ original_prompt=self.prompt, # original prompt
224
+ system_prompt=self.system_prompt, # system prompt
225
+ start_time=str(start_time),
226
+ connector_prompt=response,
227
+ )
228
+
229
+ if self.red_teaming_progress:
230
+ self.red_teaming_progress.update_red_teaming_chats(
231
+ red_teaming_prompt_arguments.to_dict(), RunStatus.RUNNING
232
+ )
233
+
234
+ self._write_record_to_db(
235
+ red_teaming_prompt_arguments.to_tuple(), target_llm_connector.id
236
+ )
237
+
238
+ return consolidated_responses
239
+
240
+ async def _send_prompt_to_single_llm(
241
+ self, list_of_prompts: list, target_llm_connector: Connector
242
+ ) -> list:
243
+ """
244
+ Asynchronously sends prompts to a single Language Learning Model (LLM) connector.
245
+
246
+ This method takes a list of prompts, sends each prompt to the specified LLM connector, records the response,
247
+ and returns a list of consolidated responses.
248
+
249
+ Args:
250
+ list_of_prompts (list): A list of prompts to be sent to the LLM connector.
251
+ target_llm_connector: The target LLM connector to send the prompts to.
252
+
253
+ Returns:
254
+ list: A list of consolidated responses from the specified LLM connector.
255
+ """
256
+ consolidated_responses = []
257
+ for prepared_prompt in list_of_prompts:
258
+ if self.cancel_event.is_set():
259
+ print("[Red Teaming] Cancellation flag is set. Cancelling task...")
260
+ break
261
+
262
+ if self.red_teaming_progress:
263
+ self.red_teaming_progress.update_red_teaming_progress()
264
+
265
+ new_prompt_info = ConnectorPromptArguments(
266
+ prompt_index=1, prompt=prepared_prompt, target=""
267
+ )
268
+ start_time = datetime.now()
269
+ response = await Connector.get_prediction(
270
+ new_prompt_info, target_llm_connector
271
+ )
272
+ consolidated_responses.append(response)
273
+ red_teaming_prompt_arguments = RedTeamingPromptArguments(
274
+ conn_id=target_llm_connector.id,
275
+ am_id=self.id,
276
+ cs_id=self.context_strategy_instances[0].id
277
+ if self.context_strategy_info
278
+ else "",
279
+ me_id=self.metric_ids[0] if self.metric_ids else "",
280
+ pt_id=self.prompt_templates[0] if self.prompt_templates else "",
281
+ original_prompt=self.prompt, # original prompt
282
+ system_prompt=self.system_prompt, # system prompt
283
+ start_time=str(start_time),
284
+ connector_prompt=response,
285
+ )
286
+
287
+ # update callback arguments
288
+ if self.red_teaming_progress:
289
+ self.red_teaming_progress.update_red_teaming_chats(
290
+ red_teaming_prompt_arguments.to_dict(), RunStatus.RUNNING
291
+ )
292
+ self._write_record_to_db(
293
+ red_teaming_prompt_arguments.to_tuple(), target_llm_connector.id
294
+ )
295
+
296
+ return consolidated_responses
297
+
298
+ def _write_record_to_db(
299
+ self,
300
+ chat_record_tuple: tuple,
301
+ chat_record_id: str,
302
+ ) -> None:
303
+ """
304
+ Writes the chat record to the database.
305
+
306
+ Args:
307
+ chat_record_tuple (tuple): A tuple containing the chat record information.
308
+ chat_record_id (str): The ID of the chat record.
309
+ """
310
+ endpoint_id = chat_record_id.replace("-", "_")
311
+ Storage.create_database_record(
312
+ self.db_instance,
313
+ chat_record_tuple,
314
+ AttackModule.sql_create_chat_record.format(endpoint_id),
315
+ )
316
+
317
+ @abstractmethod
318
+ async def execute(self) -> Any:
319
+ """
320
+ Houses the logic of the attack and is an entry point.
321
+ * Do not change the name of this function in the attack module
322
+
323
+ Returns:
324
+ Any: the return type from the execution
325
+ """
326
+ pass
327
+
328
+ async def _generate_predictions(
329
+ self,
330
+ gen_prompts_generator: AsyncGenerator[RedTeamingPromptArguments, None],
331
+ llm_connector: Connector,
332
+ ) -> AsyncGenerator[RedTeamingPromptArguments, None]:
333
+ """
334
+ Asynchronously generates predictions for the given prompts.
335
+
336
+ Args:
337
+ gen_prompts_generator (AsyncGenerator[RedTeamingPromptArguments, None]): An asynchronous generator
338
+ yielding RedTeamingPromptArguments.
339
+
340
+ llm_connector (Connector): The connector to the Language Learning Model (LLM).
341
+
342
+ Yields:
343
+ RedTeamingPromptArguments: An asynchronous generator yielding the new prompt information.
344
+ """
345
+ async for prompt_info in gen_prompts_generator:
346
+ if self.cancel_event.is_set():
347
+ print("[Red Teaming] Cancellation flag is set. Cancelling task...")
348
+ break
349
+ new_prompt_info = RedTeamingPromptArguments(
350
+ conn_id=prompt_info.conn_id,
351
+ am_id=prompt_info.am_id,
352
+ cs_id=prompt_info.cs_id,
353
+ me_id=prompt_info.me_id,
354
+ pt_id=prompt_info.pt_id,
355
+ original_prompt=self.prompt,
356
+ system_prompt=prompt_info.system_prompt,
357
+ connector_prompt=prompt_info.connector_prompt,
358
+ start_time=str(datetime.now()),
359
+ )
360
+
361
+ # send processed prompt to llm and write record to db
362
+ new_prompt_info.connector_prompt = await Connector.get_prediction(
363
+ new_prompt_info.connector_prompt, llm_connector
364
+ )
365
+
366
+ self._write_record_to_db(new_prompt_info.to_tuple(), llm_connector.id)
367
+ yield new_prompt_info
368
+
369
+ def load_modules(self) -> None:
370
+ """
371
+ Loads connector, metric, and context strategy instances if available.
372
+ """
373
+ if self.connector_ids:
374
+ self.connector_instances = [
375
+ Connector.create(ConnectorEndpoint.read(endpoint))
376
+ for endpoint in self.connector_ids
377
+ ]
378
+ else:
379
+ raise RuntimeError(
380
+ "[Red Teaming] No connector endpoints specified for red teaming."
381
+ )
382
+
383
+ if self.metric_ids:
384
+ self.metric_instances = [
385
+ Metric.load(metric_id) for metric_id in self.metric_ids
386
+ ]
387
+
388
+ if self.context_strategy_info:
389
+ self.context_strategy_instances = [
390
+ ContextStrategy.load(context_strategy_info.get("context_strategy_id"))
391
+ for context_strategy_info in self.context_strategy_info
392
+ ]
393
+ return None
394
+
395
+ @staticmethod
396
+ def get_cache_information() -> dict:
397
+ """
398
+ Retrieves cache information from the storage.
399
+
400
+ This method attempts to read the cache information from the storage and return it as a dictionary.
401
+ If the cache information does not exist or an error occurs, it returns an empty dictionary.
402
+
403
+ Returns:
404
+ dict: A dictionary containing the cache information or an empty dictionary if an error occurs
405
+ or if the cache information does not exist.
406
+
407
+ Raises:
408
+ Exception: If there's an error during the retrieval process, it is logged and an
409
+ empty dictionary is returned.
410
+ """
411
+ try:
412
+ # Retrieve cache information from the storage and return it as a dictionary
413
+ cache_info = Storage.read_object(
414
+ EnvVariables.ATTACK_MODULES.name, AttackModule.cache_name, "json"
415
+ )
416
+ return cache_info if cache_info else {}
417
+ except Exception as e:
418
+ print(f"No previous cache information: {str(e)}")
419
+ return {}
420
+
421
+ @staticmethod
422
+ def write_cache_information(cache_info: dict) -> None:
423
+ """
424
+ Writes the updated cache information to the storage.
425
+
426
+ Args:
427
+ cache_info (dict): The cache information to be written.
428
+ """
429
+ try:
430
+ Storage.create_object(
431
+ obj_type=EnvVariables.ATTACK_MODULES.name,
432
+ obj_id=AttackModule.cache_name,
433
+ obj_info=cache_info,
434
+ obj_extension=AttackModule.cache_extension,
435
+ )
436
+ except Exception as e:
437
+ print(f"Failed to write cache information: {str(e)}")
438
+ raise e
439
+
440
+ @staticmethod
441
+ def get_available_items() -> tuple[list[str], list[dict]]:
442
+ """
443
+ Retrieves the list of available attack modules and their information.
444
+
445
+ This method scans the storage for attack modules, filters out any non-relevant files,
446
+ and updates the cache information if necessary. It returns a tuple containing a list of
447
+ attack module IDs and a list of dictionaries with detailed information about each module.
448
+
449
+ Returns:
450
+ tuple[list[str], list[dict]]: A tuple with two elements. The first element is a list
451
+ of attack module IDs. The second element is a list of
452
+ dictionaries, each containing information about an
453
+ attack module.
454
+ """
455
+ try:
456
+ retn_ams = []
457
+ retn_am_ids = []
458
+ am_cache_info = AttackModule.get_cache_information()
459
+ cache_needs_update = False # Initialize a flag to track cache updates
460
+ ams = Storage.get_objects(EnvVariables.ATTACK_MODULES.name, "py")
461
+
462
+ for am in ams:
463
+ if "__" in am:
464
+ continue
465
+
466
+ am_name = Path(am).stem
467
+ am_info, cache_updated = AttackModule._get_or_update_attack_module_info(
468
+ am_name, am_cache_info
469
+ )
470
+ if cache_updated:
471
+ cache_needs_update = True # Set the flag if any cache was updated
472
+
473
+ retn_ams.append(am_info)
474
+ retn_am_ids.append(am_name)
475
+
476
+ if cache_needs_update: # Check the flag after the loop
477
+ AttackModule.write_cache_information(am_cache_info)
478
+
479
+ return retn_am_ids, retn_ams
480
+
481
+ except Exception as e:
482
+ print(f"Failed to get available attack modules: {str(e)}")
483
+ raise e
484
+
485
+ @staticmethod
486
+ def _get_or_update_attack_module_info(
487
+ am_name: str, am_cache_info: dict
488
+ ) -> tuple[dict, bool]:
489
+ """
490
+ Retrieves or updates the attack module information from the cache.
491
+
492
+ This method checks if the attack module information is already available in the cache and if the file hash
493
+ matches the one stored in the cache. If it does, the information is retrieved from the cache.
494
+
495
+ If not, the attack module information is read from the storage, the cache is updated with the new information
496
+ and the new file hash, and a flag is set to indicate that the cache has been updated.
497
+
498
+ Args:
499
+ am_name (str): The name of the attack module.
500
+ am_cache_info (dict): A dictionary containing the cached attack module information.
501
+
502
+ Returns:
503
+ tuple[dict, bool]: A tuple containing the dictionary with the attack module information
504
+ and a boolean indicating whether the cache was updated or not.
505
+ """
506
+ file_hash = Storage.get_file_hash(
507
+ EnvVariables.ATTACK_MODULES.name, am_name, "py"
508
+ )
509
+ cache_updated = False
510
+
511
+ if am_name in am_cache_info and file_hash == am_cache_info[am_name]["hash"]:
512
+ am_metadata = am_cache_info[am_name].copy()
513
+ am_metadata.pop("hash", None)
514
+ else:
515
+ am_metadata = AttackModule.load(am_name).get_metadata() # type: ignore ; ducktyping
516
+ am_cache_info[am_name] = am_metadata.copy()
517
+ am_cache_info[am_name]["hash"] = file_hash
518
+ cache_updated = True
519
+
520
+ return am_metadata, cache_updated
521
+
522
+ @staticmethod
523
+ def delete(am_id: str) -> bool:
524
+ """
525
+ Deletes the specified attack module from storage.
526
+
527
+ This method attempts to delete the attack module identified by the given ID from the storage.
528
+ If the deletion is successful, it returns True. If an exception occurs during the deletion process,
529
+ it prints an error message and re-raises the exception.
530
+
531
+ Args:
532
+ am_id (str): The ID of the attack module to be deleted.
533
+
534
+ Returns:
535
+ bool: True if the attack module was successfully deleted.
536
+
537
+ Raises:
538
+ Exception: If an error occurs during the deletion process.
539
+ """
540
+ try:
541
+ Storage.delete_object(EnvVariables.ATTACK_MODULES.name, am_id, "py")
542
+ return True
543
+
544
+ except Exception as e:
545
+ print(f"Failed to delete attack module: {str(e)}")
546
+ raise e
547
+
548
+
549
+ class RedTeamingPromptArguments(BaseModel):
550
+ conn_id: str # The ID of the connection, default is an empty string
551
+
552
+ am_id: str # The ID of the attack module, default is an empty string
553
+
554
+ cs_id: str = "" # The ID of the context strategy, default is an empty string
555
+
556
+ me_id: str = (
557
+ "" # The ID of the metric used to score the result, default is an empty string
558
+ )
559
+
560
+ pt_id: str = "" # The ID of the prompt template, default is en empty string
561
+
562
+ original_prompt: str # The original prompt used
563
+
564
+ system_prompt: str = "" # The system-generated prompt used
565
+
566
+ start_time: str # The start time of the prediction
567
+
568
+ connector_prompt: ConnectorPromptArguments # The prompt information to send
569
+
570
+ def to_tuple(self) -> tuple:
571
+ """
572
+ Converts the RedTeamingPromptArguments instance into a tuple.
573
+
574
+ This method collects all the attributes of the RedTeamingPromptArguments instance and forms a tuple
575
+ with the attribute values in this specific order: conn_id, cs_id, pt_id, am_id, me_id, original_prompt,
576
+ connector_prompt.prompt, connector_prompt.predicted_results, connector_prompt.duration, start_time.
577
+
578
+ Returns:
579
+ tuple: A tuple representation of the RedTeamingPromptArguments instance.
580
+ """
581
+ return (
582
+ self.conn_id,
583
+ self.cs_id,
584
+ self.pt_id,
585
+ self.am_id,
586
+ self.me_id,
587
+ self.original_prompt,
588
+ self.connector_prompt.prompt,
589
+ self.system_prompt,
590
+ str(self.connector_prompt.predicted_results),
591
+ str(self.connector_prompt.duration),
592
+ self.start_time,
593
+ )
594
+
595
+ def to_dict(self) -> dict:
596
+ """
597
+ Converts the RedTeamingPromptArguments instance into a dict.
598
+
599
+ This method collects all the attributes of the RedTeamingPromptArguments instance and forms a dict
600
+ with the keys: conn_id, cs_id, pt_id, am_id, me_id, original_prompt, prepared_prompt, system_prompt
601
+ response, duration, start_time.
602
+
603
+ Returns:
604
+ dict: A dict representation of the RedTeamingPromptArguments instance.
605
+ """
606
+ return {
607
+ "conn_id": self.conn_id,
608
+ "cs_id": self.cs_id,
609
+ "pt_id": self.pt_id,
610
+ "am_id": self.am_id,
611
+ "me_id": self.me_id,
612
+ "original_prompt": self.original_prompt,
613
+ "prepared_prompt": self.connector_prompt.prompt,
614
+ "system_prompt": self.system_prompt,
615
+ "response": str(self.connector_prompt.predicted_results),
616
+ "duration": str(self.connector_prompt.duration),
617
+ "start_time": self.start_time,
618
+ }
@@ -0,0 +1,44 @@
1
+ import asyncio
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from moonshot.src.redteaming.session.red_teaming_progress import RedTeamingProgress
6
+ from moonshot.src.storage.db_interface import DBInterface
7
+
8
+
9
+ class AttackModuleArguments(BaseModel):
10
+ class Config:
11
+ arbitrary_types_allowed = True
12
+
13
+ # list of connector endpoints
14
+ connector_ids: list = []
15
+
16
+ # list of prompt template ids to be used (if any)
17
+ prompt_templates: list = []
18
+
19
+ # user's prompt
20
+ prompt: str
21
+
22
+ # system prompt
23
+ system_prompt: str = ""
24
+
25
+ # list of metric ids to be used (if any)
26
+ metric_ids: list = []
27
+
28
+ # list of context strategy ids and other params to be used (if any)
29
+ context_strategy_info: list = []
30
+
31
+ # DBAccessor for the attack module to access DB data
32
+ db_instance: DBInterface
33
+
34
+ # chat batch size for returning chat information by callback
35
+ chat_batch_size: int = 0
36
+
37
+ # callback function to return chat information
38
+ red_teaming_progress: RedTeamingProgress | None = None
39
+
40
+ # an asyncio event to cancel red teaming if a cancel signal is sent
41
+ cancel_event: asyncio.Event
42
+
43
+ # a dict that contains other params that is required by the attack module (if any)
44
+ params: dict = {}