aiverify-moonshot 0.4.4__py3-none-any.whl → 0.4.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiverify_moonshot-0.4.4.dist-info → aiverify_moonshot-0.4.6.dist-info}/METADATA +3 -2
- {aiverify_moonshot-0.4.4.dist-info → aiverify_moonshot-0.4.6.dist-info}/RECORD +28 -26
- moonshot/__main__.py +125 -54
- moonshot/integrations/cli/benchmark/cookbook.py +226 -42
- moonshot/integrations/cli/benchmark/datasets.py +53 -8
- moonshot/integrations/cli/benchmark/metrics.py +48 -7
- moonshot/integrations/cli/benchmark/recipe.py +283 -42
- moonshot/integrations/cli/benchmark/result.py +73 -30
- moonshot/integrations/cli/benchmark/run.py +43 -11
- moonshot/integrations/cli/benchmark/runner.py +29 -20
- moonshot/integrations/cli/cli_errors.py +511 -0
- moonshot/integrations/cli/common/connectors.py +137 -6
- moonshot/integrations/cli/common/dataset.py +66 -13
- moonshot/integrations/cli/common/prompt_template.py +38 -2
- moonshot/integrations/cli/redteam/session.py +126 -43
- moonshot/integrations/web_api/app.py +1 -1
- moonshot/integrations/web_api/routes/bookmark.py +7 -4
- moonshot/src/api/api_bookmark.py +6 -6
- moonshot/src/bookmark/bookmark.py +119 -60
- moonshot/src/bookmark/bookmark_arguments.py +10 -0
- moonshot/src/configs/env_variables.py +7 -3
- moonshot/src/messages_constants.py +40 -0
- moonshot/src/runners/runner.py +1 -1
- moonshot/src/runs/run.py +7 -0
- {aiverify_moonshot-0.4.4.dist-info → aiverify_moonshot-0.4.6.dist-info}/WHEEL +0 -0
- {aiverify_moonshot-0.4.4.dist-info → aiverify_moonshot-0.4.6.dist-info}/licenses/AUTHORS.md +0 -0
- {aiverify_moonshot-0.4.4.dist-info → aiverify_moonshot-0.4.6.dist-info}/licenses/LICENSE.md +0 -0
- {aiverify_moonshot-0.4.4.dist-info → aiverify_moonshot-0.4.6.dist-info}/licenses/NOTICES.md +0 -0
|
@@ -13,6 +13,26 @@ from moonshot.api import (
|
|
|
13
13
|
api_read_endpoint,
|
|
14
14
|
api_update_endpoint,
|
|
15
15
|
)
|
|
16
|
+
from moonshot.integrations.cli.cli_errors import (
|
|
17
|
+
ERROR_COMMON_ADD_ENDPOINT_CONNECTOR_TYPE_VALIDATION,
|
|
18
|
+
ERROR_COMMON_ADD_ENDPOINT_MAX_CALLS_PER_SECOND_VALIDATION,
|
|
19
|
+
ERROR_COMMON_ADD_ENDPOINT_MAX_CONCURRENCY_VALIDATION,
|
|
20
|
+
ERROR_COMMON_ADD_ENDPOINT_NAME_VALIDATION,
|
|
21
|
+
ERROR_COMMON_ADD_ENDPOINT_PARAMS_VALIDATION,
|
|
22
|
+
ERROR_COMMON_ADD_ENDPOINT_TOKEN_VALIDATION,
|
|
23
|
+
ERROR_COMMON_ADD_ENDPOINT_URI_VALIDATION,
|
|
24
|
+
ERROR_COMMON_DELETE_ENDPOINT_ENDPOINT_VALIDATION,
|
|
25
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_FIND_VALIDATION,
|
|
26
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION,
|
|
27
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1,
|
|
28
|
+
ERROR_COMMON_LIST_ENDPOINTS_FIND_VALIDATION,
|
|
29
|
+
ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION,
|
|
30
|
+
ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION_1,
|
|
31
|
+
ERROR_COMMON_UPDATE_ENDPOINT_ENDPOINT_VALIDATION,
|
|
32
|
+
ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION,
|
|
33
|
+
ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION_1,
|
|
34
|
+
ERROR_COMMON_VIEW_ENDPOINT_ENDPOINT_VALIDATION,
|
|
35
|
+
)
|
|
16
36
|
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
17
37
|
|
|
18
38
|
console = Console()
|
|
@@ -42,7 +62,44 @@ def add_endpoint(args) -> None:
|
|
|
42
62
|
None
|
|
43
63
|
"""
|
|
44
64
|
try:
|
|
45
|
-
|
|
65
|
+
if not isinstance(args.name, str) or not args.name or args.name is None:
|
|
66
|
+
raise TypeError(ERROR_COMMON_ADD_ENDPOINT_NAME_VALIDATION)
|
|
67
|
+
|
|
68
|
+
if (
|
|
69
|
+
not isinstance(args.connector_type, str)
|
|
70
|
+
or not args.connector_type
|
|
71
|
+
or args.connector_type is None
|
|
72
|
+
):
|
|
73
|
+
raise TypeError(ERROR_COMMON_ADD_ENDPOINT_CONNECTOR_TYPE_VALIDATION)
|
|
74
|
+
|
|
75
|
+
if not isinstance(args.uri, str) or not args.uri or args.uri is None:
|
|
76
|
+
raise TypeError(ERROR_COMMON_ADD_ENDPOINT_URI_VALIDATION)
|
|
77
|
+
|
|
78
|
+
if not isinstance(args.token, str) or not args.token or args.token is None:
|
|
79
|
+
raise TypeError(ERROR_COMMON_ADD_ENDPOINT_TOKEN_VALIDATION)
|
|
80
|
+
|
|
81
|
+
if (
|
|
82
|
+
not isinstance(args.max_calls_per_second, int)
|
|
83
|
+
or not args.max_calls_per_second
|
|
84
|
+
or args.max_calls_per_second is None
|
|
85
|
+
or args.max_calls_per_second < 0
|
|
86
|
+
):
|
|
87
|
+
raise TypeError(ERROR_COMMON_ADD_ENDPOINT_MAX_CALLS_PER_SECOND_VALIDATION)
|
|
88
|
+
|
|
89
|
+
if (
|
|
90
|
+
not isinstance(args.max_concurrency, int)
|
|
91
|
+
or not args.max_concurrency
|
|
92
|
+
or args.max_concurrency is None
|
|
93
|
+
or args.max_calls_per_second < 0
|
|
94
|
+
):
|
|
95
|
+
raise TypeError(ERROR_COMMON_ADD_ENDPOINT_MAX_CONCURRENCY_VALIDATION)
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
params_dict = literal_eval(args.params)
|
|
99
|
+
except Exception:
|
|
100
|
+
raise SyntaxError(ERROR_COMMON_ADD_ENDPOINT_PARAMS_VALIDATION)
|
|
101
|
+
if not isinstance(params_dict, dict):
|
|
102
|
+
raise ValueError(ERROR_COMMON_ADD_ENDPOINT_PARAMS_VALIDATION)
|
|
46
103
|
|
|
47
104
|
new_endpoint_id = api_create_endpoint(
|
|
48
105
|
args.name,
|
|
@@ -74,9 +131,30 @@ def list_endpoints(args) -> list | None:
|
|
|
74
131
|
list | None: A list of ConnectorEndpoint or None if there is no result.
|
|
75
132
|
"""
|
|
76
133
|
try:
|
|
134
|
+
if args.find is not None:
|
|
135
|
+
if not isinstance(args.find, str) or not args.find:
|
|
136
|
+
raise TypeError(ERROR_COMMON_LIST_ENDPOINTS_FIND_VALIDATION)
|
|
137
|
+
|
|
138
|
+
if args.pagination is not None:
|
|
139
|
+
if not isinstance(args.pagination, str) or not args.pagination:
|
|
140
|
+
raise TypeError(ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION)
|
|
141
|
+
try:
|
|
142
|
+
pagination = literal_eval(args.pagination)
|
|
143
|
+
if not (
|
|
144
|
+
isinstance(pagination, tuple)
|
|
145
|
+
and len(pagination) == 2
|
|
146
|
+
and all(isinstance(i, int) for i in pagination)
|
|
147
|
+
):
|
|
148
|
+
raise ValueError(
|
|
149
|
+
ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION_1
|
|
150
|
+
)
|
|
151
|
+
except (ValueError, SyntaxError):
|
|
152
|
+
raise ValueError(ERROR_COMMON_LIST_ENDPOINTS_PAGINATION_VALIDATION_1)
|
|
153
|
+
else:
|
|
154
|
+
pagination = ()
|
|
155
|
+
|
|
77
156
|
endpoints_list = api_get_all_endpoint()
|
|
78
157
|
keyword = args.find.lower() if args.find else ""
|
|
79
|
-
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
80
158
|
|
|
81
159
|
if endpoints_list:
|
|
82
160
|
filtered_endpoints_list = filter_data(endpoints_list, keyword, pagination)
|
|
@@ -86,7 +164,6 @@ def list_endpoints(args) -> list | None:
|
|
|
86
164
|
|
|
87
165
|
console.print("[red]There are no endpoints found.[/red]")
|
|
88
166
|
return None
|
|
89
|
-
|
|
90
167
|
except Exception as e:
|
|
91
168
|
print(f"[list_endpoints]: {str(e)}")
|
|
92
169
|
|
|
@@ -107,9 +184,32 @@ def list_connector_types(args) -> list | None:
|
|
|
107
184
|
list | None: A list of Connector or None if there is no result.
|
|
108
185
|
"""
|
|
109
186
|
try:
|
|
187
|
+
if args.find is not None:
|
|
188
|
+
if not isinstance(args.find, str) or not args.find:
|
|
189
|
+
raise TypeError(ERROR_COMMON_LIST_CONNECTOR_TYPES_FIND_VALIDATION)
|
|
190
|
+
|
|
191
|
+
if args.pagination is not None:
|
|
192
|
+
if not isinstance(args.pagination, str) or not args.pagination:
|
|
193
|
+
raise TypeError(ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION)
|
|
194
|
+
try:
|
|
195
|
+
pagination = literal_eval(args.pagination)
|
|
196
|
+
if not (
|
|
197
|
+
isinstance(pagination, tuple)
|
|
198
|
+
and len(pagination) == 2
|
|
199
|
+
and all(isinstance(i, int) for i in pagination)
|
|
200
|
+
):
|
|
201
|
+
raise ValueError(
|
|
202
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
|
|
203
|
+
)
|
|
204
|
+
except (ValueError, SyntaxError):
|
|
205
|
+
raise ValueError(
|
|
206
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
pagination = ()
|
|
210
|
+
|
|
110
211
|
connector_type_list = api_get_all_connector_type()
|
|
111
212
|
keyword = args.find.lower() if args.find else ""
|
|
112
|
-
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
113
213
|
|
|
114
214
|
if connector_type_list:
|
|
115
215
|
filtered_connector_type_list = filter_data(
|
|
@@ -121,7 +221,6 @@ def list_connector_types(args) -> list | None:
|
|
|
121
221
|
|
|
122
222
|
console.print("[red]There are no connector types found.[/red]")
|
|
123
223
|
return None
|
|
124
|
-
|
|
125
224
|
except Exception as e:
|
|
126
225
|
print(f"[list_connector_types]: {str(e)}")
|
|
127
226
|
|
|
@@ -142,6 +241,13 @@ def view_endpoint(args) -> None:
|
|
|
142
241
|
None
|
|
143
242
|
"""
|
|
144
243
|
try:
|
|
244
|
+
if (
|
|
245
|
+
not isinstance(args.endpoint, str)
|
|
246
|
+
or not args.endpoint
|
|
247
|
+
or args.endpoint is None
|
|
248
|
+
):
|
|
249
|
+
raise TypeError(ERROR_COMMON_VIEW_ENDPOINT_ENDPOINT_VALIDATION)
|
|
250
|
+
|
|
145
251
|
endpoint_info = api_read_endpoint(args.endpoint)
|
|
146
252
|
_display_endpoints([endpoint_info])
|
|
147
253
|
except Exception as e:
|
|
@@ -165,8 +271,27 @@ def update_endpoint(args) -> None:
|
|
|
165
271
|
None
|
|
166
272
|
"""
|
|
167
273
|
try:
|
|
274
|
+
if (
|
|
275
|
+
args.endpoint is None
|
|
276
|
+
or not isinstance(args.endpoint, str)
|
|
277
|
+
or not args.endpoint
|
|
278
|
+
):
|
|
279
|
+
raise ValueError(ERROR_COMMON_UPDATE_ENDPOINT_ENDPOINT_VALIDATION)
|
|
168
280
|
endpoint = args.endpoint
|
|
169
|
-
|
|
281
|
+
|
|
282
|
+
if (
|
|
283
|
+
args.update_values is None
|
|
284
|
+
or not isinstance(args.update_values, str)
|
|
285
|
+
or not args.update_values
|
|
286
|
+
):
|
|
287
|
+
raise ValueError(ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION)
|
|
288
|
+
|
|
289
|
+
if literal_eval(args.update_values) and all(
|
|
290
|
+
isinstance(i, tuple) for i in literal_eval(args.update_values)
|
|
291
|
+
):
|
|
292
|
+
update_values = dict(literal_eval(args.update_values))
|
|
293
|
+
else:
|
|
294
|
+
raise ValueError(ERROR_COMMON_UPDATE_ENDPOINT_VALUES_VALIDATION_1)
|
|
170
295
|
api_update_endpoint(endpoint, **update_values)
|
|
171
296
|
print("[update_endpoint]: Endpoint updated.")
|
|
172
297
|
except Exception as e:
|
|
@@ -196,6 +321,12 @@ def delete_endpoint(args) -> None:
|
|
|
196
321
|
console.print("[bold yellow]Endpoint deletion cancelled.[/]")
|
|
197
322
|
return
|
|
198
323
|
try:
|
|
324
|
+
if (
|
|
325
|
+
args.endpoint is None
|
|
326
|
+
or not isinstance(args.endpoint, str)
|
|
327
|
+
or not args.endpoint
|
|
328
|
+
):
|
|
329
|
+
raise ValueError(ERROR_COMMON_DELETE_ENDPOINT_ENDPOINT_VALIDATION)
|
|
199
330
|
api_delete_endpoint(args.endpoint)
|
|
200
331
|
print("[delete_endpoint]: Endpoint deleted.")
|
|
201
332
|
except Exception as e:
|
|
@@ -3,11 +3,19 @@ from ast import literal_eval
|
|
|
3
3
|
import cmd2
|
|
4
4
|
from rich.console import Console
|
|
5
5
|
|
|
6
|
-
from moonshot.api import
|
|
7
|
-
|
|
6
|
+
from moonshot.api import api_create_datasets
|
|
7
|
+
from moonshot.integrations.cli.cli_errors import (
|
|
8
|
+
ERROR_COMMON_ADD_DATASET_DESC_VALIDATION,
|
|
9
|
+
ERROR_COMMON_ADD_DATASET_LICENSE_VALIDATION,
|
|
10
|
+
ERROR_COMMON_ADD_DATASET_METHOD_VALIDATION,
|
|
11
|
+
ERROR_COMMON_ADD_DATASET_NAME_VALIDATION,
|
|
12
|
+
ERROR_COMMON_ADD_DATASET_PARAMS_VALIDATION,
|
|
13
|
+
ERROR_COMMON_ADD_DATASET_REFERENCE_VALIDATION,
|
|
8
14
|
)
|
|
9
15
|
|
|
10
16
|
console = Console()
|
|
17
|
+
|
|
18
|
+
|
|
11
19
|
def add_dataset(args) -> None:
|
|
12
20
|
"""
|
|
13
21
|
Create a new dataset using the provided arguments and log the result.
|
|
@@ -26,6 +34,41 @@ def add_dataset(args) -> None:
|
|
|
26
34
|
- params (dict): Additional parameters for dataset creation.
|
|
27
35
|
"""
|
|
28
36
|
try:
|
|
37
|
+
if not isinstance(args.name, str) or not args.name or args.name is None:
|
|
38
|
+
raise TypeError(ERROR_COMMON_ADD_DATASET_NAME_VALIDATION)
|
|
39
|
+
|
|
40
|
+
if (
|
|
41
|
+
not isinstance(args.description, str)
|
|
42
|
+
or not args.description
|
|
43
|
+
or args.description is None
|
|
44
|
+
):
|
|
45
|
+
raise TypeError(ERROR_COMMON_ADD_DATASET_DESC_VALIDATION)
|
|
46
|
+
|
|
47
|
+
if (
|
|
48
|
+
not isinstance(args.reference, str)
|
|
49
|
+
or not args.reference
|
|
50
|
+
or args.reference is None
|
|
51
|
+
):
|
|
52
|
+
raise TypeError(ERROR_COMMON_ADD_DATASET_REFERENCE_VALIDATION)
|
|
53
|
+
|
|
54
|
+
if (
|
|
55
|
+
not isinstance(args.license, str)
|
|
56
|
+
or not args.license
|
|
57
|
+
or args.license is None
|
|
58
|
+
):
|
|
59
|
+
raise TypeError(ERROR_COMMON_ADD_DATASET_LICENSE_VALIDATION)
|
|
60
|
+
|
|
61
|
+
if (
|
|
62
|
+
not isinstance(args.method, str)
|
|
63
|
+
or not args.method
|
|
64
|
+
or args.method is None
|
|
65
|
+
or args.method.lower() not in ["hf", "csv"]
|
|
66
|
+
):
|
|
67
|
+
raise TypeError(ERROR_COMMON_ADD_DATASET_METHOD_VALIDATION)
|
|
68
|
+
|
|
69
|
+
if not isinstance(args.params, dict) or not args.params or args.params is None:
|
|
70
|
+
raise TypeError(ERROR_COMMON_ADD_DATASET_PARAMS_VALIDATION)
|
|
71
|
+
|
|
29
72
|
new_dataset_id = api_create_datasets(
|
|
30
73
|
args.name,
|
|
31
74
|
args.description,
|
|
@@ -38,7 +81,8 @@ def add_dataset(args) -> None:
|
|
|
38
81
|
except Exception as e:
|
|
39
82
|
print(f"[add_dataset]: {str(e)}")
|
|
40
83
|
|
|
41
|
-
|
|
84
|
+
|
|
85
|
+
# ------------------------------------------------------------------------------
|
|
42
86
|
# Cmd2 Arguments Parsers
|
|
43
87
|
# ------------------------------------------------------------------------------
|
|
44
88
|
# Add dataset arguments
|
|
@@ -46,21 +90,30 @@ add_dataset_args = cmd2.Cmd2ArgumentParser(
|
|
|
46
90
|
description="Add a new dataset. The 'name' argument will be slugified to create a unique identifier.",
|
|
47
91
|
epilog=(
|
|
48
92
|
"Examples:\n"
|
|
49
|
-
"1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
|
|
50
|
-
"2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\""
|
|
93
|
+
"1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n" # noqa: E501
|
|
94
|
+
"2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\"" # noqa: E501
|
|
51
95
|
),
|
|
52
96
|
)
|
|
53
97
|
add_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
|
|
54
|
-
add_dataset_args.add_argument(
|
|
55
|
-
|
|
98
|
+
add_dataset_args.add_argument(
|
|
99
|
+
"description", type=str, help="Description of the new dataset"
|
|
100
|
+
)
|
|
101
|
+
add_dataset_args.add_argument(
|
|
102
|
+
"reference", type=str, help="Reference of the new dataset"
|
|
103
|
+
)
|
|
56
104
|
add_dataset_args.add_argument("license", type=str, help="License of the new dataset")
|
|
57
|
-
add_dataset_args.add_argument("method", type=str, choices=['hf', 'csv'], help="Method to convert the new dataset. Choose either 'hf' or 'csv'.")
|
|
58
105
|
add_dataset_args.add_argument(
|
|
59
|
-
"
|
|
60
|
-
type=
|
|
106
|
+
"method",
|
|
107
|
+
type=str,
|
|
108
|
+
choices=["hf", "csv"],
|
|
109
|
+
help="Method to convert the new dataset. Choose either 'hf' or 'csv'.",
|
|
110
|
+
)
|
|
111
|
+
add_dataset_args.add_argument(
|
|
112
|
+
"params",
|
|
113
|
+
type=literal_eval,
|
|
61
114
|
help=(
|
|
62
115
|
"Params of the new dataset in dictionary format. For example: \n"
|
|
63
116
|
"1. For 'csv' method: \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
|
|
64
|
-
"2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\""
|
|
65
|
-
)
|
|
66
|
-
)
|
|
117
|
+
"2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\"" # noqa: E501
|
|
118
|
+
),
|
|
119
|
+
)
|
|
@@ -5,6 +5,12 @@ from rich.console import Console
|
|
|
5
5
|
from rich.table import Table
|
|
6
6
|
|
|
7
7
|
from moonshot.api import api_delete_prompt_template, api_get_all_prompt_template_detail
|
|
8
|
+
from moonshot.integrations.cli.cli_errors import (
|
|
9
|
+
ERROR_COMMON_DELETE_PROMPT_TEMPLATE_PROMPT_TEMPLATE_VALIDATION,
|
|
10
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION,
|
|
11
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1,
|
|
12
|
+
ERROR_COMMON_LIST_PROMPT_TEMPLATES_FIND_VALIDATION,
|
|
13
|
+
)
|
|
8
14
|
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
9
15
|
|
|
10
16
|
console = Console()
|
|
@@ -26,9 +32,32 @@ def list_prompt_templates(args) -> list | None:
|
|
|
26
32
|
list | None: A list of PromptTemplate or None if there is no result.
|
|
27
33
|
"""
|
|
28
34
|
try:
|
|
35
|
+
if args.find is not None:
|
|
36
|
+
if not isinstance(args.find, str) or not args.find:
|
|
37
|
+
raise TypeError(ERROR_COMMON_LIST_PROMPT_TEMPLATES_FIND_VALIDATION)
|
|
38
|
+
|
|
39
|
+
if args.pagination is not None:
|
|
40
|
+
if not isinstance(args.pagination, str) or not args.pagination:
|
|
41
|
+
raise TypeError(ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION)
|
|
42
|
+
try:
|
|
43
|
+
pagination = literal_eval(args.pagination)
|
|
44
|
+
if not (
|
|
45
|
+
isinstance(pagination, tuple)
|
|
46
|
+
and len(pagination) == 2
|
|
47
|
+
and all(isinstance(i, int) for i in pagination)
|
|
48
|
+
):
|
|
49
|
+
raise ValueError(
|
|
50
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
|
|
51
|
+
)
|
|
52
|
+
except (ValueError, SyntaxError):
|
|
53
|
+
raise ValueError(
|
|
54
|
+
ERROR_COMMON_LIST_CONNECTOR_TYPES_PAGINATION_VALIDATION_1
|
|
55
|
+
)
|
|
56
|
+
else:
|
|
57
|
+
pagination = ()
|
|
58
|
+
|
|
29
59
|
prompt_templates_list = api_get_all_prompt_template_detail()
|
|
30
60
|
keyword = args.find.lower() if args.find else ""
|
|
31
|
-
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
32
61
|
|
|
33
62
|
if prompt_templates_list:
|
|
34
63
|
filtered_prompt_templates_list = filter_data(
|
|
@@ -40,7 +69,6 @@ def list_prompt_templates(args) -> list | None:
|
|
|
40
69
|
|
|
41
70
|
console.print("[red]There are no prompt templates found.[/red]")
|
|
42
71
|
return None
|
|
43
|
-
|
|
44
72
|
except Exception as e:
|
|
45
73
|
print(f"[list_prompt_templates]: {str(e)}")
|
|
46
74
|
|
|
@@ -61,6 +89,14 @@ def delete_prompt_template(args) -> None:
|
|
|
61
89
|
console.print("[bold yellow]Prompt template deletion cancelled.[/]")
|
|
62
90
|
return
|
|
63
91
|
try:
|
|
92
|
+
if (
|
|
93
|
+
args.prompt_template is None
|
|
94
|
+
or not isinstance(args.prompt_template, str)
|
|
95
|
+
or not args.prompt_template
|
|
96
|
+
):
|
|
97
|
+
raise ValueError(
|
|
98
|
+
ERROR_COMMON_DELETE_PROMPT_TEMPLATE_PROMPT_TEMPLATE_VALIDATION
|
|
99
|
+
)
|
|
64
100
|
api_delete_prompt_template(args.prompt_template)
|
|
65
101
|
print("[delete_prompt_template]: Prompt template deleted.")
|
|
66
102
|
except Exception as e:
|
|
@@ -22,6 +22,23 @@ from moonshot.api import (
|
|
|
22
22
|
api_load_session,
|
|
23
23
|
)
|
|
24
24
|
from moonshot.integrations.cli.active_session_cfg import active_session
|
|
25
|
+
from moonshot.integrations.cli.cli_errors import (
|
|
26
|
+
ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION,
|
|
27
|
+
ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION_1,
|
|
28
|
+
ERROR_RED_TEAMING_ADD_BOOKMARK_NO_ACTIVE_SESSION,
|
|
29
|
+
ERROR_RED_TEAMING_LIST_SESSIONS_FIND_VALIDATION,
|
|
30
|
+
ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION,
|
|
31
|
+
ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION_1,
|
|
32
|
+
ERROR_RED_TEAMING_NEW_SESSION_ENDPOINTS_VALIDATION,
|
|
33
|
+
ERROR_RED_TEAMING_NEW_SESSION_FAILED_TO_USE_SESSION,
|
|
34
|
+
ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION,
|
|
35
|
+
ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION_1,
|
|
36
|
+
ERROR_RED_TEAMING_SHOW_PROMPTS_NO_ACTIVE_SESSION_VALIDATION,
|
|
37
|
+
ERROR_RED_TEAMING_USE_BOOKMARK_NO_ACTIVE_SESSION,
|
|
38
|
+
ERROR_RED_TEAMING_USE_SESSION_NO_SESSION_METADATA_VALIDATION,
|
|
39
|
+
ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_TYPE_VALIDATION,
|
|
40
|
+
ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_VALIDATION,
|
|
41
|
+
)
|
|
25
42
|
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
26
43
|
from moonshot.src.redteaming.session.session import Session
|
|
27
44
|
|
|
@@ -43,38 +60,73 @@ def new_session(args) -> None:
|
|
|
43
60
|
- endpoints (str, optional): The list of endpoints for the runner."""
|
|
44
61
|
global active_session
|
|
45
62
|
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
63
|
+
try:
|
|
64
|
+
required_parameters = [("runner_id", str)]
|
|
65
|
+
optional_parameters = [("context_strategy", str), ("prompt_template", str)]
|
|
66
|
+
# Check if required parameters exist in args
|
|
67
|
+
for param, param_type in required_parameters:
|
|
68
|
+
param_value = getattr(args, param, None)
|
|
69
|
+
if not param_value:
|
|
70
|
+
raise ValueError(
|
|
71
|
+
ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION.format(param=param)
|
|
72
|
+
)
|
|
73
|
+
if not isinstance(param_value, param_type):
|
|
74
|
+
raise TypeError(
|
|
75
|
+
ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION_1.format(
|
|
76
|
+
param=param, param_type=param_type.__name__
|
|
77
|
+
)
|
|
78
|
+
)
|
|
57
79
|
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
80
|
+
# Check the type of optional parameters if they exist
|
|
81
|
+
for param, param_type in optional_parameters:
|
|
82
|
+
param_value = getattr(args, param, None)
|
|
83
|
+
if param_value is not None and not isinstance(param_value, param_type):
|
|
84
|
+
raise TypeError(
|
|
85
|
+
ERROR_RED_TEAMING_NEW_SESSION_PARAMS_VALIDATION_1.format(
|
|
86
|
+
param=param, param_type=param_type.__name__
|
|
87
|
+
)
|
|
88
|
+
)
|
|
61
89
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
if
|
|
69
|
-
|
|
70
|
-
if
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
90
|
+
runner_id = args.runner_id
|
|
91
|
+
context_strategy = args.context_strategy if args.context_strategy else ""
|
|
92
|
+
prompt_template = args.prompt_template if args.prompt_template else ""
|
|
93
|
+
endpoints = []
|
|
94
|
+
|
|
95
|
+
# Check if literal eval param is correct type after eval
|
|
96
|
+
if hasattr(args, "endpoints") and args.endpoints:
|
|
97
|
+
endpoints = literal_eval(args.endpoints)
|
|
98
|
+
if not isinstance(endpoints, list):
|
|
99
|
+
raise TypeError(ERROR_RED_TEAMING_NEW_SESSION_ENDPOINTS_VALIDATION)
|
|
100
|
+
|
|
101
|
+
# create new runner and session
|
|
102
|
+
if endpoints:
|
|
103
|
+
runner = api_create_runner(runner_id, endpoints)
|
|
104
|
+
# load existing runner
|
|
76
105
|
else:
|
|
77
|
-
|
|
106
|
+
runner = api_load_runner(runner_id)
|
|
107
|
+
|
|
108
|
+
runner_args = {}
|
|
109
|
+
runner_args["context_strategy"] = context_strategy
|
|
110
|
+
runner_args["prompt_template"] = prompt_template
|
|
111
|
+
|
|
112
|
+
# create new session in runner
|
|
113
|
+
if runner.database_instance:
|
|
114
|
+
api_create_session(
|
|
115
|
+
runner.id, runner.database_instance, runner.endpoints, runner_args
|
|
116
|
+
)
|
|
117
|
+
session_metadata = api_load_session(runner.id)
|
|
118
|
+
if session_metadata:
|
|
119
|
+
active_session.update(session_metadata)
|
|
120
|
+
if active_session["context_strategy"]:
|
|
121
|
+
active_session[
|
|
122
|
+
"cs_num_of_prev_prompts"
|
|
123
|
+
] = Session.DEFAULT_CONTEXT_STRATEGY_PROMPT
|
|
124
|
+
print(f"[new_session] Using session: {active_session['session_id']}")
|
|
125
|
+
update_chat_display()
|
|
126
|
+
else:
|
|
127
|
+
raise RuntimeError(ERROR_RED_TEAMING_NEW_SESSION_FAILED_TO_USE_SESSION)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
print(f"[new_session]: {str(e)}")
|
|
78
130
|
|
|
79
131
|
|
|
80
132
|
def use_session(args) -> None:
|
|
@@ -85,15 +137,19 @@ def use_session(args) -> None:
|
|
|
85
137
|
args (Namespace): The arguments passed to the function.
|
|
86
138
|
"""
|
|
87
139
|
global active_session
|
|
88
|
-
runner_id = args.runner_id
|
|
89
140
|
|
|
90
141
|
# Load session metadata
|
|
91
142
|
try:
|
|
143
|
+
if not args.runner_id or args.runner_id is None:
|
|
144
|
+
raise ValueError(ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_VALIDATION)
|
|
145
|
+
|
|
146
|
+
if not isinstance(args.runner_id, str):
|
|
147
|
+
raise TypeError(ERROR_RED_TEAMING_USE_SESSION_RUNNER_ID_TYPE_VALIDATION)
|
|
148
|
+
|
|
149
|
+
runner_id = args.runner_id
|
|
92
150
|
session_metadata = api_load_session(runner_id)
|
|
93
151
|
if not session_metadata:
|
|
94
|
-
print(
|
|
95
|
-
"[Session] Cannot find a session with the existing Runner ID. Please try again."
|
|
96
|
-
)
|
|
152
|
+
print(ERROR_RED_TEAMING_USE_SESSION_NO_SESSION_METADATA_VALIDATION)
|
|
97
153
|
return
|
|
98
154
|
|
|
99
155
|
# Set the current session
|
|
@@ -115,7 +171,7 @@ def show_prompts() -> None:
|
|
|
115
171
|
global active_session
|
|
116
172
|
|
|
117
173
|
if not active_session:
|
|
118
|
-
print(
|
|
174
|
+
print(ERROR_RED_TEAMING_SHOW_PROMPTS_NO_ACTIVE_SESSION_VALIDATION)
|
|
119
175
|
return
|
|
120
176
|
|
|
121
177
|
update_chat_display()
|
|
@@ -146,6 +202,28 @@ def list_sessions(args) -> list | None:
|
|
|
146
202
|
"""
|
|
147
203
|
try:
|
|
148
204
|
session_metadata_list = api_get_all_session_metadata()
|
|
205
|
+
if args.find is not None:
|
|
206
|
+
if not isinstance(args.find, str) or not args.find:
|
|
207
|
+
raise TypeError(ERROR_RED_TEAMING_LIST_SESSIONS_FIND_VALIDATION)
|
|
208
|
+
|
|
209
|
+
if args.pagination is not None:
|
|
210
|
+
if not isinstance(args.pagination, str) or not args.pagination:
|
|
211
|
+
raise TypeError(ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION)
|
|
212
|
+
try:
|
|
213
|
+
pagination = literal_eval(args.pagination)
|
|
214
|
+
if not (
|
|
215
|
+
isinstance(pagination, tuple)
|
|
216
|
+
and len(pagination) == 2
|
|
217
|
+
and all(isinstance(i, int) for i in pagination)
|
|
218
|
+
):
|
|
219
|
+
raise ValueError(
|
|
220
|
+
ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION_1
|
|
221
|
+
)
|
|
222
|
+
except (ValueError, SyntaxError):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
ERROR_RED_TEAMING_LIST_SESSIONS_PAGINATION_VALIDATION_1
|
|
225
|
+
)
|
|
226
|
+
|
|
149
227
|
keyword = args.find.lower() if args.find else ""
|
|
150
228
|
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
151
229
|
|
|
@@ -159,7 +237,6 @@ def list_sessions(args) -> list | None:
|
|
|
159
237
|
|
|
160
238
|
console.print("[red]There are no sessions found.[/red]")
|
|
161
239
|
return None
|
|
162
|
-
|
|
163
240
|
except Exception as e:
|
|
164
241
|
print(f"[list_sessions]: {str(e)}")
|
|
165
242
|
|
|
@@ -214,7 +291,6 @@ def update_chat_display() -> None:
|
|
|
214
291
|
title_align="left",
|
|
215
292
|
)
|
|
216
293
|
console.print(panel)
|
|
217
|
-
|
|
218
294
|
else:
|
|
219
295
|
console.print("[red]There are no active session.[/red]")
|
|
220
296
|
|
|
@@ -249,9 +325,7 @@ def add_bookmark(args) -> None:
|
|
|
249
325
|
target_endpoint_chats = list_of_target_endpoint_chat.get(endpoint, None)
|
|
250
326
|
target_endpoint_chat_record = {}
|
|
251
327
|
if not target_endpoint_chats:
|
|
252
|
-
print(
|
|
253
|
-
"Incorrect endpoint. Please select a valid endpoint in this session."
|
|
254
|
-
)
|
|
328
|
+
print(ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION)
|
|
255
329
|
return
|
|
256
330
|
for endpoint_chat in target_endpoint_chats:
|
|
257
331
|
if endpoint_chat["chat_record_id"] == prompt_id:
|
|
@@ -273,12 +347,14 @@ def add_bookmark(args) -> None:
|
|
|
273
347
|
print("[bookmark_prompt]:", bookmark_message["message"])
|
|
274
348
|
else:
|
|
275
349
|
print(
|
|
276
|
-
|
|
350
|
+
ERROR_RED_TEAMING_ADD_BOOKMARK_ENDPOINT_VALIDATION_1.format(
|
|
351
|
+
endpoint=endpoint
|
|
352
|
+
)
|
|
277
353
|
)
|
|
278
354
|
except Exception as e:
|
|
279
355
|
print(f"[bookmark_prompt]: ({str(e)})")
|
|
280
356
|
else:
|
|
281
|
-
print(
|
|
357
|
+
print(ERROR_RED_TEAMING_ADD_BOOKMARK_NO_ACTIVE_SESSION)
|
|
282
358
|
return
|
|
283
359
|
|
|
284
360
|
|
|
@@ -323,7 +399,7 @@ def use_bookmark(args) -> None:
|
|
|
323
399
|
except Exception as e:
|
|
324
400
|
print(f"[use_bookmark]: {str(e)}")
|
|
325
401
|
else:
|
|
326
|
-
print(
|
|
402
|
+
print(ERROR_RED_TEAMING_USE_BOOKMARK_NO_ACTIVE_SESSION)
|
|
327
403
|
return
|
|
328
404
|
|
|
329
405
|
|
|
@@ -693,6 +769,7 @@ def delete_session(args) -> None:
|
|
|
693
769
|
which is the ID of the session to delete.
|
|
694
770
|
"""
|
|
695
771
|
# Confirm with the user before deleting a session
|
|
772
|
+
|
|
696
773
|
confirmation = console.input(
|
|
697
774
|
"[bold red]Are you sure you want to delete the session (y/N)? [/]"
|
|
698
775
|
)
|
|
@@ -700,6 +777,12 @@ def delete_session(args) -> None:
|
|
|
700
777
|
console.print("[bold yellow]Session deletion cancelled.[/]")
|
|
701
778
|
return
|
|
702
779
|
try:
|
|
780
|
+
if not args.session or args.session is None:
|
|
781
|
+
raise ValueError("Invalid or missing required parameter: session")
|
|
782
|
+
|
|
783
|
+
if not isinstance(args.session, str):
|
|
784
|
+
raise TypeError("Invalid type for parameter: session. Expecting type str.")
|
|
785
|
+
|
|
703
786
|
api_delete_session(args.session)
|
|
704
787
|
print("[delete_session]: Session deleted.")
|
|
705
788
|
except Exception as e:
|
|
@@ -844,7 +927,7 @@ automated_rt_session_args.add_argument(
|
|
|
844
927
|
type=str,
|
|
845
928
|
help=(
|
|
846
929
|
"The number of previous prompts to use with the context strategy. If this is set, it will overwrite the"
|
|
847
|
-
" number of previous
|
|
930
|
+
" number of previous prompts set in the session while running this attack module."
|
|
848
931
|
),
|
|
849
932
|
nargs="?",
|
|
850
933
|
)
|
|
@@ -71,7 +71,7 @@ def create_app(cfg: providers.Configuration) -> CustomFastAPI:
|
|
|
71
71
|
}
|
|
72
72
|
|
|
73
73
|
app: CustomFastAPI = CustomFastAPI(
|
|
74
|
-
title="Project Moonshot", version="0.4.
|
|
74
|
+
title="Project Moonshot", version="0.4.6", **app_kwargs
|
|
75
75
|
)
|
|
76
76
|
|
|
77
77
|
if cfg.cors.enabled():
|