ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CHAP/TaskManager.py +216 -0
- CHAP/__init__.py +27 -0
- CHAP/common/__init__.py +57 -0
- CHAP/common/models/__init__.py +8 -0
- CHAP/common/models/common.py +124 -0
- CHAP/common/models/integration.py +659 -0
- CHAP/common/models/map.py +1291 -0
- CHAP/common/processor.py +2869 -0
- CHAP/common/reader.py +658 -0
- CHAP/common/utils.py +110 -0
- CHAP/common/writer.py +730 -0
- CHAP/edd/__init__.py +23 -0
- CHAP/edd/models.py +876 -0
- CHAP/edd/processor.py +3069 -0
- CHAP/edd/reader.py +1023 -0
- CHAP/edd/select_material_params_gui.py +348 -0
- CHAP/edd/utils.py +1572 -0
- CHAP/edd/writer.py +26 -0
- CHAP/foxden/__init__.py +19 -0
- CHAP/foxden/models.py +71 -0
- CHAP/foxden/processor.py +124 -0
- CHAP/foxden/reader.py +224 -0
- CHAP/foxden/utils.py +80 -0
- CHAP/foxden/writer.py +168 -0
- CHAP/giwaxs/__init__.py +11 -0
- CHAP/giwaxs/models.py +491 -0
- CHAP/giwaxs/processor.py +776 -0
- CHAP/giwaxs/reader.py +8 -0
- CHAP/giwaxs/writer.py +8 -0
- CHAP/inference/__init__.py +7 -0
- CHAP/inference/processor.py +69 -0
- CHAP/inference/reader.py +8 -0
- CHAP/inference/writer.py +8 -0
- CHAP/models.py +227 -0
- CHAP/pipeline.py +479 -0
- CHAP/processor.py +125 -0
- CHAP/reader.py +124 -0
- CHAP/runner.py +277 -0
- CHAP/saxswaxs/__init__.py +7 -0
- CHAP/saxswaxs/processor.py +8 -0
- CHAP/saxswaxs/reader.py +8 -0
- CHAP/saxswaxs/writer.py +8 -0
- CHAP/server.py +125 -0
- CHAP/sin2psi/__init__.py +7 -0
- CHAP/sin2psi/processor.py +8 -0
- CHAP/sin2psi/reader.py +8 -0
- CHAP/sin2psi/writer.py +8 -0
- CHAP/tomo/__init__.py +15 -0
- CHAP/tomo/models.py +210 -0
- CHAP/tomo/processor.py +3862 -0
- CHAP/tomo/reader.py +9 -0
- CHAP/tomo/writer.py +59 -0
- CHAP/utils/__init__.py +6 -0
- CHAP/utils/converters.py +188 -0
- CHAP/utils/fit.py +2947 -0
- CHAP/utils/general.py +2655 -0
- CHAP/utils/material.py +274 -0
- CHAP/utils/models.py +595 -0
- CHAP/utils/parfile.py +224 -0
- CHAP/writer.py +122 -0
- MLaaS/__init__.py +0 -0
- MLaaS/ktrain.py +205 -0
- MLaaS/mnist_img.py +83 -0
- MLaaS/tfaas_client.py +371 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
CHAP/utils/parfile.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""Utilities for interacting with scans using an SMB-style .par file
|
|
2
|
+
as input.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
# System modules
|
|
6
|
+
import csv
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
# Third party modules
|
|
10
|
+
import json
|
|
11
|
+
|
|
12
|
+
class ParFile():
|
|
13
|
+
"""Representation of a .par file.
|
|
14
|
+
|
|
15
|
+
:ivar par_file: Name of the .par file.
|
|
16
|
+
:type par_file: str
|
|
17
|
+
:ivar json_file: Name of the .json file containing the keys for
|
|
18
|
+
the column names of the .par file.
|
|
19
|
+
:type json_file: str
|
|
20
|
+
:ivar spec_file: Name of the SPEC data file associated with this
|
|
21
|
+
.par file.
|
|
22
|
+
:type spec_file: str
|
|
23
|
+
:ivar column_names: List of the names of each column in the par
|
|
24
|
+
file.
|
|
25
|
+
:type column_names: list[str]
|
|
26
|
+
:ivar data: A 2D array of the data in this .par file. 0th index:
|
|
27
|
+
row. 1st index: column
|
|
28
|
+
:type data: list[list]
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, par_file, scan_numbers=None, scann_col_name='SCAN_N'):
|
|
31
|
+
# Local modules
|
|
32
|
+
from CHAP.utils.general import (
|
|
33
|
+
is_int_series,
|
|
34
|
+
string_to_list,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
self.par_file = str(par_file)
|
|
38
|
+
self.json_file = self.par_file.replace('.par', '.json')
|
|
39
|
+
self.spec_file = os.path.join(
|
|
40
|
+
os.path.dirname(self.par_file), 'spec.log')
|
|
41
|
+
|
|
42
|
+
with open(self.json_file) as json_file:
|
|
43
|
+
columns = json.load(json_file)
|
|
44
|
+
num_column = len(columns)
|
|
45
|
+
self.column_names = [None] * num_column
|
|
46
|
+
for i, name in columns.items():
|
|
47
|
+
self.column_names[int(i)] = name
|
|
48
|
+
|
|
49
|
+
self.data = []
|
|
50
|
+
with open(self.par_file) as f:
|
|
51
|
+
reader = csv.reader(f, delimiter=' ')
|
|
52
|
+
for i, row in enumerate(reader):
|
|
53
|
+
if len(row) == 0:
|
|
54
|
+
continue
|
|
55
|
+
if row[0].startswith('#'):
|
|
56
|
+
continue
|
|
57
|
+
row_data = []
|
|
58
|
+
for value in row:
|
|
59
|
+
try:
|
|
60
|
+
value = int(value)
|
|
61
|
+
except ValueError:
|
|
62
|
+
try:
|
|
63
|
+
value = float(value)
|
|
64
|
+
except Exception:
|
|
65
|
+
pass
|
|
66
|
+
row_data.append(value)
|
|
67
|
+
if len(row_data) != num_column:
|
|
68
|
+
raise ValueError(
|
|
69
|
+
'Mismatch between the number of columns in the json '
|
|
70
|
+
f'({num_column}) and line {i+1} of the par file '
|
|
71
|
+
f'({len(row_data)})')
|
|
72
|
+
self.data.append(row_data)
|
|
73
|
+
|
|
74
|
+
self.scann_i = self.column_names.index(scann_col_name)
|
|
75
|
+
self.scan_numbers = [data[self.scann_i] for data in self.data]
|
|
76
|
+
if scan_numbers is not None:
|
|
77
|
+
if isinstance(scan_numbers, int):
|
|
78
|
+
scan_numbers = [scan_numbers]
|
|
79
|
+
elif isinstance(scan_numbers, str):
|
|
80
|
+
scan_numbers = string_to_list(scan_numbers)
|
|
81
|
+
if not is_int_series(scan_numbers, ge=0, log=False):
|
|
82
|
+
raise TypeError(
|
|
83
|
+
f'Invalid scan_numbers parameter ({scan_numbers})')
|
|
84
|
+
self.scan_numbers = [
|
|
85
|
+
n for n in scan_numbers if n in self.scan_numbers]
|
|
86
|
+
|
|
87
|
+
def get_map(
|
|
88
|
+
self, experiment_type, station, par_dims, other_dims=None):
|
|
89
|
+
"""Return a map configuration based on this par file.
|
|
90
|
+
|
|
91
|
+
:param experiment_type: Experiment type name for the map
|
|
92
|
+
that this .par file represents.
|
|
93
|
+
:type experiment_type:
|
|
94
|
+
Literal['SAXSWAXS', 'EDD', 'XRF', 'TOMO']
|
|
95
|
+
:param station: Station name at which the data were collected.
|
|
96
|
+
:type station: Literal['id1a3','id3a','id3b']
|
|
97
|
+
:param par_dims: List of dictionaries configuring the map's
|
|
98
|
+
independent dimensions.
|
|
99
|
+
:type par_dims: list[dict[str, str]]
|
|
100
|
+
:param other_dims: List of other dimensions to include in
|
|
101
|
+
the returned MapConfig's independednt_dimensions. Use this
|
|
102
|
+
if each scans in thhis par ile captured more than one
|
|
103
|
+
frame of data. Defaults to `None`
|
|
104
|
+
:type other_dims: list[dict[str,str]], optional
|
|
105
|
+
:return: The map configuration.
|
|
106
|
+
:rtype: CHAP.common.models.map.MapConfig
|
|
107
|
+
"""
|
|
108
|
+
# Third party modeuls
|
|
109
|
+
# pylint: disable=import-error
|
|
110
|
+
from chess_scanparsers import SMBScanParser
|
|
111
|
+
# pylint: enable=import-error
|
|
112
|
+
|
|
113
|
+
# Local modules
|
|
114
|
+
from CHAP.common.models.map import MapConfig
|
|
115
|
+
|
|
116
|
+
scanparser = SMBScanParser(self.spec_file, 1)
|
|
117
|
+
good_scans = self.good_scan_numbers()
|
|
118
|
+
if other_dims is None:
|
|
119
|
+
other_dims = []
|
|
120
|
+
map_config = {
|
|
121
|
+
'title': scanparser.scan_name,
|
|
122
|
+
'station': station, #scanparser.station,
|
|
123
|
+
'experiment_type': experiment_type,
|
|
124
|
+
'sample': {'name': scanparser.scan_name},
|
|
125
|
+
'spec_scans': [
|
|
126
|
+
{'spec_file': self.spec_file,
|
|
127
|
+
'scan_numbers': good_scans}],
|
|
128
|
+
'independent_dimensions': [
|
|
129
|
+
{'label': dim['label'],
|
|
130
|
+
'units': dim['units'],
|
|
131
|
+
'name': dim['name'],
|
|
132
|
+
'data_type': 'smb_par'}
|
|
133
|
+
for dim in par_dims] + other_dims
|
|
134
|
+
}
|
|
135
|
+
return MapConfig(**map_config)
|
|
136
|
+
|
|
137
|
+
def good_scan_numbers(self, good_col_name='1/0'):
|
|
138
|
+
"""Return the numbers of scans marked with a "1" in the
|
|
139
|
+
indicated "good" column of the .par file.
|
|
140
|
+
|
|
141
|
+
:param good_col_name: The name of the "good" column of the par
|
|
142
|
+
file, defaults to "1/0"
|
|
143
|
+
:type good_col_name: str, optional
|
|
144
|
+
:raises ValueError: If this .par file does not have a column
|
|
145
|
+
with the same name as `good_col_name`.
|
|
146
|
+
:return: "good" scan numbers.
|
|
147
|
+
:rtype: list[int]
|
|
148
|
+
"""
|
|
149
|
+
good_col_i = self.column_names.index(good_col_name)
|
|
150
|
+
return [self.scan_numbers[i] for i in range(len(self.scan_numbers))
|
|
151
|
+
if self.data[i][good_col_i] == 1]
|
|
152
|
+
|
|
153
|
+
def get_values(self, column, scan_numbers=None):
|
|
154
|
+
"""Return values from a single column of the par file.
|
|
155
|
+
|
|
156
|
+
:param column: The string name OR index of the column to return
|
|
157
|
+
values for.
|
|
158
|
+
:type column: str or int
|
|
159
|
+
:param scan_numbers: List of specific scan numbers to return
|
|
160
|
+
values in the given column for (instead of the default
|
|
161
|
+
behavior: return the entire column of values).
|
|
162
|
+
:type scan_numbers: list[int], optional
|
|
163
|
+
:raise:
|
|
164
|
+
ValueError: Unavailable column name.
|
|
165
|
+
TypeError: Illegal column name type.
|
|
166
|
+
:return: A list of values from a single column in the par file.
|
|
167
|
+
:rtype: list[object]
|
|
168
|
+
"""
|
|
169
|
+
if isinstance(column, str):
|
|
170
|
+
if column in self.column_names:
|
|
171
|
+
column_idx = self.column_names.index(column)
|
|
172
|
+
# elif column in ('dataset_id', 'scan_type'):
|
|
173
|
+
# column_idx = None
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError(f'Unavailable column name: {column} not in '
|
|
176
|
+
f'{self.column_names}')
|
|
177
|
+
elif isinstance(column, int):
|
|
178
|
+
column_idx = column
|
|
179
|
+
else:
|
|
180
|
+
raise TypeError(f'column must be a str or int, not {type(column)}')
|
|
181
|
+
|
|
182
|
+
if column_idx is None:
|
|
183
|
+
column_data = [None]*len(self.data)
|
|
184
|
+
else:
|
|
185
|
+
column_data = [
|
|
186
|
+
self.data[i][column_idx] for i in range(len(self.data))]
|
|
187
|
+
if scan_numbers is not None:
|
|
188
|
+
column_data = [column_data[self.scan_numbers.index(scan_n)] \
|
|
189
|
+
for scan_n in scan_numbers]
|
|
190
|
+
return column_data
|
|
191
|
+
|
|
192
|
+
def map_values(self, map_config, values):
|
|
193
|
+
"""Return a reshaped array of the 1D list `values` so that it
|
|
194
|
+
matches up with the coordinates of `map_config`.
|
|
195
|
+
|
|
196
|
+
:param map_config: The map configuration according to which
|
|
197
|
+
values will be reshaped.
|
|
198
|
+
:type map_config: MapConfig
|
|
199
|
+
:param values: A 1D list of values to reshape.
|
|
200
|
+
:type values: list or np.ndarray
|
|
201
|
+
:return: Reshaped array of values.
|
|
202
|
+
:rtype: np.ndarray
|
|
203
|
+
"""
|
|
204
|
+
# Third party modules
|
|
205
|
+
import numpy as np
|
|
206
|
+
|
|
207
|
+
good_scans = self.good_scan_numbers()
|
|
208
|
+
if len(values) != len(good_scans):
|
|
209
|
+
raise ValueError('number of values provided ({len(values)}) does '
|
|
210
|
+
'not match the number of good scans in '
|
|
211
|
+
f'{self.par_file} ({len(good_scans)})')
|
|
212
|
+
n_map_points = np.prod(map_config.shape)
|
|
213
|
+
if len(values) != n_map_points:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f'Cannot reshape {len(values)} values into an array of shape '
|
|
216
|
+
f'{map_config.shape}')
|
|
217
|
+
|
|
218
|
+
map_values = np.empty(map_config.shape)
|
|
219
|
+
for map_index in np.ndindex(map_config.shape):
|
|
220
|
+
_, scan_number, _ = \
|
|
221
|
+
map_config.get_scan_step_index(map_index)
|
|
222
|
+
value_index = good_scans.index(scan_number)
|
|
223
|
+
map_values[map_index] = values[value_index]
|
|
224
|
+
return map_values
|
CHAP/writer.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
"""
|
|
3
|
+
File : writer.py
|
|
4
|
+
Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
|
|
5
|
+
Description: generic Writer module
|
|
6
|
+
|
|
7
|
+
Define a generic `Writer` object.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# System modules
|
|
11
|
+
import argparse
|
|
12
|
+
import logging
|
|
13
|
+
import os
|
|
14
|
+
from sys import modules
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
# Third party modules
|
|
18
|
+
from pydantic import model_validator
|
|
19
|
+
|
|
20
|
+
# Local modules
|
|
21
|
+
from CHAP.pipeline import PipelineItem
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def validate_writer_model(writer):
|
|
25
|
+
writer.filename = os.path.normpath(os.path.realpath(
|
|
26
|
+
os.path.join(writer.outputdir, writer.filename)))
|
|
27
|
+
if (not writer.force_overwrite
|
|
28
|
+
and os.path.isfile(writer.filename)):
|
|
29
|
+
raise ValueError(
|
|
30
|
+
'Writing to an existing file without overwrite permission. '
|
|
31
|
+
f'permission. Remove {writer.filename} or set '
|
|
32
|
+
'"force_overwrite" in pipeline configuration for '
|
|
33
|
+
f'{writer.name}')
|
|
34
|
+
return writer
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Writer(PipelineItem):
|
|
38
|
+
"""Generic file writer.
|
|
39
|
+
|
|
40
|
+
The job of any `Writer` in a `Pipeline` is to receive input
|
|
41
|
+
returned by a previous `PipelineItem`, write its data to a
|
|
42
|
+
particular file format, then return the same data unaltered so it
|
|
43
|
+
can be used by a successive `PipelineItem`.
|
|
44
|
+
|
|
45
|
+
:ivar filename: Name of file to write to.
|
|
46
|
+
:type filename: str
|
|
47
|
+
:ivar force_overwrite: Flag to allow data in `filename` to be
|
|
48
|
+
overwritten if it already exists, defaults to `False`.
|
|
49
|
+
:type force_overwrite: bool, optional
|
|
50
|
+
:ivar remove: Flag to remove the dictionary from `data`,
|
|
51
|
+
defaults to `False`.
|
|
52
|
+
:type remove: bool, optional
|
|
53
|
+
"""
|
|
54
|
+
filename: str
|
|
55
|
+
force_overwrite: Optional[bool] = False
|
|
56
|
+
remove: Optional[bool] = False
|
|
57
|
+
|
|
58
|
+
_validate_filename = model_validator(mode="after")(
|
|
59
|
+
validate_writer_model)
|
|
60
|
+
|
|
61
|
+
def write(self, data):
|
|
62
|
+
"""Write the last `CHAP.pipeline.PipelineData` item in `data`
|
|
63
|
+
as text to a file.
|
|
64
|
+
|
|
65
|
+
:param data: Input data.
|
|
66
|
+
:type data: list[CHAP.pipeline.PipelineData]
|
|
67
|
+
:return: Contents of the input data.
|
|
68
|
+
:rtype: list[PipelineData]
|
|
69
|
+
"""
|
|
70
|
+
ddata = self.unwrap_pipelinedata(data)[-1]
|
|
71
|
+
if os.path.isfile(self.filename) and not self.force_overwrite:
|
|
72
|
+
raise FileExistsError(f'{self.filename} already exists')
|
|
73
|
+
with open(self.filename, 'w') as f:
|
|
74
|
+
f.write(ddata)
|
|
75
|
+
if self.remove:
|
|
76
|
+
data.pop()
|
|
77
|
+
self.status = 'written' # Right now does nothing yet, but could
|
|
78
|
+
# add a sort of modification flag later
|
|
79
|
+
return data
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OptionParser():
|
|
83
|
+
"""User based option parser."""
|
|
84
|
+
def __init__(self):
|
|
85
|
+
self.parser = argparse.ArgumentParser(prog='PROG')
|
|
86
|
+
self.parser.add_argument(
|
|
87
|
+
'--data', action='store',
|
|
88
|
+
dest='data', default='', help='Input data')
|
|
89
|
+
self.parser.add_argument(
|
|
90
|
+
'--filename', action='store',
|
|
91
|
+
dest='filename', default='', help='Output file')
|
|
92
|
+
self.parser.add_argument(
|
|
93
|
+
'--writer', action='store',
|
|
94
|
+
dest='writer', default='Writer', help='Writer class name')
|
|
95
|
+
self.parser.add_argument(
|
|
96
|
+
'--log-level', choices=logging._nameToLevel.keys(),
|
|
97
|
+
dest='log_level', default='INFO', help='logging level')
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def main(opt_parser=OptionParser):
|
|
101
|
+
"""Main function."""
|
|
102
|
+
optmgr = opt_parser()
|
|
103
|
+
opts = optmgr.parser.parse_args()
|
|
104
|
+
cls_name = opts.writer
|
|
105
|
+
try:
|
|
106
|
+
writer_cls = getattr(modules[__name__], cls_name)
|
|
107
|
+
except AttributeError:
|
|
108
|
+
print(f'Unsupported writer {cls_name}')
|
|
109
|
+
raise
|
|
110
|
+
|
|
111
|
+
writer = writer_cls()
|
|
112
|
+
writer.logger.setLevel(getattr(logging, opts.log_level))
|
|
113
|
+
log_handler = logging.StreamHandler()
|
|
114
|
+
log_handler.setFormatter(logging.Formatter(
|
|
115
|
+
'{name:20}: {message}', style='{'))
|
|
116
|
+
writer.logger.addHandler(log_handler)
|
|
117
|
+
data = writer.write(opts.data, opts.filename)
|
|
118
|
+
print(f'Writer {writer} writes to {opts.filename}, data {data}')
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
if __name__ == '__main__':
|
|
122
|
+
main()
|
MLaaS/__init__.py
ADDED
|
File without changes
|
MLaaS/ktrain.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
#-*- coding: utf-8 -*-
|
|
3
|
+
#pylint: disable=
|
|
4
|
+
"""
|
|
5
|
+
File : ktrain.py
|
|
6
|
+
Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
|
|
7
|
+
Description: Keras based ML network to train over MNIST dataset
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
# system modules
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
import json
|
|
14
|
+
import gzip
|
|
15
|
+
import pickle
|
|
16
|
+
import argparse
|
|
17
|
+
|
|
18
|
+
# third-party modules
|
|
19
|
+
import numpy as np
|
|
20
|
+
import tensorflow as tf
|
|
21
|
+
from tensorflow import keras
|
|
22
|
+
from tensorflow.keras import layers
|
|
23
|
+
from tensorflow.keras import backend as K
|
|
24
|
+
from tensorflow.python.tools import saved_model_utils
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def modelGraph(model_dir):
|
|
28
|
+
"""
|
|
29
|
+
Provide input/output names used by TF Graph along with graph itself
|
|
30
|
+
The code is based on TF saved_model_cli.py script.
|
|
31
|
+
"""
|
|
32
|
+
input_names = []
|
|
33
|
+
output_names = []
|
|
34
|
+
tag_sets = saved_model_utils.get_saved_model_tag_sets(model_dir)
|
|
35
|
+
for tag_set in sorted(tag_sets):
|
|
36
|
+
print('%r' % ', '.join(sorted(tag_set)))
|
|
37
|
+
meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tag_set[0])
|
|
38
|
+
for key in meta_graph_def.signature_def.keys():
|
|
39
|
+
meta = meta_graph_def.signature_def[key]
|
|
40
|
+
if hasattr(meta, 'inputs') and hasattr(meta, 'outputs'):
|
|
41
|
+
inputs = meta.inputs
|
|
42
|
+
outputs = meta.outputs
|
|
43
|
+
input_signatures = list(meta.inputs.values())
|
|
44
|
+
input_names = [signature.name for signature in input_signatures]
|
|
45
|
+
if len(input_names) > 0:
|
|
46
|
+
output_signatures = list(meta.outputs.values())
|
|
47
|
+
output_names = [signature.name for signature in output_signatures]
|
|
48
|
+
return input_names, output_names, meta_graph_def
|
|
49
|
+
|
|
50
|
+
def readData(fin, num_classes):
|
|
51
|
+
"""
|
|
52
|
+
Helper function to read MNIST data and provide it to
|
|
53
|
+
upstream code, e.g. to the training layer
|
|
54
|
+
"""
|
|
55
|
+
# Load the data and split it between train and test sets
|
|
56
|
+
# (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
|
57
|
+
f = gzip.open(fin, 'rb')
|
|
58
|
+
if sys.version_info < (3,):
|
|
59
|
+
mnist_data = pickle.load(f)
|
|
60
|
+
else:
|
|
61
|
+
mnist_data = pickle.load(f, encoding='bytes')
|
|
62
|
+
f.close()
|
|
63
|
+
(x_train, y_train), (x_test, y_test) = mnist_data
|
|
64
|
+
|
|
65
|
+
# Scale images to the [0, 1] range
|
|
66
|
+
x_train = x_train.astype("float32") / 255
|
|
67
|
+
x_test = x_test.astype("float32") / 255
|
|
68
|
+
# Make sure images have shape (28, 28, 1)
|
|
69
|
+
x_train = np.expand_dims(x_train, -1)
|
|
70
|
+
x_test = np.expand_dims(x_test, -1)
|
|
71
|
+
print("x_train shape:", x_train.shape)
|
|
72
|
+
print(x_train.shape[0], "train samples")
|
|
73
|
+
print(x_test.shape[0], "test samples")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# convert class vectors to binary class matrices
|
|
77
|
+
y_train = keras.utils.to_categorical(y_train, num_classes)
|
|
78
|
+
y_test = keras.utils.to_categorical(y_test, num_classes)
|
|
79
|
+
return x_train, y_train, x_test, y_test
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def train(fin, fout=None, model_name=None, epochs=1, batch_size=128, h5=False):
|
|
83
|
+
"""
|
|
84
|
+
train function for MNIST
|
|
85
|
+
"""
|
|
86
|
+
# Model / data parameters
|
|
87
|
+
num_classes = 10
|
|
88
|
+
input_shape = (28, 28, 1)
|
|
89
|
+
|
|
90
|
+
# create ML model
|
|
91
|
+
model = keras.Sequential(
|
|
92
|
+
[
|
|
93
|
+
keras.Input(shape=input_shape),
|
|
94
|
+
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
|
|
95
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
96
|
+
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
|
|
97
|
+
layers.MaxPooling2D(pool_size=(2, 2)),
|
|
98
|
+
layers.Flatten(),
|
|
99
|
+
layers.Dropout(0.5),
|
|
100
|
+
layers.Dense(num_classes, activation="softmax"),
|
|
101
|
+
]
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
model.summary()
|
|
105
|
+
print("model input", model.input, type(model.input), model.input.__dict__)
|
|
106
|
+
print("model output", model.output, type(model.output), model.output.__dict__)
|
|
107
|
+
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
|
|
108
|
+
|
|
109
|
+
# train model
|
|
110
|
+
x_train, y_train, x_test, y_test = readData(fin, num_classes)
|
|
111
|
+
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
|
|
112
|
+
|
|
113
|
+
# evaluate trained model
|
|
114
|
+
score = model.evaluate(x_test, y_test, verbose=0)
|
|
115
|
+
print("Test loss:", score[0])
|
|
116
|
+
print("Test accuracy:", score[1])
|
|
117
|
+
print("save model to", fout)
|
|
118
|
+
writer(fout, model_name, model, input_shape, h5)
|
|
119
|
+
|
|
120
|
+
def writer(fout, model_name, model, input_shape, h5=False):
|
|
121
|
+
"""
|
|
122
|
+
Writer provide write function for given model
|
|
123
|
+
"""
|
|
124
|
+
if not fout:
|
|
125
|
+
return
|
|
126
|
+
model.save(fout)
|
|
127
|
+
if h5:
|
|
128
|
+
model.save('{}/{}'.format(fout, h5), save_format='h5')
|
|
129
|
+
pbModel = '{}/saved_model.pb'.format(fout)
|
|
130
|
+
pbtxtModel = '{}/saved_model.pbtxt'.format(fout)
|
|
131
|
+
convert(pbModel, pbtxtModel)
|
|
132
|
+
|
|
133
|
+
# get meta-data information about our ML model
|
|
134
|
+
input_names, output_names, model_graph = modelGraph(model_name)
|
|
135
|
+
print("### input", input_names)
|
|
136
|
+
print("### output", output_names)
|
|
137
|
+
# ML uses (28,28,1) shape, i.e. 28x28 black-white images
|
|
138
|
+
# if we'll use color images we'll use shape (28, 28, 3)
|
|
139
|
+
img_channels = input_shape[2] # last item represent number of colors
|
|
140
|
+
meta = {'name': model_name,
|
|
141
|
+
'model': 'saved_model.pb',
|
|
142
|
+
'labels': 'labels.txt',
|
|
143
|
+
'img_channels': img_channels,
|
|
144
|
+
'input_name': input_names[0].split(':')[0],
|
|
145
|
+
'output_name': output_names[0].split(':')[0],
|
|
146
|
+
'input_node': model.input.name,
|
|
147
|
+
'output_node': model.output.name
|
|
148
|
+
}
|
|
149
|
+
with open(fout+'/params.json', 'w') as ostream:
|
|
150
|
+
ostream.write(json.dumps(meta))
|
|
151
|
+
with open(fout+'/labels.txt', 'w') as ostream:
|
|
152
|
+
for i in range(0, 10):
|
|
153
|
+
ostream.write(str(i)+'\n')
|
|
154
|
+
with open(fout + '/model.graph', 'wb') as ostream:
|
|
155
|
+
ostream.write(model_graph.SerializeToString())
|
|
156
|
+
|
|
157
|
+
def convert(fin, fout):
|
|
158
|
+
"""
|
|
159
|
+
convert input model.pb into output model.pbtxt
|
|
160
|
+
Based on internet search:
|
|
161
|
+
- https://www.tensorflow.org/guide/saved_model
|
|
162
|
+
- https://www.programcreek.com/python/example/123317/tensorflow.core.protobuf.saved_model_pb2.SavedModel
|
|
163
|
+
"""
|
|
164
|
+
import google.protobuf
|
|
165
|
+
from tensorflow.core.protobuf import saved_model_pb2
|
|
166
|
+
import tensorflow as tf
|
|
167
|
+
|
|
168
|
+
saved_model = saved_model_pb2.SavedModel()
|
|
169
|
+
|
|
170
|
+
with open(fin, 'rb') as f:
|
|
171
|
+
saved_model.ParseFromString(f.read())
|
|
172
|
+
|
|
173
|
+
with open(fout, 'w') as f:
|
|
174
|
+
f.write(google.protobuf.text_format.MessageToString(saved_model))
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
class OptionParser():
|
|
178
|
+
def __init__(self):
|
|
179
|
+
"User based option parser"
|
|
180
|
+
self.parser = argparse.ArgumentParser(prog='PROG')
|
|
181
|
+
self.parser.add_argument("--fin", action="store",
|
|
182
|
+
dest="fin", default="", help="Input MNIST file")
|
|
183
|
+
self.parser.add_argument("--fout", action="store",
|
|
184
|
+
dest="fout", default="", help="Output models area")
|
|
185
|
+
self.parser.add_argument("--model", action="store",
|
|
186
|
+
dest="model", default="mnist", help="model name")
|
|
187
|
+
self.parser.add_argument("--epochs", action="store",
|
|
188
|
+
dest="epochs", default=1, help="number of epochs to use in ML training")
|
|
189
|
+
self.parser.add_argument("--batch_size", action="store",
|
|
190
|
+
dest="batch_size", default=128, help="batch size to use in training")
|
|
191
|
+
self.parser.add_argument("--h5", action="store",
|
|
192
|
+
dest="h5", default="mnist", help="h5 model file name")
|
|
193
|
+
|
|
194
|
+
def main():
|
|
195
|
+
"Main function"
|
|
196
|
+
optmgr = OptionParser()
|
|
197
|
+
opts = optmgr.parser.parse_args()
|
|
198
|
+
train(opts.fin, opts.fout,
|
|
199
|
+
model_name=opts.model,
|
|
200
|
+
epochs=opts.epochs,
|
|
201
|
+
batch_size=opts.batch_size,
|
|
202
|
+
h5=opts.h5)
|
|
203
|
+
|
|
204
|
+
if __name__ == '__main__':
|
|
205
|
+
main()
|
MLaaS/mnist_img.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
#-*- coding: utf-8 -*-
|
|
3
|
+
#pylint: disable=
|
|
4
|
+
"""
|
|
5
|
+
File : mnist_img.py
|
|
6
|
+
Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
|
|
7
|
+
Description:
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import gzip
|
|
12
|
+
import argparse
|
|
13
|
+
# from itertools import chain
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import matplotlib.pyplot as plt
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def readImage(fname, fout, num_images=5, imgId=2):
|
|
20
|
+
"""
|
|
21
|
+
Helper function to read MNIST image
|
|
22
|
+
"""
|
|
23
|
+
image_size = 28
|
|
24
|
+
with gzip.open(fname, 'r') as fstream:
|
|
25
|
+
fstream.read(16)
|
|
26
|
+
buf = fstream.read(image_size * image_size * num_images)
|
|
27
|
+
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
|
|
28
|
+
data = data.reshape(num_images, image_size, image_size, 1)
|
|
29
|
+
image = np.asarray(data[imgId]).squeeze()
|
|
30
|
+
plt.imsave(fout, image)
|
|
31
|
+
print("read:", fname, "wrote:", fout, "image:", type(image), "shape:", image.shape)
|
|
32
|
+
|
|
33
|
+
def img2json(image):
|
|
34
|
+
"""
|
|
35
|
+
Convert given image to JSON data format used by TFaaS
|
|
36
|
+
"""
|
|
37
|
+
# values = [int(i) for i in list(chain.from_iterable(image))]
|
|
38
|
+
# values = image.tolist()
|
|
39
|
+
values = []
|
|
40
|
+
for row in image.tolist():
|
|
41
|
+
row = [int(i) for i in row]
|
|
42
|
+
vals = [[i] for i in row]
|
|
43
|
+
values.append(vals)
|
|
44
|
+
# final values should be an array of elements, e.g. single image representation
|
|
45
|
+
values = [values]
|
|
46
|
+
keys = [str(i) for i in range(0, 10)]
|
|
47
|
+
meta = {
|
|
48
|
+
'keys': keys,
|
|
49
|
+
'values': values,
|
|
50
|
+
'model': 'mnist'
|
|
51
|
+
}
|
|
52
|
+
with open('img.json', 'w') as ostream:
|
|
53
|
+
ostream.write(json.dumps(meta))
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class OptionParser():
|
|
57
|
+
def __init__(self):
|
|
58
|
+
"User based option parser"
|
|
59
|
+
fname = "train-images-idx3-ubyte.gz"
|
|
60
|
+
self.parser = argparse.ArgumentParser(prog='PROG')
|
|
61
|
+
self.parser.add_argument("--fin", action="store",
|
|
62
|
+
dest="fin", default=fname, help=f"Input MNIST file, default {fname}")
|
|
63
|
+
self.parser.add_argument("--fout", action="store",
|
|
64
|
+
dest="fout", default="img.png", help="Output image fila name, default img.png")
|
|
65
|
+
self.parser.add_argument("--nimages", action="store",
|
|
66
|
+
dest="nimages", default=5, help="number of images to read, default 5")
|
|
67
|
+
self.parser.add_argument("--imgid", action="store",
|
|
68
|
+
dest="imgid", default=2, help="image index to use from nimages, default 2 (number 4)")
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
"""
|
|
72
|
+
main function to produce image file from mnist dataset.
|
|
73
|
+
MNIST dataset can be downloaded from
|
|
74
|
+
curl -O http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
|
|
75
|
+
"""
|
|
76
|
+
optmgr = OptionParser()
|
|
77
|
+
opts = optmgr.parser.parse_args()
|
|
78
|
+
num_images = int(opts.nimages)
|
|
79
|
+
imgId = int(opts.imgid)
|
|
80
|
+
img = readImage(opts.fin, opts.fout, num_images, imgId)
|
|
81
|
+
|
|
82
|
+
if __name__ == '__main__':
|
|
83
|
+
main()
|