aiverify-moonshot 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/METADATA +12 -10
  2. {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/RECORD +35 -31
  3. moonshot/api.py +2 -0
  4. moonshot/integrations/cli/benchmark/cookbook.py +36 -28
  5. moonshot/integrations/cli/benchmark/datasets.py +56 -47
  6. moonshot/integrations/cli/benchmark/metrics.py +39 -30
  7. moonshot/integrations/cli/benchmark/recipe.py +63 -73
  8. moonshot/integrations/cli/benchmark/result.py +62 -54
  9. moonshot/integrations/cli/benchmark/run.py +75 -66
  10. moonshot/integrations/cli/common/common.py +8 -0
  11. moonshot/integrations/cli/common/connectors.py +101 -85
  12. moonshot/integrations/cli/common/dataset.py +66 -0
  13. moonshot/integrations/cli/common/prompt_template.py +30 -27
  14. moonshot/integrations/cli/redteam/attack_module.py +45 -30
  15. moonshot/integrations/cli/redteam/context_strategy.py +36 -30
  16. moonshot/integrations/cli/redteam/session.py +101 -76
  17. moonshot/integrations/cli/utils/process_data.py +52 -0
  18. moonshot/integrations/web_api/app.py +1 -1
  19. moonshot/integrations/web_api/routes/dataset.py +46 -1
  20. moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
  21. moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
  22. moonshot/integrations/web_api/services/bookmark_service.py +6 -2
  23. moonshot/integrations/web_api/services/dataset_service.py +25 -0
  24. moonshot/integrations/web_api/services/recipe_service.py +0 -1
  25. moonshot/src/api/api_dataset.py +35 -0
  26. moonshot/src/api/api_recipe.py +0 -3
  27. moonshot/src/datasets/dataset.py +116 -0
  28. moonshot/src/recipes/recipe.py +0 -15
  29. moonshot/src/recipes/recipe_arguments.py +0 -4
  30. moonshot/src/utils/log.py +12 -6
  31. moonshot/src/utils/pagination.py +25 -0
  32. {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/WHEEL +0 -0
  33. {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/licenses/AUTHORS.md +0 -0
  34. {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/licenses/LICENSE.md +0 -0
  35. {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/licenses/NOTICES.md +0 -0
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ from ast import literal_eval
2
3
 
3
4
  import cmd2
4
5
  from rich.console import Console
@@ -10,8 +11,8 @@ from moonshot.api import (
10
11
  api_update_context_strategy,
11
12
  )
12
13
  from moonshot.integrations.cli.active_session_cfg import active_session
14
+ from moonshot.integrations.cli.utils.process_data import filter_data
13
15
  from moonshot.src.redteaming.session.session import Session
14
- from moonshot.src.utils.find_feature import find_keyword
15
16
 
16
17
  console = Console()
17
18
 
@@ -60,6 +61,7 @@ def list_context_strategies(args) -> list | None:
60
61
  Args:
61
62
  args: A namespace object from argparse. It should have an optional attribute:
62
63
  find (str): Optional field to find context strategies with a keyword.
64
+ pagination (str): Optional field to paginate context strategies.
63
65
 
64
66
  Returns:
65
67
  list | None: A list of ContextStrategy or None if there is no result.
@@ -67,19 +69,19 @@ def list_context_strategies(args) -> list | None:
67
69
  try:
68
70
  context_strategy_metadata_list = api_get_all_context_strategy_metadata()
69
71
  keyword = args.find.lower() if args.find else ""
70
- if keyword:
71
- filtered_context_strategies_list = find_keyword(
72
- keyword, context_strategy_metadata_list
72
+ pagination = literal_eval(args.pagination) if args.pagination else ()
73
+
74
+ if context_strategy_metadata_list:
75
+ filtered_context_strategies_list = filter_data(
76
+ context_strategy_metadata_list, keyword, pagination
73
77
  )
74
78
  if filtered_context_strategies_list:
75
- display_context_strategies(filtered_context_strategies_list)
79
+ _display_context_strategies(filtered_context_strategies_list)
76
80
  return filtered_context_strategies_list
77
- else:
78
- print("No context strategies containing keyword found.")
79
- return None
80
- else:
81
- display_context_strategies(context_strategy_metadata_list)
82
- return context_strategy_metadata_list
81
+
82
+ console.print("[red]There are no context strategies found.[/red]")
83
+ return None
84
+
83
85
  except Exception as e:
84
86
  print(f"[list_context_strategies]: {str(e)}")
85
87
 
@@ -124,7 +126,7 @@ def delete_context_strategy(args) -> None:
124
126
  print(f"[delete_context_strategy]: {str(e)}")
125
127
 
126
128
 
127
- def display_context_strategies(context_strategies: list) -> None:
129
+ def _display_context_strategies(context_strategies: list) -> None:
128
130
  """
129
131
  Display a list of context strategies.
130
132
 
@@ -137,25 +139,21 @@ def display_context_strategies(context_strategies: list) -> None:
137
139
  Returns:
138
140
  None
139
141
  """
140
- if context_strategies:
141
- table = Table(
142
- title="Context Strategy List",
143
- show_lines=True,
144
- expand=True,
145
- header_style="bold",
146
- )
147
- table.add_column("No.", justify="left", width=2)
148
- table.add_column("Context Strategy Information", justify="left", width=98)
149
- for context_strategy_index, context_strategy_data in enumerate(
150
- context_strategies, 1
151
- ):
152
- context_strategy_data_str = ""
153
- for k, v in context_strategy_data.items():
142
+ table = Table(
143
+ title="Context Strategy List",
144
+ show_lines=True,
145
+ expand=True,
146
+ header_style="bold",
147
+ )
148
+ table.add_column("No.", justify="left", width=2)
149
+ table.add_column("Context Strategy Information", justify="left", width=98)
150
+ for idx, context_strategy_data in enumerate(context_strategies, 1):
151
+ context_strategy_data_str = ""
152
+ for k, v in context_strategy_data.items():
153
+ if k != "idx":
154
154
  context_strategy_data_str += f"[blue]{k.capitalize()}:[/blue] {v}\n\n"
155
- table.add_row(str(context_strategy_index), context_strategy_data_str)
156
- console.print(table)
157
- else:
158
- console.print("[red]There are no context strategies found.[/red]", style="bold")
155
+ table.add_row(str(idx), context_strategy_data_str)
156
+ console.print(table)
159
157
 
160
158
 
161
159
  # Use context strategy arguments
@@ -199,3 +197,11 @@ list_context_strategies_args.add_argument(
199
197
  help="Optional field to find context strategies with keyword",
200
198
  nargs="?",
201
199
  )
200
+
201
+ list_context_strategies_args.add_argument(
202
+ "-p",
203
+ "--pagination",
204
+ type=str,
205
+ help="Optional tuple to paginate context strategies(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
206
+ nargs="?",
207
+ )
@@ -22,8 +22,8 @@ from moonshot.api import (
22
22
  api_load_session,
23
23
  )
24
24
  from moonshot.integrations.cli.active_session_cfg import active_session
25
+ from moonshot.integrations.cli.utils.process_data import filter_data
25
26
  from moonshot.src.redteaming.session.session import Session
26
- from moonshot.src.utils.find_feature import find_keyword
27
27
 
28
28
  console = Console()
29
29
 
@@ -139,6 +139,7 @@ def list_sessions(args) -> list | None:
139
139
  Args:
140
140
  args: A namespace object from argparse. It should have an optional attribute:
141
141
  find (str): Optional field to find session(s) with a keyword.
142
+ pagination (str): Optional field to paginate sessions.
142
143
 
143
144
  Returns:
144
145
  list | None: A list of Session or None if there is no result.
@@ -146,19 +147,19 @@ def list_sessions(args) -> list | None:
146
147
  try:
147
148
  session_metadata_list = api_get_all_session_metadata()
148
149
  keyword = args.find.lower() if args.find else ""
149
- if keyword:
150
- filtered_session_metadata_list = find_keyword(
151
- keyword, session_metadata_list
150
+ pagination = literal_eval(args.pagination) if args.pagination else ()
151
+
152
+ if session_metadata_list:
153
+ filtered_session_metadata_list = filter_data(
154
+ session_metadata_list, keyword, pagination
152
155
  )
153
156
  if filtered_session_metadata_list:
154
- display_sessions(filtered_session_metadata_list)
157
+ _display_sessions(filtered_session_metadata_list)
155
158
  return filtered_session_metadata_list
156
- else:
157
- print("No sessions containing keyword found.")
158
- return None
159
- else:
160
- display_sessions(session_metadata_list)
161
- return session_metadata_list
159
+
160
+ console.print("[red]There are no sessions found.[/red]")
161
+ return None
162
+
162
163
  except Exception as e:
163
164
  print(f"[list_sessions]: {str(e)}")
164
165
 
@@ -363,12 +364,13 @@ def list_bookmarks(args) -> list | None:
363
364
 
364
365
  This function retrieves all available bookmarks by calling the api_get_all_bookmarks function from the
365
366
  moonshot.api module.
366
- It then displays the retrieved bookmarks using the display_bookmarks function.
367
+ It then displays the retrieved bookmarks using the _display_bookmarks function.
367
368
  If no bookmarks are found, a message is printed to the console.
368
369
 
369
370
  Args:
370
371
  args: A namespace object from argparse. It should have an optional attribute:
371
372
  find (str): Optional field to find bookmark(s) with a keyword.
373
+ pagination (str): Optional field to paginate bookmarks.
372
374
 
373
375
  Returns:
374
376
  list | None: A list of Bookmark or None if there is no result.
@@ -376,22 +378,22 @@ def list_bookmarks(args) -> list | None:
376
378
  try:
377
379
  bookmarks_list = api_get_all_bookmarks()
378
380
  keyword = args.find.lower() if args.find else ""
379
- if keyword:
380
- filtered_bookmarks_list = find_keyword(keyword, bookmarks_list)
381
+ pagination = literal_eval(args.pagination) if args.pagination else ()
382
+
383
+ if bookmarks_list:
384
+ filtered_bookmarks_list = filter_data(bookmarks_list, keyword, pagination)
381
385
  if filtered_bookmarks_list:
382
- display_bookmarks(filtered_bookmarks_list)
386
+ _display_bookmarks(filtered_bookmarks_list)
383
387
  return filtered_bookmarks_list
384
- else:
385
- print("No bookmarks containing keyword found.")
386
- return None
387
- else:
388
- display_bookmarks(bookmarks_list)
389
- return bookmarks_list
388
+
389
+ console.print("[red]There are no bookmarks found.[/red]")
390
+ return None
391
+
390
392
  except Exception as e:
391
393
  print(f"[list_bookmarks]: {str(e)}")
392
394
 
393
395
 
394
- def display_bookmarks(bookmarks_list) -> None:
396
+ def _display_bookmarks(bookmarks_list) -> None:
395
397
  """
396
398
  Display the list of bookmarks in a tabular format.
397
399
 
@@ -402,38 +404,39 @@ def display_bookmarks(bookmarks_list) -> None:
402
404
  Args:
403
405
  bookmarks_list (list): A list of dictionaries, where each dictionary contains the details of a bookmark.
404
406
  """
405
- if bookmarks_list:
406
- table = Table(
407
- title="Bookmark List", show_lines=True, expand=True, header_style="bold"
407
+
408
+ table = Table(
409
+ title="Bookmark List", show_lines=True, expand=True, header_style="bold"
410
+ )
411
+ table.add_column("ID.", justify="left", width=5)
412
+ table.add_column("Name", justify="left", width=20)
413
+ table.add_column("Prepared Prompt", justify="left", width=50)
414
+ table.add_column("Predicted Response", justify="left", width=50)
415
+ table.add_column("Bookmark Time", justify="left", width=20)
416
+ for idx, bookmark in enumerate(bookmarks_list, 1):
417
+ (
418
+ name,
419
+ prompt,
420
+ prepared_prompt,
421
+ response,
422
+ context_strategy,
423
+ prompt_template,
424
+ attack_module,
425
+ metric,
426
+ bookmark_time,
427
+ *other_args,
428
+ ) = bookmark.values()
429
+ idx = bookmark.get("idx", idx)
430
+
431
+ table.add_section()
432
+ table.add_row(
433
+ str(idx),
434
+ name,
435
+ prepared_prompt,
436
+ response,
437
+ bookmark_time,
408
438
  )
409
- table.add_column("ID.", justify="left", width=5)
410
- table.add_column("Name", justify="left", width=20)
411
- table.add_column("Prepared Prompt", justify="left", width=50)
412
- table.add_column("Predicted Response", justify="left", width=50)
413
- table.add_column("Bookmark Time", justify="left", width=20)
414
- for idx, bookmark in enumerate(bookmarks_list, 1):
415
- (
416
- name,
417
- prompt,
418
- prepared_prompt,
419
- response,
420
- context_strategy,
421
- prompt_template,
422
- attack_module,
423
- metric,
424
- bookmark_time,
425
- ) = bookmark.values()
426
- table.add_section()
427
- table.add_row(
428
- str(idx),
429
- name,
430
- prepared_prompt,
431
- response,
432
- bookmark_time,
433
- )
434
- console.print(table)
435
- else:
436
- console.print("[red]There are no bookmarks found.[/red]")
439
+ console.print(table)
437
440
 
438
441
 
439
442
  def view_bookmark(args) -> None:
@@ -446,12 +449,12 @@ def view_bookmark(args) -> None:
446
449
 
447
450
  try:
448
451
  bookmark_info = api_get_bookmark(args.bookmark_name)
449
- display_bookmark(bookmark_info)
452
+ _display_bookmark(bookmark_info)
450
453
  except Exception as e:
451
454
  print(f"[view_bookmark]: {str(e)}")
452
455
 
453
456
 
454
- def display_bookmark(bookmark_info: dict) -> None:
457
+ def _display_bookmark(bookmark_info: dict) -> None:
455
458
  """
456
459
  Display the filtered bookmark in a tabular format.
457
460
 
@@ -600,7 +603,7 @@ def run_attack_module(args):
600
603
  if args.prompt_template:
601
604
  prompt_template = [args.prompt_template]
602
605
  elif active_session["prompt_template"]:
603
- prompt_template = [args.prompt_template]
606
+ prompt_template = [active_session["prompt_template"]]
604
607
  else:
605
608
  prompt_template = []
606
609
 
@@ -703,7 +706,7 @@ def delete_session(args) -> None:
703
706
  print(f"[delete_session]: {str(e)}")
704
707
 
705
708
 
706
- def display_sessions(sessions: list) -> None:
709
+ def _display_sessions(sessions: list) -> None:
707
710
  """
708
711
  Display a list of sessions.
709
712
 
@@ -717,25 +720,32 @@ def display_sessions(sessions: list) -> None:
717
720
  None
718
721
  """
719
722
 
720
- if sessions:
721
- table = Table(
722
- title="Session List", show_lines=True, expand=True, header_style="bold"
723
- )
724
- table.add_column("No.", justify="left", width=2)
725
- table.add_column("Session ID", justify="left", width=20)
726
- table.add_column("Contains", justify="left", width=78)
727
-
728
- for session_index, session_data in enumerate(sessions, 1):
729
- session_id = session_data.get("session_id", "")
730
- endpoints = ", ".join(session_data.get("endpoints", []))
731
- created_datetime = session_data.get("created_datetime", "")
732
-
733
- session_info = f"[red]id: {session_id}[/red]\n\nCreated: {created_datetime}"
734
- contains_info = f"[blue]Endpoints:[/blue] {endpoints}\n\n"
735
- table.add_row(str(session_index), session_info, contains_info)
736
- console.print(table)
737
- else:
738
- console.print("[red]There are no sessions found.[/red]", style="bold")
723
+ table = Table(
724
+ title="Session List", show_lines=True, expand=True, header_style="bold"
725
+ )
726
+ table.add_column("No.", justify="left", width=2)
727
+ table.add_column("Session ID", justify="left", width=20)
728
+ table.add_column("Contains", justify="left", width=78)
729
+
730
+ for idx, session_data in enumerate(sessions, 1):
731
+ (
732
+ session_id,
733
+ endpoints,
734
+ created_epoch,
735
+ created_datetime,
736
+ prompt_template,
737
+ context_strategy,
738
+ cs_num_of_prev_prompts,
739
+ attack_module,
740
+ metric,
741
+ system_prompt,
742
+ *other_args,
743
+ ) = session_data.values()
744
+ idx = session_data.get("idx", idx)
745
+ session_info = f"[red]id: {session_id}[/red]\n\nCreated: {created_datetime}"
746
+ contains_info = f"[blue]Endpoints:[/blue] {endpoints}\n\n"
747
+ table.add_row(str(idx), session_info, contains_info)
748
+ console.print(table)
739
749
 
740
750
 
741
751
  # use session arguments
@@ -885,6 +895,13 @@ list_sessions_args.add_argument(
885
895
  nargs="?",
886
896
  )
887
897
 
898
+ list_sessions_args.add_argument(
899
+ "-p",
900
+ "--pagination",
901
+ type=str,
902
+ help="Optional tuple to paginate session(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
903
+ nargs="?",
904
+ )
888
905
 
889
906
  # Add bookmark arguments
890
907
  add_bookmark_args = cmd2.Cmd2ArgumentParser(
@@ -973,3 +990,11 @@ list_bookmarks_args.add_argument(
973
990
  help="Optional field to find bookmark(s) with keyword",
974
991
  nargs="?",
975
992
  )
993
+
994
+ list_bookmarks_args.add_argument(
995
+ "-p",
996
+ "--pagination",
997
+ type=str,
998
+ help="Optional tuple to paginate bookmark(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
999
+ nargs="?",
1000
+ )
@@ -0,0 +1,52 @@
1
+ from moonshot.src.utils.find_feature import find_keyword
2
+ from moonshot.src.utils.pagination import get_paginated_lists
3
+
4
+
5
+ def filter_data(
6
+ list_of_data: list, keyword: str = "", pagination: tuple = ()
7
+ ) -> list | None:
8
+ """
9
+ Filters and paginates a list of data.
10
+
11
+ Args:
12
+ list_of_data (list): The list of data to be filtered and paginated.
13
+ keyword (str, optional): The keyword to filter the data. Defaults to "".
14
+ pagination (tuple, optional): A tuple containing the page number and page size. Defaults to ().
15
+
16
+ Returns:
17
+ list | None: The filtered and paginated list of data, or None if no data matches the criteria.
18
+ """
19
+ # if there is a find keyword
20
+ if keyword:
21
+ list_of_data = find_keyword(keyword, list_of_data)
22
+ if not list_of_data:
23
+ return
24
+
25
+ # if pagination is required
26
+ if pagination:
27
+ # add index to every dictionary in the list
28
+ if all(isinstance(item, dict) for item in list_of_data):
29
+ for index, item in enumerate(list_of_data, 1):
30
+ if isinstance(item, dict):
31
+ item["idx"] = index
32
+
33
+ # get paginated data
34
+ page_number = pagination[0]
35
+ page_size = pagination[1]
36
+
37
+ if page_number <= 0 or page_size <= 0:
38
+ # print("Invalid page number or page size. Page number and page size should start from 1.")
39
+ raise RuntimeError(
40
+ "Invalid page number or page size. Page number and page size should start from 1."
41
+ )
42
+
43
+ paginated_data = get_paginated_lists(page_size, list_of_data)
44
+
45
+ # perform index checks
46
+ paginated_data_size = len(paginated_data)
47
+ if page_number > paginated_data_size:
48
+ list_of_data = paginated_data[(paginated_data_size - 1)]
49
+ else:
50
+ list_of_data = paginated_data[(page_number - 1)]
51
+
52
+ return list_of_data
@@ -71,7 +71,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
71
71
  }
72
72
 
73
73
  app: CustomFastAPI = CustomFastAPI(
74
- title="Project Moonshot", version="0.4.2", **app_kwargs
74
+ title="Project Moonshot", version="0.4.4", **app_kwargs
75
75
  )
76
76
 
77
77
  if cfg.cors.enabled():
@@ -1,7 +1,8 @@
1
1
  from dependency_injector.wiring import Provide, inject
2
- from fastapi import APIRouter, Depends, HTTPException
2
+ from fastapi import APIRouter, Depends, HTTPException, Query
3
3
 
4
4
  from ..container import Container
5
+ from ..schemas.dataset_create_dto import DatasetCreateDTO
5
6
  from ..schemas.dataset_response_dto import DatasetResponseDTO
6
7
  from ..services.dataset_service import DatasetService
7
8
  from ..services.utils.exceptions_handler import ServiceException
@@ -9,6 +10,50 @@ from ..services.utils.exceptions_handler import ServiceException
9
10
  router = APIRouter(tags=["Datasets"])
10
11
 
11
12
 
13
+ @router.post("/api/v1/datasets")
14
+ @inject
15
+ def create_dataset(
16
+ dataset_data: DatasetCreateDTO,
17
+ method: str = Query(
18
+ ...,
19
+ description="The method to use for creating the dataset. Supported methods are 'hf' and 'csv'.",
20
+ ),
21
+ dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
22
+ ) -> str:
23
+ """
24
+ Create a new dataset using the specified method.
25
+
26
+ Args:
27
+ dataset_data (DatasetCreateDTO): The data required to create the dataset.
28
+ method (str): The method to use for creating the dataset. Supported methods are "hf" and "csv".
29
+ dataset_service (DatasetService, optional): The service responsible for creating the dataset.
30
+ Defaults to Depends(Provide[Container.dataset_service]).
31
+
32
+ Returns:
33
+ dict: A message indicating the dataset was created successfully.
34
+
35
+ Raises:
36
+ HTTPException: An error with status code 404 if the dataset file is not found.
37
+ An error with status code 400 if there is a validation error.
38
+ An error with status code 500 for any other server-side error.
39
+ """
40
+ try:
41
+ return dataset_service.create_dataset(dataset_data, method)
42
+ except ServiceException as e:
43
+ if e.error_code == "FileNotFound":
44
+ raise HTTPException(
45
+ status_code=404, detail=f"Failed to retrieve datasets: {e.msg}"
46
+ )
47
+ elif e.error_code == "ValidationError":
48
+ raise HTTPException(
49
+ status_code=400, detail=f"Failed to retrieve datasets: {e.msg}"
50
+ )
51
+ else:
52
+ raise HTTPException(
53
+ status_code=500, detail=f"Failed to retrieve datasets: {e.msg}"
54
+ )
55
+
56
+
12
57
  @router.get("/api/v1/datasets")
13
58
  @inject
14
59
  def get_all_datasets(
@@ -0,0 +1,18 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import Field
4
+ from pyparsing import Iterator
5
+
6
+ from moonshot.src.datasets.dataset_arguments import (
7
+ DatasetArguments as DatasetPydanticModel,
8
+ )
9
+
10
+
11
+ class DatasetCreateDTO(DatasetPydanticModel):
12
+ id: Optional[str] = None
13
+ examples: Iterator[dict] = None
14
+ name: str = Field(..., min_length=1)
15
+ description: str = Field(default="", min_length=1)
16
+ license: Optional[str] = ""
17
+ reference: Optional[str] = ""
18
+ params: dict
@@ -13,7 +13,6 @@ class RecipeCreateDTO(RecipePydanticModel):
13
13
  datasets: list[str] = Field(..., min_length=1)
14
14
  metrics: list[str] = Field(..., min_length=1)
15
15
  prompt_templates: Optional[list[str]] = None
16
- attack_modules: Optional[list[str]] = None
17
16
  grading_scale: Optional[dict[str, list[int]]] = None
18
17
  stats: Optional[dict] = None
19
18
 
@@ -27,6 +26,5 @@ class RecipeUpdateDTO(RecipePydanticModel):
27
26
  datasets: Optional[list[str]] = None
28
27
  prompt_templates: Optional[list[str]] = None
29
28
  metrics: Optional[list[str]] = None
30
- attack_modules: Optional[list[str]] = None
31
29
  grading_scale: Optional[dict[str, list[int]]] = None
32
30
  stats: Optional[dict] = None
@@ -52,18 +52,22 @@ class BookmarkService(BaseService):
52
52
  @exception_handler
53
53
  def delete_bookmarks(self, all: bool = False, name: str | None = None) -> dict:
54
54
  """
55
- Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True.
55
+ Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True and returns
56
+ a boolean indicating the success of the operation.
56
57
 
57
58
  Args:
58
59
  all (bool, optional): If True, all bookmarks will be deleted. Defaults to False.
59
60
  name (str | None, optional): The name of the bookmark to delete. If 'all' is False, 'name' must be provided.
61
+
62
+ Returns:
63
+ dict: True if the deletion was successful, False otherwise.
60
64
  """
61
65
  if all:
62
66
  result = moonshot_api.api_delete_all_bookmark()
63
67
  elif name is not None:
64
68
  result = moonshot_api.api_delete_bookmark(name)
65
69
  else:
66
- raise ValueError("Either 'all' must be True or 'id' must be provided.")
70
+ raise ValueError("Either 'all' must be True or 'name' must be provided.")
67
71
 
68
72
  if not result["success"]:
69
73
  raise Exception(
@@ -1,10 +1,35 @@
1
1
  from .... import api as moonshot_api
2
+ from ..schemas.dataset_create_dto import DatasetCreateDTO
2
3
  from ..schemas.dataset_response_dto import DatasetResponseDTO
3
4
  from ..services.base_service import BaseService
4
5
  from ..services.utils.exceptions_handler import exception_handler
6
+ from .utils.file_manager import copy_file
5
7
 
6
8
 
7
9
  class DatasetService(BaseService):
10
+ @exception_handler
11
+ def create_dataset(self, dataset_data: DatasetCreateDTO, method: str) -> str:
12
+ """
13
+ Create a dataset using the specified method.
14
+
15
+ Args:
16
+ dataset_data (DatasetCreateDTO): The data required to create the dataset.
17
+ method (str): The method to use for creating the dataset.
18
+ Supported methods are "hf" and "csv".
19
+
20
+ Raises:
21
+ Exception: If an error occurs during dataset creation.
22
+ """
23
+ new_ds_path = moonshot_api.api_create_datasets(
24
+ name=dataset_data.name,
25
+ description=dataset_data.description,
26
+ reference=dataset_data.reference,
27
+ license=dataset_data.license,
28
+ method=method,
29
+ **dataset_data.params,
30
+ )
31
+ return copy_file(new_ds_path)
32
+
8
33
  @exception_handler
9
34
  def get_all_datasets(self) -> list[DatasetResponseDTO]:
10
35
  datasets = moonshot_api.api_get_all_datasets()
@@ -24,7 +24,6 @@ class RecipeService(BaseService):
24
24
  datasets=recipe_data.datasets,
25
25
  prompt_templates=recipe_data.prompt_templates,
26
26
  metrics=recipe_data.metrics,
27
- attack_modules=recipe_data.attack_modules,
28
27
  grading_scale=recipe_data.grading_scale,
29
28
  )
30
29
 
@@ -1,6 +1,7 @@
1
1
  from pydantic import validate_call
2
2
 
3
3
  from moonshot.src.datasets.dataset import Dataset
4
+ from moonshot.src.datasets.dataset_arguments import DatasetArguments
4
5
 
5
6
 
6
7
  # ------------------------------------------------------------------------------
@@ -44,3 +45,37 @@ def api_get_all_datasets_name() -> list[str]:
44
45
  """
45
46
  datasets_name, _ = Dataset.get_available_items()
46
47
  return datasets_name
48
+
49
+
50
+ def api_create_datasets(
51
+ name: str, description: str, reference: str, license: str, method: str, **kwargs
52
+ ) -> str:
53
+ """
54
+ This function creates a new dataset.
55
+
56
+ This function takes the name, description, reference, and license for a new dataset as input. It then creates a new
57
+ DatasetArguments object with these details and an empty id. The id is left empty because it will be generated
58
+ from the name during the creation process. The function then calls the Dataset's create method to
59
+ create the new dataset.
60
+
61
+ Args:
62
+ name (str): The name of the new dataset.
63
+ description (str): A brief description of the new dataset.
64
+ reference (str): A reference link for the new dataset.
65
+ license (str): The license of the new dataset.
66
+ method (str): The method to create new dataset. (csv/hf)
67
+ kwargs: Additional keyword arguments for the Dataset's create method.
68
+
69
+ Returns:
70
+ str: The ID of the newly created dataset.
71
+ """
72
+ ds_args = DatasetArguments(
73
+ id="",
74
+ name=name,
75
+ description=description,
76
+ reference=reference,
77
+ license=license,
78
+ examples=None,
79
+ )
80
+
81
+ return Dataset.create(ds_args, method, **kwargs)