aiverify-moonshot 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/METADATA +2 -2
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/RECORD +70 -56
- moonshot/__main__.py +77 -35
- moonshot/api.py +16 -0
- moonshot/integrations/cli/benchmark/benchmark.py +29 -13
- moonshot/integrations/cli/benchmark/cookbook.py +62 -24
- moonshot/integrations/cli/benchmark/datasets.py +79 -40
- moonshot/integrations/cli/benchmark/metrics.py +62 -23
- moonshot/integrations/cli/benchmark/recipe.py +89 -69
- moonshot/integrations/cli/benchmark/result.py +85 -47
- moonshot/integrations/cli/benchmark/run.py +99 -59
- moonshot/integrations/cli/common/common.py +20 -6
- moonshot/integrations/cli/common/connectors.py +154 -74
- moonshot/integrations/cli/common/dataset.py +66 -0
- moonshot/integrations/cli/common/prompt_template.py +57 -19
- moonshot/integrations/cli/redteam/attack_module.py +90 -24
- moonshot/integrations/cli/redteam/context_strategy.py +83 -23
- moonshot/integrations/cli/redteam/prompt_template.py +1 -1
- moonshot/integrations/cli/redteam/redteam.py +52 -6
- moonshot/integrations/cli/redteam/session.py +565 -44
- moonshot/integrations/cli/utils/process_data.py +52 -0
- moonshot/integrations/web_api/__main__.py +2 -0
- moonshot/integrations/web_api/app.py +6 -6
- moonshot/integrations/web_api/container.py +12 -2
- moonshot/integrations/web_api/routes/bookmark.py +173 -0
- moonshot/integrations/web_api/routes/dataset.py +46 -1
- moonshot/integrations/web_api/schemas/bookmark_create_dto.py +13 -0
- moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
- moonshot/integrations/web_api/services/bookmark_service.py +94 -0
- moonshot/integrations/web_api/services/dataset_service.py +25 -0
- moonshot/integrations/web_api/services/recipe_service.py +0 -1
- moonshot/integrations/web_api/services/utils/file_manager.py +52 -0
- moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +0 -1
- moonshot/integrations/web_api/temp/.gitkeep +0 -0
- moonshot/src/api/api_bookmark.py +95 -0
- moonshot/src/api/api_connector_endpoint.py +1 -1
- moonshot/src/api/api_context_strategy.py +2 -2
- moonshot/src/api/api_dataset.py +35 -0
- moonshot/src/api/api_recipe.py +0 -3
- moonshot/src/api/api_session.py +1 -1
- moonshot/src/bookmark/bookmark.py +257 -0
- moonshot/src/bookmark/bookmark_arguments.py +38 -0
- moonshot/src/configs/env_variables.py +12 -2
- moonshot/src/connectors/connector.py +15 -7
- moonshot/src/connectors_endpoints/connector_endpoint.py +65 -49
- moonshot/src/cookbooks/cookbook.py +57 -37
- moonshot/src/datasets/dataset.py +125 -5
- moonshot/src/metrics/metric.py +8 -4
- moonshot/src/metrics/metric_interface.py +8 -2
- moonshot/src/prompt_templates/prompt_template.py +5 -1
- moonshot/src/recipes/recipe.py +38 -40
- moonshot/src/recipes/recipe_arguments.py +0 -4
- moonshot/src/redteaming/attack/attack_module.py +18 -8
- moonshot/src/redteaming/attack/context_strategy.py +6 -2
- moonshot/src/redteaming/session/session.py +15 -11
- moonshot/src/results/result.py +7 -3
- moonshot/src/runners/runner.py +65 -42
- moonshot/src/runs/run.py +15 -11
- moonshot/src/runs/run_progress.py +7 -3
- moonshot/src/storage/db_interface.py +14 -0
- moonshot/src/storage/storage.py +33 -2
- moonshot/src/utils/find_feature.py +45 -0
- moonshot/src/utils/log.py +72 -0
- moonshot/src/utils/pagination.py +25 -0
- moonshot/src/utils/timeit.py +8 -1
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/WHEEL +0 -0
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/AUTHORS.md +0 -0
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/LICENSE.md +0 -0
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/NOTICES.md +0 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from moonshot.src.utils.find_feature import find_keyword
|
|
2
|
+
from moonshot.src.utils.pagination import get_paginated_lists
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def filter_data(
|
|
6
|
+
list_of_data: list, keyword: str = "", pagination: tuple = ()
|
|
7
|
+
) -> list | None:
|
|
8
|
+
"""
|
|
9
|
+
Filters and paginates a list of data.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
list_of_data (list): The list of data to be filtered and paginated.
|
|
13
|
+
keyword (str, optional): The keyword to filter the data. Defaults to "".
|
|
14
|
+
pagination (tuple, optional): A tuple containing the page number and page size. Defaults to ().
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
list | None: The filtered and paginated list of data, or None if no data matches the criteria.
|
|
18
|
+
"""
|
|
19
|
+
# if there is a find keyword
|
|
20
|
+
if keyword:
|
|
21
|
+
list_of_data = find_keyword(keyword, list_of_data)
|
|
22
|
+
if not list_of_data:
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
# if pagination is required
|
|
26
|
+
if pagination:
|
|
27
|
+
# add index to every dictionary in the list
|
|
28
|
+
if all(isinstance(item, dict) for item in list_of_data):
|
|
29
|
+
for index, item in enumerate(list_of_data, 1):
|
|
30
|
+
if isinstance(item, dict):
|
|
31
|
+
item["idx"] = index
|
|
32
|
+
|
|
33
|
+
# get paginated data
|
|
34
|
+
page_number = pagination[0]
|
|
35
|
+
page_size = pagination[1]
|
|
36
|
+
|
|
37
|
+
if page_number <= 0 or page_size <= 0:
|
|
38
|
+
# print("Invalid page number or page size. Page number and page size should start from 1.")
|
|
39
|
+
raise RuntimeError(
|
|
40
|
+
"Invalid page number or page size. Page number and page size should start from 1."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
paginated_data = get_paginated_lists(page_size, list_of_data)
|
|
44
|
+
|
|
45
|
+
# perform index checks
|
|
46
|
+
paginated_data_size = len(paginated_data)
|
|
47
|
+
if page_number > paginated_data_size:
|
|
48
|
+
list_of_data = paginated_data[(paginated_data_size - 1)]
|
|
49
|
+
else:
|
|
50
|
+
list_of_data = paginated_data[(page_number - 1)]
|
|
51
|
+
|
|
52
|
+
return list_of_data
|
|
@@ -7,6 +7,7 @@ from dotenv import dotenv_values, load_dotenv
|
|
|
7
7
|
from .app import create_app
|
|
8
8
|
from .container import Container
|
|
9
9
|
from .logging_conf import configure_app_logging, create_uvicorn_log_config
|
|
10
|
+
from .services.utils.file_manager import create_temp_dir
|
|
10
11
|
from .types.types import UvicornRunArgs
|
|
11
12
|
|
|
12
13
|
|
|
@@ -20,6 +21,7 @@ def start_app():
|
|
|
20
21
|
else:
|
|
21
22
|
container.config.from_yaml(f"{config_file}", required=True)
|
|
22
23
|
|
|
24
|
+
create_temp_dir(container.config.temp_folder())
|
|
23
25
|
configure_app_logging(container.config)
|
|
24
26
|
logging.info(f"Environment: {container.config.app_environment()}")
|
|
25
27
|
ENABLE_SSL = container.config.ssl.enabled()
|
|
@@ -14,12 +14,13 @@ from .routes import (
|
|
|
14
14
|
attack_modules,
|
|
15
15
|
benchmark,
|
|
16
16
|
benchmark_result,
|
|
17
|
+
bookmark,
|
|
18
|
+
context_strategy,
|
|
17
19
|
cookbook,
|
|
18
20
|
dataset,
|
|
19
21
|
endpoint,
|
|
20
22
|
metric,
|
|
21
23
|
prompt_template,
|
|
22
|
-
context_strategy,
|
|
23
24
|
recipe,
|
|
24
25
|
runner,
|
|
25
26
|
)
|
|
@@ -66,13 +67,11 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
|
|
|
66
67
|
|
|
67
68
|
app_kwargs["swagger_ui_parameters"] = {
|
|
68
69
|
"defaultModelsExpandDepth": -1,
|
|
69
|
-
"docExpansion": None
|
|
70
|
-
|
|
70
|
+
"docExpansion": None,
|
|
71
|
+
}
|
|
71
72
|
|
|
72
73
|
app: CustomFastAPI = CustomFastAPI(
|
|
73
|
-
title="Project Moonshot",
|
|
74
|
-
version="0.4.1",
|
|
75
|
-
**app_kwargs
|
|
74
|
+
title="Project Moonshot", version="0.4.3", **app_kwargs
|
|
76
75
|
)
|
|
77
76
|
|
|
78
77
|
if cfg.cors.enabled():
|
|
@@ -104,6 +103,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
|
|
|
104
103
|
app.include_router(runner.router)
|
|
105
104
|
app.include_router(dataset.router)
|
|
106
105
|
app.include_router(attack_modules.router)
|
|
106
|
+
app.include_router(bookmark.router)
|
|
107
107
|
|
|
108
108
|
@app.exception_handler(RequestValidationError)
|
|
109
109
|
async def validation_exception_handler(
|
|
@@ -9,6 +9,7 @@ from .services.benchmark_result_service import BenchmarkResultService
|
|
|
9
9
|
from .services.benchmark_test_manager import BenchmarkTestManager
|
|
10
10
|
from .services.benchmark_test_state import BenchmarkTestState
|
|
11
11
|
from .services.benchmarking_service import BenchmarkingService
|
|
12
|
+
from .services.bookmark_service import BookmarkService
|
|
12
13
|
from .services.context_strategy_service import ContextStrategyService
|
|
13
14
|
from .services.cookbook_service import CookbookService
|
|
14
15
|
from .services.dataset_service import DatasetService
|
|
@@ -55,6 +56,11 @@ class Container(containers.DeclarativeContainer):
|
|
|
55
56
|
"log_file_max_size": 5242880,
|
|
56
57
|
"log_file_backup_count": 3,
|
|
57
58
|
},
|
|
59
|
+
"temp_folder": str(
|
|
60
|
+
importlib.resources.files("moonshot").joinpath(
|
|
61
|
+
"integrations/web_api/temp"
|
|
62
|
+
)
|
|
63
|
+
),
|
|
58
64
|
}
|
|
59
65
|
)
|
|
60
66
|
|
|
@@ -97,9 +103,9 @@ class Container(containers.DeclarativeContainer):
|
|
|
97
103
|
prompt_template_service: providers.Singleton[
|
|
98
104
|
PromptTemplateService
|
|
99
105
|
] = providers.Singleton(PromptTemplateService)
|
|
100
|
-
context_strategy_service: providers.Singleton[
|
|
106
|
+
context_strategy_service: providers.Singleton[
|
|
101
107
|
ContextStrategyService
|
|
102
|
-
)
|
|
108
|
+
] = providers.Singleton(ContextStrategyService)
|
|
103
109
|
benchmarking_service: providers.Singleton[
|
|
104
110
|
BenchmarkingService
|
|
105
111
|
] = providers.Singleton(
|
|
@@ -127,6 +133,9 @@ class Container(containers.DeclarativeContainer):
|
|
|
127
133
|
am_service: providers.Singleton[AttackModuleService] = providers.Singleton(
|
|
128
134
|
AttackModuleService,
|
|
129
135
|
)
|
|
136
|
+
bookmark_service: providers.Singleton[BookmarkService] = providers.Singleton(
|
|
137
|
+
BookmarkService,
|
|
138
|
+
)
|
|
130
139
|
wiring_config = containers.WiringConfiguration(
|
|
131
140
|
modules=[
|
|
132
141
|
".routes.redteam",
|
|
@@ -141,6 +150,7 @@ class Container(containers.DeclarativeContainer):
|
|
|
141
150
|
".routes.runner",
|
|
142
151
|
".routes.dataset",
|
|
143
152
|
".routes.attack_modules",
|
|
153
|
+
".routes.bookmark",
|
|
144
154
|
".services.benchmarking_service",
|
|
145
155
|
]
|
|
146
156
|
)
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from dependency_injector.wiring import Provide, inject
|
|
4
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
5
|
+
|
|
6
|
+
from ..container import Container
|
|
7
|
+
from ..schemas.bookmark_create_dto import BookmarkCreateDTO, BookmarkPydanticModel
|
|
8
|
+
from ..services.bookmark_service import BookmarkService
|
|
9
|
+
from ..services.utils.exceptions_handler import ServiceException
|
|
10
|
+
|
|
11
|
+
router = APIRouter(tags=["Bookmark"])
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@router.post(
|
|
15
|
+
"/api/v1/bookmarks", response_description="Bookmark data added to the database"
|
|
16
|
+
)
|
|
17
|
+
@inject
|
|
18
|
+
def insert_bookmark(
|
|
19
|
+
bookmark_data: BookmarkCreateDTO,
|
|
20
|
+
bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
|
|
21
|
+
) -> dict:
|
|
22
|
+
"""
|
|
23
|
+
Insert a new bookmark into the database.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
bookmark_data: The data of the bookmark to be added.
|
|
27
|
+
bookmark_service: The service responsible for bookmark operations.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A dictionary with a message indicating successful insertion.
|
|
31
|
+
|
|
32
|
+
Raises:
|
|
33
|
+
HTTPException: An error occurred while inserting the bookmark.
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
result = bookmark_service.insert_bookmark(bookmark_data)
|
|
37
|
+
return result
|
|
38
|
+
except ServiceException as e:
|
|
39
|
+
if e.error_code == "FileNotFound":
|
|
40
|
+
raise HTTPException(
|
|
41
|
+
status_code=404, detail=f"Failed to insert bookmark: {e.msg}"
|
|
42
|
+
)
|
|
43
|
+
elif e.error_code == "ValidationError":
|
|
44
|
+
raise HTTPException(
|
|
45
|
+
status_code=400, detail=f"Failed to insert bookmark: {e.msg}"
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
raise HTTPException(
|
|
49
|
+
status_code=500, detail=f"Failed to insert bookmark: {e.msg}"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@router.get(
|
|
54
|
+
"/api/v1/bookmarks",
|
|
55
|
+
response_description="List of all bookmarks or a specific bookmark by name",
|
|
56
|
+
)
|
|
57
|
+
@inject
|
|
58
|
+
def get_all_bookmarks(
|
|
59
|
+
name: Optional[str] = Query(None, description="Name of the bookmark to query"),
|
|
60
|
+
bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
|
|
61
|
+
) -> list[BookmarkPydanticModel]:
|
|
62
|
+
"""
|
|
63
|
+
Retrieve all bookmarks or a specific bookmark by name from the database.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
name: The name of the bookmark to retrieve. If None, all bookmarks are retrieved.
|
|
67
|
+
bookmark_service: The service responsible for bookmark operations.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A list of bookmarks or a single bookmark if a name is provided.
|
|
71
|
+
|
|
72
|
+
Raises:
|
|
73
|
+
HTTPException: An error occurred while retrieving bookmarks.
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
return bookmark_service.get_all_bookmarks(
|
|
77
|
+
name=name,
|
|
78
|
+
)
|
|
79
|
+
except ServiceException as e:
|
|
80
|
+
if e.error_code == "FileNotFound":
|
|
81
|
+
raise HTTPException(
|
|
82
|
+
status_code=404, detail=f"Failed to retrieve bookmarks: {e.msg}"
|
|
83
|
+
)
|
|
84
|
+
elif e.error_code == "ValidationError":
|
|
85
|
+
raise HTTPException(
|
|
86
|
+
status_code=400, detail=f"Failed to retrieve bookmarks: {e.msg}"
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
raise HTTPException(
|
|
90
|
+
status_code=500, detail=f"Failed to retrieve bookmarks: {e.msg}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@router.delete(
|
|
95
|
+
"/api/v1/bookmarks", response_description="Bookmark data deleted from the database"
|
|
96
|
+
)
|
|
97
|
+
@inject
|
|
98
|
+
def delete_bookmark(
|
|
99
|
+
all: bool = Query(False, description="Flag to delete all bookmarks"),
|
|
100
|
+
name: Optional[str] = Query(None, description="Name of the bookmark to delete"),
|
|
101
|
+
bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
|
|
102
|
+
) -> dict:
|
|
103
|
+
"""
|
|
104
|
+
Delete a specific bookmark by name or all bookmarks from the database.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
all: A flag indicating whether to delete all bookmarks.
|
|
108
|
+
name: The name of the bookmark to delete. If 'all' is False and 'name' is None, no bookmark will be deleted.
|
|
109
|
+
bookmark_service: The service responsible for bookmark operations.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
A dictionary with a message indicating successful deletion or an error message.
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
HTTPException: An error occurred while deleting the bookmark(s).
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
if all:
|
|
119
|
+
return bookmark_service.delete_bookmarks(all=True)
|
|
120
|
+
elif name is not None:
|
|
121
|
+
return bookmark_service.delete_bookmarks(all=False, name=name)
|
|
122
|
+
else:
|
|
123
|
+
raise HTTPException(
|
|
124
|
+
status_code=400, detail="Must specify 'all' or 'name' parameter"
|
|
125
|
+
)
|
|
126
|
+
except ServiceException as e:
|
|
127
|
+
if e.error_code == "FileNotFound":
|
|
128
|
+
raise HTTPException(
|
|
129
|
+
status_code=404, detail=f"Failed to delete bookmark: {e.msg}"
|
|
130
|
+
)
|
|
131
|
+
elif e.error_code == "ValidationError":
|
|
132
|
+
raise HTTPException(
|
|
133
|
+
status_code=400, detail=f"Failed to delete bookmark: {e.msg}"
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
raise HTTPException(
|
|
137
|
+
status_code=500, detail=f"Failed to delete bookmark: {e.msg}"
|
|
138
|
+
)
|
|
139
|
+
@router.post(
|
|
140
|
+
"/api/v1/bookmarks/export", response_description="Exporting Bookmark to JSON file"
|
|
141
|
+
)
|
|
142
|
+
@inject
|
|
143
|
+
def export_bookbookmarks(
|
|
144
|
+
export_file_name: Optional[str] = Query(
|
|
145
|
+
"bookmarks", description="Name of the exported file"
|
|
146
|
+
),
|
|
147
|
+
bookmark_service: BookmarkService = Depends(Provide[Container.bookmark_service]),
|
|
148
|
+
) -> str:
|
|
149
|
+
"""
|
|
150
|
+
Export bookmarks to a JSON file with a given file name.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
export_file_name: The name of the file to export the bookmarks to.z
|
|
154
|
+
bookmark_service: The service responsible for bookmark operations.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A string with the path to the exported file or an error message.
|
|
158
|
+
"""
|
|
159
|
+
try:
|
|
160
|
+
return bookmark_service.export_bookmarks(export_file_name)
|
|
161
|
+
except ServiceException as e:
|
|
162
|
+
if e.error_code == "FileNotFound":
|
|
163
|
+
raise HTTPException(
|
|
164
|
+
status_code=404, detail=f"Failed to export bookmark: {e.msg}"
|
|
165
|
+
)
|
|
166
|
+
elif e.error_code == "ValidationError":
|
|
167
|
+
raise HTTPException(
|
|
168
|
+
status_code=400, detail=f"Failed to export bookmark: {e.msg}"
|
|
169
|
+
)
|
|
170
|
+
else:
|
|
171
|
+
raise HTTPException(
|
|
172
|
+
status_code=500, detail=f"Failed to export bookmark: {e.msg}"
|
|
173
|
+
)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dependency_injector.wiring import Provide, inject
|
|
2
|
-
from fastapi import APIRouter, Depends, HTTPException
|
|
2
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
3
3
|
|
|
4
4
|
from ..container import Container
|
|
5
|
+
from ..schemas.dataset_create_dto import DatasetCreateDTO
|
|
5
6
|
from ..schemas.dataset_response_dto import DatasetResponseDTO
|
|
6
7
|
from ..services.dataset_service import DatasetService
|
|
7
8
|
from ..services.utils.exceptions_handler import ServiceException
|
|
@@ -9,6 +10,50 @@ from ..services.utils.exceptions_handler import ServiceException
|
|
|
9
10
|
router = APIRouter(tags=["Datasets"])
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
@router.post("/api/v1/datasets")
|
|
14
|
+
@inject
|
|
15
|
+
def create_dataset(
|
|
16
|
+
dataset_data: DatasetCreateDTO,
|
|
17
|
+
method: str = Query(
|
|
18
|
+
...,
|
|
19
|
+
description="The method to use for creating the dataset. Supported methods are 'hf' and 'csv'.",
|
|
20
|
+
),
|
|
21
|
+
dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
|
|
22
|
+
) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Create a new dataset using the specified method.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dataset_data (DatasetCreateDTO): The data required to create the dataset.
|
|
28
|
+
method (str): The method to use for creating the dataset. Supported methods are "hf" and "csv".
|
|
29
|
+
dataset_service (DatasetService, optional): The service responsible for creating the dataset.
|
|
30
|
+
Defaults to Depends(Provide[Container.dataset_service]).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
dict: A message indicating the dataset was created successfully.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
HTTPException: An error with status code 404 if the dataset file is not found.
|
|
37
|
+
An error with status code 400 if there is a validation error.
|
|
38
|
+
An error with status code 500 for any other server-side error.
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
return dataset_service.create_dataset(dataset_data, method)
|
|
42
|
+
except ServiceException as e:
|
|
43
|
+
if e.error_code == "FileNotFound":
|
|
44
|
+
raise HTTPException(
|
|
45
|
+
status_code=404, detail=f"Failed to retrieve datasets: {e.msg}"
|
|
46
|
+
)
|
|
47
|
+
elif e.error_code == "ValidationError":
|
|
48
|
+
raise HTTPException(
|
|
49
|
+
status_code=400, detail=f"Failed to retrieve datasets: {e.msg}"
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
raise HTTPException(
|
|
53
|
+
status_code=500, detail=f"Failed to retrieve datasets: {e.msg}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
12
57
|
@router.get("/api/v1/datasets")
|
|
13
58
|
@inject
|
|
14
59
|
def get_all_datasets(
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from moonshot.src.bookmark.bookmark_arguments import (
|
|
4
|
+
BookmarkArguments as BookmarkPydanticModel,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BookmarkCreateDTO(BookmarkPydanticModel):
|
|
9
|
+
prompt_template: Optional[str] = ""
|
|
10
|
+
context_strategy: Optional[str] = ""
|
|
11
|
+
attack_module: Optional[str] = ""
|
|
12
|
+
metric: Optional[str] = ""
|
|
13
|
+
bookmark_time: Optional[str] = None
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
from pyparsing import Iterator
|
|
5
|
+
|
|
6
|
+
from moonshot.src.datasets.dataset_arguments import (
|
|
7
|
+
DatasetArguments as DatasetPydanticModel,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DatasetCreateDTO(DatasetPydanticModel):
|
|
12
|
+
id: Optional[str] = None
|
|
13
|
+
examples: Iterator[dict] = None
|
|
14
|
+
name: str = Field(..., min_length=1)
|
|
15
|
+
description: str = Field(default="", min_length=1)
|
|
16
|
+
license: Optional[str] = ""
|
|
17
|
+
reference: Optional[str] = ""
|
|
18
|
+
params: dict
|
|
@@ -13,7 +13,6 @@ class RecipeCreateDTO(RecipePydanticModel):
|
|
|
13
13
|
datasets: list[str] = Field(..., min_length=1)
|
|
14
14
|
metrics: list[str] = Field(..., min_length=1)
|
|
15
15
|
prompt_templates: Optional[list[str]] = None
|
|
16
|
-
attack_modules: Optional[list[str]] = None
|
|
17
16
|
grading_scale: Optional[dict[str, list[int]]] = None
|
|
18
17
|
stats: Optional[dict] = None
|
|
19
18
|
|
|
@@ -27,6 +26,5 @@ class RecipeUpdateDTO(RecipePydanticModel):
|
|
|
27
26
|
datasets: Optional[list[str]] = None
|
|
28
27
|
prompt_templates: Optional[list[str]] = None
|
|
29
28
|
metrics: Optional[list[str]] = None
|
|
30
|
-
attack_modules: Optional[list[str]] = None
|
|
31
29
|
grading_scale: Optional[dict[str, list[int]]] = None
|
|
32
30
|
stats: Optional[dict] = None
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from .... import api as moonshot_api
|
|
2
|
+
from ..schemas.bookmark_create_dto import BookmarkCreateDTO, BookmarkPydanticModel
|
|
3
|
+
from ..services.base_service import BaseService
|
|
4
|
+
from ..services.utils.exceptions_handler import exception_handler
|
|
5
|
+
from .utils.file_manager import copy_file
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class BookmarkService(BaseService):
|
|
9
|
+
@exception_handler
|
|
10
|
+
def insert_bookmark(self, bookmark_data: BookmarkCreateDTO) -> dict:
|
|
11
|
+
"""
|
|
12
|
+
Inserts a new bookmark into the system.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
bookmark_data (BookmarkCreateDTO): The data transfer object containing bookmark details.
|
|
16
|
+
"""
|
|
17
|
+
result = moonshot_api.api_insert_bookmark(
|
|
18
|
+
name=bookmark_data.name,
|
|
19
|
+
prompt=bookmark_data.prompt,
|
|
20
|
+
prepared_prompt=bookmark_data.prepared_prompt,
|
|
21
|
+
response=bookmark_data.response,
|
|
22
|
+
context_strategy=bookmark_data.context_strategy,
|
|
23
|
+
prompt_template=bookmark_data.prompt_template,
|
|
24
|
+
attack_module=bookmark_data.attack_module,
|
|
25
|
+
metric=bookmark_data.metric,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return result
|
|
29
|
+
|
|
30
|
+
@exception_handler
|
|
31
|
+
def get_all_bookmarks(self, name: str | None = None) -> list[BookmarkPydanticModel]:
|
|
32
|
+
"""
|
|
33
|
+
Retrieves all bookmarks or a specific bookmark by its name.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
name (str | None, optional): The name of the bookmark to retrieve. If None, all bookmarks are retrieved.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
list[BookmarkPydanticModel]: A list of bookmark models.
|
|
40
|
+
"""
|
|
41
|
+
retn_bookmark: list[BookmarkPydanticModel] = []
|
|
42
|
+
|
|
43
|
+
if name:
|
|
44
|
+
bookmarks = [moonshot_api.api_get_bookmark(name)]
|
|
45
|
+
else:
|
|
46
|
+
bookmarks = moonshot_api.api_get_all_bookmarks()
|
|
47
|
+
|
|
48
|
+
for bookmark in bookmarks:
|
|
49
|
+
retn_bookmark.append(BookmarkPydanticModel(**bookmark))
|
|
50
|
+
return retn_bookmark
|
|
51
|
+
|
|
52
|
+
@exception_handler
|
|
53
|
+
def delete_bookmarks(self, all: bool = False, name: str | None = None) -> dict:
|
|
54
|
+
"""
|
|
55
|
+
Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True and returns
|
|
56
|
+
a boolean indicating the success of the operation.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
all (bool, optional): If True, all bookmarks will be deleted. Defaults to False.
|
|
60
|
+
name (str | None, optional): The name of the bookmark to delete. If 'all' is False, 'name' must be provided.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
dict: True if the deletion was successful, False otherwise.
|
|
64
|
+
"""
|
|
65
|
+
if all:
|
|
66
|
+
result = moonshot_api.api_delete_all_bookmark()
|
|
67
|
+
elif name is not None:
|
|
68
|
+
result = moonshot_api.api_delete_bookmark(name)
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError("Either 'all' must be True or 'name' must be provided.")
|
|
71
|
+
|
|
72
|
+
if not result["success"]:
|
|
73
|
+
raise Exception(
|
|
74
|
+
result["message"], "delete_bookmarks", "DeleteBookmarkError"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
@exception_handler
|
|
80
|
+
def export_bookmarks(self, export_file_name: str = "bookmarks") -> str:
|
|
81
|
+
"""
|
|
82
|
+
Exports bookmarks to a file or returns them as a list of dictionaries.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
write_file (bool, optional): If True, bookmarks will be written to a file. Defaults to False.
|
|
86
|
+
export_file_name (str, optional): The name of the file to write bookmarks to. Defaults to "bookmarks".
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
list[dict]: A list of bookmarks as dictionaries.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
new_file_path = moonshot_api.api_export_bookmarks(export_file_name)
|
|
93
|
+
|
|
94
|
+
return copy_file(new_file_path)
|
|
@@ -1,10 +1,35 @@
|
|
|
1
1
|
from .... import api as moonshot_api
|
|
2
|
+
from ..schemas.dataset_create_dto import DatasetCreateDTO
|
|
2
3
|
from ..schemas.dataset_response_dto import DatasetResponseDTO
|
|
3
4
|
from ..services.base_service import BaseService
|
|
4
5
|
from ..services.utils.exceptions_handler import exception_handler
|
|
6
|
+
from .utils.file_manager import copy_file
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class DatasetService(BaseService):
|
|
10
|
+
@exception_handler
|
|
11
|
+
def create_dataset(self, dataset_data: DatasetCreateDTO, method: str) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Create a dataset using the specified method.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
dataset_data (DatasetCreateDTO): The data required to create the dataset.
|
|
17
|
+
method (str): The method to use for creating the dataset.
|
|
18
|
+
Supported methods are "hf" and "csv".
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
Exception: If an error occurs during dataset creation.
|
|
22
|
+
"""
|
|
23
|
+
new_ds_path = moonshot_api.api_create_datasets(
|
|
24
|
+
name=dataset_data.name,
|
|
25
|
+
description=dataset_data.description,
|
|
26
|
+
reference=dataset_data.reference,
|
|
27
|
+
license=dataset_data.license,
|
|
28
|
+
method=method,
|
|
29
|
+
**dataset_data.params,
|
|
30
|
+
)
|
|
31
|
+
return copy_file(new_ds_path)
|
|
32
|
+
|
|
8
33
|
@exception_handler
|
|
9
34
|
def get_all_datasets(self) -> list[DatasetResponseDTO]:
|
|
10
35
|
datasets = moonshot_api.api_get_all_datasets()
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
|
|
4
|
+
# Global variable to store the path to the temporary folder
|
|
5
|
+
TEMP_FOLDER_PATH = None
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def create_temp_dir(temp_file_path: str):
|
|
9
|
+
"""
|
|
10
|
+
Create a temporary directory at the specified path.
|
|
11
|
+
|
|
12
|
+
This function sets a global variable with the path to the temporary directory
|
|
13
|
+
and creates the directory if it does not already exist.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
temp_file_path (str): The file path where the temporary directory will be created.
|
|
17
|
+
"""
|
|
18
|
+
global TEMP_FOLDER_PATH
|
|
19
|
+
TEMP_FOLDER_PATH = temp_file_path
|
|
20
|
+
if not os.path.exists(TEMP_FOLDER_PATH):
|
|
21
|
+
os.makedirs(TEMP_FOLDER_PATH)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def copy_file(file_path: str) -> str:
|
|
25
|
+
"""
|
|
26
|
+
Copy a file to the temporary directory.
|
|
27
|
+
|
|
28
|
+
This function copies a file from the given file path to the temporary directory
|
|
29
|
+
set by the create_temp_dir function. It raises an error if the temporary directory
|
|
30
|
+
path is not set before calling this function.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
file_path (str): The path of the file to be copied.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
str: The path to the copied file in the temporary directory.
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the temporary folder path is not set.
|
|
40
|
+
"""
|
|
41
|
+
if TEMP_FOLDER_PATH is None:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
"Temporary folder path is not set. Call create_temp_dir first."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Define the destination path for the copied file
|
|
47
|
+
dest_path = os.path.join(TEMP_FOLDER_PATH, os.path.basename(file_path))
|
|
48
|
+
|
|
49
|
+
# Copy the file to the destination
|
|
50
|
+
shutil.copy(file_path, dest_path)
|
|
51
|
+
|
|
52
|
+
return dest_path
|
|
@@ -60,7 +60,6 @@ class MoonshotUIWebhook(
|
|
|
60
60
|
logger = logging.getLogger()
|
|
61
61
|
logger.debug(json.dumps(progress_data, indent=2))
|
|
62
62
|
self.auto_red_team_test_state.update_progress_status(progress_data)
|
|
63
|
-
print("Calling Auto Red Team Callback")
|
|
64
63
|
try:
|
|
65
64
|
response = requests.post(self.art_url, json=progress_data)
|
|
66
65
|
response.raise_for_status()
|
|
File without changes
|