aiverify-moonshot 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
- aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
- aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
- moonshot/__init__.py +0 -0
- moonshot/__main__.py +198 -0
- moonshot/api.py +155 -0
- moonshot/integrations/__init__.py +0 -0
- moonshot/integrations/cli/__init__.py +0 -0
- moonshot/integrations/cli/__main__.py +25 -0
- moonshot/integrations/cli/active_session_cfg.py +1 -0
- moonshot/integrations/cli/benchmark/__init__.py +0 -0
- moonshot/integrations/cli/benchmark/benchmark.py +186 -0
- moonshot/integrations/cli/benchmark/cookbook.py +545 -0
- moonshot/integrations/cli/benchmark/datasets.py +164 -0
- moonshot/integrations/cli/benchmark/metrics.py +141 -0
- moonshot/integrations/cli/benchmark/recipe.py +598 -0
- moonshot/integrations/cli/benchmark/result.py +216 -0
- moonshot/integrations/cli/benchmark/run.py +140 -0
- moonshot/integrations/cli/benchmark/runner.py +174 -0
- moonshot/integrations/cli/cli.py +64 -0
- moonshot/integrations/cli/common/__init__.py +0 -0
- moonshot/integrations/cli/common/common.py +72 -0
- moonshot/integrations/cli/common/connectors.py +325 -0
- moonshot/integrations/cli/common/display_helper.py +42 -0
- moonshot/integrations/cli/common/prompt_template.py +94 -0
- moonshot/integrations/cli/initialisation/__init__.py +0 -0
- moonshot/integrations/cli/initialisation/initialisation.py +14 -0
- moonshot/integrations/cli/redteam/__init__.py +0 -0
- moonshot/integrations/cli/redteam/attack_module.py +70 -0
- moonshot/integrations/cli/redteam/context_strategy.py +147 -0
- moonshot/integrations/cli/redteam/prompt_template.py +67 -0
- moonshot/integrations/cli/redteam/redteam.py +90 -0
- moonshot/integrations/cli/redteam/session.py +467 -0
- moonshot/integrations/web_api/.env.dev +7 -0
- moonshot/integrations/web_api/__init__.py +0 -0
- moonshot/integrations/web_api/__main__.py +56 -0
- moonshot/integrations/web_api/app.py +125 -0
- moonshot/integrations/web_api/container.py +146 -0
- moonshot/integrations/web_api/log/.gitkeep +0 -0
- moonshot/integrations/web_api/logging_conf.py +114 -0
- moonshot/integrations/web_api/routes/__init__.py +0 -0
- moonshot/integrations/web_api/routes/attack_modules.py +66 -0
- moonshot/integrations/web_api/routes/benchmark.py +116 -0
- moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
- moonshot/integrations/web_api/routes/context_strategy.py +129 -0
- moonshot/integrations/web_api/routes/cookbook.py +225 -0
- moonshot/integrations/web_api/routes/dataset.py +120 -0
- moonshot/integrations/web_api/routes/endpoint.py +282 -0
- moonshot/integrations/web_api/routes/metric.py +78 -0
- moonshot/integrations/web_api/routes/prompt_template.py +128 -0
- moonshot/integrations/web_api/routes/recipe.py +219 -0
- moonshot/integrations/web_api/routes/redteam.py +609 -0
- moonshot/integrations/web_api/routes/runner.py +239 -0
- moonshot/integrations/web_api/schemas/__init__.py +0 -0
- moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
- moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
- moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
- moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
- moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
- moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
- moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
- moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
- moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
- moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
- moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
- moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
- moonshot/integrations/web_api/services/__init__.py +0 -0
- moonshot/integrations/web_api/services/attack_module_service.py +34 -0
- moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
- moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
- moonshot/integrations/web_api/services/base_service.py +8 -0
- moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
- moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
- moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
- moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
- moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
- moonshot/integrations/web_api/services/cookbook_service.py +194 -0
- moonshot/integrations/web_api/services/dataset_service.py +20 -0
- moonshot/integrations/web_api/services/endpoint_service.py +65 -0
- moonshot/integrations/web_api/services/metric_service.py +14 -0
- moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
- moonshot/integrations/web_api/services/recipe_service.py +155 -0
- moonshot/integrations/web_api/services/runner_service.py +147 -0
- moonshot/integrations/web_api/services/session_service.py +350 -0
- moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
- moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
- moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
- moonshot/integrations/web_api/types/types.py +99 -0
- moonshot/src/__init__.py +0 -0
- moonshot/src/api/__init__.py +0 -0
- moonshot/src/api/api_connector.py +58 -0
- moonshot/src/api/api_connector_endpoint.py +162 -0
- moonshot/src/api/api_context_strategy.py +57 -0
- moonshot/src/api/api_cookbook.py +160 -0
- moonshot/src/api/api_dataset.py +46 -0
- moonshot/src/api/api_environment_variables.py +17 -0
- moonshot/src/api/api_metrics.py +51 -0
- moonshot/src/api/api_prompt_template.py +43 -0
- moonshot/src/api/api_recipe.py +182 -0
- moonshot/src/api/api_red_teaming.py +59 -0
- moonshot/src/api/api_result.py +84 -0
- moonshot/src/api/api_run.py +74 -0
- moonshot/src/api/api_runner.py +132 -0
- moonshot/src/api/api_session.py +290 -0
- moonshot/src/configs/__init__.py +0 -0
- moonshot/src/configs/env_variables.py +187 -0
- moonshot/src/connectors/__init__.py +0 -0
- moonshot/src/connectors/connector.py +327 -0
- moonshot/src/connectors/connector_prompt_arguments.py +17 -0
- moonshot/src/connectors_endpoints/__init__.py +0 -0
- moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
- moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
- moonshot/src/cookbooks/__init__.py +0 -0
- moonshot/src/cookbooks/cookbook.py +225 -0
- moonshot/src/cookbooks/cookbook_arguments.py +34 -0
- moonshot/src/datasets/__init__.py +0 -0
- moonshot/src/datasets/dataset.py +255 -0
- moonshot/src/datasets/dataset_arguments.py +50 -0
- moonshot/src/metrics/__init__.py +0 -0
- moonshot/src/metrics/metric.py +192 -0
- moonshot/src/metrics/metric_interface.py +95 -0
- moonshot/src/prompt_templates/__init__.py +0 -0
- moonshot/src/prompt_templates/prompt_template.py +103 -0
- moonshot/src/recipes/__init__.py +0 -0
- moonshot/src/recipes/recipe.py +340 -0
- moonshot/src/recipes/recipe_arguments.py +111 -0
- moonshot/src/redteaming/__init__.py +0 -0
- moonshot/src/redteaming/attack/__init__.py +0 -0
- moonshot/src/redteaming/attack/attack_module.py +618 -0
- moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
- moonshot/src/redteaming/attack/context_strategy.py +131 -0
- moonshot/src/redteaming/context_strategy/__init__.py +0 -0
- moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
- moonshot/src/redteaming/session/__init__.py +0 -0
- moonshot/src/redteaming/session/chat.py +209 -0
- moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
- moonshot/src/redteaming/session/red_teaming_type.py +6 -0
- moonshot/src/redteaming/session/session.py +775 -0
- moonshot/src/results/__init__.py +0 -0
- moonshot/src/results/result.py +119 -0
- moonshot/src/results/result_arguments.py +44 -0
- moonshot/src/runners/__init__.py +0 -0
- moonshot/src/runners/runner.py +476 -0
- moonshot/src/runners/runner_arguments.py +46 -0
- moonshot/src/runners/runner_type.py +6 -0
- moonshot/src/runs/__init__.py +0 -0
- moonshot/src/runs/run.py +344 -0
- moonshot/src/runs/run_arguments.py +162 -0
- moonshot/src/runs/run_progress.py +145 -0
- moonshot/src/runs/run_status.py +10 -0
- moonshot/src/storage/__init__.py +0 -0
- moonshot/src/storage/db_interface.py +128 -0
- moonshot/src/storage/io_interface.py +31 -0
- moonshot/src/storage/storage.py +525 -0
- moonshot/src/utils/__init__.py +0 -0
- moonshot/src/utils/import_modules.py +96 -0
- moonshot/src/utils/timeit.py +25 -0
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import abstractmethod
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, AsyncGenerator
|
|
7
|
+
|
|
8
|
+
from jinja2 import Template
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from moonshot.src.configs.env_variables import EnvVariables
|
|
12
|
+
from moonshot.src.connectors.connector import Connector
|
|
13
|
+
from moonshot.src.connectors.connector_prompt_arguments import ConnectorPromptArguments
|
|
14
|
+
from moonshot.src.connectors_endpoints.connector_endpoint import ConnectorEndpoint
|
|
15
|
+
from moonshot.src.metrics.metric import Metric
|
|
16
|
+
from moonshot.src.redteaming.attack.attack_module_arguments import AttackModuleArguments
|
|
17
|
+
from moonshot.src.redteaming.attack.context_strategy import ContextStrategy
|
|
18
|
+
from moonshot.src.runs.run_status import RunStatus
|
|
19
|
+
from moonshot.src.storage.storage import Storage
|
|
20
|
+
from moonshot.src.utils.import_modules import get_instance
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class AttackModule:
|
|
24
|
+
cache_name = "cache"
|
|
25
|
+
cache_extension = "json"
|
|
26
|
+
sql_create_chat_record = """
|
|
27
|
+
INSERT INTO {} (connection_id,context_strategy,prompt_template,attack_module,
|
|
28
|
+
metric,prompt,prepared_prompt,system_prompt,predicted_result,duration,prompt_time)VALUES(?,?,?,?,?,?,?,?,?,?,?)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, am_id: str, am_arguments: AttackModuleArguments | None = None):
|
|
32
|
+
self.id = am_id
|
|
33
|
+
if am_arguments is not None:
|
|
34
|
+
self.connector_ids = am_arguments.connector_ids
|
|
35
|
+
self.prompt_templates = am_arguments.prompt_templates
|
|
36
|
+
self.prompt = am_arguments.prompt
|
|
37
|
+
self.system_prompt = am_arguments.system_prompt
|
|
38
|
+
self.metric_ids = am_arguments.metric_ids
|
|
39
|
+
self.context_strategy_info = am_arguments.context_strategy_info
|
|
40
|
+
self.db_instance = am_arguments.db_instance
|
|
41
|
+
self.red_teaming_progress = am_arguments.red_teaming_progress
|
|
42
|
+
self.cancel_event = am_arguments.cancel_event
|
|
43
|
+
self.params = am_arguments.params
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def load(
|
|
47
|
+
cls, am_id: str, am_arguments: AttackModuleArguments | None = None
|
|
48
|
+
) -> AttackModule:
|
|
49
|
+
"""
|
|
50
|
+
Retrieves an attack module instance by its ID.
|
|
51
|
+
|
|
52
|
+
This method attempts to load an attack module instance using the provided ID. If the attack module instance
|
|
53
|
+
is found, it is returned. If the attack module instance does not exist, a RuntimeError is raised.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
am_id (str): The unique identifier of the attack module to be retrieved.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
AttackModule: The retrieved attack module instance.
|
|
60
|
+
|
|
61
|
+
Raises:
|
|
62
|
+
RuntimeError: If the attack module instance does not exist.
|
|
63
|
+
"""
|
|
64
|
+
attack_module_inst = get_instance(
|
|
65
|
+
am_id,
|
|
66
|
+
Storage.get_filepath(EnvVariables.ATTACK_MODULES.name, am_id, "py"),
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if attack_module_inst:
|
|
70
|
+
return attack_module_inst(am_id, am_arguments)
|
|
71
|
+
else:
|
|
72
|
+
raise RuntimeError(
|
|
73
|
+
f"Unable to get defined attack module instance - {am_id}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
@abstractmethod
|
|
77
|
+
def get_metadata(self) -> dict:
|
|
78
|
+
"""
|
|
79
|
+
Get metadata for the attack module.
|
|
80
|
+
|
|
81
|
+
Returns a dictionary of the attack module metadata.
|
|
82
|
+
Returns:
|
|
83
|
+
dict: A dictionary containing the metadata of the attack module.
|
|
84
|
+
"""
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
async def _generate_prompts(
|
|
88
|
+
self, prompt: str, target_llm_connector_id: str
|
|
89
|
+
) -> AsyncGenerator[RedTeamingPromptArguments, None]:
|
|
90
|
+
"""
|
|
91
|
+
Generates prompts for red teaming.
|
|
92
|
+
|
|
93
|
+
This method asynchronously generates prompts for red teaming based on the provided prompt and target LLM
|
|
94
|
+
connector ID. It processes the prompt using context strategy and prompt template if specified, and
|
|
95
|
+
yields RedTeamingPromptArguments for each generated prompt.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
prompt (str): The prompt to be processed and sent to the target LLM.
|
|
99
|
+
target_llm_connector_id (str): The unique identifier of the target LLM connector.
|
|
100
|
+
|
|
101
|
+
Yields:
|
|
102
|
+
RedTeamingPromptArguments: An instance of RedTeamingPromptArguments containing the
|
|
103
|
+
generated prompt details.
|
|
104
|
+
|
|
105
|
+
"""
|
|
106
|
+
if self.context_strategy_info:
|
|
107
|
+
context_strategy_instance = self.context_strategy_instances[0]
|
|
108
|
+
num_of_prev_prompts = self.context_strategy_info[0].get(
|
|
109
|
+
"num_of_prev_prompts"
|
|
110
|
+
)
|
|
111
|
+
prompt = ContextStrategy.process_prompt_cs(
|
|
112
|
+
prompt,
|
|
113
|
+
context_strategy_instance.id,
|
|
114
|
+
self.db_instance,
|
|
115
|
+
target_llm_connector_id,
|
|
116
|
+
num_of_prev_prompts,
|
|
117
|
+
)
|
|
118
|
+
if self.prompt_templates:
|
|
119
|
+
# prepare prompt template generator
|
|
120
|
+
pt_id = self.prompt_templates[0]
|
|
121
|
+
pt_info = Storage.read_object_with_iterator(
|
|
122
|
+
EnvVariables.PROMPT_TEMPLATES.name,
|
|
123
|
+
pt_id,
|
|
124
|
+
"json",
|
|
125
|
+
iterator_keys=["template"],
|
|
126
|
+
)
|
|
127
|
+
pt = next(pt_info["template"])
|
|
128
|
+
jinja2_template = Template(pt)
|
|
129
|
+
prompt = jinja2_template.render({"prompt": prompt})
|
|
130
|
+
|
|
131
|
+
yield RedTeamingPromptArguments(
|
|
132
|
+
conn_id=target_llm_connector_id,
|
|
133
|
+
am_id=self.id,
|
|
134
|
+
cs_id=self.context_strategy_instances[0].id
|
|
135
|
+
if self.context_strategy_info
|
|
136
|
+
else "",
|
|
137
|
+
pt_id=self.prompt_templates[0] if self.prompt_templates else "",
|
|
138
|
+
me_id=self.metric_ids[0] if self.metric_ids else "",
|
|
139
|
+
original_prompt=self.prompt,
|
|
140
|
+
system_prompt=self.system_prompt,
|
|
141
|
+
start_time="",
|
|
142
|
+
connector_prompt=ConnectorPromptArguments(
|
|
143
|
+
prompt_index=0,
|
|
144
|
+
prompt=prompt,
|
|
145
|
+
target="",
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
async def _send_prompt_to_all_llm_default(self) -> list:
|
|
150
|
+
"""
|
|
151
|
+
NOTE: this method does not currently handle callbacks
|
|
152
|
+
Asynchronously sends the default prompt to all Language Learning Models (LLMs).
|
|
153
|
+
|
|
154
|
+
This method generates prompts by appending the contents of the prompt template and modifies the prompt with the
|
|
155
|
+
context strategy for each LLM, sends each prompt to the respective LLM, and consolidates the responses into a
|
|
156
|
+
list.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
list: A list of consolidated responses from all LLMs.
|
|
160
|
+
"""
|
|
161
|
+
generator_list = []
|
|
162
|
+
consolidated_result_list = []
|
|
163
|
+
if self.connector_ids:
|
|
164
|
+
for target_llm_connector in self.connector_instances:
|
|
165
|
+
gen_prompts_generator = self._generate_prompts(
|
|
166
|
+
self.prompt, target_llm_connector.id
|
|
167
|
+
)
|
|
168
|
+
gen_results_generator = self._generate_predictions(
|
|
169
|
+
gen_prompts_generator, target_llm_connector
|
|
170
|
+
)
|
|
171
|
+
generator_list.append(gen_results_generator)
|
|
172
|
+
|
|
173
|
+
for generator in generator_list:
|
|
174
|
+
async for result in generator:
|
|
175
|
+
if self.cancel_event.is_set():
|
|
176
|
+
print(
|
|
177
|
+
"[Red Teaming] Cancellation flag is set. Cancelling task..."
|
|
178
|
+
)
|
|
179
|
+
break
|
|
180
|
+
consolidated_result_list.append(result)
|
|
181
|
+
return consolidated_result_list
|
|
182
|
+
|
|
183
|
+
async def _send_prompt_to_all_llm(self, list_of_prompts: list) -> list:
|
|
184
|
+
"""
|
|
185
|
+
Asynchronously sends prompts to all Language Learning Models (LLMs).
|
|
186
|
+
|
|
187
|
+
This method takes a list of prompts, sends each prompt to all LLM connectors, records the responses,
|
|
188
|
+
and returns a list of consolidated responses.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
list_of_prompts (list): A list of prompts to be sent to the LLM connectors.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
list: A list of consolidated responses from all LLM connectors.
|
|
195
|
+
"""
|
|
196
|
+
consolidated_responses = []
|
|
197
|
+
for prepared_prompt in list_of_prompts:
|
|
198
|
+
for target_llm_connector in self.connector_instances:
|
|
199
|
+
if self.cancel_event.is_set():
|
|
200
|
+
print("[Red Teaming] Cancellation flag is set. Cancelling task...")
|
|
201
|
+
break
|
|
202
|
+
|
|
203
|
+
if self.red_teaming_progress:
|
|
204
|
+
self.red_teaming_progress.update_red_teaming_progress()
|
|
205
|
+
|
|
206
|
+
new_prompt_info = ConnectorPromptArguments(
|
|
207
|
+
prompt_index=1, prompt=prepared_prompt, target=""
|
|
208
|
+
)
|
|
209
|
+
start_time = datetime.now()
|
|
210
|
+
response = await Connector.get_prediction(
|
|
211
|
+
new_prompt_info, target_llm_connector
|
|
212
|
+
)
|
|
213
|
+
consolidated_responses.append(response)
|
|
214
|
+
|
|
215
|
+
red_teaming_prompt_arguments = RedTeamingPromptArguments(
|
|
216
|
+
conn_id=target_llm_connector.id,
|
|
217
|
+
am_id=self.id,
|
|
218
|
+
cs_id=self.context_strategy_instances[0].id
|
|
219
|
+
if self.context_strategy_info
|
|
220
|
+
else "",
|
|
221
|
+
me_id=self.metric_ids[0] if self.metric_ids else "",
|
|
222
|
+
pt_id=self.prompt_templates[0] if self.prompt_templates else "",
|
|
223
|
+
original_prompt=self.prompt, # original prompt
|
|
224
|
+
system_prompt=self.system_prompt, # system prompt
|
|
225
|
+
start_time=str(start_time),
|
|
226
|
+
connector_prompt=response,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if self.red_teaming_progress:
|
|
230
|
+
self.red_teaming_progress.update_red_teaming_chats(
|
|
231
|
+
red_teaming_prompt_arguments.to_dict(), RunStatus.RUNNING
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
self._write_record_to_db(
|
|
235
|
+
red_teaming_prompt_arguments.to_tuple(), target_llm_connector.id
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return consolidated_responses
|
|
239
|
+
|
|
240
|
+
async def _send_prompt_to_single_llm(
|
|
241
|
+
self, list_of_prompts: list, target_llm_connector: Connector
|
|
242
|
+
) -> list:
|
|
243
|
+
"""
|
|
244
|
+
Asynchronously sends prompts to a single Language Learning Model (LLM) connector.
|
|
245
|
+
|
|
246
|
+
This method takes a list of prompts, sends each prompt to the specified LLM connector, records the response,
|
|
247
|
+
and returns a list of consolidated responses.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
list_of_prompts (list): A list of prompts to be sent to the LLM connector.
|
|
251
|
+
target_llm_connector: The target LLM connector to send the prompts to.
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
list: A list of consolidated responses from the specified LLM connector.
|
|
255
|
+
"""
|
|
256
|
+
consolidated_responses = []
|
|
257
|
+
for prepared_prompt in list_of_prompts:
|
|
258
|
+
if self.cancel_event.is_set():
|
|
259
|
+
print("[Red Teaming] Cancellation flag is set. Cancelling task...")
|
|
260
|
+
break
|
|
261
|
+
|
|
262
|
+
if self.red_teaming_progress:
|
|
263
|
+
self.red_teaming_progress.update_red_teaming_progress()
|
|
264
|
+
|
|
265
|
+
new_prompt_info = ConnectorPromptArguments(
|
|
266
|
+
prompt_index=1, prompt=prepared_prompt, target=""
|
|
267
|
+
)
|
|
268
|
+
start_time = datetime.now()
|
|
269
|
+
response = await Connector.get_prediction(
|
|
270
|
+
new_prompt_info, target_llm_connector
|
|
271
|
+
)
|
|
272
|
+
consolidated_responses.append(response)
|
|
273
|
+
red_teaming_prompt_arguments = RedTeamingPromptArguments(
|
|
274
|
+
conn_id=target_llm_connector.id,
|
|
275
|
+
am_id=self.id,
|
|
276
|
+
cs_id=self.context_strategy_instances[0].id
|
|
277
|
+
if self.context_strategy_info
|
|
278
|
+
else "",
|
|
279
|
+
me_id=self.metric_ids[0] if self.metric_ids else "",
|
|
280
|
+
pt_id=self.prompt_templates[0] if self.prompt_templates else "",
|
|
281
|
+
original_prompt=self.prompt, # original prompt
|
|
282
|
+
system_prompt=self.system_prompt, # system prompt
|
|
283
|
+
start_time=str(start_time),
|
|
284
|
+
connector_prompt=response,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# update callback arguments
|
|
288
|
+
if self.red_teaming_progress:
|
|
289
|
+
self.red_teaming_progress.update_red_teaming_chats(
|
|
290
|
+
red_teaming_prompt_arguments.to_dict(), RunStatus.RUNNING
|
|
291
|
+
)
|
|
292
|
+
self._write_record_to_db(
|
|
293
|
+
red_teaming_prompt_arguments.to_tuple(), target_llm_connector.id
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
return consolidated_responses
|
|
297
|
+
|
|
298
|
+
def _write_record_to_db(
|
|
299
|
+
self,
|
|
300
|
+
chat_record_tuple: tuple,
|
|
301
|
+
chat_record_id: str,
|
|
302
|
+
) -> None:
|
|
303
|
+
"""
|
|
304
|
+
Writes the chat record to the database.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
chat_record_tuple (tuple): A tuple containing the chat record information.
|
|
308
|
+
chat_record_id (str): The ID of the chat record.
|
|
309
|
+
"""
|
|
310
|
+
endpoint_id = chat_record_id.replace("-", "_")
|
|
311
|
+
Storage.create_database_record(
|
|
312
|
+
self.db_instance,
|
|
313
|
+
chat_record_tuple,
|
|
314
|
+
AttackModule.sql_create_chat_record.format(endpoint_id),
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
@abstractmethod
|
|
318
|
+
async def execute(self) -> Any:
|
|
319
|
+
"""
|
|
320
|
+
Houses the logic of the attack and is an entry point.
|
|
321
|
+
* Do not change the name of this function in the attack module
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Any: the return type from the execution
|
|
325
|
+
"""
|
|
326
|
+
pass
|
|
327
|
+
|
|
328
|
+
async def _generate_predictions(
|
|
329
|
+
self,
|
|
330
|
+
gen_prompts_generator: AsyncGenerator[RedTeamingPromptArguments, None],
|
|
331
|
+
llm_connector: Connector,
|
|
332
|
+
) -> AsyncGenerator[RedTeamingPromptArguments, None]:
|
|
333
|
+
"""
|
|
334
|
+
Asynchronously generates predictions for the given prompts.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
gen_prompts_generator (AsyncGenerator[RedTeamingPromptArguments, None]): An asynchronous generator
|
|
338
|
+
yielding RedTeamingPromptArguments.
|
|
339
|
+
|
|
340
|
+
llm_connector (Connector): The connector to the Language Learning Model (LLM).
|
|
341
|
+
|
|
342
|
+
Yields:
|
|
343
|
+
RedTeamingPromptArguments: An asynchronous generator yielding the new prompt information.
|
|
344
|
+
"""
|
|
345
|
+
async for prompt_info in gen_prompts_generator:
|
|
346
|
+
if self.cancel_event.is_set():
|
|
347
|
+
print("[Red Teaming] Cancellation flag is set. Cancelling task...")
|
|
348
|
+
break
|
|
349
|
+
new_prompt_info = RedTeamingPromptArguments(
|
|
350
|
+
conn_id=prompt_info.conn_id,
|
|
351
|
+
am_id=prompt_info.am_id,
|
|
352
|
+
cs_id=prompt_info.cs_id,
|
|
353
|
+
me_id=prompt_info.me_id,
|
|
354
|
+
pt_id=prompt_info.pt_id,
|
|
355
|
+
original_prompt=self.prompt,
|
|
356
|
+
system_prompt=prompt_info.system_prompt,
|
|
357
|
+
connector_prompt=prompt_info.connector_prompt,
|
|
358
|
+
start_time=str(datetime.now()),
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# send processed prompt to llm and write record to db
|
|
362
|
+
new_prompt_info.connector_prompt = await Connector.get_prediction(
|
|
363
|
+
new_prompt_info.connector_prompt, llm_connector
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
self._write_record_to_db(new_prompt_info.to_tuple(), llm_connector.id)
|
|
367
|
+
yield new_prompt_info
|
|
368
|
+
|
|
369
|
+
def load_modules(self) -> None:
|
|
370
|
+
"""
|
|
371
|
+
Loads connector, metric, and context strategy instances if available.
|
|
372
|
+
"""
|
|
373
|
+
if self.connector_ids:
|
|
374
|
+
self.connector_instances = [
|
|
375
|
+
Connector.create(ConnectorEndpoint.read(endpoint))
|
|
376
|
+
for endpoint in self.connector_ids
|
|
377
|
+
]
|
|
378
|
+
else:
|
|
379
|
+
raise RuntimeError(
|
|
380
|
+
"[Red Teaming] No connector endpoints specified for red teaming."
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
if self.metric_ids:
|
|
384
|
+
self.metric_instances = [
|
|
385
|
+
Metric.load(metric_id) for metric_id in self.metric_ids
|
|
386
|
+
]
|
|
387
|
+
|
|
388
|
+
if self.context_strategy_info:
|
|
389
|
+
self.context_strategy_instances = [
|
|
390
|
+
ContextStrategy.load(context_strategy_info.get("context_strategy_id"))
|
|
391
|
+
for context_strategy_info in self.context_strategy_info
|
|
392
|
+
]
|
|
393
|
+
return None
|
|
394
|
+
|
|
395
|
+
@staticmethod
|
|
396
|
+
def get_cache_information() -> dict:
|
|
397
|
+
"""
|
|
398
|
+
Retrieves cache information from the storage.
|
|
399
|
+
|
|
400
|
+
This method attempts to read the cache information from the storage and return it as a dictionary.
|
|
401
|
+
If the cache information does not exist or an error occurs, it returns an empty dictionary.
|
|
402
|
+
|
|
403
|
+
Returns:
|
|
404
|
+
dict: A dictionary containing the cache information or an empty dictionary if an error occurs
|
|
405
|
+
or if the cache information does not exist.
|
|
406
|
+
|
|
407
|
+
Raises:
|
|
408
|
+
Exception: If there's an error during the retrieval process, it is logged and an
|
|
409
|
+
empty dictionary is returned.
|
|
410
|
+
"""
|
|
411
|
+
try:
|
|
412
|
+
# Retrieve cache information from the storage and return it as a dictionary
|
|
413
|
+
cache_info = Storage.read_object(
|
|
414
|
+
EnvVariables.ATTACK_MODULES.name, AttackModule.cache_name, "json"
|
|
415
|
+
)
|
|
416
|
+
return cache_info if cache_info else {}
|
|
417
|
+
except Exception as e:
|
|
418
|
+
print(f"No previous cache information: {str(e)}")
|
|
419
|
+
return {}
|
|
420
|
+
|
|
421
|
+
@staticmethod
|
|
422
|
+
def write_cache_information(cache_info: dict) -> None:
|
|
423
|
+
"""
|
|
424
|
+
Writes the updated cache information to the storage.
|
|
425
|
+
|
|
426
|
+
Args:
|
|
427
|
+
cache_info (dict): The cache information to be written.
|
|
428
|
+
"""
|
|
429
|
+
try:
|
|
430
|
+
Storage.create_object(
|
|
431
|
+
obj_type=EnvVariables.ATTACK_MODULES.name,
|
|
432
|
+
obj_id=AttackModule.cache_name,
|
|
433
|
+
obj_info=cache_info,
|
|
434
|
+
obj_extension=AttackModule.cache_extension,
|
|
435
|
+
)
|
|
436
|
+
except Exception as e:
|
|
437
|
+
print(f"Failed to write cache information: {str(e)}")
|
|
438
|
+
raise e
|
|
439
|
+
|
|
440
|
+
@staticmethod
|
|
441
|
+
def get_available_items() -> tuple[list[str], list[dict]]:
|
|
442
|
+
"""
|
|
443
|
+
Retrieves the list of available attack modules and their information.
|
|
444
|
+
|
|
445
|
+
This method scans the storage for attack modules, filters out any non-relevant files,
|
|
446
|
+
and updates the cache information if necessary. It returns a tuple containing a list of
|
|
447
|
+
attack module IDs and a list of dictionaries with detailed information about each module.
|
|
448
|
+
|
|
449
|
+
Returns:
|
|
450
|
+
tuple[list[str], list[dict]]: A tuple with two elements. The first element is a list
|
|
451
|
+
of attack module IDs. The second element is a list of
|
|
452
|
+
dictionaries, each containing information about an
|
|
453
|
+
attack module.
|
|
454
|
+
"""
|
|
455
|
+
try:
|
|
456
|
+
retn_ams = []
|
|
457
|
+
retn_am_ids = []
|
|
458
|
+
am_cache_info = AttackModule.get_cache_information()
|
|
459
|
+
cache_needs_update = False # Initialize a flag to track cache updates
|
|
460
|
+
ams = Storage.get_objects(EnvVariables.ATTACK_MODULES.name, "py")
|
|
461
|
+
|
|
462
|
+
for am in ams:
|
|
463
|
+
if "__" in am:
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
am_name = Path(am).stem
|
|
467
|
+
am_info, cache_updated = AttackModule._get_or_update_attack_module_info(
|
|
468
|
+
am_name, am_cache_info
|
|
469
|
+
)
|
|
470
|
+
if cache_updated:
|
|
471
|
+
cache_needs_update = True # Set the flag if any cache was updated
|
|
472
|
+
|
|
473
|
+
retn_ams.append(am_info)
|
|
474
|
+
retn_am_ids.append(am_name)
|
|
475
|
+
|
|
476
|
+
if cache_needs_update: # Check the flag after the loop
|
|
477
|
+
AttackModule.write_cache_information(am_cache_info)
|
|
478
|
+
|
|
479
|
+
return retn_am_ids, retn_ams
|
|
480
|
+
|
|
481
|
+
except Exception as e:
|
|
482
|
+
print(f"Failed to get available attack modules: {str(e)}")
|
|
483
|
+
raise e
|
|
484
|
+
|
|
485
|
+
@staticmethod
|
|
486
|
+
def _get_or_update_attack_module_info(
|
|
487
|
+
am_name: str, am_cache_info: dict
|
|
488
|
+
) -> tuple[dict, bool]:
|
|
489
|
+
"""
|
|
490
|
+
Retrieves or updates the attack module information from the cache.
|
|
491
|
+
|
|
492
|
+
This method checks if the attack module information is already available in the cache and if the file hash
|
|
493
|
+
matches the one stored in the cache. If it does, the information is retrieved from the cache.
|
|
494
|
+
|
|
495
|
+
If not, the attack module information is read from the storage, the cache is updated with the new information
|
|
496
|
+
and the new file hash, and a flag is set to indicate that the cache has been updated.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
am_name (str): The name of the attack module.
|
|
500
|
+
am_cache_info (dict): A dictionary containing the cached attack module information.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
tuple[dict, bool]: A tuple containing the dictionary with the attack module information
|
|
504
|
+
and a boolean indicating whether the cache was updated or not.
|
|
505
|
+
"""
|
|
506
|
+
file_hash = Storage.get_file_hash(
|
|
507
|
+
EnvVariables.ATTACK_MODULES.name, am_name, "py"
|
|
508
|
+
)
|
|
509
|
+
cache_updated = False
|
|
510
|
+
|
|
511
|
+
if am_name in am_cache_info and file_hash == am_cache_info[am_name]["hash"]:
|
|
512
|
+
am_metadata = am_cache_info[am_name].copy()
|
|
513
|
+
am_metadata.pop("hash", None)
|
|
514
|
+
else:
|
|
515
|
+
am_metadata = AttackModule.load(am_name).get_metadata() # type: ignore ; ducktyping
|
|
516
|
+
am_cache_info[am_name] = am_metadata.copy()
|
|
517
|
+
am_cache_info[am_name]["hash"] = file_hash
|
|
518
|
+
cache_updated = True
|
|
519
|
+
|
|
520
|
+
return am_metadata, cache_updated
|
|
521
|
+
|
|
522
|
+
@staticmethod
|
|
523
|
+
def delete(am_id: str) -> bool:
|
|
524
|
+
"""
|
|
525
|
+
Deletes the specified attack module from storage.
|
|
526
|
+
|
|
527
|
+
This method attempts to delete the attack module identified by the given ID from the storage.
|
|
528
|
+
If the deletion is successful, it returns True. If an exception occurs during the deletion process,
|
|
529
|
+
it prints an error message and re-raises the exception.
|
|
530
|
+
|
|
531
|
+
Args:
|
|
532
|
+
am_id (str): The ID of the attack module to be deleted.
|
|
533
|
+
|
|
534
|
+
Returns:
|
|
535
|
+
bool: True if the attack module was successfully deleted.
|
|
536
|
+
|
|
537
|
+
Raises:
|
|
538
|
+
Exception: If an error occurs during the deletion process.
|
|
539
|
+
"""
|
|
540
|
+
try:
|
|
541
|
+
Storage.delete_object(EnvVariables.ATTACK_MODULES.name, am_id, "py")
|
|
542
|
+
return True
|
|
543
|
+
|
|
544
|
+
except Exception as e:
|
|
545
|
+
print(f"Failed to delete attack module: {str(e)}")
|
|
546
|
+
raise e
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
class RedTeamingPromptArguments(BaseModel):
|
|
550
|
+
conn_id: str # The ID of the connection, default is an empty string
|
|
551
|
+
|
|
552
|
+
am_id: str # The ID of the attack module, default is an empty string
|
|
553
|
+
|
|
554
|
+
cs_id: str = "" # The ID of the context strategy, default is an empty string
|
|
555
|
+
|
|
556
|
+
me_id: str = (
|
|
557
|
+
"" # The ID of the metric used to score the result, default is an empty string
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
pt_id: str = "" # The ID of the prompt template, default is en empty string
|
|
561
|
+
|
|
562
|
+
original_prompt: str # The original prompt used
|
|
563
|
+
|
|
564
|
+
system_prompt: str = "" # The system-generated prompt used
|
|
565
|
+
|
|
566
|
+
start_time: str # The start time of the prediction
|
|
567
|
+
|
|
568
|
+
connector_prompt: ConnectorPromptArguments # The prompt information to send
|
|
569
|
+
|
|
570
|
+
def to_tuple(self) -> tuple:
|
|
571
|
+
"""
|
|
572
|
+
Converts the RedTeamingPromptArguments instance into a tuple.
|
|
573
|
+
|
|
574
|
+
This method collects all the attributes of the RedTeamingPromptArguments instance and forms a tuple
|
|
575
|
+
with the attribute values in this specific order: conn_id, cs_id, pt_id, am_id, me_id, original_prompt,
|
|
576
|
+
connector_prompt.prompt, connector_prompt.predicted_results, connector_prompt.duration, start_time.
|
|
577
|
+
|
|
578
|
+
Returns:
|
|
579
|
+
tuple: A tuple representation of the RedTeamingPromptArguments instance.
|
|
580
|
+
"""
|
|
581
|
+
return (
|
|
582
|
+
self.conn_id,
|
|
583
|
+
self.cs_id,
|
|
584
|
+
self.pt_id,
|
|
585
|
+
self.am_id,
|
|
586
|
+
self.me_id,
|
|
587
|
+
self.original_prompt,
|
|
588
|
+
self.connector_prompt.prompt,
|
|
589
|
+
self.system_prompt,
|
|
590
|
+
str(self.connector_prompt.predicted_results),
|
|
591
|
+
str(self.connector_prompt.duration),
|
|
592
|
+
self.start_time,
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def to_dict(self) -> dict:
|
|
596
|
+
"""
|
|
597
|
+
Converts the RedTeamingPromptArguments instance into a dict.
|
|
598
|
+
|
|
599
|
+
This method collects all the attributes of the RedTeamingPromptArguments instance and forms a dict
|
|
600
|
+
with the keys: conn_id, cs_id, pt_id, am_id, me_id, original_prompt, prepared_prompt, system_prompt
|
|
601
|
+
response, duration, start_time.
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
dict: A dict representation of the RedTeamingPromptArguments instance.
|
|
605
|
+
"""
|
|
606
|
+
return {
|
|
607
|
+
"conn_id": self.conn_id,
|
|
608
|
+
"cs_id": self.cs_id,
|
|
609
|
+
"pt_id": self.pt_id,
|
|
610
|
+
"am_id": self.am_id,
|
|
611
|
+
"me_id": self.me_id,
|
|
612
|
+
"original_prompt": self.original_prompt,
|
|
613
|
+
"prepared_prompt": self.connector_prompt.prompt,
|
|
614
|
+
"system_prompt": self.system_prompt,
|
|
615
|
+
"response": str(self.connector_prompt.predicted_results),
|
|
616
|
+
"duration": str(self.connector_prompt.duration),
|
|
617
|
+
"start_time": self.start_time,
|
|
618
|
+
}
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
from moonshot.src.redteaming.session.red_teaming_progress import RedTeamingProgress
|
|
6
|
+
from moonshot.src.storage.db_interface import DBInterface
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AttackModuleArguments(BaseModel):
|
|
10
|
+
class Config:
|
|
11
|
+
arbitrary_types_allowed = True
|
|
12
|
+
|
|
13
|
+
# list of connector endpoints
|
|
14
|
+
connector_ids: list = []
|
|
15
|
+
|
|
16
|
+
# list of prompt template ids to be used (if any)
|
|
17
|
+
prompt_templates: list = []
|
|
18
|
+
|
|
19
|
+
# user's prompt
|
|
20
|
+
prompt: str
|
|
21
|
+
|
|
22
|
+
# system prompt
|
|
23
|
+
system_prompt: str = ""
|
|
24
|
+
|
|
25
|
+
# list of metric ids to be used (if any)
|
|
26
|
+
metric_ids: list = []
|
|
27
|
+
|
|
28
|
+
# list of context strategy ids and other params to be used (if any)
|
|
29
|
+
context_strategy_info: list = []
|
|
30
|
+
|
|
31
|
+
# DBAccessor for the attack module to access DB data
|
|
32
|
+
db_instance: DBInterface
|
|
33
|
+
|
|
34
|
+
# chat batch size for returning chat information by callback
|
|
35
|
+
chat_batch_size: int = 0
|
|
36
|
+
|
|
37
|
+
# callback function to return chat information
|
|
38
|
+
red_teaming_progress: RedTeamingProgress | None = None
|
|
39
|
+
|
|
40
|
+
# an asyncio event to cancel red teaming if a cancel signal is sent
|
|
41
|
+
cancel_event: asyncio.Event
|
|
42
|
+
|
|
43
|
+
# a dict that contains other params that is required by the attack module (if any)
|
|
44
|
+
params: dict = {}
|