scope 0.5.1__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: scope
3
- Version: 0.5.1
3
+ Version: 0.5.3
4
4
  Summary: Metrics logging and analysis
5
5
  Home-page: http://github.com/danijar/scope
6
6
  Classifier: Intended Audience :: Science/Research
@@ -1,4 +1,4 @@
1
- __version__ = '0.5.1'
1
+ __version__ = '0.5.3'
2
2
 
3
3
  from .reader import Reader
4
4
  from .writer import Writer
@@ -13,7 +13,14 @@ class Float:
13
13
  return 'float'
14
14
 
15
15
  def valid(self, x):
16
- return x.ndim == 0 and np.isreal(x)
16
+ if isinstance(x, (int, float)):
17
+ return True
18
+ if isinstance(x, np.ndarray):
19
+ return x.ndim == 0 and np.isreal(x)
20
+ return False
21
+
22
+ def convert(self, x):
23
+ return np.asarray(x, np.float64)
17
24
 
18
25
  def create(self, path):
19
26
  pass
@@ -40,6 +47,9 @@ class Text:
40
47
  def valid(self, x):
41
48
  return isinstance(x, str)
42
49
 
50
+ def convert(self, x):
51
+ return x
52
+
43
53
  def create(self, path):
44
54
  path.mkdir(exist_ok=True)
45
55
 
@@ -71,10 +81,14 @@ class Image:
71
81
 
72
82
  def valid(self, x):
73
83
  return (
84
+ isinstance(x, np.ndarray) and
74
85
  x.dtype == np.uint8 and
75
86
  x.ndim == 3 and
76
87
  x.shape[-1] in (1, 3))
77
88
 
89
+ def convert(self, x):
90
+ return x
91
+
78
92
  def create(self, path):
79
93
  path.mkdir(exist_ok=True)
80
94
 
@@ -112,10 +126,14 @@ class Video:
112
126
 
113
127
  def valid(self, x):
114
128
  return (
129
+ isinstance(x, np.ndarray) and
115
130
  x.dtype == np.uint8 and
116
131
  x.ndim == 4 and
117
132
  x.shape[-1] in (1, 3))
118
133
 
134
+ def convert(self, x):
135
+ return x
136
+
119
137
  def create(self, path):
120
138
  path.mkdir(exist_ok=True)
121
139
 
@@ -37,7 +37,7 @@ class Writer:
37
37
  self.logdir.mkdir(parents=True, exist_ok=True)
38
38
  self.workers = workers
39
39
  self.rng = np.random.default_rng(seed=None)
40
- self.fmts = FORMATS
40
+ self.fmts = formats
41
41
  self.cols = {}
42
42
  if workers:
43
43
  self.pool = concurrent.futures.ThreadPoolExecutor(workers, 'scope')
@@ -48,8 +48,6 @@ class Writer:
48
48
  step = int(step)
49
49
  mapping = dict(*args, **kwargs)
50
50
  for key, value in mapping.items():
51
- if not isinstance(value, str):
52
- value = np.asarray(value)
53
51
  if key not in self.cols:
54
52
  assert re.match(r'[a-z0-9_]+(/[a-z0-9_]+)?', key), key
55
53
  for fmt in self.fmts:
@@ -64,6 +62,7 @@ class Writer:
64
62
  if not col.fmt.valid(value):
65
63
  raise ValueError(
66
64
  f"Key '{key}' contains invalid value {self._info(value)}")
65
+ value = col.fmt.convert(value)
67
66
  col.steps.append(step)
68
67
  col.values.append(value)
69
68