aiverify-moonshot 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/METADATA +2 -2
  2. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/RECORD +70 -56
  3. moonshot/__main__.py +77 -35
  4. moonshot/api.py +16 -0
  5. moonshot/integrations/cli/benchmark/benchmark.py +29 -13
  6. moonshot/integrations/cli/benchmark/cookbook.py +62 -24
  7. moonshot/integrations/cli/benchmark/datasets.py +79 -40
  8. moonshot/integrations/cli/benchmark/metrics.py +62 -23
  9. moonshot/integrations/cli/benchmark/recipe.py +89 -69
  10. moonshot/integrations/cli/benchmark/result.py +85 -47
  11. moonshot/integrations/cli/benchmark/run.py +99 -59
  12. moonshot/integrations/cli/common/common.py +20 -6
  13. moonshot/integrations/cli/common/connectors.py +154 -74
  14. moonshot/integrations/cli/common/dataset.py +66 -0
  15. moonshot/integrations/cli/common/prompt_template.py +57 -19
  16. moonshot/integrations/cli/redteam/attack_module.py +90 -24
  17. moonshot/integrations/cli/redteam/context_strategy.py +83 -23
  18. moonshot/integrations/cli/redteam/prompt_template.py +1 -1
  19. moonshot/integrations/cli/redteam/redteam.py +52 -6
  20. moonshot/integrations/cli/redteam/session.py +565 -44
  21. moonshot/integrations/cli/utils/process_data.py +52 -0
  22. moonshot/integrations/web_api/__main__.py +2 -0
  23. moonshot/integrations/web_api/app.py +6 -6
  24. moonshot/integrations/web_api/container.py +12 -2
  25. moonshot/integrations/web_api/routes/bookmark.py +173 -0
  26. moonshot/integrations/web_api/routes/dataset.py +46 -1
  27. moonshot/integrations/web_api/schemas/bookmark_create_dto.py +13 -0
  28. moonshot/integrations/web_api/schemas/dataset_create_dto.py +18 -0
  29. moonshot/integrations/web_api/schemas/recipe_create_dto.py +0 -2
  30. moonshot/integrations/web_api/services/bookmark_service.py +94 -0
  31. moonshot/integrations/web_api/services/dataset_service.py +25 -0
  32. moonshot/integrations/web_api/services/recipe_service.py +0 -1
  33. moonshot/integrations/web_api/services/utils/file_manager.py +52 -0
  34. moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +0 -1
  35. moonshot/integrations/web_api/temp/.gitkeep +0 -0
  36. moonshot/src/api/api_bookmark.py +95 -0
  37. moonshot/src/api/api_connector_endpoint.py +1 -1
  38. moonshot/src/api/api_context_strategy.py +2 -2
  39. moonshot/src/api/api_dataset.py +35 -0
  40. moonshot/src/api/api_recipe.py +0 -3
  41. moonshot/src/api/api_session.py +1 -1
  42. moonshot/src/bookmark/bookmark.py +257 -0
  43. moonshot/src/bookmark/bookmark_arguments.py +38 -0
  44. moonshot/src/configs/env_variables.py +12 -2
  45. moonshot/src/connectors/connector.py +15 -7
  46. moonshot/src/connectors_endpoints/connector_endpoint.py +65 -49
  47. moonshot/src/cookbooks/cookbook.py +57 -37
  48. moonshot/src/datasets/dataset.py +125 -5
  49. moonshot/src/metrics/metric.py +8 -4
  50. moonshot/src/metrics/metric_interface.py +8 -2
  51. moonshot/src/prompt_templates/prompt_template.py +5 -1
  52. moonshot/src/recipes/recipe.py +38 -40
  53. moonshot/src/recipes/recipe_arguments.py +0 -4
  54. moonshot/src/redteaming/attack/attack_module.py +18 -8
  55. moonshot/src/redteaming/attack/context_strategy.py +6 -2
  56. moonshot/src/redteaming/session/session.py +15 -11
  57. moonshot/src/results/result.py +7 -3
  58. moonshot/src/runners/runner.py +65 -42
  59. moonshot/src/runs/run.py +15 -11
  60. moonshot/src/runs/run_progress.py +7 -3
  61. moonshot/src/storage/db_interface.py +14 -0
  62. moonshot/src/storage/storage.py +33 -2
  63. moonshot/src/utils/find_feature.py +45 -0
  64. moonshot/src/utils/log.py +72 -0
  65. moonshot/src/utils/pagination.py +25 -0
  66. moonshot/src/utils/timeit.py +8 -1
  67. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/WHEEL +0 -0
  68. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/AUTHORS.md +0 -0
  69. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/LICENSE.md +0 -0
  70. {aiverify_moonshot-0.4.1.dist-info → aiverify_moonshot-0.4.3.dist-info}/licenses/NOTICES.md +0 -0
@@ -2,13 +2,19 @@ import argparse
2
2
 
3
3
  import cmd2
4
4
 
5
+ from moonshot.integrations.cli.common.dataset import (
6
+ add_dataset,
7
+ add_dataset_args
8
+ )
5
9
  from moonshot.integrations.cli.common.connectors import (
6
10
  add_endpoint,
7
11
  add_endpoint_args,
8
12
  delete_endpoint,
9
13
  delete_endpoint_args,
10
14
  list_connector_types,
15
+ list_connector_types_args,
11
16
  list_endpoints,
17
+ list_endpoints_args,
12
18
  update_endpoint,
13
19
  update_endpoint_args,
14
20
  view_endpoint,
@@ -18,6 +24,7 @@ from moonshot.integrations.cli.common.prompt_template import (
18
24
  delete_prompt_template,
19
25
  delete_prompt_template_args,
20
26
  list_prompt_templates,
27
+ list_prompt_templates_args,
21
28
  )
22
29
 
23
30
 
@@ -30,14 +37,17 @@ class CommonCommandSet(cmd2.CommandSet):
30
37
  # List contents
31
38
  # ------------------------------------------------------------------------------
32
39
 
33
- def do_list_connector_types(self, _: cmd2.Statement) -> None:
34
- list_connector_types()
40
+ @cmd2.with_argparser(list_connector_types_args)
41
+ def do_list_connector_types(self, args: argparse.Namespace) -> None:
42
+ list_connector_types(args)
35
43
 
36
- def do_list_endpoints(self, _: cmd2.Statement) -> None:
37
- list_endpoints()
44
+ @cmd2.with_argparser(list_endpoints_args)
45
+ def do_list_endpoints(self, args: argparse.Namespace) -> None:
46
+ list_endpoints(args)
38
47
 
39
- def do_list_prompt_templates(self, _: cmd2.Statement) -> None:
40
- list_prompt_templates()
48
+ @cmd2.with_argparser(list_prompt_templates_args)
49
+ def do_list_prompt_templates(self, args: argparse.Namespace) -> None:
50
+ list_prompt_templates(args)
41
51
 
42
52
  @cmd2.with_argparser(delete_prompt_template_args)
43
53
  def do_delete_prompt_template(self, args: argparse.Namespace) -> None:
@@ -50,6 +60,10 @@ class CommonCommandSet(cmd2.CommandSet):
50
60
  def do_add_endpoint(self, args: argparse.Namespace) -> None:
51
61
  add_endpoint(args)
52
62
 
63
+ @cmd2.with_argparser(add_dataset_args)
64
+ def do_add_dataset(self, args:argparse.Namespace) -> None:
65
+ add_dataset(args)
66
+
53
67
  # ------------------------------------------------------------------------------
54
68
  # Delete contents
55
69
  # ------------------------------------------------------------------------------
@@ -13,6 +13,7 @@ from moonshot.api import (
13
13
  api_read_endpoint,
14
14
  api_update_endpoint,
15
15
  )
16
+ from moonshot.integrations.cli.utils.process_data import filter_data
16
17
 
17
18
  console = Console()
18
19
 
@@ -57,36 +58,70 @@ def add_endpoint(args) -> None:
57
58
  print(f"[add_endpoint]: {str(e)}")
58
59
 
59
60
 
60
- def list_endpoints() -> None:
61
+ def list_endpoints(args) -> list | None:
61
62
  """
62
63
  List all endpoints.
63
64
 
64
65
  This function retrieves all endpoints by calling the api_get_all_endpoint function from the
65
- moonshot.api module. It then displays the endpoints using the display_endpoints function.
66
+ moonshot.api module. It then displays the endpoints using the _display_endpoints function.
67
+
68
+ Args:
69
+ args: A namespace object from argparse. It should have an optional attribute:
70
+ find (str): Optional field to find endpoint(s) with a keyword.
71
+ pagination (str): Optional field to paginate endpoints.
66
72
 
67
73
  Returns:
68
- None
74
+ list | None: A list of ConnectorEndpoint or None if there is no result.
69
75
  """
70
76
  try:
71
- endpoint_list = api_get_all_endpoint()
72
- display_endpoints(endpoint_list)
77
+ endpoints_list = api_get_all_endpoint()
78
+ keyword = args.find.lower() if args.find else ""
79
+ pagination = literal_eval(args.pagination) if args.pagination else ()
80
+
81
+ if endpoints_list:
82
+ filtered_endpoints_list = filter_data(endpoints_list, keyword, pagination)
83
+ if filtered_endpoints_list:
84
+ _display_endpoints(filtered_endpoints_list)
85
+ return filtered_endpoints_list
86
+
87
+ console.print("[red]There are no endpoints found.[/red]")
88
+ return None
89
+
73
90
  except Exception as e:
74
91
  print(f"[list_endpoints]: {str(e)}")
75
92
 
76
93
 
77
- def list_connector_types() -> None:
94
+ def list_connector_types(args) -> list | None:
78
95
  """
79
96
  List all connector types.
80
97
 
81
98
  This function retrieves all connector types by calling the api_get_all_connector_type function from the
82
- moonshot.api module. It then displays the connector types using the display_connector_types function.
99
+ moonshot.api module. It then displays the connector types using the _display_connector_types function.
100
+
101
+ Args:
102
+ args: A namespace object from argparse. It should have an optional attribute:
103
+ find (str): Optional field to find connector type(s) with a keyword.
104
+ pagination (str): Optional field to paginate connector types.
83
105
 
84
106
  Returns:
85
- None
107
+ list | None: A list of Connector or None if there is no result.
86
108
  """
87
109
  try:
88
110
  connector_type_list = api_get_all_connector_type()
89
- display_connector_types(connector_type_list)
111
+ keyword = args.find.lower() if args.find else ""
112
+ pagination = literal_eval(args.pagination) if args.pagination else ()
113
+
114
+ if connector_type_list:
115
+ filtered_connector_type_list = filter_data(
116
+ connector_type_list, keyword, pagination
117
+ )
118
+ if filtered_connector_type_list:
119
+ _display_connector_types(filtered_connector_type_list)
120
+ return filtered_connector_type_list
121
+
122
+ console.print("[red]There are no connector types found.[/red]")
123
+ return None
124
+
90
125
  except Exception as e:
91
126
  print(f"[list_connector_types]: {str(e)}")
92
127
 
@@ -97,7 +132,7 @@ def view_endpoint(args) -> None:
97
132
 
98
133
  This function retrieves a specific endpoint by calling the api_read_endpoint function from the
99
134
  moonshot.api module using the endpoint name provided in the args. It then displays the endpoint
100
- information using the display_endpoints function.
135
+ information using the _display_endpoints function.
101
136
 
102
137
  Args:
103
138
  args: A namespace object from argparse. It should have the following attribute:
@@ -108,7 +143,7 @@ def view_endpoint(args) -> None:
108
143
  """
109
144
  try:
110
145
  endpoint_info = api_read_endpoint(args.endpoint)
111
- display_endpoints([endpoint_info])
146
+ _display_endpoints([endpoint_info])
112
147
  except Exception as e:
113
148
  print(f"[view_endpoint]: {str(e)}")
114
149
 
@@ -170,7 +205,7 @@ def delete_endpoint(args) -> None:
170
205
  # ------------------------------------------------------------------------------
171
206
  # Helper functions: Display on cli
172
207
  # ------------------------------------------------------------------------------
173
- def display_connector_types(connector_types):
208
+ def _display_connector_types(connector_types: list) -> None:
174
209
  """
175
210
  Display a list of connector types.
176
211
 
@@ -183,24 +218,22 @@ def display_connector_types(connector_types):
183
218
  Returns:
184
219
  None
185
220
  """
186
- if connector_types:
187
- table = Table(
188
- title="List of Connector Types",
189
- show_lines=True,
190
- expand=True,
191
- header_style="bold",
192
- )
193
- table.add_column("No.", width=2)
194
- table.add_column("Connector Type", justify="left", width=78)
195
- for connector_id, connector_type in enumerate(connector_types, 1):
196
- table.add_section()
197
- table.add_row(str(connector_id), connector_type)
198
- console.print(table)
199
- else:
200
- console.print("[red]There are no connector types found.[/red]")
221
+ table = Table(
222
+ title="List of Connector Types",
223
+ show_lines=True,
224
+ expand=True,
225
+ header_style="bold",
226
+ )
227
+ table.add_column("No.", width=2)
228
+ table.add_column("Connector Type", justify="left", width=78)
229
+
230
+ for idx, connector_type in enumerate(connector_types, 1):
231
+ table.add_section()
232
+ table.add_row(str(idx), connector_type)
233
+ console.print(table)
201
234
 
202
235
 
203
- def display_endpoints(endpoints_list):
236
+ def _display_endpoints(endpoints_list):
204
237
  """
205
238
  Display a list of endpoints.
206
239
 
@@ -214,52 +247,51 @@ def display_endpoints(endpoints_list):
214
247
  Returns:
215
248
  None
216
249
  """
217
- if endpoints_list:
218
- table = Table(
219
- title="List of Connector Endpoints",
220
- show_lines=True,
221
- expand=True,
222
- header_style="bold",
250
+ table = Table(
251
+ title="List of Connector Endpoints",
252
+ show_lines=True,
253
+ expand=True,
254
+ header_style="bold",
255
+ )
256
+ table.add_column("No.", justify="left", width=2)
257
+ table.add_column("Id", justify="left", width=10)
258
+ table.add_column("Name", justify="left", width=10)
259
+ table.add_column("Connector Type", justify="left", width=10)
260
+ table.add_column("Uri", justify="left", width=10)
261
+ table.add_column("Token", justify="left", width=10)
262
+ table.add_column("Max Calls Per Second", justify="left", width=5)
263
+ table.add_column("Max concurrency", justify="left", width=5)
264
+ table.add_column("Params", justify="left", width=30)
265
+ table.add_column("Created Date", justify="left", width=8)
266
+
267
+ for idx, endpoint in enumerate(endpoints_list, 1):
268
+ (
269
+ id,
270
+ name,
271
+ connector_type,
272
+ uri,
273
+ token,
274
+ max_calls_per_second,
275
+ max_concurrency,
276
+ params,
277
+ created_date,
278
+ *other_args,
279
+ ) = endpoint.values()
280
+ table.add_section()
281
+ idx = endpoint.get("idx", idx)
282
+ table.add_row(
283
+ str(idx),
284
+ id,
285
+ name,
286
+ connector_type,
287
+ uri,
288
+ token,
289
+ str(max_calls_per_second),
290
+ str(max_concurrency),
291
+ escape(str(params)),
292
+ created_date,
223
293
  )
224
- table.add_column("No.", justify="left", width=2)
225
- table.add_column("Id", justify="left", width=10)
226
- table.add_column("Name", justify="left", width=10)
227
- table.add_column("Connector Type", justify="left", width=10)
228
- table.add_column("Uri", justify="left", width=10)
229
- table.add_column("Token", justify="left", width=10)
230
- table.add_column("Max Calls Per Second", justify="left", width=5)
231
- table.add_column("Max concurrency", justify="left", width=5)
232
- table.add_column("Params", justify="left", width=30)
233
- table.add_column("Created Date", justify="left", width=8)
234
-
235
- for endpoint_id, endpoint in enumerate(endpoints_list, 1):
236
- (
237
- id,
238
- name,
239
- connector_type,
240
- uri,
241
- token,
242
- max_calls_per_second,
243
- max_concurrency,
244
- params,
245
- created_date,
246
- ) = endpoint.values()
247
- table.add_section()
248
- table.add_row(
249
- str(endpoint_id),
250
- id,
251
- name,
252
- connector_type,
253
- uri,
254
- token,
255
- str(max_calls_per_second),
256
- str(max_concurrency),
257
- escape(str(params)),
258
- created_date,
259
- )
260
- console.print(table)
261
- else:
262
- console.print("[red]There are no endpoints found.[/red]")
294
+ console.print(table)
263
295
 
264
296
 
265
297
  # ------------------------------------------------------------------------------
@@ -305,7 +337,11 @@ update_endpoint_args = cmd2.Cmd2ArgumentParser(
305
337
  "('uri', 'my-uri-loc'), ('token', 'my-token-here')]\""
306
338
  ),
307
339
  )
308
- update_endpoint_args.add_argument("endpoint", type=str, help="ID of the endpoint. This field is not editable via CLI after creation.")
340
+ update_endpoint_args.add_argument(
341
+ "endpoint",
342
+ type=str,
343
+ help="ID of the endpoint. This field is not editable via CLI after creation.",
344
+ )
309
345
  update_endpoint_args.add_argument(
310
346
  "update_kwargs", type=str, help="Update endpoint key/value"
311
347
  )
@@ -323,3 +359,47 @@ delete_endpoint_args = cmd2.Cmd2ArgumentParser(
323
359
  epilog="Example:\n delete_endpoint openai-gpt4",
324
360
  )
325
361
  delete_endpoint_args.add_argument("endpoint", type=str, help="ID of the endpoint")
362
+
363
+ # List endpoint arguments
364
+ list_endpoints_args = cmd2.Cmd2ArgumentParser(
365
+ description="List all endpoints.",
366
+ epilog='Example:\n list_endpoints -f "gpt"',
367
+ )
368
+
369
+ list_endpoints_args.add_argument(
370
+ "-f",
371
+ "--find",
372
+ type=str,
373
+ help="Optional field to find endpoint(s) with keyword",
374
+ nargs="?",
375
+ )
376
+
377
+ list_endpoints_args.add_argument(
378
+ "-p",
379
+ "--pagination",
380
+ type=str,
381
+ help="Optional tuple to paginate endpoint(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
382
+ nargs="?",
383
+ )
384
+
385
+ # List connector types arguments
386
+ list_connector_types_args = cmd2.Cmd2ArgumentParser(
387
+ description="List all connector types.",
388
+ epilog='Example:\n list_connector_types -f "openai"',
389
+ )
390
+
391
+ list_connector_types_args.add_argument(
392
+ "-f",
393
+ "--find",
394
+ type=str,
395
+ help="Optional field to find connector type(s) with keyword",
396
+ nargs="?",
397
+ )
398
+
399
+ list_connector_types_args.add_argument(
400
+ "-p",
401
+ "--pagination",
402
+ type=str,
403
+ help="Optional tuple to paginate connector type(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
404
+ nargs="?",
405
+ )
@@ -0,0 +1,66 @@
1
+ from ast import literal_eval
2
+
3
+ import cmd2
4
+ from rich.console import Console
5
+
6
+ from moonshot.api import (
7
+ api_create_datasets,
8
+ )
9
+
10
+ console = Console()
11
+ def add_dataset(args) -> None:
12
+ """
13
+ Create a new dataset using the provided arguments and log the result.
14
+
15
+ This function attempts to create a new dataset by calling the `api_create_datasets`
16
+ function with the necessary parameters extracted from `args`. If successful, it logs
17
+ the creation of the dataset with its ID. If an exception occurs, it logs the error.
18
+
19
+ Args:
20
+ args: An argparse.Namespace object containing the following attributes:
21
+ - name (str): Name of the new dataset.
22
+ - description (str): Description of the new dataset.
23
+ - reference (str): Reference URL for the new dataset.
24
+ - license (str): License type for the new dataset.
25
+ - method (str): Method to convert the new dataset ('hf' or 'csv').
26
+ - params (dict): Additional parameters for dataset creation.
27
+ """
28
+ try:
29
+ new_dataset_id = api_create_datasets(
30
+ args.name,
31
+ args.description,
32
+ args.reference,
33
+ args.license,
34
+ args.method,
35
+ **args.params,
36
+ )
37
+ print(f"[add_dataset]: Dataset ({new_dataset_id}) created.")
38
+ except Exception as e:
39
+ print(f"[add_dataset]: {str(e)}")
40
+
41
+ # ------------------------------------------------------------------------------
42
+ # Cmd2 Arguments Parsers
43
+ # ------------------------------------------------------------------------------
44
+ # Add dataset arguments
45
+ add_dataset_args = cmd2.Cmd2ArgumentParser(
46
+ description="Add a new dataset. The 'name' argument will be slugified to create a unique identifier.",
47
+ epilog=(
48
+ "Examples:\n"
49
+ "1. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'csv' \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
50
+ "2. add_dataset 'dataset-name' 'A brief description' 'http://reference.com' 'MIT' 'hf' \"{'dataset_name': 'cais/mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['question','choices'], 'target_col': 'answer'}\""
51
+ ),
52
+ )
53
+ add_dataset_args.add_argument("name", type=str, help="Name of the new dataset")
54
+ add_dataset_args.add_argument("description", type=str, help="Description of the new dataset")
55
+ add_dataset_args.add_argument("reference", type=str, help="Reference of the new dataset")
56
+ add_dataset_args.add_argument("license", type=str, help="License of the new dataset")
57
+ add_dataset_args.add_argument("method", type=str, choices=['hf', 'csv'], help="Method to convert the new dataset. Choose either 'hf' or 'csv'.")
58
+ add_dataset_args.add_argument(
59
+ "params",
60
+ type=literal_eval,
61
+ help=(
62
+ "Params of the new dataset in dictionary format. For example: \n"
63
+ "1. For 'csv' method: \"{'csv_file_path': '/path/to/your/file.csv'}\"\n"
64
+ "2. For 'hf' method: \"{'dataset_name': 'cais_mmlu', 'dataset_config': 'college_biology', 'split': 'test', 'input_col': ['questions','choices'], 'target_col': 'answer'}\""
65
+ )
66
+ )
@@ -1,8 +1,11 @@
1
+ from ast import literal_eval
2
+
1
3
  import cmd2
2
4
  from rich.console import Console
3
5
  from rich.table import Table
4
6
 
5
7
  from moonshot.api import api_delete_prompt_template, api_get_all_prompt_template_detail
8
+ from moonshot.integrations.cli.utils.process_data import filter_data
6
9
 
7
10
  console = Console()
8
11
 
@@ -10,13 +13,34 @@ console = Console()
10
13
  # ------------------------------------------------------------------------------
11
14
  # CLI Functions
12
15
  # ------------------------------------------------------------------------------
13
- def list_prompt_templates() -> None:
16
+ def list_prompt_templates(args) -> list | None:
14
17
  """
15
18
  List all prompt templates available.
19
+
20
+ Args:
21
+ args: A namespace object from argparse. It should have an optional attribute:
22
+ find (str): Optional field to find prompt template(s) with a keyword.
23
+ pagination (str): Optional field to paginate prompt templates.
24
+
25
+ Returns:
26
+ list | None: A list of PromptTemplate or None if there is no result.
16
27
  """
17
28
  try:
18
- prompt_templates = api_get_all_prompt_template_detail()
19
- display_prompt_templates(prompt_templates)
29
+ prompt_templates_list = api_get_all_prompt_template_detail()
30
+ keyword = args.find.lower() if args.find else ""
31
+ pagination = literal_eval(args.pagination) if args.pagination else ()
32
+
33
+ if prompt_templates_list:
34
+ filtered_prompt_templates_list = filter_data(
35
+ prompt_templates_list, keyword, pagination
36
+ )
37
+ if filtered_prompt_templates_list:
38
+ _display_prompt_templates(filtered_prompt_templates_list)
39
+ return filtered_prompt_templates_list
40
+
41
+ console.print("[red]There are no prompt templates found.[/red]")
42
+ return None
43
+
20
44
  except Exception as e:
21
45
  print(f"[list_prompt_templates]: {str(e)}")
22
46
 
@@ -46,7 +70,7 @@ def delete_prompt_template(args) -> None:
46
70
  # ------------------------------------------------------------------------------
47
71
  # Helper functions: Display on cli
48
72
  # ------------------------------------------------------------------------------
49
- def display_prompt_templates(prompt_templates) -> None:
73
+ def _display_prompt_templates(prompt_templates) -> None:
50
74
  """
51
75
  Display the list of prompt templates in a formatted table.
52
76
 
@@ -66,21 +90,13 @@ def display_prompt_templates(prompt_templates) -> None:
66
90
  table.add_column("No.", width=2)
67
91
  table.add_column("Prompt Template", justify="left", width=50)
68
92
  table.add_column("Contains", justify="left", width=48, overflow="fold")
69
- if prompt_templates:
70
- for prompt_index, prompt_template in enumerate(prompt_templates, 1):
71
- (
72
- id,
73
- name,
74
- description,
75
- contents,
76
- ) = prompt_template.values()
77
-
78
- prompt_info = f"[red]id: {id}[/red]\n\n[blue]{name}[/blue]\n{description}"
79
- table.add_section()
80
- table.add_row(str(prompt_index), prompt_info, contents)
81
- console.print(table)
82
- else:
83
- console.print("[red]There are no prompt templates found.[/red]")
93
+ for idx, prompt_template in enumerate(prompt_templates, 1):
94
+ (id, name, description, contents, *other_args) = prompt_template.values()
95
+ idx = prompt_template.get("idx", idx)
96
+ prompt_info = f"[red]id: {id}[/red]\n\n[blue]{name}[/blue]\n{description}"
97
+ table.add_section()
98
+ table.add_row(str(idx), prompt_info, contents)
99
+ console.print(table)
84
100
 
85
101
 
86
102
  # Delete prompt template arguments
@@ -92,3 +108,25 @@ delete_prompt_template_args = cmd2.Cmd2ArgumentParser(
92
108
  delete_prompt_template_args.add_argument(
93
109
  "prompt_template", type=str, help="The ID of the prompt template to delete"
94
110
  )
111
+
112
+ # List prompt template arguments
113
+ list_prompt_templates_args = cmd2.Cmd2ArgumentParser(
114
+ description="List all prompt templates.",
115
+ epilog='Example:\n list_prompt_templates -f "toxicity"',
116
+ )
117
+
118
+ list_prompt_templates_args.add_argument(
119
+ "-f",
120
+ "--find",
121
+ type=str,
122
+ help="Optional field to find prompt template(s) with keyword",
123
+ nargs="?",
124
+ )
125
+
126
+ list_prompt_templates_args.add_argument(
127
+ "-p",
128
+ "--pagination",
129
+ type=str,
130
+ help="Optional tuple to paginate prompt template(s). E.g. (2,10) returns 2nd page with 10 items in each page.",
131
+ nargs="?",
132
+ )