elements 2.2.0__tar.gz → 3.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {elements-2.2.0/elements.egg-info → elements-3.1.0}/PKG-INFO +1 -1
  2. {elements-2.2.0 → elements-3.1.0}/elements/__init__.py +1 -1
  3. elements-3.1.0/elements/agg.py +143 -0
  4. {elements-2.2.0 → elements-3.1.0}/elements/checkpoint.py +23 -15
  5. {elements-2.2.0 → elements-3.1.0}/elements/config.py +7 -5
  6. {elements-2.2.0 → elements-3.1.0}/elements/flags.py +43 -19
  7. {elements-2.2.0 → elements-3.1.0}/elements/printing.py +27 -18
  8. {elements-2.2.0 → elements-3.1.0}/elements/tree.py +5 -0
  9. {elements-2.2.0 → elements-3.1.0}/elements/usage.py +1 -1
  10. {elements-2.2.0 → elements-3.1.0}/elements/uuid.py +12 -2
  11. {elements-2.2.0 → elements-3.1.0/elements.egg-info}/PKG-INFO +1 -1
  12. {elements-2.2.0 → elements-3.1.0}/elements.egg-info/SOURCES.txt +4 -1
  13. elements-3.1.0/elements.egg-info/requires.txt +1 -0
  14. {elements-2.2.0 → elements-3.1.0}/requirements-optional.txt +3 -1
  15. elements-3.1.0/requirements.txt +1 -0
  16. elements-3.1.0/tests/test_basics.py +70 -0
  17. elements-3.1.0/tests/test_flags.py +126 -0
  18. elements-3.1.0/tests/test_path.py +45 -0
  19. elements-2.2.0/elements/agg.py +0 -78
  20. elements-2.2.0/elements.egg-info/requires.txt +0 -4
  21. elements-2.2.0/requirements.txt +0 -4
  22. {elements-2.2.0 → elements-3.1.0}/LICENSE +0 -0
  23. {elements-2.2.0 → elements-3.1.0}/MANIFEST.in +0 -0
  24. {elements-2.2.0 → elements-3.1.0}/README.md +0 -0
  25. {elements-2.2.0 → elements-3.1.0}/elements/counter.py +0 -0
  26. {elements-2.2.0 → elements-3.1.0}/elements/fps.py +0 -0
  27. {elements-2.2.0 → elements-3.1.0}/elements/logger.py +0 -0
  28. {elements-2.2.0 → elements-3.1.0}/elements/path.py +0 -0
  29. {elements-2.2.0 → elements-3.1.0}/elements/plotting.py +0 -0
  30. {elements-2.2.0 → elements-3.1.0}/elements/rwlock.py +0 -0
  31. {elements-2.2.0 → elements-3.1.0}/elements/timer.py +0 -0
  32. {elements-2.2.0 → elements-3.1.0}/elements/when.py +0 -0
  33. {elements-2.2.0 → elements-3.1.0}/elements.egg-info/dependency_links.txt +0 -0
  34. {elements-2.2.0 → elements-3.1.0}/elements.egg-info/top_level.txt +0 -0
  35. {elements-2.2.0 → elements-3.1.0}/setup.cfg +0 -0
  36. {elements-2.2.0 → elements-3.1.0}/setup.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: elements
3
- Version: 2.2.0
3
+ Version: 3.1.0
4
4
  Summary: Building blocks for productive research.
5
5
  Home-page: http://github.com/danijar/elements
6
6
  Classifier: Intended Audience :: Science/Research
@@ -1,4 +1,4 @@
1
- __version__ = '2.2.0'
1
+ __version__ = '3.1.0'
2
2
 
3
3
  from .agg import Agg
4
4
  from .checkpoint import Checkpoint
@@ -0,0 +1,143 @@
1
+ import math
2
+ import operator
3
+ from collections import defaultdict
4
+ from functools import partial as bind
5
+
6
+ import numpy as np
7
+
8
+
9
+ class Agg:
10
+
11
+ def __init__(self, maxlen=1e6):
12
+ self.reducers = defaultdict(list)
13
+ self.names = {}
14
+ self.maxlen = int(maxlen)
15
+
16
+ def add(self, key_or_dict, value=None, agg='default', prefix=None):
17
+ if value is not None:
18
+ self._add_single(key_or_dict, value, agg, prefix)
19
+ return
20
+ for key, value in key_or_dict.items():
21
+ self._add_single(key, value, agg, prefix)
22
+
23
+ def result(self, reset=True, prefix=None):
24
+ metrics = {}
25
+ for key, reducers in self.reducers.items():
26
+ if len(reducers) == 1:
27
+ metrics[key] = reducers[0].current()
28
+ else:
29
+ for name, reducer in zip(self.names[key], reducers):
30
+ metrics[f'{key}/{name}'] = reducer.current()
31
+ if prefix:
32
+ metrics = {f'{prefix}/{k}': v for k, v in metrics.items()}
33
+ reset and self.reset()
34
+ return metrics
35
+
36
+ def reset(self):
37
+ self.reducers.clear()
38
+
39
+ def _add_single(self, key, value, agg, prefix):
40
+ key = f'{prefix}/{key}' if prefix else key
41
+ reducers = self.reducers[key]
42
+ if reducers:
43
+ for reducer in reducers:
44
+ reducer.update(value)
45
+ return
46
+ if agg == 'default':
47
+ agg = 'avg' if np.asarray(value).ndim <= 1 else 'last'
48
+ if isinstance(agg, str):
49
+ aggs = (agg,)
50
+ self.names[key] = None
51
+ else:
52
+ aggs = agg
53
+ self.names[key] = aggs
54
+ for agg in aggs:
55
+ if agg == 'avg':
56
+ reducer = Mean(value)
57
+ elif agg == 'sum':
58
+ reducer = Sum(value)
59
+ elif agg == 'min':
60
+ reducer = Min(value)
61
+ elif agg == 'max':
62
+ reducer = Max(value)
63
+ elif agg == 'stack':
64
+ reducer = Stack(value, self.maxlen)
65
+ elif agg == 'last':
66
+ reducer = Last(value)
67
+ else:
68
+ raise ValueError(agg)
69
+ reducers.append(reducer)
70
+
71
+
72
+ class Reducer:
73
+
74
+ def __init__(self, scalar_fn, array_fn, initial):
75
+ self.scalar_fn = scalar_fn
76
+ self.array_fn = array_fn
77
+ self.is_scalar = isinstance(initial, (int, float))
78
+ if self.is_scalar:
79
+ self.interm = initial
80
+ else:
81
+ self.interm = np.array(initial, np.float64)
82
+ self.count = 1
83
+
84
+ def update(self, value):
85
+ if self.is_scalar:
86
+ if math.isnan(value):
87
+ return
88
+ self.interm = self.scalar_fn(self.interm, value)
89
+ else:
90
+ value = np.asarray(value)
91
+ if np.isnan(value).any():
92
+ return
93
+ self.interm = self.array_fn(self.interm, value)
94
+ self.count += 1
95
+
96
+ def current(self):
97
+ return np.array(self.interm)
98
+
99
+
100
+ class Mean:
101
+
102
+ def __init__(self, initial):
103
+ self.reducer = Sum(initial)
104
+
105
+ def update(self, value):
106
+ self.reducer.update(value)
107
+
108
+ def current(self):
109
+ return self.reducer.current() / self.reducer.count
110
+
111
+
112
+ class Stack:
113
+
114
+ def __init__(self, initial, maxlen=1e5):
115
+ self.stack = [initial]
116
+ self.maxlen = int(maxlen)
117
+
118
+ def update(self, value):
119
+ if len(self.stack) < self.maxlen:
120
+ self.stack.append(value)
121
+
122
+ def current(self):
123
+ return np.stack(self.stack)
124
+
125
+
126
+ class Last:
127
+
128
+ def __init__(self, initial):
129
+ self.value = initial
130
+
131
+ def update(self, value):
132
+ self.value = value
133
+
134
+ def current(self):
135
+ return self.value
136
+
137
+
138
+ Sum = bind(
139
+ Reducer, operator.add, lambda x, y: np.add(x, y, out=x, dtype=np.float64))
140
+ Min = bind(
141
+ Reducer, min, lambda x, y: np.minimum(x, y, out=x, dtype=np.float64))
142
+ Max = bind(
143
+ Reducer, max, lambda x, y: np.maximum(x, y, out=x, dtype=np.float64))
@@ -1,8 +1,10 @@
1
1
  import concurrent.futures
2
- import json
2
+ import pickle
3
3
  import time
4
4
 
5
+ from . import printing
5
6
  from . import path
7
+ from . import timer
6
8
 
7
9
 
8
10
  class Checkpoint:
@@ -12,7 +14,7 @@ class Checkpoint:
12
14
  self._values = {}
13
15
  self._parallel = parallel
14
16
  if self._parallel:
15
- self._worker = concurrent.futures.ThreadPoolExecutor(1)
17
+ self._worker = concurrent.futures.ThreadPoolExecutor(1, 'checkpoint')
16
18
  self._promise = None
17
19
 
18
20
  def __setattr__(self, name, value):
@@ -31,7 +33,7 @@ class Checkpoint:
31
33
  if name.startswith('_'):
32
34
  raise AttributeError(name)
33
35
  try:
34
- return getattr(self._values, name)
36
+ return self._values[name]
35
37
  except AttributeError:
36
38
  raise ValueError(name)
37
39
 
@@ -48,33 +50,39 @@ class Checkpoint:
48
50
  def save(self, filename=None, keys=None):
49
51
  assert self._filename or filename
50
52
  filename = path.Path(filename or self._filename)
51
- print(f'Writing checkpoint: {filename}')
53
+ printing.print_(f'Writing checkpoint: {filename}')
52
54
  if self._parallel:
53
55
  self._promise and self._promise.result()
54
56
  self._promise = self._worker.submit(self._save, filename, keys)
55
57
  else:
56
58
  self._save(filename, keys)
57
59
 
60
+ @timer.section('checkpoint_save')
58
61
  def _save(self, filename, keys):
59
62
  keys = tuple(self._values.keys() if keys is None else keys)
60
63
  assert all([not k.startswith('_') for k in keys]), keys
61
64
  data = {k: self._values[k].save() for k in keys}
62
65
  data['_timestamp'] = time.time()
63
66
  filename.parent.mkdirs()
64
- # Write to a temporary file and then atomically rename, so that the
65
- # requested filename either contains a full checkpoint or does not exist if
66
- # writing was interrupted.
67
- tmp = filename.parent / (filename.name + '.tmp')
68
- tmp.write(json.dumps(data), mode='wb')
69
- tmp.move(filename)
70
- print(f'Wrote checkpoint: {filename}')
67
+ content = pickle.dumps(data)
68
+ if str(filename).startswith('gs://'):
69
+ filename.write(content, mode='wb')
70
+ else:
71
+ # Write to a temporary file and then atomically rename, so that the file
72
+ # either contains a complete write or not update at all if writing was
73
+ # interrupted.
74
+ tmp = filename.parent / (filename.name + '.tmp')
75
+ tmp.write(content, mode='wb')
76
+ tmp.move(filename)
77
+ print('Wrote checkpoint.')
71
78
 
79
+ @timer.section('checkpoint_load')
72
80
  def load(self, filename=None, keys=None):
73
81
  assert self._filename or filename
74
82
  self._promise and self._promise.result() # Wait for last save.
75
83
  filename = path.Path(filename or self._filename)
76
- print(f'Loading checkpoint: {filename}')
77
- data = json.loads(filename.read('rb'))
84
+ printing.print_(f'Loading checkpoint: {filename}')
85
+ data = pickle.loads(filename.read('rb'))
78
86
  keys = tuple(data.keys() if keys is None else keys)
79
87
  for key in keys:
80
88
  if key.startswith('_'):
@@ -82,10 +90,10 @@ class Checkpoint:
82
90
  try:
83
91
  self._values[key].load(data[key])
84
92
  except Exception:
85
- print(f'Error loading {key} from checkpoint.')
93
+ print(f"Error loading '{key}' from checkpoint.")
86
94
  raise
87
95
  age = time.time() - data['_timestamp']
88
- print(f'Loaded checkpoint from {age:.0f} seconds ago.')
96
+ printing.print_(f'Loaded checkpoint from {age:.0f} seconds ago.')
89
97
 
90
98
  def load_or_save(self):
91
99
  if self.exists():
@@ -30,9 +30,10 @@ class Config(dict):
30
30
  if filename.suffix == '.json':
31
31
  filename.write(json.dumps(dict(self)))
32
32
  elif filename.suffix in ('.yml', '.yaml'):
33
- import ruamel.yaml as yaml
33
+ from ruamel.yaml import YAML
34
+ yaml = YAML(typ='safe')
34
35
  with io.StringIO() as stream:
35
- yaml.safe_dump(dict(self), stream)
36
+ yaml.dump(dict(self), stream)
36
37
  filename.write(stream.getvalue())
37
38
  else:
38
39
  raise NotImplementedError(filename.suffix)
@@ -41,10 +42,11 @@ class Config(dict):
41
42
  def load(cls, filename):
42
43
  filename = path.Path(filename)
43
44
  if filename.suffix == '.json':
44
- return cls(json.loads(filename.read_text()))
45
+ return cls(json.loads(filename.read()))
45
46
  elif filename.suffix in ('.yml', '.yaml'):
46
- import ruamel.yaml as yaml
47
- return cls(yaml.safe_load(filename.read_text()))
47
+ from ruamel.yaml import YAML
48
+ yaml = YAML(typ='safe')
49
+ return cls(yaml.load(filename.read()))
48
50
  else:
49
51
  raise NotImplementedError(filename.suffix)
50
52
 
@@ -12,9 +12,11 @@ class Flags:
12
12
  def parse(self, argv=None, help_exits=True):
13
13
  parsed, remaining = self.parse_known(argv)
14
14
  for flag in remaining:
15
- if flag.startswith('--'):
16
- raise ValueError(f"Flag '{flag}' did not match any config keys.")
17
- assert not remaining, remaining
15
+ if flag.startswith('--') and flag[2:] not in self._config.flat:
16
+ raise KeyError(f"Flag '{flag}' did not match any config keys.")
17
+ if remaining:
18
+ raise ValueError(
19
+ f'Could not parse all arguments. Remaining: {remaining}')
18
20
  return parsed
19
21
 
20
22
  def parse_known(self, argv=None, help_exits=False):
@@ -52,25 +54,40 @@ class Flags:
52
54
  return
53
55
  if not key:
54
56
  vals = ', '.join(f"'{x}'" for x in vals)
55
- raise ValueError(f"Values {vals} were not preceded by any flag.")
57
+ remaining.extend(vals)
58
+ return
59
+ # raise ValueError(f"Values {vals} were not preceded by any flag.")
56
60
  name = key[len('--'):]
57
61
  if '=' in name:
58
62
  remaining.extend([key] + vals)
59
63
  return
60
- if self._config.IS_PATTERN.fullmatch(name):
64
+ if not vals:
65
+ remaining.extend([key])
66
+ return
67
+ # raise ValueError(f"Flag '{key}' was not followed by any values.")
68
+ if name.endswith('+') and name[:-1] in self._config:
69
+ key = name[:-1]
70
+ default = self._config[key]
71
+ if not isinstance(default, tuple):
72
+ raise TypeError(
73
+ f"Cannot append to key '{key}' which is of type "
74
+ f"'{type(default).__name__}' instead of tuple.")
75
+ if key not in parsed:
76
+ parsed[key] = default
77
+ parsed[key] += self._parse_flag_value(default, vals, key)
78
+ elif self._config.IS_PATTERN.fullmatch(name):
61
79
  pattern = re.compile(name)
62
- keys = {k for k in self._config.flat if pattern.fullmatch(k)}
80
+ keys = [k for k in self._config.flat if pattern.fullmatch(k)]
81
+ if keys:
82
+ for key in keys:
83
+ parsed[key] = self._parse_flag_value(self._config[key], vals, key)
84
+ else:
85
+ remaining.extend([key] + vals)
63
86
  elif name in self._config:
64
- keys = [name]
87
+ key = name
88
+ parsed[key] = self._parse_flag_value(self._config[key], vals, key)
65
89
  else:
66
- keys = []
67
- if not keys:
68
90
  remaining.extend([key] + vals)
69
- return
70
- if not vals:
71
- raise ValueError(f"Flag '{key}' was not followed by any values.")
72
- for key in keys:
73
- parsed[key] = self._parse_flag_value(self._config[key], vals, key)
74
91
 
75
92
  def _parse_flag_value(self, default, value, key):
76
93
  value = value if isinstance(value, (tuple, list)) else (value,)
@@ -78,7 +95,9 @@ class Flags:
78
95
  if len(value) == 1 and ',' in value[0]:
79
96
  value = value[0].split(',')
80
97
  return tuple(self._parse_flag_value(default[0], [x], key) for x in value)
81
- assert len(value) == 1, value
98
+ if len(value) != 1:
99
+ raise TypeError(
100
+ f"Expected a single value for key '{key}' but got: {value}")
82
101
  value = str(value[0])
83
102
  if default is None:
84
103
  return value
@@ -92,11 +111,16 @@ class Flags:
92
111
  try:
93
112
  value = float(value) # Allow scientific notation for integers.
94
113
  assert float(int(value)) == value
95
- except (TypeError, AssertionError):
96
- message = f"Expected int but got float '{value}' for key '{key}'."
114
+ except (ValueError, TypeError, AssertionError):
115
+ message = f"Expected int but got '{value}' for key '{key}'."
97
116
  raise TypeError(message)
98
117
  return int(value)
99
118
  if isinstance(default, dict):
100
- raise TypeError(
119
+ raise KeyError(
101
120
  f"Key '{key}' refers to a whole dict. Please speicfy a subkey.")
102
- return type(default)(value)
121
+ try:
122
+ return type(default)(value)
123
+ except ValueError:
124
+ raise TypeError(
125
+ f"Cannot convert '{value}' to type '{type(default).__name__}' for "
126
+ f"key '{key}'.")
@@ -2,15 +2,22 @@ import re
2
2
 
3
3
  try:
4
4
  import colored
5
- REGEX_TOKEN = re.compile(r"([^a-zA-Z0-9-_./'\"\[\]])", re.MULTILINE)
6
- REGEX_NUMBER = re.compile(r'[-+]?[0-9]+[0-9.]*(e[-+]?[0-9])?')
7
- KEYWORDS = () # ('True', 'False', 'None', 'bool', 'int', 'str', 'float')
8
5
  except ImportError:
9
- print('For colored outputs: pip install colored')
10
6
  colored = None
7
+ print('For colored outputs: pip install colored')
8
+
11
9
 
10
+ REGEX_TOKEN = re.compile(
11
+ r"([^a-zA-Z0-9-_./'\"\[\]]|['\"][^\s]*['\"])", re.MULTILINE)
12
+ REGEX_NUMBER = re.compile(
13
+ r'([-+]?[0-9]+[0-9.,]*(e[-+]?[0-9])?|nan|-?inf)')
14
+ KEYWORDS = (
15
+ 'True', 'False', 'None', 'bool', 'int', 'str', 'float',
16
+ 'uint8', 'float16', 'float32', 'int32', 'int64')
12
17
 
13
- def print_(value, color=True, **kwargs):
18
+
19
+ def print_(*values, color=True, **kwargs):
20
+ value = kwargs.get('sep', ' ').join(str(x) for x in values)
14
21
  assert not color or isinstance(color, (bool, str)), color
15
22
  if isinstance(color, str) and colored:
16
23
  value = colored.stylize(value, colored.fg(color))
@@ -20,34 +27,36 @@ def print_(value, color=True, **kwargs):
20
27
  tokens = REGEX_TOKEN.split(value) + [None]
21
28
  for i, token in enumerate(tokens[:-1]):
22
29
  new = prev.copy()
23
- stripped = token.strip()
30
+ word = token.strip()
24
31
  new[2] = None
25
- if not stripped:
32
+ if not word:
26
33
  new[0] = None
27
- elif stripped in '/-+':
34
+ elif word in '/-+':
28
35
  new[0] = 'green'
29
36
  new[2] = True
30
- elif stripped in '{}()<>,:':
37
+ elif word in '{}()<>,:':
31
38
  new[0] = 'white'
32
39
  elif token == '=':
33
40
  new[0] = 'white'
34
- elif stripped[0].isalpha() and tokens[i + 1] == '=':
41
+ elif word[0].isalpha() and tokens[i + 1] == '=':
35
42
  new[0] = 'magenta'
36
- elif stripped in KEYWORDS:
43
+ elif word in KEYWORDS:
37
44
  new[0] = 'blue'
38
- elif stripped.startswith('---'):
45
+ elif word.startswith('---'):
39
46
  new[1] = True
40
- elif REGEX_NUMBER.match(stripped):
47
+ elif REGEX_NUMBER.match(word):
41
48
  new[0] = 'blue'
42
- elif stripped[0] == stripped[-1] == "'":
49
+ elif word[0] == word[-1] == "'":
43
50
  new[0] = 'red'
44
- elif stripped[0] == stripped[-1] == '"':
51
+ elif word[0] == word[-1] == '"':
45
52
  new[0] = 'red'
46
- elif stripped[0] == '[' and stripped[-1] == ']':
53
+ elif word[0] == '[' and word[-1] == ']':
47
54
  new[0] = 'cyan'
48
- elif stripped[0] == '/':
55
+ elif any(word.startswith(x) for x in ('/', '~', './')):
49
56
  new[0] = 'yellow'
50
- elif stripped[0] == stripped[0].upper():
57
+ elif len(word) >= 3 and word[0] == word[-1] and word[0] in ("'", '"'):
58
+ new[0] = 'green'
59
+ elif word[0] == word[0].upper():
51
60
  new[0] = None
52
61
  else:
53
62
  new[0] = None
@@ -20,6 +20,11 @@ def map_(fn, *trees, isleaf=None):
20
20
  assert all(set(x.keys()) == set(first.keys()) for x in trees), (
21
21
  printing.format_(trees))
22
22
  return {k: map_(fn, *[t[k] for t in trees], **kw) for k in first}
23
+ if hasattr(first, 'keys') and hasattr(first, 'get'):
24
+ assert all(set(x.keys()) == set(first.keys()) for x in trees), (
25
+ printing.format_(trees))
26
+ return type(first)(
27
+ {k: map_(fn, *[t[k] for t in trees], **kw) for k in first})
23
28
  return fn(*trees)
24
29
 
25
30
 
@@ -50,7 +50,7 @@ class NvsmiStats:
50
50
 
51
51
  @timer.section('nvsmi_stats')
52
52
  def __call__(self):
53
- output = os.popen('nvidia-smi --query -d UTILIZATION').read()
53
+ output = os.popen('nvidia-smi --query -d UTILIZATION 2>&1').read()
54
54
  if not output:
55
55
  print('To log GPU stats, make sure nvidia-smi is working.')
56
56
  return {}
@@ -5,11 +5,18 @@ import numpy as np
5
5
 
6
6
 
7
7
  class UUID:
8
+ """UUID that is stored as 16 byte string and can be converted to and from
9
+ int, string, and array types."""
10
+
11
+ __slots__ = ('value', '_hash')
8
12
 
9
13
  DEBUG_ID = None
10
14
  BASE62 = string.digits + string.ascii_letters
11
15
  BASE62REV = {x: i for i, x in enumerate(BASE62)}
12
16
 
17
+ # def __new__(cls, val=None):
18
+ # return val or np.random.randint(1, 2 ** 63)
19
+
13
20
  @classmethod
14
21
  def reset(cls, *, debug):
15
22
  cls.DEBUG_ID = 0 if debug else None
@@ -21,10 +28,13 @@ class UUID:
21
28
  else:
22
29
  type(self).DEBUG_ID += 1
23
30
  self.value = self.DEBUG_ID.to_bytes(16, 'big')
24
- elif isinstance(value, type(self)):
31
+ elif isinstance(value, UUID):
25
32
  self.value = value.value
26
33
  elif isinstance(value, int):
27
34
  self.value = value.to_bytes(16, 'big')
35
+ elif isinstance(value, bytes):
36
+ assert len(value) == 16, value
37
+ self.value = value
28
38
  elif isinstance(value, str):
29
39
  if self.DEBUG_ID is None:
30
40
  integer = 0
@@ -37,7 +47,7 @@ class UUID:
37
47
  self.value = value.tobytes()
38
48
  else:
39
49
  raise ValueError(value)
40
- assert type(self.value) == bytes, type(self.value)
50
+ assert type(self.value) == bytes, type(self.value) # noqa
41
51
  assert len(self.value) == 16, len(self.value)
42
52
  self._hash = hash(self.value)
43
53
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: elements
3
- Version: 2.2.0
3
+ Version: 3.1.0
4
4
  Summary: Building blocks for productive research.
5
5
  Home-page: http://github.com/danijar/elements
6
6
  Classifier: Intended Audience :: Science/Research
@@ -25,4 +25,7 @@ elements.egg-info/PKG-INFO
25
25
  elements.egg-info/SOURCES.txt
26
26
  elements.egg-info/dependency_links.txt
27
27
  elements.egg-info/requires.txt
28
- elements.egg-info/top_level.txt
28
+ elements.egg-info/top_level.txt
29
+ tests/test_basics.py
30
+ tests/test_flags.py
31
+ tests/test_path.py
@@ -0,0 +1 @@
1
+ numpy
@@ -1,7 +1,9 @@
1
1
  colored
2
- gil_load
2
+ gputil
3
3
  matplotlib
4
4
  mlflow
5
+ psutil
5
6
  pytest
7
+ ruamel.yaml
6
8
  tensorflow-cpu
7
9
  wandb
@@ -0,0 +1 @@
1
+ numpy
@@ -0,0 +1,70 @@
1
+ import elements
2
+ import pytest
3
+
4
+
5
+ class TestBasics:
6
+
7
+ def test_config(self):
8
+ config = elements.Config({'one.two': 12, 'foo': {'bar': True}})
9
+ assert config.one.two == 12
10
+ assert config.foo.bar is True
11
+ assert config['foo'].bar is True
12
+ assert 'foo' in config
13
+ assert 'foo.bar' in config
14
+ assert str(config)
15
+ _ = config.update({'one.two': 42})
16
+ with pytest.raises(KeyError):
17
+ _ = config.update({'new': 1})
18
+ with pytest.raises(TypeError):
19
+ _ = config.update({'one.two': 'string'})
20
+ with pytest.raises(AttributeError):
21
+ config.one.two = 1
22
+ with pytest.raises(TypeError):
23
+ elements.Config({'foo': lambda: None})
24
+
25
+ def test_flags(self):
26
+ flags = elements.Flags(foo=12, bar={'baz': True})
27
+ assert flags.parse(['--bar.baz', 'False']).foo == 12
28
+ assert flags.parse(['--bar.baz', 'False']).bar.baz is False
29
+ assert flags.parse(['--bar.*', 'False']).bar.baz is False
30
+ with pytest.raises(TypeError):
31
+ flags.parse(['--bar.baz', '12'])
32
+ with pytest.raises(KeyError):
33
+ flags.parse(['--.*unknown.*', '12'])
34
+ _, remaining = flags.parse_known(['one=two', '--foo', '42', '--three'])
35
+ assert remaining == ['one=two', '--three']
36
+ flags = elements.Flags({'foo': 12})
37
+ _, remaining = flags.parse_known(['--help'], help_exits=False)
38
+ assert remaining == ['--help']
39
+
40
+ def test_logger(self, capsys):
41
+ step = elements.Counter()
42
+ logger = elements.Logger(step, [elements.logger.TerminalOutput()])
43
+ logger.scalar('name', 15)
44
+ logger.write()
45
+ output = capsys.readouterr()
46
+ print(output)
47
+
48
+ def test_every(self):
49
+ should = elements.when.Every(5)
50
+ result = []
51
+ for i in range(16):
52
+ if should(i):
53
+ result.append(i)
54
+ assert result == [0, 5, 10, 15]
55
+
56
+ def test_once(self):
57
+ should = elements.when.Once()
58
+ result = []
59
+ for i in range(16):
60
+ if should():
61
+ result.append(i)
62
+ assert result == [0]
63
+
64
+ def test_until(self):
65
+ should = elements.when.Until(6)
66
+ result = []
67
+ for i in range(16):
68
+ if should(i):
69
+ result.append(i)
70
+ assert result == [0, 1, 2, 3, 4, 5]
@@ -0,0 +1,126 @@
1
+ import elements
2
+ import pytest
3
+
4
+
5
+ class TestFlags:
6
+
7
+ def test_int(self):
8
+ flags = elements.Flags({'foo': 42})
9
+ assert flags.parse(['--foo=1']).foo == 1
10
+ assert flags.parse(['--foo=1.0']).foo == 1
11
+ assert flags.parse(['--foo=1e2']).foo == 100
12
+ with pytest.raises(TypeError):
13
+ flags.parse(['--foo=0.5'])
14
+ with pytest.raises(TypeError):
15
+ flags.parse(['--foo=foo'])
16
+ with pytest.raises(TypeError):
17
+ assert flags.parse(['--foo=1,2,3'])
18
+
19
+ def test_float(self):
20
+ flags = elements.Flags({'foo': 1.0})
21
+ assert flags.parse(['--foo=0.5']).foo == 0.5
22
+ assert flags.parse(['--foo=1']).foo == 1.0
23
+ assert flags.parse(['--foo=1e2']).foo == 1e2
24
+ with pytest.raises(TypeError):
25
+ flags.parse(['--foo=True'])
26
+ with pytest.raises(TypeError):
27
+ flags.parse(['--foo=foo'])
28
+ with pytest.raises(TypeError):
29
+ assert flags.parse(['--foo=0.5,1.0'])
30
+
31
+ def test_bool(self):
32
+ flags = elements.Flags({'foo': True})
33
+ assert flags.parse(['--foo=True']).foo is True
34
+ assert flags.parse(['--foo=False']).foo is False
35
+ with pytest.raises(TypeError):
36
+ flags.parse(['--foo=true'])
37
+ with pytest.raises(TypeError):
38
+ flags.parse(['--foo=1'])
39
+ with pytest.raises(TypeError):
40
+ flags.parse(['--foo=foo'])
41
+ with pytest.raises(TypeError):
42
+ assert flags.parse(['--foo=True,False'])
43
+
44
+ def test_str(self):
45
+ flags = elements.Flags({'foo': 'hello'})
46
+ assert flags.parse(['--foo=hi']).foo == 'hi'
47
+ assert flags.parse(['--foo=1,2,3']).foo == '1,2,3'
48
+ assert flags.parse(['--foo=']).foo == ''
49
+ with pytest.raises(TypeError):
50
+ assert flags.parse(['--foo', '1', '2', '3'])
51
+
52
+ def test_sequence(self):
53
+ flags = elements.Flags({'foo': [1, 2]})
54
+ assert flags.parse(['--foo=1']).foo == (1,)
55
+ assert flags.parse(['--foo=1,2,3']).foo == (1, 2, 3)
56
+ assert flags.parse(['--foo', '1,2,3']).foo == (1, 2, 3)
57
+ assert flags.parse(['--foo', '1', '2', '3']).foo == (1, 2, 3)
58
+ with pytest.raises(TypeError):
59
+ assert flags.parse(['--foo', 'False'])
60
+ with pytest.raises(TypeError):
61
+ assert flags.parse(['--foo=1,2,0.5'])
62
+
63
+ def test_append(self):
64
+ flags = elements.Flags({'foo': [1, 2]})
65
+ assert flags.parse(['--foo+=3']).foo == (1, 2, 3)
66
+ assert flags.parse(['--foo+', '3']).foo == (1, 2, 3)
67
+ with pytest.raises(TypeError):
68
+ assert flags.parse(['--foo+', '0.5'])
69
+ assert flags.parse(['--foo=1', '--foo+=2', '--foo+=3']).foo == (1, 2, 3)
70
+ assert flags.parse(['--foo+=3', '--foo+=4']).foo == (1, 2, 3, 4)
71
+
72
+ def test_nested(self):
73
+ flags = elements.Flags({'foo.bar': 12})
74
+ assert flags.parse(['--foo.bar=42']).foo.bar == 42
75
+ with pytest.raises(KeyError):
76
+ assert flags.parse(['--foo=42'])
77
+ with pytest.raises(KeyError):
78
+ assert flags.parse(['--foo.baz=42'])
79
+ flags = elements.Flags({'foo': {'bar': 12}})
80
+ assert flags.parse(['--foo.bar=42']).foo.bar == 42
81
+
82
+ def test_regex(self):
83
+ flags = elements.Flags({'foo.bar': 12, 'baz': 'text'})
84
+ assert flags.parse([r'--.*\.bar=42']).foo.bar == 42
85
+ assert flags.parse([r'--.*z$=hello']).baz == 'hello'
86
+ parsed = flags.parse([r'--.*=42'])
87
+ assert parsed.foo.bar == 42
88
+ assert parsed.baz == '42'
89
+ with pytest.raises(TypeError):
90
+ assert flags.parse([r'--.*\.bar=0.5'])
91
+
92
+ def test_kwargs(self):
93
+ assert elements.Flags(foo=42).parse(['--foo=12']).foo == 12
94
+ assert elements.Flags(foo=42, bar='text').parse(['--bar=i']).bar == 'i'
95
+
96
+ def test_multiple(self):
97
+ defaults = elements.Config(foo=12, bar=0.5, baz='text')
98
+ flags = elements.Flags(defaults)
99
+ assert flags.parse([]) == defaults
100
+ assert flags.parse(['--bar', '2.5', '--foo=1']) == (
101
+ defaults.update(foo=1, bar=2.5))
102
+ with pytest.raises(ValueError):
103
+ flags.parse(['--bar', '--foo=1'])
104
+
105
+ def test_invalid(self):
106
+ flags = elements.Flags(foo=12)
107
+ with pytest.raises(KeyError):
108
+ flags.parse(['--bar=1'])
109
+ with pytest.raises(ValueError):
110
+ flags.parse(['foo=1'])
111
+ with pytest.raises(ValueError):
112
+ flags.parse(['1', '--foo=5'])
113
+
114
+ def test_from_yaml(self, tmpdir):
115
+ filename = elements.Path(tmpdir) / 'defaults.yaml'
116
+ filename.write("""
117
+ foo: 42
118
+ parent.child: 12
119
+ seq: [1, 2, 3]
120
+ scope:
121
+ inside: foo
122
+ """)
123
+ defaults = elements.Config.load(filename)
124
+ flags = elements.Flags(defaults)
125
+ assert flags.parse(['--parent.child=42']).parent.child == 42
126
+ assert flags.parse(['--seq=2,4,6']).seq == (2, 4, 6)
@@ -0,0 +1,45 @@
1
+ import elements
2
+
3
+
4
+ class TestPath:
5
+
6
+ def test_str_canonical(self):
7
+ examples = ['/', 'foo/bar', 'file.txt', '/bar.tar.gz']
8
+ for example in examples:
9
+ assert str(elements.Path(example)) == example
10
+
11
+ def test_parent_and_name(self):
12
+ examples = ['foo/bar', '/bar.tar.gz', 'file.txt', 'foo/bar/baz']
13
+ for example in examples:
14
+ path = elements.Path(example)
15
+ assert path == path.parent / path.name
16
+
17
+ def test_stem_and_suffix(self):
18
+ examples = ['foo/bar', '/bar.tar.gz', 'file.txt', 'foo/bar/baz']
19
+ for example in examples:
20
+ path = elements.Path(example)
21
+ assert path.name == path.stem + path.suffix
22
+
23
+ def test_leading_dot(self):
24
+ assert str(elements.Path('')) == '.'
25
+ assert str(elements.Path('.')) == '.'
26
+ assert str(elements.Path('./')) == '.'
27
+ assert str(elements.Path('./foo')) == 'foo'
28
+
29
+ def test_trailing_slash(self):
30
+ assert str(elements.Path('./')) == '.'
31
+ assert str(elements.Path('a/')) == 'a'
32
+ assert str(elements.Path('foo/bar/')) == 'foo/bar'
33
+
34
+ # @pytest.mark.filterwarnings('ignore::DeprecationWarning')
35
+ # def test_protocols(self):
36
+ # assert str(elements.Path('gs://')) == ('gs://')
37
+ # assert str(elements.Path('gs://foo/bar')) == 'gs://foo/bar'
38
+
39
+ def test_parent(self):
40
+ empty = elements.Path('.')
41
+ root = elements.Path('/')
42
+ assert (root / 'foo' / 'bar.txt').parent.parent == root
43
+ assert (empty / 'foo' / 'bar.txt').parent.parent == empty
44
+ assert root.parent == root
45
+ assert empty.parent == empty
@@ -1,78 +0,0 @@
1
- from collections import defaultdict
2
-
3
- import numpy as np
4
-
5
-
6
- class Agg:
7
-
8
- def __init__(self, maxlen=int(1e6)):
9
- self.maxlen = maxlen
10
- self.avgs = defaultdict(lambda: [0.0, 0.0])
11
- self.sums = defaultdict(float)
12
- self.mins = defaultdict(float)
13
- self.maxs = defaultdict(float)
14
- self.lasts = defaultdict(lambda: None)
15
- self.stacks = defaultdict(list)
16
-
17
- def add(
18
- self, key_or_dict, value=None, agg='default', prefix=None, nan='keep'):
19
- aggs = (agg,) if isinstance(agg, str) else agg
20
- assert nan in ('keep', 'ignore')
21
- if value is None:
22
- for key, value in dict(key_or_dict).items():
23
- key = f'{prefix}/{key}' if prefix else key
24
- self._add_single(key, value, aggs, nan)
25
- else:
26
- assert not prefix, prefix
27
- self._add_single(key_or_dict, value, aggs, nan)
28
-
29
- def result(self, reset=True, prefix=None):
30
- metrics = {}
31
- metrics.update({k: v[0] / v[1] for k, v in self.avgs.items()})
32
- metrics.update(self.sums)
33
- metrics.update(self.mins.items())
34
- metrics.update(self.maxs.items())
35
- metrics.update(self.lasts.items())
36
- metrics.update({k: np.stack(v) for k, v in self.stacks.items()})
37
- if prefix:
38
- metrics = {f'{prefix}/{k}': v for k, v in metrics.items()}
39
- reset and self.reset()
40
- return metrics
41
-
42
- def reset(self):
43
- self.avgs.clear()
44
- self.sums.clear()
45
- self.mins.clear()
46
- self.maxs.clear()
47
- self.lasts.clear()
48
- self.stacks.clear()
49
-
50
- def _add_single(self, key, value, aggs, nan):
51
- value = np.asarray(value)
52
- if nan == 'ignore' and np.isnan(value):
53
- return
54
- for agg in aggs:
55
- name = key if len(aggs) == 1 else f'{key}/{agg}'
56
- if agg == 'default':
57
- agg = 'avg' if len(value.shape) <= 1 else 'last'
58
- if agg == 'avg':
59
- avg = self.avgs[name]
60
- avg[0] += value
61
- avg[1] += 1
62
- # assert not np.shares_memory(self.avgs[name][0], value)
63
- elif agg == 'sum':
64
- self.sums[name] += value
65
- elif agg == 'min':
66
- self.mins[name] = min(self.mins[name], value)
67
- # assert not np.shares_memory(self.mins[name], value)
68
- elif agg == 'max':
69
- self.maxs[name] = max(self.maxs[name], value)
70
- # assert not np.shares_memory(self.mins[name], value)
71
- elif agg == 'last':
72
- self.lasts[name] = value
73
- elif agg == 'stack':
74
- stack = self.stacks[name]
75
- if len(stack) < self.maxlen:
76
- stack.append(value)
77
- else:
78
- raise KeyError(agg)
@@ -1,4 +0,0 @@
1
- gputil
2
- numpy
3
- psutil
4
- ruamel.yaml
@@ -1,4 +0,0 @@
1
- gputil
2
- numpy
3
- psutil
4
- ruamel.yaml
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes