aiverify-moonshot 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/METADATA +12 -10
- {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/RECORD +35 -31
- moonshot/api.py +2 -0
- moonshot/integrations/cli/benchmark/cookbook.py +36 -28
- moonshot/integrations/cli/benchmark/datasets.py +56 -47
- moonshot/integrations/cli/benchmark/metrics.py +39 -30
- moonshot/integrations/cli/benchmark/recipe.py +63 -73
- moonshot/integrations/cli/benchmark/result.py +62 -54
- moonshot/integrations/cli/benchmark/run.py +75 -66
- moonshot/integrations/cli/common/common.py +8 -0
- moonshot/integrations/cli/common/connectors.py +101 -85
- moonshot/integrations/cli/common/dataset.py +66 -0
- moonshot/integrations/cli/common/prompt_template.py +30 -27
- moonshot/integrations/cli/redteam/attack_module.py +45 -30
- moonshot/integrations/cli/redteam/context_strategy.py +36 -30
- moonshot/integrations/cli/redteam/session.py +101 -76
- moonshot/integrations/cli/utils/process_data.py +52 -0
- moonshot/integrations/web_api/app.py +1 -1
- moonshot/integrations/web_api/routes/dataset.py +46 -1
- moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
- moonshot/integrations/web_api/services/bookmark_service.py +6 -2
- moonshot/integrations/web_api/services/dataset_service.py +25 -0
- moonshot/integrations/web_api/services/recipe_service.py +0 -1
- moonshot/src/api/api_dataset.py +35 -0
- moonshot/src/api/api_recipe.py +0 -3
- moonshot/src/datasets/dataset.py +116 -0
- moonshot/src/recipes/recipe.py +0 -15
- moonshot/src/recipes/recipe_arguments.py +0 -4
- moonshot/src/utils/log.py +12 -6
- moonshot/src/utils/pagination.py +25 -0
- {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/WHEEL +0 -0
- {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/licenses/AUTHORS.md +0 -0
- {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/licenses/LICENSE.md +0 -0
- {aiverify_moonshot-0.4.2.dist-info → aiverify_moonshot-0.4.4.dist-info}/licenses/NOTICES.md +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import argparse
|
|
2
|
+
from ast import literal_eval
|
|
2
3
|
|
|
3
4
|
import cmd2
|
|
4
5
|
from rich.console import Console
|
|
@@ -10,8 +11,8 @@ from moonshot.api import (
|
|
|
10
11
|
api_update_context_strategy,
|
|
11
12
|
)
|
|
12
13
|
from moonshot.integrations.cli.active_session_cfg import active_session
|
|
14
|
+
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
13
15
|
from moonshot.src.redteaming.session.session import Session
|
|
14
|
-
from moonshot.src.utils.find_feature import find_keyword
|
|
15
16
|
|
|
16
17
|
console = Console()
|
|
17
18
|
|
|
@@ -60,6 +61,7 @@ def list_context_strategies(args) -> list | None:
|
|
|
60
61
|
Args:
|
|
61
62
|
args: A namespace object from argparse. It should have an optional attribute:
|
|
62
63
|
find (str): Optional field to find context strategies with a keyword.
|
|
64
|
+
pagination (str): Optional field to paginate context strategies.
|
|
63
65
|
|
|
64
66
|
Returns:
|
|
65
67
|
list | None: A list of ContextStrategy or None if there is no result.
|
|
@@ -67,19 +69,19 @@ def list_context_strategies(args) -> list | None:
|
|
|
67
69
|
try:
|
|
68
70
|
context_strategy_metadata_list = api_get_all_context_strategy_metadata()
|
|
69
71
|
keyword = args.find.lower() if args.find else ""
|
|
70
|
-
if
|
|
71
|
-
|
|
72
|
-
|
|
72
|
+
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
73
|
+
|
|
74
|
+
if context_strategy_metadata_list:
|
|
75
|
+
filtered_context_strategies_list = filter_data(
|
|
76
|
+
context_strategy_metadata_list, keyword, pagination
|
|
73
77
|
)
|
|
74
78
|
if filtered_context_strategies_list:
|
|
75
|
-
|
|
79
|
+
_display_context_strategies(filtered_context_strategies_list)
|
|
76
80
|
return filtered_context_strategies_list
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
display_context_strategies(context_strategy_metadata_list)
|
|
82
|
-
return context_strategy_metadata_list
|
|
81
|
+
|
|
82
|
+
console.print("[red]There are no context strategies found.[/red]")
|
|
83
|
+
return None
|
|
84
|
+
|
|
83
85
|
except Exception as e:
|
|
84
86
|
print(f"[list_context_strategies]: {str(e)}")
|
|
85
87
|
|
|
@@ -124,7 +126,7 @@ def delete_context_strategy(args) -> None:
|
|
|
124
126
|
print(f"[delete_context_strategy]: {str(e)}")
|
|
125
127
|
|
|
126
128
|
|
|
127
|
-
def
|
|
129
|
+
def _display_context_strategies(context_strategies: list) -> None:
|
|
128
130
|
"""
|
|
129
131
|
Display a list of context strategies.
|
|
130
132
|
|
|
@@ -137,25 +139,21 @@ def display_context_strategies(context_strategies: list) -> None:
|
|
|
137
139
|
Returns:
|
|
138
140
|
None
|
|
139
141
|
"""
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
context_strategy_data_str = ""
|
|
153
|
-
for k, v in context_strategy_data.items():
|
|
142
|
+
table = Table(
|
|
143
|
+
title="Context Strategy List",
|
|
144
|
+
show_lines=True,
|
|
145
|
+
expand=True,
|
|
146
|
+
header_style="bold",
|
|
147
|
+
)
|
|
148
|
+
table.add_column("No.", justify="left", width=2)
|
|
149
|
+
table.add_column("Context Strategy Information", justify="left", width=98)
|
|
150
|
+
for idx, context_strategy_data in enumerate(context_strategies, 1):
|
|
151
|
+
context_strategy_data_str = ""
|
|
152
|
+
for k, v in context_strategy_data.items():
|
|
153
|
+
if k != "idx":
|
|
154
154
|
context_strategy_data_str += f"[blue]{k.capitalize()}:[/blue] {v}\n\n"
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
else:
|
|
158
|
-
console.print("[red]There are no context strategies found.[/red]", style="bold")
|
|
155
|
+
table.add_row(str(idx), context_strategy_data_str)
|
|
156
|
+
console.print(table)
|
|
159
157
|
|
|
160
158
|
|
|
161
159
|
# Use context strategy arguments
|
|
@@ -199,3 +197,11 @@ list_context_strategies_args.add_argument(
|
|
|
199
197
|
help="Optional field to find context strategies with keyword",
|
|
200
198
|
nargs="?",
|
|
201
199
|
)
|
|
200
|
+
|
|
201
|
+
list_context_strategies_args.add_argument(
|
|
202
|
+
"-p",
|
|
203
|
+
"--pagination",
|
|
204
|
+
type=str,
|
|
205
|
+
help="Optional tuple to paginate context strategies(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
|
|
206
|
+
nargs="?",
|
|
207
|
+
)
|
|
@@ -22,8 +22,8 @@ from moonshot.api import (
|
|
|
22
22
|
api_load_session,
|
|
23
23
|
)
|
|
24
24
|
from moonshot.integrations.cli.active_session_cfg import active_session
|
|
25
|
+
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
25
26
|
from moonshot.src.redteaming.session.session import Session
|
|
26
|
-
from moonshot.src.utils.find_feature import find_keyword
|
|
27
27
|
|
|
28
28
|
console = Console()
|
|
29
29
|
|
|
@@ -139,6 +139,7 @@ def list_sessions(args) -> list | None:
|
|
|
139
139
|
Args:
|
|
140
140
|
args: A namespace object from argparse. It should have an optional attribute:
|
|
141
141
|
find (str): Optional field to find session(s) with a keyword.
|
|
142
|
+
pagination (str): Optional field to paginate sessions.
|
|
142
143
|
|
|
143
144
|
Returns:
|
|
144
145
|
list | None: A list of Session or None if there is no result.
|
|
@@ -146,19 +147,19 @@ def list_sessions(args) -> list | None:
|
|
|
146
147
|
try:
|
|
147
148
|
session_metadata_list = api_get_all_session_metadata()
|
|
148
149
|
keyword = args.find.lower() if args.find else ""
|
|
149
|
-
if
|
|
150
|
-
|
|
151
|
-
|
|
150
|
+
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
151
|
+
|
|
152
|
+
if session_metadata_list:
|
|
153
|
+
filtered_session_metadata_list = filter_data(
|
|
154
|
+
session_metadata_list, keyword, pagination
|
|
152
155
|
)
|
|
153
156
|
if filtered_session_metadata_list:
|
|
154
|
-
|
|
157
|
+
_display_sessions(filtered_session_metadata_list)
|
|
155
158
|
return filtered_session_metadata_list
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
display_sessions(session_metadata_list)
|
|
161
|
-
return session_metadata_list
|
|
159
|
+
|
|
160
|
+
console.print("[red]There are no sessions found.[/red]")
|
|
161
|
+
return None
|
|
162
|
+
|
|
162
163
|
except Exception as e:
|
|
163
164
|
print(f"[list_sessions]: {str(e)}")
|
|
164
165
|
|
|
@@ -363,12 +364,13 @@ def list_bookmarks(args) -> list | None:
|
|
|
363
364
|
|
|
364
365
|
This function retrieves all available bookmarks by calling the api_get_all_bookmarks function from the
|
|
365
366
|
moonshot.api module.
|
|
366
|
-
It then displays the retrieved bookmarks using the
|
|
367
|
+
It then displays the retrieved bookmarks using the _display_bookmarks function.
|
|
367
368
|
If no bookmarks are found, a message is printed to the console.
|
|
368
369
|
|
|
369
370
|
Args:
|
|
370
371
|
args: A namespace object from argparse. It should have an optional attribute:
|
|
371
372
|
find (str): Optional field to find bookmark(s) with a keyword.
|
|
373
|
+
pagination (str): Optional field to paginate bookmarks.
|
|
372
374
|
|
|
373
375
|
Returns:
|
|
374
376
|
list | None: A list of Bookmark or None if there is no result.
|
|
@@ -376,22 +378,22 @@ def list_bookmarks(args) -> list | None:
|
|
|
376
378
|
try:
|
|
377
379
|
bookmarks_list = api_get_all_bookmarks()
|
|
378
380
|
keyword = args.find.lower() if args.find else ""
|
|
379
|
-
if
|
|
380
|
-
|
|
381
|
+
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
382
|
+
|
|
383
|
+
if bookmarks_list:
|
|
384
|
+
filtered_bookmarks_list = filter_data(bookmarks_list, keyword, pagination)
|
|
381
385
|
if filtered_bookmarks_list:
|
|
382
|
-
|
|
386
|
+
_display_bookmarks(filtered_bookmarks_list)
|
|
383
387
|
return filtered_bookmarks_list
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
display_bookmarks(bookmarks_list)
|
|
389
|
-
return bookmarks_list
|
|
388
|
+
|
|
389
|
+
console.print("[red]There are no bookmarks found.[/red]")
|
|
390
|
+
return None
|
|
391
|
+
|
|
390
392
|
except Exception as e:
|
|
391
393
|
print(f"[list_bookmarks]: {str(e)}")
|
|
392
394
|
|
|
393
395
|
|
|
394
|
-
def
|
|
396
|
+
def _display_bookmarks(bookmarks_list) -> None:
|
|
395
397
|
"""
|
|
396
398
|
Display the list of bookmarks in a tabular format.
|
|
397
399
|
|
|
@@ -402,38 +404,39 @@ def display_bookmarks(bookmarks_list) -> None:
|
|
|
402
404
|
Args:
|
|
403
405
|
bookmarks_list (list): A list of dictionaries, where each dictionary contains the details of a bookmark.
|
|
404
406
|
"""
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
407
|
+
|
|
408
|
+
table = Table(
|
|
409
|
+
title="Bookmark List", show_lines=True, expand=True, header_style="bold"
|
|
410
|
+
)
|
|
411
|
+
table.add_column("ID.", justify="left", width=5)
|
|
412
|
+
table.add_column("Name", justify="left", width=20)
|
|
413
|
+
table.add_column("Prepared Prompt", justify="left", width=50)
|
|
414
|
+
table.add_column("Predicted Response", justify="left", width=50)
|
|
415
|
+
table.add_column("Bookmark Time", justify="left", width=20)
|
|
416
|
+
for idx, bookmark in enumerate(bookmarks_list, 1):
|
|
417
|
+
(
|
|
418
|
+
name,
|
|
419
|
+
prompt,
|
|
420
|
+
prepared_prompt,
|
|
421
|
+
response,
|
|
422
|
+
context_strategy,
|
|
423
|
+
prompt_template,
|
|
424
|
+
attack_module,
|
|
425
|
+
metric,
|
|
426
|
+
bookmark_time,
|
|
427
|
+
*other_args,
|
|
428
|
+
) = bookmark.values()
|
|
429
|
+
idx = bookmark.get("idx", idx)
|
|
430
|
+
|
|
431
|
+
table.add_section()
|
|
432
|
+
table.add_row(
|
|
433
|
+
str(idx),
|
|
434
|
+
name,
|
|
435
|
+
prepared_prompt,
|
|
436
|
+
response,
|
|
437
|
+
bookmark_time,
|
|
408
438
|
)
|
|
409
|
-
|
|
410
|
-
table.add_column("Name", justify="left", width=20)
|
|
411
|
-
table.add_column("Prepared Prompt", justify="left", width=50)
|
|
412
|
-
table.add_column("Predicted Response", justify="left", width=50)
|
|
413
|
-
table.add_column("Bookmark Time", justify="left", width=20)
|
|
414
|
-
for idx, bookmark in enumerate(bookmarks_list, 1):
|
|
415
|
-
(
|
|
416
|
-
name,
|
|
417
|
-
prompt,
|
|
418
|
-
prepared_prompt,
|
|
419
|
-
response,
|
|
420
|
-
context_strategy,
|
|
421
|
-
prompt_template,
|
|
422
|
-
attack_module,
|
|
423
|
-
metric,
|
|
424
|
-
bookmark_time,
|
|
425
|
-
) = bookmark.values()
|
|
426
|
-
table.add_section()
|
|
427
|
-
table.add_row(
|
|
428
|
-
str(idx),
|
|
429
|
-
name,
|
|
430
|
-
prepared_prompt,
|
|
431
|
-
response,
|
|
432
|
-
bookmark_time,
|
|
433
|
-
)
|
|
434
|
-
console.print(table)
|
|
435
|
-
else:
|
|
436
|
-
console.print("[red]There are no bookmarks found.[/red]")
|
|
439
|
+
console.print(table)
|
|
437
440
|
|
|
438
441
|
|
|
439
442
|
def view_bookmark(args) -> None:
|
|
@@ -446,12 +449,12 @@ def view_bookmark(args) -> None:
|
|
|
446
449
|
|
|
447
450
|
try:
|
|
448
451
|
bookmark_info = api_get_bookmark(args.bookmark_name)
|
|
449
|
-
|
|
452
|
+
_display_bookmark(bookmark_info)
|
|
450
453
|
except Exception as e:
|
|
451
454
|
print(f"[view_bookmark]: {str(e)}")
|
|
452
455
|
|
|
453
456
|
|
|
454
|
-
def
|
|
457
|
+
def _display_bookmark(bookmark_info: dict) -> None:
|
|
455
458
|
"""
|
|
456
459
|
Display the filtered bookmark in a tabular format.
|
|
457
460
|
|
|
@@ -600,7 +603,7 @@ def run_attack_module(args):
|
|
|
600
603
|
if args.prompt_template:
|
|
601
604
|
prompt_template = [args.prompt_template]
|
|
602
605
|
elif active_session["prompt_template"]:
|
|
603
|
-
prompt_template = [
|
|
606
|
+
prompt_template = [active_session["prompt_template"]]
|
|
604
607
|
else:
|
|
605
608
|
prompt_template = []
|
|
606
609
|
|
|
@@ -703,7 +706,7 @@ def delete_session(args) -> None:
|
|
|
703
706
|
print(f"[delete_session]: {str(e)}")
|
|
704
707
|
|
|
705
708
|
|
|
706
|
-
def
|
|
709
|
+
def _display_sessions(sessions: list) -> None:
|
|
707
710
|
"""
|
|
708
711
|
Display a list of sessions.
|
|
709
712
|
|
|
@@ -717,25 +720,32 @@ def display_sessions(sessions: list) -> None:
|
|
|
717
720
|
None
|
|
718
721
|
"""
|
|
719
722
|
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
session_id
|
|
730
|
-
endpoints
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
723
|
+
table = Table(
|
|
724
|
+
title="Session List", show_lines=True, expand=True, header_style="bold"
|
|
725
|
+
)
|
|
726
|
+
table.add_column("No.", justify="left", width=2)
|
|
727
|
+
table.add_column("Session ID", justify="left", width=20)
|
|
728
|
+
table.add_column("Contains", justify="left", width=78)
|
|
729
|
+
|
|
730
|
+
for idx, session_data in enumerate(sessions, 1):
|
|
731
|
+
(
|
|
732
|
+
session_id,
|
|
733
|
+
endpoints,
|
|
734
|
+
created_epoch,
|
|
735
|
+
created_datetime,
|
|
736
|
+
prompt_template,
|
|
737
|
+
context_strategy,
|
|
738
|
+
cs_num_of_prev_prompts,
|
|
739
|
+
attack_module,
|
|
740
|
+
metric,
|
|
741
|
+
system_prompt,
|
|
742
|
+
*other_args,
|
|
743
|
+
) = session_data.values()
|
|
744
|
+
idx = session_data.get("idx", idx)
|
|
745
|
+
session_info = f"[red]id: {session_id}[/red]\n\nCreated: {created_datetime}"
|
|
746
|
+
contains_info = f"[blue]Endpoints:[/blue] {endpoints}\n\n"
|
|
747
|
+
table.add_row(str(idx), session_info, contains_info)
|
|
748
|
+
console.print(table)
|
|
739
749
|
|
|
740
750
|
|
|
741
751
|
# use session arguments
|
|
@@ -885,6 +895,13 @@ list_sessions_args.add_argument(
|
|
|
885
895
|
nargs="?",
|
|
886
896
|
)
|
|
887
897
|
|
|
898
|
+
list_sessions_args.add_argument(
|
|
899
|
+
"-p",
|
|
900
|
+
"--pagination",
|
|
901
|
+
type=str,
|
|
902
|
+
help="Optional tuple to paginate session(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
|
|
903
|
+
nargs="?",
|
|
904
|
+
)
|
|
888
905
|
|
|
889
906
|
# Add bookmark arguments
|
|
890
907
|
add_bookmark_args = cmd2.Cmd2ArgumentParser(
|
|
@@ -973,3 +990,11 @@ list_bookmarks_args.add_argument(
|
|
|
973
990
|
help="Optional field to find bookmark(s) with keyword",
|
|
974
991
|
nargs="?",
|
|
975
992
|
)
|
|
993
|
+
|
|
994
|
+
list_bookmarks_args.add_argument(
|
|
995
|
+
"-p",
|
|
996
|
+
"--pagination",
|
|
997
|
+
type=str,
|
|
998
|
+
help="Optional tuple to paginate bookmark(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
|
|
999
|
+
nargs="?",
|
|
1000
|
+
)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from moonshot.src.utils.find_feature import find_keyword
|
|
2
|
+
from moonshot.src.utils.pagination import get_paginated_lists
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def filter_data(
|
|
6
|
+
list_of_data: list, keyword: str = "", pagination: tuple = ()
|
|
7
|
+
) -> list | None:
|
|
8
|
+
"""
|
|
9
|
+
Filters and paginates a list of data.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
list_of_data (list): The list of data to be filtered and paginated.
|
|
13
|
+
keyword (str, optional): The keyword to filter the data. Defaults to "".
|
|
14
|
+
pagination (tuple, optional): A tuple containing the page number and page size. Defaults to ().
|
|
15
|
+
|
|
16
|
+
Returns:
|
|
17
|
+
list | None: The filtered and paginated list of data, or None if no data matches the criteria.
|
|
18
|
+
"""
|
|
19
|
+
# if there is a find keyword
|
|
20
|
+
if keyword:
|
|
21
|
+
list_of_data = find_keyword(keyword, list_of_data)
|
|
22
|
+
if not list_of_data:
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
# if pagination is required
|
|
26
|
+
if pagination:
|
|
27
|
+
# add index to every dictionary in the list
|
|
28
|
+
if all(isinstance(item, dict) for item in list_of_data):
|
|
29
|
+
for index, item in enumerate(list_of_data, 1):
|
|
30
|
+
if isinstance(item, dict):
|
|
31
|
+
item["idx"] = index
|
|
32
|
+
|
|
33
|
+
# get paginated data
|
|
34
|
+
page_number = pagination[0]
|
|
35
|
+
page_size = pagination[1]
|
|
36
|
+
|
|
37
|
+
if page_number <= 0 or page_size <= 0:
|
|
38
|
+
# print("Invalid page number or page size. Page number and page size should start from 1.")
|
|
39
|
+
raise RuntimeError(
|
|
40
|
+
"Invalid page number or page size. Page number and page size should start from 1."
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
paginated_data = get_paginated_lists(page_size, list_of_data)
|
|
44
|
+
|
|
45
|
+
# perform index checks
|
|
46
|
+
paginated_data_size = len(paginated_data)
|
|
47
|
+
if page_number > paginated_data_size:
|
|
48
|
+
list_of_data = paginated_data[(paginated_data_size - 1)]
|
|
49
|
+
else:
|
|
50
|
+
list_of_data = paginated_data[(page_number - 1)]
|
|
51
|
+
|
|
52
|
+
return list_of_data
|
|
@@ -71,7 +71,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
|
|
|
71
71
|
}
|
|
72
72
|
|
|
73
73
|
app: CustomFastAPI = CustomFastAPI(
|
|
74
|
-
title="Project Moonshot", version="0.4.
|
|
74
|
+
title="Project Moonshot", version="0.4.4", **app_kwargs
|
|
75
75
|
)
|
|
76
76
|
|
|
77
77
|
if cfg.cors.enabled():
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from dependency_injector.wiring import Provide, inject
|
|
2
|
-
from fastapi import APIRouter, Depends, HTTPException
|
|
2
|
+
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
3
3
|
|
|
4
4
|
from ..container import Container
|
|
5
|
+
from ..schemas.dataset_create_dto import DatasetCreateDTO
|
|
5
6
|
from ..schemas.dataset_response_dto import DatasetResponseDTO
|
|
6
7
|
from ..services.dataset_service import DatasetService
|
|
7
8
|
from ..services.utils.exceptions_handler import ServiceException
|
|
@@ -9,6 +10,50 @@ from ..services.utils.exceptions_handler import ServiceException
|
|
|
9
10
|
router = APIRouter(tags=["Datasets"])
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
@router.post("/api/v1/datasets")
|
|
14
|
+
@inject
|
|
15
|
+
def create_dataset(
|
|
16
|
+
dataset_data: DatasetCreateDTO,
|
|
17
|
+
method: str = Query(
|
|
18
|
+
...,
|
|
19
|
+
description="The method to use for creating the dataset. Supported methods are 'hf' and 'csv'.",
|
|
20
|
+
),
|
|
21
|
+
dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
|
|
22
|
+
) -> str:
|
|
23
|
+
"""
|
|
24
|
+
Create a new dataset using the specified method.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
dataset_data (DatasetCreateDTO): The data required to create the dataset.
|
|
28
|
+
method (str): The method to use for creating the dataset. Supported methods are "hf" and "csv".
|
|
29
|
+
dataset_service (DatasetService, optional): The service responsible for creating the dataset.
|
|
30
|
+
Defaults to Depends(Provide[Container.dataset_service]).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
dict: A message indicating the dataset was created successfully.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
HTTPException: An error with status code 404 if the dataset file is not found.
|
|
37
|
+
An error with status code 400 if there is a validation error.
|
|
38
|
+
An error with status code 500 for any other server-side error.
|
|
39
|
+
"""
|
|
40
|
+
try:
|
|
41
|
+
return dataset_service.create_dataset(dataset_data, method)
|
|
42
|
+
except ServiceException as e:
|
|
43
|
+
if e.error_code == "FileNotFound":
|
|
44
|
+
raise HTTPException(
|
|
45
|
+
status_code=404, detail=f"Failed to retrieve datasets: {e.msg}"
|
|
46
|
+
)
|
|
47
|
+
elif e.error_code == "ValidationError":
|
|
48
|
+
raise HTTPException(
|
|
49
|
+
status_code=400, detail=f"Failed to retrieve datasets: {e.msg}"
|
|
50
|
+
)
|
|
51
|
+
else:
|
|
52
|
+
raise HTTPException(
|
|
53
|
+
status_code=500, detail=f"Failed to retrieve datasets: {e.msg}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
12
57
|
@router.get("/api/v1/datasets")
|
|
13
58
|
@inject
|
|
14
59
|
def get_all_datasets(
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
from pyparsing import Iterator
|
|
5
|
+
|
|
6
|
+
from moonshot.src.datasets.dataset_arguments import (
|
|
7
|
+
DatasetArguments as DatasetPydanticModel,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DatasetCreateDTO(DatasetPydanticModel):
|
|
12
|
+
id: Optional[str] = None
|
|
13
|
+
examples: Iterator[dict] = None
|
|
14
|
+
name: str = Field(..., min_length=1)
|
|
15
|
+
description: str = Field(default="", min_length=1)
|
|
16
|
+
license: Optional[str] = ""
|
|
17
|
+
reference: Optional[str] = ""
|
|
18
|
+
params: dict
|
|
@@ -13,7 +13,6 @@ class RecipeCreateDTO(RecipePydanticModel):
|
|
|
13
13
|
datasets: list[str] = Field(..., min_length=1)
|
|
14
14
|
metrics: list[str] = Field(..., min_length=1)
|
|
15
15
|
prompt_templates: Optional[list[str]] = None
|
|
16
|
-
attack_modules: Optional[list[str]] = None
|
|
17
16
|
grading_scale: Optional[dict[str, list[int]]] = None
|
|
18
17
|
stats: Optional[dict] = None
|
|
19
18
|
|
|
@@ -27,6 +26,5 @@ class RecipeUpdateDTO(RecipePydanticModel):
|
|
|
27
26
|
datasets: Optional[list[str]] = None
|
|
28
27
|
prompt_templates: Optional[list[str]] = None
|
|
29
28
|
metrics: Optional[list[str]] = None
|
|
30
|
-
attack_modules: Optional[list[str]] = None
|
|
31
29
|
grading_scale: Optional[dict[str, list[int]]] = None
|
|
32
30
|
stats: Optional[dict] = None
|
|
@@ -52,18 +52,22 @@ class BookmarkService(BaseService):
|
|
|
52
52
|
@exception_handler
|
|
53
53
|
def delete_bookmarks(self, all: bool = False, name: str | None = None) -> dict:
|
|
54
54
|
"""
|
|
55
|
-
Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True
|
|
55
|
+
Deletes a single bookmark by its name or all bookmarks if the 'all' flag is set to True and returns
|
|
56
|
+
a boolean indicating the success of the operation.
|
|
56
57
|
|
|
57
58
|
Args:
|
|
58
59
|
all (bool, optional): If True, all bookmarks will be deleted. Defaults to False.
|
|
59
60
|
name (str | None, optional): The name of the bookmark to delete. If 'all' is False, 'name' must be provided.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
dict: True if the deletion was successful, False otherwise.
|
|
60
64
|
"""
|
|
61
65
|
if all:
|
|
62
66
|
result = moonshot_api.api_delete_all_bookmark()
|
|
63
67
|
elif name is not None:
|
|
64
68
|
result = moonshot_api.api_delete_bookmark(name)
|
|
65
69
|
else:
|
|
66
|
-
raise ValueError("Either 'all' must be True or '
|
|
70
|
+
raise ValueError("Either 'all' must be True or 'name' must be provided.")
|
|
67
71
|
|
|
68
72
|
if not result["success"]:
|
|
69
73
|
raise Exception(
|
|
@@ -1,10 +1,35 @@
|
|
|
1
1
|
from .... import api as moonshot_api
|
|
2
|
+
from ..schemas.dataset_create_dto import DatasetCreateDTO
|
|
2
3
|
from ..schemas.dataset_response_dto import DatasetResponseDTO
|
|
3
4
|
from ..services.base_service import BaseService
|
|
4
5
|
from ..services.utils.exceptions_handler import exception_handler
|
|
6
|
+
from .utils.file_manager import copy_file
|
|
5
7
|
|
|
6
8
|
|
|
7
9
|
class DatasetService(BaseService):
|
|
10
|
+
@exception_handler
|
|
11
|
+
def create_dataset(self, dataset_data: DatasetCreateDTO, method: str) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Create a dataset using the specified method.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
dataset_data (DatasetCreateDTO): The data required to create the dataset.
|
|
17
|
+
method (str): The method to use for creating the dataset.
|
|
18
|
+
Supported methods are "hf" and "csv".
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
Exception: If an error occurs during dataset creation.
|
|
22
|
+
"""
|
|
23
|
+
new_ds_path = moonshot_api.api_create_datasets(
|
|
24
|
+
name=dataset_data.name,
|
|
25
|
+
description=dataset_data.description,
|
|
26
|
+
reference=dataset_data.reference,
|
|
27
|
+
license=dataset_data.license,
|
|
28
|
+
method=method,
|
|
29
|
+
**dataset_data.params,
|
|
30
|
+
)
|
|
31
|
+
return copy_file(new_ds_path)
|
|
32
|
+
|
|
8
33
|
@exception_handler
|
|
9
34
|
def get_all_datasets(self) -> list[DatasetResponseDTO]:
|
|
10
35
|
datasets = moonshot_api.api_get_all_datasets()
|
moonshot/src/api/api_dataset.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from pydantic import validate_call
|
|
2
2
|
|
|
3
3
|
from moonshot.src.datasets.dataset import Dataset
|
|
4
|
+
from moonshot.src.datasets.dataset_arguments import DatasetArguments
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
# ------------------------------------------------------------------------------
|
|
@@ -44,3 +45,37 @@ def api_get_all_datasets_name() -> list[str]:
|
|
|
44
45
|
"""
|
|
45
46
|
datasets_name, _ = Dataset.get_available_items()
|
|
46
47
|
return datasets_name
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def api_create_datasets(
|
|
51
|
+
name: str, description: str, reference: str, license: str, method: str, **kwargs
|
|
52
|
+
) -> str:
|
|
53
|
+
"""
|
|
54
|
+
This function creates a new dataset.
|
|
55
|
+
|
|
56
|
+
This function takes the name, description, reference, and license for a new dataset as input. It then creates a new
|
|
57
|
+
DatasetArguments object with these details and an empty id. The id is left empty because it will be generated
|
|
58
|
+
from the name during the creation process. The function then calls the Dataset's create method to
|
|
59
|
+
create the new dataset.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
name (str): The name of the new dataset.
|
|
63
|
+
description (str): A brief description of the new dataset.
|
|
64
|
+
reference (str): A reference link for the new dataset.
|
|
65
|
+
license (str): The license of the new dataset.
|
|
66
|
+
method (str): The method to create new dataset. (csv/hf)
|
|
67
|
+
kwargs: Additional keyword arguments for the Dataset's create method.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
str: The ID of the newly created dataset.
|
|
71
|
+
"""
|
|
72
|
+
ds_args = DatasetArguments(
|
|
73
|
+
id="",
|
|
74
|
+
name=name,
|
|
75
|
+
description=description,
|
|
76
|
+
reference=reference,
|
|
77
|
+
license=license,
|
|
78
|
+
examples=None,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return Dataset.create(ds_args, method, **kwargs)
|