ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. CHAP/TaskManager.py +216 -0
  2. CHAP/__init__.py +27 -0
  3. CHAP/common/__init__.py +57 -0
  4. CHAP/common/models/__init__.py +8 -0
  5. CHAP/common/models/common.py +124 -0
  6. CHAP/common/models/integration.py +659 -0
  7. CHAP/common/models/map.py +1291 -0
  8. CHAP/common/processor.py +2869 -0
  9. CHAP/common/reader.py +658 -0
  10. CHAP/common/utils.py +110 -0
  11. CHAP/common/writer.py +730 -0
  12. CHAP/edd/__init__.py +23 -0
  13. CHAP/edd/models.py +876 -0
  14. CHAP/edd/processor.py +3069 -0
  15. CHAP/edd/reader.py +1023 -0
  16. CHAP/edd/select_material_params_gui.py +348 -0
  17. CHAP/edd/utils.py +1572 -0
  18. CHAP/edd/writer.py +26 -0
  19. CHAP/foxden/__init__.py +19 -0
  20. CHAP/foxden/models.py +71 -0
  21. CHAP/foxden/processor.py +124 -0
  22. CHAP/foxden/reader.py +224 -0
  23. CHAP/foxden/utils.py +80 -0
  24. CHAP/foxden/writer.py +168 -0
  25. CHAP/giwaxs/__init__.py +11 -0
  26. CHAP/giwaxs/models.py +491 -0
  27. CHAP/giwaxs/processor.py +776 -0
  28. CHAP/giwaxs/reader.py +8 -0
  29. CHAP/giwaxs/writer.py +8 -0
  30. CHAP/inference/__init__.py +7 -0
  31. CHAP/inference/processor.py +69 -0
  32. CHAP/inference/reader.py +8 -0
  33. CHAP/inference/writer.py +8 -0
  34. CHAP/models.py +227 -0
  35. CHAP/pipeline.py +479 -0
  36. CHAP/processor.py +125 -0
  37. CHAP/reader.py +124 -0
  38. CHAP/runner.py +277 -0
  39. CHAP/saxswaxs/__init__.py +7 -0
  40. CHAP/saxswaxs/processor.py +8 -0
  41. CHAP/saxswaxs/reader.py +8 -0
  42. CHAP/saxswaxs/writer.py +8 -0
  43. CHAP/server.py +125 -0
  44. CHAP/sin2psi/__init__.py +7 -0
  45. CHAP/sin2psi/processor.py +8 -0
  46. CHAP/sin2psi/reader.py +8 -0
  47. CHAP/sin2psi/writer.py +8 -0
  48. CHAP/tomo/__init__.py +15 -0
  49. CHAP/tomo/models.py +210 -0
  50. CHAP/tomo/processor.py +3862 -0
  51. CHAP/tomo/reader.py +9 -0
  52. CHAP/tomo/writer.py +59 -0
  53. CHAP/utils/__init__.py +6 -0
  54. CHAP/utils/converters.py +188 -0
  55. CHAP/utils/fit.py +2947 -0
  56. CHAP/utils/general.py +2655 -0
  57. CHAP/utils/material.py +274 -0
  58. CHAP/utils/models.py +595 -0
  59. CHAP/utils/parfile.py +224 -0
  60. CHAP/writer.py +122 -0
  61. MLaaS/__init__.py +0 -0
  62. MLaaS/ktrain.py +205 -0
  63. MLaaS/mnist_img.py +83 -0
  64. MLaaS/tfaas_client.py +371 -0
  65. chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
  66. chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
  67. chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
  68. chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
  69. chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
  70. chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
CHAP/utils/parfile.py ADDED
@@ -0,0 +1,224 @@
1
+ """Utilities for interacting with scans using an SMB-style .par file
2
+ as input.
3
+ """
4
+
5
+ # System modules
6
+ import csv
7
+ import os
8
+
9
+ # Third party modules
10
+ import json
11
+
12
+ class ParFile():
13
+ """Representation of a .par file.
14
+
15
+ :ivar par_file: Name of the .par file.
16
+ :type par_file: str
17
+ :ivar json_file: Name of the .json file containing the keys for
18
+ the column names of the .par file.
19
+ :type json_file: str
20
+ :ivar spec_file: Name of the SPEC data file associated with this
21
+ .par file.
22
+ :type spec_file: str
23
+ :ivar column_names: List of the names of each column in the par
24
+ file.
25
+ :type column_names: list[str]
26
+ :ivar data: A 2D array of the data in this .par file. 0th index:
27
+ row. 1st index: column
28
+ :type data: list[list]
29
+ """
30
+ def __init__(self, par_file, scan_numbers=None, scann_col_name='SCAN_N'):
31
+ # Local modules
32
+ from CHAP.utils.general import (
33
+ is_int_series,
34
+ string_to_list,
35
+ )
36
+
37
+ self.par_file = str(par_file)
38
+ self.json_file = self.par_file.replace('.par', '.json')
39
+ self.spec_file = os.path.join(
40
+ os.path.dirname(self.par_file), 'spec.log')
41
+
42
+ with open(self.json_file) as json_file:
43
+ columns = json.load(json_file)
44
+ num_column = len(columns)
45
+ self.column_names = [None] * num_column
46
+ for i, name in columns.items():
47
+ self.column_names[int(i)] = name
48
+
49
+ self.data = []
50
+ with open(self.par_file) as f:
51
+ reader = csv.reader(f, delimiter=' ')
52
+ for i, row in enumerate(reader):
53
+ if len(row) == 0:
54
+ continue
55
+ if row[0].startswith('#'):
56
+ continue
57
+ row_data = []
58
+ for value in row:
59
+ try:
60
+ value = int(value)
61
+ except ValueError:
62
+ try:
63
+ value = float(value)
64
+ except Exception:
65
+ pass
66
+ row_data.append(value)
67
+ if len(row_data) != num_column:
68
+ raise ValueError(
69
+ 'Mismatch between the number of columns in the json '
70
+ f'({num_column}) and line {i+1} of the par file '
71
+ f'({len(row_data)})')
72
+ self.data.append(row_data)
73
+
74
+ self.scann_i = self.column_names.index(scann_col_name)
75
+ self.scan_numbers = [data[self.scann_i] for data in self.data]
76
+ if scan_numbers is not None:
77
+ if isinstance(scan_numbers, int):
78
+ scan_numbers = [scan_numbers]
79
+ elif isinstance(scan_numbers, str):
80
+ scan_numbers = string_to_list(scan_numbers)
81
+ if not is_int_series(scan_numbers, ge=0, log=False):
82
+ raise TypeError(
83
+ f'Invalid scan_numbers parameter ({scan_numbers})')
84
+ self.scan_numbers = [
85
+ n for n in scan_numbers if n in self.scan_numbers]
86
+
87
+ def get_map(
88
+ self, experiment_type, station, par_dims, other_dims=None):
89
+ """Return a map configuration based on this par file.
90
+
91
+ :param experiment_type: Experiment type name for the map
92
+ that this .par file represents.
93
+ :type experiment_type:
94
+ Literal['SAXSWAXS', 'EDD', 'XRF', 'TOMO']
95
+ :param station: Station name at which the data were collected.
96
+ :type station: Literal['id1a3','id3a','id3b']
97
+ :param par_dims: List of dictionaries configuring the map's
98
+ independent dimensions.
99
+ :type par_dims: list[dict[str, str]]
100
+ :param other_dims: List of other dimensions to include in
101
+ the returned MapConfig's independednt_dimensions. Use this
102
+ if each scans in thhis par ile captured more than one
103
+ frame of data. Defaults to `None`
104
+ :type other_dims: list[dict[str,str]], optional
105
+ :return: The map configuration.
106
+ :rtype: CHAP.common.models.map.MapConfig
107
+ """
108
+ # Third party modeuls
109
+ # pylint: disable=import-error
110
+ from chess_scanparsers import SMBScanParser
111
+ # pylint: enable=import-error
112
+
113
+ # Local modules
114
+ from CHAP.common.models.map import MapConfig
115
+
116
+ scanparser = SMBScanParser(self.spec_file, 1)
117
+ good_scans = self.good_scan_numbers()
118
+ if other_dims is None:
119
+ other_dims = []
120
+ map_config = {
121
+ 'title': scanparser.scan_name,
122
+ 'station': station, #scanparser.station,
123
+ 'experiment_type': experiment_type,
124
+ 'sample': {'name': scanparser.scan_name},
125
+ 'spec_scans': [
126
+ {'spec_file': self.spec_file,
127
+ 'scan_numbers': good_scans}],
128
+ 'independent_dimensions': [
129
+ {'label': dim['label'],
130
+ 'units': dim['units'],
131
+ 'name': dim['name'],
132
+ 'data_type': 'smb_par'}
133
+ for dim in par_dims] + other_dims
134
+ }
135
+ return MapConfig(**map_config)
136
+
137
+ def good_scan_numbers(self, good_col_name='1/0'):
138
+ """Return the numbers of scans marked with a "1" in the
139
+ indicated "good" column of the .par file.
140
+
141
+ :param good_col_name: The name of the "good" column of the par
142
+ file, defaults to "1/0"
143
+ :type good_col_name: str, optional
144
+ :raises ValueError: If this .par file does not have a column
145
+ with the same name as `good_col_name`.
146
+ :return: "good" scan numbers.
147
+ :rtype: list[int]
148
+ """
149
+ good_col_i = self.column_names.index(good_col_name)
150
+ return [self.scan_numbers[i] for i in range(len(self.scan_numbers))
151
+ if self.data[i][good_col_i] == 1]
152
+
153
+ def get_values(self, column, scan_numbers=None):
154
+ """Return values from a single column of the par file.
155
+
156
+ :param column: The string name OR index of the column to return
157
+ values for.
158
+ :type column: str or int
159
+ :param scan_numbers: List of specific scan numbers to return
160
+ values in the given column for (instead of the default
161
+ behavior: return the entire column of values).
162
+ :type scan_numbers: list[int], optional
163
+ :raise:
164
+ ValueError: Unavailable column name.
165
+ TypeError: Illegal column name type.
166
+ :return: A list of values from a single column in the par file.
167
+ :rtype: list[object]
168
+ """
169
+ if isinstance(column, str):
170
+ if column in self.column_names:
171
+ column_idx = self.column_names.index(column)
172
+ # elif column in ('dataset_id', 'scan_type'):
173
+ # column_idx = None
174
+ else:
175
+ raise ValueError(f'Unavailable column name: {column} not in '
176
+ f'{self.column_names}')
177
+ elif isinstance(column, int):
178
+ column_idx = column
179
+ else:
180
+ raise TypeError(f'column must be a str or int, not {type(column)}')
181
+
182
+ if column_idx is None:
183
+ column_data = [None]*len(self.data)
184
+ else:
185
+ column_data = [
186
+ self.data[i][column_idx] for i in range(len(self.data))]
187
+ if scan_numbers is not None:
188
+ column_data = [column_data[self.scan_numbers.index(scan_n)] \
189
+ for scan_n in scan_numbers]
190
+ return column_data
191
+
192
+ def map_values(self, map_config, values):
193
+ """Return a reshaped array of the 1D list `values` so that it
194
+ matches up with the coordinates of `map_config`.
195
+
196
+ :param map_config: The map configuration according to which
197
+ values will be reshaped.
198
+ :type map_config: MapConfig
199
+ :param values: A 1D list of values to reshape.
200
+ :type values: list or np.ndarray
201
+ :return: Reshaped array of values.
202
+ :rtype: np.ndarray
203
+ """
204
+ # Third party modules
205
+ import numpy as np
206
+
207
+ good_scans = self.good_scan_numbers()
208
+ if len(values) != len(good_scans):
209
+ raise ValueError('number of values provided ({len(values)}) does '
210
+ 'not match the number of good scans in '
211
+ f'{self.par_file} ({len(good_scans)})')
212
+ n_map_points = np.prod(map_config.shape)
213
+ if len(values) != n_map_points:
214
+ raise ValueError(
215
+ f'Cannot reshape {len(values)} values into an array of shape '
216
+ f'{map_config.shape}')
217
+
218
+ map_values = np.empty(map_config.shape)
219
+ for map_index in np.ndindex(map_config.shape):
220
+ _, scan_number, _ = \
221
+ map_config.get_scan_step_index(map_index)
222
+ value_index = good_scans.index(scan_number)
223
+ map_values[map_index] = values[value_index]
224
+ return map_values
CHAP/writer.py ADDED
@@ -0,0 +1,122 @@
1
+ #!/usr/bin/env python
2
+ """
3
+ File : writer.py
4
+ Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
5
+ Description: generic Writer module
6
+
7
+ Define a generic `Writer` object.
8
+ """
9
+
10
+ # System modules
11
+ import argparse
12
+ import logging
13
+ import os
14
+ from sys import modules
15
+ from typing import Optional
16
+
17
+ # Third party modules
18
+ from pydantic import model_validator
19
+
20
+ # Local modules
21
+ from CHAP.pipeline import PipelineItem
22
+
23
+
24
+ def validate_writer_model(writer):
25
+ writer.filename = os.path.normpath(os.path.realpath(
26
+ os.path.join(writer.outputdir, writer.filename)))
27
+ if (not writer.force_overwrite
28
+ and os.path.isfile(writer.filename)):
29
+ raise ValueError(
30
+ 'Writing to an existing file without overwrite permission. '
31
+ f'permission. Remove {writer.filename} or set '
32
+ '"force_overwrite" in pipeline configuration for '
33
+ f'{writer.name}')
34
+ return writer
35
+
36
+
37
+ class Writer(PipelineItem):
38
+ """Generic file writer.
39
+
40
+ The job of any `Writer` in a `Pipeline` is to receive input
41
+ returned by a previous `PipelineItem`, write its data to a
42
+ particular file format, then return the same data unaltered so it
43
+ can be used by a successive `PipelineItem`.
44
+
45
+ :ivar filename: Name of file to write to.
46
+ :type filename: str
47
+ :ivar force_overwrite: Flag to allow data in `filename` to be
48
+ overwritten if it already exists, defaults to `False`.
49
+ :type force_overwrite: bool, optional
50
+ :ivar remove: Flag to remove the dictionary from `data`,
51
+ defaults to `False`.
52
+ :type remove: bool, optional
53
+ """
54
+ filename: str
55
+ force_overwrite: Optional[bool] = False
56
+ remove: Optional[bool] = False
57
+
58
+ _validate_filename = model_validator(mode="after")(
59
+ validate_writer_model)
60
+
61
+ def write(self, data):
62
+ """Write the last `CHAP.pipeline.PipelineData` item in `data`
63
+ as text to a file.
64
+
65
+ :param data: Input data.
66
+ :type data: list[CHAP.pipeline.PipelineData]
67
+ :return: Contents of the input data.
68
+ :rtype: list[PipelineData]
69
+ """
70
+ ddata = self.unwrap_pipelinedata(data)[-1]
71
+ if os.path.isfile(self.filename) and not self.force_overwrite:
72
+ raise FileExistsError(f'{self.filename} already exists')
73
+ with open(self.filename, 'w') as f:
74
+ f.write(ddata)
75
+ if self.remove:
76
+ data.pop()
77
+ self.status = 'written' # Right now does nothing yet, but could
78
+ # add a sort of modification flag later
79
+ return data
80
+
81
+
82
+ class OptionParser():
83
+ """User based option parser."""
84
+ def __init__(self):
85
+ self.parser = argparse.ArgumentParser(prog='PROG')
86
+ self.parser.add_argument(
87
+ '--data', action='store',
88
+ dest='data', default='', help='Input data')
89
+ self.parser.add_argument(
90
+ '--filename', action='store',
91
+ dest='filename', default='', help='Output file')
92
+ self.parser.add_argument(
93
+ '--writer', action='store',
94
+ dest='writer', default='Writer', help='Writer class name')
95
+ self.parser.add_argument(
96
+ '--log-level', choices=logging._nameToLevel.keys(),
97
+ dest='log_level', default='INFO', help='logging level')
98
+
99
+
100
+ def main(opt_parser=OptionParser):
101
+ """Main function."""
102
+ optmgr = opt_parser()
103
+ opts = optmgr.parser.parse_args()
104
+ cls_name = opts.writer
105
+ try:
106
+ writer_cls = getattr(modules[__name__], cls_name)
107
+ except AttributeError:
108
+ print(f'Unsupported writer {cls_name}')
109
+ raise
110
+
111
+ writer = writer_cls()
112
+ writer.logger.setLevel(getattr(logging, opts.log_level))
113
+ log_handler = logging.StreamHandler()
114
+ log_handler.setFormatter(logging.Formatter(
115
+ '{name:20}: {message}', style='{'))
116
+ writer.logger.addHandler(log_handler)
117
+ data = writer.write(opts.data, opts.filename)
118
+ print(f'Writer {writer} writes to {opts.filename}, data {data}')
119
+
120
+
121
+ if __name__ == '__main__':
122
+ main()
MLaaS/__init__.py ADDED
File without changes
MLaaS/ktrain.py ADDED
@@ -0,0 +1,205 @@
1
+ #!/usr/bin/env python
2
+ #-*- coding: utf-8 -*-
3
+ #pylint: disable=
4
+ """
5
+ File : ktrain.py
6
+ Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
7
+ Description: Keras based ML network to train over MNIST dataset
8
+ """
9
+
10
+ # system modules
11
+ import os
12
+ import sys
13
+ import json
14
+ import gzip
15
+ import pickle
16
+ import argparse
17
+
18
+ # third-party modules
19
+ import numpy as np
20
+ import tensorflow as tf
21
+ from tensorflow import keras
22
+ from tensorflow.keras import layers
23
+ from tensorflow.keras import backend as K
24
+ from tensorflow.python.tools import saved_model_utils
25
+
26
+
27
+ def modelGraph(model_dir):
28
+ """
29
+ Provide input/output names used by TF Graph along with graph itself
30
+ The code is based on TF saved_model_cli.py script.
31
+ """
32
+ input_names = []
33
+ output_names = []
34
+ tag_sets = saved_model_utils.get_saved_model_tag_sets(model_dir)
35
+ for tag_set in sorted(tag_sets):
36
+ print('%r' % ', '.join(sorted(tag_set)))
37
+ meta_graph_def = saved_model_utils.get_meta_graph_def(model_dir, tag_set[0])
38
+ for key in meta_graph_def.signature_def.keys():
39
+ meta = meta_graph_def.signature_def[key]
40
+ if hasattr(meta, 'inputs') and hasattr(meta, 'outputs'):
41
+ inputs = meta.inputs
42
+ outputs = meta.outputs
43
+ input_signatures = list(meta.inputs.values())
44
+ input_names = [signature.name for signature in input_signatures]
45
+ if len(input_names) > 0:
46
+ output_signatures = list(meta.outputs.values())
47
+ output_names = [signature.name for signature in output_signatures]
48
+ return input_names, output_names, meta_graph_def
49
+
50
+ def readData(fin, num_classes):
51
+ """
52
+ Helper function to read MNIST data and provide it to
53
+ upstream code, e.g. to the training layer
54
+ """
55
+ # Load the data and split it between train and test sets
56
+ # (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
57
+ f = gzip.open(fin, 'rb')
58
+ if sys.version_info < (3,):
59
+ mnist_data = pickle.load(f)
60
+ else:
61
+ mnist_data = pickle.load(f, encoding='bytes')
62
+ f.close()
63
+ (x_train, y_train), (x_test, y_test) = mnist_data
64
+
65
+ # Scale images to the [0, 1] range
66
+ x_train = x_train.astype("float32") / 255
67
+ x_test = x_test.astype("float32") / 255
68
+ # Make sure images have shape (28, 28, 1)
69
+ x_train = np.expand_dims(x_train, -1)
70
+ x_test = np.expand_dims(x_test, -1)
71
+ print("x_train shape:", x_train.shape)
72
+ print(x_train.shape[0], "train samples")
73
+ print(x_test.shape[0], "test samples")
74
+
75
+
76
+ # convert class vectors to binary class matrices
77
+ y_train = keras.utils.to_categorical(y_train, num_classes)
78
+ y_test = keras.utils.to_categorical(y_test, num_classes)
79
+ return x_train, y_train, x_test, y_test
80
+
81
+
82
+ def train(fin, fout=None, model_name=None, epochs=1, batch_size=128, h5=False):
83
+ """
84
+ train function for MNIST
85
+ """
86
+ # Model / data parameters
87
+ num_classes = 10
88
+ input_shape = (28, 28, 1)
89
+
90
+ # create ML model
91
+ model = keras.Sequential(
92
+ [
93
+ keras.Input(shape=input_shape),
94
+ layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
95
+ layers.MaxPooling2D(pool_size=(2, 2)),
96
+ layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
97
+ layers.MaxPooling2D(pool_size=(2, 2)),
98
+ layers.Flatten(),
99
+ layers.Dropout(0.5),
100
+ layers.Dense(num_classes, activation="softmax"),
101
+ ]
102
+ )
103
+
104
+ model.summary()
105
+ print("model input", model.input, type(model.input), model.input.__dict__)
106
+ print("model output", model.output, type(model.output), model.output.__dict__)
107
+ model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
108
+
109
+ # train model
110
+ x_train, y_train, x_test, y_test = readData(fin, num_classes)
111
+ model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
112
+
113
+ # evaluate trained model
114
+ score = model.evaluate(x_test, y_test, verbose=0)
115
+ print("Test loss:", score[0])
116
+ print("Test accuracy:", score[1])
117
+ print("save model to", fout)
118
+ writer(fout, model_name, model, input_shape, h5)
119
+
120
+ def writer(fout, model_name, model, input_shape, h5=False):
121
+ """
122
+ Writer provide write function for given model
123
+ """
124
+ if not fout:
125
+ return
126
+ model.save(fout)
127
+ if h5:
128
+ model.save('{}/{}'.format(fout, h5), save_format='h5')
129
+ pbModel = '{}/saved_model.pb'.format(fout)
130
+ pbtxtModel = '{}/saved_model.pbtxt'.format(fout)
131
+ convert(pbModel, pbtxtModel)
132
+
133
+ # get meta-data information about our ML model
134
+ input_names, output_names, model_graph = modelGraph(model_name)
135
+ print("### input", input_names)
136
+ print("### output", output_names)
137
+ # ML uses (28,28,1) shape, i.e. 28x28 black-white images
138
+ # if we'll use color images we'll use shape (28, 28, 3)
139
+ img_channels = input_shape[2] # last item represent number of colors
140
+ meta = {'name': model_name,
141
+ 'model': 'saved_model.pb',
142
+ 'labels': 'labels.txt',
143
+ 'img_channels': img_channels,
144
+ 'input_name': input_names[0].split(':')[0],
145
+ 'output_name': output_names[0].split(':')[0],
146
+ 'input_node': model.input.name,
147
+ 'output_node': model.output.name
148
+ }
149
+ with open(fout+'/params.json', 'w') as ostream:
150
+ ostream.write(json.dumps(meta))
151
+ with open(fout+'/labels.txt', 'w') as ostream:
152
+ for i in range(0, 10):
153
+ ostream.write(str(i)+'\n')
154
+ with open(fout + '/model.graph', 'wb') as ostream:
155
+ ostream.write(model_graph.SerializeToString())
156
+
157
+ def convert(fin, fout):
158
+ """
159
+ convert input model.pb into output model.pbtxt
160
+ Based on internet search:
161
+ - https://www.tensorflow.org/guide/saved_model
162
+ - https://www.programcreek.com/python/example/123317/tensorflow.core.protobuf.saved_model_pb2.SavedModel
163
+ """
164
+ import google.protobuf
165
+ from tensorflow.core.protobuf import saved_model_pb2
166
+ import tensorflow as tf
167
+
168
+ saved_model = saved_model_pb2.SavedModel()
169
+
170
+ with open(fin, 'rb') as f:
171
+ saved_model.ParseFromString(f.read())
172
+
173
+ with open(fout, 'w') as f:
174
+ f.write(google.protobuf.text_format.MessageToString(saved_model))
175
+
176
+
177
+ class OptionParser():
178
+ def __init__(self):
179
+ "User based option parser"
180
+ self.parser = argparse.ArgumentParser(prog='PROG')
181
+ self.parser.add_argument("--fin", action="store",
182
+ dest="fin", default="", help="Input MNIST file")
183
+ self.parser.add_argument("--fout", action="store",
184
+ dest="fout", default="", help="Output models area")
185
+ self.parser.add_argument("--model", action="store",
186
+ dest="model", default="mnist", help="model name")
187
+ self.parser.add_argument("--epochs", action="store",
188
+ dest="epochs", default=1, help="number of epochs to use in ML training")
189
+ self.parser.add_argument("--batch_size", action="store",
190
+ dest="batch_size", default=128, help="batch size to use in training")
191
+ self.parser.add_argument("--h5", action="store",
192
+ dest="h5", default="mnist", help="h5 model file name")
193
+
194
+ def main():
195
+ "Main function"
196
+ optmgr = OptionParser()
197
+ opts = optmgr.parser.parse_args()
198
+ train(opts.fin, opts.fout,
199
+ model_name=opts.model,
200
+ epochs=opts.epochs,
201
+ batch_size=opts.batch_size,
202
+ h5=opts.h5)
203
+
204
+ if __name__ == '__main__':
205
+ main()
MLaaS/mnist_img.py ADDED
@@ -0,0 +1,83 @@
1
+ #!/usr/bin/env python
2
+ #-*- coding: utf-8 -*-
3
+ #pylint: disable=
4
+ """
5
+ File : mnist_img.py
6
+ Author : Valentin Kuznetsov <vkuznet AT gmail dot com>
7
+ Description:
8
+ """
9
+
10
+ import json
11
+ import gzip
12
+ import argparse
13
+ # from itertools import chain
14
+
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+
18
+
19
+ def readImage(fname, fout, num_images=5, imgId=2):
20
+ """
21
+ Helper function to read MNIST image
22
+ """
23
+ image_size = 28
24
+ with gzip.open(fname, 'r') as fstream:
25
+ fstream.read(16)
26
+ buf = fstream.read(image_size * image_size * num_images)
27
+ data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
28
+ data = data.reshape(num_images, image_size, image_size, 1)
29
+ image = np.asarray(data[imgId]).squeeze()
30
+ plt.imsave(fout, image)
31
+ print("read:", fname, "wrote:", fout, "image:", type(image), "shape:", image.shape)
32
+
33
+ def img2json(image):
34
+ """
35
+ Convert given image to JSON data format used by TFaaS
36
+ """
37
+ # values = [int(i) for i in list(chain.from_iterable(image))]
38
+ # values = image.tolist()
39
+ values = []
40
+ for row in image.tolist():
41
+ row = [int(i) for i in row]
42
+ vals = [[i] for i in row]
43
+ values.append(vals)
44
+ # final values should be an array of elements, e.g. single image representation
45
+ values = [values]
46
+ keys = [str(i) for i in range(0, 10)]
47
+ meta = {
48
+ 'keys': keys,
49
+ 'values': values,
50
+ 'model': 'mnist'
51
+ }
52
+ with open('img.json', 'w') as ostream:
53
+ ostream.write(json.dumps(meta))
54
+
55
+
56
+ class OptionParser():
57
+ def __init__(self):
58
+ "User based option parser"
59
+ fname = "train-images-idx3-ubyte.gz"
60
+ self.parser = argparse.ArgumentParser(prog='PROG')
61
+ self.parser.add_argument("--fin", action="store",
62
+ dest="fin", default=fname, help=f"Input MNIST file, default {fname}")
63
+ self.parser.add_argument("--fout", action="store",
64
+ dest="fout", default="img.png", help="Output image fila name, default img.png")
65
+ self.parser.add_argument("--nimages", action="store",
66
+ dest="nimages", default=5, help="number of images to read, default 5")
67
+ self.parser.add_argument("--imgid", action="store",
68
+ dest="imgid", default=2, help="image index to use from nimages, default 2 (number 4)")
69
+
70
+ def main():
71
+ """
72
+ main function to produce image file from mnist dataset.
73
+ MNIST dataset can be downloaded from
74
+ curl -O http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
75
+ """
76
+ optmgr = OptionParser()
77
+ opts = optmgr.parser.parse_args()
78
+ num_images = int(opts.nimages)
79
+ imgId = int(opts.imgid)
80
+ img = readImage(opts.fin, opts.fout, num_images, imgId)
81
+
82
+ if __name__ == '__main__':
83
+ main()