naeural-client 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. naeural_client/__init__.py +13 -0
  2. naeural_client/_ver.py +13 -0
  3. naeural_client/base/__init__.py +6 -0
  4. naeural_client/base/distributed_custom_code_presets.py +44 -0
  5. naeural_client/base/generic_session.py +1763 -0
  6. naeural_client/base/instance.py +616 -0
  7. naeural_client/base/payload/__init__.py +1 -0
  8. naeural_client/base/payload/payload.py +66 -0
  9. naeural_client/base/pipeline.py +1499 -0
  10. naeural_client/base/plugin_template.py +5209 -0
  11. naeural_client/base/responses.py +209 -0
  12. naeural_client/base/transaction.py +157 -0
  13. naeural_client/base_decentra_object.py +143 -0
  14. naeural_client/bc/__init__.py +3 -0
  15. naeural_client/bc/base.py +1046 -0
  16. naeural_client/bc/chain.py +0 -0
  17. naeural_client/bc/ec.py +324 -0
  18. naeural_client/certs/__init__.py +0 -0
  19. naeural_client/certs/r9092118.ala.eu-central-1.emqxsl.com.crt +22 -0
  20. naeural_client/code_cheker/__init__.py +1 -0
  21. naeural_client/code_cheker/base.py +520 -0
  22. naeural_client/code_cheker/checker.py +294 -0
  23. naeural_client/comm/__init__.py +2 -0
  24. naeural_client/comm/amqp_wrapper.py +338 -0
  25. naeural_client/comm/mqtt_wrapper.py +539 -0
  26. naeural_client/const/README.md +3 -0
  27. naeural_client/const/__init__.py +9 -0
  28. naeural_client/const/base.py +101 -0
  29. naeural_client/const/comms.py +80 -0
  30. naeural_client/const/environment.py +26 -0
  31. naeural_client/const/formatter.py +7 -0
  32. naeural_client/const/heartbeat.py +111 -0
  33. naeural_client/const/misc.py +20 -0
  34. naeural_client/const/payload.py +190 -0
  35. naeural_client/default/__init__.py +1 -0
  36. naeural_client/default/instance/__init__.py +4 -0
  37. naeural_client/default/instance/chain_dist_custom_job_01_plugin.py +54 -0
  38. naeural_client/default/instance/custom_web_app_01_plugin.py +118 -0
  39. naeural_client/default/instance/net_mon_01_plugin.py +45 -0
  40. naeural_client/default/instance/view_scene_01_plugin.py +28 -0
  41. naeural_client/default/session/mqtt_session.py +72 -0
  42. naeural_client/io_formatter/__init__.py +2 -0
  43. naeural_client/io_formatter/base/__init__.py +1 -0
  44. naeural_client/io_formatter/base/base_formatter.py +80 -0
  45. naeural_client/io_formatter/default/__init__.py +3 -0
  46. naeural_client/io_formatter/default/a_dummy.py +51 -0
  47. naeural_client/io_formatter/default/aixp1.py +113 -0
  48. naeural_client/io_formatter/default/default.py +22 -0
  49. naeural_client/io_formatter/io_formatter_manager.py +96 -0
  50. naeural_client/logging/__init__.py +1 -0
  51. naeural_client/logging/base_logger.py +2056 -0
  52. naeural_client/logging/logger_mixins/__init__.py +12 -0
  53. naeural_client/logging/logger_mixins/class_instance_mixin.py +92 -0
  54. naeural_client/logging/logger_mixins/computer_vision_mixin.py +443 -0
  55. naeural_client/logging/logger_mixins/datetime_mixin.py +344 -0
  56. naeural_client/logging/logger_mixins/download_mixin.py +421 -0
  57. naeural_client/logging/logger_mixins/general_serialization_mixin.py +242 -0
  58. naeural_client/logging/logger_mixins/json_serialization_mixin.py +481 -0
  59. naeural_client/logging/logger_mixins/pickle_serialization_mixin.py +301 -0
  60. naeural_client/logging/logger_mixins/process_mixin.py +63 -0
  61. naeural_client/logging/logger_mixins/resource_size_mixin.py +81 -0
  62. naeural_client/logging/logger_mixins/timers_mixin.py +501 -0
  63. naeural_client/logging/logger_mixins/upload_mixin.py +260 -0
  64. naeural_client/logging/logger_mixins/utils_mixin.py +675 -0
  65. naeural_client/logging/small_logger.py +93 -0
  66. naeural_client/logging/tzlocal/__init__.py +20 -0
  67. naeural_client/logging/tzlocal/unix.py +231 -0
  68. naeural_client/logging/tzlocal/utils.py +113 -0
  69. naeural_client/logging/tzlocal/win32.py +151 -0
  70. naeural_client/logging/tzlocal/windows_tz.py +718 -0
  71. naeural_client/plugins_manager_mixin.py +273 -0
  72. naeural_client/utils/__init__.py +2 -0
  73. naeural_client/utils/comm_utils.py +44 -0
  74. naeural_client/utils/dotenv.py +75 -0
  75. naeural_client-2.0.0.dist-info/METADATA +365 -0
  76. naeural_client-2.0.0.dist-info/RECORD +78 -0
  77. naeural_client-2.0.0.dist-info/WHEEL +4 -0
  78. naeural_client-2.0.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,12 @@
1
+ from .class_instance_mixin import _ClassInstanceMixin
2
+ from .computer_vision_mixin import _ComputerVisionMixin
3
+ from .datetime_mixin import _DateTimeMixin
4
+ from .download_mixin import _DownloadMixin
5
+ from .general_serialization_mixin import _GeneralSerializationMixin
6
+ from .json_serialization_mixin import _JSONSerializationMixin
7
+ from .pickle_serialization_mixin import _PickleSerializationMixin
8
+ from .process_mixin import _ProcessMixin
9
+ from .resource_size_mixin import _ResourceSizeMixin
10
+ from .timers_mixin import _TimersMixin
11
+ from .upload_mixin import _UploadMixin
12
+ from .utils_mixin import _UtilsMixin
@@ -0,0 +1,92 @@
1
+ import inspect
2
+
3
+ class _ClassInstanceMixin(object):
4
+ """
5
+ Mixin for class instance functionalities that are attached to `pye2.Logger`.
6
+
7
+ This mixin cannot be instantiated because it is built just to provide some additional
8
+ functionalities for `pye2.Logger`
9
+
10
+ In this mixin we can use any attribute/method of the Logger.
11
+ """
12
+
13
+ def __init__(self):
14
+ super(_ClassInstanceMixin, self).__init__()
15
+
16
+ @staticmethod
17
+ def get_class_instance_methods(obj, spacing=20):
18
+ methodList = []
19
+ for method_name in dir(obj):
20
+ try:
21
+ if callable(getattr(obj, method_name)):
22
+ methodList.append(str(method_name))
23
+ except:
24
+ methodList.append(str(method_name))
25
+ processFunc = (lambda s: ' '.join(s.split())) or (lambda s: s)
26
+ for method in methodList:
27
+ try:
28
+ print(str(method.ljust(spacing)) + ' ' + processFunc(str(getattr(obj, method).__doc__)[0:90]))
29
+ except:
30
+ print(method.ljust(spacing) + ' ' + ' getattr() failed')
31
+ return
32
+
33
+
34
+ @staticmethod
35
+ def get_class_instance_params(obj, n=None):
36
+ """
37
+ Parameters
38
+ ----------
39
+ obj : any type
40
+ the inspected object.
41
+ n : int, optional
42
+ the number of params that are returned. The default is None
43
+ (all params returned).
44
+
45
+ Returns
46
+ -------
47
+ out_str : str
48
+ the description of the object 'obj' in terms of parameters values.
49
+ """
50
+
51
+ out_str = obj.__class__.__name__ + "("
52
+ n_added_to_log = 0
53
+ for _iter, (prop, value) in enumerate(vars(obj).items()):
54
+ if type(value) in [int, float, bool]:
55
+ out_str += prop + '=' + str(value) + ','
56
+ n_added_to_log += 1
57
+ elif type(value) in [str]:
58
+ out_str += prop + "='" + value + "',"
59
+ n_added_to_log += 1
60
+
61
+ if n is not None and n_added_to_log >= n:
62
+ break
63
+ # endfor
64
+
65
+ out_str = out_str[:-1] if out_str[-1] == ',' else out_str
66
+ out_str += ')'
67
+ return out_str
68
+
69
+ @staticmethod
70
+ def get_object_params(*args, **kwargs):
71
+ print("DeprecationWarning! `get_object_params` is deprecated. Please use `get_class_instance_params` instead")
72
+ return _ClassInstanceMixin.get_class_instance_params(*args, **kwargs)
73
+
74
+ @staticmethod
75
+ def get_object_methods(*args, **kwargs):
76
+ print("DeprecationWarning! `get_object_methods` is deprecated. Please use `get_class_instance_methods` instead")
77
+ return _ClassInstanceMixin.get_class_instance_methods(*args, **kwargs)
78
+
79
+ @staticmethod
80
+ def get_class_methods(cls, include_parent=True):
81
+ lst_methods = inspect.getmembers(cls, predicate=inspect.isfunction)
82
+
83
+ if not include_parent:
84
+ lst_methods = list(filter(
85
+ lambda x: cls.__name__ in x[1].__qualname__,
86
+ lst_methods
87
+ ))
88
+
89
+ return lst_methods
90
+
91
+
92
+
@@ -0,0 +1,443 @@
1
+ import os
2
+ import io
3
+ import numpy as np
4
+ import base64
5
+
6
+ from time import time as tm
7
+ from shutil import copyfile
8
+ from io import BytesIO
9
+
10
+ try:
11
+ import cv2
12
+ except:
13
+ pass
14
+
15
+ class _ComputerVisionMixin(object):
16
+ """
17
+ Mixin for computer vision functionalities that are attached to `libraries.logger.Logger`
18
+
19
+ This mixin cannot be instantiated because it is built just to provide some additional
20
+ functionalities for `libraries.logger.Logger`
21
+
22
+ In this mixin we can use any attribute/method of the Logger.
23
+ """
24
+
25
+ def __init__(self):
26
+ super(_ComputerVisionMixin, self).__init__()
27
+ return
28
+
29
+ @staticmethod
30
+ def is_image(file):
31
+ return any([file.endswith(tail) for tail in ('.png', '.jpg', '.jpeg')])
32
+
33
+ @staticmethod
34
+ def increase_brightness(img, value=30):
35
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
36
+ h, s, v = cv2.split(hsv)
37
+
38
+ lim = 255 - value
39
+ v[v > lim] = 255
40
+ v[v <= lim] += value
41
+
42
+ final_hsv = cv2.merge((h, s, v))
43
+ img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
44
+ return img
45
+
46
+ @staticmethod
47
+ def convert_gray(np_src, copy=False):
48
+ np_source = np_src.copy() if copy else np_src
49
+ if len(np_source.shape) == 3 and np_source.shape[-1] != 1:
50
+ np_source = cv2.cvtColor(np_source, cv2.COLOR_BGR2GRAY)
51
+ return np_source
52
+
53
+ @staticmethod
54
+ def scale_image(np_src, copy=False):
55
+ np_source = np_src.copy() if copy else np_src
56
+ if np_source.max() > 1:
57
+ np_source = (np_source / 255).astype(np.float32)
58
+ # endif
59
+ return np_source
60
+
61
+ @staticmethod
62
+ def center_image(image, target_h, target_w, copy=False):
63
+ image = image.copy() if copy else image
64
+ new_h, new_w, _ = image.shape
65
+
66
+ # determine the new size of the image
67
+ if (float(target_w) / new_w) < (float(target_h) / new_h):
68
+ new_h = int((new_h * target_w) / new_w)
69
+ new_w = int(target_w)
70
+ else:
71
+ new_w = int((new_w * target_h) / new_h)
72
+ new_h = int(target_h)
73
+
74
+ # embed the image into the standard letter box
75
+ new_image = np.ones((target_h, target_w, 3), dtype=np.uint8)
76
+ top = int((target_h - new_h) // 2)
77
+ bottom = int((target_h + new_h) // 2)
78
+ left = int((target_w - new_w) // 2)
79
+ right = int((target_w + new_w) // 2)
80
+
81
+ resized = cv2.resize(image, dsize=(new_w, new_h))
82
+ new_image[top:bottom, left:right, :] = resized
83
+ return new_image
84
+
85
+ @staticmethod
86
+ def center_image_coordinates(src_h, src_w, target_h, target_w):
87
+ asp_src = src_h / src_w
88
+ asp_dst = target_h / target_w
89
+ if asp_src > asp_dst:
90
+ # if src_h > src_w:
91
+ new_h = target_h
92
+ new_w = int(target_h / (src_h / src_w))
93
+ else:
94
+ new_w = target_w
95
+ new_h = int(target_w / (src_w / src_h))
96
+ # endif
97
+
98
+ left = target_w // 2 - new_w // 2
99
+ top = target_h // 2 - new_h // 2
100
+ right = left + new_w
101
+ bottom = top + new_h
102
+ return (top, left, bottom, right), (new_h, new_w)
103
+
104
+ @staticmethod
105
+ def center_image2(np_src, target_h, target_w, copy=False):
106
+ np_source = np_src.copy() if copy else np_src
107
+ shape = (target_h, target_w, np_src.shape[-1]) if len(np_src.shape) == 3 else (target_h, target_w)
108
+ (top, left, bottom, right), (new_h, new_w) = _ComputerVisionMixin.center_image_coordinates(
109
+ np_source.shape[0], np_source.shape[1],
110
+ target_h, target_w
111
+ )
112
+ np_dest = np.zeros(shape).astype(np.float32 if np_src.max() <= 1 else np.uint8)
113
+ np_src_mod = cv2.resize(np_src, dsize=(new_w, new_h))
114
+ np_dest[top:bottom, left:right] = np_src_mod
115
+ return np_dest
116
+
117
+ @staticmethod
118
+ def rescale_boxes(boxes,
119
+ src_h,
120
+ src_w,
121
+ target_h,
122
+ target_w,
123
+ ):
124
+ if np.max(boxes) <= 1:
125
+ # [0:1] to [0:model_shape]
126
+ boxes[:, 0] *= target_h
127
+ boxes[:, 1] *= target_w
128
+ boxes[:, 2] *= target_h
129
+ boxes[:, 3] *= target_w
130
+
131
+ (top, left, bottom, right), (new_h, new_w) = _ComputerVisionMixin.center_image_coordinates(
132
+ src_h=src_h,
133
+ src_w=src_w,
134
+ target_h=target_h,
135
+ target_w=target_w
136
+ )
137
+
138
+ # eliminate centering
139
+ boxes[:, 0] = boxes[:, 0] - top
140
+ boxes[:, 1] = boxes[:, 1] - left
141
+ boxes[:, 2] = boxes[:, 2] - top
142
+ boxes[:, 3] = boxes[:, 3] - left
143
+
144
+ # translate to original image scale
145
+ boxes[:, 0] = boxes[:, 0] / new_h * src_h
146
+ boxes[:, 1] = boxes[:, 1] / new_w * src_w
147
+ boxes[:, 2] = boxes[:, 2] / new_h * src_h
148
+ boxes[:, 3] = boxes[:, 3] / new_w * src_w
149
+
150
+ # clipping between [0: max]
151
+ boxes = boxes.astype(np.int32)
152
+ boxes[:, 0] = np.maximum(0, boxes[:, 0])
153
+ boxes[:, 1] = np.maximum(0, boxes[:, 1])
154
+ boxes[:, 2] = np.minimum(src_h, boxes[:, 2])
155
+ boxes[:, 3] = np.minimum(src_w, boxes[:, 3])
156
+ return boxes
157
+
158
+ @staticmethod
159
+ def dir_visual_image_dedup(path_dir, magnify_image=None):
160
+ """
161
+ This method reads a folder, display every image in that folder al let's you
162
+ decide where to move a specific image.
163
+ Use:
164
+ - s for skip
165
+ - b for bad
166
+ - g for good
167
+ - q for quit
168
+ """
169
+
170
+ def close():
171
+ cv2.waitKey(1)
172
+ cv2.destroyAllWindows()
173
+ for _ in range(5):
174
+ cv2.waitKey(1)
175
+ # endfor
176
+
177
+ assert os.path.exists(path_dir)
178
+ files = [x for x in os.listdir(path_dir) if x.endswith('.png') or x.endswith('.jpg')]
179
+ path_good = os.path.join(path_dir, 'good')
180
+ path_bad = os.path.join(path_dir, 'bad')
181
+ for file in files:
182
+ img_path = os.path.join(path_dir, file)
183
+ img = cv2.imread(img_path)
184
+ if magnify_image and magnify_image > 0:
185
+ img = cv2.resize(img, (img.shape[1] * magnify_image, img.shape[0] * magnify_image))
186
+ # endif
187
+ cv2.imshow('Img', img)
188
+ key = cv2.waitKey(0)
189
+ if key == ord('b'):
190
+ os.makedirs(path_bad, exist_ok=True)
191
+ copyfile(img_path, os.path.join(path_bad, file))
192
+ elif key == ord('g'):
193
+ os.makedirs(path_good, exist_ok=True)
194
+ copyfile(img_path, os.path.join(path_good, file))
195
+ elif key == ord('q'):
196
+ close()
197
+ break
198
+ elif key == ord('s'):
199
+ pass
200
+ # endif
201
+ # endfor
202
+ close()
203
+
204
+ @staticmethod
205
+ def to_rgb(image):
206
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
207
+
208
+ @staticmethod
209
+ def change_brightness(img, delta):
210
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
211
+ h, s, v = cv2.split(hsv)
212
+ lim = 255 - delta
213
+ v[v > lim] = 255
214
+ v[v <= lim] += delta
215
+ final_hsv = cv2.merge((h, s, v))
216
+ img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
217
+ return img
218
+
219
+ @staticmethod
220
+ def rotate(image, angle, center=None, scale=1.0):
221
+ (h, w) = image.shape[:2]
222
+ if center is None:
223
+ center = (w / 2, h / 2)
224
+ # Perform the rotation
225
+ M = cv2.getRotationMatrix2D(center, angle, scale)
226
+ rotated = cv2.warpAffine(image, M, (w, h))
227
+ return rotated
228
+
229
+ @staticmethod
230
+ def np_image_to_base64(np_image, ENCODING='utf-8', quality=95, max_height=None):
231
+ from PIL import Image
232
+ image = Image.fromarray(np_image)
233
+ w, h = image.size
234
+ if max_height is not None and h > max_height:
235
+ ratio = max_height / h
236
+ new_w, new_h = int(ratio * w), max_height
237
+ image = image.resize((new_w, new_h), reducing_gap=1)
238
+ buffered = BytesIO()
239
+ if quality < 95:
240
+ image.save(buffered, format='JPEG', quality=quality, optimize=True)
241
+ else:
242
+ image.save(buffered, format='JPEG')
243
+ img_base64 = base64.b64encode(buffered.getvalue())
244
+ img_str = img_base64.decode(ENCODING)
245
+ return img_str
246
+
247
+ @staticmethod
248
+ def base64_to_np_image(base64_img):
249
+ from PIL import Image
250
+ base64_decoded = base64.b64decode(base64_img)
251
+ image = Image.open(io.BytesIO(base64_decoded))
252
+ np_image = np.array(image)
253
+ return np_image
254
+
255
+ @staticmethod
256
+ def plt_to_base64(plt, close=True):
257
+ figfile = BytesIO()
258
+ plt.savefig(figfile, format='JPEG')
259
+ figfile.seek(0)
260
+ base64_bytes = base64.b64encode(figfile.getvalue())
261
+ base64_string = base64_bytes.decode('utf-8')
262
+ if close:
263
+ plt.close()
264
+ return base64_string
265
+
266
+ @staticmethod
267
+ def plt_to_np(plt, close=True, axis=False):
268
+ # TODO: this method IS NOT thread-safe - needs revision
269
+ if not axis:
270
+ plt.axis('off')
271
+ fig = plt.gcf()
272
+ fig.canvas.draw()
273
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
274
+ w, h = fig.canvas.get_width_height()
275
+ try:
276
+ np_img = data.reshape((int(h), int(w), -1))
277
+ except Exception as e:
278
+ print(e)
279
+ np_img = None
280
+ if close:
281
+ plt.close()
282
+ return np_img
283
+
284
+ def time_model_forward(self, model, img_height, img_width, nr_inputs=1, lst_range=list(range(2, 33, 2))):
285
+ from tqdm import tqdm
286
+ lst_batches = list(range(2, 33, 2))
287
+ self.P('Testing model {} forward time for {} iterations with batches: {}'.format(
288
+ model.name, len(lst_batches), lst_batches)
289
+ )
290
+ for batch in tqdm(lst_batches):
291
+ for i in range(2):
292
+ data = nr_inputs * [np.random.uniform(size=(batch, img_height, img_width, 3))]
293
+ start = tm()
294
+ model.predict(data)
295
+ stop = tm()
296
+ if i == 1:
297
+ self.P(' Forward time for bs {}: {:.3f}s. Time per obs: {:.3f}s'.format(
298
+ batch, stop - start, (stop - start) / batch)
299
+ )
300
+ # endif
301
+ # endfor
302
+ # endfor
303
+
304
+ @staticmethod
305
+ def destroy_all_windows():
306
+ cv2.waitKey(1)
307
+ for i in range(5):
308
+ cv2.waitKey(1)
309
+ cv2.destroyAllWindows()
310
+ cv2.waitKey(1)
311
+ # endfor
312
+
313
+ def describe_np_tensor(self, np_v):
314
+ """
315
+ Will take a numpy vector and will print statistics about that vector:
316
+ - shape
317
+ - min
318
+ - max
319
+ - std
320
+ - median
321
+ - percentiles (25 - 100)
322
+
323
+ Parameters
324
+ ----------
325
+ np_v : np.ndarray
326
+ Tensor of nD
327
+
328
+ Returns
329
+ -------
330
+ None.
331
+ """
332
+ import pandas as pd
333
+ assert isinstance(np_v, np.ndarray)
334
+ lst = []
335
+ lst.append(('Shape', np_v.shape))
336
+ lst.append(('Dtype', np_v.dtype))
337
+ lst.append(('Min', np.min(np_v)))
338
+ lst.append(('Max', np.max(np_v)))
339
+ lst.append(('Std', np.std(np_v)))
340
+ lst.append(('Median', np.median(np_v)))
341
+ for perc in [25, 50, 75, 100]:
342
+ lst.append((perc, np.percentile(a=np_v, q=perc)))
343
+ self.p('Tensor description \n{}'.format(pd.DataFrame(lst, columns=['IND', 'VAL'])))
344
+ return
345
+
346
+ def class_distrib(self, v):
347
+ """
348
+ Simple function to print the class distribution. Usefull especially for supervised classification tasks
349
+
350
+ Parameters
351
+ ----------
352
+ v : list or ndarray
353
+ Target tensor used in supervised classification problem.
354
+
355
+ Returns
356
+ -------
357
+ None.
358
+
359
+ """
360
+ import pandas as pd
361
+ assert isinstance(v, list) or isinstance(v, np.ndarray)
362
+ unique, counts = np.unique(v, return_counts=True)
363
+ self.p('Class distribution \n{}'.format(pd.DataFrame({'CLASS': unique, 'NR': counts})))
364
+ return
365
+
366
+ def cv_check_np_tensors(self, X, y, impose_norm=True, task='binary'):
367
+ """
368
+ Common pitfalls checks for input tensors of a supervised computer vision task.
369
+
370
+ Parameters
371
+ ----------
372
+ X : np.ndarray
373
+ Input tensor of a supervised computer vision task.
374
+ y : np.ndarray
375
+ Target tensor of a supervised computer vision task.
376
+ impose_norm : boolean, optional
377
+ Check if input tensor is normed. The default is True.
378
+
379
+ Returns
380
+ -------
381
+ None.
382
+
383
+ """
384
+ assert task in ['binary', 'multiclass']
385
+ self.p('Starting tensor sanity check')
386
+ # check tensor lenghts are equal
387
+ assert X.shape[0] == y.shape[0]
388
+
389
+ # check if target is of shape (N, M)
390
+ assert len(y.shape) == 2
391
+ if task == 'binary': assert y.shape[-1] == 1
392
+ if task == 'multiclass': self.p('[Info] Please ensure you have the right number of classes')
393
+
394
+ # check if source is normed, only if impose_norm==True
395
+ if impose_norm:
396
+ assert X.max() <= 1.0, 'Your input tensor is not normed.'
397
+ else:
398
+ self.p('[Warning] Your input tensor is not normed.', color='yellow')
399
+
400
+ # check datatypes
401
+ if impose_norm:
402
+ assert X.dtype == 'float32', 'Your input tensor should be of type np.float32'
403
+ else:
404
+ assert X.dtype == 'uint8', 'You didn\'t normed your input tensor, you should convert it to np.uint8'
405
+ assert y.dtype == 'uint8', 'Your target tensor should be of type np.uint8'
406
+
407
+ self.p('[Warning] RGB/BGR Please ensure that all your images are RGB', color='yellow')
408
+ self.p('Done tensor sanity check')
409
+ return
410
+
411
+ @staticmethod
412
+ def image_resize(image, width=None, height=None, inter=None):
413
+ if inter is None:
414
+ inter = cv2.INTER_AREA
415
+ # initialize the dimensions of the image to be resized and
416
+ # grab the image size
417
+ dim = None
418
+ (h, w) = image.shape[:2]
419
+
420
+ # if both the width and height are None, then return the
421
+ # original image
422
+ if width is None and height is None:
423
+ return image
424
+
425
+ # check to see if the width is None
426
+ if width is None:
427
+ # calculate the ratio of the height and construct the
428
+ # dimensions
429
+ r = height / float(h)
430
+ dim = (int(w * r), height)
431
+
432
+ # otherwise, the height is None
433
+ else:
434
+ # calculate the ratio of the width and construct the
435
+ # dimensions
436
+ r = width / float(w)
437
+ dim = (width, int(h * r))
438
+
439
+ # resize the image
440
+ resized = cv2.resize(image, dim, interpolation = inter)
441
+
442
+ # return the resized image
443
+ return resized