aiverify-moonshot 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
- aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
- aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
- moonshot/__init__.py +0 -0
- moonshot/__main__.py +198 -0
- moonshot/api.py +155 -0
- moonshot/integrations/__init__.py +0 -0
- moonshot/integrations/cli/__init__.py +0 -0
- moonshot/integrations/cli/__main__.py +25 -0
- moonshot/integrations/cli/active_session_cfg.py +1 -0
- moonshot/integrations/cli/benchmark/__init__.py +0 -0
- moonshot/integrations/cli/benchmark/benchmark.py +186 -0
- moonshot/integrations/cli/benchmark/cookbook.py +545 -0
- moonshot/integrations/cli/benchmark/datasets.py +164 -0
- moonshot/integrations/cli/benchmark/metrics.py +141 -0
- moonshot/integrations/cli/benchmark/recipe.py +598 -0
- moonshot/integrations/cli/benchmark/result.py +216 -0
- moonshot/integrations/cli/benchmark/run.py +140 -0
- moonshot/integrations/cli/benchmark/runner.py +174 -0
- moonshot/integrations/cli/cli.py +64 -0
- moonshot/integrations/cli/common/__init__.py +0 -0
- moonshot/integrations/cli/common/common.py +72 -0
- moonshot/integrations/cli/common/connectors.py +325 -0
- moonshot/integrations/cli/common/display_helper.py +42 -0
- moonshot/integrations/cli/common/prompt_template.py +94 -0
- moonshot/integrations/cli/initialisation/__init__.py +0 -0
- moonshot/integrations/cli/initialisation/initialisation.py +14 -0
- moonshot/integrations/cli/redteam/__init__.py +0 -0
- moonshot/integrations/cli/redteam/attack_module.py +70 -0
- moonshot/integrations/cli/redteam/context_strategy.py +147 -0
- moonshot/integrations/cli/redteam/prompt_template.py +67 -0
- moonshot/integrations/cli/redteam/redteam.py +90 -0
- moonshot/integrations/cli/redteam/session.py +467 -0
- moonshot/integrations/web_api/.env.dev +7 -0
- moonshot/integrations/web_api/__init__.py +0 -0
- moonshot/integrations/web_api/__main__.py +56 -0
- moonshot/integrations/web_api/app.py +125 -0
- moonshot/integrations/web_api/container.py +146 -0
- moonshot/integrations/web_api/log/.gitkeep +0 -0
- moonshot/integrations/web_api/logging_conf.py +114 -0
- moonshot/integrations/web_api/routes/__init__.py +0 -0
- moonshot/integrations/web_api/routes/attack_modules.py +66 -0
- moonshot/integrations/web_api/routes/benchmark.py +116 -0
- moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
- moonshot/integrations/web_api/routes/context_strategy.py +129 -0
- moonshot/integrations/web_api/routes/cookbook.py +225 -0
- moonshot/integrations/web_api/routes/dataset.py +120 -0
- moonshot/integrations/web_api/routes/endpoint.py +282 -0
- moonshot/integrations/web_api/routes/metric.py +78 -0
- moonshot/integrations/web_api/routes/prompt_template.py +128 -0
- moonshot/integrations/web_api/routes/recipe.py +219 -0
- moonshot/integrations/web_api/routes/redteam.py +609 -0
- moonshot/integrations/web_api/routes/runner.py +239 -0
- moonshot/integrations/web_api/schemas/__init__.py +0 -0
- moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
- moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
- moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
- moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
- moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
- moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
- moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
- moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
- moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
- moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
- moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
- moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
- moonshot/integrations/web_api/services/__init__.py +0 -0
- moonshot/integrations/web_api/services/attack_module_service.py +34 -0
- moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
- moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
- moonshot/integrations/web_api/services/base_service.py +8 -0
- moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
- moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
- moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
- moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
- moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
- moonshot/integrations/web_api/services/cookbook_service.py +194 -0
- moonshot/integrations/web_api/services/dataset_service.py +20 -0
- moonshot/integrations/web_api/services/endpoint_service.py +65 -0
- moonshot/integrations/web_api/services/metric_service.py +14 -0
- moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
- moonshot/integrations/web_api/services/recipe_service.py +155 -0
- moonshot/integrations/web_api/services/runner_service.py +147 -0
- moonshot/integrations/web_api/services/session_service.py +350 -0
- moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
- moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
- moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
- moonshot/integrations/web_api/types/types.py +99 -0
- moonshot/src/__init__.py +0 -0
- moonshot/src/api/__init__.py +0 -0
- moonshot/src/api/api_connector.py +58 -0
- moonshot/src/api/api_connector_endpoint.py +162 -0
- moonshot/src/api/api_context_strategy.py +57 -0
- moonshot/src/api/api_cookbook.py +160 -0
- moonshot/src/api/api_dataset.py +46 -0
- moonshot/src/api/api_environment_variables.py +17 -0
- moonshot/src/api/api_metrics.py +51 -0
- moonshot/src/api/api_prompt_template.py +43 -0
- moonshot/src/api/api_recipe.py +182 -0
- moonshot/src/api/api_red_teaming.py +59 -0
- moonshot/src/api/api_result.py +84 -0
- moonshot/src/api/api_run.py +74 -0
- moonshot/src/api/api_runner.py +132 -0
- moonshot/src/api/api_session.py +290 -0
- moonshot/src/configs/__init__.py +0 -0
- moonshot/src/configs/env_variables.py +187 -0
- moonshot/src/connectors/__init__.py +0 -0
- moonshot/src/connectors/connector.py +327 -0
- moonshot/src/connectors/connector_prompt_arguments.py +17 -0
- moonshot/src/connectors_endpoints/__init__.py +0 -0
- moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
- moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
- moonshot/src/cookbooks/__init__.py +0 -0
- moonshot/src/cookbooks/cookbook.py +225 -0
- moonshot/src/cookbooks/cookbook_arguments.py +34 -0
- moonshot/src/datasets/__init__.py +0 -0
- moonshot/src/datasets/dataset.py +255 -0
- moonshot/src/datasets/dataset_arguments.py +50 -0
- moonshot/src/metrics/__init__.py +0 -0
- moonshot/src/metrics/metric.py +192 -0
- moonshot/src/metrics/metric_interface.py +95 -0
- moonshot/src/prompt_templates/__init__.py +0 -0
- moonshot/src/prompt_templates/prompt_template.py +103 -0
- moonshot/src/recipes/__init__.py +0 -0
- moonshot/src/recipes/recipe.py +340 -0
- moonshot/src/recipes/recipe_arguments.py +111 -0
- moonshot/src/redteaming/__init__.py +0 -0
- moonshot/src/redteaming/attack/__init__.py +0 -0
- moonshot/src/redteaming/attack/attack_module.py +618 -0
- moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
- moonshot/src/redteaming/attack/context_strategy.py +131 -0
- moonshot/src/redteaming/context_strategy/__init__.py +0 -0
- moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
- moonshot/src/redteaming/session/__init__.py +0 -0
- moonshot/src/redteaming/session/chat.py +209 -0
- moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
- moonshot/src/redteaming/session/red_teaming_type.py +6 -0
- moonshot/src/redteaming/session/session.py +775 -0
- moonshot/src/results/__init__.py +0 -0
- moonshot/src/results/result.py +119 -0
- moonshot/src/results/result_arguments.py +44 -0
- moonshot/src/runners/__init__.py +0 -0
- moonshot/src/runners/runner.py +476 -0
- moonshot/src/runners/runner_arguments.py +46 -0
- moonshot/src/runners/runner_type.py +6 -0
- moonshot/src/runs/__init__.py +0 -0
- moonshot/src/runs/run.py +344 -0
- moonshot/src/runs/run_arguments.py +162 -0
- moonshot/src/runs/run_progress.py +145 -0
- moonshot/src/runs/run_status.py +10 -0
- moonshot/src/storage/__init__.py +0 -0
- moonshot/src/storage/db_interface.py +128 -0
- moonshot/src/storage/io_interface.py +31 -0
- moonshot/src/storage/storage.py +525 -0
- moonshot/src/utils/__init__.py +0 -0
- moonshot/src/utils/import_modules.py +96 -0
- moonshot/src/utils/timeit.py +25 -0
|
@@ -0,0 +1,467 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from ast import literal_eval
|
|
3
|
+
|
|
4
|
+
import cmd2
|
|
5
|
+
from rich.columns import Columns
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.panel import Panel
|
|
8
|
+
from rich.table import Table
|
|
9
|
+
|
|
10
|
+
from moonshot.api import (
|
|
11
|
+
api_create_runner,
|
|
12
|
+
api_create_session,
|
|
13
|
+
api_delete_session,
|
|
14
|
+
api_get_all_chats_from_session,
|
|
15
|
+
api_get_all_session_metadata,
|
|
16
|
+
api_load_runner,
|
|
17
|
+
api_load_session,
|
|
18
|
+
)
|
|
19
|
+
from moonshot.integrations.cli.active_session_cfg import active_session
|
|
20
|
+
from moonshot.src.redteaming.session.session import Session
|
|
21
|
+
|
|
22
|
+
console = Console()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def new_session(args) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Creates a new session based on the provided arguments.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
args (Namespace): The arguments passed to the function.
|
|
31
|
+
"""
|
|
32
|
+
global active_session
|
|
33
|
+
|
|
34
|
+
runner_id = args.runner_id
|
|
35
|
+
context_strategy = args.context_strategy if args.context_strategy else ""
|
|
36
|
+
prompt_template = args.prompt_template if args.prompt_template else ""
|
|
37
|
+
endpoints = literal_eval(args.endpoints) if args.endpoints else []
|
|
38
|
+
|
|
39
|
+
# create new runner and session
|
|
40
|
+
if endpoints:
|
|
41
|
+
runner = api_create_runner(runner_id, endpoints)
|
|
42
|
+
# load existing runner
|
|
43
|
+
else:
|
|
44
|
+
runner = api_load_runner(runner_id)
|
|
45
|
+
|
|
46
|
+
runner_args = {}
|
|
47
|
+
runner_args["context_strategy"] = context_strategy
|
|
48
|
+
runner_args["prompt_template"] = prompt_template
|
|
49
|
+
|
|
50
|
+
# create new session in runner
|
|
51
|
+
if runner.database_instance:
|
|
52
|
+
api_create_session(
|
|
53
|
+
runner.id, runner.database_instance, runner.endpoints, runner_args
|
|
54
|
+
)
|
|
55
|
+
session_metadata = api_load_session(runner_id)
|
|
56
|
+
if session_metadata:
|
|
57
|
+
active_session.update(session_metadata)
|
|
58
|
+
if active_session["context_strategy"]:
|
|
59
|
+
active_session[
|
|
60
|
+
"cs_num_of_prev_prompts"
|
|
61
|
+
] = Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
|
|
62
|
+
print(f"Using session: {active_session['session_id']}")
|
|
63
|
+
update_chat_display()
|
|
64
|
+
else:
|
|
65
|
+
raise RuntimeError("Unable to use session")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def use_session(args) -> None:
|
|
69
|
+
"""
|
|
70
|
+
Resumes a session by specifying its runner ID and updates the active session.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
args (Namespace): The arguments passed to the function.
|
|
74
|
+
"""
|
|
75
|
+
global active_session
|
|
76
|
+
runner_id = args.runner_id
|
|
77
|
+
|
|
78
|
+
# Load session metadata
|
|
79
|
+
try:
|
|
80
|
+
session_metadata = api_load_session(runner_id)
|
|
81
|
+
if not session_metadata:
|
|
82
|
+
print(
|
|
83
|
+
"[Session] Cannot find a session with the existing Runner ID. Please try again."
|
|
84
|
+
)
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
# Set the current session
|
|
88
|
+
active_session.update(session_metadata)
|
|
89
|
+
if active_session["context_strategy"]:
|
|
90
|
+
active_session[
|
|
91
|
+
"cs_num_of_prev_prompts"
|
|
92
|
+
] = Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
|
|
93
|
+
print(f"Using session: {active_session['session_id']}. ")
|
|
94
|
+
update_chat_display()
|
|
95
|
+
except Exception as e:
|
|
96
|
+
print(f"[use_session]: {str(e)}")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def end_session() -> None:
|
|
100
|
+
"""
|
|
101
|
+
Ends the current session by clearing active_session variable.
|
|
102
|
+
"""
|
|
103
|
+
global active_session
|
|
104
|
+
active_session.clear()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def list_sessions() -> None:
|
|
108
|
+
"""
|
|
109
|
+
Retrieves and displays the list of sessions.
|
|
110
|
+
|
|
111
|
+
This function retrieves the metadata in dict for all sessions and displays them in a tabular format.
|
|
112
|
+
If no sessions are found, a message is printed to the console.
|
|
113
|
+
"""
|
|
114
|
+
session_metadata_list = api_get_all_session_metadata()
|
|
115
|
+
if session_metadata_list:
|
|
116
|
+
table = Table(
|
|
117
|
+
title="Session List", show_lines=True, expand=True, header_style="bold"
|
|
118
|
+
)
|
|
119
|
+
table.add_column("No.", justify="left", width=2)
|
|
120
|
+
table.add_column("Session ID", justify="left", width=20)
|
|
121
|
+
table.add_column("Contains", justify="left", width=78)
|
|
122
|
+
|
|
123
|
+
for session_index, session_data in enumerate(session_metadata_list, 1):
|
|
124
|
+
session_id = session_data.get("session_id", "")
|
|
125
|
+
endpoints = ", ".join(session_data.get("endpoints", []))
|
|
126
|
+
created_datetime = session_data.get("created_datetime", "")
|
|
127
|
+
|
|
128
|
+
session_info = f"[red]id: {session_id}[/red]\n\nCreated: {created_datetime}"
|
|
129
|
+
contains_info = f"[blue]Endpoints:[/blue] {endpoints}\n\n"
|
|
130
|
+
table.add_row(str(session_index), session_info, contains_info)
|
|
131
|
+
console.print(table)
|
|
132
|
+
else:
|
|
133
|
+
console.print("[red]There are no sessions found.[/red]", style="bold")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def update_chat_display() -> None:
|
|
137
|
+
"""
|
|
138
|
+
Updates the chat display for the active session.
|
|
139
|
+
|
|
140
|
+
This function retrieves the chat details for the active session and prepares a table display for the chat history.
|
|
141
|
+
The table includes columns for the chat ID, prepared prompts, and the prompt/response pairs.
|
|
142
|
+
If there is no active session, a message is printed to the console.
|
|
143
|
+
"""
|
|
144
|
+
global active_session
|
|
145
|
+
|
|
146
|
+
if active_session:
|
|
147
|
+
list_of_endpoint_chats = api_get_all_chats_from_session(
|
|
148
|
+
active_session["session_id"]
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
# Prepare for table display
|
|
152
|
+
table = Table(expand=True, show_lines=True, header_style="bold")
|
|
153
|
+
table_list = []
|
|
154
|
+
for endpoint, endpoint_chats in list_of_endpoint_chats.items():
|
|
155
|
+
table.add_column(endpoint, justify="center")
|
|
156
|
+
new_table = Table(expand=True)
|
|
157
|
+
new_table.add_column(
|
|
158
|
+
"Prepared Prompts", justify="left", style="cyan", width=50
|
|
159
|
+
)
|
|
160
|
+
new_table.add_column("Prompt/Response", justify="left", width=50)
|
|
161
|
+
|
|
162
|
+
for chat_with_details in endpoint_chats:
|
|
163
|
+
new_table.add_row(
|
|
164
|
+
chat_with_details["prepared_prompt"],
|
|
165
|
+
(
|
|
166
|
+
f"[magenta]{chat_with_details['prompt']}[/magenta] \n"
|
|
167
|
+
f"|---> [green]{chat_with_details['predicted_result']}[/green]"
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
new_table.add_section()
|
|
171
|
+
table_list.append(new_table)
|
|
172
|
+
table.add_row(*table_list)
|
|
173
|
+
|
|
174
|
+
# Display table
|
|
175
|
+
panel = Panel.fit(
|
|
176
|
+
Columns([table], expand=True),
|
|
177
|
+
title=active_session["session_id"],
|
|
178
|
+
border_style="red",
|
|
179
|
+
title_align="left",
|
|
180
|
+
)
|
|
181
|
+
console.print(panel)
|
|
182
|
+
|
|
183
|
+
else:
|
|
184
|
+
console.print("[red]There are no active session.[/red]")
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def manual_red_teaming(user_prompt: str) -> None:
|
|
188
|
+
"""
|
|
189
|
+
Initiates manual red teaming with the provided user prompt.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
user_prompt (str): The user prompt to be used for manual red teaming.
|
|
193
|
+
|
|
194
|
+
If there is no active session, a message is printed to the console and the function returns.
|
|
195
|
+
|
|
196
|
+
The function then prepares the manual red teaming arguments and runs the red teaming process using the provided
|
|
197
|
+
user prompt, context strategy, and prompt template. After running the red teaming process, the session is reloaded.
|
|
198
|
+
"""
|
|
199
|
+
if not active_session:
|
|
200
|
+
print("There is no active session. Activate a session to start red teaming.")
|
|
201
|
+
return
|
|
202
|
+
prompt_template = (
|
|
203
|
+
[active_session["prompt_template"]] if active_session["prompt_template"] else []
|
|
204
|
+
)
|
|
205
|
+
context_strategy = (
|
|
206
|
+
active_session["context_strategy"] if active_session["context_strategy"] else []
|
|
207
|
+
)
|
|
208
|
+
num_of_prev_prompts = (
|
|
209
|
+
active_session["cs_num_of_prev_prompts"]
|
|
210
|
+
if active_session["context_strategy"]
|
|
211
|
+
else Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if context_strategy:
|
|
215
|
+
context_strategy_info = [
|
|
216
|
+
{
|
|
217
|
+
"context_strategy_id": context_strategy,
|
|
218
|
+
"num_of_prev_prompts": num_of_prev_prompts,
|
|
219
|
+
}
|
|
220
|
+
]
|
|
221
|
+
else:
|
|
222
|
+
context_strategy_info = []
|
|
223
|
+
|
|
224
|
+
mrt_arguments = {
|
|
225
|
+
"manual_rt_args": {
|
|
226
|
+
"prompt": user_prompt,
|
|
227
|
+
"context_strategy_info": context_strategy_info,
|
|
228
|
+
"prompt_template_ids": prompt_template,
|
|
229
|
+
}
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
# load runner, perform red teaming and close the runner
|
|
233
|
+
try:
|
|
234
|
+
runner = api_load_runner(active_session["session_id"])
|
|
235
|
+
loop = asyncio.get_event_loop()
|
|
236
|
+
loop.run_until_complete(runner.run_red_teaming(mrt_arguments))
|
|
237
|
+
runner.close()
|
|
238
|
+
_reload_session(active_session["session_id"])
|
|
239
|
+
except Exception as e:
|
|
240
|
+
print(f"[manual_red_teaming]: str({e})")
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def run_attack_module(args):
|
|
244
|
+
"""
|
|
245
|
+
Initiates automated red teaming with the provided arguments.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
args: The arguments for automated red teaming.
|
|
249
|
+
|
|
250
|
+
If there is no active session, a message is printed to the console and the function returns.
|
|
251
|
+
|
|
252
|
+
The function prepares the runner arguments for automated red teaming using the provided arguments such as
|
|
253
|
+
attack module ID, prompt, system prompt, context strategy, prompt template, and metric. It then loads the runner,
|
|
254
|
+
performs red teaming, closes the runner, and reloads the session metadata.
|
|
255
|
+
"""
|
|
256
|
+
if not active_session:
|
|
257
|
+
print("There is no active session. Activate a session to start red teaming.")
|
|
258
|
+
return
|
|
259
|
+
try:
|
|
260
|
+
attack_module_id = args.attack_module_id
|
|
261
|
+
prompt = args.prompt
|
|
262
|
+
system_prompt = args.system_prompt if args.system_prompt else ""
|
|
263
|
+
context_strategy = args.context_strategy or []
|
|
264
|
+
prompt_template = [args.prompt_template] if args.prompt_template else []
|
|
265
|
+
metric = [args.metric] if args.metric else []
|
|
266
|
+
num_of_prev_prompts = (
|
|
267
|
+
args.num_of_prev_prompts
|
|
268
|
+
if args.num_of_prev_prompts
|
|
269
|
+
else Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
if context_strategy:
|
|
273
|
+
context_strategy_info = [
|
|
274
|
+
{
|
|
275
|
+
"context_strategy_id": context_strategy,
|
|
276
|
+
"num_of_prev_prompts": num_of_prev_prompts,
|
|
277
|
+
}
|
|
278
|
+
]
|
|
279
|
+
else:
|
|
280
|
+
context_strategy_info = []
|
|
281
|
+
|
|
282
|
+
# form runner arguments
|
|
283
|
+
attack_strategy = [
|
|
284
|
+
{
|
|
285
|
+
"attack_module_id": attack_module_id,
|
|
286
|
+
"prompt": prompt,
|
|
287
|
+
"system_prompt": system_prompt,
|
|
288
|
+
"context_strategy_info": context_strategy_info,
|
|
289
|
+
"prompt_template_ids": prompt_template,
|
|
290
|
+
"metric_ids": metric,
|
|
291
|
+
}
|
|
292
|
+
]
|
|
293
|
+
runner_args = {}
|
|
294
|
+
runner_args["attack_strategies"] = attack_strategy
|
|
295
|
+
|
|
296
|
+
# load runner, perform red teaming and close the runner
|
|
297
|
+
|
|
298
|
+
runner = api_load_runner(active_session["session_id"])
|
|
299
|
+
loop = asyncio.get_event_loop()
|
|
300
|
+
loop.run_until_complete(runner.run_red_teaming(runner_args))
|
|
301
|
+
runner.close()
|
|
302
|
+
_reload_session(active_session["session_id"])
|
|
303
|
+
update_chat_display()
|
|
304
|
+
except Exception as e:
|
|
305
|
+
print(f"[run_attack_module]: str({e})")
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _reload_session(runner_id: str) -> None:
|
|
309
|
+
"""
|
|
310
|
+
Reloads the session metadata for the given runner ID and updates the active session.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
runner_id (str): The ID of the runner for which the session metadata needs to be reloaded.
|
|
314
|
+
"""
|
|
315
|
+
global active_session
|
|
316
|
+
try:
|
|
317
|
+
session_metadata = api_load_session(runner_id)
|
|
318
|
+
if not session_metadata:
|
|
319
|
+
print(
|
|
320
|
+
"[Session] Cannot find a session with the existing Runner ID. Please try again."
|
|
321
|
+
)
|
|
322
|
+
return
|
|
323
|
+
active_session.update(session_metadata)
|
|
324
|
+
except Exception as e:
|
|
325
|
+
print(f"[reload_session]: str({e})")
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def delete_session(args) -> None:
|
|
329
|
+
"""
|
|
330
|
+
Deletes a session after confirming with the user.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
args (object): The arguments object. It should have a 'session' attribute
|
|
334
|
+
which is the ID of the session to delete.
|
|
335
|
+
"""
|
|
336
|
+
# Confirm with the user before deleting a session
|
|
337
|
+
confirmation = console.input(
|
|
338
|
+
"[bold red]Are you sure you want to delete the session (y/N)? [/]"
|
|
339
|
+
)
|
|
340
|
+
if confirmation.lower() != "y":
|
|
341
|
+
console.print("[bold yellow]Session deletion cancelled.[/]")
|
|
342
|
+
return
|
|
343
|
+
try:
|
|
344
|
+
api_delete_session(args.session)
|
|
345
|
+
print("[delete_session]: Session deleted.")
|
|
346
|
+
except Exception as e:
|
|
347
|
+
print(f"[delete_session]: {str(e)}")
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
# use session arguments
|
|
351
|
+
use_session_args = cmd2.Cmd2ArgumentParser(
|
|
352
|
+
description="Use an existing red teaming session by specifying the runner ID.",
|
|
353
|
+
epilog="Example:\n use_session 'my-runner'",
|
|
354
|
+
)
|
|
355
|
+
use_session_args.add_argument(
|
|
356
|
+
"runner_id",
|
|
357
|
+
type=str,
|
|
358
|
+
help="The ID of the runner which contains the session you want to use.",
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
# new session arguments
|
|
362
|
+
new_session_args = cmd2.Cmd2ArgumentParser(
|
|
363
|
+
description="Creates a new red teaming session.",
|
|
364
|
+
epilog=(
|
|
365
|
+
"Example(create new runner): new_session my-runner -e \"['openai-gpt4']\" -c add_previous_prompt -p mmlu\n"
|
|
366
|
+
"Example(load existing runner): new_session my-runner -c add_previous_prompt -p mmlu"
|
|
367
|
+
),
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
new_session_args.add_argument(
|
|
371
|
+
"runner_id",
|
|
372
|
+
type=str,
|
|
373
|
+
help="ID of the runner. Creates a new runner if runner does not exist.",
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
new_session_args.add_argument(
|
|
377
|
+
"-e",
|
|
378
|
+
"--endpoints",
|
|
379
|
+
type=str,
|
|
380
|
+
help="List of endpoint(s) for the runner that is only compulsory for creating a new runner.",
|
|
381
|
+
nargs="?",
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
new_session_args.add_argument(
|
|
385
|
+
"-c",
|
|
386
|
+
"--context_strategy",
|
|
387
|
+
type=str,
|
|
388
|
+
help=(
|
|
389
|
+
"Name of the context_strategy to be used - indicate context strategy here if you wish to use with "
|
|
390
|
+
"the selected attack."
|
|
391
|
+
),
|
|
392
|
+
nargs="?",
|
|
393
|
+
)
|
|
394
|
+
new_session_args.add_argument(
|
|
395
|
+
"-p",
|
|
396
|
+
"--prompt_template",
|
|
397
|
+
type=str,
|
|
398
|
+
help=(
|
|
399
|
+
"Name of the prompt template to be used - indicate prompt template here if you wish to use with "
|
|
400
|
+
"the selected attack."
|
|
401
|
+
),
|
|
402
|
+
nargs="?",
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
# automated red teaming arguments
|
|
406
|
+
automated_rt_session_args = cmd2.Cmd2ArgumentParser(
|
|
407
|
+
description="Runs automated red teaming in the current session.",
|
|
408
|
+
epilog=(
|
|
409
|
+
'Example:\n run_attack_module sample_attack_module "this is my prompt" -s "test system prompt" '
|
|
410
|
+
'-c "add_previous_prompt" -p "mmlu" -m "bleuscore"'
|
|
411
|
+
),
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
automated_rt_session_args.add_argument(
|
|
415
|
+
"attack_module_id", type=str, help="ID of the attack module."
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
automated_rt_session_args.add_argument(
|
|
419
|
+
"prompt", type=str, help="Prompt to be used for the attack."
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
automated_rt_session_args.add_argument(
|
|
423
|
+
"-s",
|
|
424
|
+
"--system_prompt",
|
|
425
|
+
type=str,
|
|
426
|
+
help="System Prompt to be used for the attack. If not specified, the default system prompt will be used.",
|
|
427
|
+
nargs="?",
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
automated_rt_session_args.add_argument(
|
|
431
|
+
"-c",
|
|
432
|
+
"--context_strategy",
|
|
433
|
+
type=str,
|
|
434
|
+
help="Name of the context strategy module to be used.",
|
|
435
|
+
nargs="?",
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
automated_rt_session_args.add_argument(
|
|
439
|
+
"-n",
|
|
440
|
+
"--num_of_prev_prompts",
|
|
441
|
+
type=str,
|
|
442
|
+
help="The number of previous prompts to use with the context strategy.",
|
|
443
|
+
nargs="?",
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
automated_rt_session_args.add_argument(
|
|
447
|
+
"-p",
|
|
448
|
+
"--prompt-template",
|
|
449
|
+
type=str,
|
|
450
|
+
help="Name of the prompt template to be used.",
|
|
451
|
+
nargs="?",
|
|
452
|
+
)
|
|
453
|
+
|
|
454
|
+
automated_rt_session_args.add_argument(
|
|
455
|
+
"-m", "--metric", type=str, help="Name of the metric module to be used.", nargs="?"
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
# Delete session arguments
|
|
460
|
+
delete_session_args = cmd2.Cmd2ArgumentParser(
|
|
461
|
+
description="Delete a session",
|
|
462
|
+
epilog="Example:\n delete_session my-test-runner",
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
delete_session_args.add_argument(
|
|
466
|
+
"session", type=str, help="The runner ID of the session to delete"
|
|
467
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import uvicorn
|
|
5
|
+
from dotenv import dotenv_values, load_dotenv
|
|
6
|
+
|
|
7
|
+
from .app import create_app
|
|
8
|
+
from .container import Container
|
|
9
|
+
from .logging_conf import configure_app_logging, create_uvicorn_log_config
|
|
10
|
+
from .types.types import UvicornRunArgs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def start_app():
|
|
14
|
+
load_dotenv()
|
|
15
|
+
container: Container = Container()
|
|
16
|
+
# use our own config.yml
|
|
17
|
+
config_file = dotenv_values().get("MS_WEB_API_CONFIG")
|
|
18
|
+
if config_file is None:
|
|
19
|
+
container.config.from_default()
|
|
20
|
+
else:
|
|
21
|
+
container.config.from_yaml(f"{config_file}", required=True)
|
|
22
|
+
|
|
23
|
+
configure_app_logging(container.config)
|
|
24
|
+
logging.info(f"Environment: {container.config.app_environment()}")
|
|
25
|
+
ENABLE_SSL = container.config.ssl.enabled()
|
|
26
|
+
SSL_CERT_PATH = container.config.ssl.file_path()
|
|
27
|
+
app = create_app(container.config)
|
|
28
|
+
|
|
29
|
+
run_kwargs: UvicornRunArgs = {}
|
|
30
|
+
port = dotenv_values().get("HOST_PORT", 5000)
|
|
31
|
+
if port is not None:
|
|
32
|
+
port = int(port)
|
|
33
|
+
run_kwargs["port"] = port
|
|
34
|
+
run_kwargs["host"] = dotenv_values().get("HOST_ADDRESS", "127.0.0.1")
|
|
35
|
+
run_kwargs["log_config"] = create_uvicorn_log_config(container.config)
|
|
36
|
+
if ENABLE_SSL:
|
|
37
|
+
if not SSL_CERT_PATH:
|
|
38
|
+
logging.debug("SSL_CERT_PATH not set, not enabling SSL")
|
|
39
|
+
elif os.path.exists(os.path.join(SSL_CERT_PATH, "key.pem")) and os.path.exists(
|
|
40
|
+
os.path.join(SSL_CERT_PATH, "cert.pem")
|
|
41
|
+
):
|
|
42
|
+
run_kwargs["ssl_keyfile"] = str(
|
|
43
|
+
os.path.join(SSL_CERT_PATH, str(container.config.ssl.key_filename()))
|
|
44
|
+
)
|
|
45
|
+
run_kwargs["ssl_certfile"] = str(
|
|
46
|
+
os.path.join(SSL_CERT_PATH, str(container.config.ssl.cert_filename()))
|
|
47
|
+
)
|
|
48
|
+
else:
|
|
49
|
+
logging.debug(
|
|
50
|
+
"SSL_CERT_PATH does not contain necessary files, not enabling SSL"
|
|
51
|
+
)
|
|
52
|
+
uvicorn.run(app, **run_kwargs) # type: ignore
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
if __name__ == "__main__":
|
|
56
|
+
start_app()
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import asynccontextmanager
|
|
4
|
+
from typing import Awaitable, Callable
|
|
5
|
+
|
|
6
|
+
from dependency_injector.wiring import providers
|
|
7
|
+
from fastapi import FastAPI, Request, Response
|
|
8
|
+
from fastapi.exceptions import RequestValidationError
|
|
9
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
10
|
+
from fastapi.responses import JSONResponse
|
|
11
|
+
|
|
12
|
+
from .container import Container
|
|
13
|
+
from .routes import (
|
|
14
|
+
attack_modules,
|
|
15
|
+
benchmark,
|
|
16
|
+
benchmark_result,
|
|
17
|
+
cookbook,
|
|
18
|
+
dataset,
|
|
19
|
+
endpoint,
|
|
20
|
+
metric,
|
|
21
|
+
prompt_template,
|
|
22
|
+
context_strategy,
|
|
23
|
+
recipe,
|
|
24
|
+
runner,
|
|
25
|
+
)
|
|
26
|
+
from .routes.redteam import router as red_team_router
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CustomFastAPI(FastAPI):
|
|
32
|
+
container: Container
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
async def monitor_tasks(loop: asyncio.AbstractEventLoop):
|
|
36
|
+
while True:
|
|
37
|
+
tasks = asyncio.all_tasks(loop)
|
|
38
|
+
for task in tasks:
|
|
39
|
+
logger.debug(task)
|
|
40
|
+
await asyncio.sleep(1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@asynccontextmanager
|
|
44
|
+
async def lifespan(app: FastAPI):
|
|
45
|
+
loop = asyncio.get_running_loop()
|
|
46
|
+
loop.create_task(monitor_tasks(loop))
|
|
47
|
+
yield
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def log_request_origin(
|
|
51
|
+
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
|
52
|
+
):
|
|
53
|
+
origin = request.headers.get("origin")
|
|
54
|
+
logger.info(f"Request origin: {origin}")
|
|
55
|
+
response = await call_next(request)
|
|
56
|
+
return response
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def create_app(cfg: providers.Configuration) -> CustomFastAPI:
|
|
60
|
+
if cfg.asyncio.monitor_task():
|
|
61
|
+
logger.warn("Monitoring tasks in uvicorn's asyncio event loop")
|
|
62
|
+
|
|
63
|
+
app_kwargs = {}
|
|
64
|
+
if cfg.asyncio.monitor_task():
|
|
65
|
+
app_kwargs["lifespan"] = lifespan
|
|
66
|
+
|
|
67
|
+
app_kwargs["swagger_ui_parameters"] = {
|
|
68
|
+
"defaultModelsExpandDepth": -1,
|
|
69
|
+
"docExpansion": None
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
app: CustomFastAPI = CustomFastAPI(
|
|
73
|
+
title="Project Moonshot",
|
|
74
|
+
version="0.4.0",
|
|
75
|
+
**app_kwargs
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
if cfg.cors.enabled():
|
|
79
|
+
logger.info("CORS is enabled")
|
|
80
|
+
allowed_origins_raw: str = cfg.cors.allowed_origins()
|
|
81
|
+
allowed_origins = allowed_origins_raw.split(",") if allowed_origins_raw else []
|
|
82
|
+
app.add_middleware(
|
|
83
|
+
CORSMiddleware,
|
|
84
|
+
allow_origins=allowed_origins,
|
|
85
|
+
allow_credentials=True,
|
|
86
|
+
allow_methods=["*"],
|
|
87
|
+
allow_headers=["*"],
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
logger.warn("CORS is disabled")
|
|
91
|
+
|
|
92
|
+
if cfg.app_environment().upper() in ["DEV", "DEVELOPMENT", "LOCAL"]:
|
|
93
|
+
app.middleware("http")(log_request_origin)
|
|
94
|
+
|
|
95
|
+
app.include_router(red_team_router)
|
|
96
|
+
app.include_router(prompt_template.router)
|
|
97
|
+
app.include_router(context_strategy.router)
|
|
98
|
+
app.include_router(benchmark.router)
|
|
99
|
+
app.include_router(endpoint.router)
|
|
100
|
+
app.include_router(recipe.router)
|
|
101
|
+
app.include_router(cookbook.router)
|
|
102
|
+
app.include_router(benchmark_result.router)
|
|
103
|
+
app.include_router(metric.router)
|
|
104
|
+
app.include_router(runner.router)
|
|
105
|
+
app.include_router(dataset.router)
|
|
106
|
+
app.include_router(attack_modules.router)
|
|
107
|
+
|
|
108
|
+
@app.exception_handler(RequestValidationError)
|
|
109
|
+
async def validation_exception_handler(
|
|
110
|
+
request: Request, exc: RequestValidationError
|
|
111
|
+
) -> JSONResponse:
|
|
112
|
+
modified_errors: list[str] = []
|
|
113
|
+
for error in exc.errors():
|
|
114
|
+
# Remove the 'url' key from the error detail if it exists
|
|
115
|
+
if "url" in error:
|
|
116
|
+
del error["url"]
|
|
117
|
+
modified_errors.append(error)
|
|
118
|
+
|
|
119
|
+
logger.error(f"Validation error for request {request.url}: {exc.errors()}")
|
|
120
|
+
return JSONResponse(
|
|
121
|
+
status_code=422,
|
|
122
|
+
content={"error": exc.errors()},
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return app
|