aiverify-moonshot 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/METADATA +2 -2
  2. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/RECORD +70 -56
  3. moonshot/__main__.py +77 -35
  4. moonshot/api.py +16 -0
  5. moonshot/integrations/cli/benchmark/benchmark.py +29 -13
  6. moonshot/integrations/cli/benchmark/cookbook.py +62 -24
  7. moonshot/integrations/cli/benchmark/datasets.py +79 -40
  8. moonshot/integrations/cli/benchmark/metrics.py +62 -23
  9. moonshot/integrations/cli/benchmark/recipe.py +89 -69
  10. moonshot/integrations/cli/benchmark/result.py +85 -47
  11. moonshot/integrations/cli/benchmark/run.py +99 -59
  12. moonshot/integrations/cli/common/common.py +20 -6
  13. moonshot/integrations/cli/common/connectors.py +154 -74
  14. moonshot/integrations/cli/common/dataset.py +66 -0
  15. moonshot/integrations/cli/common/prompt_template.py +57 -19
  16. moonshot/integrations/cli/redteam/attack_module.py +90 -24
  17. moonshot/integrations/cli/redteam/context_strategy.py +83 -23
  18. moonshot/integrations/cli/redteam/prompt_template.py +1 -1
  19. moonshot/integrations/cli/redteam/redteam.py +52 -6
  20. moonshot/integrations/cli/redteam/session.py +565 -44
  21. moonshot/integrations/cli/utils/process_data.py +52 -0
  22. moonshot/integrations/web_api/__main__.py +2 -0
  23. moonshot/integrations/web_api/app.py +6 -6
  24. moonshot/integrations/web_api/container.py +12 -2
  25. moonshot/integrations/web_api/routes/bookmark.py +173 -0
  26. moonshot/integrations/web_api/routes/dataset.py +46 -1
  27. moonshot/integrations/web_api/schemas/bookmark_create_dto.py +13 -0
  28. moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
  29. moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
  30. moonshot/integrations/web_api/services/bookmark_service.py +94 -0
  31. moonshot/integrations/web_api/services/dataset_service.py +25 -0
  32. moonshot/integrations/web_api/services/recipe_service.py +0 -1
  33. moonshot/integrations/web_api/services/utils/file_manager.py +52 -0
  34. moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +0 -1
  35. moonshot/integrations/web_api/temp/.gitkeep +0 -0
  36. moonshot/src/api/api_bookmark.py +95 -0
  37. moonshot/src/api/api_connector_endpoint.py +1 -1
  38. moonshot/src/api/api_context_strategy.py +2 -2
  39. moonshot/src/api/api_dataset.py +35 -0
  40. moonshot/src/api/api_recipe.py +0 -3
  41. moonshot/src/api/api_session.py +1 -1
  42. moonshot/src/bookmark/bookmark.py +257 -0
  43. moonshot/src/bookmark/bookmark_arguments.py +38 -0
  44. moonshot/src/configs/env_variables.py +12 -2
  45. moonshot/src/connectors/connector.py +15 -7
  46. moonshot/src/connectors_endpoints/connector_endpoint.py +65 -49
  47. moonshot/src/cookbooks/cookbook.py +57 -37
  48. moonshot/src/datasets/dataset.py +125 -5
  49. moonshot/src/metrics/metric.py +8 -4
  50. moonshot/src/metrics/metric_interface.py +8 -2
  51. moonshot/src/prompt_templates/prompt_template.py +5 -1
  52. moonshot/src/recipes/recipe.py +38 -40
  53. moonshot/src/recipes/recipe_arguments.py +0 -4
  54. moonshot/src/redteaming/attack/attack_module.py +18 -8
  55. moonshot/src/redteaming/attack/context_strategy.py +6 -2
  56. moonshot/src/redteaming/session/session.py +15 -11
  57. moonshot/src/results/result.py +7 -3
  58. moonshot/src/runners/runner.py +65 -42
  59. moonshot/src/runs/run.py +15 -11
  60. moonshot/src/runs/run_progress.py +7 -3
  61. moonshot/src/storage/db_interface.py +14 -0
  62. moonshot/src/storage/storage.py +33 -2
  63. moonshot/src/utils/find_feature.py +45 -0
  64. moonshot/src/utils/log.py +72 -0
  65. moonshot/src/utils/pagination.py +25 -0
  66. moonshot/src/utils/timeit.py +8 -1
  67. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/WHEEL +0 -0
  68. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/AUTHORS.md +0 -0
  69. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/LICENSE.md +0 -0
  70. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/NOTICES.md +0 -0
@@ -10,13 +10,19 @@ from rich.table import Table
10
10
  from moonshot.api import (
11
11
  api_create_runner,
12
12
  api_create_session,
13
+ api_delete_bookmark,
13
14
  api_delete_session,
15
+ api_export_bookmarks,
16
+ api_get_all_bookmarks,
14
17
  api_get_all_chats_from_session,
15
18
  api_get_all_session_metadata,
19
+ api_get_bookmark,
20
+ api_insert_bookmark,
16
21
  api_load_runner,
17
22
  api_load_session,
18
23
  )
19
24
  from moonshot.integrations.cli.active_session_cfg import active_session
25
+ from moonshot.integrations.cli.utils.process_data import filter_data
20
26
  from moonshot.src.redteaming.session.session import Session
21
27
 
22
28
  console = Console()
@@ -24,11 +30,17 @@ console = Console()
24
30
 
25
31
  def new_session(args) -> None:
26
32
  """
27
- Creates a new session based on the provided arguments.
33
+ Creates a new red teaming session or loads an existing one.
34
+
35
+ This function either creates a new runner and session or loads an existing runner based on the provided arguments.
36
+ It updates the global active_session with the session metadata and displays the chat history.
28
37
 
29
38
  Args:
30
- args (Namespace): The arguments passed to the function.
31
- """
39
+ args (Namespace): The arguments passed to the function, containing:
40
+ - runner_id (str): The ID of the runner.
41
+ - context_strategy (str, optional): The context strategy to be used.
42
+ - prompt_template (str, optional): The prompt template to be used.
43
+ - endpoints (str, optional): The list of endpoints for the runner."""
32
44
  global active_session
33
45
 
34
46
  runner_id = args.runner_id
@@ -52,7 +64,7 @@ def new_session(args) -> None:
52
64
  api_create_session(
53
65
  runner.id, runner.database_instance, runner.endpoints, runner_args
54
66
  )
55
- session_metadata = api_load_session(runner_id)
67
+ session_metadata = api_load_session(runner.id)
56
68
  if session_metadata:
57
69
  active_session.update(session_metadata)
58
70
  if active_session["context_strategy"]:
@@ -96,6 +108,19 @@ def use_session(args) -> None:
96
108
  print(f"[use_session]: {str(e)}")
97
109
 
98
110
 
111
+ def show_prompts() -> None:
112
+ """
113
+ Shows the chat table in a session so that users don't have to restart a session to view the chat table
114
+ """
115
+ global active_session
116
+
117
+ if not active_session:
118
+ print("There is no active session. Activate a session to show a chat table.")
119
+ return
120
+
121
+ update_chat_display()
122
+
123
+
99
124
  def end_session() -> None:
100
125
  """
101
126
  Ends the current session by clearing active_session variable.
@@ -104,33 +129,39 @@ def end_session() -> None:
104
129
  active_session.clear()
105
130
 
106
131
 
107
- def list_sessions() -> None:
132
+ def list_sessions(args) -> list | None:
108
133
  """
109
134
  Retrieves and displays the list of sessions.
110
135
 
111
136
  This function retrieves the metadata in dict for all sessions and displays them in a tabular format.
112
137
  If no sessions are found, a message is printed to the console.
138
+
139
+ Args:
140
+ args: A namespace object from argparse. It should have an optional attribute:
141
+ find (str): Optional field to find session(s) with a keyword.
142
+ pagination (str): Optional field to paginate sessions.
143
+
144
+ Returns:
145
+ list | None: A list of Session or None if there is no result.
113
146
  """
114
- session_metadata_list = api_get_all_session_metadata()
115
- if session_metadata_list:
116
- table = Table(
117
- title="Session List", show_lines=True, expand=True, header_style="bold"
118
- )
119
- table.add_column("No.", justify="left", width=2)
120
- table.add_column("Session ID", justify="left", width=20)
121
- table.add_column("Contains", justify="left", width=78)
122
-
123
- for session_index, session_data in enumerate(session_metadata_list, 1):
124
- session_id = session_data.get("session_id", "")
125
- endpoints = ", ".join(session_data.get("endpoints", []))
126
- created_datetime = session_data.get("created_datetime", "")
127
-
128
- session_info = f"[red]id: {session_id}[/red]\n\nCreated: {created_datetime}"
129
- contains_info = f"[blue]Endpoints:[/blue] {endpoints}\n\n"
130
- table.add_row(str(session_index), session_info, contains_info)
131
- console.print(table)
132
- else:
133
- console.print("[red]There are no sessions found.[/red]", style="bold")
147
+ try:
148
+ session_metadata_list = api_get_all_session_metadata()
149
+ keyword = args.find.lower() if args.find else ""
150
+ pagination = literal_eval(args.pagination) if args.pagination else ()
151
+
152
+ if session_metadata_list:
153
+ filtered_session_metadata_list = filter_data(
154
+ session_metadata_list, keyword, pagination
155
+ )
156
+ if filtered_session_metadata_list:
157
+ _display_sessions(filtered_session_metadata_list)
158
+ return filtered_session_metadata_list
159
+
160
+ console.print("[red]There are no sessions found.[/red]")
161
+ return None
162
+
163
+ except Exception as e:
164
+ print(f"[list_sessions]: {str(e)}")
134
165
 
135
166
 
136
167
  def update_chat_display() -> None:
@@ -151,16 +182,20 @@ def update_chat_display() -> None:
151
182
  # Prepare for table display
152
183
  table = Table(expand=True, show_lines=True, header_style="bold")
153
184
  table_list = []
185
+ active_session["list_of_endpoint_chats"] = list_of_endpoint_chats
186
+
154
187
  for endpoint, endpoint_chats in list_of_endpoint_chats.items():
155
188
  table.add_column(endpoint, justify="center")
156
189
  new_table = Table(expand=True)
190
+ new_table.add_column("ID", justify="left", ratio=1, min_width=5)
157
191
  new_table.add_column(
158
- "Prepared Prompts", justify="left", style="cyan", width=50
192
+ "Prepared Prompts", justify="left", style="cyan", ratio=7
159
193
  )
160
- new_table.add_column("Prompt/Response", justify="left", width=50)
194
+ new_table.add_column("Prompt/Response", justify="left", ratio=7)
161
195
 
162
196
  for chat_with_details in endpoint_chats:
163
197
  new_table.add_row(
198
+ str(chat_with_details["chat_record_id"]),
164
199
  chat_with_details["prepared_prompt"],
165
200
  (
166
201
  f"[magenta]{chat_with_details['prompt']}[/magenta] \n"
@@ -184,6 +219,308 @@ def update_chat_display() -> None:
184
219
  console.print("[red]There are no active session.[/red]")
185
220
 
186
221
 
222
+ def add_bookmark(args) -> None:
223
+ """
224
+ Bookmarks a specific prompt in the active session.
225
+
226
+ This function retrieves a specific chat record from the active session based on the provided endpoint and prompt ID.
227
+ If the chat record is found, it inserts a bookmark with the specified name and the details of the chat record.
228
+ If the chat record is not found, it prints an error message.
229
+
230
+ Args:
231
+ args (Namespace): The arguments passed to the function, containing:
232
+ - endpoint (str): The endpoint to which the prompt was sent.
233
+ - prompt_id (int): The ID of the prompt (the leftmost column).
234
+ - bookmark_name (str): The name of the bookmark to be created.
235
+
236
+ If there is no active session, a message is printed to the console and the function returns.
237
+ """
238
+ global active_session
239
+
240
+ if active_session:
241
+ try:
242
+ endpoint = args.endpoint
243
+ prompt_id = args.prompt_id
244
+ bookmark_name = args.bookmark_name
245
+
246
+ list_of_target_endpoint_chat = active_session.get(
247
+ "list_of_endpoint_chats", None
248
+ )
249
+ target_endpoint_chats = list_of_target_endpoint_chat.get(endpoint, None)
250
+ target_endpoint_chat_record = {}
251
+ if not target_endpoint_chats:
252
+ print(
253
+ "Incorrect endpoint. Please select a valid endpoint in this session."
254
+ )
255
+ return
256
+ for endpoint_chat in target_endpoint_chats:
257
+ if endpoint_chat["chat_record_id"] == prompt_id:
258
+ # found the prompt to bookmark
259
+ target_endpoint_chat_record = endpoint_chat
260
+ break
261
+
262
+ if target_endpoint_chat_record:
263
+ bookmark_message = api_insert_bookmark(
264
+ bookmark_name,
265
+ target_endpoint_chat_record["prompt"],
266
+ target_endpoint_chat_record["prepared_prompt"],
267
+ target_endpoint_chat_record["predicted_result"],
268
+ target_endpoint_chat_record["context_strategy"],
269
+ target_endpoint_chat_record["prompt_template"],
270
+ target_endpoint_chat_record["attack_module"],
271
+ target_endpoint_chat_record["metric"],
272
+ )
273
+ print("[bookmark_prompt]:", bookmark_message["message"])
274
+ else:
275
+ print(
276
+ f"Unable to find prompt ID in the of prompts for endpoint {endpoint}. Please select a valid ID."
277
+ )
278
+ except Exception as e:
279
+ print(f"[bookmark_prompt]: ({str(e)})")
280
+ else:
281
+ print("There is no active session. Activate a session to bookmark a prompt.")
282
+ return
283
+
284
+
285
+ def use_bookmark(args) -> None:
286
+ """
287
+ Updates the current session with the details from a specified bookmark.
288
+
289
+ This function retrieves the details of a bookmark by its ID and updates the active session's context strategy
290
+ and prompt template with the bookmark's details. If the bookmark includes an attack module, it crafts a CLI
291
+ command for the user to copy and paste. Otherwise, it provides the bookmarked prompt for manual red teaming.
292
+
293
+ Args:
294
+ args (Namespace): The arguments passed to the function, containing:
295
+ - bookmark_name (str): The ID of the bookmark to use.
296
+
297
+ If there is no active session, a message is printed to the console and the function returns.
298
+ """
299
+ global active_session
300
+ if active_session:
301
+ try:
302
+ bookmark_name = args.bookmark_name
303
+ bookmark_details = api_get_bookmark(bookmark_name)
304
+ if bookmark_details:
305
+ bookmarked_prompt = bookmark_details["prepared_prompt"]
306
+
307
+ # automated redteaming: craft CLI command for user to copy and paste
308
+ if bookmark_details["attack_module"]:
309
+ attack_module = bookmark_details["attack_module"]
310
+ run_attack_module_cmd = (
311
+ f'run_attack_module {attack_module} "{bookmarked_prompt}"'
312
+ )
313
+ console.print(
314
+ f"[bold yellow]Copy this command and paste it below:[/]\n{run_attack_module_cmd}\n"
315
+ )
316
+
317
+ # manual redteaming: return prompt for user to copy and paste
318
+ else:
319
+ console.print(
320
+ f"[bold yellow]Copy this prompt and paste it below: [/]\n{bookmarked_prompt}\n"
321
+ )
322
+ return
323
+ except Exception as e:
324
+ print(f"[use_bookmark]: {str(e)}")
325
+ else:
326
+ print("There is no active session. Activate a session to use a bookmark.")
327
+ return
328
+
329
+
330
+ def delete_bookmark(args) -> None:
331
+ """
332
+ Delete a bookmark.
333
+
334
+ This function deletes a cookbook with the specified identifier. It prompts the user for confirmation before
335
+ proceeding with the deletion. If the user confirms, it calls the api_delete_bookmark function from the moonshot.api
336
+ module to delete the bookmark. If the deletion is successful, it prints a confirmation message.
337
+
338
+ If an exception occurs, it prints an error message.
339
+
340
+ Args:
341
+ args: A namespace object from argparse. It should have the following attribute:
342
+ bookmark_name (str): The identifier of the bookmark to delete.
343
+
344
+ Returns:
345
+ None
346
+ """
347
+ # Confirm with the user before deleting a bookmark
348
+ confirmation = console.input(
349
+ "[bold red]Are you sure you want to delete the bookmark (y/N)? [/]"
350
+ )
351
+ if confirmation.lower() != "y":
352
+ console.print("[bold yellow]Bookmark deletion cancelled.[/]")
353
+ return
354
+ try:
355
+ bookmark_message = api_delete_bookmark(args.bookmark_name)
356
+ print("[delete_bookmark]:", bookmark_message["message"])
357
+ except Exception as e:
358
+ print(f"[delete_bookmark]: {str(e)}")
359
+
360
+
361
+ def list_bookmarks(args) -> list | None:
362
+ """
363
+ List all available bookmarks.
364
+
365
+ This function retrieves all available bookmarks by calling the api_get_all_bookmarks function from the
366
+ moonshot.api module.
367
+ It then displays the retrieved bookmarks using the _display_bookmarks function.
368
+ If no bookmarks are found, a message is printed to the console.
369
+
370
+ Args:
371
+ args: A namespace object from argparse. It should have an optional attribute:
372
+ find (str): Optional field to find bookmark(s) with a keyword.
373
+ pagination (str): Optional field to paginate bookmarks.
374
+
375
+ Returns:
376
+ list | None: A list of Bookmark or None if there is no result.
377
+ """
378
+ try:
379
+ bookmarks_list = api_get_all_bookmarks()
380
+ keyword = args.find.lower() if args.find else ""
381
+ pagination = literal_eval(args.pagination) if args.pagination else ()
382
+
383
+ if bookmarks_list:
384
+ filtered_bookmarks_list = filter_data(bookmarks_list, keyword, pagination)
385
+ if filtered_bookmarks_list:
386
+ _display_bookmarks(filtered_bookmarks_list)
387
+ return filtered_bookmarks_list
388
+
389
+ console.print("[red]There are no bookmarks found.[/red]")
390
+ return None
391
+
392
+ except Exception as e:
393
+ print(f"[list_bookmarks]: {str(e)}")
394
+
395
+
396
+ def _display_bookmarks(bookmarks_list) -> None:
397
+ """
398
+ Display the list of bookmarks in a tabular format.
399
+
400
+ This function takes a list of bookmarks dictionaries and displays each bookmark's details in a table.
401
+ The table includes an autogenerated index, name, prepared prompt, response and bookmark time.
402
+ If the list is empty, it prints a message indicating that no bookmarks are found.
403
+
404
+ Args:
405
+ bookmarks_list (list): A list of dictionaries, where each dictionary contains the details of a bookmark.
406
+ """
407
+
408
+ table = Table(
409
+ title="Bookmark List", show_lines=True, expand=True, header_style="bold"
410
+ )
411
+ table.add_column("ID.", justify="left", width=5)
412
+ table.add_column("Name", justify="left", width=20)
413
+ table.add_column("Prepared Prompt", justify="left", width=50)
414
+ table.add_column("Predicted Response", justify="left", width=50)
415
+ table.add_column("Bookmark Time", justify="left", width=20)
416
+ for idx, bookmark in enumerate(bookmarks_list, 1):
417
+ (
418
+ name,
419
+ prompt,
420
+ prepared_prompt,
421
+ response,
422
+ context_strategy,
423
+ prompt_template,
424
+ attack_module,
425
+ metric,
426
+ bookmark_time,
427
+ *other_args,
428
+ ) = bookmark.values()
429
+ idx = bookmark.get("idx", idx)
430
+
431
+ table.add_section()
432
+ table.add_row(
433
+ str(idx),
434
+ name,
435
+ prepared_prompt,
436
+ response,
437
+ bookmark_time,
438
+ )
439
+ console.print(table)
440
+
441
+
442
+ def view_bookmark(args) -> None:
443
+ """
444
+ Displays the details of a specific bookmark by its ID.
445
+
446
+ Args:
447
+ args (Namespace): The arguments passed to the function, containing the bookmark ID.
448
+ """
449
+
450
+ try:
451
+ bookmark_info = api_get_bookmark(args.bookmark_name)
452
+ _display_bookmark(bookmark_info)
453
+ except Exception as e:
454
+ print(f"[view_bookmark]: {str(e)}")
455
+
456
+
457
+ def _display_bookmark(bookmark_info: dict) -> None:
458
+ """
459
+ Display the filtered bookmark in a tabular format.
460
+
461
+ This function takes a list of bookmarks dictionaries and displays the target bookmark's details in a table.
462
+ The table includes the bookmark name, prompt, prepared prompt, response, context strategy, prompt template,
463
+ attack module, metric and bookmark time. If the list is empty, it prints a message indicating that no bookmarks
464
+ are found.
465
+
466
+ Args:
467
+ bookmark_info (dict): A dictionary which contains the details of a bookmark.
468
+ """
469
+ if bookmark_info:
470
+ table = Table(
471
+ title="Bookmark List", show_lines=True, expand=True, header_style="bold"
472
+ )
473
+ table.add_column("Name", justify="left", width=10)
474
+ table.add_column("Prompt", justify="left", width=30)
475
+ table.add_column("Prepared Prompt", justify="left", width=30)
476
+ table.add_column("Predicted Response", justify="left", width=30)
477
+ table.add_column("Context Strategy", justify="left", width=5)
478
+ table.add_column("Prompt Template", justify="left", width=5)
479
+ table.add_column("Attack Module", justify="left", width=5)
480
+ table.add_column("Metric", justify="left", width=5)
481
+ table.add_column("Bookmark Time", justify="left", width=5)
482
+ (
483
+ name,
484
+ prompt,
485
+ prepared_prompt,
486
+ response,
487
+ context_strategy,
488
+ prompt_template,
489
+ attack_module,
490
+ metric,
491
+ bookmark_time,
492
+ ) = bookmark_info.values()
493
+ table.add_section()
494
+ table.add_row(
495
+ name,
496
+ prompt,
497
+ prepared_prompt,
498
+ response,
499
+ context_strategy,
500
+ prompt_template,
501
+ attack_module,
502
+ metric,
503
+ bookmark_time,
504
+ )
505
+ console.print(table)
506
+ else:
507
+ console.print("[red]There are no bookmarks found.[/red]")
508
+
509
+
510
+ def export_bookmarks(args) -> None:
511
+ """
512
+ Exports all bookmarks to a JSON file.
513
+
514
+ Args:
515
+ args (Namespace): The arguments passed to the function, containing the name of the export file.
516
+ """
517
+ try:
518
+ file_path = api_export_bookmarks(export_file_name=args.bookmark_list_name)
519
+ print(f"Bookmarks exported successfully. Written to: {file_path}")
520
+ except Exception as e:
521
+ print(f"[export_bookmarks]: {str(e)}")
522
+
523
+
187
524
  def manual_red_teaming(user_prompt: str) -> None:
188
525
  """
189
526
  Initiates manual red teaming with the provided user prompt.
@@ -199,6 +536,7 @@ def manual_red_teaming(user_prompt: str) -> None:
199
536
  if not active_session:
200
537
  print("There is no active session. Activate a session to start red teaming.")
201
538
  return
539
+
202
540
  prompt_template = (
203
541
  [active_session["prompt_template"]] if active_session["prompt_template"] else []
204
542
  )
@@ -237,7 +575,7 @@ def manual_red_teaming(user_prompt: str) -> None:
237
575
  runner.close()
238
576
  _reload_session(active_session["session_id"])
239
577
  except Exception as e:
240
- print(f"[manual_red_teaming]: str({e})")
578
+ print(f"[manual_red_teaming]: ({str(e)})")
241
579
 
242
580
 
243
581
  def run_attack_module(args):
@@ -260,18 +598,34 @@ def run_attack_module(args):
260
598
  attack_module_id = args.attack_module_id
261
599
  prompt = args.prompt
262
600
  system_prompt = args.system_prompt if args.system_prompt else ""
263
- context_strategy = args.context_strategy or []
264
- prompt_template = [args.prompt_template] if args.prompt_template else []
265
- metric = [args.metric] if args.metric else []
601
+ # context strategy and prompt template should come from the session instead of the command
602
+
603
+ if args.prompt_template:
604
+ prompt_template = [args.prompt_template]
605
+ elif active_session["prompt_template"]:
606
+ prompt_template = [active_session["prompt_template"]]
607
+ else:
608
+ prompt_template = []
609
+
610
+ if args.context_strategy:
611
+ context_strategy = args.context_strategy
612
+ num_of_prev_prompts = (
613
+ args.cs_num_of_prev_prompts
614
+ if args.cs_num_of_prev_prompts
615
+ else Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
616
+ )
617
+ elif active_session["context_strategy"]:
618
+ context_strategy = active_session["context_strategy"]
619
+ num_of_prev_prompts = active_session["cs_num_of_prev_prompts"]
620
+ else:
621
+ context_strategy = []
622
+ num_of_prev_prompts = Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
623
+
266
624
  optional_arguments = (
267
625
  literal_eval(args.optional_args) if args.optional_args else {}
268
626
  )
269
- num_of_prev_prompts = (
270
- args.num_of_prev_prompts
271
- if args.num_of_prev_prompts
272
- else Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
273
- )
274
627
 
628
+ metric = [args.metric] if args.metric else []
275
629
  if context_strategy:
276
630
  context_strategy_info = [
277
631
  {
@@ -294,6 +648,7 @@ def run_attack_module(args):
294
648
  "optional_params": optional_arguments,
295
649
  }
296
650
  ]
651
+
297
652
  runner_args = {}
298
653
  runner_args["attack_strategies"] = attack_strategy
299
654
 
@@ -351,6 +706,48 @@ def delete_session(args) -> None:
351
706
  print(f"[delete_session]: {str(e)}")
352
707
 
353
708
 
709
+ def _display_sessions(sessions: list) -> None:
710
+ """
711
+ Display a list of sessions.
712
+
713
+ This function takes a list of sessions and displays them in a table format. If the list is empty, it prints a
714
+ message indicating that no sessions were found.
715
+
716
+ Args:
717
+ sessions (list): A list of sessions.
718
+
719
+ Returns:
720
+ None
721
+ """
722
+
723
+ table = Table(
724
+ title="Session List", show_lines=True, expand=True, header_style="bold"
725
+ )
726
+ table.add_column("No.", justify="left", width=2)
727
+ table.add_column("Session ID", justify="left", width=20)
728
+ table.add_column("Contains", justify="left", width=78)
729
+
730
+ for idx, session_data in enumerate(sessions, 1):
731
+ (
732
+ session_id,
733
+ endpoints,
734
+ created_epoch,
735
+ created_datetime,
736
+ prompt_template,
737
+ context_strategy,
738
+ cs_num_of_prev_prompts,
739
+ attack_module,
740
+ metric,
741
+ system_prompt,
742
+ *other_args,
743
+ ) = session_data.values()
744
+ idx = session_data.get("idx", idx)
745
+ session_info = f"[red]id: {session_id}[/red]\n\nCreated: {created_datetime}"
746
+ contains_info = f"[blue]Endpoints:[/blue] {endpoints}\n\n"
747
+ table.add_row(str(idx), session_info, contains_info)
748
+ console.print(table)
749
+
750
+
354
751
  # use session arguments
355
752
  use_session_args = cmd2.Cmd2ArgumentParser(
356
753
  description="Use an existing red teaming session by specifying the runner ID.",
@@ -410,9 +807,7 @@ new_session_args.add_argument(
410
807
  automated_rt_session_args = cmd2.Cmd2ArgumentParser(
411
808
  description="Runs automated red teaming in the current session.",
412
809
  epilog=(
413
- 'Example:\n run_attack_module sample_attack_module "this is my prompt" -s "test system prompt" '
414
- '-c "add_previous_prompt" -p "mmlu" -m "bleuscore" '
415
- "-o \"{'max_number_of_iteration': 1, 'my_optional_param': 'hello world'}\""
810
+ 'Example:\n run_attack_module sample_attack_module "this is my prompt" -s "test system prompt" -m bleuscore'
416
811
  ),
417
812
  )
418
813
 
@@ -436,23 +831,30 @@ automated_rt_session_args.add_argument(
436
831
  "-c",
437
832
  "--context_strategy",
438
833
  type=str,
439
- help="Name of the context strategy module to be used.",
834
+ help=(
835
+ "Name of the context strategy module to be used. If this is set, it will overwrite the context strategy"
836
+ " set in the session while running this attack module."
837
+ ),
440
838
  nargs="?",
441
839
  )
442
840
 
443
841
  automated_rt_session_args.add_argument(
444
842
  "-n",
445
- "--num_of_prev_prompts",
843
+ "--cs_num_of_prev_prompts",
446
844
  type=str,
447
- help="The number of previous prompts to use with the context strategy.",
845
+ help=(
846
+ "The number of previous prompts to use with the context strategy. If this is set, it will overwrite the"
847
+ " number of previous promtps set in the session while running this attack module."
848
+ ),
448
849
  nargs="?",
449
850
  )
450
851
 
451
852
  automated_rt_session_args.add_argument(
452
853
  "-p",
453
- "--prompt-template",
854
+ "--prompt_template",
454
855
  type=str,
455
- help="Name of the prompt template to be used.",
856
+ help="Name of the prompt template to be used. If this is set, it will overwrite the prompt template set in"
857
+ " the session while running this attack module.",
456
858
  nargs="?",
457
859
  )
458
860
 
@@ -477,3 +879,122 @@ delete_session_args = cmd2.Cmd2ArgumentParser(
477
879
  delete_session_args.add_argument(
478
880
  "session", type=str, help="The runner ID of the session to delete"
479
881
  )
882
+
883
+
884
+ # List sessions arguments
885
+ list_sessions_args = cmd2.Cmd2ArgumentParser(
886
+ description="List all sessions.",
887
+ epilog='Example:\n list_sessions -f "my-sessions"',
888
+ )
889
+
890
+ list_sessions_args.add_argument(
891
+ "-f",
892
+ "--find",
893
+ type=str,
894
+ help="Optional field to find session(s) with keyword",
895
+ nargs="?",
896
+ )
897
+
898
+ list_sessions_args.add_argument(
899
+ "-p",
900
+ "--pagination",
901
+ type=str,
902
+ help="Optional tuple to paginate session(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
903
+ nargs="?",
904
+ )
905
+
906
+ # Add bookmark arguments
907
+ add_bookmark_args = cmd2.Cmd2ArgumentParser(
908
+ description="Bookmark a prompt",
909
+ epilog="Example:\n add_bookmark openai-connector 2 my-bookmarked-prompt",
910
+ )
911
+
912
+ add_bookmark_args.add_argument(
913
+ "endpoint",
914
+ type=str,
915
+ help="Endpoint which the prompt was sent to.",
916
+ )
917
+
918
+ add_bookmark_args.add_argument(
919
+ "prompt_id",
920
+ type=int,
921
+ help="ID of the prompt (the leftmost column)",
922
+ )
923
+
924
+ add_bookmark_args.add_argument(
925
+ "bookmark_name",
926
+ type=str,
927
+ help="Name of the bookmark",
928
+ )
929
+
930
+ # Use bookmark arguments
931
+ use_bookmark_args = cmd2.Cmd2ArgumentParser(
932
+ description="Use a bookmarked prompt",
933
+ epilog="Example:\n use_bookmark my_bookmark",
934
+ )
935
+
936
+ use_bookmark_args.add_argument(
937
+ "bookmark_name",
938
+ type=str,
939
+ help="Name of the bookmark",
940
+ )
941
+
942
+
943
+ # Delete bookmark arguments
944
+ delete_bookmark_args = cmd2.Cmd2ArgumentParser(
945
+ description="Delete a bookmark",
946
+ epilog="Example:\n delete_bookmark my_bookmarked_prompt",
947
+ )
948
+
949
+ delete_bookmark_args.add_argument(
950
+ "bookmark_name",
951
+ type=str,
952
+ help="Name of the bookmark",
953
+ )
954
+
955
+ # View bookmark arguments
956
+ view_bookmark_args = cmd2.Cmd2ArgumentParser(
957
+ description="View a bookmark",
958
+ epilog="Example:\n view_bookmark my_bookmarked_prompt",
959
+ )
960
+
961
+ view_bookmark_args.add_argument(
962
+ "bookmark_name",
963
+ type=str,
964
+ help="Name of the bookmark you want to view",
965
+ )
966
+
967
+ # Export bookmarks arguments
968
+ export_bookmarks_args = cmd2.Cmd2ArgumentParser(
969
+ description="Exports bookmarks as a JSON file",
970
+ epilog='Example:\n export_bookmarks "my_list_of_exported_bookmarks"',
971
+ )
972
+
973
+ export_bookmarks_args.add_argument(
974
+ "bookmark_list_name",
975
+ type=str,
976
+ help="Name of the exported bookmarks JSON file you want to save as (without the .json extension)",
977
+ )
978
+
979
+
980
+ # List bookmarks arguments
981
+ list_bookmarks_args = cmd2.Cmd2ArgumentParser(
982
+ description="List all bookmarks.",
983
+ epilog="Example:\n list_bookmarks -f my_bookmark",
984
+ )
985
+
986
+ list_bookmarks_args.add_argument(
987
+ "-f",
988
+ "--find",
989
+ type=str,
990
+ help="Optional field to find bookmark(s) with keyword",
991
+ nargs="?",
992
+ )
993
+
994
+ list_bookmarks_args.add_argument(
995
+ "-p",
996
+ "--pagination",
997
+ type=str,
998
+ help="Optional tuple to paginate bookmark(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
999
+ nargs="?",
1000
+ )