model-train-protocol 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_train_protocol/Protocol.py +193 -0
- model_train_protocol/__init__.py +24 -0
- model_train_protocol/_internal/ProtocolFile.py +236 -0
- model_train_protocol/_internal/TemplateFile.py +148 -0
- model_train_protocol/_internal/__init__.py +0 -0
- model_train_protocol/common/__init__.py +0 -0
- model_train_protocol/common/constants.py +21 -0
- model_train_protocol/common/guardrails/Guardrail.py +51 -0
- model_train_protocol/common/guardrails/__init__.py +9 -0
- model_train_protocol/common/instructions/Instruction.py +213 -0
- model_train_protocol/common/instructions/SimpleInstruction.py +46 -0
- model_train_protocol/common/instructions/UserInstruction.py +72 -0
- model_train_protocol/common/instructions/__init__.py +13 -0
- model_train_protocol/common/pydantic/__init__.py +0 -0
- model_train_protocol/common/pydantic/protocol.py +157 -0
- model_train_protocol/common/tokens/NumListToken.py +21 -0
- model_train_protocol/common/tokens/NumToken.py +33 -0
- model_train_protocol/common/tokens/SpecialToken.py +35 -0
- model_train_protocol/common/tokens/Token.py +95 -0
- model_train_protocol/common/tokens/TokenSet.py +124 -0
- model_train_protocol/common/tokens/UserToken.py +19 -0
- model_train_protocol/common/tokens/__init__.py +21 -0
- model_train_protocol/common/util.py +57 -0
- model_train_protocol-0.1.7.dist-info/METADATA +323 -0
- model_train_protocol-0.1.7.dist-info/RECORD +28 -0
- model_train_protocol-0.1.7.dist-info/WHEEL +5 -0
- model_train_protocol-0.1.7.dist-info/licenses/LICENSE +21 -0
- model_train_protocol-0.1.7.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from . import Token
|
|
5
|
+
from ._internal.ProtocolFile import ProtocolFile
|
|
6
|
+
from ._internal.TemplateFile import TemplateFile
|
|
7
|
+
from .common.constants import BOS_TOKEN, EOS_TOKEN, RUN_TOKEN, PAD_TOKEN, UNK_TOKEN
|
|
8
|
+
from .common.instructions.Instruction import Instruction
|
|
9
|
+
from .common.tokens.SpecialToken import SpecialToken
|
|
10
|
+
from .common.util import get_possible_emojis, hash_string, validate_string_set
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Protocol:
|
|
14
|
+
"""Model Training Protocol (MTP) class for creating the training configuration."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, name: str, context_lines: int, encrypt: bool = True):
|
|
17
|
+
"""
|
|
18
|
+
Initialize the Model Training Protocol (MTP)
|
|
19
|
+
|
|
20
|
+
:param name: The name of the protocol.
|
|
21
|
+
:param context_lines: The number of lines in each instruction sample. Must be at least 2.
|
|
22
|
+
:param encrypt: Whether to encrypt Tokens with unspecified with hashed keys. Default is True.
|
|
23
|
+
"""
|
|
24
|
+
self.name: str = name
|
|
25
|
+
self.context_lines: int = context_lines # Number of lines in instruction samples
|
|
26
|
+
self.encrypt: bool = encrypt
|
|
27
|
+
if self.context_lines < 2:
|
|
28
|
+
raise ValueError("A minimum of 2 context lines is required for all instructions.")
|
|
29
|
+
self.context: list[str] = []
|
|
30
|
+
self.tokens: set[Token] = set()
|
|
31
|
+
self.instructions: set[Instruction] = set()
|
|
32
|
+
self.guardrails: dict[str, list[str]] = dict()
|
|
33
|
+
self.numbers: dict[str, str] = dict()
|
|
34
|
+
self.none = None
|
|
35
|
+
self.special_tokens: set[Token] = set()
|
|
36
|
+
self.possible_emoji_keys: set[str] = get_possible_emojis()
|
|
37
|
+
self.used_keys: set[str] = set()
|
|
38
|
+
|
|
39
|
+
def add_context(self, context: str):
|
|
40
|
+
"""Adds a line of context to the model."""
|
|
41
|
+
if not isinstance(context, str):
|
|
42
|
+
raise TypeError("Context must be a string.")
|
|
43
|
+
|
|
44
|
+
self.context.append(context)
|
|
45
|
+
|
|
46
|
+
def add_instruction(self, instruction: Instruction):
|
|
47
|
+
"""
|
|
48
|
+
Adds an Instruction (and its components) to the protocol.
|
|
49
|
+
|
|
50
|
+
Asserts that all samples in the instruction match the defined sample line size.
|
|
51
|
+
"""
|
|
52
|
+
if instruction in self.instructions:
|
|
53
|
+
raise ValueError("Instruction already added to the protocol.")
|
|
54
|
+
|
|
55
|
+
if len(instruction.samples) == 0:
|
|
56
|
+
raise ValueError("Instruction must have at least three samples.")
|
|
57
|
+
|
|
58
|
+
# Assert all samples match the defined sample line size
|
|
59
|
+
for sample in instruction.samples:
|
|
60
|
+
if not len(sample.context) == self.context_lines:
|
|
61
|
+
raise ValueError(
|
|
62
|
+
f"Sample context lines ({len(sample.context)}) does not match defined context_lines count ({self.context_lines})"
|
|
63
|
+
f"\n{sample}."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# Add all tokens
|
|
67
|
+
for token in instruction.get_tokens():
|
|
68
|
+
if token not in self.tokens:
|
|
69
|
+
self._add_token(token)
|
|
70
|
+
|
|
71
|
+
# Add the instruction to the protocol
|
|
72
|
+
self.instructions.add(instruction)
|
|
73
|
+
|
|
74
|
+
def save(self, name: str | None = None, path: str | None = None):
|
|
75
|
+
"""
|
|
76
|
+
Saves the protocol to a JSON file. This file can be submitted to Databiomes for model training.
|
|
77
|
+
|
|
78
|
+
:param name: The name of the file (without extension). If None, uses the protocol's name.
|
|
79
|
+
:param path: The directory path where the file will be saved. If None, saves in the current directory.
|
|
80
|
+
"""
|
|
81
|
+
if name is None:
|
|
82
|
+
name = self.name
|
|
83
|
+
if path is None:
|
|
84
|
+
path = os.getcwd()
|
|
85
|
+
os.makedirs(path, exist_ok=True)
|
|
86
|
+
filename = f"{path}\\{name}_model.json"
|
|
87
|
+
|
|
88
|
+
self._prep_protocol()
|
|
89
|
+
protocol_file: ProtocolFile = ProtocolFile(
|
|
90
|
+
name=self.name, context=self.context, context_lines=self.context_lines,
|
|
91
|
+
tokens=self.tokens, special_tokens=self.special_tokens, instructions=self.instructions,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
print(f"Saving Model Train Protocol to {filename}...")
|
|
95
|
+
with open(filename, 'w', encoding="utf-8") as file:
|
|
96
|
+
json.dump(protocol_file.to_json(), file, indent=4, ensure_ascii=False)
|
|
97
|
+
|
|
98
|
+
def template(self, path: str | None = None):
|
|
99
|
+
"""
|
|
100
|
+
Create a template JSON file for the model training protocol.
|
|
101
|
+
|
|
102
|
+
The template json file includes example usage and all possible combinations of model inputs and
|
|
103
|
+
outputs based on the defined tokens and instructions.
|
|
104
|
+
|
|
105
|
+
:param path: The directory path where the template file will be saved. If None, saves in the current directory.
|
|
106
|
+
"""
|
|
107
|
+
if path is None:
|
|
108
|
+
path = os.getcwd()
|
|
109
|
+
filename = f"{path}\\{self.name}_template.json"
|
|
110
|
+
|
|
111
|
+
self._prep_protocol()
|
|
112
|
+
template_file: TemplateFile = TemplateFile(
|
|
113
|
+
instructions=list(self.instructions),
|
|
114
|
+
context_lines=self.context_lines
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
print(f"Saving Model Train Protocol Template to {filename}...")
|
|
118
|
+
with open(filename, 'w', encoding="utf-8") as file:
|
|
119
|
+
json.dump(template_file.to_json(), file, indent=4, ensure_ascii=False)
|
|
120
|
+
|
|
121
|
+
def _assign_key(self, token: Token):
|
|
122
|
+
"""
|
|
123
|
+
Assigns a key to a Token based on the protocol's encryption setting.
|
|
124
|
+
|
|
125
|
+
:param token: The Token to assign the key of.
|
|
126
|
+
"""
|
|
127
|
+
# If the user has assigned a key, use this key
|
|
128
|
+
if token.key is not None:
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
if self.encrypt:
|
|
132
|
+
# Generate a random key for the token if encrypting and no key is set
|
|
133
|
+
token.key = hash_string(key=token.value, output_char=6)
|
|
134
|
+
else:
|
|
135
|
+
# Use the value as the key if not encrypting. I.e. Token 'Continue_' has key 'Continue_'
|
|
136
|
+
token.key = token.value
|
|
137
|
+
|
|
138
|
+
def _add_token(self, token: Token):
|
|
139
|
+
"""
|
|
140
|
+
Adds a unique token to the protocol.
|
|
141
|
+
|
|
142
|
+
Validates that the token's value and key are unique.
|
|
143
|
+
:param token: The Token instance to add.
|
|
144
|
+
"""
|
|
145
|
+
if token in self.tokens:
|
|
146
|
+
raise ValueError(f"Token value '{token.value}' already used.")
|
|
147
|
+
|
|
148
|
+
if token.key in self.used_keys:
|
|
149
|
+
raise ValueError(f"Token key '{token.key}' already used.")
|
|
150
|
+
|
|
151
|
+
self._assign_key(token=token)
|
|
152
|
+
|
|
153
|
+
self.tokens.add(token)
|
|
154
|
+
self.used_keys.add(token.key)
|
|
155
|
+
|
|
156
|
+
if isinstance(token, SpecialToken):
|
|
157
|
+
self.special_tokens.add(token)
|
|
158
|
+
|
|
159
|
+
def _set_guardrails(self):
|
|
160
|
+
"""Sets all guardrails from TokenSets into the protocol."""
|
|
161
|
+
# Add all guardrails to the protocol
|
|
162
|
+
for instruction in self.instructions:
|
|
163
|
+
if instruction.response.guardrail is not None:
|
|
164
|
+
# instruction.response is the user TokenSet
|
|
165
|
+
self.guardrails[instruction.response.key] = instruction.response.guardrail.format_samples()
|
|
166
|
+
|
|
167
|
+
def _add_default_special_tokens(self):
|
|
168
|
+
"""Adds all special tokens to the protocol."""
|
|
169
|
+
self.special_tokens.add(BOS_TOKEN)
|
|
170
|
+
self.special_tokens.add(EOS_TOKEN)
|
|
171
|
+
self.special_tokens.add(RUN_TOKEN)
|
|
172
|
+
self.special_tokens.add(PAD_TOKEN)
|
|
173
|
+
if len(self.guardrails) > 0:
|
|
174
|
+
self.special_tokens.add(UNK_TOKEN)
|
|
175
|
+
|
|
176
|
+
def _prep_protocol(self):
|
|
177
|
+
"""
|
|
178
|
+
Sets all elements in the protocol before serialization.
|
|
179
|
+
|
|
180
|
+
Raises errors if any validation checks fail.
|
|
181
|
+
|
|
182
|
+
Setups up all necessary components in the protocol before saving or templating.
|
|
183
|
+
|
|
184
|
+
This includes setting guardrails from their TokenSets and creating default special tokens.
|
|
185
|
+
"""
|
|
186
|
+
if len(self.instructions) == 0:
|
|
187
|
+
raise ValueError("No instructions have been added to Protocol. Call protocol.add_instruction() to add instructions.")
|
|
188
|
+
|
|
189
|
+
self._set_guardrails()
|
|
190
|
+
self._add_default_special_tokens()
|
|
191
|
+
used_values: set[str] = {token.value for token in self.tokens}
|
|
192
|
+
validate_string_set(used_values)
|
|
193
|
+
validate_string_set(self.used_keys)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Train Protocol (MTP) - A Python package for creating custom Language Model training protocols.
|
|
3
|
+
|
|
4
|
+
MTP is an open-source protocol for training custom Language Models on Databiomes.
|
|
5
|
+
MTP contains all the data that a model is trained on.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from .common.tokens import Token, UserToken, NumToken, NumListToken, Snippet, TokenSet
|
|
9
|
+
from .common.instructions import SimpleInstruction, UserInstruction
|
|
10
|
+
from .common.guardrails import Guardrail
|
|
11
|
+
from .Protocol import Protocol
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"Protocol",
|
|
15
|
+
"Token",
|
|
16
|
+
"UserToken",
|
|
17
|
+
"NumToken",
|
|
18
|
+
"NumListToken",
|
|
19
|
+
"TokenSet",
|
|
20
|
+
"Snippet",
|
|
21
|
+
"SimpleInstruction",
|
|
22
|
+
"UserInstruction",
|
|
23
|
+
"Guardrail"
|
|
24
|
+
]
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Collection
|
|
3
|
+
|
|
4
|
+
from model_train_protocol import Token, NumToken
|
|
5
|
+
from model_train_protocol.common.constants import UNK_TOKEN
|
|
6
|
+
from model_train_protocol.common.instructions import Instruction
|
|
7
|
+
from model_train_protocol.common.guardrails import Guardrail
|
|
8
|
+
from model_train_protocol.common.tokens import TokenSet, SpecialToken
|
|
9
|
+
from model_train_protocol.common.pydantic.protocol import InstructionModel, TokenInfoModel, SampleModel, InstructionSetModel, NumberModel, \
|
|
10
|
+
BatchModel, ProtocolModel
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProtocolFile:
|
|
14
|
+
"""Manages the model.json file for model training protocols."""
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ProtocolInstruction:
|
|
18
|
+
"""Represents an instruction in the template."""
|
|
19
|
+
|
|
20
|
+
context_lines: int
|
|
21
|
+
sets: list = field(default_factory=list)
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class ProtocolInstructionSet:
|
|
25
|
+
"""Represents an instruction set in the template."""
|
|
26
|
+
|
|
27
|
+
set: list[list[str]]
|
|
28
|
+
result: str
|
|
29
|
+
samples: list
|
|
30
|
+
ppo: list
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class Batches:
|
|
34
|
+
"""Represents batches in the template."""
|
|
35
|
+
|
|
36
|
+
pretrain: list = field(default_factory=list)
|
|
37
|
+
instruct: list = field(default_factory=list)
|
|
38
|
+
judge: list = field(default_factory=list)
|
|
39
|
+
ppo: list = field(default_factory=list)
|
|
40
|
+
|
|
41
|
+
def __init__(self, name: str, context: list[str], context_lines: int, tokens: Collection[Token],
|
|
42
|
+
special_tokens: Collection[Token], instructions: Collection[Instruction]):
|
|
43
|
+
"""Initializes the Template with a name and context."""
|
|
44
|
+
self._name: str = name
|
|
45
|
+
self._context: list[str] = context
|
|
46
|
+
self._tokens: dict[str, dict] = {}
|
|
47
|
+
self._special_token_keys: set[str] = set()
|
|
48
|
+
self._instruction_token_keys: set[str] = set()
|
|
49
|
+
self._instruction: ProtocolFile.ProtocolInstruction = ProtocolFile.ProtocolInstruction(
|
|
50
|
+
context_lines=context_lines)
|
|
51
|
+
self._guardrails: dict[str, list[str] | str] = {'None': ''}
|
|
52
|
+
self._numbers: dict[str, str] = {'None': ''}
|
|
53
|
+
self._batches: ProtocolFile.Batches = ProtocolFile.Batches()
|
|
54
|
+
|
|
55
|
+
# Add regular tokens
|
|
56
|
+
self.add_tokens(tokens)
|
|
57
|
+
|
|
58
|
+
# Add special tokens
|
|
59
|
+
self.add_tokens(special_tokens)
|
|
60
|
+
|
|
61
|
+
# Add instructions
|
|
62
|
+
self.add_instructions(instructions)
|
|
63
|
+
|
|
64
|
+
def add_tokens(self, tokens: Collection[Token]):
|
|
65
|
+
"""Adds tokens to the template."""
|
|
66
|
+
for token in tokens:
|
|
67
|
+
token_dict: dict[str, dict] = token.to_dict()
|
|
68
|
+
token_dict.pop("value")
|
|
69
|
+
self._tokens[token.value] = token_dict
|
|
70
|
+
|
|
71
|
+
# Add numbers to the numbers dictionary
|
|
72
|
+
if isinstance(token, NumToken):
|
|
73
|
+
self._numbers[token.value] = token.protocol_representation
|
|
74
|
+
|
|
75
|
+
# Add special tokens to the special tokens set
|
|
76
|
+
if isinstance(token, SpecialToken):
|
|
77
|
+
self._special_token_keys.add(token.key)
|
|
78
|
+
|
|
79
|
+
def add_instructions(self, instructions: Collection[Instruction]):
|
|
80
|
+
"""Adds instructions to the template."""
|
|
81
|
+
for instruction in instructions:
|
|
82
|
+
instruction_set: ProtocolFile.ProtocolInstructionSet = ProtocolFile.ProtocolInstructionSet(
|
|
83
|
+
set=instruction.serialize_memory_set(),
|
|
84
|
+
result=instruction.final.value,
|
|
85
|
+
samples=instruction.serialize_samples(),
|
|
86
|
+
ppo=instruction.serialize_ppo(),
|
|
87
|
+
)
|
|
88
|
+
self._instruction.sets.append(instruction_set)
|
|
89
|
+
|
|
90
|
+
# Add guardrails from the instruction's TokenSets
|
|
91
|
+
self._add_guardrails(instruction.get_token_sets())
|
|
92
|
+
|
|
93
|
+
# Add instruction token keys
|
|
94
|
+
for token_set in instruction.get_token_sets():
|
|
95
|
+
self._add_instruction_token_key(token_set.get_token_key_set())
|
|
96
|
+
|
|
97
|
+
# Add the result token as a special token
|
|
98
|
+
if instruction.final.key is not None:
|
|
99
|
+
self._add_instruction_token_key(instruction.final.key)
|
|
100
|
+
|
|
101
|
+
def _add_instruction_token_key(self, key: str):
|
|
102
|
+
"""Adds an instruction token key to the template."""
|
|
103
|
+
self._instruction_token_keys.add(key)
|
|
104
|
+
|
|
105
|
+
def _add_guardrails(self, token_sets: Collection[TokenSet]):
|
|
106
|
+
"""Adds guardrails from TokenSets to the template."""
|
|
107
|
+
for token_set in token_sets:
|
|
108
|
+
if token_set.guardrail is None:
|
|
109
|
+
continue
|
|
110
|
+
guardrail: Guardrail = token_set.guardrail
|
|
111
|
+
self._guardrails[token_set.key] = guardrail.format_samples()
|
|
112
|
+
|
|
113
|
+
@classmethod
|
|
114
|
+
def _rename_protocol_elements(cls, protocol_json: dict):
|
|
115
|
+
"""
|
|
116
|
+
Renames elements in the ProtocolFile json to match the previous output format for backwards compatibility.
|
|
117
|
+
:param protocol_json: The original json dictionary.
|
|
118
|
+
:return: The modified json with renamed elements.
|
|
119
|
+
"""
|
|
120
|
+
# Add special token <UNK> REGARDLESS of whether we have any guardrails
|
|
121
|
+
unk_token_dict: dict = UNK_TOKEN.to_dict()
|
|
122
|
+
unk_token_dict['emoji'] = unk_token_dict.pop('key')
|
|
123
|
+
unk_token_dict.pop('value')
|
|
124
|
+
protocol_json['tokens'][UNK_TOKEN.value] = unk_token_dict
|
|
125
|
+
protocol_json['special_tokens'].append(UNK_TOKEN.key)
|
|
126
|
+
|
|
127
|
+
for token_value, token_info in protocol_json.get('tokens', {}).items():
|
|
128
|
+
# Rename Token 'key' to 'emoji'
|
|
129
|
+
if 'key' in token_info:
|
|
130
|
+
token_info['emoji'] = token_info.pop('key')
|
|
131
|
+
|
|
132
|
+
# Reassign Token 'num' to boolean
|
|
133
|
+
if 'num' in token_info:
|
|
134
|
+
num: int = token_info['num']
|
|
135
|
+
token_info['num'] = True if num >= 1 else False
|
|
136
|
+
|
|
137
|
+
# TODO: Differentiate between num and num_list tokens in the future - currently both are just 'num': True
|
|
138
|
+
|
|
139
|
+
for instruction in protocol_json.get('instruction', {}).get('sets', []):
|
|
140
|
+
|
|
141
|
+
# Rename sample number to None if an array of empty arrays
|
|
142
|
+
for sample in instruction['samples']:
|
|
143
|
+
if all(num == [] for num in sample['number']):
|
|
144
|
+
sample['number'] = None
|
|
145
|
+
|
|
146
|
+
# Rename sample 'strings' to 'sample'
|
|
147
|
+
for sample in instruction['samples']:
|
|
148
|
+
if 'strings' in sample:
|
|
149
|
+
sample['sample'] = sample.pop('strings')
|
|
150
|
+
|
|
151
|
+
# Rename null values to "None"
|
|
152
|
+
for sample in instruction['samples']:
|
|
153
|
+
if sample['value'] is None:
|
|
154
|
+
sample['value'] = "None"
|
|
155
|
+
|
|
156
|
+
return protocol_json
|
|
157
|
+
|
|
158
|
+
def _get_special_token_keys(self):
|
|
159
|
+
"""
|
|
160
|
+
Returns a sorted list of tokens that should be under 'special_tokens' in the JSON.
|
|
161
|
+
|
|
162
|
+
:return: A sorted list of special token keys.
|
|
163
|
+
"""
|
|
164
|
+
return sorted(self._special_token_keys | self._instruction_token_keys)
|
|
165
|
+
|
|
166
|
+
def to_json(self):
|
|
167
|
+
"""Converts the template to a JSON-compatible dictionary using Pydantic models."""
|
|
168
|
+
|
|
169
|
+
# Create TokenInfo objects for each token
|
|
170
|
+
token_info_dict = {}
|
|
171
|
+
for token_value, token_dict in self._tokens.items():
|
|
172
|
+
token_info = TokenInfoModel(
|
|
173
|
+
emoji=token_dict.get('emoji', ''),
|
|
174
|
+
num=token_dict.get('num', False),
|
|
175
|
+
user=token_dict.get('user', False),
|
|
176
|
+
desc=token_dict.get('desc'),
|
|
177
|
+
special=token_dict.get('special')
|
|
178
|
+
)
|
|
179
|
+
token_info_dict[token_value] = token_info
|
|
180
|
+
|
|
181
|
+
# Create InstructionSet objects
|
|
182
|
+
instruction_sets = []
|
|
183
|
+
for instruction_set in self._instruction.sets:
|
|
184
|
+
# Create Sample objects
|
|
185
|
+
samples = []
|
|
186
|
+
for sample_data in instruction_set.samples:
|
|
187
|
+
sample = SampleModel(
|
|
188
|
+
sample=sample_data.get('strings', []),
|
|
189
|
+
prompt=sample_data.get('prompt', ''),
|
|
190
|
+
number=sample_data.get('number', []),
|
|
191
|
+
result=sample_data.get('result', ''),
|
|
192
|
+
value=sample_data.get('value', '')
|
|
193
|
+
)
|
|
194
|
+
samples.append(sample)
|
|
195
|
+
|
|
196
|
+
# Create InstructionSet
|
|
197
|
+
instruction_set_obj = InstructionSetModel(
|
|
198
|
+
set=instruction_set.set,
|
|
199
|
+
result=instruction_set.result,
|
|
200
|
+
samples=samples,
|
|
201
|
+
ppo=instruction_set.ppo
|
|
202
|
+
)
|
|
203
|
+
instruction_sets.append(instruction_set_obj)
|
|
204
|
+
|
|
205
|
+
# Create Instruction object
|
|
206
|
+
instruction = InstructionModel(
|
|
207
|
+
memory=self._instruction.context_lines + 1, # +1 for the response line
|
|
208
|
+
sets=instruction_sets
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
# Create Numbers object
|
|
212
|
+
numbers = NumberModel()
|
|
213
|
+
|
|
214
|
+
# Create Batches object
|
|
215
|
+
batches = BatchModel(
|
|
216
|
+
pretrain=self._batches.pretrain,
|
|
217
|
+
instruct=self._batches.instruct,
|
|
218
|
+
judge=self._batches.judge,
|
|
219
|
+
ppo=self._batches.ppo
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Create ProtocolModel
|
|
223
|
+
protocol = ProtocolModel(
|
|
224
|
+
name=self._name,
|
|
225
|
+
context=self._context,
|
|
226
|
+
tokens=token_info_dict,
|
|
227
|
+
special_tokens=self._get_special_token_keys(),
|
|
228
|
+
instruction=instruction,
|
|
229
|
+
guardrails=self._guardrails,
|
|
230
|
+
numbers=numbers,
|
|
231
|
+
batches=batches
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Convert to JSON and apply backwards compatibility transformations
|
|
235
|
+
json_dict = protocol.model_dump(by_alias=True)
|
|
236
|
+
return self._rename_protocol_elements(json_dict)
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Collection
|
|
3
|
+
|
|
4
|
+
from model_train_protocol import NumToken, SimpleInstruction, UserInstruction
|
|
5
|
+
from model_train_protocol.common.constants import BOS_TOKEN, RUN_TOKEN, EOS_TOKEN
|
|
6
|
+
from model_train_protocol.common.instructions import Instruction
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TemplateFile:
|
|
10
|
+
"""Manages the model.json file for model training protocols."""
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class ExampleUsage:
|
|
14
|
+
"""Stores example usages of the template."""
|
|
15
|
+
|
|
16
|
+
input: str
|
|
17
|
+
output: str
|
|
18
|
+
|
|
19
|
+
class ModelInput:
|
|
20
|
+
"""Represents inputs to the model."""
|
|
21
|
+
|
|
22
|
+
inputs: list[list[str]] = list()
|
|
23
|
+
|
|
24
|
+
def add_inputs_from_instructions(self, instructions: list[Instruction], context_lines: int):
|
|
25
|
+
"""Adds input combinations from a list of instructions."""
|
|
26
|
+
unique_sets = {i: set() for i in range(context_lines + 1)}
|
|
27
|
+
for instruction in instructions:
|
|
28
|
+
for idx, token_set in enumerate(instruction.get_token_sets()):
|
|
29
|
+
token_user = [t.user for t in token_set]
|
|
30
|
+
token_strings = "".join([t.value for t in token_set])
|
|
31
|
+
token_keys = []
|
|
32
|
+
for token in token_set:
|
|
33
|
+
token_keys.append(
|
|
34
|
+
token.key + (token.protocol_representation if isinstance(token, NumToken) else ""))
|
|
35
|
+
token_keys = "".join(token_keys)
|
|
36
|
+
unique_sets[idx].add(str(token_strings) + ": " + (
|
|
37
|
+
(str(token_keys) + "USER PROMPT") if any(token_user) and (
|
|
38
|
+
idx == (len(unique_sets) - 1)) else str(
|
|
39
|
+
token_keys)) + "\n" + ("<string>" if idx != (len(instruction.context) - 1) else ""))
|
|
40
|
+
|
|
41
|
+
for input_set in unique_sets.values():
|
|
42
|
+
self.inputs.append(list(input_set))
|
|
43
|
+
|
|
44
|
+
def to_json(self):
|
|
45
|
+
"""Converts the model input to a JSON-serializable dictionary."""
|
|
46
|
+
model_json: dict[str, Collection[str] | str] = {"<BOS>": BOS_TOKEN.key}
|
|
47
|
+
# Add each input sequence with its index as the key
|
|
48
|
+
for idx, input_seq in enumerate(self.inputs):
|
|
49
|
+
model_json[str(idx)] = input_seq
|
|
50
|
+
model_json["<RUN>"] = RUN_TOKEN.key
|
|
51
|
+
return model_json
|
|
52
|
+
|
|
53
|
+
class ModelOutput:
|
|
54
|
+
model_results: dict[str, str] = dict()
|
|
55
|
+
model_response: str = "<string>"
|
|
56
|
+
|
|
57
|
+
def __setitem__(self, key: str, value: str):
|
|
58
|
+
self.model_results[key] = value
|
|
59
|
+
|
|
60
|
+
def add_results_from_instructions(self, instructions: list[Instruction]):
|
|
61
|
+
"""Adds model results from a list of instructions."""
|
|
62
|
+
for instruction in instructions:
|
|
63
|
+
self.model_results[str(instruction.final.value)] = str(instruction.final.key)
|
|
64
|
+
|
|
65
|
+
def to_json(self):
|
|
66
|
+
"""Converts the model output to a JSON-serializable dictionary."""
|
|
67
|
+
model_json: dict[str, str | dict] = {
|
|
68
|
+
"model_response": self.model_response,
|
|
69
|
+
"model_results": {}
|
|
70
|
+
}
|
|
71
|
+
# Add each model result with its key
|
|
72
|
+
for key, value in self.model_results.items():
|
|
73
|
+
model_json["model_results"][key] = value
|
|
74
|
+
|
|
75
|
+
model_json["<EOS>"] = EOS_TOKEN.key
|
|
76
|
+
|
|
77
|
+
# Sort alphabetically for readability and consistency across runs
|
|
78
|
+
model_json["model_results"] = dict(sorted(model_json["model_results"].items()))
|
|
79
|
+
|
|
80
|
+
return model_json
|
|
81
|
+
|
|
82
|
+
def __init__(self, context_lines: int, instructions: list[Instruction], ):
|
|
83
|
+
"""Initializes the template"""
|
|
84
|
+
self.model_input: TemplateFile.ModelInput = TemplateFile.ModelInput()
|
|
85
|
+
self.model_output: TemplateFile.ModelOutput = TemplateFile.ModelOutput()
|
|
86
|
+
self.context_lines: int = context_lines
|
|
87
|
+
self.instructions: list[Instruction] = instructions
|
|
88
|
+
self._add_io_from_instructions()
|
|
89
|
+
|
|
90
|
+
def _add_io_from_instructions(self):
|
|
91
|
+
"""Adds input and output sequences from the instructions."""
|
|
92
|
+
self.model_input.add_inputs_from_instructions(self.instructions, context_lines=self.context_lines)
|
|
93
|
+
self.model_output.add_results_from_instructions(self.instructions)
|
|
94
|
+
|
|
95
|
+
def _create_sample_model_output(self):
|
|
96
|
+
"""Creates a sample model output string for example usages."""
|
|
97
|
+
sample_output: str = ""
|
|
98
|
+
sample_output += self.model_output.model_response + "\n"
|
|
99
|
+
sorted_model_results: list[tuple[str, str]] = list(sorted(self.model_output.model_results.items()))
|
|
100
|
+
sample_output += sorted_model_results[0][1] + "\n"
|
|
101
|
+
sample_output += EOS_TOKEN.key
|
|
102
|
+
return sample_output
|
|
103
|
+
|
|
104
|
+
def _create_examples(self) -> dict[str, str]:
|
|
105
|
+
"""
|
|
106
|
+
Creates example usages of the template.
|
|
107
|
+
|
|
108
|
+
Creates a simple instruction example and a user instruction example if available.
|
|
109
|
+
"""
|
|
110
|
+
examples: dict[str, str] = dict()
|
|
111
|
+
simple_instruction: SimpleInstruction = next(
|
|
112
|
+
(i for i in self.instructions if isinstance(i, SimpleInstruction)), None)
|
|
113
|
+
user_instruction: UserInstruction = next(
|
|
114
|
+
(i for i in self.instructions if isinstance(i, UserInstruction)), None)
|
|
115
|
+
|
|
116
|
+
if simple_instruction:
|
|
117
|
+
simple_input: str = ""
|
|
118
|
+
for token_set in simple_instruction.get_token_sets():
|
|
119
|
+
token_strings = "".join([token.key for token in token_set])
|
|
120
|
+
simple_input += token_strings + "\n"
|
|
121
|
+
simple_input += "<string>\n"
|
|
122
|
+
simple_input = BOS_TOKEN.key + "\n" + simple_input + RUN_TOKEN.key + "\n"
|
|
123
|
+
examples["simple_instruction_input"] = simple_input + self._create_sample_model_output()
|
|
124
|
+
|
|
125
|
+
if user_instruction:
|
|
126
|
+
user_input: str = ""
|
|
127
|
+
for idx, token_set in enumerate(user_instruction.get_token_sets()):
|
|
128
|
+
token_strings = "".join([token.key for token in token_set])
|
|
129
|
+
user_input += token_strings + "\n"
|
|
130
|
+
user_input += "<string>\n" if idx != (len(user_instruction.get_token_sets()) - 1) else "USER PROMPT\n"
|
|
131
|
+
user_input = BOS_TOKEN.key + "\n" + user_input + RUN_TOKEN.key + "\n"
|
|
132
|
+
examples["valid_user_input"] = user_input
|
|
133
|
+
|
|
134
|
+
examples["valid_output"] = self._create_sample_model_output()
|
|
135
|
+
|
|
136
|
+
return examples
|
|
137
|
+
|
|
138
|
+
def to_json(self) -> dict:
|
|
139
|
+
"""Converts the entire template to a JSON-serializable dictionary."""
|
|
140
|
+
examples: dict[str, str] = self._create_examples()
|
|
141
|
+
json_dict: dict = {
|
|
142
|
+
"all_combinations": {
|
|
143
|
+
"model_input": self.model_input.to_json(),
|
|
144
|
+
"model_output": self.model_output.to_json()
|
|
145
|
+
},
|
|
146
|
+
"example_usage": examples
|
|
147
|
+
}
|
|
148
|
+
return json_dict
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from model_train_protocol.common.tokens import SpecialToken
|
|
2
|
+
|
|
3
|
+
NON_TOKEN: SpecialToken = SpecialToken(value="<NON>", key="<NON>", special="none")
|
|
4
|
+
BOS_TOKEN: SpecialToken = SpecialToken(value="<BOS>", key="<BOS>", special="start")
|
|
5
|
+
EOS_TOKEN: SpecialToken = SpecialToken(value="<EOS>", key="<EOS>", special="end")
|
|
6
|
+
RUN_TOKEN: SpecialToken = SpecialToken(value="<RUN>", key="<RUN>", special="infer")
|
|
7
|
+
PAD_TOKEN: SpecialToken = SpecialToken(value="<PAD>", key="<PAD>", special="pad")
|
|
8
|
+
UNK_TOKEN: SpecialToken = SpecialToken(value="<UNK>", key="<UNK>", special="unknown")
|
|
9
|
+
|
|
10
|
+
# TODO: Remove this code when emoji dependency is removed
|
|
11
|
+
"""Assign default emoji keys to special tokens for backward compatibility."""
|
|
12
|
+
special_token_emoji_map = {
|
|
13
|
+
NON_TOKEN.value: "🫙",
|
|
14
|
+
BOS_TOKEN.value: "🏁",
|
|
15
|
+
EOS_TOKEN.value: "🎬",
|
|
16
|
+
RUN_TOKEN.value: "🏃",
|
|
17
|
+
PAD_TOKEN.value: "🗒",
|
|
18
|
+
UNK_TOKEN.value: "🛑"
|
|
19
|
+
}
|
|
20
|
+
for token in [NON_TOKEN, BOS_TOKEN, EOS_TOKEN, RUN_TOKEN, PAD_TOKEN, UNK_TOKEN]:
|
|
21
|
+
token.key = special_token_emoji_map[token.value]
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
class Guardrail:
|
|
2
|
+
"""
|
|
3
|
+
Defines a guardrails response to bad prompts.
|
|
4
|
+
|
|
5
|
+
Guardrails are set on TokenSets. Each TokenSet can have at most one guardrails, but guardrails can be reused.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, good_prompt: str, bad_prompt: str, bad_output: str):
|
|
9
|
+
"""
|
|
10
|
+
Initializes a Guardrail.
|
|
11
|
+
:param good_prompt: Description of a good prompt.
|
|
12
|
+
:param bad_prompt: Description of a bad prompt.
|
|
13
|
+
:param bad_output: The output the model should produce when a bad prompt is detected.
|
|
14
|
+
|
|
15
|
+
Example:
|
|
16
|
+
good_prompt="Quote being spoken with 1-20 words",
|
|
17
|
+
bad_prompt="Quote being spoken that is irrelevant and off-topic with 1-20 words",
|
|
18
|
+
output="I have no idea what you're talking about."
|
|
19
|
+
"""
|
|
20
|
+
if not all(isinstance(param, str) for param in [good_prompt, bad_prompt, bad_output]):
|
|
21
|
+
raise TypeError("All parameters must be non-empty strings.")
|
|
22
|
+
|
|
23
|
+
if any(param == "" for param in [good_prompt, bad_prompt, bad_output]):
|
|
24
|
+
raise ValueError("All parameters must be non-empty strings.")
|
|
25
|
+
|
|
26
|
+
self.good_prompt: str = good_prompt
|
|
27
|
+
self.bad_prompt: str = bad_prompt
|
|
28
|
+
self.bad_output: str = bad_output
|
|
29
|
+
self.samples: list[str] = []
|
|
30
|
+
|
|
31
|
+
def add_sample(self, sample: str):
|
|
32
|
+
"""
|
|
33
|
+
Add an example of a bad sample prompt to the guardrails.
|
|
34
|
+
|
|
35
|
+
:param sample: An example of a bad prompt that should trigger the guardrails.
|
|
36
|
+
|
|
37
|
+
Example:
|
|
38
|
+
sample="Tell me a joke about politics."
|
|
39
|
+
"""
|
|
40
|
+
if not isinstance(sample, str) or not sample.strip():
|
|
41
|
+
raise ValueError("Sample prompt must be a non-empty string.")
|
|
42
|
+
|
|
43
|
+
if not all(not char.isdigit() for char in sample):
|
|
44
|
+
raise ValueError("Sample prompt cannot contain digits.")
|
|
45
|
+
self.samples.append(sample)
|
|
46
|
+
|
|
47
|
+
def format_samples(self) -> list[str]:
|
|
48
|
+
"""Return the guardrails as a list of strings for JSON formatting."""
|
|
49
|
+
if len(self.samples) < 3:
|
|
50
|
+
raise ValueError("At least 3 sample prompts are required. Call add_sample() to add more.")
|
|
51
|
+
return [self.bad_output, f"<{self.bad_prompt}>", f"<{self.good_prompt}>", self.samples]
|