model-train-protocol 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. model_train_protocol/Protocol.py +193 -0
  2. model_train_protocol/__init__.py +24 -0
  3. model_train_protocol/_internal/ProtocolFile.py +236 -0
  4. model_train_protocol/_internal/TemplateFile.py +148 -0
  5. model_train_protocol/_internal/__init__.py +0 -0
  6. model_train_protocol/common/__init__.py +0 -0
  7. model_train_protocol/common/constants.py +21 -0
  8. model_train_protocol/common/guardrails/Guardrail.py +51 -0
  9. model_train_protocol/common/guardrails/__init__.py +9 -0
  10. model_train_protocol/common/instructions/Instruction.py +213 -0
  11. model_train_protocol/common/instructions/SimpleInstruction.py +46 -0
  12. model_train_protocol/common/instructions/UserInstruction.py +72 -0
  13. model_train_protocol/common/instructions/__init__.py +13 -0
  14. model_train_protocol/common/pydantic/__init__.py +0 -0
  15. model_train_protocol/common/pydantic/protocol.py +157 -0
  16. model_train_protocol/common/tokens/NumListToken.py +21 -0
  17. model_train_protocol/common/tokens/NumToken.py +33 -0
  18. model_train_protocol/common/tokens/SpecialToken.py +35 -0
  19. model_train_protocol/common/tokens/Token.py +95 -0
  20. model_train_protocol/common/tokens/TokenSet.py +124 -0
  21. model_train_protocol/common/tokens/UserToken.py +19 -0
  22. model_train_protocol/common/tokens/__init__.py +21 -0
  23. model_train_protocol/common/util.py +57 -0
  24. model_train_protocol-0.1.7.dist-info/METADATA +323 -0
  25. model_train_protocol-0.1.7.dist-info/RECORD +28 -0
  26. model_train_protocol-0.1.7.dist-info/WHEEL +5 -0
  27. model_train_protocol-0.1.7.dist-info/licenses/LICENSE +21 -0
  28. model_train_protocol-0.1.7.dist-info/top_level.txt +1 -0
@@ -0,0 +1,193 @@
1
+ import json
2
+ import os
3
+
4
+ from . import Token
5
+ from ._internal.ProtocolFile import ProtocolFile
6
+ from ._internal.TemplateFile import TemplateFile
7
+ from .common.constants import BOS_TOKEN, EOS_TOKEN, RUN_TOKEN, PAD_TOKEN, UNK_TOKEN
8
+ from .common.instructions.Instruction import Instruction
9
+ from .common.tokens.SpecialToken import SpecialToken
10
+ from .common.util import get_possible_emojis, hash_string, validate_string_set
11
+
12
+
13
+ class Protocol:
14
+ """Model Training Protocol (MTP) class for creating the training configuration."""
15
+
16
+ def __init__(self, name: str, context_lines: int, encrypt: bool = True):
17
+ """
18
+ Initialize the Model Training Protocol (MTP)
19
+
20
+ :param name: The name of the protocol.
21
+ :param context_lines: The number of lines in each instruction sample. Must be at least 2.
22
+ :param encrypt: Whether to encrypt Tokens with unspecified with hashed keys. Default is True.
23
+ """
24
+ self.name: str = name
25
+ self.context_lines: int = context_lines # Number of lines in instruction samples
26
+ self.encrypt: bool = encrypt
27
+ if self.context_lines < 2:
28
+ raise ValueError("A minimum of 2 context lines is required for all instructions.")
29
+ self.context: list[str] = []
30
+ self.tokens: set[Token] = set()
31
+ self.instructions: set[Instruction] = set()
32
+ self.guardrails: dict[str, list[str]] = dict()
33
+ self.numbers: dict[str, str] = dict()
34
+ self.none = None
35
+ self.special_tokens: set[Token] = set()
36
+ self.possible_emoji_keys: set[str] = get_possible_emojis()
37
+ self.used_keys: set[str] = set()
38
+
39
+ def add_context(self, context: str):
40
+ """Adds a line of context to the model."""
41
+ if not isinstance(context, str):
42
+ raise TypeError("Context must be a string.")
43
+
44
+ self.context.append(context)
45
+
46
+ def add_instruction(self, instruction: Instruction):
47
+ """
48
+ Adds an Instruction (and its components) to the protocol.
49
+
50
+ Asserts that all samples in the instruction match the defined sample line size.
51
+ """
52
+ if instruction in self.instructions:
53
+ raise ValueError("Instruction already added to the protocol.")
54
+
55
+ if len(instruction.samples) == 0:
56
+ raise ValueError("Instruction must have at least three samples.")
57
+
58
+ # Assert all samples match the defined sample line size
59
+ for sample in instruction.samples:
60
+ if not len(sample.context) == self.context_lines:
61
+ raise ValueError(
62
+ f"Sample context lines ({len(sample.context)}) does not match defined context_lines count ({self.context_lines})"
63
+ f"\n{sample}."
64
+ )
65
+
66
+ # Add all tokens
67
+ for token in instruction.get_tokens():
68
+ if token not in self.tokens:
69
+ self._add_token(token)
70
+
71
+ # Add the instruction to the protocol
72
+ self.instructions.add(instruction)
73
+
74
+ def save(self, name: str | None = None, path: str | None = None):
75
+ """
76
+ Saves the protocol to a JSON file. This file can be submitted to Databiomes for model training.
77
+
78
+ :param name: The name of the file (without extension). If None, uses the protocol's name.
79
+ :param path: The directory path where the file will be saved. If None, saves in the current directory.
80
+ """
81
+ if name is None:
82
+ name = self.name
83
+ if path is None:
84
+ path = os.getcwd()
85
+ os.makedirs(path, exist_ok=True)
86
+ filename = f"{path}\\{name}_model.json"
87
+
88
+ self._prep_protocol()
89
+ protocol_file: ProtocolFile = ProtocolFile(
90
+ name=self.name, context=self.context, context_lines=self.context_lines,
91
+ tokens=self.tokens, special_tokens=self.special_tokens, instructions=self.instructions,
92
+ )
93
+
94
+ print(f"Saving Model Train Protocol to {filename}...")
95
+ with open(filename, 'w', encoding="utf-8") as file:
96
+ json.dump(protocol_file.to_json(), file, indent=4, ensure_ascii=False)
97
+
98
+ def template(self, path: str | None = None):
99
+ """
100
+ Create a template JSON file for the model training protocol.
101
+
102
+ The template json file includes example usage and all possible combinations of model inputs and
103
+ outputs based on the defined tokens and instructions.
104
+
105
+ :param path: The directory path where the template file will be saved. If None, saves in the current directory.
106
+ """
107
+ if path is None:
108
+ path = os.getcwd()
109
+ filename = f"{path}\\{self.name}_template.json"
110
+
111
+ self._prep_protocol()
112
+ template_file: TemplateFile = TemplateFile(
113
+ instructions=list(self.instructions),
114
+ context_lines=self.context_lines
115
+ )
116
+
117
+ print(f"Saving Model Train Protocol Template to {filename}...")
118
+ with open(filename, 'w', encoding="utf-8") as file:
119
+ json.dump(template_file.to_json(), file, indent=4, ensure_ascii=False)
120
+
121
+ def _assign_key(self, token: Token):
122
+ """
123
+ Assigns a key to a Token based on the protocol's encryption setting.
124
+
125
+ :param token: The Token to assign the key of.
126
+ """
127
+ # If the user has assigned a key, use this key
128
+ if token.key is not None:
129
+ return
130
+
131
+ if self.encrypt:
132
+ # Generate a random key for the token if encrypting and no key is set
133
+ token.key = hash_string(key=token.value, output_char=6)
134
+ else:
135
+ # Use the value as the key if not encrypting. I.e. Token 'Continue_' has key 'Continue_'
136
+ token.key = token.value
137
+
138
+ def _add_token(self, token: Token):
139
+ """
140
+ Adds a unique token to the protocol.
141
+
142
+ Validates that the token's value and key are unique.
143
+ :param token: The Token instance to add.
144
+ """
145
+ if token in self.tokens:
146
+ raise ValueError(f"Token value '{token.value}' already used.")
147
+
148
+ if token.key in self.used_keys:
149
+ raise ValueError(f"Token key '{token.key}' already used.")
150
+
151
+ self._assign_key(token=token)
152
+
153
+ self.tokens.add(token)
154
+ self.used_keys.add(token.key)
155
+
156
+ if isinstance(token, SpecialToken):
157
+ self.special_tokens.add(token)
158
+
159
+ def _set_guardrails(self):
160
+ """Sets all guardrails from TokenSets into the protocol."""
161
+ # Add all guardrails to the protocol
162
+ for instruction in self.instructions:
163
+ if instruction.response.guardrail is not None:
164
+ # instruction.response is the user TokenSet
165
+ self.guardrails[instruction.response.key] = instruction.response.guardrail.format_samples()
166
+
167
+ def _add_default_special_tokens(self):
168
+ """Adds all special tokens to the protocol."""
169
+ self.special_tokens.add(BOS_TOKEN)
170
+ self.special_tokens.add(EOS_TOKEN)
171
+ self.special_tokens.add(RUN_TOKEN)
172
+ self.special_tokens.add(PAD_TOKEN)
173
+ if len(self.guardrails) > 0:
174
+ self.special_tokens.add(UNK_TOKEN)
175
+
176
+ def _prep_protocol(self):
177
+ """
178
+ Sets all elements in the protocol before serialization.
179
+
180
+ Raises errors if any validation checks fail.
181
+
182
+ Setups up all necessary components in the protocol before saving or templating.
183
+
184
+ This includes setting guardrails from their TokenSets and creating default special tokens.
185
+ """
186
+ if len(self.instructions) == 0:
187
+ raise ValueError("No instructions have been added to Protocol. Call protocol.add_instruction() to add instructions.")
188
+
189
+ self._set_guardrails()
190
+ self._add_default_special_tokens()
191
+ used_values: set[str] = {token.value for token in self.tokens}
192
+ validate_string_set(used_values)
193
+ validate_string_set(self.used_keys)
@@ -0,0 +1,24 @@
1
+ """
2
+ Model Train Protocol (MTP) - A Python package for creating custom Language Model training protocols.
3
+
4
+ MTP is an open-source protocol for training custom Language Models on Databiomes.
5
+ MTP contains all the data that a model is trained on.
6
+ """
7
+
8
+ from .common.tokens import Token, UserToken, NumToken, NumListToken, Snippet, TokenSet
9
+ from .common.instructions import SimpleInstruction, UserInstruction
10
+ from .common.guardrails import Guardrail
11
+ from .Protocol import Protocol
12
+
13
+ __all__ = [
14
+ "Protocol",
15
+ "Token",
16
+ "UserToken",
17
+ "NumToken",
18
+ "NumListToken",
19
+ "TokenSet",
20
+ "Snippet",
21
+ "SimpleInstruction",
22
+ "UserInstruction",
23
+ "Guardrail"
24
+ ]
@@ -0,0 +1,236 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Collection
3
+
4
+ from model_train_protocol import Token, NumToken
5
+ from model_train_protocol.common.constants import UNK_TOKEN
6
+ from model_train_protocol.common.instructions import Instruction
7
+ from model_train_protocol.common.guardrails import Guardrail
8
+ from model_train_protocol.common.tokens import TokenSet, SpecialToken
9
+ from model_train_protocol.common.pydantic.protocol import InstructionModel, TokenInfoModel, SampleModel, InstructionSetModel, NumberModel, \
10
+ BatchModel, ProtocolModel
11
+
12
+
13
+ class ProtocolFile:
14
+ """Manages the model.json file for model training protocols."""
15
+
16
+ @dataclass
17
+ class ProtocolInstruction:
18
+ """Represents an instruction in the template."""
19
+
20
+ context_lines: int
21
+ sets: list = field(default_factory=list)
22
+
23
+ @dataclass
24
+ class ProtocolInstructionSet:
25
+ """Represents an instruction set in the template."""
26
+
27
+ set: list[list[str]]
28
+ result: str
29
+ samples: list
30
+ ppo: list
31
+
32
+ @dataclass
33
+ class Batches:
34
+ """Represents batches in the template."""
35
+
36
+ pretrain: list = field(default_factory=list)
37
+ instruct: list = field(default_factory=list)
38
+ judge: list = field(default_factory=list)
39
+ ppo: list = field(default_factory=list)
40
+
41
+ def __init__(self, name: str, context: list[str], context_lines: int, tokens: Collection[Token],
42
+ special_tokens: Collection[Token], instructions: Collection[Instruction]):
43
+ """Initializes the Template with a name and context."""
44
+ self._name: str = name
45
+ self._context: list[str] = context
46
+ self._tokens: dict[str, dict] = {}
47
+ self._special_token_keys: set[str] = set()
48
+ self._instruction_token_keys: set[str] = set()
49
+ self._instruction: ProtocolFile.ProtocolInstruction = ProtocolFile.ProtocolInstruction(
50
+ context_lines=context_lines)
51
+ self._guardrails: dict[str, list[str] | str] = {'None': ''}
52
+ self._numbers: dict[str, str] = {'None': ''}
53
+ self._batches: ProtocolFile.Batches = ProtocolFile.Batches()
54
+
55
+ # Add regular tokens
56
+ self.add_tokens(tokens)
57
+
58
+ # Add special tokens
59
+ self.add_tokens(special_tokens)
60
+
61
+ # Add instructions
62
+ self.add_instructions(instructions)
63
+
64
+ def add_tokens(self, tokens: Collection[Token]):
65
+ """Adds tokens to the template."""
66
+ for token in tokens:
67
+ token_dict: dict[str, dict] = token.to_dict()
68
+ token_dict.pop("value")
69
+ self._tokens[token.value] = token_dict
70
+
71
+ # Add numbers to the numbers dictionary
72
+ if isinstance(token, NumToken):
73
+ self._numbers[token.value] = token.protocol_representation
74
+
75
+ # Add special tokens to the special tokens set
76
+ if isinstance(token, SpecialToken):
77
+ self._special_token_keys.add(token.key)
78
+
79
+ def add_instructions(self, instructions: Collection[Instruction]):
80
+ """Adds instructions to the template."""
81
+ for instruction in instructions:
82
+ instruction_set: ProtocolFile.ProtocolInstructionSet = ProtocolFile.ProtocolInstructionSet(
83
+ set=instruction.serialize_memory_set(),
84
+ result=instruction.final.value,
85
+ samples=instruction.serialize_samples(),
86
+ ppo=instruction.serialize_ppo(),
87
+ )
88
+ self._instruction.sets.append(instruction_set)
89
+
90
+ # Add guardrails from the instruction's TokenSets
91
+ self._add_guardrails(instruction.get_token_sets())
92
+
93
+ # Add instruction token keys
94
+ for token_set in instruction.get_token_sets():
95
+ self._add_instruction_token_key(token_set.get_token_key_set())
96
+
97
+ # Add the result token as a special token
98
+ if instruction.final.key is not None:
99
+ self._add_instruction_token_key(instruction.final.key)
100
+
101
+ def _add_instruction_token_key(self, key: str):
102
+ """Adds an instruction token key to the template."""
103
+ self._instruction_token_keys.add(key)
104
+
105
+ def _add_guardrails(self, token_sets: Collection[TokenSet]):
106
+ """Adds guardrails from TokenSets to the template."""
107
+ for token_set in token_sets:
108
+ if token_set.guardrail is None:
109
+ continue
110
+ guardrail: Guardrail = token_set.guardrail
111
+ self._guardrails[token_set.key] = guardrail.format_samples()
112
+
113
+ @classmethod
114
+ def _rename_protocol_elements(cls, protocol_json: dict):
115
+ """
116
+ Renames elements in the ProtocolFile json to match the previous output format for backwards compatibility.
117
+ :param protocol_json: The original json dictionary.
118
+ :return: The modified json with renamed elements.
119
+ """
120
+ # Add special token <UNK> REGARDLESS of whether we have any guardrails
121
+ unk_token_dict: dict = UNK_TOKEN.to_dict()
122
+ unk_token_dict['emoji'] = unk_token_dict.pop('key')
123
+ unk_token_dict.pop('value')
124
+ protocol_json['tokens'][UNK_TOKEN.value] = unk_token_dict
125
+ protocol_json['special_tokens'].append(UNK_TOKEN.key)
126
+
127
+ for token_value, token_info in protocol_json.get('tokens', {}).items():
128
+ # Rename Token 'key' to 'emoji'
129
+ if 'key' in token_info:
130
+ token_info['emoji'] = token_info.pop('key')
131
+
132
+ # Reassign Token 'num' to boolean
133
+ if 'num' in token_info:
134
+ num: int = token_info['num']
135
+ token_info['num'] = True if num >= 1 else False
136
+
137
+ # TODO: Differentiate between num and num_list tokens in the future - currently both are just 'num': True
138
+
139
+ for instruction in protocol_json.get('instruction', {}).get('sets', []):
140
+
141
+ # Rename sample number to None if an array of empty arrays
142
+ for sample in instruction['samples']:
143
+ if all(num == [] for num in sample['number']):
144
+ sample['number'] = None
145
+
146
+ # Rename sample 'strings' to 'sample'
147
+ for sample in instruction['samples']:
148
+ if 'strings' in sample:
149
+ sample['sample'] = sample.pop('strings')
150
+
151
+ # Rename null values to "None"
152
+ for sample in instruction['samples']:
153
+ if sample['value'] is None:
154
+ sample['value'] = "None"
155
+
156
+ return protocol_json
157
+
158
+ def _get_special_token_keys(self):
159
+ """
160
+ Returns a sorted list of tokens that should be under 'special_tokens' in the JSON.
161
+
162
+ :return: A sorted list of special token keys.
163
+ """
164
+ return sorted(self._special_token_keys | self._instruction_token_keys)
165
+
166
+ def to_json(self):
167
+ """Converts the template to a JSON-compatible dictionary using Pydantic models."""
168
+
169
+ # Create TokenInfo objects for each token
170
+ token_info_dict = {}
171
+ for token_value, token_dict in self._tokens.items():
172
+ token_info = TokenInfoModel(
173
+ emoji=token_dict.get('emoji', ''),
174
+ num=token_dict.get('num', False),
175
+ user=token_dict.get('user', False),
176
+ desc=token_dict.get('desc'),
177
+ special=token_dict.get('special')
178
+ )
179
+ token_info_dict[token_value] = token_info
180
+
181
+ # Create InstructionSet objects
182
+ instruction_sets = []
183
+ for instruction_set in self._instruction.sets:
184
+ # Create Sample objects
185
+ samples = []
186
+ for sample_data in instruction_set.samples:
187
+ sample = SampleModel(
188
+ sample=sample_data.get('strings', []),
189
+ prompt=sample_data.get('prompt', ''),
190
+ number=sample_data.get('number', []),
191
+ result=sample_data.get('result', ''),
192
+ value=sample_data.get('value', '')
193
+ )
194
+ samples.append(sample)
195
+
196
+ # Create InstructionSet
197
+ instruction_set_obj = InstructionSetModel(
198
+ set=instruction_set.set,
199
+ result=instruction_set.result,
200
+ samples=samples,
201
+ ppo=instruction_set.ppo
202
+ )
203
+ instruction_sets.append(instruction_set_obj)
204
+
205
+ # Create Instruction object
206
+ instruction = InstructionModel(
207
+ memory=self._instruction.context_lines + 1, # +1 for the response line
208
+ sets=instruction_sets
209
+ )
210
+
211
+ # Create Numbers object
212
+ numbers = NumberModel()
213
+
214
+ # Create Batches object
215
+ batches = BatchModel(
216
+ pretrain=self._batches.pretrain,
217
+ instruct=self._batches.instruct,
218
+ judge=self._batches.judge,
219
+ ppo=self._batches.ppo
220
+ )
221
+
222
+ # Create ProtocolModel
223
+ protocol = ProtocolModel(
224
+ name=self._name,
225
+ context=self._context,
226
+ tokens=token_info_dict,
227
+ special_tokens=self._get_special_token_keys(),
228
+ instruction=instruction,
229
+ guardrails=self._guardrails,
230
+ numbers=numbers,
231
+ batches=batches
232
+ )
233
+
234
+ # Convert to JSON and apply backwards compatibility transformations
235
+ json_dict = protocol.model_dump(by_alias=True)
236
+ return self._rename_protocol_elements(json_dict)
@@ -0,0 +1,148 @@
1
+ from dataclasses import dataclass
2
+ from typing import Collection
3
+
4
+ from model_train_protocol import NumToken, SimpleInstruction, UserInstruction
5
+ from model_train_protocol.common.constants import BOS_TOKEN, RUN_TOKEN, EOS_TOKEN
6
+ from model_train_protocol.common.instructions import Instruction
7
+
8
+
9
+ class TemplateFile:
10
+ """Manages the model.json file for model training protocols."""
11
+
12
+ @dataclass
13
+ class ExampleUsage:
14
+ """Stores example usages of the template."""
15
+
16
+ input: str
17
+ output: str
18
+
19
+ class ModelInput:
20
+ """Represents inputs to the model."""
21
+
22
+ inputs: list[list[str]] = list()
23
+
24
+ def add_inputs_from_instructions(self, instructions: list[Instruction], context_lines: int):
25
+ """Adds input combinations from a list of instructions."""
26
+ unique_sets = {i: set() for i in range(context_lines + 1)}
27
+ for instruction in instructions:
28
+ for idx, token_set in enumerate(instruction.get_token_sets()):
29
+ token_user = [t.user for t in token_set]
30
+ token_strings = "".join([t.value for t in token_set])
31
+ token_keys = []
32
+ for token in token_set:
33
+ token_keys.append(
34
+ token.key + (token.protocol_representation if isinstance(token, NumToken) else ""))
35
+ token_keys = "".join(token_keys)
36
+ unique_sets[idx].add(str(token_strings) + ": " + (
37
+ (str(token_keys) + "USER PROMPT") if any(token_user) and (
38
+ idx == (len(unique_sets) - 1)) else str(
39
+ token_keys)) + "\n" + ("<string>" if idx != (len(instruction.context) - 1) else ""))
40
+
41
+ for input_set in unique_sets.values():
42
+ self.inputs.append(list(input_set))
43
+
44
+ def to_json(self):
45
+ """Converts the model input to a JSON-serializable dictionary."""
46
+ model_json: dict[str, Collection[str] | str] = {"<BOS>": BOS_TOKEN.key}
47
+ # Add each input sequence with its index as the key
48
+ for idx, input_seq in enumerate(self.inputs):
49
+ model_json[str(idx)] = input_seq
50
+ model_json["<RUN>"] = RUN_TOKEN.key
51
+ return model_json
52
+
53
+ class ModelOutput:
54
+ model_results: dict[str, str] = dict()
55
+ model_response: str = "<string>"
56
+
57
+ def __setitem__(self, key: str, value: str):
58
+ self.model_results[key] = value
59
+
60
+ def add_results_from_instructions(self, instructions: list[Instruction]):
61
+ """Adds model results from a list of instructions."""
62
+ for instruction in instructions:
63
+ self.model_results[str(instruction.final.value)] = str(instruction.final.key)
64
+
65
+ def to_json(self):
66
+ """Converts the model output to a JSON-serializable dictionary."""
67
+ model_json: dict[str, str | dict] = {
68
+ "model_response": self.model_response,
69
+ "model_results": {}
70
+ }
71
+ # Add each model result with its key
72
+ for key, value in self.model_results.items():
73
+ model_json["model_results"][key] = value
74
+
75
+ model_json["<EOS>"] = EOS_TOKEN.key
76
+
77
+ # Sort alphabetically for readability and consistency across runs
78
+ model_json["model_results"] = dict(sorted(model_json["model_results"].items()))
79
+
80
+ return model_json
81
+
82
+ def __init__(self, context_lines: int, instructions: list[Instruction], ):
83
+ """Initializes the template"""
84
+ self.model_input: TemplateFile.ModelInput = TemplateFile.ModelInput()
85
+ self.model_output: TemplateFile.ModelOutput = TemplateFile.ModelOutput()
86
+ self.context_lines: int = context_lines
87
+ self.instructions: list[Instruction] = instructions
88
+ self._add_io_from_instructions()
89
+
90
+ def _add_io_from_instructions(self):
91
+ """Adds input and output sequences from the instructions."""
92
+ self.model_input.add_inputs_from_instructions(self.instructions, context_lines=self.context_lines)
93
+ self.model_output.add_results_from_instructions(self.instructions)
94
+
95
+ def _create_sample_model_output(self):
96
+ """Creates a sample model output string for example usages."""
97
+ sample_output: str = ""
98
+ sample_output += self.model_output.model_response + "\n"
99
+ sorted_model_results: list[tuple[str, str]] = list(sorted(self.model_output.model_results.items()))
100
+ sample_output += sorted_model_results[0][1] + "\n"
101
+ sample_output += EOS_TOKEN.key
102
+ return sample_output
103
+
104
+ def _create_examples(self) -> dict[str, str]:
105
+ """
106
+ Creates example usages of the template.
107
+
108
+ Creates a simple instruction example and a user instruction example if available.
109
+ """
110
+ examples: dict[str, str] = dict()
111
+ simple_instruction: SimpleInstruction = next(
112
+ (i for i in self.instructions if isinstance(i, SimpleInstruction)), None)
113
+ user_instruction: UserInstruction = next(
114
+ (i for i in self.instructions if isinstance(i, UserInstruction)), None)
115
+
116
+ if simple_instruction:
117
+ simple_input: str = ""
118
+ for token_set in simple_instruction.get_token_sets():
119
+ token_strings = "".join([token.key for token in token_set])
120
+ simple_input += token_strings + "\n"
121
+ simple_input += "<string>\n"
122
+ simple_input = BOS_TOKEN.key + "\n" + simple_input + RUN_TOKEN.key + "\n"
123
+ examples["simple_instruction_input"] = simple_input + self._create_sample_model_output()
124
+
125
+ if user_instruction:
126
+ user_input: str = ""
127
+ for idx, token_set in enumerate(user_instruction.get_token_sets()):
128
+ token_strings = "".join([token.key for token in token_set])
129
+ user_input += token_strings + "\n"
130
+ user_input += "<string>\n" if idx != (len(user_instruction.get_token_sets()) - 1) else "USER PROMPT\n"
131
+ user_input = BOS_TOKEN.key + "\n" + user_input + RUN_TOKEN.key + "\n"
132
+ examples["valid_user_input"] = user_input
133
+
134
+ examples["valid_output"] = self._create_sample_model_output()
135
+
136
+ return examples
137
+
138
+ def to_json(self) -> dict:
139
+ """Converts the entire template to a JSON-serializable dictionary."""
140
+ examples: dict[str, str] = self._create_examples()
141
+ json_dict: dict = {
142
+ "all_combinations": {
143
+ "model_input": self.model_input.to_json(),
144
+ "model_output": self.model_output.to_json()
145
+ },
146
+ "example_usage": examples
147
+ }
148
+ return json_dict
File without changes
File without changes
@@ -0,0 +1,21 @@
1
+ from model_train_protocol.common.tokens import SpecialToken
2
+
3
+ NON_TOKEN: SpecialToken = SpecialToken(value="<NON>", key="<NON>", special="none")
4
+ BOS_TOKEN: SpecialToken = SpecialToken(value="<BOS>", key="<BOS>", special="start")
5
+ EOS_TOKEN: SpecialToken = SpecialToken(value="<EOS>", key="<EOS>", special="end")
6
+ RUN_TOKEN: SpecialToken = SpecialToken(value="<RUN>", key="<RUN>", special="infer")
7
+ PAD_TOKEN: SpecialToken = SpecialToken(value="<PAD>", key="<PAD>", special="pad")
8
+ UNK_TOKEN: SpecialToken = SpecialToken(value="<UNK>", key="<UNK>", special="unknown")
9
+
10
+ # TODO: Remove this code when emoji dependency is removed
11
+ """Assign default emoji keys to special tokens for backward compatibility."""
12
+ special_token_emoji_map = {
13
+ NON_TOKEN.value: "🫙",
14
+ BOS_TOKEN.value: "🏁",
15
+ EOS_TOKEN.value: "🎬",
16
+ RUN_TOKEN.value: "🏃",
17
+ PAD_TOKEN.value: "🗒",
18
+ UNK_TOKEN.value: "🛑"
19
+ }
20
+ for token in [NON_TOKEN, BOS_TOKEN, EOS_TOKEN, RUN_TOKEN, PAD_TOKEN, UNK_TOKEN]:
21
+ token.key = special_token_emoji_map[token.value]
@@ -0,0 +1,51 @@
1
+ class Guardrail:
2
+ """
3
+ Defines a guardrails response to bad prompts.
4
+
5
+ Guardrails are set on TokenSets. Each TokenSet can have at most one guardrails, but guardrails can be reused.
6
+ """
7
+
8
+ def __init__(self, good_prompt: str, bad_prompt: str, bad_output: str):
9
+ """
10
+ Initializes a Guardrail.
11
+ :param good_prompt: Description of a good prompt.
12
+ :param bad_prompt: Description of a bad prompt.
13
+ :param bad_output: The output the model should produce when a bad prompt is detected.
14
+
15
+ Example:
16
+ good_prompt="Quote being spoken with 1-20 words",
17
+ bad_prompt="Quote being spoken that is irrelevant and off-topic with 1-20 words",
18
+ output="I have no idea what you're talking about."
19
+ """
20
+ if not all(isinstance(param, str) for param in [good_prompt, bad_prompt, bad_output]):
21
+ raise TypeError("All parameters must be non-empty strings.")
22
+
23
+ if any(param == "" for param in [good_prompt, bad_prompt, bad_output]):
24
+ raise ValueError("All parameters must be non-empty strings.")
25
+
26
+ self.good_prompt: str = good_prompt
27
+ self.bad_prompt: str = bad_prompt
28
+ self.bad_output: str = bad_output
29
+ self.samples: list[str] = []
30
+
31
+ def add_sample(self, sample: str):
32
+ """
33
+ Add an example of a bad sample prompt to the guardrails.
34
+
35
+ :param sample: An example of a bad prompt that should trigger the guardrails.
36
+
37
+ Example:
38
+ sample="Tell me a joke about politics."
39
+ """
40
+ if not isinstance(sample, str) or not sample.strip():
41
+ raise ValueError("Sample prompt must be a non-empty string.")
42
+
43
+ if not all(not char.isdigit() for char in sample):
44
+ raise ValueError("Sample prompt cannot contain digits.")
45
+ self.samples.append(sample)
46
+
47
+ def format_samples(self) -> list[str]:
48
+ """Return the guardrails as a list of strings for JSON formatting."""
49
+ if len(self.samples) < 3:
50
+ raise ValueError("At least 3 sample prompts are required. Call add_sample() to add more.")
51
+ return [self.bad_output, f"<{self.bad_prompt}>", f"<{self.good_prompt}>", self.samples]