common-ai 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. common_ai-0.1.2/PKG-INFO +105 -0
  2. common_ai-0.1.2/README.md +82 -0
  3. common_ai-0.1.2/pyproject.toml +32 -0
  4. common_ai-0.1.2/src/common_ai/__init__.py +0 -0
  5. common_ai-0.1.2/src/common_ai/app.yaml +3 -0
  6. common_ai-0.1.2/src/common_ai/component/early_stopping.yaml +2 -0
  7. common_ai-0.1.2/src/common_ai/component/generator.yaml +1 -0
  8. common_ai-0.1.2/src/common_ai/component/initializer.yaml +1 -0
  9. common_ai-0.1.2/src/common_ai/component/logger.yaml +1 -0
  10. common_ai-0.1.2/src/common_ai/component/lr_scheduler.yaml +3 -0
  11. common_ai-0.1.2/src/common_ai/component/optimizer.yaml +3 -0
  12. common_ai-0.1.2/src/common_ai/component/profiler.yaml +3 -0
  13. common_ai-0.1.2/src/common_ai/config.py +198 -0
  14. common_ai-0.1.2/src/common_ai/dataset.py +18 -0
  15. common_ai-0.1.2/src/common_ai/early_stopping.py +38 -0
  16. common_ai-0.1.2/src/common_ai/explain.yaml +7 -0
  17. common_ai-0.1.2/src/common_ai/generator.py +42 -0
  18. common_ai-0.1.2/src/common_ai/gradio_fn.py +42 -0
  19. common_ai-0.1.2/src/common_ai/hpo.py +238 -0
  20. common_ai-0.1.2/src/common_ai/hpo.yaml +11 -0
  21. common_ai-0.1.2/src/common_ai/hta.py +305 -0
  22. common_ai-0.1.2/src/common_ai/hta.yaml +3 -0
  23. common_ai-0.1.2/src/common_ai/infer.yaml +6 -0
  24. common_ai-0.1.2/src/common_ai/inference.py +10 -0
  25. common_ai-0.1.2/src/common_ai/initializer.py +93 -0
  26. common_ai-0.1.2/src/common_ai/logger.py +21 -0
  27. common_ai-0.1.2/src/common_ai/lr_scheduler.py +111 -0
  28. common_ai-0.1.2/src/common_ai/metric.py +20 -0
  29. common_ai-0.1.2/src/common_ai/model.py +12 -0
  30. common_ai-0.1.2/src/common_ai/non_causality_hyena.py +304 -0
  31. common_ai-0.1.2/src/common_ai/optimizer.py +106 -0
  32. common_ai-0.1.2/src/common_ai/profiler.py +82 -0
  33. common_ai-0.1.2/src/common_ai/protein_bert.py +449 -0
  34. common_ai-0.1.2/src/common_ai/prs.py +643 -0
  35. common_ai-0.1.2/src/common_ai/shap.py +390 -0
  36. common_ai-0.1.2/src/common_ai/standalone_hyena.py +293 -0
  37. common_ai-0.1.2/src/common_ai/test.py +210 -0
  38. common_ai-0.1.2/src/common_ai/test.yaml +5 -0
  39. common_ai-0.1.2/src/common_ai/train.py +472 -0
  40. common_ai-0.1.2/src/common_ai/train.yaml +30 -0
  41. common_ai-0.1.2/src/common_ai/upload.py +95 -0
  42. common_ai-0.1.2/src/common_ai/upload.yaml +11 -0
  43. common_ai-0.1.2/src/common_ai/utils.py +125 -0
@@ -0,0 +1,105 @@
1
+ Metadata-Version: 2.3
2
+ Name: common-ai
3
+ Version: 0.1.2
4
+ Summary: Add your description here
5
+ Author: ljw
6
+ Author-email: ljw <ljw2017@sjtu.edu.cn>
7
+ Requires-Dist: deeplift
8
+ Requires-Dist: einops
9
+ Requires-Dist: gradio
10
+ Requires-Dist: holistictraceanalysis
11
+ Requires-Dist: jsonargparse
12
+ Requires-Dist: opt-einsum
13
+ Requires-Dist: optuna
14
+ Requires-Dist: pandas[hdf5]
15
+ Requires-Dist: plotnine
16
+ Requires-Dist: shap
17
+ Requires-Dist: tbparse
18
+ Requires-Dist: transformers
19
+ Requires-Dist: torch
20
+ Requires-Dist: datasets
21
+ Requires-Python: >=3.14
22
+ Description-Content-Type: text/markdown
23
+
24
+ # Introduction
25
+
26
+ This repository contains AI libraries commonly used for all my AI projects.
27
+
28
+ # Train
29
+
30
+ The `MyTrain` class train subclass of huggingface `PreTrainedModel`. `model.state_dict` and `model.load_state_dict` must be consistent.
31
+
32
+ ```mermaid
33
+ ---
34
+ title: MyTrain.__call__
35
+ ---
36
+ flowchart TD
37
+ INST[instantiate model and random generator] --> TRAININSTMETRICS[instantiate metrics] --> MODE{{evaluation only?}}
38
+ MODE -- yes --> EVALMODEL[<code>MyTrain.my_eval_model</code>]
39
+
40
+ subgraph EVALMODEL[<code>MyTrain.my_eval_model</code>]
41
+ direction TB
42
+ EVALLOOP[eval loop]
43
+ end
44
+
45
+ subgraph EVALLOOP[eval loop]
46
+ direction TB
47
+ CHECKCONSISTENCY[check config consistency] --> EVALLOADCHECKPOINT[load checkpoint for model and generator] --> EVALDEVICE[set model device] --> EVALDATALOADER[setup data loader] --> EVALEPOCHBRANCH{{implement <code>model.my_eval_epoch</code>?}}
48
+ EVALEPOCHBRANCH -- yes --> CUSTOMEVAL[<code>model.my_eval_epoch</code>]
49
+ EVALEPOCHBRANCH -- no --> COMMONEVAL[<code>MyTrain.my_eval_epoch</code>]
50
+ CUSTOMEVAL --> UPDATECONFIGPERFORM[update configuration]
51
+ COMMONEVAL --> UPDATECONFIGPERFORM
52
+ end
53
+
54
+ MODE -- no --> COMMONTRAIN[<code>MyTrain.my_train_model</code>]
55
+
56
+ subgraph COMMONTRAIN[<code>MyTrain.my_train_model</code>]
57
+ direction TB
58
+ CONTINUETRAIN{{last epoch is -1?}}
59
+ CONTINUETRAIN -- yes --> HASINIT{{implement <code>model.my_initialize_model</code>?}} -- yes --> CUSTOMINIT[<code>model.my_initialize_model</code>?]
60
+ HASINIT -- no --> INITWEIGHT[initialize model weights by <code>my_initializer</code>]
61
+ CONTINUETRAIN -- no --> TRAINLOADCHECK[load checkpoint for model and random generator]
62
+ CUSTOMINIT --> TRAINDEVICE[set model device]
63
+ INITWEIGHT --> TRAINDEVICE
64
+ TRAINLOADCHECK --> TRAINDEVICE
65
+ TRAINDEVICE --> INSTOPSC[instantiate optimizer and lr scheduler] --> CONTINUETRAIN2{{last epoch is -1?}} -- no --> TRAINCHECKOPSC[load checkpoint for optimizer and lr scheduler] --> SETUPOPSC[setup optimizer and lr_scheduler]
66
+ CONTINUETRAIN2{{last epoch is -1?}} -- yes --> SETUPOPSC
67
+ SETUPOPSC --> TRAINDATALOADER[setup data loader] --> INSTEARLYSTOP[instantiate early stopping] --> TRAINLOOP[train loop]
68
+ end
69
+
70
+ subgraph TRAINLOOP[train loop]
71
+ direction TB
72
+ M{{implement <code>model.my_train_epoch</code>?}}
73
+ M -- yes --> N[<code>model.my_train_epoch</code>]
74
+ M -- no --> O[<code>MyTrain.my_train_epoch</code>]
75
+ N --> P{{implement <code>model.my_eval_epoch</code>?}}
76
+ O --> P
77
+ P -- yes --> Q[<code>model.my_eval_epoch</code>]
78
+ P -- no --> R[<code>MyTrain.my_eval_epoch</code>]
79
+ Q --> UPDATELR[update learning rate]
80
+ R --> UPDATELR
81
+ UPDATELR --> TRAINSAVE[save epoch configuration and checkpoint] --> EARLYSTOP[check early stopping]
82
+ end
83
+ ```
84
+
85
+ # Test
86
+
87
+ The `MyTest` class test subclass of huggingface `PreTrainedModel`. `MyTest` will load the epoch saved by `MyTrain`. If `model.my_train_model` is implemented, then the corresponding `model.my_load_model` is necessary.
88
+
89
+ ```mermaid
90
+ ---
91
+ title: MyTest.__call__
92
+ ---
93
+ flowchart TD
94
+ INSTMODEL[instantiate model and random generator] --> INSTMETRIC[instantiate metrics] --> LOADCHECK[load checkpoint for model and random generator] --> TESTDEVICE[set model device] --> TESTDATA[setup data loader] --> TESTMODEL[test model] --> TESTSAVE[save metrics]
95
+ ```
96
+
97
+ # Metric
98
+
99
+ The metric classes should implement three methods.
100
+ 1. `__init__` intialized the parameters and metric state.
101
+ 2. `step` process the batchs. It receives:
102
+ - `df`: the data frame returned by the model's `eval_output` method.
103
+ - `examples`: the examples in the dataset.
104
+ - `batch`: the batch returned by the model's `data_collator`.
105
+ 3. `epoch` accumulate all batch results, reinitialize the metric state and return the final metric.
@@ -0,0 +1,82 @@
1
+ # Introduction
2
+
3
+ This repository contains AI libraries commonly used for all my AI projects.
4
+
5
+ # Train
6
+
7
+ The `MyTrain` class train subclass of huggingface `PreTrainedModel`. `model.state_dict` and `model.load_state_dict` must be consistent.
8
+
9
+ ```mermaid
10
+ ---
11
+ title: MyTrain.__call__
12
+ ---
13
+ flowchart TD
14
+ INST[instantiate model and random generator] --> TRAININSTMETRICS[instantiate metrics] --> MODE{{evaluation only?}}
15
+ MODE -- yes --> EVALMODEL[<code>MyTrain.my_eval_model</code>]
16
+
17
+ subgraph EVALMODEL[<code>MyTrain.my_eval_model</code>]
18
+ direction TB
19
+ EVALLOOP[eval loop]
20
+ end
21
+
22
+ subgraph EVALLOOP[eval loop]
23
+ direction TB
24
+ CHECKCONSISTENCY[check config consistency] --> EVALLOADCHECKPOINT[load checkpoint for model and generator] --> EVALDEVICE[set model device] --> EVALDATALOADER[setup data loader] --> EVALEPOCHBRANCH{{implement <code>model.my_eval_epoch</code>?}}
25
+ EVALEPOCHBRANCH -- yes --> CUSTOMEVAL[<code>model.my_eval_epoch</code>]
26
+ EVALEPOCHBRANCH -- no --> COMMONEVAL[<code>MyTrain.my_eval_epoch</code>]
27
+ CUSTOMEVAL --> UPDATECONFIGPERFORM[update configuration]
28
+ COMMONEVAL --> UPDATECONFIGPERFORM
29
+ end
30
+
31
+ MODE -- no --> COMMONTRAIN[<code>MyTrain.my_train_model</code>]
32
+
33
+ subgraph COMMONTRAIN[<code>MyTrain.my_train_model</code>]
34
+ direction TB
35
+ CONTINUETRAIN{{last epoch is -1?}}
36
+ CONTINUETRAIN -- yes --> HASINIT{{implement <code>model.my_initialize_model</code>?}} -- yes --> CUSTOMINIT[<code>model.my_initialize_model</code>?]
37
+ HASINIT -- no --> INITWEIGHT[initialize model weights by <code>my_initializer</code>]
38
+ CONTINUETRAIN -- no --> TRAINLOADCHECK[load checkpoint for model and random generator]
39
+ CUSTOMINIT --> TRAINDEVICE[set model device]
40
+ INITWEIGHT --> TRAINDEVICE
41
+ TRAINLOADCHECK --> TRAINDEVICE
42
+ TRAINDEVICE --> INSTOPSC[instantiate optimizer and lr scheduler] --> CONTINUETRAIN2{{last epoch is -1?}} -- no --> TRAINCHECKOPSC[load checkpoint for optimizer and lr scheduler] --> SETUPOPSC[setup optimizer and lr_scheduler]
43
+ CONTINUETRAIN2{{last epoch is -1?}} -- yes --> SETUPOPSC
44
+ SETUPOPSC --> TRAINDATALOADER[setup data loader] --> INSTEARLYSTOP[instantiate early stopping] --> TRAINLOOP[train loop]
45
+ end
46
+
47
+ subgraph TRAINLOOP[train loop]
48
+ direction TB
49
+ M{{implement <code>model.my_train_epoch</code>?}}
50
+ M -- yes --> N[<code>model.my_train_epoch</code>]
51
+ M -- no --> O[<code>MyTrain.my_train_epoch</code>]
52
+ N --> P{{implement <code>model.my_eval_epoch</code>?}}
53
+ O --> P
54
+ P -- yes --> Q[<code>model.my_eval_epoch</code>]
55
+ P -- no --> R[<code>MyTrain.my_eval_epoch</code>]
56
+ Q --> UPDATELR[update learning rate]
57
+ R --> UPDATELR
58
+ UPDATELR --> TRAINSAVE[save epoch configuration and checkpoint] --> EARLYSTOP[check early stopping]
59
+ end
60
+ ```
61
+
62
+ # Test
63
+
64
+ The `MyTest` class test subclass of huggingface `PreTrainedModel`. `MyTest` will load the epoch saved by `MyTrain`. If `model.my_train_model` is implemented, then the corresponding `model.my_load_model` is necessary.
65
+
66
+ ```mermaid
67
+ ---
68
+ title: MyTest.__call__
69
+ ---
70
+ flowchart TD
71
+ INSTMODEL[instantiate model and random generator] --> INSTMETRIC[instantiate metrics] --> LOADCHECK[load checkpoint for model and random generator] --> TESTDEVICE[set model device] --> TESTDATA[setup data loader] --> TESTMODEL[test model] --> TESTSAVE[save metrics]
72
+ ```
73
+
74
+ # Metric
75
+
76
+ The metric classes should implement three methods.
77
+ 1. `__init__` intialized the parameters and metric state.
78
+ 2. `step` process the batchs. It receives:
79
+ - `df`: the data frame returned by the model's `eval_output` method.
80
+ - `examples`: the examples in the dataset.
81
+ - `batch`: the batch returned by the model's `data_collator`.
82
+ 3. `epoch` accumulate all batch results, reinitialize the metric state and return the final metric.
@@ -0,0 +1,32 @@
1
+ [project]
2
+ name = "common-ai"
3
+ version = "0.1.2"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "ljw", email = "ljw2017@sjtu.edu.cn" }
8
+ ]
9
+ requires-python = ">=3.14"
10
+ dependencies = [
11
+ "deeplift",
12
+ "einops",
13
+ "gradio",
14
+ "holistictraceanalysis",
15
+ "jsonargparse",
16
+ "opt-einsum",
17
+ "optuna",
18
+ "pandas[hdf5]",
19
+ "plotnine",
20
+ "shap",
21
+ "tbparse",
22
+ "transformers",
23
+ "torch",
24
+ "datasets",
25
+ ]
26
+
27
+ [project.scripts]
28
+ common-ai = "common_ai:main"
29
+
30
+ [build-system]
31
+ requires = ["uv_build>=0.9.16,<0.10.0"]
32
+ build-backend = "uv_build"
File without changes
@@ -0,0 +1,3 @@
1
+ # inference:
2
+
3
+ # test:
@@ -0,0 +1,2 @@
1
+ patience: null
2
+ delta: 0.0
@@ -0,0 +1 @@
1
+ seed: 63036
@@ -0,0 +1 @@
1
+ name: kaiming_uniform_ # uniform_, normal_, xavier_uniform_, xavier_normal_, kaiming_uniform_, kaiming_normal_, trunc_normal_
@@ -0,0 +1 @@
1
+ log_level: WARNING # CRITICAL, FATAL, ERROR, WARNING, INFO, DEBUG, NOTSET
@@ -0,0 +1,3 @@
1
+ name: CosineAnnealingWarmRestarts # CosineAnnealingWarmRestarts, ConstantLR, ReduceLROnPlateau
2
+ warmup_epochs: 3
3
+ period_epochs: 30
@@ -0,0 +1,3 @@
1
+ name: AdamW # Adadelta, Adafactor, Adagrad, Adam, AdamW, Adamax, ASGD, NAdam, RAdam, RMSprop, SGD
2
+ # learning_rate:
3
+ weight_decay: 0.0
@@ -0,0 +1,3 @@
1
+ warmup: 1
2
+ active: 1
3
+ repeat: -1
@@ -0,0 +1,198 @@
1
+ import jsonargparse
2
+ from .train import MyTrain
3
+ from .test import MyTest
4
+ from .hta import MyHta
5
+ from .hpo import MyHpo
6
+ from .logger import get_logger
7
+ from .generator import MyGenerator
8
+ from .initializer import MyInitializer
9
+ from .optimizer import MyOptimizer
10
+ from .lr_scheduler import MyLrScheduler
11
+ from .early_stopping import MyEarlyStopping
12
+ from .profiler import MyProfiler
13
+ from .dataset import MyDatasetAbstract
14
+ from .metric import MyMetricAbstract
15
+ from .model import MyModelAbstract
16
+ from .inference import MyInferenceAbstract
17
+ from .shap import MyShapAbstract
18
+ from .upload import MyUpload
19
+
20
+ def get_train_parser() -> jsonargparse.ArgumentParser:
21
+ train_parser = jsonargparse.ArgumentParser(description="Train AI models.")
22
+ train_parser.add_argument("--config", action="config")
23
+ train_parser.add_class_arguments(theclass=MyTrain, nested_key="train")
24
+
25
+ train_parser.add_function_arguments(
26
+ function=get_logger,
27
+ nested_key="logger",
28
+ )
29
+ train_parser.add_class_arguments(
30
+ theclass=MyGenerator,
31
+ nested_key="generator",
32
+ )
33
+ train_parser.add_class_arguments(
34
+ theclass=MyInitializer,
35
+ nested_key="initializer",
36
+ )
37
+ train_parser.add_class_arguments(
38
+ theclass=MyOptimizer,
39
+ nested_key="optimizer",
40
+ )
41
+ train_parser.add_class_arguments(
42
+ theclass=MyLrScheduler,
43
+ nested_key="lr_scheduler",
44
+ )
45
+ train_parser.add_class_arguments(
46
+ theclass=MyEarlyStopping,
47
+ nested_key="early_stopping",
48
+ )
49
+ train_parser.add_class_arguments(
50
+ theclass=MyProfiler,
51
+ nested_key="profiler",
52
+ )
53
+ train_parser.add_subclass_arguments(
54
+ baseclass=MyDatasetAbstract,
55
+ nested_key="dataset",
56
+ )
57
+ train_parser.add_argument(
58
+ "--metric",
59
+ nargs="+",
60
+ type=MyMetricAbstract,
61
+ required=True,
62
+ enable_path=True,
63
+ )
64
+ train_parser.add_subclass_arguments(
65
+ baseclass=MyModelAbstract,
66
+ nested_key="model",
67
+ )
68
+
69
+ return train_parser
70
+
71
+
72
+ def get_test_parser() -> jsonargparse.ArgumentParser:
73
+ test_parser = jsonargparse.ArgumentParser(description="Test AI models.")
74
+ test_parser.add_argument("--config", action="config")
75
+ test_parser.add_class_arguments(theclass=MyTest, nested_key=None)
76
+
77
+ return test_parser
78
+
79
+
80
+ def get_infer_parser() -> jsonargparse.ArgumentParser:
81
+ infer_parser = jsonargparse.ArgumentParser(description="Infer AI models.")
82
+ infer_parser.add_argument("--config", action="config")
83
+ infer_parser.add_argument("--input", required=True, type=str, help="input file")
84
+ infer_parser.add_argument("--output", required=True, type=str, help="output file")
85
+ infer_parser.add_subclass_arguments(
86
+ baseclass=MyInferenceAbstract, nested_key="inference"
87
+ )
88
+ infer_parser.add_argument(
89
+ "--test", action=jsonargparse.ActionParser(parser=get_test_parser())
90
+ )
91
+
92
+ return infer_parser
93
+
94
+
95
+ def get_explain_parser() -> jsonargparse.ArgumentParser:
96
+ explain_parser = jsonargparse.ArgumentParser(description="Explain AI models.")
97
+ explain_parser.add_argument("--config", action="config")
98
+ explain_parser.add_subclass_arguments(baseclass=MyShapAbstract, nested_key="shap")
99
+ explain_parser.add_subclass_arguments(
100
+ baseclass=MyInferenceAbstract, nested_key="inference"
101
+ )
102
+ explain_parser.add_argument(
103
+ "--test", action=jsonargparse.ActionParser(parser=get_test_parser())
104
+ )
105
+ explain_parser.add_subclass_arguments(
106
+ baseclass=MyDatasetAbstract, nested_key="dataset"
107
+ )
108
+
109
+ return explain_parser
110
+
111
+
112
+ def get_app_parser() -> jsonargparse.ArgumentParser:
113
+ app_parser = jsonargparse.ArgumentParser(description="App AI models.")
114
+ app_parser.add_argument("--config", action="config")
115
+ app_parser.add_argument(
116
+ "--inference",
117
+ nargs="+",
118
+ type=MyInferenceAbstract,
119
+ required=True,
120
+ enable_path=True,
121
+ )
122
+ app_parser.add_argument(
123
+ "--test",
124
+ nargs="+",
125
+ type=MyTest,
126
+ required=True,
127
+ enable_path=True,
128
+ )
129
+
130
+ return app_parser
131
+
132
+
133
+ def get_hta_parser() -> jsonargparse.ArgumentParser:
134
+ hta_parser = jsonargparse.ArgumentParser(description="Hta AI models.")
135
+ hta_parser.add_argument("--config", action="config")
136
+ hta_parser.add_class_arguments(theclass=MyHta, nested_key=None)
137
+
138
+ return hta_parser
139
+
140
+
141
+ def get_hpo_parser() -> jsonargparse.ArgumentParser:
142
+ hpo_parser = jsonargparse.ArgumentParser(description="Hpo AI models.")
143
+ hpo_parser.add_argument("--config", action="config")
144
+ hpo_parser.add_class_arguments(theclass=MyHpo, nested_key="hpo")
145
+ hpo_parser.add_argument(
146
+ "--train", action=jsonargparse.ActionParser(parser=get_train_parser())
147
+ )
148
+
149
+ return hpo_parser
150
+
151
+ def get_upload_parser() -> jsonargparse.ArgumentParser:
152
+ upload_parser = jsonargparse.ArgumentParser(description="Upload AI models.")
153
+ upload_parser.add_argument("--config", action="config")
154
+ upload_parser.add_class_arguments(theclass=MyUpload, nested_key=None)
155
+
156
+ return upload_parser
157
+
158
+ def get_config() -> tuple[jsonargparse.ArgumentParser]:
159
+ parser = jsonargparse.ArgumentParser(
160
+ description="Arguments of AI models.",
161
+ )
162
+ subcommands = parser.add_subcommands(required=True, dest="subcommand")
163
+
164
+ train_parser = get_train_parser()
165
+ subcommands.add_subcommand(name="train", parser=train_parser)
166
+
167
+ test_parser = get_test_parser()
168
+ subcommands.add_subcommand(name="test", parser=test_parser)
169
+
170
+ infer_parser = get_infer_parser()
171
+ subcommands.add_subcommand(name="infer", parser=infer_parser)
172
+
173
+ explain_parser = get_explain_parser()
174
+ subcommands.add_subcommand(name="explain", parser=explain_parser)
175
+
176
+ app_parser = get_app_parser()
177
+ subcommands.add_subcommand(name="app", parser=app_parser)
178
+
179
+ hta_parser = get_hta_parser()
180
+ subcommands.add_subcommand(name="hta", parser=hta_parser)
181
+
182
+ hpo_parser = get_hpo_parser()
183
+ subcommands.add_subcommand(name="hpo", parser=hpo_parser)
184
+
185
+ upload_parser = get_upload_parser()
186
+ subcommands.add_subcommand(name="upload", parser=upload_parser)
187
+
188
+ return (
189
+ parser,
190
+ train_parser,
191
+ test_parser,
192
+ infer_parser,
193
+ explain_parser,
194
+ app_parser,
195
+ hta_parser,
196
+ hpo_parser,
197
+ upload_parser,
198
+ )
@@ -0,0 +1,18 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class MyDatasetAbstract(ABC):
5
+ def __init__(
6
+ self,
7
+ name: str,
8
+ ):
9
+ self.name = name
10
+
11
+ @abstractmethod
12
+ def __call__(self):
13
+ pass
14
+
15
+ @classmethod
16
+ @abstractmethod
17
+ def hpo(cls):
18
+ pass
@@ -0,0 +1,38 @@
1
+ from typing import Optional
2
+
3
+
4
+ class MyEarlyStopping:
5
+ def __init__(
6
+ self,
7
+ patience: Optional[int],
8
+ delta: float,
9
+ **kwargs,
10
+ ) -> None:
11
+ """Early stopping arguments.
12
+
13
+ Args:
14
+ patience: early stopping patience.
15
+ delta: early stopping loss improvement threshold.
16
+ """
17
+ self.patience = patience
18
+ self.delta = delta
19
+ self.remain_patience = patience
20
+ self.best_loss = float("inf")
21
+
22
+ def __call__(self, loss: float) -> bool:
23
+ if self.patience is None:
24
+ return False
25
+
26
+ if loss < self.best_loss - self.delta:
27
+ self.best_loss = loss
28
+ self.remain_patience = self.patience
29
+ return False
30
+
31
+ self.remain_patience -= 1
32
+ if self.remain_patience > 0:
33
+ return False
34
+
35
+ # reset internal states
36
+ self.remain_patience = self.patience
37
+ self.best_loss = float("inf")
38
+ return True
@@ -0,0 +1,7 @@
1
+ # shap:
2
+
3
+ # inference:
4
+
5
+ test: test.yaml
6
+
7
+ # dataset:
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class MyGenerator:
6
+ def __init__(
7
+ self,
8
+ seed: int,
9
+ **kwargs,
10
+ ) -> None:
11
+ """Generator arguments.
12
+
13
+ Args:
14
+ seed: Random seed.
15
+ """
16
+ self.seed = seed
17
+ self.np_rng = np.random.default_rng(self.seed)
18
+ self.torch_c_rng = torch.Generator(device="cpu").manual_seed(self.seed)
19
+ if torch.cuda.is_available():
20
+ self.torch_g_rng = torch.Generator(device="cuda").manual_seed(self.seed)
21
+
22
+ def get_torch_generator_by_device(
23
+ self, device: str | torch.device
24
+ ) -> torch.Generator:
25
+ if device == "cpu" or device == torch.device("cpu"):
26
+ return self.torch_c_rng
27
+ return self.torch_g_rng
28
+
29
+ def state_dict(self) -> dict:
30
+ state_dict = {
31
+ "np_rng": self.np_rng.bit_generator.state,
32
+ "torch_c_rng": self.torch_c_rng.get_state(),
33
+ }
34
+ if torch.cuda.is_available():
35
+ state_dict.update({"torch_g_rng": self.torch_g_rng.get_state()})
36
+ return state_dict
37
+
38
+ def load_state_dict(self, state_dict: dict) -> None:
39
+ self.np_rng.bit_generator.state = state_dict["np_rng"]
40
+ self.torch_c_rng.set_state(state_dict["torch_c_rng"])
41
+ if torch.cuda.is_available():
42
+ self.torch_g_rng.set_state(state_dict["torch_g_rng"])
@@ -0,0 +1,42 @@
1
+ import importlib
2
+ import pathlib
3
+ import tempfile
4
+ from abc import ABC, abstractmethod
5
+
6
+ import jsonargparse
7
+
8
+ from .test import MyTest
9
+ from .inference import MyInferenceAbstract
10
+
11
+ class MyGradioFnAbstract(ABC):
12
+ def __init__(
13
+ self,
14
+ app_cfg: jsonargparse.Namespace,
15
+ train_parser: jsonargparse.ArgumentParser,
16
+ ) -> None:
17
+ self.DEFAULT_TEMP_DIR = pathlib.Path(tempfile.gettempdir())
18
+ self.train_parser = train_parser
19
+ self.inference_dict = {}
20
+ for test_cfg, inference_cfg in zip(app_cfg.test, app_cfg.inference):
21
+ test_cfg = test_cfg.init_args
22
+ my_test = MyTest(**test_cfg.as_dict())
23
+ _, train_cfg = my_test.load_train_cfg(train_parser)
24
+ _, preprocess, _, model_cls = train_cfg.model.class_path.rsplit(".", 3)
25
+ data_name = train_cfg.dataset.init_args.name
26
+ repo_id = f"{preprocess}_{model_cls}_{data_name}"
27
+ self.inference_dict[repo_id] = (inference_cfg, test_cfg)
28
+
29
+ def load_inference(self, repo_id: str) -> MyInferenceAbstract:
30
+ assert repo_id in self.inference_dict, f"repo id {repo_id} is not found"
31
+ inference_cfg, test_cfg = self.inference_dict[repo_id]
32
+ inference_module, inference_cls = inference_cfg.class_path.rsplit(".", 1)
33
+ my_inference = getattr(
34
+ importlib.import_module(inference_module), inference_cls
35
+ )(**inference_cfg.init_args.as_dict())
36
+ my_inference.load_model(test_cfg, self.train_parser)
37
+
38
+ return my_inference
39
+
40
+ @abstractmethod
41
+ def launch(self):
42
+ pass