common-ai 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- common_ai-0.1.2/PKG-INFO +105 -0
- common_ai-0.1.2/README.md +82 -0
- common_ai-0.1.2/pyproject.toml +32 -0
- common_ai-0.1.2/src/common_ai/__init__.py +0 -0
- common_ai-0.1.2/src/common_ai/app.yaml +3 -0
- common_ai-0.1.2/src/common_ai/component/early_stopping.yaml +2 -0
- common_ai-0.1.2/src/common_ai/component/generator.yaml +1 -0
- common_ai-0.1.2/src/common_ai/component/initializer.yaml +1 -0
- common_ai-0.1.2/src/common_ai/component/logger.yaml +1 -0
- common_ai-0.1.2/src/common_ai/component/lr_scheduler.yaml +3 -0
- common_ai-0.1.2/src/common_ai/component/optimizer.yaml +3 -0
- common_ai-0.1.2/src/common_ai/component/profiler.yaml +3 -0
- common_ai-0.1.2/src/common_ai/config.py +198 -0
- common_ai-0.1.2/src/common_ai/dataset.py +18 -0
- common_ai-0.1.2/src/common_ai/early_stopping.py +38 -0
- common_ai-0.1.2/src/common_ai/explain.yaml +7 -0
- common_ai-0.1.2/src/common_ai/generator.py +42 -0
- common_ai-0.1.2/src/common_ai/gradio_fn.py +42 -0
- common_ai-0.1.2/src/common_ai/hpo.py +238 -0
- common_ai-0.1.2/src/common_ai/hpo.yaml +11 -0
- common_ai-0.1.2/src/common_ai/hta.py +305 -0
- common_ai-0.1.2/src/common_ai/hta.yaml +3 -0
- common_ai-0.1.2/src/common_ai/infer.yaml +6 -0
- common_ai-0.1.2/src/common_ai/inference.py +10 -0
- common_ai-0.1.2/src/common_ai/initializer.py +93 -0
- common_ai-0.1.2/src/common_ai/logger.py +21 -0
- common_ai-0.1.2/src/common_ai/lr_scheduler.py +111 -0
- common_ai-0.1.2/src/common_ai/metric.py +20 -0
- common_ai-0.1.2/src/common_ai/model.py +12 -0
- common_ai-0.1.2/src/common_ai/non_causality_hyena.py +304 -0
- common_ai-0.1.2/src/common_ai/optimizer.py +106 -0
- common_ai-0.1.2/src/common_ai/profiler.py +82 -0
- common_ai-0.1.2/src/common_ai/protein_bert.py +449 -0
- common_ai-0.1.2/src/common_ai/prs.py +643 -0
- common_ai-0.1.2/src/common_ai/shap.py +390 -0
- common_ai-0.1.2/src/common_ai/standalone_hyena.py +293 -0
- common_ai-0.1.2/src/common_ai/test.py +210 -0
- common_ai-0.1.2/src/common_ai/test.yaml +5 -0
- common_ai-0.1.2/src/common_ai/train.py +472 -0
- common_ai-0.1.2/src/common_ai/train.yaml +30 -0
- common_ai-0.1.2/src/common_ai/upload.py +95 -0
- common_ai-0.1.2/src/common_ai/upload.yaml +11 -0
- common_ai-0.1.2/src/common_ai/utils.py +125 -0
common_ai-0.1.2/PKG-INFO
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: common-ai
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Author: ljw
|
|
6
|
+
Author-email: ljw <ljw2017@sjtu.edu.cn>
|
|
7
|
+
Requires-Dist: deeplift
|
|
8
|
+
Requires-Dist: einops
|
|
9
|
+
Requires-Dist: gradio
|
|
10
|
+
Requires-Dist: holistictraceanalysis
|
|
11
|
+
Requires-Dist: jsonargparse
|
|
12
|
+
Requires-Dist: opt-einsum
|
|
13
|
+
Requires-Dist: optuna
|
|
14
|
+
Requires-Dist: pandas[hdf5]
|
|
15
|
+
Requires-Dist: plotnine
|
|
16
|
+
Requires-Dist: shap
|
|
17
|
+
Requires-Dist: tbparse
|
|
18
|
+
Requires-Dist: transformers
|
|
19
|
+
Requires-Dist: torch
|
|
20
|
+
Requires-Dist: datasets
|
|
21
|
+
Requires-Python: >=3.14
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
|
|
24
|
+
# Introduction
|
|
25
|
+
|
|
26
|
+
This repository contains AI libraries commonly used for all my AI projects.
|
|
27
|
+
|
|
28
|
+
# Train
|
|
29
|
+
|
|
30
|
+
The `MyTrain` class train subclass of huggingface `PreTrainedModel`. `model.state_dict` and `model.load_state_dict` must be consistent.
|
|
31
|
+
|
|
32
|
+
```mermaid
|
|
33
|
+
---
|
|
34
|
+
title: MyTrain.__call__
|
|
35
|
+
---
|
|
36
|
+
flowchart TD
|
|
37
|
+
INST[instantiate model and random generator] --> TRAININSTMETRICS[instantiate metrics] --> MODE{{evaluation only?}}
|
|
38
|
+
MODE -- yes --> EVALMODEL[<code>MyTrain.my_eval_model</code>]
|
|
39
|
+
|
|
40
|
+
subgraph EVALMODEL[<code>MyTrain.my_eval_model</code>]
|
|
41
|
+
direction TB
|
|
42
|
+
EVALLOOP[eval loop]
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
subgraph EVALLOOP[eval loop]
|
|
46
|
+
direction TB
|
|
47
|
+
CHECKCONSISTENCY[check config consistency] --> EVALLOADCHECKPOINT[load checkpoint for model and generator] --> EVALDEVICE[set model device] --> EVALDATALOADER[setup data loader] --> EVALEPOCHBRANCH{{implement <code>model.my_eval_epoch</code>?}}
|
|
48
|
+
EVALEPOCHBRANCH -- yes --> CUSTOMEVAL[<code>model.my_eval_epoch</code>]
|
|
49
|
+
EVALEPOCHBRANCH -- no --> COMMONEVAL[<code>MyTrain.my_eval_epoch</code>]
|
|
50
|
+
CUSTOMEVAL --> UPDATECONFIGPERFORM[update configuration]
|
|
51
|
+
COMMONEVAL --> UPDATECONFIGPERFORM
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
MODE -- no --> COMMONTRAIN[<code>MyTrain.my_train_model</code>]
|
|
55
|
+
|
|
56
|
+
subgraph COMMONTRAIN[<code>MyTrain.my_train_model</code>]
|
|
57
|
+
direction TB
|
|
58
|
+
CONTINUETRAIN{{last epoch is -1?}}
|
|
59
|
+
CONTINUETRAIN -- yes --> HASINIT{{implement <code>model.my_initialize_model</code>?}} -- yes --> CUSTOMINIT[<code>model.my_initialize_model</code>?]
|
|
60
|
+
HASINIT -- no --> INITWEIGHT[initialize model weights by <code>my_initializer</code>]
|
|
61
|
+
CONTINUETRAIN -- no --> TRAINLOADCHECK[load checkpoint for model and random generator]
|
|
62
|
+
CUSTOMINIT --> TRAINDEVICE[set model device]
|
|
63
|
+
INITWEIGHT --> TRAINDEVICE
|
|
64
|
+
TRAINLOADCHECK --> TRAINDEVICE
|
|
65
|
+
TRAINDEVICE --> INSTOPSC[instantiate optimizer and lr scheduler] --> CONTINUETRAIN2{{last epoch is -1?}} -- no --> TRAINCHECKOPSC[load checkpoint for optimizer and lr scheduler] --> SETUPOPSC[setup optimizer and lr_scheduler]
|
|
66
|
+
CONTINUETRAIN2{{last epoch is -1?}} -- yes --> SETUPOPSC
|
|
67
|
+
SETUPOPSC --> TRAINDATALOADER[setup data loader] --> INSTEARLYSTOP[instantiate early stopping] --> TRAINLOOP[train loop]
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
subgraph TRAINLOOP[train loop]
|
|
71
|
+
direction TB
|
|
72
|
+
M{{implement <code>model.my_train_epoch</code>?}}
|
|
73
|
+
M -- yes --> N[<code>model.my_train_epoch</code>]
|
|
74
|
+
M -- no --> O[<code>MyTrain.my_train_epoch</code>]
|
|
75
|
+
N --> P{{implement <code>model.my_eval_epoch</code>?}}
|
|
76
|
+
O --> P
|
|
77
|
+
P -- yes --> Q[<code>model.my_eval_epoch</code>]
|
|
78
|
+
P -- no --> R[<code>MyTrain.my_eval_epoch</code>]
|
|
79
|
+
Q --> UPDATELR[update learning rate]
|
|
80
|
+
R --> UPDATELR
|
|
81
|
+
UPDATELR --> TRAINSAVE[save epoch configuration and checkpoint] --> EARLYSTOP[check early stopping]
|
|
82
|
+
end
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
# Test
|
|
86
|
+
|
|
87
|
+
The `MyTest` class test subclass of huggingface `PreTrainedModel`. `MyTest` will load the epoch saved by `MyTrain`. If `model.my_train_model` is implemented, then the corresponding `model.my_load_model` is necessary.
|
|
88
|
+
|
|
89
|
+
```mermaid
|
|
90
|
+
---
|
|
91
|
+
title: MyTest.__call__
|
|
92
|
+
---
|
|
93
|
+
flowchart TD
|
|
94
|
+
INSTMODEL[instantiate model and random generator] --> INSTMETRIC[instantiate metrics] --> LOADCHECK[load checkpoint for model and random generator] --> TESTDEVICE[set model device] --> TESTDATA[setup data loader] --> TESTMODEL[test model] --> TESTSAVE[save metrics]
|
|
95
|
+
```
|
|
96
|
+
|
|
97
|
+
# Metric
|
|
98
|
+
|
|
99
|
+
The metric classes should implement three methods.
|
|
100
|
+
1. `__init__` intialized the parameters and metric state.
|
|
101
|
+
2. `step` process the batchs. It receives:
|
|
102
|
+
- `df`: the data frame returned by the model's `eval_output` method.
|
|
103
|
+
- `examples`: the examples in the dataset.
|
|
104
|
+
- `batch`: the batch returned by the model's `data_collator`.
|
|
105
|
+
3. `epoch` accumulate all batch results, reinitialize the metric state and return the final metric.
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Introduction
|
|
2
|
+
|
|
3
|
+
This repository contains AI libraries commonly used for all my AI projects.
|
|
4
|
+
|
|
5
|
+
# Train
|
|
6
|
+
|
|
7
|
+
The `MyTrain` class train subclass of huggingface `PreTrainedModel`. `model.state_dict` and `model.load_state_dict` must be consistent.
|
|
8
|
+
|
|
9
|
+
```mermaid
|
|
10
|
+
---
|
|
11
|
+
title: MyTrain.__call__
|
|
12
|
+
---
|
|
13
|
+
flowchart TD
|
|
14
|
+
INST[instantiate model and random generator] --> TRAININSTMETRICS[instantiate metrics] --> MODE{{evaluation only?}}
|
|
15
|
+
MODE -- yes --> EVALMODEL[<code>MyTrain.my_eval_model</code>]
|
|
16
|
+
|
|
17
|
+
subgraph EVALMODEL[<code>MyTrain.my_eval_model</code>]
|
|
18
|
+
direction TB
|
|
19
|
+
EVALLOOP[eval loop]
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
subgraph EVALLOOP[eval loop]
|
|
23
|
+
direction TB
|
|
24
|
+
CHECKCONSISTENCY[check config consistency] --> EVALLOADCHECKPOINT[load checkpoint for model and generator] --> EVALDEVICE[set model device] --> EVALDATALOADER[setup data loader] --> EVALEPOCHBRANCH{{implement <code>model.my_eval_epoch</code>?}}
|
|
25
|
+
EVALEPOCHBRANCH -- yes --> CUSTOMEVAL[<code>model.my_eval_epoch</code>]
|
|
26
|
+
EVALEPOCHBRANCH -- no --> COMMONEVAL[<code>MyTrain.my_eval_epoch</code>]
|
|
27
|
+
CUSTOMEVAL --> UPDATECONFIGPERFORM[update configuration]
|
|
28
|
+
COMMONEVAL --> UPDATECONFIGPERFORM
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
MODE -- no --> COMMONTRAIN[<code>MyTrain.my_train_model</code>]
|
|
32
|
+
|
|
33
|
+
subgraph COMMONTRAIN[<code>MyTrain.my_train_model</code>]
|
|
34
|
+
direction TB
|
|
35
|
+
CONTINUETRAIN{{last epoch is -1?}}
|
|
36
|
+
CONTINUETRAIN -- yes --> HASINIT{{implement <code>model.my_initialize_model</code>?}} -- yes --> CUSTOMINIT[<code>model.my_initialize_model</code>?]
|
|
37
|
+
HASINIT -- no --> INITWEIGHT[initialize model weights by <code>my_initializer</code>]
|
|
38
|
+
CONTINUETRAIN -- no --> TRAINLOADCHECK[load checkpoint for model and random generator]
|
|
39
|
+
CUSTOMINIT --> TRAINDEVICE[set model device]
|
|
40
|
+
INITWEIGHT --> TRAINDEVICE
|
|
41
|
+
TRAINLOADCHECK --> TRAINDEVICE
|
|
42
|
+
TRAINDEVICE --> INSTOPSC[instantiate optimizer and lr scheduler] --> CONTINUETRAIN2{{last epoch is -1?}} -- no --> TRAINCHECKOPSC[load checkpoint for optimizer and lr scheduler] --> SETUPOPSC[setup optimizer and lr_scheduler]
|
|
43
|
+
CONTINUETRAIN2{{last epoch is -1?}} -- yes --> SETUPOPSC
|
|
44
|
+
SETUPOPSC --> TRAINDATALOADER[setup data loader] --> INSTEARLYSTOP[instantiate early stopping] --> TRAINLOOP[train loop]
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
subgraph TRAINLOOP[train loop]
|
|
48
|
+
direction TB
|
|
49
|
+
M{{implement <code>model.my_train_epoch</code>?}}
|
|
50
|
+
M -- yes --> N[<code>model.my_train_epoch</code>]
|
|
51
|
+
M -- no --> O[<code>MyTrain.my_train_epoch</code>]
|
|
52
|
+
N --> P{{implement <code>model.my_eval_epoch</code>?}}
|
|
53
|
+
O --> P
|
|
54
|
+
P -- yes --> Q[<code>model.my_eval_epoch</code>]
|
|
55
|
+
P -- no --> R[<code>MyTrain.my_eval_epoch</code>]
|
|
56
|
+
Q --> UPDATELR[update learning rate]
|
|
57
|
+
R --> UPDATELR
|
|
58
|
+
UPDATELR --> TRAINSAVE[save epoch configuration and checkpoint] --> EARLYSTOP[check early stopping]
|
|
59
|
+
end
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
# Test
|
|
63
|
+
|
|
64
|
+
The `MyTest` class test subclass of huggingface `PreTrainedModel`. `MyTest` will load the epoch saved by `MyTrain`. If `model.my_train_model` is implemented, then the corresponding `model.my_load_model` is necessary.
|
|
65
|
+
|
|
66
|
+
```mermaid
|
|
67
|
+
---
|
|
68
|
+
title: MyTest.__call__
|
|
69
|
+
---
|
|
70
|
+
flowchart TD
|
|
71
|
+
INSTMODEL[instantiate model and random generator] --> INSTMETRIC[instantiate metrics] --> LOADCHECK[load checkpoint for model and random generator] --> TESTDEVICE[set model device] --> TESTDATA[setup data loader] --> TESTMODEL[test model] --> TESTSAVE[save metrics]
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
# Metric
|
|
75
|
+
|
|
76
|
+
The metric classes should implement three methods.
|
|
77
|
+
1. `__init__` intialized the parameters and metric state.
|
|
78
|
+
2. `step` process the batchs. It receives:
|
|
79
|
+
- `df`: the data frame returned by the model's `eval_output` method.
|
|
80
|
+
- `examples`: the examples in the dataset.
|
|
81
|
+
- `batch`: the batch returned by the model's `data_collator`.
|
|
82
|
+
3. `epoch` accumulate all batch results, reinitialize the metric state and return the final metric.
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "common-ai"
|
|
3
|
+
version = "0.1.2"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "ljw", email = "ljw2017@sjtu.edu.cn" }
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.14"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"deeplift",
|
|
12
|
+
"einops",
|
|
13
|
+
"gradio",
|
|
14
|
+
"holistictraceanalysis",
|
|
15
|
+
"jsonargparse",
|
|
16
|
+
"opt-einsum",
|
|
17
|
+
"optuna",
|
|
18
|
+
"pandas[hdf5]",
|
|
19
|
+
"plotnine",
|
|
20
|
+
"shap",
|
|
21
|
+
"tbparse",
|
|
22
|
+
"transformers",
|
|
23
|
+
"torch",
|
|
24
|
+
"datasets",
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
[project.scripts]
|
|
28
|
+
common-ai = "common_ai:main"
|
|
29
|
+
|
|
30
|
+
[build-system]
|
|
31
|
+
requires = ["uv_build>=0.9.16,<0.10.0"]
|
|
32
|
+
build-backend = "uv_build"
|
|
File without changes
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
seed: 63036
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
name: kaiming_uniform_ # uniform_, normal_, xavier_uniform_, xavier_normal_, kaiming_uniform_, kaiming_normal_, trunc_normal_
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
log_level: WARNING # CRITICAL, FATAL, ERROR, WARNING, INFO, DEBUG, NOTSET
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import jsonargparse
|
|
2
|
+
from .train import MyTrain
|
|
3
|
+
from .test import MyTest
|
|
4
|
+
from .hta import MyHta
|
|
5
|
+
from .hpo import MyHpo
|
|
6
|
+
from .logger import get_logger
|
|
7
|
+
from .generator import MyGenerator
|
|
8
|
+
from .initializer import MyInitializer
|
|
9
|
+
from .optimizer import MyOptimizer
|
|
10
|
+
from .lr_scheduler import MyLrScheduler
|
|
11
|
+
from .early_stopping import MyEarlyStopping
|
|
12
|
+
from .profiler import MyProfiler
|
|
13
|
+
from .dataset import MyDatasetAbstract
|
|
14
|
+
from .metric import MyMetricAbstract
|
|
15
|
+
from .model import MyModelAbstract
|
|
16
|
+
from .inference import MyInferenceAbstract
|
|
17
|
+
from .shap import MyShapAbstract
|
|
18
|
+
from .upload import MyUpload
|
|
19
|
+
|
|
20
|
+
def get_train_parser() -> jsonargparse.ArgumentParser:
|
|
21
|
+
train_parser = jsonargparse.ArgumentParser(description="Train AI models.")
|
|
22
|
+
train_parser.add_argument("--config", action="config")
|
|
23
|
+
train_parser.add_class_arguments(theclass=MyTrain, nested_key="train")
|
|
24
|
+
|
|
25
|
+
train_parser.add_function_arguments(
|
|
26
|
+
function=get_logger,
|
|
27
|
+
nested_key="logger",
|
|
28
|
+
)
|
|
29
|
+
train_parser.add_class_arguments(
|
|
30
|
+
theclass=MyGenerator,
|
|
31
|
+
nested_key="generator",
|
|
32
|
+
)
|
|
33
|
+
train_parser.add_class_arguments(
|
|
34
|
+
theclass=MyInitializer,
|
|
35
|
+
nested_key="initializer",
|
|
36
|
+
)
|
|
37
|
+
train_parser.add_class_arguments(
|
|
38
|
+
theclass=MyOptimizer,
|
|
39
|
+
nested_key="optimizer",
|
|
40
|
+
)
|
|
41
|
+
train_parser.add_class_arguments(
|
|
42
|
+
theclass=MyLrScheduler,
|
|
43
|
+
nested_key="lr_scheduler",
|
|
44
|
+
)
|
|
45
|
+
train_parser.add_class_arguments(
|
|
46
|
+
theclass=MyEarlyStopping,
|
|
47
|
+
nested_key="early_stopping",
|
|
48
|
+
)
|
|
49
|
+
train_parser.add_class_arguments(
|
|
50
|
+
theclass=MyProfiler,
|
|
51
|
+
nested_key="profiler",
|
|
52
|
+
)
|
|
53
|
+
train_parser.add_subclass_arguments(
|
|
54
|
+
baseclass=MyDatasetAbstract,
|
|
55
|
+
nested_key="dataset",
|
|
56
|
+
)
|
|
57
|
+
train_parser.add_argument(
|
|
58
|
+
"--metric",
|
|
59
|
+
nargs="+",
|
|
60
|
+
type=MyMetricAbstract,
|
|
61
|
+
required=True,
|
|
62
|
+
enable_path=True,
|
|
63
|
+
)
|
|
64
|
+
train_parser.add_subclass_arguments(
|
|
65
|
+
baseclass=MyModelAbstract,
|
|
66
|
+
nested_key="model",
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
return train_parser
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_test_parser() -> jsonargparse.ArgumentParser:
|
|
73
|
+
test_parser = jsonargparse.ArgumentParser(description="Test AI models.")
|
|
74
|
+
test_parser.add_argument("--config", action="config")
|
|
75
|
+
test_parser.add_class_arguments(theclass=MyTest, nested_key=None)
|
|
76
|
+
|
|
77
|
+
return test_parser
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def get_infer_parser() -> jsonargparse.ArgumentParser:
|
|
81
|
+
infer_parser = jsonargparse.ArgumentParser(description="Infer AI models.")
|
|
82
|
+
infer_parser.add_argument("--config", action="config")
|
|
83
|
+
infer_parser.add_argument("--input", required=True, type=str, help="input file")
|
|
84
|
+
infer_parser.add_argument("--output", required=True, type=str, help="output file")
|
|
85
|
+
infer_parser.add_subclass_arguments(
|
|
86
|
+
baseclass=MyInferenceAbstract, nested_key="inference"
|
|
87
|
+
)
|
|
88
|
+
infer_parser.add_argument(
|
|
89
|
+
"--test", action=jsonargparse.ActionParser(parser=get_test_parser())
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return infer_parser
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_explain_parser() -> jsonargparse.ArgumentParser:
|
|
96
|
+
explain_parser = jsonargparse.ArgumentParser(description="Explain AI models.")
|
|
97
|
+
explain_parser.add_argument("--config", action="config")
|
|
98
|
+
explain_parser.add_subclass_arguments(baseclass=MyShapAbstract, nested_key="shap")
|
|
99
|
+
explain_parser.add_subclass_arguments(
|
|
100
|
+
baseclass=MyInferenceAbstract, nested_key="inference"
|
|
101
|
+
)
|
|
102
|
+
explain_parser.add_argument(
|
|
103
|
+
"--test", action=jsonargparse.ActionParser(parser=get_test_parser())
|
|
104
|
+
)
|
|
105
|
+
explain_parser.add_subclass_arguments(
|
|
106
|
+
baseclass=MyDatasetAbstract, nested_key="dataset"
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
return explain_parser
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def get_app_parser() -> jsonargparse.ArgumentParser:
|
|
113
|
+
app_parser = jsonargparse.ArgumentParser(description="App AI models.")
|
|
114
|
+
app_parser.add_argument("--config", action="config")
|
|
115
|
+
app_parser.add_argument(
|
|
116
|
+
"--inference",
|
|
117
|
+
nargs="+",
|
|
118
|
+
type=MyInferenceAbstract,
|
|
119
|
+
required=True,
|
|
120
|
+
enable_path=True,
|
|
121
|
+
)
|
|
122
|
+
app_parser.add_argument(
|
|
123
|
+
"--test",
|
|
124
|
+
nargs="+",
|
|
125
|
+
type=MyTest,
|
|
126
|
+
required=True,
|
|
127
|
+
enable_path=True,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
return app_parser
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def get_hta_parser() -> jsonargparse.ArgumentParser:
|
|
134
|
+
hta_parser = jsonargparse.ArgumentParser(description="Hta AI models.")
|
|
135
|
+
hta_parser.add_argument("--config", action="config")
|
|
136
|
+
hta_parser.add_class_arguments(theclass=MyHta, nested_key=None)
|
|
137
|
+
|
|
138
|
+
return hta_parser
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_hpo_parser() -> jsonargparse.ArgumentParser:
|
|
142
|
+
hpo_parser = jsonargparse.ArgumentParser(description="Hpo AI models.")
|
|
143
|
+
hpo_parser.add_argument("--config", action="config")
|
|
144
|
+
hpo_parser.add_class_arguments(theclass=MyHpo, nested_key="hpo")
|
|
145
|
+
hpo_parser.add_argument(
|
|
146
|
+
"--train", action=jsonargparse.ActionParser(parser=get_train_parser())
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
return hpo_parser
|
|
150
|
+
|
|
151
|
+
def get_upload_parser() -> jsonargparse.ArgumentParser:
|
|
152
|
+
upload_parser = jsonargparse.ArgumentParser(description="Upload AI models.")
|
|
153
|
+
upload_parser.add_argument("--config", action="config")
|
|
154
|
+
upload_parser.add_class_arguments(theclass=MyUpload, nested_key=None)
|
|
155
|
+
|
|
156
|
+
return upload_parser
|
|
157
|
+
|
|
158
|
+
def get_config() -> tuple[jsonargparse.ArgumentParser]:
|
|
159
|
+
parser = jsonargparse.ArgumentParser(
|
|
160
|
+
description="Arguments of AI models.",
|
|
161
|
+
)
|
|
162
|
+
subcommands = parser.add_subcommands(required=True, dest="subcommand")
|
|
163
|
+
|
|
164
|
+
train_parser = get_train_parser()
|
|
165
|
+
subcommands.add_subcommand(name="train", parser=train_parser)
|
|
166
|
+
|
|
167
|
+
test_parser = get_test_parser()
|
|
168
|
+
subcommands.add_subcommand(name="test", parser=test_parser)
|
|
169
|
+
|
|
170
|
+
infer_parser = get_infer_parser()
|
|
171
|
+
subcommands.add_subcommand(name="infer", parser=infer_parser)
|
|
172
|
+
|
|
173
|
+
explain_parser = get_explain_parser()
|
|
174
|
+
subcommands.add_subcommand(name="explain", parser=explain_parser)
|
|
175
|
+
|
|
176
|
+
app_parser = get_app_parser()
|
|
177
|
+
subcommands.add_subcommand(name="app", parser=app_parser)
|
|
178
|
+
|
|
179
|
+
hta_parser = get_hta_parser()
|
|
180
|
+
subcommands.add_subcommand(name="hta", parser=hta_parser)
|
|
181
|
+
|
|
182
|
+
hpo_parser = get_hpo_parser()
|
|
183
|
+
subcommands.add_subcommand(name="hpo", parser=hpo_parser)
|
|
184
|
+
|
|
185
|
+
upload_parser = get_upload_parser()
|
|
186
|
+
subcommands.add_subcommand(name="upload", parser=upload_parser)
|
|
187
|
+
|
|
188
|
+
return (
|
|
189
|
+
parser,
|
|
190
|
+
train_parser,
|
|
191
|
+
test_parser,
|
|
192
|
+
infer_parser,
|
|
193
|
+
explain_parser,
|
|
194
|
+
app_parser,
|
|
195
|
+
hta_parser,
|
|
196
|
+
hpo_parser,
|
|
197
|
+
upload_parser,
|
|
198
|
+
)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MyDatasetAbstract(ABC):
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
name: str,
|
|
8
|
+
):
|
|
9
|
+
self.name = name
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def __call__(self):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@classmethod
|
|
16
|
+
@abstractmethod
|
|
17
|
+
def hpo(cls):
|
|
18
|
+
pass
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MyEarlyStopping:
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
patience: Optional[int],
|
|
8
|
+
delta: float,
|
|
9
|
+
**kwargs,
|
|
10
|
+
) -> None:
|
|
11
|
+
"""Early stopping arguments.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
patience: early stopping patience.
|
|
15
|
+
delta: early stopping loss improvement threshold.
|
|
16
|
+
"""
|
|
17
|
+
self.patience = patience
|
|
18
|
+
self.delta = delta
|
|
19
|
+
self.remain_patience = patience
|
|
20
|
+
self.best_loss = float("inf")
|
|
21
|
+
|
|
22
|
+
def __call__(self, loss: float) -> bool:
|
|
23
|
+
if self.patience is None:
|
|
24
|
+
return False
|
|
25
|
+
|
|
26
|
+
if loss < self.best_loss - self.delta:
|
|
27
|
+
self.best_loss = loss
|
|
28
|
+
self.remain_patience = self.patience
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
self.remain_patience -= 1
|
|
32
|
+
if self.remain_patience > 0:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
# reset internal states
|
|
36
|
+
self.remain_patience = self.patience
|
|
37
|
+
self.best_loss = float("inf")
|
|
38
|
+
return True
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class MyGenerator:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
seed: int,
|
|
9
|
+
**kwargs,
|
|
10
|
+
) -> None:
|
|
11
|
+
"""Generator arguments.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
seed: Random seed.
|
|
15
|
+
"""
|
|
16
|
+
self.seed = seed
|
|
17
|
+
self.np_rng = np.random.default_rng(self.seed)
|
|
18
|
+
self.torch_c_rng = torch.Generator(device="cpu").manual_seed(self.seed)
|
|
19
|
+
if torch.cuda.is_available():
|
|
20
|
+
self.torch_g_rng = torch.Generator(device="cuda").manual_seed(self.seed)
|
|
21
|
+
|
|
22
|
+
def get_torch_generator_by_device(
|
|
23
|
+
self, device: str | torch.device
|
|
24
|
+
) -> torch.Generator:
|
|
25
|
+
if device == "cpu" or device == torch.device("cpu"):
|
|
26
|
+
return self.torch_c_rng
|
|
27
|
+
return self.torch_g_rng
|
|
28
|
+
|
|
29
|
+
def state_dict(self) -> dict:
|
|
30
|
+
state_dict = {
|
|
31
|
+
"np_rng": self.np_rng.bit_generator.state,
|
|
32
|
+
"torch_c_rng": self.torch_c_rng.get_state(),
|
|
33
|
+
}
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
state_dict.update({"torch_g_rng": self.torch_g_rng.get_state()})
|
|
36
|
+
return state_dict
|
|
37
|
+
|
|
38
|
+
def load_state_dict(self, state_dict: dict) -> None:
|
|
39
|
+
self.np_rng.bit_generator.state = state_dict["np_rng"]
|
|
40
|
+
self.torch_c_rng.set_state(state_dict["torch_c_rng"])
|
|
41
|
+
if torch.cuda.is_available():
|
|
42
|
+
self.torch_g_rng.set_state(state_dict["torch_g_rng"])
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import pathlib
|
|
3
|
+
import tempfile
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
|
|
6
|
+
import jsonargparse
|
|
7
|
+
|
|
8
|
+
from .test import MyTest
|
|
9
|
+
from .inference import MyInferenceAbstract
|
|
10
|
+
|
|
11
|
+
class MyGradioFnAbstract(ABC):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
app_cfg: jsonargparse.Namespace,
|
|
15
|
+
train_parser: jsonargparse.ArgumentParser,
|
|
16
|
+
) -> None:
|
|
17
|
+
self.DEFAULT_TEMP_DIR = pathlib.Path(tempfile.gettempdir())
|
|
18
|
+
self.train_parser = train_parser
|
|
19
|
+
self.inference_dict = {}
|
|
20
|
+
for test_cfg, inference_cfg in zip(app_cfg.test, app_cfg.inference):
|
|
21
|
+
test_cfg = test_cfg.init_args
|
|
22
|
+
my_test = MyTest(**test_cfg.as_dict())
|
|
23
|
+
_, train_cfg = my_test.load_train_cfg(train_parser)
|
|
24
|
+
_, preprocess, _, model_cls = train_cfg.model.class_path.rsplit(".", 3)
|
|
25
|
+
data_name = train_cfg.dataset.init_args.name
|
|
26
|
+
repo_id = f"{preprocess}_{model_cls}_{data_name}"
|
|
27
|
+
self.inference_dict[repo_id] = (inference_cfg, test_cfg)
|
|
28
|
+
|
|
29
|
+
def load_inference(self, repo_id: str) -> MyInferenceAbstract:
|
|
30
|
+
assert repo_id in self.inference_dict, f"repo id {repo_id} is not found"
|
|
31
|
+
inference_cfg, test_cfg = self.inference_dict[repo_id]
|
|
32
|
+
inference_module, inference_cls = inference_cfg.class_path.rsplit(".", 1)
|
|
33
|
+
my_inference = getattr(
|
|
34
|
+
importlib.import_module(inference_module), inference_cls
|
|
35
|
+
)(**inference_cfg.init_args.as_dict())
|
|
36
|
+
my_inference.load_model(test_cfg, self.train_parser)
|
|
37
|
+
|
|
38
|
+
return my_inference
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def launch(self):
|
|
42
|
+
pass
|