aiverify-moonshot 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/METADATA +2 -2
  2. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/RECORD +70 -56
  3. moonshot/__main__.py +77 -35
  4. moonshot/api.py +16 -0
  5. moonshot/integrations/cli/benchmark/benchmark.py +29 -13
  6. moonshot/integrations/cli/benchmark/cookbook.py +62 -24
  7. moonshot/integrations/cli/benchmark/datasets.py +79 -40
  8. moonshot/integrations/cli/benchmark/metrics.py +62 -23
  9. moonshot/integrations/cli/benchmark/recipe.py +89 -69
  10. moonshot/integrations/cli/benchmark/result.py +85 -47
  11. moonshot/integrations/cli/benchmark/run.py +99 -59
  12. moonshot/integrations/cli/common/common.py +20 -6
  13. moonshot/integrations/cli/common/connectors.py +154 -74
  14. moonshot/integrations/cli/common/dataset.py +66 -0
  15. moonshot/integrations/cli/common/prompt_template.py +57 -19
  16. moonshot/integrations/cli/redteam/attack_module.py +90 -24
  17. moonshot/integrations/cli/redteam/context_strategy.py +83 -23
  18. moonshot/integrations/cli/redteam/prompt_template.py +1 -1
  19. moonshot/integrations/cli/redteam/redteam.py +52 -6
  20. moonshot/integrations/cli/redteam/session.py +565 -44
  21. moonshot/integrations/cli/utils/process_data.py +52 -0
  22. moonshot/integrations/web_api/__main__.py +2 -0
  23. moonshot/integrations/web_api/app.py +6 -6
  24. moonshot/integrations/web_api/container.py +12 -2
  25. moonshot/integrations/web_api/routes/bookmark.py +173 -0
  26. moonshot/integrations/web_api/routes/dataset.py +46 -1
  27. moonshot/integrations/web_api/schemas/bookmark_create_dto.py +13 -0
  28. moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
  29. moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
  30. moonshot/integrations/web_api/services/bookmark_service.py +94 -0
  31. moonshot/integrations/web_api/services/dataset_service.py +25 -0
  32. moonshot/integrations/web_api/services/recipe_service.py +0 -1
  33. moonshot/integrations/web_api/services/utils/file_manager.py +52 -0
  34. moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +0 -1
  35. moonshot/integrations/web_api/temp/.gitkeep +0 -0
  36. moonshot/src/api/api_bookmark.py +95 -0
  37. moonshot/src/api/api_connector_endpoint.py +1 -1
  38. moonshot/src/api/api_context_strategy.py +2 -2
  39. moonshot/src/api/api_dataset.py +35 -0
  40. moonshot/src/api/api_recipe.py +0 -3
  41. moonshot/src/api/api_session.py +1 -1
  42. moonshot/src/bookmark/bookmark.py +257 -0
  43. moonshot/src/bookmark/bookmark_arguments.py +38 -0
  44. moonshot/src/configs/env_variables.py +12 -2
  45. moonshot/src/connectors/connector.py +15 -7
  46. moonshot/src/connectors_endpoints/connector_endpoint.py +65 -49
  47. moonshot/src/cookbooks/cookbook.py +57 -37
  48. moonshot/src/datasets/dataset.py +125 -5
  49. moonshot/src/metrics/metric.py +8 -4
  50. moonshot/src/metrics/metric_interface.py +8 -2
  51. moonshot/src/prompt_templates/prompt_template.py +5 -1
  52. moonshot/src/recipes/recipe.py +38 -40
  53. moonshot/src/recipes/recipe_arguments.py +0 -4
  54. moonshot/src/redteaming/attack/attack_module.py +18 -8
  55. moonshot/src/redteaming/attack/context_strategy.py +6 -2
  56. moonshot/src/redteaming/session/session.py +15 -11
  57. moonshot/src/results/result.py +7 -3
  58. moonshot/src/runners/runner.py +65 -42
  59. moonshot/src/runs/run.py +15 -11
  60. moonshot/src/runs/run_progress.py +7 -3
  61. moonshot/src/storage/db_interface.py +14 -0
  62. moonshot/src/storage/storage.py +33 -2
  63. moonshot/src/utils/find_feature.py +45 -0
  64. moonshot/src/utils/log.py +72 -0
  65. moonshot/src/utils/pagination.py +25 -0
  66. moonshot/src/utils/timeit.py +8 -1
  67. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/WHEEL +0 -0
  68. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/AUTHORS.md +0 -0
  69. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/LICENSE.md +0 -0
  70. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/NOTICES.md +0 -0
@@ -0,0 +1,52 @@
1
+ from moonshot.src.utils.find_feature import find_keyword
2
+ from moonshot.src.utils.pagination import get_paginated_lists
3
+
4
+
5
+ def filter_data(
6
+ list_of_data: list, keyword: str = "", pagination: tuple = ()
7
+ ) -> list | None:
8
+ """
9
+ Filters and paginates a list of data.
10
+
11
+ Args:
12
+ list_of_data (list): The list of data to be filtered and paginated.
13
+ keyword (str, optional): The keyword to filter the data. Defaults to "".
14
+ pagination (tuple, optional): A tuple containing the page number and page size. Defaults to ().
15
+
16
+ Returns:
17
+ list | None: The filtered and paginated list of data, or None if no data matches the criteria.
18
+ """
19
+ # if there is a find keyword
20
+ if keyword:
21
+ list_of_data = find_keyword(keyword, list_of_data)
22
+ if not list_of_data:
23
+ return
24
+
25
+ # if pagination is required
26
+ if pagination:
27
+ # add index to every dictionary in the list
28
+ if all(isinstance(item, dict) for item in list_of_data):
29
+ for index, item in enumerate(list_of_data, 1):
30
+ if isinstance(item, dict):
31
+ item["idx"] = index
32
+
33
+ # get paginated data
34
+ page_number = pagination[0]
35
+ page_size = pagination[1]
36
+
37
+ if page_number <= 0 or page_size <= 0:
38
+ # print("Invalid page number or page size. Page number and page size should start from 1.")
39
+ raise RuntimeError(
40
+ "Invalid page number or page size. Page number and page size should start from 1."
41
+ )
42
+
43
+ paginated_data = get_paginated_lists(page_size, list_of_data)
44
+
45
+ # perform index checks
46
+ paginated_data_size = len(paginated_data)
47
+ if page_number > paginated_data_size:
48
+ list_of_data = paginated_data[(paginated_data_size - 1)]
49
+ else:
50
+ list_of_data = paginated_data[(page_number - 1)]
51
+
52
+ return list_of_data
@@ -7,6 +7,7 @@ from dotenv import dotenv_values, load_dotenv
7
7
  from .app import create_app
8
8
  from .container import Container
9
9
  from .logging_conf import configure_app_logging, create_uvicorn_log_config
10
+ from .services.utils.file_manager import create_temp_dir
10
11
  from .types.types import UvicornRunArgs
11
12
 
12
13
 
@@ -20,6 +21,7 @@ def start_app():
20
21
  else:
21
22
  container.config.from_yaml(f"{config_file}", required=True)
22
23
 
24
+ create_temp_dir(container.config.temp_folder())
23
25
  configure_app_logging(container.config)
24
26
  logging.info(f"Environment: {container.config.app_environment()}")
25
27
  ENABLE_SSL = container.config.ssl.enabled()
@@ -14,12 +14,13 @@ from .routes import (
14
14
  attack_modules,
15
15
  benchmark,
16
16
  benchmark_result,
17
+ bookmark,
18
+ context_strategy,
17
19
  cookbook,
18
20
  dataset,
19
21
  endpoint,
20
22
  metric,
21
23
  prompt_template,
22
- context_strategy,
23
24
  recipe,
24
25
  runner,
25
26
  )
@@ -66,13 +67,11 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
66
67
 
67
68
  app_kwargs["swagger_ui_parameters"] = {
68
69
  "defaultModelsExpandDepth": -1,
69
- "docExpansion": None
70
- }
70
+ "docExpansion": None,
71
+ }
71
72
 
72
73
  app: CustomFastAPI = CustomFastAPI(
73
- title="Project Moonshot",
74
- version="0.4.1",
75
- **app_kwargs
74
+ title="Project Moonshot", version="0.4.3", **app_kwargs
76
75
  )
77
76
 
78
77
  if cfg.cors.enabled():
@@ -104,6 +103,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
104
103
  app.include_router(runner.router)
105
104
  app.include_router(dataset.router)
106
105
  app.include_router(attack_modules.router)
106
+ app.include_router(bookmark.router)
107
107
 
108
108
  @app.exception_handler(RequestValidationError)
109
109
  async def validation_exception_handler(
@@ -9,6 +9,7 @@ from .services.benchmark_result_service import BenchmarkResultService
9
9
  from .services.benchmark_test_manager import BenchmarkTestManager
10
10
  from .services.benchmark_test_state import BenchmarkTestState
11
11
  from .services.benchmarking_service import BenchmarkingService
12
+ from .services.bookmark_service import BookmarkService
12
13
  from .services.context_strategy_service import ContextStrategyService
13
14
  from .services.cookbook_service import CookbookService
14
15
  from .services.dataset_service import DatasetService
@@ -55,6 +56,11 @@ class Container(containers.DeclarativeContainer):
55
56
  "log_file_max_size": 5242880,
56
57
  "log_file_backup_count": 3,
57
58
  },
59
+ "temp_folder": str(
60
+ importlib.resources.files("moonshot").joinpath(
61
+ "integrations/web_api/temp"
62
+ )
63
+ ),
58
64
  }
59
65
  )
60
66
 
@@ -97,9 +103,9 @@ class Container(containers.DeclarativeContainer):
97
103
  prompt_template_service: providers.Singleton[
98
104
  PromptTemplateService
99
105
  ] = providers.Singleton(PromptTemplateService)
100
- context_strategy_service: providers.Singleton[ContextStrategyService] = providers.Singleton(
106
+ context_strategy_service: providers.Singleton[
101
107
  ContextStrategyService
102
- )
108
+ ] = providers.Singleton(ContextStrategyService)
103
109
  benchmarking_service: providers.Singleton[
104
110
  BenchmarkingService
105
111
  ] = providers.Singleton(
@@ -127,6 +133,9 @@ class Container(containers.DeclarativeContainer):
127
133
  am_service: providers.Singleton[AttackModuleService] = providers.Singleton(
128
134
  AttackModuleService,
129
135
  )
136
+ bookmark_service: providers.Singleton[BookmarkService] = providers.Singleton(
137
+ BookmarkService,
138
+ )
130
139
  wiring_config = containers.WiringConfiguration(
131
140
  modules=[
132
141
  ".routes.redteam",
@@ -141,6 +150,7 @@ class Container(containers.DeclarativeContainer):
141
150
  ".routes.runner",
142
151
  ".routes.dataset",
143
152
  ".routes.attack_modules",
153
+ ".routes.bookmark",
144
154
  ".services.benchmarking_service",
145
155
  ]
146
156
  )
@@ -0,0 +1,173 @@
1
+ from typing import Optional
2
+
3
+ from dependency_injector.wiring import Provide, inject
4
+ from fastapi import APIRouter, Depends, HTTPException, Query
5
+
6
+ from ..container import Container
7
+ from ..schemas.bookmark_create_dto import BookmarkCreateDTO, BookmarkPydanticModel
8
+ from ..services.bookmark_service import BookmarkService
9
+ from ..services.utils.exceptions_handler import ServiceException
10
+
11
+ router = APIRouter(tags=["Bookmark"])
12
+
13
+
14
+ @router.post(
15
+ "/api/v1/bookmarks", response_description="Bookmark data added to the database"
16
+ )
17
+ @inject
18
+ def insert_bookmark(
19
+ bookmark_data: BookmarkCreateDTO,
20
+ bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
21
+ ) -> dict:
22
+ """
23
+ Insert a new bookmark into the database.
24
+
25
+ Args:
26
+ bookmark_data: The data of the bookmark to be added.
27
+ bookmark_service: The service responsible for bookmark operations.
28
+
29
+ Returns:
30
+ A dictionary with a message indicating successful insertion.
31
+
32
+ Raises:
33
+ HTTPException: An error occurred while inserting the bookmark.
34
+ """
35
+ try:
36
+ result = bookmark_service.insert_bookmark(bookmark_data)
37
+ return result
38
+ except ServiceException as e:
39
+ if e.error_code == "FileNotFound":
40
+ raise HTTPException(
41
+ status_code=404, detail=f"Failed to insert bookmark: {e.msg}"
42
+ )
43
+ elif e.error_code == "ValidationError":
44
+ raise HTTPException(
45
+ status_code=400, detail=f"Failed to insert bookmark: {e.msg}"
46
+ )
47
+ else:
48
+ raise HTTPException(
49
+ status_code=500, detail=f"Failed to insert bookmark: {e.msg}"
50
+ )
51
+
52
+
53
+ @router.get(
54
+ "/api/v1/bookmarks",
55
+ response_description="List of all bookmarks or a specific bookmark by name",
56
+ )
57
+ @inject
58
+ def get_all_bookmarks(
59
+ name: Optional[str] = Query(None, description="Name of the bookmark to query"),
60
+ bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
61
+ ) -> list[BookmarkPydanticModel]:
62
+ """
63
+ Retrieve all bookmarks or a specific bookmark by name from the database.
64
+
65
+ Args:
66
+ name: The name of the bookmark to retrieve. If None, all bookmarks are retrieved.
67
+ bookmark_service: The service responsible for bookmark operations.
68
+
69
+ Returns:
70
+ A list of bookmarks or a single bookmark if a name is provided.
71
+
72
+ Raises:
73
+ HTTPException: An error occurred while retrieving bookmarks.
74
+ """
75
+ try:
76
+ return bookmark_service.get_all_bookmarks(
77
+ name=name,
78
+ )
79
+ except ServiceException as e:
80
+ if e.error_code == "FileNotFound":
81
+ raise HTTPException(
82
+ status_code=404, detail=f"Failed to retrieve bookmarks: {e.msg}"
83
+ )
84
+ elif e.error_code == "ValidationError":
85
+ raise HTTPException(
86
+ status_code=400, detail=f"Failed to retrieve bookmarks: {e.msg}"
87
+ )
88
+ else:
89
+ raise HTTPException(
90
+ status_code=500, detail=f"Failed to retrieve bookmarks: {e.msg}"
91
+ )
92
+
93
+
94
+ @router.delete(
95
+ "/api/v1/bookmarks", response_description="Bookmark data deleted from the database"
96
+ )
97
+ @inject
98
+ def delete_bookmark(
99
+ all: bool = Query(False, description="Flag to delete all bookmarks"),
100
+ name: Optional[str] = Query(None, description="Name of the bookmark to delete"),
101
+ bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
102
+ ) -> dict:
103
+ """
104
+ Delete a specific bookmark by name or all bookmarks from the database.
105
+
106
+ Args:
107
+ all: A flag indicating whether to delete all bookmarks.
108
+ name: The name of the bookmark to delete. If 'all' is False and 'name' is None, no bookmark will be deleted.
109
+ bookmark_service: The service responsible for bookmark operations.
110
+
111
+ Returns:
112
+ A dictionary with a message indicating successful deletion or an error message.
113
+
114
+ Raises:
115
+ HTTPException: An error occurred while deleting the bookmark(s).
116
+ """
117
+ try:
118
+ if all:
119
+ return bookmark_service.delete_bookmarks(all=True)
120
+ elif name is not None:
121
+ return bookmark_service.delete_bookmarks(all=False, name=name)
122
+ else:
123
+ raise HTTPException(
124
+ status_code=400, detail="Must specify 'all' or 'name' parameter"
125
+ )
126
+ except ServiceException as e:
127
+ if e.error_code == "FileNotFound":
128
+ raise HTTPException(
129
+ status_code=404, detail=f"Failed to delete bookmark: {e.msg}"
130
+ )
131
+ elif e.error_code == "ValidationError":
132
+ raise HTTPException(
133
+ status_code=400, detail=f"Failed to delete bookmark: {e.msg}"
134
+ )
135
+ else:
136
+ raise HTTPException(
137
+ status_code=500, detail=f"Failed to delete bookmark: {e.msg}"
138
+ )
139
+ @router.post(
140
+ "/api/v1/bookmarks/export", response_description="Exporting Bookmark to JSON file"
141
+ )
142
+ @inject
143
+ def export_bookbookmarks(
144
+ export_file_name: Optional[str] = Query(
145
+ "bookmarks", description="Name of the exported file"
146
+ ),
147
+ bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
148
+ ) -> str:
149
+ """
150
+ Export bookmarks to a JSON file with a given file name.
151
+
152
+ Args:
153
+ export_file_name: The name of the file to export the bookmarks to.z
154
+ bookmark_service: The service responsible for bookmark operations.
155
+
156
+ Returns:
157
+ A string with the path to the exported file or an error message.
158
+ """
159
+ try:
160
+ return bookmark_service.export_bookmarks(export_file_name)
161
+ except ServiceException as e:
162
+ if e.error_code == "FileNotFound":
163
+ raise HTTPException(
164
+ status_code=404, detail=f"Failed to export bookmark: {e.msg}"
165
+ )
166
+ elif e.error_code == "ValidationError":
167
+ raise HTTPException(
168
+ status_code=400, detail=f"Failed to export bookmark: {e.msg}"
169
+ )
170
+ else:
171
+ raise HTTPException(
172
+ status_code=500, detail=f"Failed to export bookmark: {e.msg}"
173
+ )
@@ -1,7 +1,8 @@
1
1
  from dependency_injector.wiring import Provide, inject
2
- from fastapi import APIRouter, Depends, HTTPException
2
+ from fastapi import APIRouter, Depends, HTTPException, Query
3
3
 
4
4
  from ..container import Container
5
+ from ..schemas.dataset_create_dto import DatasetCreateDTO
5
6
  from ..schemas.dataset_response_dto import DatasetResponseDTO
6
7
  from ..services.dataset_service import DatasetService
7
8
  from ..services.utils.exceptions_handler import ServiceException
@@ -9,6 +10,50 @@ from ..services.utils.exceptions_handler import ServiceException
9
10
  router = APIRouter(tags=["Datasets"])
10
11
 
11
12
 
13
+ @router.post("/api/v1/datasets")
14
+ @inject
15
+ def create_dataset(
16
+ dataset_data: DatasetCreateDTO,
17
+ method: str = Query(
18
+ ...,
19
+ description="The method to use for creating the dataset. Supported methods are 'hf' and 'csv'.",
20
+ ),
21
+ dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
22
+ ) -> str:
23
+ """
24
+ Create a new dataset using the specified method.
25
+
26
+ Args:
27
+ dataset_data (DatasetCreateDTO): The data required to create the dataset.
28
+ method (str): The method to use for creating the dataset. Supported methods are "hf" and "csv".
29
+ dataset_service (DatasetService, optional): The service responsible for creating the dataset.
30
+ Defaults to Depends(Provide[Container.dataset_service]).
31
+
32
+ Returns:
33
+ dict: A message indicating the dataset was created successfully.
34
+
35
+ Raises:
36
+ HTTPException: An error with status code 404 if the dataset file is not found.
37
+ An error with status code 400 if there is a validation error.
38
+ An error with status code 500 for any other server-side error.
39
+ """
40
+ try:
41
+ return dataset_service.create_dataset(dataset_data, method)
42
+ except ServiceException as e:
43
+ if e.error_code == "FileNotFound":
44
+ raise HTTPException(
45
+ status_code=404, detail=f"Failed to retrieve datasets: {e.msg}"
46
+ )
47
+ elif e.error_code == "ValidationError":
48
+ raise HTTPException(
49
+ status_code=400, detail=f"Failed to retrieve datasets: {e.msg}"
50
+ )
51
+ else:
52
+ raise HTTPException(
53
+ status_code=500, detail=f"Failed to retrieve datasets: {e.msg}"
54
+ )
55
+
56
+
12
57
  @router.get("/api/v1/datasets")
13
58
  @inject
14
59
  def get_all_datasets(
@@ -0,0 +1,13 @@
1
+ from typing import Optional
2
+
3
+ from moonshot.src.bookmark.bookmark_arguments import (
4
+ BookmarkArguments as BookmarkPydanticModel,
5
+ )
6
+
7
+
8
+ class BookmarkCreateDTO(BookmarkPydanticModel):
9
+ prompt_template: Optional[str] = ""
10
+ context_strategy: Optional[str] = ""
11
+ attack_module: Optional[str] = ""
12
+ metric: Optional[str] = ""
13
+ bookmark_time: Optional[str] = None
@@ -0,0 +1,18 @@
1
+ from typing import Optional
2
+
3
+ from pydantic import Field
4
+ from pyparsing import Iterator
5
+
6
+ from moonshot.src.datasets.dataset_arguments import (
7
+ DatasetArguments as DatasetPydanticModel,
8
+ )
9
+
10
+
11
+ class DatasetCreateDTO(DatasetPydanticModel):
12
+ id: Optional[str] = None
13
+ examples: Iterator[dict] = None
14
+ name: str = Field(..., min_length=1)
15
+ description: str = Field(default="", min_length=1)
16
+ license: Optional[str] = ""
17
+ reference: Optional[str] = ""
18
+ params: dict
@@ -13,7 +13,6 @@ class RecipeCreateDTO(RecipePydanticModel):
13
13
  datasets: list[str] = Field(..., min_length=1)
14
14
  metrics: list[str] = Field(..., min_length=1)
15
15
  prompt_templates: Optional[list[str]] = None
16
- attack_modules: Optional[list[str]] = None
17
16
  grading_scale: Optional[dict[str, list[int]]] = None
18
17
  stats: Optional[dict] = None
19
18
 
@@ -27,6 +26,5 @@ class RecipeUpdateDTO(RecipePydanticModel):
27
26
  datasets: Optional[list[str]] = None
28
27
  prompt_templates: Optional[list[str]] = None
29
28
  metrics: Optional[list[str]] = None
30
- attack_modules: Optional[list[str]] = None
31
29
  grading_scale: Optional[dict[str, list[int]]] = None
32
30
  stats: Optional[dict] = None
@@ -0,0 +1,94 @@
1
+ from .... import api as moonshot_api
2
+ from ..schemas.bookmark_create_dto import BookmarkCreateDTO, BookmarkPydanticModel
3
+ from ..services.base_service import BaseService
4
+ from ..services.utils.exceptions_handler import exception_handler
5
+ from .utils.file_manager import copy_file
6
+
7
+
8
+ class BookmarkService(BaseService):
9
+ @exception_handler
10
+ def insert_bookmark(self, bookmark_data: BookmarkCreateDTO) -> dict:
11
+ """
12
+ Inserts a new bookmark into the system.
13
+
14
+ Args:
15
+ bookmark_data (BookmarkCreateDTO): The data transfer object containing bookmark details.
16
+ """
17
+ result = moonshot_api.api_insert_bookmark(
18
+ name=bookmark_data.name,
19
+ prompt=bookmark_data.prompt,
20
+ prepared_prompt=bookmark_data.prepared_prompt,
21
+ response=bookmark_data.response,
22
+ context_strategy=bookmark_data.context_strategy,
23
+ prompt_template=bookmark_data.prompt_template,
24
+ attack_module=bookmark_data.attack_module,
25
+ metric=bookmark_data.metric,
26
+ )
27
+
28
+ return result
29
+
30
+ @exception_handler
31
+ def get_all_bookmarks(self, name: str | None = None) -> list[BookmarkPydanticModel]:
32
+ """
33
+ Retrieves all bookmarks or a specific bookmark by its name.
34
+
35
+ Args:
36
+ name (str | None, optional): The name of the bookmark to retrieve. If None, all bookmarks are retrieved.
37
+
38
+ Returns:
39
+ list[BookmarkPydanticModel]: A list of bookmark models.
40
+ """
41
+ retn_bookmark: list[BookmarkPydanticModel] = []
42
+
43
+ if name:
44
+ bookmarks = [moonshot_api.api_get_bookmark(name)]
45
+ else:
46
+ bookmarks = moonshot_api.api_get_all_bookmarks()
47
+
48
+ for bookmark in bookmarks:
49
+ retn_bookmark.append(BookmarkPydanticModel(**bookmark))
50
+ return retn_bookmark
51
+
52
+ @exception_handler
53
+ def delete_bookmarks(self, all: bool = False, name: str | None = None) -> dict:
54
+ """
55
+ Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True and returns
56
+ a boolean indicating the success of the operation.
57
+
58
+ Args:
59
+ all (bool, optional): If True, all bookmarks will be deleted. Defaults to False.
60
+ name (str | None, optional): The name of the bookmark to delete. If 'all' is False, 'name' must be provided.
61
+
62
+ Returns:
63
+ dict: True if the deletion was successful, False otherwise.
64
+ """
65
+ if all:
66
+ result = moonshot_api.api_delete_all_bookmark()
67
+ elif name is not None:
68
+ result = moonshot_api.api_delete_bookmark(name)
69
+ else:
70
+ raise ValueError("Either 'all' must be True or 'name' must be provided.")
71
+
72
+ if not result["success"]:
73
+ raise Exception(
74
+ result["message"], "delete_bookmarks", "DeleteBookmarkError"
75
+ )
76
+
77
+ return result
78
+
79
+ @exception_handler
80
+ def export_bookmarks(self, export_file_name: str = "bookmarks") -> str:
81
+ """
82
+ Exports bookmarks to a file or returns them as a list of dictionaries.
83
+
84
+ Args:
85
+ write_file (bool, optional): If True, bookmarks will be written to a file. Defaults to False.
86
+ export_file_name (str, optional): The name of the file to write bookmarks to. Defaults to "bookmarks".
87
+
88
+ Returns:
89
+ list[dict]: A list of bookmarks as dictionaries.
90
+ """
91
+
92
+ new_file_path = moonshot_api.api_export_bookmarks(export_file_name)
93
+
94
+ return copy_file(new_file_path)
@@ -1,10 +1,35 @@
1
1
  from .... import api as moonshot_api
2
+ from ..schemas.dataset_create_dto import DatasetCreateDTO
2
3
  from ..schemas.dataset_response_dto import DatasetResponseDTO
3
4
  from ..services.base_service import BaseService
4
5
  from ..services.utils.exceptions_handler import exception_handler
6
+ from .utils.file_manager import copy_file
5
7
 
6
8
 
7
9
  class DatasetService(BaseService):
10
+ @exception_handler
11
+ def create_dataset(self, dataset_data: DatasetCreateDTO, method: str) -> str:
12
+ """
13
+ Create a dataset using the specified method.
14
+
15
+ Args:
16
+ dataset_data (DatasetCreateDTO): The data required to create the dataset.
17
+ method (str): The method to use for creating the dataset.
18
+ Supported methods are "hf" and "csv".
19
+
20
+ Raises:
21
+ Exception: If an error occurs during dataset creation.
22
+ """
23
+ new_ds_path = moonshot_api.api_create_datasets(
24
+ name=dataset_data.name,
25
+ description=dataset_data.description,
26
+ reference=dataset_data.reference,
27
+ license=dataset_data.license,
28
+ method=method,
29
+ **dataset_data.params,
30
+ )
31
+ return copy_file(new_ds_path)
32
+
8
33
  @exception_handler
9
34
  def get_all_datasets(self) -> list[DatasetResponseDTO]:
10
35
  datasets = moonshot_api.api_get_all_datasets()
@@ -24,7 +24,6 @@ class RecipeService(BaseService):
24
24
  datasets=recipe_data.datasets,
25
25
  prompt_templates=recipe_data.prompt_templates,
26
26
  metrics=recipe_data.metrics,
27
- attack_modules=recipe_data.attack_modules,
28
27
  grading_scale=recipe_data.grading_scale,
29
28
  )
30
29
 
@@ -0,0 +1,52 @@
1
+ import os
2
+ import shutil
3
+
4
+ # Global variable to store the path to the temporary folder
5
+ TEMP_FOLDER_PATH = None
6
+
7
+
8
+ def create_temp_dir(temp_file_path: str):
9
+ """
10
+ Create a temporary directory at the specified path.
11
+
12
+ This function sets a global variable with the path to the temporary directory
13
+ and creates the directory if it does not already exist.
14
+
15
+ Args:
16
+ temp_file_path (str): The file path where the temporary directory will be created.
17
+ """
18
+ global TEMP_FOLDER_PATH
19
+ TEMP_FOLDER_PATH = temp_file_path
20
+ if not os.path.exists(TEMP_FOLDER_PATH):
21
+ os.makedirs(TEMP_FOLDER_PATH)
22
+
23
+
24
+ def copy_file(file_path: str) -> str:
25
+ """
26
+ Copy a file to the temporary directory.
27
+
28
+ This function copies a file from the given file path to the temporary directory
29
+ set by the create_temp_dir function. It raises an error if the temporary directory
30
+ path is not set before calling this function.
31
+
32
+ Args:
33
+ file_path (str): The path of the file to be copied.
34
+
35
+ Returns:
36
+ str: The path to the copied file in the temporary directory.
37
+
38
+ Raises:
39
+ ValueError: If the temporary folder path is not set.
40
+ """
41
+ if TEMP_FOLDER_PATH is None:
42
+ raise ValueError(
43
+ "Temporary folder path is not set. Call create_temp_dir first."
44
+ )
45
+
46
+ # Define the destination path for the copied file
47
+ dest_path = os.path.join(TEMP_FOLDER_PATH, os.path.basename(file_path))
48
+
49
+ # Copy the file to the destination
50
+ shutil.copy(file_path, dest_path)
51
+
52
+ return dest_path
@@ -60,7 +60,6 @@ class MoonshotUIWebhook(
60
60
  logger = logging.getLogger()
61
61
  logger.debug(json.dumps(progress_data, indent=2))
62
62
  self.auto_red_team_test_state.update_progress_status(progress_data)
63
- print("Calling Auto Red Team Callback")
64
63
  try:
65
64
  response = requests.post(self.art_url, json=progress_data)
66
65
  response.raise_for_status()
File without changes