erictransformer 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- erictransformer/__init__.py +44 -0
- erictransformer/args/__init__.py +7 -0
- erictransformer/args/eric_args.py +50 -0
- erictransformer/eric_tasks/__init__.py +47 -0
- erictransformer/eric_tasks/args/__init__.py +16 -0
- erictransformer/eric_tasks/args/eric_chat_args.py +21 -0
- erictransformer/eric_tasks/args/eric_generation_args.py +20 -0
- erictransformer/eric_tasks/args/eric_text_classification_args.py +13 -0
- erictransformer/eric_tasks/args/eric_text_to_text_args.py +18 -0
- erictransformer/eric_tasks/chat_stream_handlers/__init__.py +6 -0
- erictransformer/eric_tasks/chat_stream_handlers/args.py +13 -0
- erictransformer/eric_tasks/chat_stream_handlers/default.py +19 -0
- erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py +147 -0
- erictransformer/eric_tasks/chat_stream_handlers/smol.py +81 -0
- erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py +17 -0
- erictransformer/eric_tasks/chat_templates/__init__.py +1 -0
- erictransformer/eric_tasks/chat_templates/convert.py +67 -0
- erictransformer/eric_tasks/eric_chat.py +369 -0
- erictransformer/eric_tasks/eric_chat_mlx.py +278 -0
- erictransformer/eric_tasks/eric_generation.py +243 -0
- erictransformer/eric_tasks/eric_text_classification.py +231 -0
- erictransformer/eric_tasks/eric_text_to_text.py +283 -0
- erictransformer/eric_tasks/inference_engine/__init__.py +3 -0
- erictransformer/eric_tasks/inference_engine/text_classification.py +28 -0
- erictransformer/eric_tasks/misc/__init__.py +11 -0
- erictransformer/eric_tasks/misc/call_utils.py +69 -0
- erictransformer/eric_tasks/misc/get_pad_eos.py +24 -0
- erictransformer/eric_tasks/misc/rag.py +17 -0
- erictransformer/eric_tasks/results/__init__.py +6 -0
- erictransformer/eric_tasks/results/call_results.py +30 -0
- erictransformer/eric_tasks/tok/__init__.py +0 -0
- erictransformer/eric_tasks/tok/tok_functions.py +118 -0
- erictransformer/eric_tracker/__init__.py +1 -0
- erictransformer/eric_tracker/eric_tracker.py +256 -0
- erictransformer/eric_tracker/save_plot.py +422 -0
- erictransformer/eric_transformer.py +534 -0
- erictransformer/eval_models/__init__.py +1 -0
- erictransformer/eval_models/eval_model.py +75 -0
- erictransformer/exceptions/__init__.py +19 -0
- erictransformer/exceptions/eric_exceptions.py +74 -0
- erictransformer/loops/__init__.py +2 -0
- erictransformer/loops/eval_loop.py +111 -0
- erictransformer/loops/train_loop.py +310 -0
- erictransformer/utils/__init__.py +21 -0
- erictransformer/utils/init/__init__.py +5 -0
- erictransformer/utils/init/get_components.py +204 -0
- erictransformer/utils/init/get_device.py +22 -0
- erictransformer/utils/init/get_logger.py +15 -0
- erictransformer/utils/load_from_repo_or_path.py +14 -0
- erictransformer/utils/test/__init__.py +1 -0
- erictransformer/utils/test/debug_hook.py +20 -0
- erictransformer/utils/timer/__init__.py +1 -0
- erictransformer/utils/timer/eric_timer.py +145 -0
- erictransformer/utils/tok_data/__init__.py +8 -0
- erictransformer/utils/tok_data/num_proc.py +15 -0
- erictransformer/utils/tok_data/save_tok_data.py +36 -0
- erictransformer/utils/tok_data/tok_data_to_dataset.py +48 -0
- erictransformer/utils/tok_data/tok_helpers.py +79 -0
- erictransformer/utils/train/__init__.py +6 -0
- erictransformer/utils/train/confirm_optimizer.py +18 -0
- erictransformer/utils/train/create_dir.py +72 -0
- erictransformer/utils/train/get_num_training_steps.py +15 -0
- erictransformer/utils/train/get_precision.py +22 -0
- erictransformer/utils/train/get_tok_data.py +105 -0
- erictransformer/utils/train/resume.py +62 -0
- erictransformer/validator/__init__.py +11 -0
- erictransformer/validator/eric/__init__.py +2 -0
- erictransformer/validator/eric/eval_validator.py +75 -0
- erictransformer/validator/eric/train_validator.py +143 -0
- erictransformer/validator/eric_validator.py +10 -0
- erictransformer/validator/tasks/__init__.py +5 -0
- erictransformer/validator/tasks/chat_validator.py +28 -0
- erictransformer/validator/tasks/gen_validator.py +28 -0
- erictransformer/validator/tasks/task_validator.py +54 -0
- erictransformer/validator/tasks/tc_validator.py +45 -0
- erictransformer/validator/tasks/tt_validator.py +28 -0
- erictransformer/validator/tok/__init__.py +1 -0
- erictransformer/validator/tok/tok_validator.py +23 -0
- erictransformer-0.0.1.dist-info/METADATA +72 -0
- erictransformer-0.0.1.dist-info/RECORD +83 -0
- erictransformer-0.0.1.dist-info/WHEEL +5 -0
- erictransformer-0.0.1.dist-info/licenses/LICENSE +202 -0
- erictransformer-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from erictransformer.validator.eric_validator import EricValidator
|
|
2
|
+
|
|
3
|
+
# tasks
|
|
4
|
+
from erictransformer.validator.eric import EvalValidator, TrainValidator
|
|
5
|
+
from erictransformer.validator.tasks import (
|
|
6
|
+
CHATValidator,
|
|
7
|
+
GENValidator,
|
|
8
|
+
TCValidator,
|
|
9
|
+
TTValidator
|
|
10
|
+
)
|
|
11
|
+
from erictransformer.validator.tok.tok_validator import TokValidator
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from logging import Logger
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
from transformers import PreTrainedModel
|
|
6
|
+
|
|
7
|
+
from erictransformer.args import EricEvalArgs
|
|
8
|
+
from erictransformer.exceptions import EricInputError
|
|
9
|
+
from erictransformer.validator import EricValidator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class EvalValidator(EricValidator):
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
logger: Logger,
|
|
16
|
+
eval_path: str = "",
|
|
17
|
+
args: EricEvalArgs = EricEvalArgs(),
|
|
18
|
+
out_dir: str = "",
|
|
19
|
+
model: Union[PreTrainedModel, None] = None,
|
|
20
|
+
):
|
|
21
|
+
self.logger = logger
|
|
22
|
+
self.eval_path = eval_path
|
|
23
|
+
self.args = args
|
|
24
|
+
self.out_dir = out_dir
|
|
25
|
+
self.model = model
|
|
26
|
+
|
|
27
|
+
super().__init__()
|
|
28
|
+
|
|
29
|
+
self.eval_source = self.resolve_eval_source()
|
|
30
|
+
|
|
31
|
+
def validate_init(self):
|
|
32
|
+
self._validate_file_paths_exist()
|
|
33
|
+
self._validate_args_type()
|
|
34
|
+
self._validate_eval_args_fields()
|
|
35
|
+
self._validate_out_dir()
|
|
36
|
+
|
|
37
|
+
def _validate_model_or_resume_path(self):
|
|
38
|
+
if self.model is None:
|
|
39
|
+
raise ValueError("Provide model_name when calling init")
|
|
40
|
+
|
|
41
|
+
def _validate_file_paths_exist(self):
|
|
42
|
+
if not os.path.isfile(self.eval_path) and not os.path.isdir(
|
|
43
|
+
self.eval_path
|
|
44
|
+
):
|
|
45
|
+
raise FileNotFoundError(
|
|
46
|
+
f"Input file or tokenized directory not found: {self.eval_path}"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def _validate_args_type(self):
|
|
50
|
+
if not isinstance(self.args, EricEvalArgs):
|
|
51
|
+
raise TypeError(
|
|
52
|
+
f"Expected args to be of type EvalArgs but got {type(self.args).__name__}"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
def _validate_eval_args_fields(self):
|
|
56
|
+
args = self.args
|
|
57
|
+
|
|
58
|
+
if not isinstance(args.bs, int) or args.bs <= 0:
|
|
59
|
+
raise EricInputError("`bs` must be a positive integer.")
|
|
60
|
+
|
|
61
|
+
if not isinstance(args.seed, int) or args.seed < 0:
|
|
62
|
+
raise EricInputError("`seed` must be a non-negative integer.")
|
|
63
|
+
|
|
64
|
+
def resolve_eval_source(self) -> str:
|
|
65
|
+
if os.path.isfile(self.eval_path):
|
|
66
|
+
return "file"
|
|
67
|
+
elif os.path.isdir(self.eval_path):
|
|
68
|
+
return "folder"
|
|
69
|
+
|
|
70
|
+
def _validate_out_dir(self):
|
|
71
|
+
if self.eval_path:
|
|
72
|
+
if self.out_dir == "":
|
|
73
|
+
raise EricInputError(
|
|
74
|
+
"`out_dir` cannot be empty when using `eval_path`."
|
|
75
|
+
)
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import os
|
|
3
|
+
from logging import Logger
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
|
+
|
|
8
|
+
from erictransformer.args import EricTrainArgs
|
|
9
|
+
from erictransformer.exceptions import EricInputError
|
|
10
|
+
from erictransformer.validator import EricValidator
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TrainValidator(EricValidator):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
logger: Logger,
|
|
17
|
+
train_path: str = "",
|
|
18
|
+
args: EricTrainArgs = EricTrainArgs(),
|
|
19
|
+
eval_path: str = "",
|
|
20
|
+
resume_path: str = "",
|
|
21
|
+
model: Union[PreTrainedModel, None] = None,
|
|
22
|
+
):
|
|
23
|
+
self.train_path = train_path
|
|
24
|
+
self.eval_path = eval_path
|
|
25
|
+
self.resume_path = resume_path
|
|
26
|
+
self.args = copy.deepcopy(args)
|
|
27
|
+
self.logger = logger
|
|
28
|
+
self.model = model
|
|
29
|
+
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.train_source = self.resolve_train_source()
|
|
33
|
+
self.eval_source = self.resolve_eval_source()
|
|
34
|
+
|
|
35
|
+
def validate_init(self):
|
|
36
|
+
self._validate_required_inputs()
|
|
37
|
+
self._validate_model_or_resume_path()
|
|
38
|
+
self._validate_file_paths_exist()
|
|
39
|
+
self._validate_args_type()
|
|
40
|
+
self._validate_train_args()
|
|
41
|
+
|
|
42
|
+
def _validate_required_inputs(self):
|
|
43
|
+
if not self.train_path and not self.resume_path:
|
|
44
|
+
raise EricInputError(
|
|
45
|
+
"You must provide at least one of: `train_path` or `resume_path`."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def _validate_model_or_resume_path(self):
|
|
49
|
+
if self.model is None and not self.resume_path:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Either load a model when calling init or provide resume_path"
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def _validate_file_paths_exist(self):
|
|
55
|
+
if self.train_path and not (
|
|
56
|
+
os.path.isfile(self.train_path) or os.path.isdir(self.train_path)
|
|
57
|
+
):
|
|
58
|
+
raise EricInputError(
|
|
59
|
+
f"Input file or tok dir not found: {self.train_path}"
|
|
60
|
+
)
|
|
61
|
+
if self.eval_path and not (
|
|
62
|
+
os.path.isfile(self.eval_path) or os.path.isdir(self.eval_path)
|
|
63
|
+
):
|
|
64
|
+
raise EricInputError(f"Eval file not found: {self.eval_path}")
|
|
65
|
+
if self.resume_path and not os.path.isdir(self.resume_path):
|
|
66
|
+
raise EricInputError(
|
|
67
|
+
f"Resume path directory not found: {self.resume_path}"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
def _validate_args_type(self):
|
|
71
|
+
if not isinstance(self.args, EricTrainArgs):
|
|
72
|
+
raise TypeError(
|
|
73
|
+
f"Expected args to be of type TrainArgs but got {type(self.args).__name__}"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _validate_train_args(self):
|
|
78
|
+
args = self.args
|
|
79
|
+
|
|
80
|
+
if args.lr <= 0 or args.lr > 1e-2:
|
|
81
|
+
raise EricInputError("`lr` must be > 0 and <= 1e-2.")
|
|
82
|
+
|
|
83
|
+
if args.epochs <= 0 or not isinstance(args.epochs, int):
|
|
84
|
+
raise EricInputError("`epochs` must be and int >= 1.")
|
|
85
|
+
|
|
86
|
+
if args.gas <= 0 or not isinstance(
|
|
87
|
+
args.gas, int
|
|
88
|
+
):
|
|
89
|
+
raise EricInputError("`gas` must be and int > 1.")
|
|
90
|
+
|
|
91
|
+
if args.bs <= 0 or not isinstance(args.bs, int):
|
|
92
|
+
raise EricInputError("`bs` must be an int > 0.")
|
|
93
|
+
|
|
94
|
+
if args.eval_bs <= -1 or not isinstance(args.eval_bs, int):
|
|
95
|
+
raise EricInputError("`eval_bs` must be and int >= 0.")
|
|
96
|
+
|
|
97
|
+
if args.eval_steps < 1 and args.eval_steps != 0:
|
|
98
|
+
self.logger.warning(
|
|
99
|
+
f"eval_steps must be an integer greater than 1, but got {args.eval_steps}. eval_steps will be set to 256."
|
|
100
|
+
)
|
|
101
|
+
args.eval_steps = 256
|
|
102
|
+
|
|
103
|
+
if args.log_steps < 1 and args.log_steps != 0:
|
|
104
|
+
self.logger.warning(
|
|
105
|
+
f"log_steps must be an integer greater than 1, but got {args.log_steps}. log_steps will be set to 256."
|
|
106
|
+
)
|
|
107
|
+
args.log_steps = 256
|
|
108
|
+
|
|
109
|
+
if args.checkpoint_steps < 1:
|
|
110
|
+
if args.checkpoint_steps != -1:
|
|
111
|
+
raise EricInputError(
|
|
112
|
+
"`checkpoint_steps` must be an int >= 1. Use -1 to disable checkpointing."
|
|
113
|
+
)
|
|
114
|
+
if args.lr_sched not in ["constant", "warmup_then_decay"]:
|
|
115
|
+
raise EricInputError(
|
|
116
|
+
"`lr_sched` must be 'constant' or 'warmup_then_decay'."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
if not isinstance(args.seed, int) or args.seed < 0:
|
|
120
|
+
raise EricInputError("`seed` must be a non-negative int.")
|
|
121
|
+
if hasattr(args, "project_name"):
|
|
122
|
+
if not isinstance(args.project_name, str):
|
|
123
|
+
raise EricInputError("`project_name` must be a string.")
|
|
124
|
+
|
|
125
|
+
if not isinstance(args.run_name, str):
|
|
126
|
+
raise EricInputError("`run_name` must be a string.")
|
|
127
|
+
|
|
128
|
+
if not isinstance(args.save_best, bool):
|
|
129
|
+
raise EricInputError("`save_best` must be a boolean.")
|
|
130
|
+
|
|
131
|
+
def resolve_train_source(self) -> str:
|
|
132
|
+
if os.path.isfile(self.train_path):
|
|
133
|
+
return "file"
|
|
134
|
+
elif os.path.isdir(self.train_path):
|
|
135
|
+
return "folder"
|
|
136
|
+
|
|
137
|
+
def resolve_eval_source(self) -> str:
|
|
138
|
+
if os.path.isfile(self.eval_path):
|
|
139
|
+
return "file"
|
|
140
|
+
elif os.path.isdir(self.eval_path):
|
|
141
|
+
return "folder"
|
|
142
|
+
else:
|
|
143
|
+
return "no_eval"
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
from erictransformer.validator.tasks.chat_validator import CHATValidator
|
|
2
|
+
from erictransformer.validator.tasks.gen_validator import GENValidator
|
|
3
|
+
from erictransformer.validator.tasks.task_validator import TaskValidator
|
|
4
|
+
from erictransformer.validator.tasks.tc_validator import TCValidator
|
|
5
|
+
from erictransformer.validator.tasks.tt_validator import TTValidator
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from logging import Logger
|
|
2
|
+
from typing import List, Union
|
|
3
|
+
|
|
4
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
5
|
+
|
|
6
|
+
from erictransformer.validator.tasks.task_validator import TaskValidator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CHATValidator(TaskValidator):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model_name: Union[str, PreTrainedModel, None],
|
|
13
|
+
trust_remote_code: bool,
|
|
14
|
+
tokenizer: Union[str, PreTrainedTokenizerBase],
|
|
15
|
+
logger: Logger,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(
|
|
18
|
+
model_name=model_name,
|
|
19
|
+
trust_remote_code=trust_remote_code,
|
|
20
|
+
tokenizer=tokenizer,
|
|
21
|
+
logger=logger,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def validate_init(self):
|
|
25
|
+
super().validate_init()
|
|
26
|
+
|
|
27
|
+
def validate_call(self, texts: Union[List[str], str], args=None):
|
|
28
|
+
super().validate_call(texts, args)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from logging import Logger
|
|
2
|
+
from typing import List, Union
|
|
3
|
+
|
|
4
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
5
|
+
|
|
6
|
+
from erictransformer.validator.tasks.task_validator import TaskValidator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class GENValidator(TaskValidator):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model_name: Union[str, PreTrainedModel, None],
|
|
13
|
+
trust_remote_code: bool,
|
|
14
|
+
tokenizer: Union[str, PreTrainedTokenizerBase],
|
|
15
|
+
logger: Logger,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(
|
|
18
|
+
model_name=model_name,
|
|
19
|
+
trust_remote_code=trust_remote_code,
|
|
20
|
+
tokenizer=tokenizer,
|
|
21
|
+
logger=logger,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def validate_init(self):
|
|
25
|
+
super().validate_init()
|
|
26
|
+
|
|
27
|
+
def validate_call(self, texts: Union[List[str], str], args=None):
|
|
28
|
+
super().validate_call(texts, args)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
from dataclasses import is_dataclass
|
|
3
|
+
from logging import Logger
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
7
|
+
|
|
8
|
+
from erictransformer.validator import EricValidator
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TaskValidator(EricValidator):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
model_name: Union[str, PreTrainedModel, None],
|
|
15
|
+
trust_remote_code: bool,
|
|
16
|
+
tokenizer: Union[str, PreTrainedTokenizerBase],
|
|
17
|
+
logger: Logger,
|
|
18
|
+
):
|
|
19
|
+
self.model_name = model_name
|
|
20
|
+
self.trust_remote_code = trust_remote_code
|
|
21
|
+
self.tokenizer = tokenizer
|
|
22
|
+
self.logger = logger
|
|
23
|
+
|
|
24
|
+
super().__init__()
|
|
25
|
+
|
|
26
|
+
@abstractmethod
|
|
27
|
+
def validate_init(self):
|
|
28
|
+
self._validate_model_name()
|
|
29
|
+
self._validate_tokenizer()
|
|
30
|
+
|
|
31
|
+
def _validate_model_name(self):
|
|
32
|
+
if not (
|
|
33
|
+
isinstance(self.model_name, (str, PreTrainedModel))
|
|
34
|
+
or self.model_name is None
|
|
35
|
+
):
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"model_name must be a string, PreTrainedModel instance, or None."
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def _validate_tokenizer(self):
|
|
41
|
+
if self.tokenizer is not None:
|
|
42
|
+
if not isinstance(self.tokenizer, (str, PreTrainedTokenizerBase)):
|
|
43
|
+
print(self.tokenizer)
|
|
44
|
+
raise ValueError(
|
|
45
|
+
"tokenizer must be a string, PreTrainedTokenizer, or PreTrainedTokenizerFast."
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def validate_call(self, text: str, args=None):
|
|
50
|
+
if not isinstance(text, str):
|
|
51
|
+
raise ValueError('"text" must be a string')
|
|
52
|
+
|
|
53
|
+
if not is_dataclass(args):
|
|
54
|
+
raise ValueError('"args" must be a dataclass')
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from logging import Logger
|
|
2
|
+
from typing import List, Union, Optional
|
|
3
|
+
|
|
4
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
5
|
+
|
|
6
|
+
from erictransformer.validator.tasks.task_validator import TaskValidator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TCValidator(TaskValidator):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model_name: Union[str, PreTrainedModel, None],
|
|
13
|
+
trust_remote_code: bool,
|
|
14
|
+
tokenizer: Union[str, PreTrainedTokenizerBase],
|
|
15
|
+
logger: Logger,
|
|
16
|
+
labels: Optional[List[str]] = None
|
|
17
|
+
):
|
|
18
|
+
self.labels = labels
|
|
19
|
+
super().__init__(
|
|
20
|
+
model_name=model_name,
|
|
21
|
+
trust_remote_code=trust_remote_code,
|
|
22
|
+
tokenizer=tokenizer,
|
|
23
|
+
logger=logger,
|
|
24
|
+
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def validate_init(self):
|
|
28
|
+
|
|
29
|
+
if self.labels is not None:
|
|
30
|
+
if type(self.labels) is not list:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"self.labels is not a list of strings"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
for label in self.labels:
|
|
36
|
+
if type(label) is not str:
|
|
37
|
+
raise ValueError(
|
|
38
|
+
"self.labels is not a list of strings."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
super().validate_init()
|
|
42
|
+
|
|
43
|
+
def validate_call(self, texts: Union[List[str], str], args=None):
|
|
44
|
+
super().validate_call(texts, args)
|
|
45
|
+
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from logging import Logger
|
|
2
|
+
from typing import Union
|
|
3
|
+
|
|
4
|
+
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
|
5
|
+
|
|
6
|
+
from erictransformer.validator.tasks.task_validator import TaskValidator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class TTValidator(TaskValidator):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model_name: Union[str, PreTrainedModel, None],
|
|
13
|
+
trust_remote_code: bool,
|
|
14
|
+
tokenizer: Union[str, PreTrainedTokenizerBase],
|
|
15
|
+
logger: Logger,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(
|
|
18
|
+
model_name=model_name,
|
|
19
|
+
trust_remote_code=trust_remote_code,
|
|
20
|
+
tokenizer=tokenizer,
|
|
21
|
+
logger=logger,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
def validate_init(self):
|
|
25
|
+
super().validate_init()
|
|
26
|
+
|
|
27
|
+
def validate_call(self, text: str, args=None):
|
|
28
|
+
super().validate_call(text, args)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from erictransformer.validator.tok.tok_validator import TokValidator
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from erictransformer.args import TokArgs
|
|
2
|
+
from erictransformer.validator import EricValidator
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TokValidator(EricValidator):
|
|
6
|
+
def __init__(self, input_data: str, out_dir: str, args: TokArgs):
|
|
7
|
+
self.input_data = input_data
|
|
8
|
+
self.out_dir = out_dir
|
|
9
|
+
self.args = args
|
|
10
|
+
super().__init__()
|
|
11
|
+
|
|
12
|
+
def validate_init(self):
|
|
13
|
+
self._validate_input_data()
|
|
14
|
+
|
|
15
|
+
def _validate_input_data(self):
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
def _validate_out_dir(self):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
def _validate_args(self):
|
|
22
|
+
if not isinstance(self.args.procs, int) and self.args.procs < 0:
|
|
23
|
+
raise ValueError("procs must be an integer greater than or equal to 0")
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: erictransformer
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Local fine-tuning, pre-training and inference for LLMs
|
|
5
|
+
License: Apache 2.0
|
|
6
|
+
Project-URL: Repository, https://github.com/EricFillion/erictransformer
|
|
7
|
+
Keywords: transformer,fine-tuning,pretraining,training,AI,LLM,deep-learning,generative-ai
|
|
8
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Description-Content-Type: text/markdown
|
|
18
|
+
License-File: LICENSE
|
|
19
|
+
Requires-Dist: torch>=2.2
|
|
20
|
+
Requires-Dist: tqdm>=4.66.3
|
|
21
|
+
Requires-Dist: transformers<5.0.0,>=4.57.3
|
|
22
|
+
Requires-Dist: huggingface-hub<1.0,>=0.34.0
|
|
23
|
+
Requires-Dist: datasets<5.0.0,>=4.0.0
|
|
24
|
+
Requires-Dist: sentencepiece
|
|
25
|
+
Requires-Dist: protobuf
|
|
26
|
+
Requires-Dist: tokenizers<0.23.0,>=0.22.0
|
|
27
|
+
Requires-Dist: matplotlib<4.0.0,>=3.9.0
|
|
28
|
+
Requires-Dist: numpy<3.0.0,>=2.0.0
|
|
29
|
+
Requires-Dist: sentence-transformers<6.0.0,>=5.0.0
|
|
30
|
+
Requires-Dist: mlx-lm==0.29.1; platform_system == "Darwin" and platform_machine == "arm64"
|
|
31
|
+
Requires-Dist: ericsearch<1.0.0
|
|
32
|
+
Dynamic: license-file
|
|
33
|
+
|
|
34
|
+
<h1 align="center">
|
|
35
|
+
Eric Transformer
|
|
36
|
+
</h1>
|
|
37
|
+
|
|
38
|
+
<p align="center">
|
|
39
|
+
<a href="https://opensource.org/licenses/Apache-2.0">
|
|
40
|
+
<img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License: Apache-2.0" height="20">
|
|
41
|
+
</a>
|
|
42
|
+
|
|
43
|
+
</p>
|
|
44
|
+
|
|
45
|
+
<p align="center">
|
|
46
|
+
<strong><a href="https://ericfillion.github.io/erictransformer/">https://ericfillion.github.io/erictransformer/</a></strong>
|
|
47
|
+
</p>
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
Local pre-training, fine-tuning and inference for LLMs.
|
|
51
|
+
|
|
52
|
+
- Format your text data in JSONL and then use a few lines of code to train models.
|
|
53
|
+
- Full-parameter training of GPT-OSS-20b on a single H200.
|
|
54
|
+
- Use Apple's new MLX-LM framework for fast inference. Run GPT-OSS-120b locally.
|
|
55
|
+
- Enable RAG powered by [Eric Search](https://github.com/EricFillion/ericsearch).
|
|
56
|
+
- Local experiment tracking that displays charts and metrics.
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
## Install
|
|
60
|
+
```sh
|
|
61
|
+
pip install erictransformer
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
[Documentation](https://ericfillion.github.io/erictransformer)
|
|
65
|
+
|
|
66
|
+
## Maintainers
|
|
67
|
+
- [Eric Fillion](https://github.com/ericfillion) Lead Maintainer
|
|
68
|
+
- [Ted Brownlow](https://github.com/ted537) Maintainer
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
## Contributing
|
|
72
|
+
We are currently not accepting contributions.
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
erictransformer/__init__.py,sha256=Fb88UVYjpX6BQORqjpk4QjG0R26coK-2At4cPpIUKnc,971
|
|
2
|
+
erictransformer/eric_transformer.py,sha256=WBxNzE3G1YpK5npzRZJCzZZ03yUYQPfs2PSWGmtvf-A,17818
|
|
3
|
+
erictransformer/args/__init__.py,sha256=j-6eXRgikyNH8VTv2KwT1whtgQjMaWNBKrXTlvlqrmw,127
|
|
4
|
+
erictransformer/args/eric_args.py,sha256=q0XTfYzuJDyWYTKQv950tYl4tZ_yRiIo3jBnH417F2g,1086
|
|
5
|
+
erictransformer/eric_tasks/__init__.py,sha256=xnQKI4kPlJXUtSzT3LGRVtI7YkX78nz_TGqFRdO9BIA,1109
|
|
6
|
+
erictransformer/eric_tasks/eric_chat.py,sha256=QsrYXob4Eu3g7OZUn_3Rq-64Ntr5Qus-AsGfEW5CJaQ,12241
|
|
7
|
+
erictransformer/eric_tasks/eric_chat_mlx.py,sha256=Kzfyr9uu4lCcIbvWvvnkv_r5nZZTtOr34cC0nBzYh2M,9138
|
|
8
|
+
erictransformer/eric_tasks/eric_generation.py,sha256=C4QYkxWud4fQCdVH7pwKGAK6Iuu0oVgVVqqtoWgDfck,7538
|
|
9
|
+
erictransformer/eric_tasks/eric_text_classification.py,sha256=q6VTTVvPRICU57nM9-WHi30J7zt6wxmR6OhjTs-RmHo,6896
|
|
10
|
+
erictransformer/eric_tasks/eric_text_to_text.py,sha256=9V69Gth2BAWaaQznNJEp2Z273IzGZU4DqHFt_Fzo86c,8845
|
|
11
|
+
erictransformer/eric_tasks/args/__init__.py,sha256=xinzUPZreKbjPX2UVX542cW4jqz86f4eKOs3A5UwmN4,411
|
|
12
|
+
erictransformer/eric_tasks/args/eric_chat_args.py,sha256=6yX7fK5nXP3IzvcOMAN269iLxETmv4iPf1cbnrFqOOc,467
|
|
13
|
+
erictransformer/eric_tasks/args/eric_generation_args.py,sha256=l5Hh68gqNtihBpuPdkvZNqh0Fr68jLenpe0ZUM4bUgw,359
|
|
14
|
+
erictransformer/eric_tasks/args/eric_text_classification_args.py,sha256=7IvdbQawiU9AToIlw1QaeJ2ubw4lL1fM1D54sZ-rclU,225
|
|
15
|
+
erictransformer/eric_tasks/args/eric_text_to_text_args.py,sha256=EgfDZBlGQpugjoMTA3B-Q5C2TBV7J_6vixxdKIsH2EA,355
|
|
16
|
+
erictransformer/eric_tasks/chat_stream_handlers/__init__.py,sha256=KEHCWaKXUT4oQqsYU82pmYWveihxo82wen8-gT8XXmg,347
|
|
17
|
+
erictransformer/eric_tasks/chat_stream_handlers/args.py,sha256=UrGSzXBszkDOIdmzgd7dWcqBQT18E0EYYLdOWgOLlRA,283
|
|
18
|
+
erictransformer/eric_tasks/chat_stream_handlers/default.py,sha256=zvZHmwOADiRNyO9yVW3GqssF7YcGPrD58K9B63fBsDY,595
|
|
19
|
+
erictransformer/eric_tasks/chat_stream_handlers/gpt_oss.py,sha256=fFhTgyTINhpQRv0YkWIRLhFZ4TstYG-TuSahJ_-B76c,5002
|
|
20
|
+
erictransformer/eric_tasks/chat_stream_handlers/smol.py,sha256=imRTmhkuSG9BmbizoiO8uADZLOxPVkMeTZhBqCsbSkM,2563
|
|
21
|
+
erictransformer/eric_tasks/chat_stream_handlers/stream_handler.py,sha256=faLKcmXe3SIECIyiaUDt9nusuoZg-urRxSgdY41nsU0,521
|
|
22
|
+
erictransformer/eric_tasks/chat_templates/__init__.py,sha256=R92aJQsbCUf9HZUWuRxoO4Q-ReAwVv9NOm-GC2vqpo0,77
|
|
23
|
+
erictransformer/eric_tasks/chat_templates/convert.py,sha256=m2wOX3LrXaPBLtdMaeBfjiUx_csJTkeBvjzHcjWj25s,1749
|
|
24
|
+
erictransformer/eric_tasks/inference_engine/__init__.py,sha256=mX_bH2xjYmMWzEKx5B3fxbTz59TZYshbv6S8FqrDLbM,98
|
|
25
|
+
erictransformer/eric_tasks/inference_engine/text_classification.py,sha256=-CtRisiagoOoSt5t0cp1Mzw9mFJS2BczNjq5Bx54b0g,721
|
|
26
|
+
erictransformer/eric_tasks/misc/__init__.py,sha256=yZRzfOd8D2lXQ-2GP978MKE8j6wn_167lZJ-qn_Q1mc,330
|
|
27
|
+
erictransformer/eric_tasks/misc/call_utils.py,sha256=2g_k_feOjOxd14VtIDdKQV9k-fPRz_e2v9kgXhajpM8,1697
|
|
28
|
+
erictransformer/eric_tasks/misc/get_pad_eos.py,sha256=UiHPGaUYngmgTQt1xpGU3ivfyZLevY1Mgod221Zn1as,818
|
|
29
|
+
erictransformer/eric_tasks/misc/rag.py,sha256=2V74eLA4Nro-eFM_9RhnD7oq6_1mTy8EFyXSK3JgJNE,490
|
|
30
|
+
erictransformer/eric_tasks/results/__init__.py,sha256=k1uKA-tIHW0RovDkFOkgbwKO-XSc_7XeDfUOJc9F9Ts,122
|
|
31
|
+
erictransformer/eric_tasks/results/call_results.py,sha256=FKX418TiBZfWRv72kjAoaVyRj3FYFiIJ9MMkxloE2z4,504
|
|
32
|
+
erictransformer/eric_tasks/tok/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
33
|
+
erictransformer/eric_tasks/tok/tok_functions.py,sha256=aYm3josXkQozB4Wfbz38AK5AN4MTDak7_t2PmPAm3fc,3381
|
|
34
|
+
erictransformer/eric_tracker/__init__.py,sha256=wZR9Sc-hOAK6jfhu1tEUtRxKEDXQy391UMxmc78WVPI,66
|
|
35
|
+
erictransformer/eric_tracker/eric_tracker.py,sha256=bqGQfhCHcjDfAKjOs7-8RWhMOCDmUqGQoCBlDwBDWM8,8256
|
|
36
|
+
erictransformer/eric_tracker/save_plot.py,sha256=uCIElKeWFOqRIptF8EsW7826hTbJZGS_VtFm6VDoYGk,12921
|
|
37
|
+
erictransformer/eval_models/__init__.py,sha256=UESPUhXT5I4NesZ93AGA2jr9ACc7gCunj1R8n54QMaE,82
|
|
38
|
+
erictransformer/eval_models/eval_model.py,sha256=1S3sxeKitlw7uOyylPKBhzrhn4qAF9-UFOfqMkOw1hI,2493
|
|
39
|
+
erictransformer/exceptions/__init__.py,sha256=CgRMjDIV9_00WdvUMG1I8LupESCXwx5NLaEZvApohMo,440
|
|
40
|
+
erictransformer/exceptions/eric_exceptions.py,sha256=Pdx4JflFb1VnGpTOLpxl2ARpwE5yqvPtp7vfR7D5yK8,1078
|
|
41
|
+
erictransformer/loops/__init__.py,sha256=wOj26Ow4N5jAwCcNfhRfU9X3MdZkgD53L7HNbBUulaw,135
|
|
42
|
+
erictransformer/loops/eval_loop.py,sha256=O05CQpt4GkiW8520LlZTn2M18mlpvVQvzznOu9cMZOY,3610
|
|
43
|
+
erictransformer/loops/train_loop.py,sha256=Hv9R0_0C-xmS5cr8YmsvLfAJpfchn2JTUMHBuB0wREE,12779
|
|
44
|
+
erictransformer/utils/__init__.py,sha256=v7VUlR_xOcG0sf7Do-M4MvK0t1fGqlR7g7UAS0xUzvQ,685
|
|
45
|
+
erictransformer/utils/load_from_repo_or_path.py,sha256=gl8mig_wRFiLKKKt9W5dOmkS2-Y8u7RbOMlGCefuH2U,435
|
|
46
|
+
erictransformer/utils/init/__init__.py,sha256=D0CXoiL69UyCf4FoXNS1pprSVV3o1XsUueD94SH92jg,140
|
|
47
|
+
erictransformer/utils/init/get_components.py,sha256=cdzJ9nzIQTFkDWPkyEXwzR-J2fbKnLF8iRkW_DNiPJo,5924
|
|
48
|
+
erictransformer/utils/init/get_device.py,sha256=76OawLdU7_GklPKMDvkTVkFyYTh-pm0xOOTx5CzwWOA,526
|
|
49
|
+
erictransformer/utils/init/get_logger.py,sha256=oqSv9ClYjaSyDAkBcfrlEJwBsu6y7ae4-zQ0lK0Y7sk,433
|
|
50
|
+
erictransformer/utils/test/__init__.py,sha256=FBlwdf1LAiB1rGdGfXhjTZS2EH-9v4wofJ-vSnw1_SI,34
|
|
51
|
+
erictransformer/utils/test/debug_hook.py,sha256=jtnDUolEMTUt0DgwlFQGu2_XHjCbK4njz4afi4J3koM,407
|
|
52
|
+
erictransformer/utils/timer/__init__.py,sha256=Vh-OoSsW7DXhGgtg6OgrB4ClLDf_Unoe0p7l24wygvI,61
|
|
53
|
+
erictransformer/utils/timer/eric_timer.py,sha256=9J1w9pUhKLUYzFDV9TTPzRrIItsdVMZ-AbIzi0bI8PI,4848
|
|
54
|
+
erictransformer/utils/tok_data/__init__.py,sha256=TNNYUOOpN7fd-_hPc8CvTgNeHBli-NrYv2NroFCi3uw,358
|
|
55
|
+
erictransformer/utils/tok_data/num_proc.py,sha256=M9Gv7YJrKg51uyS3M5UO4PLXldyP8tRyDsSoCQybt1k,320
|
|
56
|
+
erictransformer/utils/tok_data/save_tok_data.py,sha256=BT7IVCPKRSIVgh9q9rlotWc3ESbEPTDC7ewRb4qa8zo,1122
|
|
57
|
+
erictransformer/utils/tok_data/tok_data_to_dataset.py,sha256=GOYje9W9JddXTh0dZZYwfvOlIKcFM0usjmQDuT2XW_Q,1540
|
|
58
|
+
erictransformer/utils/tok_data/tok_helpers.py,sha256=mCSdz7ujRVm5p98uDIRAhEmILQGxfpXImnDeaUJWI8k,2824
|
|
59
|
+
erictransformer/utils/train/__init__.py,sha256=eJqrwQXzP2-UuaztQv4nyEOM8w5fUZjt9adf9G_zk6I,431
|
|
60
|
+
erictransformer/utils/train/confirm_optimizer.py,sha256=PNqWlEj02nOzpR7tgYI4J8-SebxGrBkYm-S0lSt8VbY,560
|
|
61
|
+
erictransformer/utils/train/create_dir.py,sha256=0SCgCHZS3AxslLEKpIlevIVUZ4OVtBMu9rpUiZQRupI,1806
|
|
62
|
+
erictransformer/utils/train/get_num_training_steps.py,sha256=zCqYEejdYfGDogIY5hxSkEIxXnEJ5nCTweRO5bOOkFA,248
|
|
63
|
+
erictransformer/utils/train/get_precision.py,sha256=sNHL91eeE_bK5-uBwSivucfH_x0RMIUMXpF1jjdJl40,463
|
|
64
|
+
erictransformer/utils/train/get_tok_data.py,sha256=6LBr0hZYhHc8k0YkRCu5xxWD3_ulmDpgOMWEWZfimHM,3162
|
|
65
|
+
erictransformer/utils/train/resume.py,sha256=ybrt33dg4i_Rs-a8w0Timc6t6oD3XwDYjJaAH6MJ_Ns,1851
|
|
66
|
+
erictransformer/validator/__init__.py,sha256=9jF2ReL-XvDNy6PCcn9LhGNBUEDR8TjXAm7EjD8Y3Cs,336
|
|
67
|
+
erictransformer/validator/eric_validator.py,sha256=XEngi3mw9_yJmeuQDpsEehIPIkEG-LcjpKTvrwNZX9c,180
|
|
68
|
+
erictransformer/validator/eric/__init__.py,sha256=WLUjN4TMRSspd3uhhkUP3et8sog8EAXReqLbSqxwF88,146
|
|
69
|
+
erictransformer/validator/eric/eval_validator.py,sha256=7_w48dxEv4cm8c33BODyYfnbGCjaXE9uMqoflYUJCyk,2314
|
|
70
|
+
erictransformer/validator/eric/train_validator.py,sha256=yD3w-mATs7UI0xdPi1ZV0-NsY91v-P3p8x6CMv8tzmE,5083
|
|
71
|
+
erictransformer/validator/tasks/__init__.py,sha256=4RnF6CxTgg2ZeTMuaSQa6trnvfQPg-vVkTKESJvsuec,355
|
|
72
|
+
erictransformer/validator/tasks/chat_validator.py,sha256=3b40PZ3KMnHYlZIPlsFisz4bHPyjNFTOE67Tf5twGSo,800
|
|
73
|
+
erictransformer/validator/tasks/gen_validator.py,sha256=vOnL9Ui7Y9Arj_5f5RjURvGK9qr-isXoZt7Pu9YO_dI,799
|
|
74
|
+
erictransformer/validator/tasks/task_validator.py,sha256=pHN5yaFG63SHrWdydTn66eVfjakSKgtno_tczzA8Tuo,1684
|
|
75
|
+
erictransformer/validator/tasks/tc_validator.py,sha256=phJ5fTYvqEdDPpdd_mf63Whu3SYO95kNP83TyRg_4L4,1283
|
|
76
|
+
erictransformer/validator/tasks/tt_validator.py,sha256=8SKDfk1fB2UI4qRQhpav_g7CO1ruLKJyz_ZR1rM6Dhw,772
|
|
77
|
+
erictransformer/validator/tok/__init__.py,sha256=qBZpErkJnHdejVu8evycMNJPNdQIf7wZIDb4Iz9ec_A,69
|
|
78
|
+
erictransformer/validator/tok/tok_validator.py,sha256=dHi0tafMChZ33mjclqp3vi8MeXWv_XCsZLCm27BYzUA,671
|
|
79
|
+
erictransformer-0.0.1.dist-info/licenses/LICENSE,sha256=vFDZ93WFV6P9Yw5FB4xUX97Gbj66EIR4d--MfJS795s,11342
|
|
80
|
+
erictransformer-0.0.1.dist-info/METADATA,sha256=OHlJqz_reJ3KlgXZUP61zTYmiT4zi-T24SgY1gaGwwM,2540
|
|
81
|
+
erictransformer-0.0.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
82
|
+
erictransformer-0.0.1.dist-info/top_level.txt,sha256=lWsXEJsIM2bx2pyk0dteK44Eo6UwQCxCgRMKwnZCKyo,16
|
|
83
|
+
erictransformer-0.0.1.dist-info/RECORD,,
|