aiverify-moonshot 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/METADATA +2 -2
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/RECORD +70 -56
- moonshot/__main__.py +77 -35
- moonshot/api.py +16 -0
- moonshot/integrations/cli/benchmark/benchmark.py +29 -13
- moonshot/integrations/cli/benchmark/cookbook.py +62 -24
- moonshot/integrations/cli/benchmark/datasets.py +79 -40
- moonshot/integrations/cli/benchmark/metrics.py +62 -23
- moonshot/integrations/cli/benchmark/recipe.py +89 -69
- moonshot/integrations/cli/benchmark/result.py +85 -47
- moonshot/integrations/cli/benchmark/run.py +99 -59
- moonshot/integrations/cli/common/common.py +20 -6
- moonshot/integrations/cli/common/connectors.py +154 -74
- moonshot/integrations/cli/common/dataset.py +66 -0
- moonshot/integrations/cli/common/prompt_template.py +57 -19
- moonshot/integrations/cli/redteam/attack_module.py +90 -24
- moonshot/integrations/cli/redteam/context_strategy.py +83 -23
- moonshot/integrations/cli/redteam/prompt_template.py +1 -1
- moonshot/integrations/cli/redteam/redteam.py +52 -6
- moonshot/integrations/cli/redteam/session.py +565 -44
- moonshot/integrations/cli/utils/process_data.py +52 -0
- moonshot/integrations/web_api/__main__.py +2 -0
- moonshot/integrations/web_api/app.py +6 -6
- moonshot/integrations/web_api/container.py +12 -2
- moonshot/integrations/web_api/routes/bookmark.py +173 -0
- moonshot/integrations/web_api/routes/dataset.py +46 -1
- moonshot/integrations/web_api/schemas/bookmark_create_dto.py +13 -0
- moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
- moonshot/integrations/web_api/services/bookmark_service.py +94 -0
- moonshot/integrations/web_api/services/dataset_service.py +25 -0
- moonshot/integrations/web_api/services/recipe_service.py +0 -1
- moonshot/integrations/web_api/services/utils/file_manager.py +52 -0
- moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +0 -1
- moonshot/integrations/web_api/temp/.gitkeep +0 -0
- moonshot/src/api/api_bookmark.py +95 -0
- moonshot/src/api/api_connector_endpoint.py +1 -1
- moonshot/src/api/api_context_strategy.py +2 -2
- moonshot/src/api/api_dataset.py +35 -0
- moonshot/src/api/api_recipe.py +0 -3
- moonshot/src/api/api_session.py +1 -1
- moonshot/src/bookmark/bookmark.py +257 -0
- moonshot/src/bookmark/bookmark_arguments.py +38 -0
- moonshot/src/configs/env_variables.py +12 -2
- moonshot/src/connectors/connector.py +15 -7
- moonshot/src/connectors_endpoints/connector_endpoint.py +65 -49
- moonshot/src/cookbooks/cookbook.py +57 -37
- moonshot/src/datasets/dataset.py +125 -5
- moonshot/src/metrics/metric.py +8 -4
- moonshot/src/metrics/metric_interface.py +8 -2
- moonshot/src/prompt_templates/prompt_template.py +5 -1
- moonshot/src/recipes/recipe.py +38 -40
- moonshot/src/recipes/recipe_arguments.py +0 -4
- moonshot/src/redteaming/attack/attack_module.py +18 -8
- moonshot/src/redteaming/attack/context_strategy.py +6 -2
- moonshot/src/redteaming/session/session.py +15 -11
- moonshot/src/results/result.py +7 -3
- moonshot/src/runners/runner.py +65 -42
- moonshot/src/runs/run.py +15 -11
- moonshot/src/runs/run_progress.py +7 -3
- moonshot/src/storage/db_interface.py +14 -0
- moonshot/src/storage/storage.py +33 -2
- moonshot/src/utils/find_feature.py +45 -0
- moonshot/src/utils/log.py +72 -0
- moonshot/src/utils/pagination.py +25 -0
- moonshot/src/utils/timeit.py +8 -1
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/WHEEL +0 -0
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/AUTHORS.md +0 -0
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/LICENSE.md +0 -0
- {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/NOTICES.md +0 -0
|
@@ -2,13 +2,19 @@ import argparse
|
|
|
2
2
|
|
|
3
3
|
import cmd2
|
|
4
4
|
|
|
5
|
+
from moonshot.integrations.cli.common.dataset import (
|
|
6
|
+
add_dataset,
|
|
7
|
+
add_dataset_args
|
|
8
|
+
)
|
|
5
9
|
from moonshot.integrations.cli.common.connectors import (
|
|
6
10
|
add_endpoint,
|
|
7
11
|
add_endpoint_args,
|
|
8
12
|
delete_endpoint,
|
|
9
13
|
delete_endpoint_args,
|
|
10
14
|
list_connector_types,
|
|
15
|
+
list_connector_types_args,
|
|
11
16
|
list_endpoints,
|
|
17
|
+
list_endpoints_args,
|
|
12
18
|
update_endpoint,
|
|
13
19
|
update_endpoint_args,
|
|
14
20
|
view_endpoint,
|
|
@@ -18,6 +24,7 @@ from moonshot.integrations.cli.common.prompt_template import (
|
|
|
18
24
|
delete_prompt_template,
|
|
19
25
|
delete_prompt_template_args,
|
|
20
26
|
list_prompt_templates,
|
|
27
|
+
list_prompt_templates_args,
|
|
21
28
|
)
|
|
22
29
|
|
|
23
30
|
|
|
@@ -30,14 +37,17 @@ class CommonCommandSet(cmd2.CommandSet):
|
|
|
30
37
|
# List contents
|
|
31
38
|
# ------------------------------------------------------------------------------
|
|
32
39
|
|
|
33
|
-
|
|
34
|
-
|
|
40
|
+
@cmd2.with_argparser(list_connector_types_args)
|
|
41
|
+
def do_list_connector_types(self, args: argparse.Namespace) -> None:
|
|
42
|
+
list_connector_types(args)
|
|
35
43
|
|
|
36
|
-
|
|
37
|
-
|
|
44
|
+
@cmd2.with_argparser(list_endpoints_args)
|
|
45
|
+
def do_list_endpoints(self, args: argparse.Namespace) -> None:
|
|
46
|
+
list_endpoints(args)
|
|
38
47
|
|
|
39
|
-
|
|
40
|
-
|
|
48
|
+
@cmd2.with_argparser(list_prompt_templates_args)
|
|
49
|
+
def do_list_prompt_templates(self, args: argparse.Namespace) -> None:
|
|
50
|
+
list_prompt_templates(args)
|
|
41
51
|
|
|
42
52
|
@cmd2.with_argparser(delete_prompt_template_args)
|
|
43
53
|
def do_delete_prompt_template(self, args: argparse.Namespace) -> None:
|
|
@@ -50,6 +60,10 @@ class CommonCommandSet(cmd2.CommandSet):
|
|
|
50
60
|
def do_add_endpoint(self, args: argparse.Namespace) -> None:
|
|
51
61
|
add_endpoint(args)
|
|
52
62
|
|
|
63
|
+
@cmd2.with_argparser(add_dataset_args)
|
|
64
|
+
def do_add_dataset(self, args:argparse.Namespace) -> None:
|
|
65
|
+
add_dataset(args)
|
|
66
|
+
|
|
53
67
|
# ------------------------------------------------------------------------------
|
|
54
68
|
# Delete contents
|
|
55
69
|
# ------------------------------------------------------------------------------
|
|
@@ -13,6 +13,7 @@ from moonshot.api import (
|
|
|
13
13
|
api_read_endpoint,
|
|
14
14
|
api_update_endpoint,
|
|
15
15
|
)
|
|
16
|
+
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
16
17
|
|
|
17
18
|
console = Console()
|
|
18
19
|
|
|
@@ -57,36 +58,70 @@ def add_endpoint(args) -> None:
|
|
|
57
58
|
print(f"[add_endpoint]: {str(e)}")
|
|
58
59
|
|
|
59
60
|
|
|
60
|
-
def list_endpoints() -> None:
|
|
61
|
+
def list_endpoints(args) -> list | None:
|
|
61
62
|
"""
|
|
62
63
|
List all endpoints.
|
|
63
64
|
|
|
64
65
|
This function retrieves all endpoints by calling the api_get_all_endpoint function from the
|
|
65
|
-
moonshot.api module. It then displays the endpoints using the
|
|
66
|
+
moonshot.api module. It then displays the endpoints using the _display_endpoints function.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
args: A namespace object from argparse. It should have an optional attribute:
|
|
70
|
+
find (str): Optional field to find endpoint(s) with a keyword.
|
|
71
|
+
pagination (str): Optional field to paginate endpoints.
|
|
66
72
|
|
|
67
73
|
Returns:
|
|
68
|
-
None
|
|
74
|
+
list | None: A list of ConnectorEndpoint or None if there is no result.
|
|
69
75
|
"""
|
|
70
76
|
try:
|
|
71
|
-
|
|
72
|
-
|
|
77
|
+
endpoints_list = api_get_all_endpoint()
|
|
78
|
+
keyword = args.find.lower() if args.find else ""
|
|
79
|
+
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
80
|
+
|
|
81
|
+
if endpoints_list:
|
|
82
|
+
filtered_endpoints_list = filter_data(endpoints_list, keyword, pagination)
|
|
83
|
+
if filtered_endpoints_list:
|
|
84
|
+
_display_endpoints(filtered_endpoints_list)
|
|
85
|
+
return filtered_endpoints_list
|
|
86
|
+
|
|
87
|
+
console.print("[red]There are no endpoints found.[/red]")
|
|
88
|
+
return None
|
|
89
|
+
|
|
73
90
|
except Exception as e:
|
|
74
91
|
print(f"[list_endpoints]: {str(e)}")
|
|
75
92
|
|
|
76
93
|
|
|
77
|
-
def list_connector_types() -> None:
|
|
94
|
+
def list_connector_types(args) -> list | None:
|
|
78
95
|
"""
|
|
79
96
|
List all connector types.
|
|
80
97
|
|
|
81
98
|
This function retrieves all connector types by calling the api_get_all_connector_type function from the
|
|
82
|
-
moonshot.api module. It then displays the connector types using the
|
|
99
|
+
moonshot.api module. It then displays the connector types using the _display_connector_types function.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
args: A namespace object from argparse. It should have an optional attribute:
|
|
103
|
+
find (str): Optional field to find connector type(s) with a keyword.
|
|
104
|
+
pagination (str): Optional field to paginate connector types.
|
|
83
105
|
|
|
84
106
|
Returns:
|
|
85
|
-
None
|
|
107
|
+
list | None: A list of Connector or None if there is no result.
|
|
86
108
|
"""
|
|
87
109
|
try:
|
|
88
110
|
connector_type_list = api_get_all_connector_type()
|
|
89
|
-
|
|
111
|
+
keyword = args.find.lower() if args.find else ""
|
|
112
|
+
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
113
|
+
|
|
114
|
+
if connector_type_list:
|
|
115
|
+
filtered_connector_type_list = filter_data(
|
|
116
|
+
connector_type_list, keyword, pagination
|
|
117
|
+
)
|
|
118
|
+
if filtered_connector_type_list:
|
|
119
|
+
_display_connector_types(filtered_connector_type_list)
|
|
120
|
+
return filtered_connector_type_list
|
|
121
|
+
|
|
122
|
+
console.print("[red]There are no connector types found.[/red]")
|
|
123
|
+
return None
|
|
124
|
+
|
|
90
125
|
except Exception as e:
|
|
91
126
|
print(f"[list_connector_types]: {str(e)}")
|
|
92
127
|
|
|
@@ -97,7 +132,7 @@ def view_endpoint(args) -> None:
|
|
|
97
132
|
|
|
98
133
|
This function retrieves a specific endpoint by calling the api_read_endpoint function from the
|
|
99
134
|
moonshot.api module using the endpoint name provided in the args. It then displays the endpoint
|
|
100
|
-
information using the
|
|
135
|
+
information using the _display_endpoints function.
|
|
101
136
|
|
|
102
137
|
Args:
|
|
103
138
|
args: A namespace object from argparse. It should have the following attribute:
|
|
@@ -108,7 +143,7 @@ def view_endpoint(args) -> None:
|
|
|
108
143
|
"""
|
|
109
144
|
try:
|
|
110
145
|
endpoint_info = api_read_endpoint(args.endpoint)
|
|
111
|
-
|
|
146
|
+
_display_endpoints([endpoint_info])
|
|
112
147
|
except Exception as e:
|
|
113
148
|
print(f"[view_endpoint]: {str(e)}")
|
|
114
149
|
|
|
@@ -170,7 +205,7 @@ def delete_endpoint(args) -> None:
|
|
|
170
205
|
# ------------------------------------------------------------------------------
|
|
171
206
|
# Helper functions: Display on cli
|
|
172
207
|
# ------------------------------------------------------------------------------
|
|
173
|
-
def
|
|
208
|
+
def _display_connector_types(connector_types: list) -> None:
|
|
174
209
|
"""
|
|
175
210
|
Display a list of connector types.
|
|
176
211
|
|
|
@@ -183,24 +218,22 @@ def display_connector_types(connector_types):
|
|
|
183
218
|
Returns:
|
|
184
219
|
None
|
|
185
220
|
"""
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
else:
|
|
200
|
-
console.print("[red]There are no connector types found.[/red]")
|
|
221
|
+
table = Table(
|
|
222
|
+
title="List of Connector Types",
|
|
223
|
+
show_lines=True,
|
|
224
|
+
expand=True,
|
|
225
|
+
header_style="bold",
|
|
226
|
+
)
|
|
227
|
+
table.add_column("No.", width=2)
|
|
228
|
+
table.add_column("Connector Type", justify="left", width=78)
|
|
229
|
+
|
|
230
|
+
for idx, connector_type in enumerate(connector_types, 1):
|
|
231
|
+
table.add_section()
|
|
232
|
+
table.add_row(str(idx), connector_type)
|
|
233
|
+
console.print(table)
|
|
201
234
|
|
|
202
235
|
|
|
203
|
-
def
|
|
236
|
+
def _display_endpoints(endpoints_list):
|
|
204
237
|
"""
|
|
205
238
|
Display a list of endpoints.
|
|
206
239
|
|
|
@@ -214,52 +247,51 @@ def display_endpoints(endpoints_list):
|
|
|
214
247
|
Returns:
|
|
215
248
|
None
|
|
216
249
|
"""
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
250
|
+
table = Table(
|
|
251
|
+
title="List of Connector Endpoints",
|
|
252
|
+
show_lines=True,
|
|
253
|
+
expand=True,
|
|
254
|
+
header_style="bold",
|
|
255
|
+
)
|
|
256
|
+
table.add_column("No.", justify="left", width=2)
|
|
257
|
+
table.add_column("Id", justify="left", width=10)
|
|
258
|
+
table.add_column("Name", justify="left", width=10)
|
|
259
|
+
table.add_column("Connector Type", justify="left", width=10)
|
|
260
|
+
table.add_column("Uri", justify="left", width=10)
|
|
261
|
+
table.add_column("Token", justify="left", width=10)
|
|
262
|
+
table.add_column("Max Calls Per Second", justify="left", width=5)
|
|
263
|
+
table.add_column("Max concurrency", justify="left", width=5)
|
|
264
|
+
table.add_column("Params", justify="left", width=30)
|
|
265
|
+
table.add_column("Created Date", justify="left", width=8)
|
|
266
|
+
|
|
267
|
+
for idx, endpoint in enumerate(endpoints_list, 1):
|
|
268
|
+
(
|
|
269
|
+
id,
|
|
270
|
+
name,
|
|
271
|
+
connector_type,
|
|
272
|
+
uri,
|
|
273
|
+
token,
|
|
274
|
+
max_calls_per_second,
|
|
275
|
+
max_concurrency,
|
|
276
|
+
params,
|
|
277
|
+
created_date,
|
|
278
|
+
*other_args,
|
|
279
|
+
) = endpoint.values()
|
|
280
|
+
table.add_section()
|
|
281
|
+
idx = endpoint.get("idx", idx)
|
|
282
|
+
table.add_row(
|
|
283
|
+
str(idx),
|
|
284
|
+
id,
|
|
285
|
+
name,
|
|
286
|
+
connector_type,
|
|
287
|
+
uri,
|
|
288
|
+
token,
|
|
289
|
+
str(max_calls_per_second),
|
|
290
|
+
str(max_concurrency),
|
|
291
|
+
escape(str(params)),
|
|
292
|
+
created_date,
|
|
223
293
|
)
|
|
224
|
-
|
|
225
|
-
table.add_column("Id", justify="left", width=10)
|
|
226
|
-
table.add_column("Name", justify="left", width=10)
|
|
227
|
-
table.add_column("Connector Type", justify="left", width=10)
|
|
228
|
-
table.add_column("Uri", justify="left", width=10)
|
|
229
|
-
table.add_column("Token", justify="left", width=10)
|
|
230
|
-
table.add_column("Max Calls Per Second", justify="left", width=5)
|
|
231
|
-
table.add_column("Max concurrency", justify="left", width=5)
|
|
232
|
-
table.add_column("Params", justify="left", width=30)
|
|
233
|
-
table.add_column("Created Date", justify="left", width=8)
|
|
234
|
-
|
|
235
|
-
for endpoint_id, endpoint in enumerate(endpoints_list, 1):
|
|
236
|
-
(
|
|
237
|
-
id,
|
|
238
|
-
name,
|
|
239
|
-
connector_type,
|
|
240
|
-
uri,
|
|
241
|
-
token,
|
|
242
|
-
max_calls_per_second,
|
|
243
|
-
max_concurrency,
|
|
244
|
-
params,
|
|
245
|
-
created_date,
|
|
246
|
-
) = endpoint.values()
|
|
247
|
-
table.add_section()
|
|
248
|
-
table.add_row(
|
|
249
|
-
str(endpoint_id),
|
|
250
|
-
id,
|
|
251
|
-
name,
|
|
252
|
-
connector_type,
|
|
253
|
-
uri,
|
|
254
|
-
token,
|
|
255
|
-
str(max_calls_per_second),
|
|
256
|
-
str(max_concurrency),
|
|
257
|
-
escape(str(params)),
|
|
258
|
-
created_date,
|
|
259
|
-
)
|
|
260
|
-
console.print(table)
|
|
261
|
-
else:
|
|
262
|
-
console.print("[red]There are no endpoints found.[/red]")
|
|
294
|
+
console.print(table)
|
|
263
295
|
|
|
264
296
|
|
|
265
297
|
# ------------------------------------------------------------------------------
|
|
@@ -305,7 +337,11 @@ update_endpoint_args = cmd2.Cmd2ArgumentParser(
|
|
|
305
337
|
"('uri', 'my-uri-loc'), ('token', 'my-token-here')]\""
|
|
306
338
|
),
|
|
307
339
|
)
|
|
308
|
-
update_endpoint_args.add_argument(
|
|
340
|
+
update_endpoint_args.add_argument(
|
|
341
|
+
"endpoint",
|
|
342
|
+
type=str,
|
|
343
|
+
help="ID of the endpoint. This field is not editable via CLI after creation.",
|
|
344
|
+
)
|
|
309
345
|
update_endpoint_args.add_argument(
|
|
310
346
|
"update_kwargs", type=str, help="Update endpoint key/value"
|
|
311
347
|
)
|
|
@@ -323,3 +359,47 @@ delete_endpoint_args = cmd2.Cmd2ArgumentParser(
|
|
|
323
359
|
epilog="Example:\n delete_endpoint openai-gpt4",
|
|
324
360
|
)
|
|
325
361
|
delete_endpoint_args.add_argument("endpoint", type=str, help="ID of the endpoint")
|
|
362
|
+
|
|
363
|
+
# List endpoint arguments
|
|
364
|
+
list_endpoints_args = cmd2.Cmd2ArgumentParser(
|
|
365
|
+
description="List all endpoints.",
|
|
366
|
+
epilog='Example:\n list_endpoints -f "gpt"',
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
list_endpoints_args.add_argument(
|
|
370
|
+
"-f",
|
|
371
|
+
"--find",
|
|
372
|
+
type=str,
|
|
373
|
+
help="Optional field to find endpoint(s) with keyword",
|
|
374
|
+
nargs="?",
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
list_endpoints_args.add_argument(
|
|
378
|
+
"-p",
|
|
379
|
+
"--pagination",
|
|
380
|
+
type=str,
|
|
381
|
+
help="Optional tuple to paginate endpoint(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
|
|
382
|
+
nargs="?",
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# List connector types arguments
|
|
386
|
+
list_connector_types_args = cmd2.Cmd2ArgumentParser(
|
|
387
|
+
description="List all connector types.",
|
|
388
|
+
epilog='Example:\n list_connector_types -f "openai"',
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
list_connector_types_args.add_argument(
|
|
392
|
+
"-f",
|
|
393
|
+
"--find",
|
|
394
|
+
type=str,
|
|
395
|
+
help="Optional field to find connector type(s) with keyword",
|
|
396
|
+
nargs="?",
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
list_connector_types_args.add_argument(
|
|
400
|
+
"-p",
|
|
401
|
+
"--pagination",
|
|
402
|
+
type=str,
|
|
403
|
+
help="Optional tuple to paginate connector type(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
|
|
404
|
+
nargs="?",
|
|
405
|
+
)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from ast import literal_eval
|
|
2
|
+
|
|
3
|
+
import cmd2
|
|
4
|
+
from rich.console import Console
|
|
5
|
+
|
|
6
|
+
from moonshot.api import (
|
|
7
|
+
api_create_datasets,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
console = Console()
|
|
11
|
+
def add_dataset(args) -> None:
|
|
12
|
+
"""
|
|
13
|
+
Create a new dataset using the provided arguments and log the result.
|
|
14
|
+
|
|
15
|
+
This function attempts to create a new dataset by calling the `api_create_datasets`
|
|
16
|
+
function with the necessary parameters extracted from `args`. If successful, it logs
|
|
17
|
+
the creation of the dataset with its ID. If an exception occurs, it logs the error.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
args: An argparse.Namespace object containing the following attributes:
|
|
21
|
+
- name (str): Name of the new dataset.
|
|
22
|
+
- description (str): Description of the new dataset.
|
|
23
|
+
- reference (str): Reference URL for the new dataset.
|
|
24
|
+
- license (str): License type for the new dataset.
|
|
25
|
+
- method (str): Method to convert the new dataset ('hf' or 'csv').
|
|
26
|
+
- params (dict): Additional parameters for dataset creation.
|
|
27
|
+
"""
|
|
28
|
+
try:
|
|
29
|
+
new_dataset_id = api_create_datasets(
|
|
30
|
+
args.name,
|
|
31
|
+
args.description,
|
|
32
|
+
args.reference,
|
|
33
|
+
args.license,
|
|
34
|
+
args.method,
|
|
35
|
+
**args.params,
|
|
36
|
+
)
|
|
37
|
+
print(f"[add_dataset]: Dataset ({new_dataset_id}) created.")
|
|
38
|
+
except Exception as e:
|
|
39
|
+
print(f"[add_dataset]: {str(e)}")
|
|
40
|
+
|
|
41
|
+
# ------------------------------------------------------------------------------
|
|
42
|
+
# Cmd2 Arguments Parsers
|
|
43
|
+
# ------------------------------------------------------------------------------
|
|
44
|
+
# Add dataset arguments
|
|
45
|
+
add_dataset_args = cmd2.Cmd2ArgumentParser(
|
|
46
|
+
description="Add a new dataset. The 'name' argument will be slugified to create a unique identifier.",
|
|
47
|
+
epilog=(
|
|
48
|
+
"Examples:\n"
|
|
49
|
+
"1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
|
|
50
|
+
"2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\""
|
|
51
|
+
),
|
|
52
|
+
)
|
|
53
|
+
add_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
|
|
54
|
+
add_dataset_args.add_argument("description", type=str, help="Description of the new dataset")
|
|
55
|
+
add_dataset_args.add_argument("reference", type=str, help="Reference of the new dataset")
|
|
56
|
+
add_dataset_args.add_argument("license", type=str, help="License of the new dataset")
|
|
57
|
+
add_dataset_args.add_argument("method", type=str, choices=['hf', 'csv'], help="Method to convert the new dataset. Choose either 'hf' or 'csv'.")
|
|
58
|
+
add_dataset_args.add_argument(
|
|
59
|
+
"params",
|
|
60
|
+
type=literal_eval,
|
|
61
|
+
help=(
|
|
62
|
+
"Params of the new dataset in dictionary format. For example: \n"
|
|
63
|
+
"1. For 'csv' method: \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
|
|
64
|
+
"2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\""
|
|
65
|
+
)
|
|
66
|
+
)
|
|
@@ -1,8 +1,11 @@
|
|
|
1
|
+
from ast import literal_eval
|
|
2
|
+
|
|
1
3
|
import cmd2
|
|
2
4
|
from rich.console import Console
|
|
3
5
|
from rich.table import Table
|
|
4
6
|
|
|
5
7
|
from moonshot.api import api_delete_prompt_template, api_get_all_prompt_template_detail
|
|
8
|
+
from moonshot.integrations.cli.utils.process_data import filter_data
|
|
6
9
|
|
|
7
10
|
console = Console()
|
|
8
11
|
|
|
@@ -10,13 +13,34 @@ console = Console()
|
|
|
10
13
|
# ------------------------------------------------------------------------------
|
|
11
14
|
# CLI Functions
|
|
12
15
|
# ------------------------------------------------------------------------------
|
|
13
|
-
def list_prompt_templates() -> None:
|
|
16
|
+
def list_prompt_templates(args) -> list | None:
|
|
14
17
|
"""
|
|
15
18
|
List all prompt templates available.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
args: A namespace object from argparse. It should have an optional attribute:
|
|
22
|
+
find (str): Optional field to find prompt template(s) with a keyword.
|
|
23
|
+
pagination (str): Optional field to paginate prompt templates.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
list | None: A list of PromptTemplate or None if there is no result.
|
|
16
27
|
"""
|
|
17
28
|
try:
|
|
18
|
-
|
|
19
|
-
|
|
29
|
+
prompt_templates_list = api_get_all_prompt_template_detail()
|
|
30
|
+
keyword = args.find.lower() if args.find else ""
|
|
31
|
+
pagination = literal_eval(args.pagination) if args.pagination else ()
|
|
32
|
+
|
|
33
|
+
if prompt_templates_list:
|
|
34
|
+
filtered_prompt_templates_list = filter_data(
|
|
35
|
+
prompt_templates_list, keyword, pagination
|
|
36
|
+
)
|
|
37
|
+
if filtered_prompt_templates_list:
|
|
38
|
+
_display_prompt_templates(filtered_prompt_templates_list)
|
|
39
|
+
return filtered_prompt_templates_list
|
|
40
|
+
|
|
41
|
+
console.print("[red]There are no prompt templates found.[/red]")
|
|
42
|
+
return None
|
|
43
|
+
|
|
20
44
|
except Exception as e:
|
|
21
45
|
print(f"[list_prompt_templates]: {str(e)}")
|
|
22
46
|
|
|
@@ -46,7 +70,7 @@ def delete_prompt_template(args) -> None:
|
|
|
46
70
|
# ------------------------------------------------------------------------------
|
|
47
71
|
# Helper functions: Display on cli
|
|
48
72
|
# ------------------------------------------------------------------------------
|
|
49
|
-
def
|
|
73
|
+
def _display_prompt_templates(prompt_templates) -> None:
|
|
50
74
|
"""
|
|
51
75
|
Display the list of prompt templates in a formatted table.
|
|
52
76
|
|
|
@@ -66,21 +90,13 @@ def display_prompt_templates(prompt_templates) -> None:
|
|
|
66
90
|
table.add_column("No.", width=2)
|
|
67
91
|
table.add_column("Prompt Template", justify="left", width=50)
|
|
68
92
|
table.add_column("Contains", justify="left", width=48, overflow="fold")
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
) = prompt_template.values()
|
|
77
|
-
|
|
78
|
-
prompt_info = f"[red]id: {id}[/red]\n\n[blue]{name}[/blue]\n{description}"
|
|
79
|
-
table.add_section()
|
|
80
|
-
table.add_row(str(prompt_index), prompt_info, contents)
|
|
81
|
-
console.print(table)
|
|
82
|
-
else:
|
|
83
|
-
console.print("[red]There are no prompt templates found.[/red]")
|
|
93
|
+
for idx, prompt_template in enumerate(prompt_templates, 1):
|
|
94
|
+
(id, name, description, contents, *other_args) = prompt_template.values()
|
|
95
|
+
idx = prompt_template.get("idx", idx)
|
|
96
|
+
prompt_info = f"[red]id: {id}[/red]\n\n[blue]{name}[/blue]\n{description}"
|
|
97
|
+
table.add_section()
|
|
98
|
+
table.add_row(str(idx), prompt_info, contents)
|
|
99
|
+
console.print(table)
|
|
84
100
|
|
|
85
101
|
|
|
86
102
|
# Delete prompt template arguments
|
|
@@ -92,3 +108,25 @@ delete_prompt_template_args = cmd2.Cmd2ArgumentParser(
|
|
|
92
108
|
delete_prompt_template_args.add_argument(
|
|
93
109
|
"prompt_template", type=str, help="The ID of the prompt template to delete"
|
|
94
110
|
)
|
|
111
|
+
|
|
112
|
+
# List prompt template arguments
|
|
113
|
+
list_prompt_templates_args = cmd2.Cmd2ArgumentParser(
|
|
114
|
+
description="List all prompt templates.",
|
|
115
|
+
epilog='Example:\n list_prompt_templates -f "toxicity"',
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
list_prompt_templates_args.add_argument(
|
|
119
|
+
"-f",
|
|
120
|
+
"--find",
|
|
121
|
+
type=str,
|
|
122
|
+
help="Optional field to find prompt template(s) with keyword",
|
|
123
|
+
nargs="?",
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
list_prompt_templates_args.add_argument(
|
|
127
|
+
"-p",
|
|
128
|
+
"--pagination",
|
|
129
|
+
type=str,
|
|
130
|
+
help="Optional tuple to paginate prompt template(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
|
|
131
|
+
nargs="?",
|
|
132
|
+
)
|