aiverify-moonshot 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (163) hide show
  1. aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
  2. aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
  3. aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
  4. aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
  5. aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
  6. aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
  7. moonshot/__init__.py +0 -0
  8. moonshot/__main__.py +198 -0
  9. moonshot/api.py +155 -0
  10. moonshot/integrations/__init__.py +0 -0
  11. moonshot/integrations/cli/__init__.py +0 -0
  12. moonshot/integrations/cli/__main__.py +25 -0
  13. moonshot/integrations/cli/active_session_cfg.py +1 -0
  14. moonshot/integrations/cli/benchmark/__init__.py +0 -0
  15. moonshot/integrations/cli/benchmark/benchmark.py +186 -0
  16. moonshot/integrations/cli/benchmark/cookbook.py +545 -0
  17. moonshot/integrations/cli/benchmark/datasets.py +164 -0
  18. moonshot/integrations/cli/benchmark/metrics.py +141 -0
  19. moonshot/integrations/cli/benchmark/recipe.py +598 -0
  20. moonshot/integrations/cli/benchmark/result.py +216 -0
  21. moonshot/integrations/cli/benchmark/run.py +140 -0
  22. moonshot/integrations/cli/benchmark/runner.py +174 -0
  23. moonshot/integrations/cli/cli.py +64 -0
  24. moonshot/integrations/cli/common/__init__.py +0 -0
  25. moonshot/integrations/cli/common/common.py +72 -0
  26. moonshot/integrations/cli/common/connectors.py +325 -0
  27. moonshot/integrations/cli/common/display_helper.py +42 -0
  28. moonshot/integrations/cli/common/prompt_template.py +94 -0
  29. moonshot/integrations/cli/initialisation/__init__.py +0 -0
  30. moonshot/integrations/cli/initialisation/initialisation.py +14 -0
  31. moonshot/integrations/cli/redteam/__init__.py +0 -0
  32. moonshot/integrations/cli/redteam/attack_module.py +70 -0
  33. moonshot/integrations/cli/redteam/context_strategy.py +147 -0
  34. moonshot/integrations/cli/redteam/prompt_template.py +67 -0
  35. moonshot/integrations/cli/redteam/redteam.py +90 -0
  36. moonshot/integrations/cli/redteam/session.py +467 -0
  37. moonshot/integrations/web_api/.env.dev +7 -0
  38. moonshot/integrations/web_api/__init__.py +0 -0
  39. moonshot/integrations/web_api/__main__.py +56 -0
  40. moonshot/integrations/web_api/app.py +125 -0
  41. moonshot/integrations/web_api/container.py +146 -0
  42. moonshot/integrations/web_api/log/.gitkeep +0 -0
  43. moonshot/integrations/web_api/logging_conf.py +114 -0
  44. moonshot/integrations/web_api/routes/__init__.py +0 -0
  45. moonshot/integrations/web_api/routes/attack_modules.py +66 -0
  46. moonshot/integrations/web_api/routes/benchmark.py +116 -0
  47. moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
  48. moonshot/integrations/web_api/routes/context_strategy.py +129 -0
  49. moonshot/integrations/web_api/routes/cookbook.py +225 -0
  50. moonshot/integrations/web_api/routes/dataset.py +120 -0
  51. moonshot/integrations/web_api/routes/endpoint.py +282 -0
  52. moonshot/integrations/web_api/routes/metric.py +78 -0
  53. moonshot/integrations/web_api/routes/prompt_template.py +128 -0
  54. moonshot/integrations/web_api/routes/recipe.py +219 -0
  55. moonshot/integrations/web_api/routes/redteam.py +609 -0
  56. moonshot/integrations/web_api/routes/runner.py +239 -0
  57. moonshot/integrations/web_api/schemas/__init__.py +0 -0
  58. moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
  59. moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
  60. moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
  61. moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
  62. moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
  63. moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
  64. moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
  65. moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
  66. moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
  67. moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
  68. moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
  69. moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
  70. moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
  71. moonshot/integrations/web_api/services/__init__.py +0 -0
  72. moonshot/integrations/web_api/services/attack_module_service.py +34 -0
  73. moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
  74. moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
  75. moonshot/integrations/web_api/services/base_service.py +8 -0
  76. moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
  77. moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
  78. moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
  79. moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
  80. moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
  81. moonshot/integrations/web_api/services/cookbook_service.py +194 -0
  82. moonshot/integrations/web_api/services/dataset_service.py +20 -0
  83. moonshot/integrations/web_api/services/endpoint_service.py +65 -0
  84. moonshot/integrations/web_api/services/metric_service.py +14 -0
  85. moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
  86. moonshot/integrations/web_api/services/recipe_service.py +155 -0
  87. moonshot/integrations/web_api/services/runner_service.py +147 -0
  88. moonshot/integrations/web_api/services/session_service.py +350 -0
  89. moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
  90. moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
  91. moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
  92. moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
  93. moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
  94. moonshot/integrations/web_api/types/types.py +99 -0
  95. moonshot/src/__init__.py +0 -0
  96. moonshot/src/api/__init__.py +0 -0
  97. moonshot/src/api/api_connector.py +58 -0
  98. moonshot/src/api/api_connector_endpoint.py +162 -0
  99. moonshot/src/api/api_context_strategy.py +57 -0
  100. moonshot/src/api/api_cookbook.py +160 -0
  101. moonshot/src/api/api_dataset.py +46 -0
  102. moonshot/src/api/api_environment_variables.py +17 -0
  103. moonshot/src/api/api_metrics.py +51 -0
  104. moonshot/src/api/api_prompt_template.py +43 -0
  105. moonshot/src/api/api_recipe.py +182 -0
  106. moonshot/src/api/api_red_teaming.py +59 -0
  107. moonshot/src/api/api_result.py +84 -0
  108. moonshot/src/api/api_run.py +74 -0
  109. moonshot/src/api/api_runner.py +132 -0
  110. moonshot/src/api/api_session.py +290 -0
  111. moonshot/src/configs/__init__.py +0 -0
  112. moonshot/src/configs/env_variables.py +187 -0
  113. moonshot/src/connectors/__init__.py +0 -0
  114. moonshot/src/connectors/connector.py +327 -0
  115. moonshot/src/connectors/connector_prompt_arguments.py +17 -0
  116. moonshot/src/connectors_endpoints/__init__.py +0 -0
  117. moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
  118. moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
  119. moonshot/src/cookbooks/__init__.py +0 -0
  120. moonshot/src/cookbooks/cookbook.py +225 -0
  121. moonshot/src/cookbooks/cookbook_arguments.py +34 -0
  122. moonshot/src/datasets/__init__.py +0 -0
  123. moonshot/src/datasets/dataset.py +255 -0
  124. moonshot/src/datasets/dataset_arguments.py +50 -0
  125. moonshot/src/metrics/__init__.py +0 -0
  126. moonshot/src/metrics/metric.py +192 -0
  127. moonshot/src/metrics/metric_interface.py +95 -0
  128. moonshot/src/prompt_templates/__init__.py +0 -0
  129. moonshot/src/prompt_templates/prompt_template.py +103 -0
  130. moonshot/src/recipes/__init__.py +0 -0
  131. moonshot/src/recipes/recipe.py +340 -0
  132. moonshot/src/recipes/recipe_arguments.py +111 -0
  133. moonshot/src/redteaming/__init__.py +0 -0
  134. moonshot/src/redteaming/attack/__init__.py +0 -0
  135. moonshot/src/redteaming/attack/attack_module.py +618 -0
  136. moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
  137. moonshot/src/redteaming/attack/context_strategy.py +131 -0
  138. moonshot/src/redteaming/context_strategy/__init__.py +0 -0
  139. moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
  140. moonshot/src/redteaming/session/__init__.py +0 -0
  141. moonshot/src/redteaming/session/chat.py +209 -0
  142. moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
  143. moonshot/src/redteaming/session/red_teaming_type.py +6 -0
  144. moonshot/src/redteaming/session/session.py +775 -0
  145. moonshot/src/results/__init__.py +0 -0
  146. moonshot/src/results/result.py +119 -0
  147. moonshot/src/results/result_arguments.py +44 -0
  148. moonshot/src/runners/__init__.py +0 -0
  149. moonshot/src/runners/runner.py +476 -0
  150. moonshot/src/runners/runner_arguments.py +46 -0
  151. moonshot/src/runners/runner_type.py +6 -0
  152. moonshot/src/runs/__init__.py +0 -0
  153. moonshot/src/runs/run.py +344 -0
  154. moonshot/src/runs/run_arguments.py +162 -0
  155. moonshot/src/runs/run_progress.py +145 -0
  156. moonshot/src/runs/run_status.py +10 -0
  157. moonshot/src/storage/__init__.py +0 -0
  158. moonshot/src/storage/db_interface.py +128 -0
  159. moonshot/src/storage/io_interface.py +31 -0
  160. moonshot/src/storage/storage.py +525 -0
  161. moonshot/src/utils/__init__.py +0 -0
  162. moonshot/src/utils/import_modules.py +96 -0
  163. moonshot/src/utils/timeit.py +25 -0
@@ -0,0 +1,146 @@
1
+ import importlib.resources
2
+
3
+ from dependency_injector import containers, providers
4
+
5
+ from .services.attack_module_service import AttackModuleService
6
+ from .services.auto_red_team_test_manager import AutoRedTeamTestManager
7
+ from .services.auto_red_team_test_state import AutoRedTeamTestState
8
+ from .services.benchmark_result_service import BenchmarkResultService
9
+ from .services.benchmark_test_manager import BenchmarkTestManager
10
+ from .services.benchmark_test_state import BenchmarkTestState
11
+ from .services.benchmarking_service import BenchmarkingService
12
+ from .services.context_strategy_service import ContextStrategyService
13
+ from .services.cookbook_service import CookbookService
14
+ from .services.dataset_service import DatasetService
15
+ from .services.endpoint_service import EndpointService
16
+ from .services.metric_service import MetricService
17
+ from .services.prompt_template_service import PromptTemplateService
18
+ from .services.recipe_service import RecipeService
19
+ from .services.runner_service import RunnerService
20
+ from .services.session_service import SessionService
21
+ from .status_updater.moonshot_ui_webhook import MoonshotUIWebhook
22
+
23
+
24
+ class Container(containers.DeclarativeContainer):
25
+ config = providers.Configuration("config")
26
+ config.from_dict(
27
+ {
28
+ "app_environment": "DEV",
29
+ "asyncio": {
30
+ "monitor_task": False,
31
+ },
32
+ "ssl": {
33
+ "enabled": False,
34
+ "file_path": str(
35
+ importlib.resources.files("moonshot").joinpath(
36
+ "integrations/web_api/certs"
37
+ )
38
+ ),
39
+ "cert_filename": "cert.pem",
40
+ "key_filename": "key.pem",
41
+ },
42
+ "cors": {
43
+ "enabled": False,
44
+ "allowed_origins": "http://localhost:3000",
45
+ },
46
+ "log": {
47
+ "logging": True,
48
+ "level": "DEBUG",
49
+ "format": "[%(asctime)s] [%(levelname)s] [%(name)s]: %(message)s",
50
+ "log_file_path": str(
51
+ importlib.resources.files("moonshot").joinpath(
52
+ "integrations/web_api/log"
53
+ )
54
+ ),
55
+ "log_file_max_size": 5242880,
56
+ "log_file_backup_count": 3,
57
+ },
58
+ }
59
+ )
60
+
61
+ benchmark_test_state: providers.Singleton[BenchmarkTestState] = providers.Singleton(
62
+ BenchmarkTestState
63
+ )
64
+ auto_red_team_test_state: providers.Singleton[
65
+ AutoRedTeamTestState
66
+ ] = providers.Singleton(AutoRedTeamTestState)
67
+ webhook: providers.Singleton[MoonshotUIWebhook] = providers.Singleton(
68
+ MoonshotUIWebhook,
69
+ benchmark_test_state=benchmark_test_state,
70
+ auto_red_team_test_state=auto_red_team_test_state,
71
+ )
72
+ runner_service: providers.Singleton[RunnerService] = providers.Singleton(
73
+ RunnerService
74
+ )
75
+ auto_red_team_test_manager: providers.Singleton[
76
+ AutoRedTeamTestManager
77
+ ] = providers.Singleton(
78
+ AutoRedTeamTestManager,
79
+ auto_red_team_test_state=auto_red_team_test_state,
80
+ progress_status_updater=webhook,
81
+ runner_service=runner_service,
82
+ )
83
+ benchmark_test_manager: providers.Singleton[
84
+ BenchmarkTestManager
85
+ ] = providers.Singleton(
86
+ BenchmarkTestManager,
87
+ benchmark_test_state=benchmark_test_state,
88
+ progress_status_updater=webhook,
89
+ runner_service=runner_service,
90
+ )
91
+ session_service: providers.Singleton[SessionService] = providers.Singleton(
92
+ SessionService,
93
+ auto_red_team_test_manager=auto_red_team_test_manager,
94
+ progress_status_updater=webhook,
95
+ runner_service=runner_service,
96
+ )
97
+ prompt_template_service: providers.Singleton[
98
+ PromptTemplateService
99
+ ] = providers.Singleton(PromptTemplateService)
100
+ context_strategy_service: providers.Singleton[ContextStrategyService] = providers.Singleton(
101
+ ContextStrategyService
102
+ )
103
+ benchmarking_service: providers.Singleton[
104
+ BenchmarkingService
105
+ ] = providers.Singleton(
106
+ BenchmarkingService, benchmark_test_manager=benchmark_test_manager
107
+ )
108
+ endpoint_service: providers.Singleton[EndpointService] = providers.Singleton(
109
+ EndpointService
110
+ )
111
+ recipe_service: providers.Singleton[RecipeService] = providers.Singleton(
112
+ RecipeService
113
+ )
114
+ cookbook_service: providers.Singleton[CookbookService] = providers.Singleton(
115
+ CookbookService
116
+ )
117
+ benchmark_result_service: providers.Singleton[
118
+ BenchmarkResultService
119
+ ] = providers.Singleton(BenchmarkResultService)
120
+ metric_service: providers.Singleton[MetricService] = providers.Singleton(
121
+ MetricService
122
+ )
123
+
124
+ dataset_service: providers.Singleton[DatasetService] = providers.Singleton(
125
+ DatasetService,
126
+ )
127
+ am_service: providers.Singleton[AttackModuleService] = providers.Singleton(
128
+ AttackModuleService,
129
+ )
130
+ wiring_config = containers.WiringConfiguration(
131
+ modules=[
132
+ ".routes.redteam",
133
+ ".routes.prompt_template",
134
+ ".routes.context_strategy",
135
+ ".routes.benchmark",
136
+ ".routes.endpoint",
137
+ ".routes.recipe",
138
+ ".routes.cookbook",
139
+ ".routes.benchmark_result",
140
+ ".routes.metric",
141
+ ".routes.runner",
142
+ ".routes.dataset",
143
+ ".routes.attack_modules",
144
+ ".services.benchmarking_service",
145
+ ]
146
+ )
File without changes
@@ -0,0 +1,114 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from logging.handlers import RotatingFileHandler
5
+ from typing import Literal
6
+
7
+ from dependency_injector import providers
8
+
9
+ from .types.types import UvicornLoggingConfig
10
+
11
+ COLORS = {
12
+ "HEADER": "\033[95m",
13
+ "OKBLUE": "\033[94m",
14
+ "OKGREEN": "\033[92m",
15
+ "WARNING": "\033[93m",
16
+ "FAIL": "\033[91m",
17
+ "ENDC": "\033[0m",
18
+ "BOLD": "\033[1m",
19
+ "UNDERLINE": "\033[4m",
20
+ "WHITE": "\033[97m",
21
+ }
22
+
23
+
24
+ class ColorizedFormatter(logging.Formatter):
25
+ LEVEL_COLORS = {
26
+ logging.DEBUG: COLORS["OKBLUE"],
27
+ logging.INFO: COLORS["OKGREEN"],
28
+ logging.WARNING: COLORS["WARNING"],
29
+ logging.ERROR: COLORS["FAIL"],
30
+ logging.CRITICAL: COLORS["HEADER"],
31
+ }
32
+
33
+ def __init__(
34
+ self,
35
+ fmt: str,
36
+ datefmt: str | None = None,
37
+ style: Literal["%"] = "%",
38
+ disableColor: bool = False,
39
+ ):
40
+ super().__init__(fmt, datefmt, style)
41
+ self.disableColor = disableColor
42
+
43
+ def format(self, record: logging.LogRecord):
44
+ if self.disableColor:
45
+ return super().format(record)
46
+ else:
47
+ color = str(self.LEVEL_COLORS.get(record.levelno))
48
+ message = super().format(record)
49
+ return color + message + COLORS["ENDC"]
50
+
51
+
52
+ def create_logging_dir(log_file_path: str):
53
+ if not os.path.exists(log_file_path):
54
+ os.makedirs(log_file_path)
55
+
56
+
57
+ def configure_app_logging(cfg: providers.Configuration):
58
+ if cfg.log.logging():
59
+ create_logging_dir(cfg.log.log_file_path())
60
+
61
+ file_handler = RotatingFileHandler(
62
+ filename=cfg.log.log_file_path() + "/web_api.log",
63
+ maxBytes=cfg.log.log_file_max_size(),
64
+ backupCount=cfg.log.log_file_backup_count(),
65
+ )
66
+ stream_handler = logging.StreamHandler(sys.stdout)
67
+ stream_handler.setFormatter(ColorizedFormatter(cfg.log.format()))
68
+
69
+ logging.basicConfig(
70
+ handlers=[file_handler, stream_handler],
71
+ level=cfg.log.level(),
72
+ format=cfg.log.format(),
73
+ )
74
+
75
+ logging.info("Logging is configured.")
76
+
77
+
78
+ def create_uvicorn_log_config(cfg: providers.Configuration) -> UvicornLoggingConfig:
79
+ if cfg.log.logging():
80
+ create_logging_dir(cfg.log.log_file_path())
81
+
82
+ return {
83
+ "version": 1,
84
+ "disable_existing_loggers": False,
85
+ "formatters": {
86
+ "default": {
87
+ "()": "moonshot.integrations.web_api.logging_conf.ColorizedFormatter",
88
+ "format": cfg.log.format(),
89
+ },
90
+ "file_formatter": {
91
+ "()": "moonshot.integrations.web_api.logging_conf.ColorizedFormatter",
92
+ "format": cfg.log.format(),
93
+ "disableColor": True,
94
+ },
95
+ },
96
+ "handlers": {
97
+ "file": {
98
+ "class": "logging.handlers.RotatingFileHandler",
99
+ "filename": cfg.log.log_file_path() + "/web_api.log",
100
+ "maxBytes": cfg.log.log_file_max_size(),
101
+ "backupCount": cfg.log.log_file_backup_count(),
102
+ "formatter": "file_formatter",
103
+ },
104
+ "console": {
105
+ "class": "logging.StreamHandler",
106
+ "stream": "ext://sys.stdout",
107
+ "formatter": "default",
108
+ },
109
+ },
110
+ "root": {
111
+ "level": cfg.log.level(),
112
+ "handlers": ["file", "console"],
113
+ },
114
+ }
File without changes
@@ -0,0 +1,66 @@
1
+ from dependency_injector.wiring import Provide, inject
2
+ from fastapi import APIRouter, Depends, HTTPException
3
+
4
+ from ..container import Container
5
+ from ..services.attack_module_service import AttackModuleService
6
+ from ..services.utils.exceptions_handler import ServiceException
7
+
8
+ router = APIRouter(tags=["Attack Modules"])
9
+
10
+
11
+ @router.get("/api/v1/attack-modules")
12
+ @inject
13
+ def get_all_attack_module(
14
+ am_service: AttackModuleService = Depends(Provide[Container.am_service]),
15
+ ) -> list[str]:
16
+ """
17
+ Retrieve all attack modules from the database.
18
+
19
+ Args:
20
+ am_service (AttackModuleService): The service responsible for fetching attack modules.
21
+
22
+ Returns:
23
+ list: A list of attack modules if successful.
24
+
25
+ Raises:
26
+ HTTPException: An error with status code 404 if attack modules file is not found.
27
+ HTTPException: An error with status code 400 if there is a validation error with the request.
28
+ HTTPException: An error with status code 500 for any other type of server-side error.
29
+ """
30
+ try:
31
+ return am_service.get_all_attack_module()
32
+ except ServiceException as e:
33
+ if e.error_code == "FileNotFound":
34
+ raise HTTPException(
35
+ status_code=404, detail=f"Failed to retrieve attack modules: {e.msg}"
36
+ )
37
+ elif e.error_code == "ValidationError":
38
+ raise HTTPException(
39
+ status_code=400, detail=f"Failed to retrieve attack modules: {e.msg}"
40
+ )
41
+ else:
42
+ raise HTTPException(
43
+ status_code=500, detail=f"Failed to retrieve attack modules: {e.msg}"
44
+ )
45
+
46
+
47
+ @router.get("/api/v1/attack-modules/metadata")
48
+ @inject
49
+ def get_all_attack_module_metadata(
50
+ am_service: AttackModuleService = Depends(Provide[Container.am_service]),
51
+ ) -> list:
52
+ try:
53
+ return am_service.get_all_attack_module_metadata()
54
+ except ServiceException as e:
55
+ if e.error_code == "FileNotFound":
56
+ raise HTTPException(
57
+ status_code=404, detail=f"Failed to retrieve attack modules: {e.msg}"
58
+ )
59
+ elif e.error_code == "ValidationError":
60
+ raise HTTPException(
61
+ status_code=400, detail=f"Failed to retrieve attack modules: {e.msg}"
62
+ )
63
+ else:
64
+ raise HTTPException(
65
+ status_code=500, detail=f"Failed to retrieve attack modules: {e.msg}"
66
+ )
@@ -0,0 +1,116 @@
1
+ from dependency_injector.wiring import Provide, inject
2
+ from fastapi import APIRouter, Depends, HTTPException
3
+
4
+ from ..container import Container
5
+ from ..schemas.benchmark_runner_dto import BenchmarkRunnerDTO
6
+ from ..services.benchmark_test_state import BenchmarkTestState
7
+ from ..services.benchmarking_service import BenchmarkingService
8
+ from ..services.utils.exceptions_handler import ServiceException
9
+ from ..types.types import BenchmarkCollectionType
10
+
11
+ router = APIRouter(tags=["Benchmarking"])
12
+
13
+
14
+ @router.post("/api/v1/benchmarks")
15
+ @inject
16
+ async def benchmark_executor(
17
+ type: BenchmarkCollectionType,
18
+ data: BenchmarkRunnerDTO,
19
+ benchmarking_service: BenchmarkingService = Depends(
20
+ Provide[Container.benchmarking_service]
21
+ ),
22
+ ) -> dict:
23
+ """
24
+ Execute a benchmark test.
25
+
26
+ Args:
27
+ type (BenchmarkCollectionType): The type of benchmark to execute.
28
+ data (BenchmarkRunnerDTO): The data required to execute the benchmark.
29
+ benchmarking_service (BenchmarkingService, optional): The service that will execute the benchmark.
30
+
31
+ Returns:
32
+ dict: A dictionary with the 'id' key containing the ID of the created execution task.
33
+
34
+ Raises:
35
+ HTTPException: If the provided type is invalid (status code 400) or if the service fails to create
36
+ and execute the benchmark (status code 500).
37
+ """
38
+ try:
39
+ if type is BenchmarkCollectionType.COOKBOOK:
40
+ id = await benchmarking_service.execute_cookbook(data)
41
+ return {"id": id}
42
+ elif type is BenchmarkCollectionType.RECIPE:
43
+ id = await benchmarking_service.execute_recipe(data)
44
+ return {"id": id}
45
+ else:
46
+ raise HTTPException(status_code=400, detail="Invalid query parameter: type")
47
+ except ServiceException as e:
48
+ raise HTTPException(
49
+ status_code=500, detail=f"Unable to create and execute benchmark: {e}"
50
+ )
51
+
52
+
53
+ @router.get("/api/v1/benchmarks/status")
54
+ @inject
55
+ def get_benchmark_progress(
56
+ benchmark_state: BenchmarkTestState = Depends(
57
+ Provide[Container.benchmark_test_state]
58
+ ),
59
+ ):
60
+ """
61
+ Retrieve the progress status of all benchmarks.
62
+
63
+ Args:
64
+ benchmark_state (BenchmarkTestState, optional): The state service that tracks benchmark progress.
65
+
66
+ Returns:
67
+ The progress status of all benchmarks.
68
+
69
+ Raises:
70
+ HTTPException: If there is an error retrieving the progress status, with a status code indicating the
71
+ nature of the error (404 for file not found, 400 for validation error).
72
+ """
73
+ try:
74
+ all_status = benchmark_state.get_all_progress_status()
75
+ return all_status
76
+ except ServiceException as e:
77
+ if e.error_code == "FileNotFound":
78
+ raise HTTPException(
79
+ status_code=404, detail=f"Failed to retrieve progress status: {e.msg}"
80
+ )
81
+ elif e.error_code == "ValidationError":
82
+ raise HTTPException(
83
+ status_code=400, detail=f"Failed to retrieve progress status: {e.msg}"
84
+ )
85
+ else:
86
+ raise HTTPException(
87
+ status_code=500, detail=f"Failed to retrieve progress status: {e.msg}"
88
+ )
89
+
90
+
91
+ @router.post("/api/v1/benchmarks/cancel/{runner_id}")
92
+ @inject
93
+ async def cancel_benchmark_executor(
94
+ runner_id: str,
95
+ benchmarking_service: BenchmarkingService = Depends(
96
+ Provide[Container.benchmarking_service]
97
+ ),
98
+ ):
99
+ """
100
+ Cancel a benchmark execution task.
101
+
102
+ Args:
103
+ runner_id (str): The ID of the runner executing the benchmark.
104
+ benchmarking_service (BenchmarkingService): The service that will cancel the benchmark execution.
105
+
106
+ Returns:
107
+ None
108
+
109
+ Raises:
110
+ HTTPException: If the service is unable to cancel the benchmark, with a status code
111
+ 500 indicating an internal server error.
112
+ """
113
+ try:
114
+ await benchmarking_service.cancel_executor(runner_id)
115
+ except ServiceException as e:
116
+ raise HTTPException(status_code=500, detail=f"Unable to cancel benchmark: {e}")
@@ -0,0 +1,175 @@
1
+ from dependency_injector.wiring import Provide, inject
2
+ from fastapi import APIRouter, Depends, HTTPException
3
+
4
+ from ..container import Container
5
+ from ..services.benchmark_result_service import BenchmarkResultService
6
+ from ..services.utils.exceptions_handler import ServiceException
7
+
8
+ router = APIRouter(tags=["Benchmark Results"])
9
+
10
+
11
+ @router.get("/api/v1/benchmarks/results")
12
+ @inject
13
+ async def get_all_results(
14
+ benchmark_result_service: BenchmarkResultService = Depends(
15
+ Provide[Container.benchmark_result_service]
16
+ ),
17
+ ) -> list[dict]:
18
+ """
19
+ Retrieve all benchmark results.
20
+
21
+ This endpoint retrieves a list of all benchmark results from the database. Each benchmark result is
22
+ represented as a dictionary containing its associated data.
23
+
24
+ Args:
25
+ benchmark_result_service (BenchmarkResultService): The service responsible for fetching benchmark results.
26
+
27
+ Returns:
28
+ list[dict]: A list of dictionaries, each representing a single benchmark result.
29
+
30
+ Raises:
31
+ HTTPException: Raised if the results file cannot be found (404) or if an unspecified error occurs (500).
32
+ """
33
+ try:
34
+ results = benchmark_result_service.get_all_results()
35
+ return results
36
+ except ServiceException as e:
37
+ if e.error_code == "FileNotFound":
38
+ raise HTTPException(
39
+ status_code=404, detail=f"Failed to retrieve results: {e.msg}"
40
+ )
41
+ elif e.error_code == "ValidationError":
42
+ raise HTTPException(
43
+ status_code=400, detail=f"Failed to retrieve results: {e.msg}"
44
+ )
45
+ else:
46
+ raise HTTPException(
47
+ status_code=500, detail=f"Failed to retrieve results: {e.msg}"
48
+ )
49
+
50
+
51
+ @router.get("/api/v1/benchmarks/results/name")
52
+ @inject
53
+ async def get_all_results_name(
54
+ benchmark_result_service: BenchmarkResultService = Depends(
55
+ Provide[Container.benchmark_result_service]
56
+ ),
57
+ ):
58
+ """
59
+ Get all benchmark result names from the database.
60
+
61
+ This endpoint retrieves the names of all benchmark results stored in the database.
62
+
63
+ Args:
64
+ benchmark_result_service (BenchmarkResultService): The service responsible for fetching
65
+ the names of the benchmark results.
66
+
67
+ Returns:
68
+ A list of all benchmark result names.
69
+
70
+ Raises:
71
+ HTTPException: An error occurred while trying to find the result names file (404),
72
+ a validation error occurred (400), or
73
+ an unspecified error occurred (500).
74
+ """
75
+ try:
76
+ results = benchmark_result_service.get_all_result_name()
77
+ return results
78
+ except ServiceException as e:
79
+ if e.error_code == "FileNotFound":
80
+ raise HTTPException(
81
+ status_code=404, detail=f"Failed to retrieve result name: {e.msg}"
82
+ )
83
+ elif e.error_code == "ValidationError":
84
+ raise HTTPException(
85
+ status_code=400, detail=f"Failed to retrieve result name: {e.msg}"
86
+ )
87
+ else:
88
+ raise HTTPException(
89
+ status_code=500, detail=f"Failed to retrieve result name: {e.msg}"
90
+ )
91
+
92
+
93
+ @router.get("/api/v1/benchmarks/results/{result_id}")
94
+ @inject
95
+ async def get_one_results(
96
+ result_id: str,
97
+ benchmark_result_service: BenchmarkResultService = Depends(
98
+ Provide[Container.benchmark_result_service]
99
+ ),
100
+ ):
101
+ """
102
+ Retrieve a single benchmark result by its ID.
103
+
104
+ This endpoint fetches the details of a specific benchmark result identified by the provided result_id.
105
+
106
+ Args:
107
+ result_id (str): The unique identifier of the benchmark result to retrieve.
108
+ benchmark_result_service (BenchmarkResultService): The service responsible for fetching the benchmark result.
109
+
110
+ Returns:
111
+ dict: A dictionary containing the details of the benchmark result.
112
+
113
+ Raises:
114
+ HTTPException: An error occurred while trying to find the results file (404) or
115
+ an unspecified error occurred (500).
116
+ """
117
+ try:
118
+ results = benchmark_result_service.get_result_by_id(result_id)
119
+ return results
120
+ except ServiceException as e:
121
+ if e.error_code == "FileNotFound":
122
+ raise HTTPException(
123
+ status_code=404, detail=f"Failed to retrieve result: {e.msg}"
124
+ )
125
+ elif e.error_code == "ValidationError":
126
+ raise HTTPException(
127
+ status_code=400, detail=f"Failed to retrieve result: {e.msg}"
128
+ )
129
+ else:
130
+ raise HTTPException(
131
+ status_code=500, detail=f"Failed to retrieve result: {e.msg}"
132
+ )
133
+
134
+ @router.delete("/api/v1/benchmarks/results/{result_id}")
135
+ @inject
136
+ def delete_result(
137
+ result_id: str,
138
+ benchmark_result_service: BenchmarkResultService = Depends(
139
+ Provide[Container.benchmark_result_service]
140
+ ),
141
+ ) -> dict[str, str] | tuple[dict[str, str], int]:
142
+ """
143
+ Delete a benchmark result by its ID.
144
+
145
+ This endpoint deletes a specific benchmark result identified by the provided result_id.
146
+
147
+ Args:
148
+ result_id (str): The unique identifier of the benchmark result to delete.
149
+ benchmark_result_service (BenchmarkResultService): The service responsible for deleting the benchmark result.
150
+
151
+ Returns:
152
+ dict[str, str] | tuple[dict[str, str], int]: A message indicating successful deletion,
153
+ or an HTTPException with an appropriate status code.
154
+
155
+ Raises:
156
+ HTTPException: An error occurred while trying to delete the result due to the result not being found (404),
157
+ a validation error occurred (400), or
158
+ an unspecified error occurred (500).
159
+ """
160
+ try:
161
+ benchmark_result_service.delete_result(result_id)
162
+ return {"message": "Result deleted successfully"}
163
+ except ServiceException as e:
164
+ if e.error_code == "FileNotFound":
165
+ raise HTTPException(
166
+ status_code=404, detail=f"Failed to delete result: {e.msg}"
167
+ )
168
+ elif e.error_code == "ValidationError":
169
+ raise HTTPException(
170
+ status_code=400, detail=f"Failed to delete result: {e.msg}"
171
+ )
172
+ else:
173
+ raise HTTPException(
174
+ status_code=500, detail=f"Failed to delete result: {e.msg}"
175
+ )