zarth-utils 1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zarth_utils-1.0/PKG-INFO +32 -0
- zarth_utils-1.0/README.md +1 -0
- zarth_utils-1.0/pyproject.toml +21 -0
- zarth_utils-1.0/requirements.txt +11 -0
- zarth_utils-1.0/setup.cfg +4 -0
- zarth_utils-1.0/tests/test_config.py +37 -0
- zarth_utils-1.0/zarth_utils/__init__.py +3 -0
- zarth_utils-1.0/zarth_utils/config.py +268 -0
- zarth_utils-1.0/zarth_utils/drawer.py +192 -0
- zarth_utils-1.0/zarth_utils/general_utils.py +26 -0
- zarth_utils-1.0/zarth_utils/jupyter_utils.py +7 -0
- zarth_utils-1.0/zarth_utils/logger.py +45 -0
- zarth_utils-1.0/zarth_utils/nn_utils.py +287 -0
- zarth_utils-1.0/zarth_utils/result_recorder.py +369 -0
- zarth_utils-1.0/zarth_utils/text_processing.py +68 -0
- zarth_utils-1.0/zarth_utils/timer.py +86 -0
- zarth_utils-1.0/zarth_utils.egg-info/PKG-INFO +32 -0
- zarth_utils-1.0/zarth_utils.egg-info/SOURCES.txt +19 -0
- zarth_utils-1.0/zarth_utils.egg-info/dependency_links.txt +1 -0
- zarth_utils-1.0/zarth_utils.egg-info/requires.txt +26 -0
- zarth_utils-1.0/zarth_utils.egg-info/top_level.txt +1 -0
zarth_utils-1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: zarth_utils
|
|
3
|
+
Version: 1.0
|
|
4
|
+
Summary: Package used for my personal development on ML projects.
|
|
5
|
+
Requires-Python: >=3.7
|
|
6
|
+
Description-Content-Type: text/x-rst
|
|
7
|
+
Requires-Dist: matplotlib
|
|
8
|
+
Requires-Dist: tqdm
|
|
9
|
+
Requires-Dist: pandas
|
|
10
|
+
Requires-Dist: numpy
|
|
11
|
+
Requires-Dist: joblib
|
|
12
|
+
Requires-Dist: gitpython
|
|
13
|
+
Requires-Dist: ipython
|
|
14
|
+
Requires-Dist: wandb
|
|
15
|
+
Requires-Dist: scikit-learn
|
|
16
|
+
Requires-Dist: gitpython
|
|
17
|
+
Requires-Dist: ipython
|
|
18
|
+
Provides-Extra: all
|
|
19
|
+
Requires-Dist: nltk; extra == "all"
|
|
20
|
+
Requires-Dist: torch; extra == "all"
|
|
21
|
+
Requires-Dist: tensorflow; extra == "all"
|
|
22
|
+
Requires-Dist: jupyterlab; extra == "all"
|
|
23
|
+
Provides-Extra: hf
|
|
24
|
+
Requires-Dist: transformers; extra == "hf"
|
|
25
|
+
Requires-Dist: accelerate; extra == "hf"
|
|
26
|
+
Requires-Dist: trl; extra == "hf"
|
|
27
|
+
Requires-Dist: datasets; extra == "hf"
|
|
28
|
+
Requires-Dist: diffusers; extra == "hf"
|
|
29
|
+
Requires-Dist: tokenizers; extra == "hf"
|
|
30
|
+
Requires-Dist: huggingface_hub; extra == "hf"
|
|
31
|
+
|
|
32
|
+
Package used for my personal development on ML projects.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
Package used for my personal development on ML projects.
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "zarth_utils"
|
|
7
|
+
description = "Package used for my personal development on ML projects."
|
|
8
|
+
version = "1.0"
|
|
9
|
+
dynamic = ["readme", "dependencies"]
|
|
10
|
+
requires-python = ">=3.7"
|
|
11
|
+
|
|
12
|
+
[project.optional-dependencies]
|
|
13
|
+
all = ["nltk", "torch", "tensorflow", "jupyterlab"]
|
|
14
|
+
hf = ["transformers", "accelerate", "trl", "datasets", "diffusers", "tokenizers", "huggingface_hub"]
|
|
15
|
+
|
|
16
|
+
[tool.setuptools.packages.find]
|
|
17
|
+
include = ["zarth_utils"]
|
|
18
|
+
|
|
19
|
+
[tool.setuptools.dynamic]
|
|
20
|
+
readme = { file = "README.md" }
|
|
21
|
+
dependencies = { file = "requirements.txt" }
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from unittest import TestCase
|
|
2
|
+
from zarth_utils.config import Config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TestConfig(TestCase):
|
|
6
|
+
def test_config(self):
|
|
7
|
+
config = Config(default_config_dict={
|
|
8
|
+
"a": "a",
|
|
9
|
+
"b": {
|
|
10
|
+
"c": "c",
|
|
11
|
+
"d": "d",
|
|
12
|
+
"e": {
|
|
13
|
+
"f": "f"
|
|
14
|
+
}
|
|
15
|
+
}
|
|
16
|
+
})
|
|
17
|
+
config.a = "A"
|
|
18
|
+
config["b.c"] = "X"
|
|
19
|
+
config.b.c = "Y"
|
|
20
|
+
config.b["c"] = "C"
|
|
21
|
+
config["b"]["d"] = "D"
|
|
22
|
+
config["b"].e["f"] = "F"
|
|
23
|
+
self.assertEqual(config.a, "A")
|
|
24
|
+
self.assertEqual(config.b.c, "C")
|
|
25
|
+
self.assertEqual(config.b.d, "D")
|
|
26
|
+
self.assertEqual(config.b.e.f, "F")
|
|
27
|
+
self.assertEqual(config["a"], "A")
|
|
28
|
+
self.assertEqual(config["b"]["c"], "C")
|
|
29
|
+
self.assertEqual(config["b"]["d"], "D")
|
|
30
|
+
self.assertEqual(config["b"]["e"]["f"], "F")
|
|
31
|
+
self.assertEqual(config.get("a"), "A")
|
|
32
|
+
self.assertEqual(config.get("b.c"), "C")
|
|
33
|
+
self.assertEqual(config.get("b.d"), "D")
|
|
34
|
+
self.assertEqual(config.get("b.e.f"), "F")
|
|
35
|
+
self.assertEqual(config["b"].e["f"], "F")
|
|
36
|
+
self.assertEqual(config.b.e["f"], "F")
|
|
37
|
+
self.assertEqual(config.b["e"].f, "F")
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import json
|
|
3
|
+
import argparse
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from .general_utils import get_random_time_stamp, makedir_if_not_exist
|
|
9
|
+
from .logger import logging_info
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
import wandb
|
|
13
|
+
except ModuleNotFoundError as err:
|
|
14
|
+
logging.warning("WandB not installed!")
|
|
15
|
+
|
|
16
|
+
dir_configs = os.path.join(os.getcwd(), "configs")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def smart_load(path_file):
|
|
20
|
+
if path_file.endswith("json"):
|
|
21
|
+
return json.load(open(path_file, "r", encoding="utf-8"))
|
|
22
|
+
elif path_file.endswith("yaml") or path_file.endswith("yml"):
|
|
23
|
+
return yaml.safe_load(open(path_file, "r", encoding="utf-8"))
|
|
24
|
+
else:
|
|
25
|
+
logging.warning("Un-identified file type. It will be processed as json by default.")
|
|
26
|
+
return json.load(open(path_file, "r", encoding="utf-8"))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class NestedDict(dict):
|
|
30
|
+
def __init__(self, *args, **kwargs):
|
|
31
|
+
"""
|
|
32
|
+
Every element could be visited by either attribute or dict manner.
|
|
33
|
+
|
|
34
|
+
Examples:
|
|
35
|
+
>>> a = NestedDict()
|
|
36
|
+
>>> a["b"]["c"] = 1
|
|
37
|
+
>>> a.b.c
|
|
38
|
+
1
|
|
39
|
+
>>> a.b.d = 2
|
|
40
|
+
>>> a["b"]["d"]
|
|
41
|
+
2
|
|
42
|
+
"""
|
|
43
|
+
super(NestedDict, self).__init__(*args, **kwargs)
|
|
44
|
+
for k in dict.keys(self):
|
|
45
|
+
if type(k) is dict:
|
|
46
|
+
dict.__setitem__(self, k, NestedDict(k))
|
|
47
|
+
|
|
48
|
+
def __getattr__(self, item):
|
|
49
|
+
return self[item]
|
|
50
|
+
|
|
51
|
+
def __setattr__(self, key, value):
|
|
52
|
+
self[key] = value
|
|
53
|
+
|
|
54
|
+
def __getitem__(self, key):
|
|
55
|
+
ret = self
|
|
56
|
+
for k in key.split("."):
|
|
57
|
+
ret = dict.__getitem__(ret, k)
|
|
58
|
+
return ret
|
|
59
|
+
|
|
60
|
+
def __setitem__(self, key, value):
|
|
61
|
+
key_list = key.split(".")
|
|
62
|
+
cur = self
|
|
63
|
+
for i in range(len(key_list)):
|
|
64
|
+
key = key_list[i]
|
|
65
|
+
if i == len(key_list) - 1:
|
|
66
|
+
dict.__setitem__(cur, key, value)
|
|
67
|
+
else:
|
|
68
|
+
if key in dict.keys(cur):
|
|
69
|
+
assert type(dict.__getitem__(cur, key)) is NestedDict
|
|
70
|
+
else:
|
|
71
|
+
dict.__setitem__(cur, key, NestedDict())
|
|
72
|
+
cur = dict.__getitem__(cur, key)
|
|
73
|
+
|
|
74
|
+
def update(self, new_dict, prefix=None):
|
|
75
|
+
for k in new_dict:
|
|
76
|
+
key = ".".join([prefix, k]) if prefix is not None else k
|
|
77
|
+
value = new_dict[k]
|
|
78
|
+
if type(value) is dict or type(value) is NestedDict:
|
|
79
|
+
self.update(value, prefix=key)
|
|
80
|
+
else:
|
|
81
|
+
self[key] = value
|
|
82
|
+
|
|
83
|
+
def keys(self, cur=None, prefix=None):
|
|
84
|
+
if cur is None:
|
|
85
|
+
cur = self
|
|
86
|
+
|
|
87
|
+
ret = []
|
|
88
|
+
for k in dict.keys(cur):
|
|
89
|
+
v = cur[k]
|
|
90
|
+
new_prefix = ".".join([prefix, k]) if prefix is not None else k
|
|
91
|
+
if type(v) is dict or type(v) is NestedDict:
|
|
92
|
+
ret += self.keys(cur=v, prefix=new_prefix)
|
|
93
|
+
else:
|
|
94
|
+
ret.append(new_prefix)
|
|
95
|
+
return ret
|
|
96
|
+
|
|
97
|
+
def get(self, item, default_value=None):
|
|
98
|
+
if item in self.keys():
|
|
99
|
+
return self[item]
|
|
100
|
+
return default_value
|
|
101
|
+
|
|
102
|
+
def show(self):
|
|
103
|
+
"""
|
|
104
|
+
Show all the configs in logging. If get_logger is used before, then the outputs will also be in the log file.
|
|
105
|
+
"""
|
|
106
|
+
logging_info("\n%s" % json.dumps(self._nested_dict, sort_keys=True, indent=4, separators=(',', ': ')))
|
|
107
|
+
|
|
108
|
+
def to_dict(self):
|
|
109
|
+
"""
|
|
110
|
+
Return the config as a dict
|
|
111
|
+
:return: config dict
|
|
112
|
+
:rtype: dict
|
|
113
|
+
"""
|
|
114
|
+
return self._nested_dict
|
|
115
|
+
|
|
116
|
+
def dump(self, path_dump=None):
|
|
117
|
+
"""
|
|
118
|
+
Dump the config in the path_dump.
|
|
119
|
+
:param path_dump: the path to dump the config
|
|
120
|
+
:type path_dump: str
|
|
121
|
+
"""
|
|
122
|
+
if path_dump is None:
|
|
123
|
+
makedir_if_not_exist(dir_configs)
|
|
124
|
+
path_dump = os.path.join(dir_configs, "%s.json" % get_random_time_stamp())
|
|
125
|
+
path_dump = "%s.json" % path_dump if not path_dump.endswith(".json") else path_dump
|
|
126
|
+
assert not os.path.exists(path_dump)
|
|
127
|
+
with open(path_dump, "w", encoding="utf-8") as fout:
|
|
128
|
+
json.dump(self._nested_dict, fout)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class Config(NestedDict):
|
|
132
|
+
def __init__(self, default_config_file=None, default_config_dict=None, use_argparse=True, use_wandb=False):
|
|
133
|
+
"""
|
|
134
|
+
Initialize the config. Note that either default_config_dict or default_config_file in json format must be
|
|
135
|
+
provided! The keys will be transferred to argument names, and the type will be automatically detected. The
|
|
136
|
+
priority is ``the user specified parameter (if the use_argparse is True)'' > ``user specified config file (if
|
|
137
|
+
the use_argparse is True)'' > ``default config dict'' > ``default config file''.
|
|
138
|
+
|
|
139
|
+
Examples:
|
|
140
|
+
default_config_dict = {"lr": 0.01, "optimizer": "sgd", "num_epoch": 30, "use_early_stop": False}
|
|
141
|
+
Then the following corresponding arguments will be added in this function if use_argparse is True:
|
|
142
|
+
parser.add_argument("--lr", type=float)
|
|
143
|
+
parser.add_argument("--optimizer", type=str)
|
|
144
|
+
parser.add_argument("--num_epoch", type=int)
|
|
145
|
+
parser.add_argument("--use_early_stop", action="store_true", default=False)
|
|
146
|
+
parser.add_argument("--no-use_early_stop", dest="use_early_stop", action="store_false")
|
|
147
|
+
|
|
148
|
+
:param default_config_dict: the default config dict
|
|
149
|
+
:type default_config_dict: dict
|
|
150
|
+
:param default_config_file: the default config file path
|
|
151
|
+
:type default_config_file: str
|
|
152
|
+
:param use_argparse: whether use argparse to parse the config
|
|
153
|
+
:type use_argparse: bool
|
|
154
|
+
:param use_wandb: whether init wandb with parent directory as project name and exp_name as run name
|
|
155
|
+
:type use_wandb: bool
|
|
156
|
+
"""
|
|
157
|
+
super(Config, self).__init__()
|
|
158
|
+
|
|
159
|
+
# load from default config file
|
|
160
|
+
if default_config_dict is None and default_config_file is None:
|
|
161
|
+
if os.path.exists(os.path.join(os.getcwd(), "default_config.json")):
|
|
162
|
+
default_config_file = os.path.join(os.getcwd(), "default_config.json")
|
|
163
|
+
else:
|
|
164
|
+
logging.error("Either default_config_file or default_config_dict must be provided!")
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
if default_config_file is not None:
|
|
168
|
+
self.update(smart_load(default_config_file))
|
|
169
|
+
if default_config_dict is not None:
|
|
170
|
+
self.update(default_config_dict)
|
|
171
|
+
|
|
172
|
+
# transform the param terms into argparse
|
|
173
|
+
if use_argparse:
|
|
174
|
+
parser = argparse.ArgumentParser()
|
|
175
|
+
parser.add_argument("--config_file", type=str, default=None)
|
|
176
|
+
# add argument parser
|
|
177
|
+
for name_param in self.keys():
|
|
178
|
+
value_param = self[name_param]
|
|
179
|
+
if type(value_param) is bool:
|
|
180
|
+
parser.add_argument("--%s" % name_param, action="store_true", default=None)
|
|
181
|
+
parser.add_argument("--no-%s" % name_param, dest="%s" % name_param,
|
|
182
|
+
action="store_false", default=None)
|
|
183
|
+
elif type(value_param) is list:
|
|
184
|
+
parser.add_argument("--%s" % name_param, type=type(value_param[0]), nargs="+", default=None)
|
|
185
|
+
else:
|
|
186
|
+
parser.add_argument("--%s" % name_param, type=type(value_param), default=None)
|
|
187
|
+
args = parser.parse_args()
|
|
188
|
+
|
|
189
|
+
updated_parameters = dict()
|
|
190
|
+
args_dict = vars(args)
|
|
191
|
+
for k in vars(args):
|
|
192
|
+
if k != "config_file" and args_dict[k] is not None:
|
|
193
|
+
updated_parameters[k] = args_dict[k]
|
|
194
|
+
|
|
195
|
+
if args.config_file is not None:
|
|
196
|
+
self.update(smart_load(args.config_file))
|
|
197
|
+
|
|
198
|
+
self.update(updated_parameters)
|
|
199
|
+
|
|
200
|
+
if use_wandb:
|
|
201
|
+
wandb.login()
|
|
202
|
+
wandb.init(
|
|
203
|
+
project=os.path.split(os.getcwd())[-1],
|
|
204
|
+
name=self["exp_name"],
|
|
205
|
+
config=self.to_dict()
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_parser():
|
|
210
|
+
"""
|
|
211
|
+
Get a simple parser for argparse, which already contains the config_file argument.
|
|
212
|
+
"""
|
|
213
|
+
parser = argparse.ArgumentParser()
|
|
214
|
+
parser.add_argument("--config_file", type=str, default=None)
|
|
215
|
+
return parser
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def parser2config(parser):
|
|
219
|
+
"""
|
|
220
|
+
Parse the arguments from parser into a config.
|
|
221
|
+
"""
|
|
222
|
+
return args2config(parser.parse_args())
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def args2config(args):
|
|
226
|
+
"""
|
|
227
|
+
Parse the arguments from args into a config.
|
|
228
|
+
"""
|
|
229
|
+
args_dict = vars(args)
|
|
230
|
+
return Config(default_config_dict=args_dict, use_argparse=False)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def is_configs_same(config_a, config_b, ignored_keys=("load_epoch",)):
|
|
234
|
+
"""
|
|
235
|
+
Judge whether two configs are the same.
|
|
236
|
+
Args:
|
|
237
|
+
config_a: the first config
|
|
238
|
+
config_b: the second config
|
|
239
|
+
ignored_keys: thes keys that will be ignored when comparing
|
|
240
|
+
|
|
241
|
+
Returns: True if the two configs are the same, False otherwise.
|
|
242
|
+
|
|
243
|
+
"""
|
|
244
|
+
config_a, config_b = config_a.to_dict(), config_b.to_dict()
|
|
245
|
+
|
|
246
|
+
# make sure config A is always equal or longer than config B
|
|
247
|
+
if len(config_a.keys()) < len(config_b.keys()):
|
|
248
|
+
swap_var = config_a
|
|
249
|
+
config_a = config_b
|
|
250
|
+
config_b = swap_var
|
|
251
|
+
|
|
252
|
+
if len(config_a.keys() - config_b.keys()) > 1:
|
|
253
|
+
logging.error(
|
|
254
|
+
"Different config numbers: %d (Existing) : %d (New)!" % (len(config_a.keys()), len(config_b.keys())))
|
|
255
|
+
return False
|
|
256
|
+
elif len(config_a.keys() - config_b.keys()) == 1 and (config_a.keys() - config_b.keys())[0] != "config_file":
|
|
257
|
+
logging.error(
|
|
258
|
+
"Different config numbers: %d (Existing) : %d (New)!" % (len(config_a.keys()), len(config_b.keys())))
|
|
259
|
+
return False
|
|
260
|
+
else:
|
|
261
|
+
for i in config_a.keys() & config_b.keys():
|
|
262
|
+
_ai = tuple(config_a[i]) if type(config_a[i]) == list else config_a[i]
|
|
263
|
+
_bi = tuple(config_b[i]) if type(config_b[i]) == list else config_b[i]
|
|
264
|
+
if _ai != _bi and i not in ignored_keys:
|
|
265
|
+
logging.error("Mismatch in %s: %s (Existing) - %s (New)" % (str(i), str(config_a[i]), str(config_b[i])))
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
return True
|
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from matplotlib.pyplot import figure
|
|
5
|
+
|
|
6
|
+
from .general_utils import get_random_time_stamp, makedir_if_not_exist
|
|
7
|
+
|
|
8
|
+
dir_figures = os.path.join(os.getcwd(), "figures")
|
|
9
|
+
makedir_if_not_exist(dir_figures)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Drawer:
|
|
13
|
+
def __init__(self, num_row=1, num_col=1, unit_length=10, unit_row_length=None, unit_col_length=None):
|
|
14
|
+
"""
|
|
15
|
+
Init the drawer with the (width=num_col*unit_length, height=num_row*unit_length).
|
|
16
|
+
:param num_row: the number of rows
|
|
17
|
+
:type num_row: int
|
|
18
|
+
:param num_col: the number of columns
|
|
19
|
+
:type num_col: int
|
|
20
|
+
:param unit_length: the length of unit
|
|
21
|
+
:type unit_length: float
|
|
22
|
+
:param unit_row_length: the length of unit for rows
|
|
23
|
+
:param unit_col_length: the length of unit for cols
|
|
24
|
+
"""
|
|
25
|
+
self.num_row = num_row
|
|
26
|
+
self.num_col = num_col
|
|
27
|
+
unit_row_length = unit_length if unit_row_length is None else unit_row_length
|
|
28
|
+
unit_col_length = unit_length if unit_col_length is None else unit_col_length
|
|
29
|
+
self.figure = figure(figsize=(num_col * unit_row_length, num_row * unit_col_length))
|
|
30
|
+
|
|
31
|
+
def add_one_empty_axes(self, index=1, nrows=None, ncols=None,
|
|
32
|
+
title="", xlabel="", ylabel="", fontsize=15, xlim=None, ylim=None):
|
|
33
|
+
"""
|
|
34
|
+
Draw one axes, which can be understood as a sub-figure.
|
|
35
|
+
:param index: The subplot will take the index position on a grid with nrows rows and ncols columns.
|
|
36
|
+
:type index: int
|
|
37
|
+
:param nrows: the number of rows in the figure
|
|
38
|
+
:type nrows: int
|
|
39
|
+
:param ncols: the number of columns in the figure
|
|
40
|
+
:type ncols: int
|
|
41
|
+
:param title: the title of the axes
|
|
42
|
+
:type title: str
|
|
43
|
+
:param xlabel: the label for x axis
|
|
44
|
+
:type xlabel: str
|
|
45
|
+
:param ylabel: the label for x axis
|
|
46
|
+
:type ylabel: str
|
|
47
|
+
:param fontsize: the size of the fonts
|
|
48
|
+
:param xlim: the range of x axis, (low, upp)
|
|
49
|
+
:param ylim: the range of y axis, (low, upp)
|
|
50
|
+
:return:
|
|
51
|
+
:rtype:
|
|
52
|
+
"""
|
|
53
|
+
nrows = self.num_row if nrows is None else nrows
|
|
54
|
+
ncols = self.num_col if ncols is None else ncols
|
|
55
|
+
|
|
56
|
+
ax = self.figure.add_subplot(nrows, ncols, index)
|
|
57
|
+
|
|
58
|
+
ax.set_xlabel(xlabel, fontsize=fontsize)
|
|
59
|
+
ax.set_ylabel(ylabel, fontsize=fontsize)
|
|
60
|
+
ax.set_title(title, fontsize=fontsize)
|
|
61
|
+
ax.xaxis.set_tick_params(labelsize=fontsize)
|
|
62
|
+
ax.yaxis.set_tick_params(labelsize=fontsize)
|
|
63
|
+
if xlim is not None:
|
|
64
|
+
ax.set_xlim(*xlim)
|
|
65
|
+
if ylim is not None:
|
|
66
|
+
ax.set_ylim(*ylim)
|
|
67
|
+
|
|
68
|
+
return ax
|
|
69
|
+
|
|
70
|
+
def draw_one_axes(self, x, y, labels=None, *, index=1, nrows=None, ncols=None,
|
|
71
|
+
title="", xlabel="", ylabel="", use_marker=False, linewidth=6,
|
|
72
|
+
fontsize=15, xlim=None, ylim=None, smooth=0, mode="plot", **kwargs):
|
|
73
|
+
"""
|
|
74
|
+
Draw one axes, which can be understood as a sub-figure.
|
|
75
|
+
:param x: the data for x axis, list
|
|
76
|
+
:param y: the data for y axis, list of line lists. e.g. [[1, 2, 3], [2, 3, 1]], list
|
|
77
|
+
:param labels: the list of labels of each line, list
|
|
78
|
+
:param index: The subplot will take the index position on a grid with nrows rows and ncols columns.
|
|
79
|
+
:type index: int
|
|
80
|
+
:param nrows: the number of rows in the figure
|
|
81
|
+
:type nrows: int
|
|
82
|
+
:param ncols: the number of columns in the figure
|
|
83
|
+
:type ncols: int
|
|
84
|
+
:param title: the title of the axes
|
|
85
|
+
:type title: str
|
|
86
|
+
:param xlabel: the label for x axis
|
|
87
|
+
:type xlabel: str
|
|
88
|
+
:param ylabel: the label for x axis
|
|
89
|
+
:type ylabel: str
|
|
90
|
+
:param use_marker: whether use markers to mark the points, default=False
|
|
91
|
+
:type use_marker: bool
|
|
92
|
+
:param linewidth: the width of the lines for mode "plot", or the size of the points for mode "scatter"
|
|
93
|
+
:param fontsize: the size of the fonts
|
|
94
|
+
:param xlim: the range of x axis, (low, upp)
|
|
95
|
+
:param ylim: the range of y axis, (low, upp)
|
|
96
|
+
:param smooth: smooth the line with neighbours
|
|
97
|
+
:param mode: "plot" or "scatter"
|
|
98
|
+
:return:
|
|
99
|
+
:rtype:
|
|
100
|
+
"""
|
|
101
|
+
ax = self.add_one_empty_axes(index, nrows, ncols, title, xlabel, ylabel,
|
|
102
|
+
fontsize, xlim, ylim)
|
|
103
|
+
|
|
104
|
+
format_generator = self.get_format(use_marker)
|
|
105
|
+
for i, yi in enumerate(y):
|
|
106
|
+
if len(x) == len(y) and type(x[0]) is list:
|
|
107
|
+
xi = x[i]
|
|
108
|
+
elif len(x) == len(y[0]) and type(x[0]) is not list:
|
|
109
|
+
xi = x
|
|
110
|
+
else:
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
if smooth != 0:
|
|
114
|
+
yi_smoothed = []
|
|
115
|
+
for j, yij in enumerate(yi):
|
|
116
|
+
_r = min(j + smooth, len(yi) - 1)
|
|
117
|
+
_l = max(j - smooth, 0)
|
|
118
|
+
yij = sum(yi[_l: _r]) / (_r - _l)
|
|
119
|
+
yi_smoothed.append(yij)
|
|
120
|
+
yi = yi_smoothed
|
|
121
|
+
|
|
122
|
+
len_no_nan = 0
|
|
123
|
+
while len_no_nan < len(yi) and not (np.isnan(yi[len_no_nan]) or np.isinf(yi[len_no_nan])):
|
|
124
|
+
len_no_nan += 1
|
|
125
|
+
if len_no_nan == 0:
|
|
126
|
+
continue
|
|
127
|
+
|
|
128
|
+
fmt = next(format_generator)
|
|
129
|
+
|
|
130
|
+
if labels is not None:
|
|
131
|
+
kwargs["label"] = labels[i]
|
|
132
|
+
if mode == "plot":
|
|
133
|
+
kwargs["linewidth"] = linewidth
|
|
134
|
+
|
|
135
|
+
if mode == "plot":
|
|
136
|
+
ax.plot(xi[:len_no_nan], yi[:len_no_nan], fmt, **kwargs)
|
|
137
|
+
elif mode == "scatter":
|
|
138
|
+
ax.scatter(xi[:len_no_nan], yi[:len_no_nan], c=fmt[0], s=linewidth, **kwargs)
|
|
139
|
+
else:
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
if labels is not None:
|
|
143
|
+
ax.legend(fontsize=fontsize)
|
|
144
|
+
|
|
145
|
+
return ax
|
|
146
|
+
|
|
147
|
+
def show(self):
|
|
148
|
+
"""
|
|
149
|
+
To show the figure.
|
|
150
|
+
"""
|
|
151
|
+
self.figure.show()
|
|
152
|
+
|
|
153
|
+
def save(self, fname=None):
|
|
154
|
+
"""
|
|
155
|
+
To save the figure as fname.
|
|
156
|
+
:param fname: the filename
|
|
157
|
+
:type fname: str
|
|
158
|
+
"""
|
|
159
|
+
if fname is None:
|
|
160
|
+
fname = get_random_time_stamp()
|
|
161
|
+
fname = "%s.jpeg" % fname if not fname.endswith(".config") else fname
|
|
162
|
+
self.figure.savefig(os.path.join(dir_figures, fname), bbox_inches='tight')
|
|
163
|
+
|
|
164
|
+
def clear(self):
|
|
165
|
+
"""
|
|
166
|
+
Clear the figure.
|
|
167
|
+
"""
|
|
168
|
+
self.figure.clf()
|
|
169
|
+
|
|
170
|
+
@staticmethod
|
|
171
|
+
def get_format(use_marker=False):
|
|
172
|
+
"""
|
|
173
|
+
Get the format of a line.
|
|
174
|
+
:param use_marker: whether use markers for points or not.
|
|
175
|
+
:type use_marker: bool
|
|
176
|
+
"""
|
|
177
|
+
p_color, p_style, p_marker = 0, 0, 0
|
|
178
|
+
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k']
|
|
179
|
+
styles = ['-', '--', '-.', ':']
|
|
180
|
+
markers = [""]
|
|
181
|
+
if use_marker:
|
|
182
|
+
markers = ['o', 'v', '^', '<', '>', '1', '2', '3', '4', '8', 's', 'p', 'P', '*', 'h', 'H', '+',
|
|
183
|
+
'x', 'X', 'D', 'd', '|', '_', ]
|
|
184
|
+
|
|
185
|
+
while True:
|
|
186
|
+
yield colors[p_color] + styles[p_style] + markers[p_marker]
|
|
187
|
+
p_color += 1
|
|
188
|
+
p_style += 1
|
|
189
|
+
p_marker += 1
|
|
190
|
+
p_color %= len(colors)
|
|
191
|
+
p_style %= len(styles)
|
|
192
|
+
p_marker %= len(markers)
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import random
|
|
3
|
+
import datetime
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_datetime():
|
|
7
|
+
return datetime.datetime.now().strftime('%Y.%m.%d-%H.%M.%S')
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def get_random_time_stamp():
|
|
11
|
+
"""
|
|
12
|
+
Return a random time stamp.
|
|
13
|
+
:return: random time stamp
|
|
14
|
+
:rtype: str
|
|
15
|
+
"""
|
|
16
|
+
return "%d-%s" % (random.randint(100, 999), get_datetime())
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def makedir_if_not_exist(name):
|
|
20
|
+
"""
|
|
21
|
+
Make the directory if it does not exist.
|
|
22
|
+
:param name: dir name
|
|
23
|
+
:type name: str
|
|
24
|
+
"""
|
|
25
|
+
if not os.path.exists(name):
|
|
26
|
+
os.makedirs(name)
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from IPython.display import display, HTML, Markdown, Latex
|
|
3
|
+
|
|
4
|
+
display(HTML("<style>.container { width:95% !important; }</style>"))
|
|
5
|
+
pd.set_option('display.max_colwidth', None)
|
|
6
|
+
pd.set_option('display.max_columns', 1000)
|
|
7
|
+
pd.set_option('display.max_rows', 1000)
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
from .general_utils import get_random_time_stamp
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_logger(path_log="%s.log" % get_random_time_stamp(), force_add_handler=False):
|
|
8
|
+
"""
|
|
9
|
+
Set up the logger. Note that the setting will also impact the default logging logger, which means that simply
|
|
10
|
+
using logging.info() will output the logs to both stdout and the filename_log.
|
|
11
|
+
:param path_log: the filename of the log
|
|
12
|
+
:param force_add_handler: if True, will clear logging.root.handlers
|
|
13
|
+
:type path_log: str
|
|
14
|
+
"""
|
|
15
|
+
ret_logger = logging.getLogger()
|
|
16
|
+
ret_logger.setLevel(logging.DEBUG)
|
|
17
|
+
formatter = logging.Formatter('%(asctime)s-%(name)s-%(levelname)s: %(message)s', datefmt='%Y-%m-%d-%H:%M:%S')
|
|
18
|
+
|
|
19
|
+
if force_add_handler:
|
|
20
|
+
ret_logger.handlers = []
|
|
21
|
+
|
|
22
|
+
if not ret_logger.handlers:
|
|
23
|
+
if path_log is not None:
|
|
24
|
+
path_log = "%s.log" % path_log if not path_log.endswith(".log") else path_log
|
|
25
|
+
fh = logging.FileHandler(path_log)
|
|
26
|
+
fh.setLevel(logging.DEBUG)
|
|
27
|
+
fh.setFormatter(formatter)
|
|
28
|
+
ret_logger.addHandler(fh)
|
|
29
|
+
|
|
30
|
+
ch = logging.StreamHandler(sys.stdout)
|
|
31
|
+
ch.setLevel(logging.INFO)
|
|
32
|
+
ch.setFormatter(formatter)
|
|
33
|
+
ret_logger.addHandler(ch)
|
|
34
|
+
|
|
35
|
+
return ret_logger
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def logging_info(*args):
|
|
39
|
+
if logging.root.level > logging.getLevelName("INFO"):
|
|
40
|
+
logging.root.setLevel(logging.DEBUG)
|
|
41
|
+
for handler in logging.root.handlers:
|
|
42
|
+
handler.setLevel(logging.DEBUG)
|
|
43
|
+
logging.info(*args)
|
|
44
|
+
else:
|
|
45
|
+
logging.info(*args)
|