aiverify-moonshot 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aiverify_moonshot-0.4.0.dist-info/METADATA +249 -0
- aiverify_moonshot-0.4.0.dist-info/RECORD +163 -0
- aiverify_moonshot-0.4.0.dist-info/WHEEL +4 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/AUTHORS.md +5 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/LICENSE.md +201 -0
- aiverify_moonshot-0.4.0.dist-info/licenses/NOTICES.md +3340 -0
- moonshot/__init__.py +0 -0
- moonshot/__main__.py +198 -0
- moonshot/api.py +155 -0
- moonshot/integrations/__init__.py +0 -0
- moonshot/integrations/cli/__init__.py +0 -0
- moonshot/integrations/cli/__main__.py +25 -0
- moonshot/integrations/cli/active_session_cfg.py +1 -0
- moonshot/integrations/cli/benchmark/__init__.py +0 -0
- moonshot/integrations/cli/benchmark/benchmark.py +186 -0
- moonshot/integrations/cli/benchmark/cookbook.py +545 -0
- moonshot/integrations/cli/benchmark/datasets.py +164 -0
- moonshot/integrations/cli/benchmark/metrics.py +141 -0
- moonshot/integrations/cli/benchmark/recipe.py +598 -0
- moonshot/integrations/cli/benchmark/result.py +216 -0
- moonshot/integrations/cli/benchmark/run.py +140 -0
- moonshot/integrations/cli/benchmark/runner.py +174 -0
- moonshot/integrations/cli/cli.py +64 -0
- moonshot/integrations/cli/common/__init__.py +0 -0
- moonshot/integrations/cli/common/common.py +72 -0
- moonshot/integrations/cli/common/connectors.py +325 -0
- moonshot/integrations/cli/common/display_helper.py +42 -0
- moonshot/integrations/cli/common/prompt_template.py +94 -0
- moonshot/integrations/cli/initialisation/__init__.py +0 -0
- moonshot/integrations/cli/initialisation/initialisation.py +14 -0
- moonshot/integrations/cli/redteam/__init__.py +0 -0
- moonshot/integrations/cli/redteam/attack_module.py +70 -0
- moonshot/integrations/cli/redteam/context_strategy.py +147 -0
- moonshot/integrations/cli/redteam/prompt_template.py +67 -0
- moonshot/integrations/cli/redteam/redteam.py +90 -0
- moonshot/integrations/cli/redteam/session.py +467 -0
- moonshot/integrations/web_api/.env.dev +7 -0
- moonshot/integrations/web_api/__init__.py +0 -0
- moonshot/integrations/web_api/__main__.py +56 -0
- moonshot/integrations/web_api/app.py +125 -0
- moonshot/integrations/web_api/container.py +146 -0
- moonshot/integrations/web_api/log/.gitkeep +0 -0
- moonshot/integrations/web_api/logging_conf.py +114 -0
- moonshot/integrations/web_api/routes/__init__.py +0 -0
- moonshot/integrations/web_api/routes/attack_modules.py +66 -0
- moonshot/integrations/web_api/routes/benchmark.py +116 -0
- moonshot/integrations/web_api/routes/benchmark_result.py +175 -0
- moonshot/integrations/web_api/routes/context_strategy.py +129 -0
- moonshot/integrations/web_api/routes/cookbook.py +225 -0
- moonshot/integrations/web_api/routes/dataset.py +120 -0
- moonshot/integrations/web_api/routes/endpoint.py +282 -0
- moonshot/integrations/web_api/routes/metric.py +78 -0
- moonshot/integrations/web_api/routes/prompt_template.py +128 -0
- moonshot/integrations/web_api/routes/recipe.py +219 -0
- moonshot/integrations/web_api/routes/redteam.py +609 -0
- moonshot/integrations/web_api/routes/runner.py +239 -0
- moonshot/integrations/web_api/schemas/__init__.py +0 -0
- moonshot/integrations/web_api/schemas/benchmark_runner_dto.py +13 -0
- moonshot/integrations/web_api/schemas/cookbook_create_dto.py +19 -0
- moonshot/integrations/web_api/schemas/cookbook_response_model.py +9 -0
- moonshot/integrations/web_api/schemas/dataset_response_dto.py +9 -0
- moonshot/integrations/web_api/schemas/endpoint_create_dto.py +21 -0
- moonshot/integrations/web_api/schemas/endpoint_response_model.py +11 -0
- moonshot/integrations/web_api/schemas/prompt_response_model.py +14 -0
- moonshot/integrations/web_api/schemas/prompt_template_response_model.py +10 -0
- moonshot/integrations/web_api/schemas/recipe_create_dto.py +32 -0
- moonshot/integrations/web_api/schemas/recipe_response_model.py +7 -0
- moonshot/integrations/web_api/schemas/session_create_dto.py +16 -0
- moonshot/integrations/web_api/schemas/session_prompt_dto.py +7 -0
- moonshot/integrations/web_api/schemas/session_response_model.py +38 -0
- moonshot/integrations/web_api/services/__init__.py +0 -0
- moonshot/integrations/web_api/services/attack_module_service.py +34 -0
- moonshot/integrations/web_api/services/auto_red_team_test_manager.py +86 -0
- moonshot/integrations/web_api/services/auto_red_team_test_state.py +57 -0
- moonshot/integrations/web_api/services/base_service.py +8 -0
- moonshot/integrations/web_api/services/benchmark_result_service.py +25 -0
- moonshot/integrations/web_api/services/benchmark_test_manager.py +106 -0
- moonshot/integrations/web_api/services/benchmark_test_state.py +56 -0
- moonshot/integrations/web_api/services/benchmarking_service.py +31 -0
- moonshot/integrations/web_api/services/context_strategy_service.py +22 -0
- moonshot/integrations/web_api/services/cookbook_service.py +194 -0
- moonshot/integrations/web_api/services/dataset_service.py +20 -0
- moonshot/integrations/web_api/services/endpoint_service.py +65 -0
- moonshot/integrations/web_api/services/metric_service.py +14 -0
- moonshot/integrations/web_api/services/prompt_template_service.py +39 -0
- moonshot/integrations/web_api/services/recipe_service.py +155 -0
- moonshot/integrations/web_api/services/runner_service.py +147 -0
- moonshot/integrations/web_api/services/session_service.py +350 -0
- moonshot/integrations/web_api/services/utils/exceptions_handler.py +41 -0
- moonshot/integrations/web_api/services/utils/results_formatter.py +47 -0
- moonshot/integrations/web_api/status_updater/interface/benchmark_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/interface/redteam_progress_callback.py +14 -0
- moonshot/integrations/web_api/status_updater/moonshot_ui_webhook.py +72 -0
- moonshot/integrations/web_api/types/types.py +99 -0
- moonshot/src/__init__.py +0 -0
- moonshot/src/api/__init__.py +0 -0
- moonshot/src/api/api_connector.py +58 -0
- moonshot/src/api/api_connector_endpoint.py +162 -0
- moonshot/src/api/api_context_strategy.py +57 -0
- moonshot/src/api/api_cookbook.py +160 -0
- moonshot/src/api/api_dataset.py +46 -0
- moonshot/src/api/api_environment_variables.py +17 -0
- moonshot/src/api/api_metrics.py +51 -0
- moonshot/src/api/api_prompt_template.py +43 -0
- moonshot/src/api/api_recipe.py +182 -0
- moonshot/src/api/api_red_teaming.py +59 -0
- moonshot/src/api/api_result.py +84 -0
- moonshot/src/api/api_run.py +74 -0
- moonshot/src/api/api_runner.py +132 -0
- moonshot/src/api/api_session.py +290 -0
- moonshot/src/configs/__init__.py +0 -0
- moonshot/src/configs/env_variables.py +187 -0
- moonshot/src/connectors/__init__.py +0 -0
- moonshot/src/connectors/connector.py +327 -0
- moonshot/src/connectors/connector_prompt_arguments.py +17 -0
- moonshot/src/connectors_endpoints/__init__.py +0 -0
- moonshot/src/connectors_endpoints/connector_endpoint.py +211 -0
- moonshot/src/connectors_endpoints/connector_endpoint_arguments.py +54 -0
- moonshot/src/cookbooks/__init__.py +0 -0
- moonshot/src/cookbooks/cookbook.py +225 -0
- moonshot/src/cookbooks/cookbook_arguments.py +34 -0
- moonshot/src/datasets/__init__.py +0 -0
- moonshot/src/datasets/dataset.py +255 -0
- moonshot/src/datasets/dataset_arguments.py +50 -0
- moonshot/src/metrics/__init__.py +0 -0
- moonshot/src/metrics/metric.py +192 -0
- moonshot/src/metrics/metric_interface.py +95 -0
- moonshot/src/prompt_templates/__init__.py +0 -0
- moonshot/src/prompt_templates/prompt_template.py +103 -0
- moonshot/src/recipes/__init__.py +0 -0
- moonshot/src/recipes/recipe.py +340 -0
- moonshot/src/recipes/recipe_arguments.py +111 -0
- moonshot/src/redteaming/__init__.py +0 -0
- moonshot/src/redteaming/attack/__init__.py +0 -0
- moonshot/src/redteaming/attack/attack_module.py +618 -0
- moonshot/src/redteaming/attack/attack_module_arguments.py +44 -0
- moonshot/src/redteaming/attack/context_strategy.py +131 -0
- moonshot/src/redteaming/context_strategy/__init__.py +0 -0
- moonshot/src/redteaming/context_strategy/context_strategy_interface.py +46 -0
- moonshot/src/redteaming/session/__init__.py +0 -0
- moonshot/src/redteaming/session/chat.py +209 -0
- moonshot/src/redteaming/session/red_teaming_progress.py +128 -0
- moonshot/src/redteaming/session/red_teaming_type.py +6 -0
- moonshot/src/redteaming/session/session.py +775 -0
- moonshot/src/results/__init__.py +0 -0
- moonshot/src/results/result.py +119 -0
- moonshot/src/results/result_arguments.py +44 -0
- moonshot/src/runners/__init__.py +0 -0
- moonshot/src/runners/runner.py +476 -0
- moonshot/src/runners/runner_arguments.py +46 -0
- moonshot/src/runners/runner_type.py +6 -0
- moonshot/src/runs/__init__.py +0 -0
- moonshot/src/runs/run.py +344 -0
- moonshot/src/runs/run_arguments.py +162 -0
- moonshot/src/runs/run_progress.py +145 -0
- moonshot/src/runs/run_status.py +10 -0
- moonshot/src/storage/__init__.py +0 -0
- moonshot/src/storage/db_interface.py +128 -0
- moonshot/src/storage/io_interface.py +31 -0
- moonshot/src/storage/storage.py +525 -0
- moonshot/src/utils/__init__.py +0 -0
- moonshot/src/utils/import_modules.py +96 -0
- moonshot/src/utils/timeit.py +25 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from moonshot.src.configs.env_variables import EnvVariables
|
|
6
|
+
from moonshot.src.storage.db_interface import DBInterface
|
|
7
|
+
from moonshot.src.storage.storage import Storage
|
|
8
|
+
from moonshot.src.utils.import_modules import get_instance
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ContextStrategy:
|
|
12
|
+
def __init__(self, cs_id: str) -> None:
|
|
13
|
+
self.id = cs_id
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
def load(cls, cs_id: str) -> ContextStrategy:
|
|
17
|
+
"""
|
|
18
|
+
Retrieves a context strategy module instance by its ID.
|
|
19
|
+
|
|
20
|
+
This method attempts to load a context strategy instance using the provided ID. If the context strategy instance
|
|
21
|
+
is found, it is returned. If the context strategy instance does not exist, a RuntimeError is raised.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
cs_id (str): The unique identifier of the context strategy to be retrieved.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
ContextStrategy: The retrieved context strategy instance.
|
|
28
|
+
|
|
29
|
+
Raises:
|
|
30
|
+
RuntimeError: If the context strategy instance does not exist.
|
|
31
|
+
"""
|
|
32
|
+
context_strategy_inst = get_instance(
|
|
33
|
+
cs_id,
|
|
34
|
+
Storage.get_filepath(EnvVariables.CONTEXT_STRATEGY.name, cs_id, "py"),
|
|
35
|
+
)
|
|
36
|
+
if context_strategy_inst:
|
|
37
|
+
return context_strategy_inst(cs_id)
|
|
38
|
+
else:
|
|
39
|
+
raise RuntimeError(
|
|
40
|
+
f"Unable to get defined context strategy instance - {cs_id}"
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def get_all_context_strategies() -> list[str]:
|
|
45
|
+
"""
|
|
46
|
+
Retrieves the names of all context strategy files.
|
|
47
|
+
|
|
48
|
+
This method fetches the names of all context strategy files by scanning the directory specified in the
|
|
49
|
+
EnvironmentVars. It filters out filenames containing double underscores before returning the list
|
|
50
|
+
of context strategy names.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
list: A list of context strategy file names.
|
|
54
|
+
"""
|
|
55
|
+
filepaths = []
|
|
56
|
+
context_strategy_files = Storage.get_objects(
|
|
57
|
+
EnvVariables.CONTEXT_STRATEGY.name, "py"
|
|
58
|
+
)
|
|
59
|
+
for context_strategy in context_strategy_files:
|
|
60
|
+
filepaths.append(Path(context_strategy).stem)
|
|
61
|
+
return filepaths
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def delete(cs_id: str) -> bool:
|
|
65
|
+
"""
|
|
66
|
+
Deletes a context strategy module instance by its ID.
|
|
67
|
+
|
|
68
|
+
This method attempts to delete a context strategy instance using the provided ID.
|
|
69
|
+
If the deletion is successful, it returns True.
|
|
70
|
+
|
|
71
|
+
If an error occurs during the deletion process, it prints an error message and re-raises the exception.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
cs_id (str): The unique identifier of the context strategy to be deleted.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
bool: True if the deletion was successful.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
Exception: If an error occurs during the deletion process.
|
|
81
|
+
"""
|
|
82
|
+
try:
|
|
83
|
+
Storage.delete_object(EnvVariables.CONTEXT_STRATEGY.name, cs_id, "py")
|
|
84
|
+
return True
|
|
85
|
+
|
|
86
|
+
except Exception as e:
|
|
87
|
+
print(f"Failed to delete context strategy: {str(e)}")
|
|
88
|
+
raise e
|
|
89
|
+
|
|
90
|
+
@staticmethod
|
|
91
|
+
def process_prompt_cs(
|
|
92
|
+
user_prompt: str,
|
|
93
|
+
context_strategy_name: str,
|
|
94
|
+
db_instance: DBInterface,
|
|
95
|
+
endpoint_id,
|
|
96
|
+
num_of_previous_chats: int,
|
|
97
|
+
) -> str:
|
|
98
|
+
"""
|
|
99
|
+
Process the user prompt using the specified context strategy.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
user_prompt (str): The user prompt to process.
|
|
103
|
+
context_strategy_name (str): The name of the context strategy to use.
|
|
104
|
+
db_instance (DBInterface): The database interface instance.
|
|
105
|
+
endpoint_id: The ID of the endpoint.
|
|
106
|
+
num_of_previous_chats (int): The number of previous chats to add.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
str: The processed user prompt based on the context strategy.
|
|
110
|
+
"""
|
|
111
|
+
from moonshot.src.redteaming.session.chat import Chat
|
|
112
|
+
|
|
113
|
+
context_strategy_instance = get_instance(
|
|
114
|
+
context_strategy_name,
|
|
115
|
+
Storage.get_filepath(
|
|
116
|
+
EnvVariables.CONTEXT_STRATEGY.name, context_strategy_name, "py"
|
|
117
|
+
),
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
if context_strategy_instance:
|
|
121
|
+
# get the last n chats
|
|
122
|
+
list_of_chats = Chat.get_n_chat_history(
|
|
123
|
+
db_instance, endpoint_id, num_of_previous_chats
|
|
124
|
+
)
|
|
125
|
+
context_strategy_instance = context_strategy_instance(context_strategy_name)
|
|
126
|
+
return context_strategy_instance.add_in_context(user_prompt, list_of_chats)
|
|
127
|
+
else:
|
|
128
|
+
print(
|
|
129
|
+
"Cannot load context strategy. Make sure the name of the context strategy is correct."
|
|
130
|
+
)
|
|
131
|
+
return ""
|
|
File without changes
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
|
|
3
|
+
from moonshot.src.utils.timeit import timeit
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ContextStrategyInterface:
|
|
7
|
+
@abstractmethod
|
|
8
|
+
@timeit
|
|
9
|
+
def get_metadata(self) -> dict | None:
|
|
10
|
+
"""
|
|
11
|
+
Abstract method to retrieve metadata from the context strategy.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
dict | None: Returns a dictionary containing the metadata of the context strategy, or None if the
|
|
15
|
+
operation was unsuccessful.
|
|
16
|
+
"""
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def add_in_context(
|
|
21
|
+
self, user_prompt: str, list_of_previous_prompts: list[dict] = []
|
|
22
|
+
) -> None:
|
|
23
|
+
"""
|
|
24
|
+
Abstract method to add a user prompt and list of previous prompts to the context strategy.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
user_prompt (str): The user prompt to be added to the context strategy.
|
|
28
|
+
list_of_previous_prompts (list[dict], optional): List of previous prompts. Defaults to [].
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
None
|
|
32
|
+
"""
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
def get_number_of_prev_prompts(self, no_prev_prompts: int) -> int:
|
|
37
|
+
"""
|
|
38
|
+
Abstract method to get the number of previous prompts to be retrieved from the context strategy.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
no_prev_prompts (int): The number of previous prompts to retrieve.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
int: The number of previous prompts to be retrieved.
|
|
45
|
+
"""
|
|
46
|
+
pass
|
|
File without changes
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from slugify import slugify
|
|
6
|
+
|
|
7
|
+
from moonshot.src.storage.db_interface import DBInterface
|
|
8
|
+
from moonshot.src.storage.storage import Storage
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ChatRecord:
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
chat_record_id: str,
|
|
15
|
+
conn_id: str,
|
|
16
|
+
context_strategy: str,
|
|
17
|
+
prompt_template: str,
|
|
18
|
+
attack_module: str,
|
|
19
|
+
metric: str,
|
|
20
|
+
prompt: str,
|
|
21
|
+
prepared_prompt: str,
|
|
22
|
+
system_prompt: str,
|
|
23
|
+
predicted_result: str,
|
|
24
|
+
duration: str,
|
|
25
|
+
prompt_time: str,
|
|
26
|
+
):
|
|
27
|
+
self.chat_record_id = chat_record_id
|
|
28
|
+
self.conn_id = conn_id
|
|
29
|
+
self.context_strategy = context_strategy
|
|
30
|
+
self.prompt_template = prompt_template
|
|
31
|
+
self.attack_module = attack_module
|
|
32
|
+
self.metric = metric
|
|
33
|
+
self.prompt = prompt
|
|
34
|
+
self.prepared_prompt = prepared_prompt
|
|
35
|
+
self.system_prompt = system_prompt
|
|
36
|
+
self.predicted_result = predicted_result
|
|
37
|
+
self.duration = duration
|
|
38
|
+
self.prompt_time = prompt_time
|
|
39
|
+
|
|
40
|
+
def to_dict(self) -> dict[str, str]:
|
|
41
|
+
"""
|
|
42
|
+
Converts the ChatRecord instance into a dictionary.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
dict: A dictionary representation of the ChatRecord instance.
|
|
46
|
+
"""
|
|
47
|
+
return {
|
|
48
|
+
"chat_record_id": self.chat_record_id,
|
|
49
|
+
"conn_id": self.conn_id,
|
|
50
|
+
"context_strategy": self.context_strategy,
|
|
51
|
+
"prompt_template": self.prompt_template,
|
|
52
|
+
"attack_module": self.attack_module,
|
|
53
|
+
"metric": self.metric,
|
|
54
|
+
"prompt": self.prompt,
|
|
55
|
+
"prepared_prompt": self.prepared_prompt,
|
|
56
|
+
"system_prompt": self.system_prompt,
|
|
57
|
+
"predicted_result": self.predicted_result,
|
|
58
|
+
"duration": self.duration,
|
|
59
|
+
"prompt_time": self.prompt_time,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Chat:
|
|
64
|
+
sql_create_chat_metadata_record = """
|
|
65
|
+
INSERT INTO chat_metadata_table (
|
|
66
|
+
chat_id,endpoint,created_epoch,created_datetime)
|
|
67
|
+
VALUES(?,?,?,?)
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
sql_select_n_chat_from_chat_table = (
|
|
71
|
+
"""SELECT * FROM {} order by prompt_time desc limit {}"""
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def __init__(
|
|
75
|
+
self,
|
|
76
|
+
session_db_instance: DBInterface,
|
|
77
|
+
endpoint: str = "",
|
|
78
|
+
created_epoch: float = 0.0,
|
|
79
|
+
created_datetime: str = "",
|
|
80
|
+
chat_id: str = "",
|
|
81
|
+
):
|
|
82
|
+
self.chat_history = []
|
|
83
|
+
if chat_id:
|
|
84
|
+
db_chat_id = chat_id.replace("-", "_")
|
|
85
|
+
# There is an existing chat
|
|
86
|
+
self.chat_id = db_chat_id
|
|
87
|
+
self.endpoint = endpoint
|
|
88
|
+
self.chat_history = self.load_chat_history(session_db_instance, db_chat_id)
|
|
89
|
+
else:
|
|
90
|
+
# No existing chat, create new chat
|
|
91
|
+
created_datetime = str(created_datetime).replace("-", "_")
|
|
92
|
+
chat_id = f"{slugify(endpoint)}_{created_datetime}"
|
|
93
|
+
db_chat_id = chat_id.replace("-", "_")
|
|
94
|
+
|
|
95
|
+
sql_create_chat_history_table = f"""
|
|
96
|
+
CREATE TABLE IF NOT EXISTS {db_chat_id} (
|
|
97
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
98
|
+
connection_id text NOT NULL,
|
|
99
|
+
context_strategy int,
|
|
100
|
+
prompt_template text,
|
|
101
|
+
prompt text NOT NULL,
|
|
102
|
+
prepared_prompt text NOT NULL,
|
|
103
|
+
predicted_result text NOT NULL,
|
|
104
|
+
duration text NOT NULL,
|
|
105
|
+
prompt_time text NOT NULL
|
|
106
|
+
);
|
|
107
|
+
"""
|
|
108
|
+
Storage.create_database_table(
|
|
109
|
+
session_db_instance, sql_create_chat_history_table
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
chat_meta_tuple = (
|
|
113
|
+
chat_id,
|
|
114
|
+
endpoint,
|
|
115
|
+
created_epoch,
|
|
116
|
+
created_datetime,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
Storage.create_database_record(
|
|
120
|
+
session_db_instance,
|
|
121
|
+
chat_meta_tuple,
|
|
122
|
+
Chat.sql_create_chat_metadata_record,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
self.chat_id = db_chat_id
|
|
126
|
+
self.endpoint = endpoint
|
|
127
|
+
self.chat_history: list[ChatRecord] = []
|
|
128
|
+
|
|
129
|
+
def to_dict(self) -> dict[str, Union[str, list[dict[str, str]]]]:
|
|
130
|
+
"""
|
|
131
|
+
Converts the Chat instance into a dictionary.
|
|
132
|
+
|
|
133
|
+
This method iterates over the chat history, converting each ChatRecord instance into a dictionary
|
|
134
|
+
using its `to_dict` method. It then constructs a dictionary that includes the chat ID, endpoint,
|
|
135
|
+
and the list of chat history dictionaries.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
dict: A dictionary representation of the Chat instance, including chat ID, endpoint, and chat history.
|
|
139
|
+
"""
|
|
140
|
+
list_of_chat_history_dict = [
|
|
141
|
+
chat_record.to_dict() for chat_record in self.chat_history
|
|
142
|
+
]
|
|
143
|
+
return {
|
|
144
|
+
"chat_id": self.chat_id,
|
|
145
|
+
"endpoint": self.endpoint,
|
|
146
|
+
"chat_history": list_of_chat_history_dict,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
@staticmethod
|
|
150
|
+
def load_chat_history(session_db_instance: DBInterface, chat_id: str) -> list:
|
|
151
|
+
"""
|
|
152
|
+
Loads the chat history for a specific chat ID.
|
|
153
|
+
|
|
154
|
+
This method retrieves the chat history for a given chat ID by calling the StorageManager's method
|
|
155
|
+
to get the chat history for one endpoint. It then converts the chat record tuples into ChatRecord instances
|
|
156
|
+
and returns a list of ChatRecord objects.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
db_instance: The database instance associated with the chat session.
|
|
160
|
+
chat_id (str): The unique identifier for the chat session.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
list[ChatRecord]: A list of ChatRecord instances representing the chat history.
|
|
164
|
+
"""
|
|
165
|
+
sql_read_chat_history_for_one_endpoint = f"""SELECT * FROM {chat_id}"""
|
|
166
|
+
list_of_chat_record_tuples = Storage.read_database_records(
|
|
167
|
+
session_db_instance, sql_read_chat_history_for_one_endpoint
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
list_of_chat_records = []
|
|
171
|
+
if list_of_chat_record_tuples:
|
|
172
|
+
list_of_chat_records = [
|
|
173
|
+
ChatRecord(*chat_record_tuple).to_dict()
|
|
174
|
+
for chat_record_tuple in list_of_chat_record_tuples
|
|
175
|
+
]
|
|
176
|
+
return list_of_chat_records
|
|
177
|
+
|
|
178
|
+
@staticmethod
|
|
179
|
+
def get_n_chat_history(
|
|
180
|
+
session_db_instance: DBInterface, endpoint_id: str, num_of_previous_chats: int
|
|
181
|
+
) -> list[dict]:
|
|
182
|
+
"""
|
|
183
|
+
Loads the chat history for a specific chat ID.
|
|
184
|
+
|
|
185
|
+
This method retrieves the chat history for a given chat ID by calling the StorageManager's method
|
|
186
|
+
to get the chat history for one endpoint. It then converts the chat record tuples into ChatRecord instances
|
|
187
|
+
and returns a list of ChatRecord objects.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
db_instance: The database instance associated with the chat session.
|
|
191
|
+
chat_id (str): The unique identifier for the chat session.
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
list[ChatRecord]: A list of ChatRecord instances representing the chat history.
|
|
195
|
+
"""
|
|
196
|
+
endpoint_id = endpoint_id.replace("-", "_")
|
|
197
|
+
list_of_chat_record_tuples = Storage.read_database_records(
|
|
198
|
+
session_db_instance,
|
|
199
|
+
Chat.sql_select_n_chat_from_chat_table.format(
|
|
200
|
+
endpoint_id, num_of_previous_chats
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
list_of_chat_records = []
|
|
204
|
+
if list_of_chat_record_tuples:
|
|
205
|
+
list_of_chat_records = [
|
|
206
|
+
ChatRecord(*chat_record_tuple).to_dict()
|
|
207
|
+
for chat_record_tuple in list_of_chat_record_tuples
|
|
208
|
+
]
|
|
209
|
+
return list_of_chat_records
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
from typing import Callable
|
|
2
|
+
|
|
3
|
+
from moonshot.src.runs.run_status import RunStatus
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class RedTeamingProgress:
|
|
7
|
+
DEFAULT_CHAT_BATCH_SIZE = 5
|
|
8
|
+
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
runner_id: str,
|
|
12
|
+
red_teaming_arguments: dict,
|
|
13
|
+
run_progress_callback_func: Callable | None,
|
|
14
|
+
):
|
|
15
|
+
# Information on the run and callback for progress updating
|
|
16
|
+
self.runner_id = runner_id
|
|
17
|
+
self.chat_batch_size = red_teaming_arguments.get(
|
|
18
|
+
"chat_batch_size", RedTeamingProgress.DEFAULT_CHAT_BATCH_SIZE
|
|
19
|
+
)
|
|
20
|
+
self.chats = red_teaming_arguments.get("chats", {})
|
|
21
|
+
self.current_count = 0
|
|
22
|
+
self.run_progress_callback_func = run_progress_callback_func
|
|
23
|
+
self.status = RunStatus.PENDING
|
|
24
|
+
|
|
25
|
+
def update_red_teaming_chats(
|
|
26
|
+
self, red_teaming_prompt_arguments: dict, run_status: RunStatus
|
|
27
|
+
) -> None:
|
|
28
|
+
"""
|
|
29
|
+
This method updates the red teaming chats with the provided arguments and run status.
|
|
30
|
+
|
|
31
|
+
It calculates the response time by adding the duration to the start time. Then, it creates a dictionary
|
|
32
|
+
with the prompt, response, prompt time, and response time. This dictionary is then added to the chats.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
red_teaming_prompt_arguments (dict): A dictionary containing the arguments for the red teaming prompt.
|
|
36
|
+
It should contain all keys in AttackModule.RedTeamingPromptArguments, which are mainly:
|
|
37
|
+
- conn_id (str): The connection ID of the chat
|
|
38
|
+
- cs_id (str): The context strategy ID (if any) or ""
|
|
39
|
+
- pt_id (str): The prompt template ID (if any) or ""
|
|
40
|
+
- am_id (str): The attack module ID
|
|
41
|
+
- me_id (str): The metric ID (if any) or ""
|
|
42
|
+
- original_prompt (str): The original prompt entered by the user
|
|
43
|
+
- prepared_prompt(str): The modified and final prompt that was sent to the LLM
|
|
44
|
+
- system_prompt (str): The system prompt entered by the user (if any)
|
|
45
|
+
- response (str): The response from the LLM
|
|
46
|
+
- duration (str): The amount of time it takes to get back the response from the LLM in seconds
|
|
47
|
+
(in string)
|
|
48
|
+
- start_time (str): The datetime of the prompt (in string)
|
|
49
|
+
|
|
50
|
+
- run_status (RunStatus): The current status of the run.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
None
|
|
54
|
+
"""
|
|
55
|
+
prompt_response_dict = {
|
|
56
|
+
"conn_id": red_teaming_prompt_arguments["conn_id"],
|
|
57
|
+
"context_strategy": red_teaming_prompt_arguments["cs_id"],
|
|
58
|
+
"prompt_template": red_teaming_prompt_arguments["pt_id"],
|
|
59
|
+
"attack_module": red_teaming_prompt_arguments["am_id"],
|
|
60
|
+
"metric": red_teaming_prompt_arguments["me_id"],
|
|
61
|
+
"prompt": red_teaming_prompt_arguments["original_prompt"],
|
|
62
|
+
"prepared_prompt": red_teaming_prompt_arguments["prepared_prompt"],
|
|
63
|
+
"system_prompt": red_teaming_prompt_arguments["system_prompt"],
|
|
64
|
+
"predicted_result": red_teaming_prompt_arguments["response"],
|
|
65
|
+
"duration": red_teaming_prompt_arguments["duration"],
|
|
66
|
+
"prompt_time": red_teaming_prompt_arguments["start_time"],
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
if red_teaming_prompt_arguments["conn_id"] not in self.chats:
|
|
70
|
+
self.chats[red_teaming_prompt_arguments["conn_id"]] = []
|
|
71
|
+
self.chats[red_teaming_prompt_arguments["conn_id"]].append(prompt_response_dict)
|
|
72
|
+
self.status = run_status
|
|
73
|
+
|
|
74
|
+
def reset_chats(self) -> None:
|
|
75
|
+
"""
|
|
76
|
+
This method clears all the chat data stored in the 'chats' attribute.
|
|
77
|
+
"""
|
|
78
|
+
self.chats.clear()
|
|
79
|
+
|
|
80
|
+
def update_red_teaming_progress(self) -> None:
|
|
81
|
+
"""
|
|
82
|
+
This method updates the progress of the red teaming session.
|
|
83
|
+
|
|
84
|
+
It checks if the current count of chats is equal to or greater than the batch size. If it is, it triggers
|
|
85
|
+
a callback to notify the progress, resets the chats for the next batch, and resets the current count to zero.
|
|
86
|
+
|
|
87
|
+
Regardless of the condition, it increments the current count by one.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
None
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
None
|
|
94
|
+
"""
|
|
95
|
+
if self.current_count >= self.chat_batch_size:
|
|
96
|
+
self.notify_progress()
|
|
97
|
+
self.reset_chats()
|
|
98
|
+
self.current_count = 0
|
|
99
|
+
|
|
100
|
+
self.current_count += 1
|
|
101
|
+
|
|
102
|
+
def notify_progress(self) -> None:
|
|
103
|
+
"""
|
|
104
|
+
This method checks if a callback function for run progress exists and if so,
|
|
105
|
+
it calls the function with the current state of the red teaming progress.
|
|
106
|
+
"""
|
|
107
|
+
if self.run_progress_callback_func:
|
|
108
|
+
self.run_progress_callback_func(self.get_dict())
|
|
109
|
+
|
|
110
|
+
def get_dict(self) -> dict:
|
|
111
|
+
"""
|
|
112
|
+
This method returns a dictionary representation of the current state of the red teaming progress.
|
|
113
|
+
|
|
114
|
+
The dictionary includes the following keys:
|
|
115
|
+
- "current_runner_id": The ID of the current runner.
|
|
116
|
+
- "current_chats": The chats that will be returned during a callback.
|
|
117
|
+
- "current_batch_size": The current batch size, which indicates the number of chats returned during a callback.
|
|
118
|
+
- "current_status": The current status of the run.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
dict: A dictionary representation of the current state of the red teaming progress.
|
|
122
|
+
"""
|
|
123
|
+
return {
|
|
124
|
+
"current_runner_id": self.runner_id,
|
|
125
|
+
"current_chats": self.chats,
|
|
126
|
+
"current_batch_size": self.chat_batch_size,
|
|
127
|
+
"current_status": self.status.name,
|
|
128
|
+
}
|