edgefirst-validator 4.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. deepview/modelpack/utils/argmax.py +16 -0
  2. edgefirst/validator/__init__.py +1 -0
  3. edgefirst/validator/__main__.py +375 -0
  4. edgefirst/validator/datasets/__init__.py +118 -0
  5. edgefirst/validator/datasets/cache.py +296 -0
  6. edgefirst/validator/datasets/core.py +250 -0
  7. edgefirst/validator/datasets/darknet.py +446 -0
  8. edgefirst/validator/datasets/database.py +1067 -0
  9. edgefirst/validator/datasets/instance/__init__.py +4 -0
  10. edgefirst/validator/datasets/instance/core.py +222 -0
  11. edgefirst/validator/datasets/instance/detection.py +145 -0
  12. edgefirst/validator/datasets/instance/multitask.py +80 -0
  13. edgefirst/validator/datasets/instance/segmentation.py +120 -0
  14. edgefirst/validator/datasets/utils/fetch.py +682 -0
  15. edgefirst/validator/datasets/utils/readers.py +425 -0
  16. edgefirst/validator/datasets/utils/transformations.py +1695 -0
  17. edgefirst/validator/evaluators/__init__.py +17 -0
  18. edgefirst/validator/evaluators/callbacks/__init__.py +3 -0
  19. edgefirst/validator/evaluators/callbacks/core.py +192 -0
  20. edgefirst/validator/evaluators/callbacks/plots.py +900 -0
  21. edgefirst/validator/evaluators/callbacks/studio.py +234 -0
  22. edgefirst/validator/evaluators/core.py +257 -0
  23. edgefirst/validator/evaluators/detection.py +749 -0
  24. edgefirst/validator/evaluators/multitask.py +270 -0
  25. edgefirst/validator/evaluators/parameters/__init__.py +53 -0
  26. edgefirst/validator/evaluators/parameters/core.py +554 -0
  27. edgefirst/validator/evaluators/parameters/dataset.py +239 -0
  28. edgefirst/validator/evaluators/parameters/model.py +338 -0
  29. edgefirst/validator/evaluators/parameters/validation.py +528 -0
  30. edgefirst/validator/evaluators/segmentation.py +729 -0
  31. edgefirst/validator/evaluators/utils/__init__.py +3 -0
  32. edgefirst/validator/evaluators/utils/classify.py +292 -0
  33. edgefirst/validator/evaluators/utils/match.py +262 -0
  34. edgefirst/validator/evaluators/utils/timer.py +132 -0
  35. edgefirst/validator/metrics/__init__.py +9 -0
  36. edgefirst/validator/metrics/data/__init__.py +7 -0
  37. edgefirst/validator/metrics/data/label.py +668 -0
  38. edgefirst/validator/metrics/data/metrics.py +759 -0
  39. edgefirst/validator/metrics/data/plots.py +476 -0
  40. edgefirst/validator/metrics/data/stats.py +507 -0
  41. edgefirst/validator/metrics/detection.py +595 -0
  42. edgefirst/validator/metrics/segmentation.py +173 -0
  43. edgefirst/validator/metrics/utils/math.py +717 -0
  44. edgefirst/validator/publishers/__init__.py +3 -0
  45. edgefirst/validator/publishers/console.py +147 -0
  46. edgefirst/validator/publishers/studio.py +128 -0
  47. edgefirst/validator/publishers/tensorboard.py +119 -0
  48. edgefirst/validator/publishers/utils/logger.py +111 -0
  49. edgefirst/validator/publishers/utils/table.py +403 -0
  50. edgefirst/validator/runners/__init__.py +8 -0
  51. edgefirst/validator/runners/core.py +727 -0
  52. edgefirst/validator/runners/deepviewrt.py +177 -0
  53. edgefirst/validator/runners/hailo.py +263 -0
  54. edgefirst/validator/runners/keras.py +150 -0
  55. edgefirst/validator/runners/kinara.py +265 -0
  56. edgefirst/validator/runners/offline.py +228 -0
  57. edgefirst/validator/runners/onnx.py +241 -0
  58. edgefirst/validator/runners/processing/decode.py +320 -0
  59. edgefirst/validator/runners/processing/dvapi.py +4192 -0
  60. edgefirst/validator/runners/processing/nms.py +637 -0
  61. edgefirst/validator/runners/processing/outputs.py +507 -0
  62. edgefirst/validator/runners/tensorrt.py +321 -0
  63. edgefirst/validator/runners/tflite.py +221 -0
  64. edgefirst/validator/validate.py +843 -0
  65. edgefirst/validator/visualize/__init__.py +3 -0
  66. edgefirst/validator/visualize/detection.py +623 -0
  67. edgefirst/validator/visualize/segmentation.py +281 -0
  68. edgefirst/validator/visualize/utils/plots.py +635 -0
  69. edgefirst_validator-4.2.1.dist-info/METADATA +111 -0
  70. edgefirst_validator-4.2.1.dist-info/RECORD +73 -0
  71. edgefirst_validator-4.2.1.dist-info/WHEEL +5 -0
  72. edgefirst_validator-4.2.1.dist-info/entry_points.txt +2 -0
  73. edgefirst_validator-4.2.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,843 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import ast
5
+ import zipfile
6
+ import datetime
7
+ import traceback
8
+ from typing import TYPE_CHECKING, Union, List, Tuple
9
+
10
+ import yaml
11
+ from edgefirst_client import Client
12
+
13
+ from edgefirst.validator.datasets import instantiate_dataset
14
+ from edgefirst.validator.datasets.utils.fetch import (classify_dataset,
15
+ download_file)
16
+ from edgefirst.validator.datasets.utils.readers import (read_labels_file,
17
+ read_yaml_file)
18
+ from edgefirst.validator.publishers.utils.logger import (logger,
19
+ set_symbol_condition)
20
+ from edgefirst.validator.runners import (TFliteRunner, ONNXRunner, KerasRunner,
21
+ TensorRTRunner, OfflineRunner,
22
+ DeepViewRTRunner, KinaraRunner)
23
+ from edgefirst.validator.evaluators import (CombinedParameters, CommonParameters,
24
+ ModelParameters, DatasetParameters,
25
+ ValidationParameters, TimerContext)
26
+ from edgefirst.validator.evaluators import (YOLOValidator, EdgeFirstValidator,
27
+ YOLOSegmentationValidator,
28
+ SegmentationValidator,
29
+ MultitaskValidator,
30
+ StudioProgress)
31
+ from edgefirst.validator.publishers import StudioPublisher
32
+ from edgefirst.validator.datasets import StudioCache
33
+
34
+ if TYPE_CHECKING:
35
+ from edgefirst.validator.runners import Runner
36
+ from edgefirst.validator.datasets import Dataset
37
+ from edgefirst.validator.evaluators import Evaluator
38
+
39
+
40
+ def build_parameters(args) -> CombinedParameters:
41
+ """
42
+ Store command line arguments inside the `Parameters` object.
43
+
44
+ Parameters
45
+ ----------
46
+ args: argsparse.NameSpace
47
+ The command line arguments.
48
+
49
+ Returns
50
+ -------
51
+ CombinedParameters
52
+ This object is a container for both the model
53
+ and validation parameters set from the command line.
54
+ """
55
+ # Time of validation
56
+ today = datetime.datetime.now().strftime(
57
+ '%Y-%m-%d--%H:%M:%S').replace(":", "_")
58
+ tensorboard, visualize, json_out = None, None, None
59
+ if args.visualize:
60
+ visualize = os.path.join(
61
+ args.visualize,
62
+ f"{os.path.basename(os.path.normpath(args.model))}_{today}")
63
+ elif args.tensorboard:
64
+ tensorboard = os.path.join(
65
+ args.tensorboard,
66
+ f"{os.path.basename(os.path.normpath(args.model))}_{today}"
67
+ )
68
+
69
+ json_out = args.json_out
70
+ if args.session_id is not None:
71
+ if json_out is None:
72
+ json_out = "apex_charts"
73
+
74
+ if json_out:
75
+ json_out = os.path.join(
76
+ json_out,
77
+ f"{os.path.basename(os.path.normpath(args.model))}_{today}"
78
+ )
79
+
80
+ validation_parameters = ValidationParameters(
81
+ method=args.method,
82
+ iou_threshold=args.validation_iou,
83
+ score_threshold=args.validation_score,
84
+ metric=args.metric,
85
+ matching_leniency=args.matching_leniency,
86
+ clamp_boxes=args.clamp_boxes,
87
+ ignore_boxes=args.ignore_boxes,
88
+ display=args.display,
89
+ visualize=visualize,
90
+ tensorboard=tensorboard,
91
+ json_out=json_out,
92
+ csv_out=args.csv,
93
+ include_background=args.include_background
94
+ )
95
+
96
+ common_parameters = CommonParameters(
97
+ norm=args.norm,
98
+ preprocessing=args.preprocessing,
99
+ backend=args.backend
100
+ )
101
+ common_parameters.check_backend_availability()
102
+
103
+ model_parameters = ModelParameters(
104
+ common_parameters=common_parameters,
105
+ model_path=args.model,
106
+ iou_threshold=args.nms_iou_threshold,
107
+ score_threshold=args.nms_score_threshold,
108
+ max_detections=args.max_detections,
109
+ engine=args.engine,
110
+ nms=args.nms,
111
+ box_format=args.box_format,
112
+ warmup=args.warmup,
113
+ labels_path=args.model_labels,
114
+ label_offset=args.label_offset,
115
+ agnostic_nms=not args.class_nms
116
+ )
117
+ model_parameters.check_nms_availability()
118
+
119
+ dataset_parameters = DatasetParameters(
120
+ common_parameters=common_parameters,
121
+ dataset_path=args.dataset,
122
+ show_missing_annotations=args.show_missing_annotations,
123
+ normalized=args.absolute_annotations,
124
+ box_format=args.annotation_format,
125
+ labels_path=args.dataset_labels,
126
+ label_offset=args.gt_label_offset,
127
+ )
128
+ dataset_parameters.silent = validation_parameters.silent
129
+ dataset_parameters.visualize = (validation_parameters.visualize or
130
+ validation_parameters.tensorboard)
131
+
132
+ parameters = CombinedParameters(
133
+ model_parameters=model_parameters,
134
+ dataset_parameters=dataset_parameters,
135
+ validation_parameters=validation_parameters
136
+ )
137
+
138
+ if (model_parameters.nms in ["hal", "numpy", "torch"] and
139
+ not model_parameters.agnostic_nms):
140
+ logger(
141
+ "Class-based NMS is currently not supported for the {} NMS.".format(
142
+ model_parameters.nms), code="INFO"
143
+ )
144
+ return parameters
145
+
146
+
147
+ def build_dataset(
148
+ args,
149
+ parameters: DatasetParameters,
150
+ timer: TimerContext,
151
+ studio_cache: StudioCache,
152
+ ) -> Dataset:
153
+ """
154
+ Instantiate the Dataset Reader.
155
+
156
+ Parameters
157
+ ----------
158
+ args: argsparse.NameSpace
159
+ The command line arguments.
160
+ parameters: DatasetParameters
161
+ Contains the dataset parameters set from the command line.
162
+ timer: TimerContext
163
+ A timer object for handling validation timings in
164
+ the dataset input preprocessing.
165
+ studio_cache: StudioCache
166
+ The object used for downloading and caching the dataset.
167
+
168
+ Returns
169
+ -------
170
+ Dataset
171
+ This can be any dataset reader such as a DarkNetDataset,
172
+ EdgeFirstDatabase, etc. depending on the dataset format that
173
+ was specified.
174
+ """
175
+
176
+ if args.session_id is not None:
177
+ # Avoid the default dataset path for studio validation.
178
+ if args.dataset == "samples/coco128.yaml":
179
+ args.dataset = "dataset"
180
+ parameters.dataset_path = args.dataset
181
+
182
+ if parameters.labels_path and os.path.exists(parameters.labels_path):
183
+ parameters.labels = read_labels_file(parameters.labels_path)
184
+
185
+ # Download the dataset if it doesn't exist.
186
+ if not (os.path.exists(args.dataset) and os.listdir(args.dataset)):
187
+ logger("The dataset does not exist. " +
188
+ f"Attempting to download the dataset to '{args.dataset}'",
189
+ code="INFO")
190
+ studio_cache.download(args.dataset)
191
+ else:
192
+ studio_cache.complete_stage(
193
+ stage=studio_cache.stages[0][0],
194
+ message=studio_cache.stages[0][1]
195
+ )
196
+ studio_cache.complete_stage(
197
+ stage=studio_cache.stages[1][0],
198
+ message=studio_cache.stages[1][1]
199
+ )
200
+
201
+ # Use the dataset cache if specified and it exists.
202
+ if args.cache is not None:
203
+ parameters.cache = True
204
+ if os.path.exists(args.cache):
205
+ parameters.dataset_path = args.cache
206
+ studio_cache.complete_stage(
207
+ stage=studio_cache.stages[2][0],
208
+ message=studio_cache.stages[2][1]
209
+ )
210
+
211
+ # Determine the dataset type.
212
+ info_dataset = classify_dataset(
213
+ source=parameters.dataset_path,
214
+ labels_path=parameters.labels_path
215
+ )
216
+
217
+ # Build the dataset class depending on the type.
218
+ return instantiate_dataset(
219
+ info_dataset=info_dataset,
220
+ parameters=parameters,
221
+ timer=timer
222
+ )
223
+
224
+
225
+ def build_runner(args, parameters: ModelParameters,
226
+ timer: TimerContext) -> Runner:
227
+ """
228
+ Instantiate the model runners.
229
+
230
+ Parameters
231
+ ----------
232
+ args: argsparse.NameSpace
233
+ The command line arguments.
234
+ parameters: ModelParameters
235
+ Contains the model parameters set from the command line.
236
+ timer: TimerContext
237
+ A timer object for handling validation timings in the model.
238
+
239
+ Returns
240
+ -------
241
+ Runner
242
+ This can be any model runner depending on the model passed
243
+ such as ONNX, TFLite, Keras, RTM, etc.
244
+
245
+ Raises
246
+ ------
247
+ NotImplementedError
248
+ Certain runner implementations are not yet implemented.
249
+ """
250
+ if (not os.path.exists(parameters.model_path) and
251
+ parameters.model_path == "yolov5s.onnx"):
252
+ download_file(
253
+ url="https://github.com/ultralytics/yolov5/releases/download/v7.0/yolov5s.onnx",
254
+ download_path=os.path.join(os.getcwd(), "yolov5s.onnx")
255
+ )
256
+
257
+ model_metadata = get_model_metadata(args)
258
+ # Validate with the model metadata parameters.
259
+ # By default in the command line override is set to True to use
260
+ # the command line parameters. Otherwise in EdgeFirst Studio, override
261
+ # is set to False to use model meta parameters.
262
+ if not args.override and model_metadata is not None:
263
+ parameters.score_threshold = model_metadata\
264
+ .get("validation", {})\
265
+ .get("score",
266
+ parameters.score_threshold)
267
+ parameters.iou_threshold = model_metadata\
268
+ .get("validation", {})\
269
+ .get("iou", parameters.iou_threshold)
270
+ parameters.common.norm = model_metadata\
271
+ .get("validation", {})\
272
+ .get("normalization", parameters.common.norm)
273
+ parameters.common.preprocessing = model_metadata\
274
+ .get("validation", {})\
275
+ .get("preprocessing",
276
+ parameters.common.preprocessing)
277
+
278
+ # KERAS
279
+ if os.path.splitext(parameters.model_path)[1].lower() in [".h5", ".keras"]:
280
+ runner = KerasRunner(parameters.model_path,
281
+ parameters=parameters,
282
+ metadata=model_metadata,
283
+ timer=timer)
284
+ # TFLITE
285
+ elif os.path.splitext(parameters.model_path)[1].lower() == ".tflite":
286
+ runner = TFliteRunner(parameters.model_path,
287
+ parameters=parameters,
288
+ metadata=model_metadata,
289
+ timer=timer)
290
+ # ONNX
291
+ elif os.path.splitext(parameters.model_path)[1].lower() == ".onnx":
292
+ runner = ONNXRunner(parameters.model_path,
293
+ parameters=parameters,
294
+ metadata=model_metadata,
295
+ timer=timer)
296
+ # TENSORRT
297
+ elif os.path.splitext(parameters.model_path)[1].lower() in [".engine", ".trt"]:
298
+ runner = TensorRTRunner(parameters.model_path,
299
+ parameters=parameters,
300
+ metadata=model_metadata,
301
+ timer=timer)
302
+ # KINARA
303
+ elif os.path.splitext(parameters.model_path)[1].lower() == ".dvm":
304
+ runner = KinaraRunner(
305
+ parameters.model_path,
306
+ parameters=parameters,
307
+ metadata=model_metadata,
308
+ timer=timer
309
+ )
310
+ # HAILO
311
+ elif os.path.splitext(parameters.model_path)[1].lower() == ".hef":
312
+ raise NotImplementedError(
313
+ "Running Hailo models is not implemented.")
314
+ # DEEPVIEWRT EVALUATION
315
+ elif os.path.splitext(parameters.model_path)[1].lower() == ".rtm":
316
+ runner = DeepViewRTRunner(
317
+ model=parameters.model_path,
318
+ parameters=parameters,
319
+ metadata=model_metadata,
320
+ timer=timer
321
+ )
322
+ # OFFLINE (TEXT FILES) or SAVED MODEL Directory
323
+ elif os.path.splitext(parameters.model_path)[1].lower() == "":
324
+ runner = find_keras_pb_model(parameters=parameters,
325
+ metadata=model_metadata,
326
+ timer=timer)
327
+
328
+ if runner is None:
329
+ logger("Model extension does not exist, running offline validation.",
330
+ code='INFO')
331
+
332
+ runner = OfflineRunner(
333
+ annotation_source=parameters.model_path,
334
+ parameters=parameters,
335
+ timer=timer
336
+ )
337
+ else:
338
+ raise NotImplementedError(
339
+ "Running the model '{}' is currently not supported".format(
340
+ parameters.model_path)
341
+ )
342
+ return runner
343
+
344
+
345
+ def build_evaluator(
346
+ args,
347
+ parameters: CombinedParameters,
348
+ client: Client,
349
+ stages: List[Tuple[str, str]]
350
+ ) -> Evaluator:
351
+ """
352
+ Intantiate the evaluator object depending on the task.
353
+
354
+ Parameters
355
+ ----------
356
+ args: argsparse.NameSpace
357
+ The command line arguments.
358
+ parameters: CombinedParameters
359
+ This object is a container for both model, dataset, and validation
360
+ parameters set from the command line.
361
+ client: Client
362
+ The EdgeFirst Client object.
363
+ stages: List[Tuple[str, str]]
364
+ This contains the stages that tracks each progress in Studio.
365
+ A stage contains ("stage identifier", "stage description").
366
+
367
+ Returns
368
+ -------
369
+ Evaluator
370
+ This can be any evaluator object depending on the task such
371
+ as segmentation, detection, multitask, or pose.
372
+
373
+ Raises
374
+ ------
375
+ ValueError
376
+ Dataset labels were not found.
377
+ NotImplementedError
378
+ Certain validation types are not yet implemented.
379
+ """
380
+ timer = TimerContext()
381
+ studio_cache = StudioCache(
382
+ parameters=parameters.dataset,
383
+ stages=stages,
384
+ client=client,
385
+ session_id=args.session_id,
386
+ )
387
+
388
+ dataset = build_dataset(
389
+ args, parameters=parameters.dataset, timer=timer,
390
+ studio_cache=studio_cache,
391
+ )
392
+
393
+ if parameters.dataset.labels is None or len(
394
+ parameters.dataset.labels) == 0:
395
+ raise ValueError(
396
+ "The unique set of string labels from the dataset was not found. " +
397
+ "Try setting --dataset-labels=path/to/labels.txt")
398
+
399
+ # Read labels.txt or assign the dataset labels as the model labels as a fallback.
400
+ # During validation, all model indices will be translated to the dataset indices
401
+ # for a 1-to-1 match.
402
+ if parameters.model.labels is None or len(parameters.model.labels) == 0:
403
+ parameters.model.labels = get_model_labels(args, parameters.dataset)
404
+
405
+ # Builds the runner and assigns conditions for with_masks or with_boxes.
406
+ runner = build_runner(args, parameters=parameters.model, timer=timer)
407
+
408
+ # Cache the dataset if it doesn't exist.
409
+ # This block is placed after the building the runner object to initialize
410
+ # with_masks and with_boxes conditions needed for iterating the dataset.
411
+ if args.cache is not None and not os.path.exists(args.cache):
412
+ logger("The dataset cache does not exist. " +
413
+ f"Attempting to cache existing dataset to {args.cache}",
414
+ code="INFO")
415
+ dataset = instantiate_dataset(
416
+ info_dataset=dataset.info_dataset,
417
+ parameters=parameters.dataset,
418
+ timer=timer
419
+ )
420
+ dataset = studio_cache.cache(dataset, args.cache)
421
+ parameters.dataset.dataset_path = args.cache
422
+
423
+ dataset.verify_dataset()
424
+
425
+ # If the model labels has background, but the dataset does not,
426
+ # include background in the dataset labels with a +1 offset to label
427
+ # indices.
428
+ if ("background" in parameters.model.labels and
429
+ "background" not in parameters.dataset.labels):
430
+ parameters.dataset.labels = ['background'] + parameters.dataset.labels
431
+ parameters.dataset.label_offset = 1
432
+
433
+ # If the labels in the dataset and the model do not match.
434
+ # However consider possibility of the background class inside the dataset.
435
+ if abs(len(parameters.dataset.labels) - len(parameters.model.labels)) > 1:
436
+ logger(
437
+ "The model contains {} labels and the dataset contains {} labels.".format(
438
+ len(parameters.model.labels),
439
+ len(parameters.dataset.labels)
440
+ ),
441
+ code="WARNING")
442
+
443
+ dataset_labels = parameters.dataset.labels
444
+ model_labels = parameters.model.labels
445
+ if len(dataset_labels) < len(model_labels):
446
+ offset = len(model_labels) - len(dataset_labels)
447
+ parameters.dataset.labels += ["unknown"] * offset
448
+ else:
449
+ offset = len(dataset_labels) - len(model_labels)
450
+ parameters.model.labels += ["unknown"] * offset
451
+
452
+ """
453
+ Instantiate evaluators
454
+ """
455
+ # Multitask Validation
456
+ if parameters.model.common.with_boxes and parameters.model.common.with_masks:
457
+ if (not parameters.model.common.semantic and
458
+ parameters.validation.method in ["ultralytics", "yolov7"]):
459
+ # Ultralytics segmentation models are always multitask models.
460
+ evaluator = YOLOSegmentationValidator(
461
+ parameters=parameters,
462
+ runner=runner,
463
+ dataset=dataset
464
+ )
465
+ else:
466
+ logger("Detected semantic segmentation model. " +
467
+ "Deploying EdgeFirst validation.",
468
+ code="INFO")
469
+ evaluator = MultitaskValidator(
470
+ parameters=parameters,
471
+ runner=runner,
472
+ dataset=dataset
473
+ )
474
+ # Segmentation Validation
475
+ elif parameters.model.common.with_masks:
476
+ logger("Detected semantic segmentation model. " +
477
+ "Deploying EdgeFirst validation.",
478
+ code="INFO")
479
+ # Semantic Segmentation models from ModelPack are validated using
480
+ # EdgeFirst
481
+ parameters.validation.method = "edgefirst"
482
+ evaluator = SegmentationValidator(
483
+ parameters=parameters,
484
+ runner=runner,
485
+ dataset=dataset
486
+ )
487
+ # Detection Validation
488
+ elif parameters.model.common.with_boxes:
489
+ if parameters.validation.method in ["ultralytics", "yolov7"]:
490
+ evaluator = YOLOValidator(
491
+ parameters=parameters,
492
+ runner=runner,
493
+ dataset=dataset
494
+ )
495
+ else:
496
+ evaluator = EdgeFirstValidator(
497
+ parameters=parameters,
498
+ runner=runner,
499
+ dataset=dataset
500
+ )
501
+ else:
502
+ raise RuntimeError(
503
+ "Both values for `with_boxes` and `with_masks` were set to False.")
504
+
505
+ return evaluator
506
+
507
+
508
+ def find_keras_pb_model(
509
+ parameters: ModelParameters,
510
+ metadata: dict,
511
+ timer: TimerContext
512
+ ) -> Union[KerasRunner, None]:
513
+ """
514
+ Instantiate Keras runners based on pb model extension.
515
+
516
+ Parameters
517
+ ----------
518
+ parameters: Parameters
519
+ These are the model parameters loaded by the command line.
520
+ metadata: dict
521
+ The model metadata which contains information for decoding
522
+ the model outputs.
523
+ timer: TimerContext
524
+ A timer object handling validation timings in the model.
525
+
526
+ Returns
527
+ -------
528
+ Union[KerasRunner, None]
529
+ If 'keras_metadata.pb' or 'saved_model.pb' files exists, then
530
+ the KerasRunner is instantiated. This is the runner object for
531
+ deploying Keras models for inference. Otherwise, None is returned.
532
+ """
533
+ runner = None
534
+ for root, _, files in os.walk(parameters.model_path):
535
+ for file in files:
536
+ if (os.path.basename(file) == "keras_metadata.pb" or
537
+ os.path.basename(file) == "saved_model.pb"):
538
+ runner = KerasRunner(
539
+ model=root,
540
+ parameters=parameters,
541
+ metadata=metadata,
542
+ timer=timer
543
+ )
544
+ break
545
+ return runner
546
+
547
+
548
+ def get_model_labels(args, parameters: DatasetParameters) -> list:
549
+ """
550
+ Fetch the labels associated to the model.
551
+
552
+ Parameters
553
+ ----------
554
+ args: argsparse.NameSpace
555
+ The command line arguments.
556
+ parameters: DatasetParameters
557
+ The dataset parameters set from the command line.
558
+
559
+ Returns
560
+ -------
561
+ list
562
+ The list of model labels.
563
+ """
564
+ model_labels = parameters.labels
565
+
566
+ arg_labels, embedded_labels = [], []
567
+ if args.model_labels and os.path.exists(args.model_labels):
568
+ arg_labels = read_labels_file(args.model_labels)
569
+ model_labels = arg_labels
570
+
571
+ if args.model.endswith('.tflite'):
572
+ if zipfile.is_zipfile(args.model):
573
+ with zipfile.ZipFile(args.model, 'r') as zip_ref:
574
+ # Find the first .txt file inside the ZIP.
575
+ txt_files = [name for name in zip_ref.namelist()
576
+ if name.lower().endswith('.txt')]
577
+ if txt_files:
578
+ # Pick the first .txt file (or handle multiple if needed).
579
+ with zip_ref.open(txt_files[0]) as file:
580
+ content = file.read().decode('utf-8').strip()
581
+ try:
582
+ model_metadata = ast.literal_eval(content)
583
+ names = model_metadata.get("names", {})
584
+ embedded_labels = [name for name in names.values()]
585
+ except (ValueError, SyntaxError):
586
+ embedded_labels = [line
587
+ for line in content.splitlines()
588
+ if line not in ["\n", "", "\t"]]
589
+ model_labels = embedded_labels
590
+
591
+ if len(arg_labels) and len(embedded_labels):
592
+ if arg_labels != embedded_labels:
593
+ logger("The contents of the specified --model-labels does not match " +
594
+ "the labels embedded in the model. Falling back to the " +
595
+ "labels embedded in the model", code="WARNING")
596
+
597
+ if not (len(arg_labels) or len(embedded_labels)):
598
+ logger("Model labels was not specified. " +
599
+ "Falling back to use the dataset labels for the model.",
600
+ code="WARNING")
601
+ return model_labels
602
+
603
+
604
+ def get_model_metadata(args) -> Union[dict, None]:
605
+ """
606
+ Returns the model metadata for decoding the outputs.
607
+
608
+ Parameters
609
+ ----------
610
+ args: argsparse.NameSpace
611
+ The command line arguments.
612
+
613
+ Returns
614
+ -------
615
+ Union[dict, None]
616
+ The model metadata if it exists. Otherwise None is returned.
617
+ """
618
+ if args.config is not None:
619
+ return read_yaml_file(args.config)
620
+ if zipfile.is_zipfile(args.model):
621
+ with zipfile.ZipFile(args.model) as zip_ref:
622
+ if "edgefirst.yaml" in zip_ref.namelist():
623
+ file = "edgefirst.yaml"
624
+ elif "config.yaml" in zip_ref.namelist():
625
+ file = "config.yaml"
626
+ else:
627
+ return None
628
+ with zip_ref.open(file) as f:
629
+ yaml_text = f.read().decode("utf-8")
630
+ metadata = yaml.safe_load(yaml_text)
631
+ return metadata
632
+ return None
633
+
634
+
635
+ def download_model_artifacts(args, client: Client):
636
+ """
637
+ Download model artifacts in EdgeFirst Studio.
638
+
639
+ Parameters
640
+ ----------
641
+ args: argsparse.NameSpace
642
+ The command line arguments.
643
+ client: Client
644
+ The EdgeFirst Studio client object to
645
+ communicate with EdgeFirst Studio.
646
+ """
647
+ session = client.validation_session(session_id=args.session_id)
648
+
649
+ train_session_id = session.training_session_id
650
+ model = session.params["model"]
651
+
652
+ logger(f"Downloading model artifacts from train session ID " +
653
+ f"'t-{train_session_id.value:x}'.", code="INFO")
654
+
655
+ # Do not auto-download the model, in case offline validation is specified.
656
+ if not os.path.exists(args.model):
657
+ model = str(model)
658
+ if "String" in model:
659
+ model = model.removeprefix("String(").removesuffix(")")
660
+
661
+ try:
662
+ client.download_artifact(
663
+ training_session_id=train_session_id,
664
+ modelname=model,
665
+ filename=model
666
+ )
667
+ except RuntimeError as e:
668
+ if "Status(404" in str(e):
669
+ raise FileNotFoundError(
670
+ f"The artifact '{model}' does not exist.")
671
+ raise e
672
+ args.model = os.path.join(os.path.dirname(args.model), model)
673
+
674
+ if args.model_labels is None:
675
+ args.model_labels = "labels.txt"
676
+
677
+ if args.config is None:
678
+ args.config = "edgefirst.yaml"
679
+
680
+ try:
681
+ client.download_artifact(
682
+ training_session_id=train_session_id,
683
+ modelname=args.model_labels,
684
+ filename=args.model_labels
685
+ )
686
+ except RuntimeError as e:
687
+ if "Status(404" in str(e):
688
+ raise FileNotFoundError(
689
+ "The artifact 'labels.txt' does not exist.")
690
+ raise e
691
+
692
+ try:
693
+ client.download_artifact(
694
+ training_session_id=train_session_id,
695
+ modelname=args.config,
696
+ filename=args.config
697
+ )
698
+ except RuntimeError as e:
699
+ if "Status(404" in str(e):
700
+ raise FileNotFoundError(
701
+ "The artifact 'edgefirst.yaml' does not exist.")
702
+ raise e
703
+
704
+
705
+ def update_parameters(args, client: Client):
706
+ """
707
+ Updates the parameters specified by EdgeFirst Studio.
708
+
709
+ Parameters
710
+ ----------
711
+ args: argsparse.NameSpace
712
+ The command line arguments.
713
+ client: Client
714
+ The EdgeFirst Client object.
715
+ """
716
+ session = client.validation_session(args.session_id)
717
+
718
+ args.method = session.params.get("method", args.method)
719
+ args.override = "override" in session.params.keys()
720
+ args.nms_score_threshold = session.params.get("nms_score_threshold",
721
+ args.nms_score_threshold)
722
+ args.nms_iou_threshold = session.params.get("nms_iou_threshold",
723
+ args.nms_iou_threshold)
724
+
725
+
726
+ def initialize_studio_client(args) -> Union[Client, None]:
727
+ """
728
+ Initialize the EdgeFirst Client if the validation session ID is set.
729
+ Downloads the model artifacts if it doesn't exist.
730
+
731
+ Parameters
732
+ ----------
733
+ args: argsparse.NameSpace
734
+ The command line arguments.
735
+
736
+ Returns
737
+ -------
738
+ Union[Client, None]
739
+ The EdgeFirst client object is a bridge of communication between
740
+ EdgeFirst Studio and the applications. Otherwise None is
741
+ returned if the validation session ID is not specified.
742
+ """
743
+ client = None
744
+ if args.session_id is not None:
745
+ if args.session_id.isdigit():
746
+ args.session_id = int(args.session_id)
747
+ logger(f"Detected EdgeFirst Studio validation ID: '{args.session_id}'.",
748
+ code="INFO")
749
+
750
+ try:
751
+ client = Client(
752
+ token=args.token,
753
+ username=args.username,
754
+ password=args.password,
755
+ server=args.server
756
+ )
757
+ except RuntimeError as e:
758
+ if "MaxRetries" in str(e):
759
+ raise ValueError(
760
+ f"Got an invalid server: {args.server}. " +
761
+ "Check that the right server is set.")
762
+ raise e
763
+ return client
764
+
765
+
766
+ def validate(args):
767
+ """
768
+ Instantiates the runners and readers to deploy the model for validation.
769
+
770
+ Parameters
771
+ ----------
772
+ args: argsparse.NameSpace
773
+ The command line arguments set.
774
+ """
775
+ set_symbol_condition(args.exclude_symbols)
776
+
777
+ client = initialize_studio_client(args)
778
+ studio_publisher = None
779
+ evaluator = None
780
+
781
+ # Progress stages are defined in the order below.
782
+ # If the order is to change, update the stages defined in StudioCache.
783
+ stages = [
784
+ ("fetch_img", "Downloading Images"),
785
+ ("fetch_as", "Downloading Annotations"),
786
+ ("validate", "Running Validation"),
787
+ ]
788
+ if args.cache is not None:
789
+ stages.insert(2, ("cache", "Caching Dataset"))
790
+
791
+ if client is not None:
792
+ studio_publisher = StudioPublisher(
793
+ json_path=args.json_out,
794
+ session_id=args.session_id,
795
+ client=client
796
+ )
797
+
798
+ try:
799
+ if studio_publisher is not None:
800
+ session = client.validation_session(session_id=args.session_id)
801
+ client.set_stages(session.task.id, stages)
802
+
803
+ download_model_artifacts(args, client=client)
804
+ # Update parameters set from the validation session in studio.
805
+ update_parameters(args=args, client=client)
806
+
807
+ parameters = build_parameters(args)
808
+ studio_publisher.json_path = parameters.validation.json_out
809
+ else:
810
+ parameters = build_parameters(args)
811
+ evaluator = build_evaluator(args, parameters=parameters,
812
+ client=client, stages=stages)
813
+ except Exception as e:
814
+ if studio_publisher is not None:
815
+ studio_publisher.update_stage(
816
+ stage="validate",
817
+ status="error",
818
+ message=str(e),
819
+ percentage=0
820
+ )
821
+ if evaluator is not None:
822
+ evaluator.stop()
823
+ error = traceback.format_exc()
824
+ print(error)
825
+ raise e
826
+
827
+ if args.session_id is not None:
828
+ studio_progress = StudioProgress(
829
+ evaluator=evaluator,
830
+ studio_publisher=studio_publisher,
831
+ stage=stages[-1][0]
832
+ )
833
+ try:
834
+ studio_progress.group_evaluation()
835
+ except Exception as e:
836
+ evaluator.stop()
837
+ raise e
838
+ else:
839
+ try:
840
+ evaluator.group_evaluation()
841
+ except Exception as e:
842
+ evaluator.stop()
843
+ raise e