dtlpy 1.113.10__py3-none-any.whl → 1.114.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dtlpy/__init__.py +488 -488
- dtlpy/__version__.py +1 -1
- dtlpy/assets/__init__.py +26 -26
- dtlpy/assets/__pycache__/__init__.cpython-38.pyc +0 -0
- dtlpy/assets/code_server/config.yaml +2 -2
- dtlpy/assets/code_server/installation.sh +24 -24
- dtlpy/assets/code_server/launch.json +13 -13
- dtlpy/assets/code_server/settings.json +2 -2
- dtlpy/assets/main.py +53 -53
- dtlpy/assets/main_partial.py +18 -18
- dtlpy/assets/mock.json +11 -11
- dtlpy/assets/model_adapter.py +83 -83
- dtlpy/assets/package.json +61 -61
- dtlpy/assets/package_catalog.json +29 -29
- dtlpy/assets/package_gitignore +307 -307
- dtlpy/assets/service_runners/__init__.py +33 -33
- dtlpy/assets/service_runners/converter.py +96 -96
- dtlpy/assets/service_runners/multi_method.py +49 -49
- dtlpy/assets/service_runners/multi_method_annotation.py +54 -54
- dtlpy/assets/service_runners/multi_method_dataset.py +55 -55
- dtlpy/assets/service_runners/multi_method_item.py +52 -52
- dtlpy/assets/service_runners/multi_method_json.py +52 -52
- dtlpy/assets/service_runners/single_method.py +37 -37
- dtlpy/assets/service_runners/single_method_annotation.py +43 -43
- dtlpy/assets/service_runners/single_method_dataset.py +43 -43
- dtlpy/assets/service_runners/single_method_item.py +41 -41
- dtlpy/assets/service_runners/single_method_json.py +42 -42
- dtlpy/assets/service_runners/single_method_multi_input.py +45 -45
- dtlpy/assets/voc_annotation_template.xml +23 -23
- dtlpy/caches/base_cache.py +32 -32
- dtlpy/caches/cache.py +473 -473
- dtlpy/caches/dl_cache.py +201 -201
- dtlpy/caches/filesystem_cache.py +89 -89
- dtlpy/caches/redis_cache.py +84 -84
- dtlpy/dlp/__init__.py +20 -20
- dtlpy/dlp/cli_utilities.py +367 -367
- dtlpy/dlp/command_executor.py +764 -764
- dtlpy/dlp/dlp +1 -1
- dtlpy/dlp/dlp.bat +1 -1
- dtlpy/dlp/dlp.py +128 -128
- dtlpy/dlp/parser.py +651 -651
- dtlpy/entities/__init__.py +83 -83
- dtlpy/entities/analytic.py +311 -311
- dtlpy/entities/annotation.py +1879 -1879
- dtlpy/entities/annotation_collection.py +699 -699
- dtlpy/entities/annotation_definitions/__init__.py +20 -20
- dtlpy/entities/annotation_definitions/base_annotation_definition.py +100 -100
- dtlpy/entities/annotation_definitions/box.py +195 -195
- dtlpy/entities/annotation_definitions/classification.py +67 -67
- dtlpy/entities/annotation_definitions/comparison.py +72 -72
- dtlpy/entities/annotation_definitions/cube.py +204 -204
- dtlpy/entities/annotation_definitions/cube_3d.py +149 -149
- dtlpy/entities/annotation_definitions/description.py +32 -32
- dtlpy/entities/annotation_definitions/ellipse.py +124 -124
- dtlpy/entities/annotation_definitions/free_text.py +62 -62
- dtlpy/entities/annotation_definitions/gis.py +69 -69
- dtlpy/entities/annotation_definitions/note.py +139 -139
- dtlpy/entities/annotation_definitions/point.py +117 -117
- dtlpy/entities/annotation_definitions/polygon.py +182 -182
- dtlpy/entities/annotation_definitions/polyline.py +111 -111
- dtlpy/entities/annotation_definitions/pose.py +92 -92
- dtlpy/entities/annotation_definitions/ref_image.py +86 -86
- dtlpy/entities/annotation_definitions/segmentation.py +240 -240
- dtlpy/entities/annotation_definitions/subtitle.py +34 -34
- dtlpy/entities/annotation_definitions/text.py +85 -85
- dtlpy/entities/annotation_definitions/undefined_annotation.py +74 -74
- dtlpy/entities/app.py +220 -220
- dtlpy/entities/app_module.py +107 -107
- dtlpy/entities/artifact.py +174 -174
- dtlpy/entities/assignment.py +399 -399
- dtlpy/entities/base_entity.py +214 -214
- dtlpy/entities/bot.py +113 -113
- dtlpy/entities/codebase.py +296 -296
- dtlpy/entities/collection.py +38 -38
- dtlpy/entities/command.py +169 -169
- dtlpy/entities/compute.py +442 -442
- dtlpy/entities/dataset.py +1285 -1285
- dtlpy/entities/directory_tree.py +44 -44
- dtlpy/entities/dpk.py +470 -470
- dtlpy/entities/driver.py +222 -222
- dtlpy/entities/execution.py +397 -397
- dtlpy/entities/feature.py +124 -124
- dtlpy/entities/feature_set.py +145 -145
- dtlpy/entities/filters.py +641 -641
- dtlpy/entities/gis_item.py +107 -107
- dtlpy/entities/integration.py +184 -184
- dtlpy/entities/item.py +953 -953
- dtlpy/entities/label.py +123 -123
- dtlpy/entities/links.py +85 -85
- dtlpy/entities/message.py +175 -175
- dtlpy/entities/model.py +694 -691
- dtlpy/entities/node.py +1005 -1005
- dtlpy/entities/ontology.py +803 -803
- dtlpy/entities/organization.py +287 -287
- dtlpy/entities/package.py +657 -657
- dtlpy/entities/package_defaults.py +5 -5
- dtlpy/entities/package_function.py +185 -185
- dtlpy/entities/package_module.py +113 -113
- dtlpy/entities/package_slot.py +118 -118
- dtlpy/entities/paged_entities.py +290 -267
- dtlpy/entities/pipeline.py +593 -593
- dtlpy/entities/pipeline_execution.py +279 -279
- dtlpy/entities/project.py +394 -394
- dtlpy/entities/prompt_item.py +499 -499
- dtlpy/entities/recipe.py +301 -301
- dtlpy/entities/reflect_dict.py +102 -102
- dtlpy/entities/resource_execution.py +138 -138
- dtlpy/entities/service.py +958 -958
- dtlpy/entities/service_driver.py +117 -117
- dtlpy/entities/setting.py +294 -294
- dtlpy/entities/task.py +491 -491
- dtlpy/entities/time_series.py +143 -143
- dtlpy/entities/trigger.py +426 -426
- dtlpy/entities/user.py +118 -118
- dtlpy/entities/webhook.py +124 -124
- dtlpy/examples/__init__.py +19 -19
- dtlpy/examples/add_labels.py +135 -135
- dtlpy/examples/add_metadata_to_item.py +21 -21
- dtlpy/examples/annotate_items_using_model.py +65 -65
- dtlpy/examples/annotate_video_using_model_and_tracker.py +75 -75
- dtlpy/examples/annotations_convert_to_voc.py +9 -9
- dtlpy/examples/annotations_convert_to_yolo.py +9 -9
- dtlpy/examples/convert_annotation_types.py +51 -51
- dtlpy/examples/converter.py +143 -143
- dtlpy/examples/copy_annotations.py +22 -22
- dtlpy/examples/copy_folder.py +31 -31
- dtlpy/examples/create_annotations.py +51 -51
- dtlpy/examples/create_video_annotations.py +83 -83
- dtlpy/examples/delete_annotations.py +26 -26
- dtlpy/examples/filters.py +113 -113
- dtlpy/examples/move_item.py +23 -23
- dtlpy/examples/play_video_annotation.py +13 -13
- dtlpy/examples/show_item_and_mask.py +53 -53
- dtlpy/examples/triggers.py +49 -49
- dtlpy/examples/upload_batch_of_items.py +20 -20
- dtlpy/examples/upload_items_and_custom_format_annotations.py +55 -55
- dtlpy/examples/upload_items_with_modalities.py +43 -43
- dtlpy/examples/upload_segmentation_annotations_from_mask_image.py +44 -44
- dtlpy/examples/upload_yolo_format_annotations.py +70 -70
- dtlpy/exceptions.py +125 -125
- dtlpy/miscellaneous/__init__.py +20 -20
- dtlpy/miscellaneous/dict_differ.py +95 -95
- dtlpy/miscellaneous/git_utils.py +217 -217
- dtlpy/miscellaneous/json_utils.py +14 -14
- dtlpy/miscellaneous/list_print.py +105 -105
- dtlpy/miscellaneous/zipping.py +130 -130
- dtlpy/ml/__init__.py +20 -20
- dtlpy/ml/base_feature_extractor_adapter.py +27 -27
- dtlpy/ml/base_model_adapter.py +945 -940
- dtlpy/ml/metrics.py +461 -461
- dtlpy/ml/predictions_utils.py +274 -274
- dtlpy/ml/summary_writer.py +57 -57
- dtlpy/ml/train_utils.py +60 -60
- dtlpy/new_instance.py +252 -252
- dtlpy/repositories/__init__.py +56 -56
- dtlpy/repositories/analytics.py +85 -85
- dtlpy/repositories/annotations.py +916 -916
- dtlpy/repositories/apps.py +383 -383
- dtlpy/repositories/artifacts.py +452 -452
- dtlpy/repositories/assignments.py +599 -599
- dtlpy/repositories/bots.py +213 -213
- dtlpy/repositories/codebases.py +559 -559
- dtlpy/repositories/collections.py +332 -348
- dtlpy/repositories/commands.py +158 -158
- dtlpy/repositories/compositions.py +61 -61
- dtlpy/repositories/computes.py +434 -406
- dtlpy/repositories/datasets.py +1291 -1291
- dtlpy/repositories/downloader.py +895 -895
- dtlpy/repositories/dpks.py +433 -433
- dtlpy/repositories/drivers.py +266 -266
- dtlpy/repositories/executions.py +817 -817
- dtlpy/repositories/feature_sets.py +226 -226
- dtlpy/repositories/features.py +238 -238
- dtlpy/repositories/integrations.py +484 -484
- dtlpy/repositories/items.py +909 -915
- dtlpy/repositories/messages.py +94 -94
- dtlpy/repositories/models.py +877 -867
- dtlpy/repositories/nodes.py +80 -80
- dtlpy/repositories/ontologies.py +511 -511
- dtlpy/repositories/organizations.py +525 -525
- dtlpy/repositories/packages.py +1941 -1941
- dtlpy/repositories/pipeline_executions.py +448 -448
- dtlpy/repositories/pipelines.py +642 -642
- dtlpy/repositories/projects.py +539 -539
- dtlpy/repositories/recipes.py +399 -399
- dtlpy/repositories/resource_executions.py +137 -137
- dtlpy/repositories/schema.py +120 -120
- dtlpy/repositories/service_drivers.py +213 -213
- dtlpy/repositories/services.py +1704 -1704
- dtlpy/repositories/settings.py +339 -339
- dtlpy/repositories/tasks.py +1124 -1124
- dtlpy/repositories/times_series.py +278 -278
- dtlpy/repositories/triggers.py +536 -536
- dtlpy/repositories/upload_element.py +257 -257
- dtlpy/repositories/uploader.py +651 -651
- dtlpy/repositories/webhooks.py +249 -249
- dtlpy/services/__init__.py +22 -22
- dtlpy/services/aihttp_retry.py +131 -131
- dtlpy/services/api_client.py +1782 -1782
- dtlpy/services/api_reference.py +40 -40
- dtlpy/services/async_utils.py +133 -133
- dtlpy/services/calls_counter.py +44 -44
- dtlpy/services/check_sdk.py +68 -68
- dtlpy/services/cookie.py +115 -115
- dtlpy/services/create_logger.py +156 -156
- dtlpy/services/events.py +84 -84
- dtlpy/services/logins.py +235 -235
- dtlpy/services/reporter.py +256 -256
- dtlpy/services/service_defaults.py +91 -91
- dtlpy/utilities/__init__.py +20 -20
- dtlpy/utilities/annotations/__init__.py +16 -16
- dtlpy/utilities/annotations/annotation_converters.py +269 -269
- dtlpy/utilities/base_package_runner.py +264 -264
- dtlpy/utilities/converter.py +1650 -1650
- dtlpy/utilities/dataset_generators/__init__.py +1 -1
- dtlpy/utilities/dataset_generators/dataset_generator.py +670 -670
- dtlpy/utilities/dataset_generators/dataset_generator_tensorflow.py +23 -23
- dtlpy/utilities/dataset_generators/dataset_generator_torch.py +21 -21
- dtlpy/utilities/local_development/__init__.py +1 -1
- dtlpy/utilities/local_development/local_session.py +179 -179
- dtlpy/utilities/reports/__init__.py +2 -2
- dtlpy/utilities/reports/figures.py +343 -343
- dtlpy/utilities/reports/report.py +71 -71
- dtlpy/utilities/videos/__init__.py +17 -17
- dtlpy/utilities/videos/video_player.py +598 -598
- dtlpy/utilities/videos/videos.py +470 -470
- {dtlpy-1.113.10.data → dtlpy-1.114.13.data}/scripts/dlp +1 -1
- dtlpy-1.114.13.data/scripts/dlp.bat +2 -0
- {dtlpy-1.113.10.data → dtlpy-1.114.13.data}/scripts/dlp.py +128 -128
- {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/LICENSE +200 -200
- {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/METADATA +172 -172
- dtlpy-1.114.13.dist-info/RECORD +240 -0
- {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/WHEEL +1 -1
- tests/features/environment.py +551 -550
- dtlpy-1.113.10.data/scripts/dlp.bat +0 -2
- dtlpy-1.113.10.dist-info/RECORD +0 -244
- tests/assets/__init__.py +0 -0
- tests/assets/models_flow/__init__.py +0 -0
- tests/assets/models_flow/failedmain.py +0 -52
- tests/assets/models_flow/main.py +0 -62
- tests/assets/models_flow/main_model.py +0 -54
- {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/entry_points.txt +0 -0
- {dtlpy-1.113.10.dist-info → dtlpy-1.114.13.dist-info}/top_level.txt +0 -0
dtlpy/ml/base_model_adapter.py
CHANGED
|
@@ -1,940 +1,945 @@
|
|
|
1
|
-
import dataclasses
|
|
2
|
-
import tempfile
|
|
3
|
-
import datetime
|
|
4
|
-
import logging
|
|
5
|
-
import string
|
|
6
|
-
import shutil
|
|
7
|
-
import random
|
|
8
|
-
import base64
|
|
9
|
-
import tqdm
|
|
10
|
-
import sys
|
|
11
|
-
import io
|
|
12
|
-
import os
|
|
13
|
-
from PIL import Image
|
|
14
|
-
from functools import partial
|
|
15
|
-
import numpy as np
|
|
16
|
-
from concurrent.futures import ThreadPoolExecutor
|
|
17
|
-
import attr
|
|
18
|
-
from .. import entities, utilities, repositories, exceptions
|
|
19
|
-
from ..services import service_defaults
|
|
20
|
-
from ..services.api_client import ApiClient
|
|
21
|
-
|
|
22
|
-
logger = logging.getLogger('ModelAdapter')
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
@dataclasses.dataclass
|
|
26
|
-
class AdapterDefaults(dict):
|
|
27
|
-
# for predict items, dataset, evaluate
|
|
28
|
-
upload_annotations: bool = dataclasses.field(default=True)
|
|
29
|
-
clean_annotations: bool = dataclasses.field(default=True)
|
|
30
|
-
# for embeddings
|
|
31
|
-
upload_features: bool = dataclasses.field(default=True)
|
|
32
|
-
# for training
|
|
33
|
-
root_path: str = dataclasses.field(default=None)
|
|
34
|
-
data_path: str = dataclasses.field(default=None)
|
|
35
|
-
output_path: str = dataclasses.field(default=None)
|
|
36
|
-
|
|
37
|
-
def __post_init__(self):
|
|
38
|
-
# Initialize the internal dictionary with the dataclass fields
|
|
39
|
-
self.update(**dataclasses.asdict(self))
|
|
40
|
-
|
|
41
|
-
def update(self, **kwargs):
|
|
42
|
-
for f in dataclasses.fields(AdapterDefaults):
|
|
43
|
-
if f.name in kwargs:
|
|
44
|
-
setattr(self, f.name, kwargs[f.name])
|
|
45
|
-
super().update(**kwargs)
|
|
46
|
-
|
|
47
|
-
def resolve(self, key, *args):
|
|
48
|
-
|
|
49
|
-
for arg in args:
|
|
50
|
-
if arg is not None:
|
|
51
|
-
return arg
|
|
52
|
-
return self.get(key, None)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
56
|
-
_client_api = attr.ib(type=ApiClient, repr=False)
|
|
57
|
-
|
|
58
|
-
def __init__(self, model_entity: entities.Model = None):
|
|
59
|
-
self.adapter_defaults = AdapterDefaults()
|
|
60
|
-
self.logger = logger
|
|
61
|
-
# entities
|
|
62
|
-
self._model_entity = None
|
|
63
|
-
self._package = None
|
|
64
|
-
self._base_configuration = dict()
|
|
65
|
-
self.package_name = None
|
|
66
|
-
self.model = None
|
|
67
|
-
self.bucket_path = None
|
|
68
|
-
# funcs
|
|
69
|
-
self.item_to_batch_mapping = {'text': self._item_to_text,
|
|
70
|
-
'image': self._item_to_image}
|
|
71
|
-
if model_entity is not None:
|
|
72
|
-
self.load_from_model(model_entity=model_entity)
|
|
73
|
-
logger.warning(
|
|
74
|
-
"in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'.")
|
|
75
|
-
|
|
76
|
-
##################
|
|
77
|
-
# Configurations #
|
|
78
|
-
##################
|
|
79
|
-
|
|
80
|
-
@property
|
|
81
|
-
def configuration(self) -> dict:
|
|
82
|
-
# load from model
|
|
83
|
-
if self._model_entity is not None:
|
|
84
|
-
configuration = self.model_entity.configuration
|
|
85
|
-
# else - load the default from the package
|
|
86
|
-
elif self._package is not None:
|
|
87
|
-
configuration = self.package.metadata.get('system', {}).get('ml', {}).get('defaultConfiguration', {})
|
|
88
|
-
else:
|
|
89
|
-
configuration = self._base_configuration
|
|
90
|
-
return configuration
|
|
91
|
-
|
|
92
|
-
@configuration.setter
|
|
93
|
-
def configuration(self, d):
|
|
94
|
-
assert isinstance(d, dict)
|
|
95
|
-
if self._model_entity is not None:
|
|
96
|
-
self._model_entity.configuration = d
|
|
97
|
-
|
|
98
|
-
############
|
|
99
|
-
# Entities #
|
|
100
|
-
############
|
|
101
|
-
@property
|
|
102
|
-
def model_entity(self):
|
|
103
|
-
if self._model_entity is None:
|
|
104
|
-
raise ValueError(
|
|
105
|
-
"No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
|
|
106
|
-
assert isinstance(self._model_entity, entities.Model)
|
|
107
|
-
return self._model_entity
|
|
108
|
-
|
|
109
|
-
@model_entity.setter
|
|
110
|
-
def model_entity(self, model_entity):
|
|
111
|
-
assert isinstance(model_entity, entities.Model)
|
|
112
|
-
if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
|
|
113
|
-
if self._model_entity.id != model_entity.id:
|
|
114
|
-
self.logger.warning(
|
|
115
|
-
'Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
|
|
116
|
-
self._model_entity = model_entity
|
|
117
|
-
self.package = model_entity.package
|
|
118
|
-
|
|
119
|
-
@property
|
|
120
|
-
def package(self):
|
|
121
|
-
if self._model_entity is not None:
|
|
122
|
-
self.package = self.model_entity.package
|
|
123
|
-
if self._package is None:
|
|
124
|
-
raise ValueError('Missing Package entity on adapter. Please set: "adapter.package=package"')
|
|
125
|
-
assert isinstance(self._package, (entities.Package, entities.Dpk))
|
|
126
|
-
return self._package
|
|
127
|
-
|
|
128
|
-
@package.setter
|
|
129
|
-
def package(self, package):
|
|
130
|
-
assert isinstance(package, (entities.Package, entities.Dpk))
|
|
131
|
-
self.package_name = package.name
|
|
132
|
-
self._package = package
|
|
133
|
-
|
|
134
|
-
###################################
|
|
135
|
-
# NEED TO IMPLEMENT THESE METHODS #
|
|
136
|
-
###################################
|
|
137
|
-
|
|
138
|
-
def load(self, local_path, **kwargs):
|
|
139
|
-
""" Loads model and populates self.model with a `runnable` model
|
|
140
|
-
|
|
141
|
-
Virtual method - need to implement
|
|
142
|
-
|
|
143
|
-
This function is called by load_from_model (download to local and then loads)
|
|
144
|
-
|
|
145
|
-
:param local_path: `str` directory path in local FileSystem
|
|
146
|
-
"""
|
|
147
|
-
raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
|
|
148
|
-
|
|
149
|
-
def save(self, local_path, **kwargs):
|
|
150
|
-
""" saves configuration and weights locally
|
|
151
|
-
|
|
152
|
-
Virtual method - need to implement
|
|
153
|
-
|
|
154
|
-
the function is called in save_to_model which first save locally and then uploads to model entity
|
|
155
|
-
|
|
156
|
-
:param local_path: `str` directory path in local FileSystem
|
|
157
|
-
"""
|
|
158
|
-
raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
|
|
159
|
-
|
|
160
|
-
def train(self, data_path, output_path, **kwargs):
|
|
161
|
-
"""
|
|
162
|
-
Virtual method - need to implement
|
|
163
|
-
|
|
164
|
-
Train the model according to data in data_paths and save the train outputs to output_path,
|
|
165
|
-
this include the weights and any other artifacts created during train
|
|
166
|
-
|
|
167
|
-
:param data_path: `str` local File System path to where the data was downloaded and converted at
|
|
168
|
-
:param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
|
|
169
|
-
"""
|
|
170
|
-
raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
|
|
171
|
-
|
|
172
|
-
def predict(self, batch, **kwargs):
|
|
173
|
-
""" Model inference (predictions) on batch of items
|
|
174
|
-
|
|
175
|
-
Virtual method - need to implement
|
|
176
|
-
|
|
177
|
-
:param batch: output of the `prepare_item_func` func
|
|
178
|
-
:return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
|
|
179
|
-
"""
|
|
180
|
-
raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
|
|
181
|
-
|
|
182
|
-
def embed(self, batch, **kwargs):
|
|
183
|
-
""" Extract model embeddings on batch of items
|
|
184
|
-
|
|
185
|
-
Virtual method - need to implement
|
|
186
|
-
|
|
187
|
-
:param batch: output of the `prepare_item_func` func
|
|
188
|
-
:return: `list[list]` a feature vector per each item in the batch
|
|
189
|
-
"""
|
|
190
|
-
raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
|
|
191
|
-
|
|
192
|
-
def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
|
|
193
|
-
"""
|
|
194
|
-
This function evaluates the model prediction on a dataset (with GT annotations).
|
|
195
|
-
The evaluation process will upload the scores and metrics to the platform.
|
|
196
|
-
|
|
197
|
-
:param model: The model to evaluate (annotation.metadata.system.model.name
|
|
198
|
-
:param dataset: Dataset where the model predicted and uploaded its annotations
|
|
199
|
-
:param filters: Filters to query items on the dataset
|
|
200
|
-
:return:
|
|
201
|
-
"""
|
|
202
|
-
import dtlpymetrics
|
|
203
|
-
compare_types = model.output_type
|
|
204
|
-
if not filters:
|
|
205
|
-
filters = entities.Filters()
|
|
206
|
-
if filters is not None and isinstance(filters, dict):
|
|
207
|
-
filters = entities.Filters(custom_filter=filters)
|
|
208
|
-
model = dtlpymetrics.scoring.create_model_score(model=model,
|
|
209
|
-
dataset=dataset,
|
|
210
|
-
filters=filters,
|
|
211
|
-
compare_types=compare_types)
|
|
212
|
-
return model
|
|
213
|
-
|
|
214
|
-
def convert_from_dtlpy(self, data_path, **kwargs):
|
|
215
|
-
""" Convert Dataloop structure data to model structured
|
|
216
|
-
|
|
217
|
-
Virtual method - need to implement
|
|
218
|
-
|
|
219
|
-
e.g. take dlp dir structure and construct annotation file
|
|
220
|
-
|
|
221
|
-
:param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
|
|
222
|
-
:return:
|
|
223
|
-
"""
|
|
224
|
-
raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
|
|
225
|
-
|
|
226
|
-
#################
|
|
227
|
-
# DTLPY METHODS #
|
|
228
|
-
################
|
|
229
|
-
def prepare_item_func(self, item: entities.Item):
|
|
230
|
-
"""
|
|
231
|
-
Prepare the Dataloop item before calling the `predict` function with a batch.
|
|
232
|
-
A user can override this function to load item differently
|
|
233
|
-
Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
|
|
234
|
-
|
|
235
|
-
:param item:
|
|
236
|
-
:return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
|
|
237
|
-
"""
|
|
238
|
-
# Item to batch func
|
|
239
|
-
if isinstance(self.model_entity.input_type, list):
|
|
240
|
-
if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
|
|
241
|
-
processed = self._item_to_text(item)
|
|
242
|
-
elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
|
|
243
|
-
processed = self._item_to_image(item)
|
|
244
|
-
else:
|
|
245
|
-
processed = self._item_to_item(item)
|
|
246
|
-
|
|
247
|
-
elif self.model_entity.input_type in self.item_to_batch_mapping:
|
|
248
|
-
processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
|
|
249
|
-
|
|
250
|
-
else:
|
|
251
|
-
processed = self._item_to_item(item)
|
|
252
|
-
|
|
253
|
-
return processed
|
|
254
|
-
|
|
255
|
-
def prepare_data(self,
|
|
256
|
-
dataset: entities.Dataset,
|
|
257
|
-
# paths
|
|
258
|
-
root_path=None,
|
|
259
|
-
data_path=None,
|
|
260
|
-
output_path=None,
|
|
261
|
-
#
|
|
262
|
-
overwrite=False,
|
|
263
|
-
**kwargs):
|
|
264
|
-
"""
|
|
265
|
-
Prepares dataset locally before training or evaluation.
|
|
266
|
-
download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
|
|
267
|
-
|
|
268
|
-
:param dataset: dl.Dataset
|
|
269
|
-
:param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
|
|
270
|
-
:param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
|
|
271
|
-
:param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
|
|
272
|
-
|
|
273
|
-
:param bool overwrite: overwrite the data path (download again). default is False
|
|
274
|
-
"""
|
|
275
|
-
# define paths
|
|
276
|
-
dataloop_path = service_defaults.DATALOOP_PATH
|
|
277
|
-
root_path = self.adapter_defaults.resolve("root_path", root_path)
|
|
278
|
-
data_path = self.adapter_defaults.resolve("data_path", data_path)
|
|
279
|
-
output_path = self.adapter_defaults.resolve("output_path", output_path)
|
|
280
|
-
|
|
281
|
-
if root_path is None:
|
|
282
|
-
now = datetime.datetime.now()
|
|
283
|
-
root_path = os.path.join(dataloop_path,
|
|
284
|
-
'model_data',
|
|
285
|
-
"{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
|
|
286
|
-
now.strftime('%Y-%m-%d-%H%M%S'),
|
|
287
|
-
)
|
|
288
|
-
if data_path is None:
|
|
289
|
-
data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
|
|
290
|
-
os.makedirs(data_path, exist_ok=True)
|
|
291
|
-
if output_path is None:
|
|
292
|
-
output_path = os.path.join(root_path, 'output')
|
|
293
|
-
os.makedirs(output_path, exist_ok=True)
|
|
294
|
-
|
|
295
|
-
if len(os.listdir(data_path)) > 0:
|
|
296
|
-
self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
|
|
297
|
-
|
|
298
|
-
annotation_options = entities.ViewAnnotationOptions.JSON
|
|
299
|
-
if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
|
|
300
|
-
annotation_options = entities.ViewAnnotationOptions.INSTANCE
|
|
301
|
-
|
|
302
|
-
# Download the subset items
|
|
303
|
-
subsets = self.model_entity.metadata.get("system", dict()).get("subsets", None)
|
|
304
|
-
if subsets is None:
|
|
305
|
-
raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
|
|
306
|
-
for subset, filters_dict in subsets.items():
|
|
307
|
-
filters = entities.Filters(custom_filter=filters_dict)
|
|
308
|
-
data_subset_base_path = os.path.join(data_path, subset)
|
|
309
|
-
if os.path.isdir(data_subset_base_path) and not overwrite:
|
|
310
|
-
# existing and dont overwrite
|
|
311
|
-
self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
|
|
312
|
-
else:
|
|
313
|
-
self.logger.debug("Downloading subset {!r} of {}".format(subset,
|
|
314
|
-
self.model_entity.dataset.name))
|
|
315
|
-
|
|
316
|
-
annotation_filters = None
|
|
317
|
-
if self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
|
|
318
|
-
annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
|
|
319
|
-
if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION,
|
|
320
|
-
entities.AnnotationType.POLYGON]:
|
|
321
|
-
model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
|
|
322
|
-
else:
|
|
323
|
-
model_output_types = [self.model_entity.output_type]
|
|
324
|
-
|
|
325
|
-
annotation_filters.add(
|
|
326
|
-
field=entities.FiltersKnownFields.TYPE,
|
|
327
|
-
values=model_output_types,
|
|
328
|
-
operator=entities.FiltersOperations.IN
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
if not self.configuration.get("include_model_annotations", False):
|
|
332
|
-
annotation_filters.add(
|
|
333
|
-
field="metadata.system.model.name",
|
|
334
|
-
values=False,
|
|
335
|
-
operator=entities.FiltersOperations.EXISTS
|
|
336
|
-
)
|
|
337
|
-
|
|
338
|
-
ret_list = dataset.items.download(filters=filters,
|
|
339
|
-
local_path=data_subset_base_path,
|
|
340
|
-
annotation_options=annotation_options,
|
|
341
|
-
annotation_filters=annotation_filters
|
|
342
|
-
)
|
|
343
|
-
if isinstance(ret_list, list) and len(ret_list) == 0:
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
)
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
:
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
feature_set_name
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
:
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
:
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
"""
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
#
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
|
|
716
|
-
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
:
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
logger.info(
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
if
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
"""
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
prediction.
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
if 'model'
|
|
877
|
-
prediction.metadata['
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
'
|
|
881
|
-
|
|
882
|
-
'
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
"""
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
"
|
|
932
|
-
"
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
1
|
+
import dataclasses
|
|
2
|
+
import tempfile
|
|
3
|
+
import datetime
|
|
4
|
+
import logging
|
|
5
|
+
import string
|
|
6
|
+
import shutil
|
|
7
|
+
import random
|
|
8
|
+
import base64
|
|
9
|
+
import tqdm
|
|
10
|
+
import sys
|
|
11
|
+
import io
|
|
12
|
+
import os
|
|
13
|
+
from PIL import Image
|
|
14
|
+
from functools import partial
|
|
15
|
+
import numpy as np
|
|
16
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
17
|
+
import attr
|
|
18
|
+
from .. import entities, utilities, repositories, exceptions
|
|
19
|
+
from ..services import service_defaults
|
|
20
|
+
from ..services.api_client import ApiClient
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger('ModelAdapter')
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclasses.dataclass
|
|
26
|
+
class AdapterDefaults(dict):
|
|
27
|
+
# for predict items, dataset, evaluate
|
|
28
|
+
upload_annotations: bool = dataclasses.field(default=True)
|
|
29
|
+
clean_annotations: bool = dataclasses.field(default=True)
|
|
30
|
+
# for embeddings
|
|
31
|
+
upload_features: bool = dataclasses.field(default=True)
|
|
32
|
+
# for training
|
|
33
|
+
root_path: str = dataclasses.field(default=None)
|
|
34
|
+
data_path: str = dataclasses.field(default=None)
|
|
35
|
+
output_path: str = dataclasses.field(default=None)
|
|
36
|
+
|
|
37
|
+
def __post_init__(self):
|
|
38
|
+
# Initialize the internal dictionary with the dataclass fields
|
|
39
|
+
self.update(**dataclasses.asdict(self))
|
|
40
|
+
|
|
41
|
+
def update(self, **kwargs):
|
|
42
|
+
for f in dataclasses.fields(AdapterDefaults):
|
|
43
|
+
if f.name in kwargs:
|
|
44
|
+
setattr(self, f.name, kwargs[f.name])
|
|
45
|
+
super().update(**kwargs)
|
|
46
|
+
|
|
47
|
+
def resolve(self, key, *args):
|
|
48
|
+
|
|
49
|
+
for arg in args:
|
|
50
|
+
if arg is not None:
|
|
51
|
+
return arg
|
|
52
|
+
return self.get(key, None)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class BaseModelAdapter(utilities.BaseServiceRunner):
|
|
56
|
+
_client_api = attr.ib(type=ApiClient, repr=False)
|
|
57
|
+
|
|
58
|
+
def __init__(self, model_entity: entities.Model = None):
|
|
59
|
+
self.adapter_defaults = AdapterDefaults()
|
|
60
|
+
self.logger = logger
|
|
61
|
+
# entities
|
|
62
|
+
self._model_entity = None
|
|
63
|
+
self._package = None
|
|
64
|
+
self._base_configuration = dict()
|
|
65
|
+
self.package_name = None
|
|
66
|
+
self.model = None
|
|
67
|
+
self.bucket_path = None
|
|
68
|
+
# funcs
|
|
69
|
+
self.item_to_batch_mapping = {'text': self._item_to_text,
|
|
70
|
+
'image': self._item_to_image}
|
|
71
|
+
if model_entity is not None:
|
|
72
|
+
self.load_from_model(model_entity=model_entity)
|
|
73
|
+
logger.warning(
|
|
74
|
+
"in case of a mismatch between 'model.name' and 'model_info.name' in the model adapter, model_info.name will be updated to align with 'model.name'.")
|
|
75
|
+
|
|
76
|
+
##################
|
|
77
|
+
# Configurations #
|
|
78
|
+
##################
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def configuration(self) -> dict:
|
|
82
|
+
# load from model
|
|
83
|
+
if self._model_entity is not None:
|
|
84
|
+
configuration = self.model_entity.configuration
|
|
85
|
+
# else - load the default from the package
|
|
86
|
+
elif self._package is not None:
|
|
87
|
+
configuration = self.package.metadata.get('system', {}).get('ml', {}).get('defaultConfiguration', {})
|
|
88
|
+
else:
|
|
89
|
+
configuration = self._base_configuration
|
|
90
|
+
return configuration
|
|
91
|
+
|
|
92
|
+
@configuration.setter
|
|
93
|
+
def configuration(self, d):
|
|
94
|
+
assert isinstance(d, dict)
|
|
95
|
+
if self._model_entity is not None:
|
|
96
|
+
self._model_entity.configuration = d
|
|
97
|
+
|
|
98
|
+
############
|
|
99
|
+
# Entities #
|
|
100
|
+
############
|
|
101
|
+
@property
|
|
102
|
+
def model_entity(self):
|
|
103
|
+
if self._model_entity is None:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
"No model entity loaded. Please load a model (adapter.load_from_model(<dl.Model>)) or set: 'adapter.model_entity=<dl.Model>'")
|
|
106
|
+
assert isinstance(self._model_entity, entities.Model)
|
|
107
|
+
return self._model_entity
|
|
108
|
+
|
|
109
|
+
@model_entity.setter
|
|
110
|
+
def model_entity(self, model_entity):
|
|
111
|
+
assert isinstance(model_entity, entities.Model)
|
|
112
|
+
if self._model_entity is not None and isinstance(self._model_entity, entities.Model):
|
|
113
|
+
if self._model_entity.id != model_entity.id:
|
|
114
|
+
self.logger.warning(
|
|
115
|
+
'Replacing Model from {!r} to {!r}'.format(self._model_entity.name, model_entity.name))
|
|
116
|
+
self._model_entity = model_entity
|
|
117
|
+
self.package = model_entity.package
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def package(self):
|
|
121
|
+
if self._model_entity is not None:
|
|
122
|
+
self.package = self.model_entity.package
|
|
123
|
+
if self._package is None:
|
|
124
|
+
raise ValueError('Missing Package entity on adapter. Please set: "adapter.package=package"')
|
|
125
|
+
assert isinstance(self._package, (entities.Package, entities.Dpk))
|
|
126
|
+
return self._package
|
|
127
|
+
|
|
128
|
+
@package.setter
|
|
129
|
+
def package(self, package):
|
|
130
|
+
assert isinstance(package, (entities.Package, entities.Dpk))
|
|
131
|
+
self.package_name = package.name
|
|
132
|
+
self._package = package
|
|
133
|
+
|
|
134
|
+
###################################
|
|
135
|
+
# NEED TO IMPLEMENT THESE METHODS #
|
|
136
|
+
###################################
|
|
137
|
+
|
|
138
|
+
def load(self, local_path, **kwargs):
|
|
139
|
+
""" Loads model and populates self.model with a `runnable` model
|
|
140
|
+
|
|
141
|
+
Virtual method - need to implement
|
|
142
|
+
|
|
143
|
+
This function is called by load_from_model (download to local and then loads)
|
|
144
|
+
|
|
145
|
+
:param local_path: `str` directory path in local FileSystem
|
|
146
|
+
"""
|
|
147
|
+
raise NotImplementedError("Please implement `load` method in {}".format(self.__class__.__name__))
|
|
148
|
+
|
|
149
|
+
def save(self, local_path, **kwargs):
|
|
150
|
+
""" saves configuration and weights locally
|
|
151
|
+
|
|
152
|
+
Virtual method - need to implement
|
|
153
|
+
|
|
154
|
+
the function is called in save_to_model which first save locally and then uploads to model entity
|
|
155
|
+
|
|
156
|
+
:param local_path: `str` directory path in local FileSystem
|
|
157
|
+
"""
|
|
158
|
+
raise NotImplementedError("Please implement `save` method in {}".format(self.__class__.__name__))
|
|
159
|
+
|
|
160
|
+
def train(self, data_path, output_path, **kwargs):
|
|
161
|
+
"""
|
|
162
|
+
Virtual method - need to implement
|
|
163
|
+
|
|
164
|
+
Train the model according to data in data_paths and save the train outputs to output_path,
|
|
165
|
+
this include the weights and any other artifacts created during train
|
|
166
|
+
|
|
167
|
+
:param data_path: `str` local File System path to where the data was downloaded and converted at
|
|
168
|
+
:param output_path: `str` local File System path where to dump training mid-results (checkpoints, logs...)
|
|
169
|
+
"""
|
|
170
|
+
raise NotImplementedError("Please implement `train` method in {}".format(self.__class__.__name__))
|
|
171
|
+
|
|
172
|
+
def predict(self, batch, **kwargs):
|
|
173
|
+
""" Model inference (predictions) on batch of items
|
|
174
|
+
|
|
175
|
+
Virtual method - need to implement
|
|
176
|
+
|
|
177
|
+
:param batch: output of the `prepare_item_func` func
|
|
178
|
+
:return: `list[dl.AnnotationCollection]` each collection is per each image / item in the batch
|
|
179
|
+
"""
|
|
180
|
+
raise NotImplementedError("Please implement `predict` method in {}".format(self.__class__.__name__))
|
|
181
|
+
|
|
182
|
+
def embed(self, batch, **kwargs):
|
|
183
|
+
""" Extract model embeddings on batch of items
|
|
184
|
+
|
|
185
|
+
Virtual method - need to implement
|
|
186
|
+
|
|
187
|
+
:param batch: output of the `prepare_item_func` func
|
|
188
|
+
:return: `list[list]` a feature vector per each item in the batch
|
|
189
|
+
"""
|
|
190
|
+
raise NotImplementedError("Please implement `embed` method in {}".format(self.__class__.__name__))
|
|
191
|
+
|
|
192
|
+
def evaluate(self, model: entities.Model, dataset: entities.Dataset, filters: entities.Filters) -> entities.Model:
|
|
193
|
+
"""
|
|
194
|
+
This function evaluates the model prediction on a dataset (with GT annotations).
|
|
195
|
+
The evaluation process will upload the scores and metrics to the platform.
|
|
196
|
+
|
|
197
|
+
:param model: The model to evaluate (annotation.metadata.system.model.name
|
|
198
|
+
:param dataset: Dataset where the model predicted and uploaded its annotations
|
|
199
|
+
:param filters: Filters to query items on the dataset
|
|
200
|
+
:return:
|
|
201
|
+
"""
|
|
202
|
+
import dtlpymetrics
|
|
203
|
+
compare_types = model.output_type
|
|
204
|
+
if not filters:
|
|
205
|
+
filters = entities.Filters()
|
|
206
|
+
if filters is not None and isinstance(filters, dict):
|
|
207
|
+
filters = entities.Filters(custom_filter=filters)
|
|
208
|
+
model = dtlpymetrics.scoring.create_model_score(model=model,
|
|
209
|
+
dataset=dataset,
|
|
210
|
+
filters=filters,
|
|
211
|
+
compare_types=compare_types)
|
|
212
|
+
return model
|
|
213
|
+
|
|
214
|
+
def convert_from_dtlpy(self, data_path, **kwargs):
|
|
215
|
+
""" Convert Dataloop structure data to model structured
|
|
216
|
+
|
|
217
|
+
Virtual method - need to implement
|
|
218
|
+
|
|
219
|
+
e.g. take dlp dir structure and construct annotation file
|
|
220
|
+
|
|
221
|
+
:param data_path: `str` local File System directory path where we already downloaded the data from dataloop platform
|
|
222
|
+
:return:
|
|
223
|
+
"""
|
|
224
|
+
raise NotImplementedError("Please implement `convert_from_dtlpy` method in {}".format(self.__class__.__name__))
|
|
225
|
+
|
|
226
|
+
#################
|
|
227
|
+
# DTLPY METHODS #
|
|
228
|
+
################
|
|
229
|
+
def prepare_item_func(self, item: entities.Item):
|
|
230
|
+
"""
|
|
231
|
+
Prepare the Dataloop item before calling the `predict` function with a batch.
|
|
232
|
+
A user can override this function to load item differently
|
|
233
|
+
Default will load the item according the input_type (mapping type to function is in self.item_to_batch_mapping)
|
|
234
|
+
|
|
235
|
+
:param item:
|
|
236
|
+
:return: preprocessed: the var with the loaded item information (e.g. ndarray for image, dict for json files etc)
|
|
237
|
+
"""
|
|
238
|
+
# Item to batch func
|
|
239
|
+
if isinstance(self.model_entity.input_type, list):
|
|
240
|
+
if 'text' in self.model_entity.input_type and 'text' in item.mimetype:
|
|
241
|
+
processed = self._item_to_text(item)
|
|
242
|
+
elif 'image' in self.model_entity.input_type and 'image' in item.mimetype:
|
|
243
|
+
processed = self._item_to_image(item)
|
|
244
|
+
else:
|
|
245
|
+
processed = self._item_to_item(item)
|
|
246
|
+
|
|
247
|
+
elif self.model_entity.input_type in self.item_to_batch_mapping:
|
|
248
|
+
processed = self.item_to_batch_mapping[self.model_entity.input_type](item)
|
|
249
|
+
|
|
250
|
+
else:
|
|
251
|
+
processed = self._item_to_item(item)
|
|
252
|
+
|
|
253
|
+
return processed
|
|
254
|
+
|
|
255
|
+
def prepare_data(self,
|
|
256
|
+
dataset: entities.Dataset,
|
|
257
|
+
# paths
|
|
258
|
+
root_path=None,
|
|
259
|
+
data_path=None,
|
|
260
|
+
output_path=None,
|
|
261
|
+
#
|
|
262
|
+
overwrite=False,
|
|
263
|
+
**kwargs):
|
|
264
|
+
"""
|
|
265
|
+
Prepares dataset locally before training or evaluation.
|
|
266
|
+
download the specific subset selected to data_path and preforms `self.convert` to the data_path dir
|
|
267
|
+
|
|
268
|
+
:param dataset: dl.Dataset
|
|
269
|
+
:param root_path: `str` root directory for training. default is "tmp". Can be set using self.adapter_defaults.root_path
|
|
270
|
+
:param data_path: `str` dataset directory. default <root_path>/"data". Can be set using self.adapter_defaults.data_path
|
|
271
|
+
:param output_path: `str` save everything to this folder. default <root_path>/"output". Can be set using self.adapter_defaults.output_path
|
|
272
|
+
|
|
273
|
+
:param bool overwrite: overwrite the data path (download again). default is False
|
|
274
|
+
"""
|
|
275
|
+
# define paths
|
|
276
|
+
dataloop_path = service_defaults.DATALOOP_PATH
|
|
277
|
+
root_path = self.adapter_defaults.resolve("root_path", root_path)
|
|
278
|
+
data_path = self.adapter_defaults.resolve("data_path", data_path)
|
|
279
|
+
output_path = self.adapter_defaults.resolve("output_path", output_path)
|
|
280
|
+
|
|
281
|
+
if root_path is None:
|
|
282
|
+
now = datetime.datetime.now()
|
|
283
|
+
root_path = os.path.join(dataloop_path,
|
|
284
|
+
'model_data',
|
|
285
|
+
"{s_id}_{s_n}".format(s_id=self.model_entity.id, s_n=self.model_entity.name),
|
|
286
|
+
now.strftime('%Y-%m-%d-%H%M%S'),
|
|
287
|
+
)
|
|
288
|
+
if data_path is None:
|
|
289
|
+
data_path = os.path.join(root_path, 'datasets', self.model_entity.dataset.id)
|
|
290
|
+
os.makedirs(data_path, exist_ok=True)
|
|
291
|
+
if output_path is None:
|
|
292
|
+
output_path = os.path.join(root_path, 'output')
|
|
293
|
+
os.makedirs(output_path, exist_ok=True)
|
|
294
|
+
|
|
295
|
+
if len(os.listdir(data_path)) > 0:
|
|
296
|
+
self.logger.warning("Data path directory ({}) is not empty..".format(data_path))
|
|
297
|
+
|
|
298
|
+
annotation_options = entities.ViewAnnotationOptions.JSON
|
|
299
|
+
if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION]:
|
|
300
|
+
annotation_options = entities.ViewAnnotationOptions.INSTANCE
|
|
301
|
+
|
|
302
|
+
# Download the subset items
|
|
303
|
+
subsets = self.model_entity.metadata.get("system", dict()).get("subsets", None)
|
|
304
|
+
if subsets is None:
|
|
305
|
+
raise ValueError("Model (id: {}) must have subsets in metadata.system.subsets".format(self.model_entity.id))
|
|
306
|
+
for subset, filters_dict in subsets.items():
|
|
307
|
+
filters = entities.Filters(custom_filter=filters_dict)
|
|
308
|
+
data_subset_base_path = os.path.join(data_path, subset)
|
|
309
|
+
if os.path.isdir(data_subset_base_path) and not overwrite:
|
|
310
|
+
# existing and dont overwrite
|
|
311
|
+
self.logger.debug("Subset {!r} already exists (and overwrite=False). Skipping.".format(subset))
|
|
312
|
+
else:
|
|
313
|
+
self.logger.debug("Downloading subset {!r} of {}".format(subset,
|
|
314
|
+
self.model_entity.dataset.name))
|
|
315
|
+
|
|
316
|
+
annotation_filters = None
|
|
317
|
+
if self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
|
|
318
|
+
annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
|
|
319
|
+
if self.model_entity.output_type in [entities.AnnotationType.SEGMENTATION,
|
|
320
|
+
entities.AnnotationType.POLYGON]:
|
|
321
|
+
model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
|
|
322
|
+
else:
|
|
323
|
+
model_output_types = [self.model_entity.output_type]
|
|
324
|
+
|
|
325
|
+
annotation_filters.add(
|
|
326
|
+
field=entities.FiltersKnownFields.TYPE,
|
|
327
|
+
values=model_output_types,
|
|
328
|
+
operator=entities.FiltersOperations.IN
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
if not self.configuration.get("include_model_annotations", False):
|
|
332
|
+
annotation_filters.add(
|
|
333
|
+
field="metadata.system.model.name",
|
|
334
|
+
values=False,
|
|
335
|
+
operator=entities.FiltersOperations.EXISTS
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
ret_list = dataset.items.download(filters=filters,
|
|
339
|
+
local_path=data_subset_base_path,
|
|
340
|
+
annotation_options=annotation_options,
|
|
341
|
+
annotation_filters=annotation_filters
|
|
342
|
+
)
|
|
343
|
+
if isinstance(ret_list, list) and len(ret_list) == 0:
|
|
344
|
+
if annotation_filters is not None:
|
|
345
|
+
annotation_filters_str = annotation_filters.prepare()
|
|
346
|
+
raise ValueError(f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
|
|
347
|
+
f"Subset {subset} filters: {filters.prepare()}\nAnnotation filters: {annotation_filters_str}")
|
|
348
|
+
else:
|
|
349
|
+
raise ValueError(f"No items downloaded for subset {subset}! Cannot train model with empty subset.\n"
|
|
350
|
+
f"Subset {subset} filters: {filters.prepare()}")
|
|
351
|
+
|
|
352
|
+
self.convert_from_dtlpy(data_path=data_path, **kwargs)
|
|
353
|
+
return root_path, data_path, output_path
|
|
354
|
+
|
|
355
|
+
def load_from_model(self, model_entity=None, local_path=None, overwrite=True, **kwargs):
|
|
356
|
+
""" Loads a model from given `dl.Model`.
|
|
357
|
+
Reads configurations and instantiate self.model_entity
|
|
358
|
+
Downloads the model_entity bucket (if available)
|
|
359
|
+
|
|
360
|
+
:param model_entity: `str` dl.Model entity
|
|
361
|
+
:param local_path: `str` directory path in local FileSystem to download the model_entity to
|
|
362
|
+
:param overwrite: `bool` (default False) if False does not download files with same name else (True) download all
|
|
363
|
+
"""
|
|
364
|
+
if model_entity is not None:
|
|
365
|
+
self.model_entity = model_entity
|
|
366
|
+
if local_path is None:
|
|
367
|
+
local_path = os.path.join(service_defaults.DATALOOP_PATH, "models", self.model_entity.name)
|
|
368
|
+
# Load configuration
|
|
369
|
+
self.configuration = self.model_entity.configuration
|
|
370
|
+
# Update the adapter config with the model config to run over defaults if needed
|
|
371
|
+
self.adapter_defaults.update(**self.configuration)
|
|
372
|
+
# Download
|
|
373
|
+
self.model_entity.artifacts.download(
|
|
374
|
+
local_path=local_path,
|
|
375
|
+
overwrite=overwrite
|
|
376
|
+
)
|
|
377
|
+
self.load(local_path, **kwargs)
|
|
378
|
+
|
|
379
|
+
def save_to_model(self, local_path=None, cleanup=False, replace=True, **kwargs):
|
|
380
|
+
"""
|
|
381
|
+
Saves the model state to a new bucket and configuration
|
|
382
|
+
|
|
383
|
+
Saves configuration and weights to artifacts
|
|
384
|
+
Mark the model as `trained`
|
|
385
|
+
loads only applies for remote buckets
|
|
386
|
+
|
|
387
|
+
:param local_path: `str` directory path in local FileSystem to save the current model bucket (weights) (default will create a temp dir)
|
|
388
|
+
:param replace: `bool` will clean the bucket's content before uploading new files
|
|
389
|
+
:param cleanup: `bool` if True (default) remove the data from local FileSystem after upload
|
|
390
|
+
:return:
|
|
391
|
+
"""
|
|
392
|
+
|
|
393
|
+
if local_path is None:
|
|
394
|
+
local_path = tempfile.mkdtemp(prefix="model_{}".format(self.model_entity.name))
|
|
395
|
+
self.logger.debug("Using temporary dir at {}".format(local_path))
|
|
396
|
+
|
|
397
|
+
self.save(local_path=local_path, **kwargs)
|
|
398
|
+
|
|
399
|
+
if self.model_entity is None:
|
|
400
|
+
raise ValueError('Missing model entity on the adapter. '
|
|
401
|
+
'Please set before saving: "adapter.model_entity=model"')
|
|
402
|
+
|
|
403
|
+
self.model_entity.artifacts.upload(filepath=os.path.join(local_path, '*'),
|
|
404
|
+
overwrite=True)
|
|
405
|
+
if cleanup:
|
|
406
|
+
shutil.rmtree(path=local_path, ignore_errors=True)
|
|
407
|
+
self.logger.info("Clean-up. deleting {}".format(local_path))
|
|
408
|
+
|
|
409
|
+
# ===============
|
|
410
|
+
# SERVICE METHODS
|
|
411
|
+
# ===============
|
|
412
|
+
|
|
413
|
+
@entities.Package.decorators.function(display_name='Predict Items',
|
|
414
|
+
inputs={'items': 'Item[]'},
|
|
415
|
+
outputs={'items': 'Item[]', 'annotations': 'Annotation[]'})
|
|
416
|
+
def predict_items(self, items: list, upload_annotations=None, clean_annotations=None, batch_size=None, **kwargs):
|
|
417
|
+
"""
|
|
418
|
+
Run the predict function on the input list of items (or single) and return the items and the predictions.
|
|
419
|
+
Each prediction is by the model output type (package.output_type) and model_info in the metadata
|
|
420
|
+
|
|
421
|
+
:param items: `List[dl.Item]` list of items to predict
|
|
422
|
+
:param upload_annotations: `bool` uploads the predictions on the given items
|
|
423
|
+
:param clean_annotations: `bool` deletes previous model annotations (predictions) before uploading new ones
|
|
424
|
+
:param batch_size: `int` size of batch to run a single inference
|
|
425
|
+
|
|
426
|
+
:return: `List[dl.Item]`, `List[List[dl.Annotation]]`
|
|
427
|
+
"""
|
|
428
|
+
if batch_size is None:
|
|
429
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
430
|
+
upload_annotations = self.adapter_defaults.resolve("upload_annotations", upload_annotations)
|
|
431
|
+
clean_annotations = self.adapter_defaults.resolve("clean_annotations", clean_annotations)
|
|
432
|
+
input_type = self.model_entity.input_type
|
|
433
|
+
self.logger.debug(
|
|
434
|
+
"Predicting {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
435
|
+
pool = ThreadPoolExecutor(max_workers=16)
|
|
436
|
+
|
|
437
|
+
annotations = list()
|
|
438
|
+
for i_batch in tqdm.tqdm(range(0, len(items), batch_size), desc='predicting', unit='bt', leave=None,
|
|
439
|
+
file=sys.stdout):
|
|
440
|
+
batch_items = items[i_batch: i_batch + batch_size]
|
|
441
|
+
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
442
|
+
batch_collections = self.predict(batch, **kwargs)
|
|
443
|
+
_futures = list(pool.map(partial(self._update_predictions_metadata),
|
|
444
|
+
batch_items,
|
|
445
|
+
batch_collections))
|
|
446
|
+
# Loop over the futures to make sure they are all done to avoid race conditions
|
|
447
|
+
_ = [_f for _f in _futures]
|
|
448
|
+
if upload_annotations is True:
|
|
449
|
+
self.logger.debug(
|
|
450
|
+
"Uploading items' annotation for model {!r}.".format(self.model_entity.name))
|
|
451
|
+
try:
|
|
452
|
+
batch_collections = list(pool.map(partial(self._upload_model_annotations,
|
|
453
|
+
clean_annotations=clean_annotations),
|
|
454
|
+
batch_items,
|
|
455
|
+
batch_collections))
|
|
456
|
+
except Exception as err:
|
|
457
|
+
self.logger.exception("Failed to upload annotations items.")
|
|
458
|
+
|
|
459
|
+
for collection in batch_collections:
|
|
460
|
+
# function needs to return `List[List[dl.Annotation]]`
|
|
461
|
+
# convert annotation collection to a list of dl.Annotation for each batch
|
|
462
|
+
if isinstance(collection, entities.AnnotationCollection):
|
|
463
|
+
annotations.extend([annotation for annotation in collection.annotations])
|
|
464
|
+
else:
|
|
465
|
+
logger.warning(f'RETURN TYPE MAY BE INVALID: {type(collection)}')
|
|
466
|
+
annotations.extend(collection)
|
|
467
|
+
# TODO call the callback
|
|
468
|
+
|
|
469
|
+
pool.shutdown()
|
|
470
|
+
return items, annotations
|
|
471
|
+
|
|
472
|
+
@entities.Package.decorators.function(display_name='Embed Items',
|
|
473
|
+
inputs={'items': 'Item[]'},
|
|
474
|
+
outputs={'items': 'Item[]', 'features': 'Json[]'})
|
|
475
|
+
def embed_items(self, items: list, upload_features=None, batch_size=None, progress:utilities.Progress=None, **kwargs):
|
|
476
|
+
"""
|
|
477
|
+
Extract feature from an input list of items (or single) and return the items and the feature vector.
|
|
478
|
+
|
|
479
|
+
:param items: `List[dl.Item]` list of items to embed
|
|
480
|
+
:param upload_features: `bool` uploads the features on the given items
|
|
481
|
+
:param batch_size: `int` size of batch to run a single embed
|
|
482
|
+
|
|
483
|
+
:return: `List[dl.Item]`, `List[List[vector]]`
|
|
484
|
+
"""
|
|
485
|
+
if batch_size is None:
|
|
486
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
487
|
+
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
488
|
+
input_type = self.model_entity.input_type
|
|
489
|
+
self.logger.debug(
|
|
490
|
+
"Embedding {} items, using batch size {}. input type: {}".format(len(items), batch_size, input_type))
|
|
491
|
+
|
|
492
|
+
# Search for existing feature set for this model id
|
|
493
|
+
feature_set = self.model_entity.feature_set
|
|
494
|
+
if feature_set is None:
|
|
495
|
+
logger.info('Feature Set not found. creating... ')
|
|
496
|
+
try:
|
|
497
|
+
self.model_entity.project.feature_sets.get(feature_set_name=self.model_entity.name)
|
|
498
|
+
feature_set_name = f"{self.model_entity.name}-{''.join(random.choices(string.ascii_letters + string.digits, k=5))}"
|
|
499
|
+
logger.warning(
|
|
500
|
+
f"Feature set with the model name already exists. Creating new feature set with name {feature_set_name}")
|
|
501
|
+
except exceptions.NotFound:
|
|
502
|
+
feature_set_name = self.model_entity.name
|
|
503
|
+
feature_set = self.model_entity.project.feature_sets.create(name=feature_set_name,
|
|
504
|
+
entity_type=entities.FeatureEntityType.ITEM,
|
|
505
|
+
model_id=self.model_entity.id,
|
|
506
|
+
project_id=self.model_entity.project_id,
|
|
507
|
+
set_type=self.model_entity.name,
|
|
508
|
+
size=self.configuration.get('embeddings_size',
|
|
509
|
+
256))
|
|
510
|
+
logger.info(f'Feature Set created! name: {feature_set.name}, id: {feature_set.id}')
|
|
511
|
+
else:
|
|
512
|
+
logger.info(f'Feature Set found! name: {feature_set.name}, id: {feature_set.id}')
|
|
513
|
+
|
|
514
|
+
# upload the feature vectors
|
|
515
|
+
pool = ThreadPoolExecutor(max_workers=16)
|
|
516
|
+
vectors = list()
|
|
517
|
+
for i_batch in tqdm.tqdm(range(0, len(items), batch_size),
|
|
518
|
+
desc='embedding',
|
|
519
|
+
unit='bt',
|
|
520
|
+
leave=None,
|
|
521
|
+
file=sys.stdout):
|
|
522
|
+
batch_items = items[i_batch: i_batch + batch_size]
|
|
523
|
+
batch = list(pool.map(self.prepare_item_func, batch_items))
|
|
524
|
+
batch_vectors = self.embed(batch, **kwargs)
|
|
525
|
+
vectors.extend(batch_vectors)
|
|
526
|
+
if upload_features is True:
|
|
527
|
+
self.logger.debug(
|
|
528
|
+
"Uploading items' feature vectors for model {!r}.".format(self.model_entity.name))
|
|
529
|
+
try:
|
|
530
|
+
list(pool.map(partial(self._upload_model_features,
|
|
531
|
+
progress.logger if progress is not None else self.logger,
|
|
532
|
+
feature_set.id,
|
|
533
|
+
self.model_entity.project_id),
|
|
534
|
+
batch_items,
|
|
535
|
+
batch_vectors))
|
|
536
|
+
except Exception as err:
|
|
537
|
+
self.logger.exception("Failed to upload feature vectors to items.")
|
|
538
|
+
|
|
539
|
+
pool.shutdown()
|
|
540
|
+
return items, vectors
|
|
541
|
+
|
|
542
|
+
@entities.Package.decorators.function(display_name='Embed Dataset with DQL',
|
|
543
|
+
inputs={'dataset': 'Dataset',
|
|
544
|
+
'filters': 'Json'})
|
|
545
|
+
def embed_dataset(self,
|
|
546
|
+
dataset: entities.Dataset,
|
|
547
|
+
filters: entities.Filters = None,
|
|
548
|
+
upload_features=None,
|
|
549
|
+
batch_size=None,
|
|
550
|
+
progress:utilities.Progress=None,
|
|
551
|
+
**kwargs):
|
|
552
|
+
"""
|
|
553
|
+
Extract feature from all items given
|
|
554
|
+
|
|
555
|
+
:param dataset: Dataset entity to predict
|
|
556
|
+
:param filters: Filters entity for a filtering before embedding
|
|
557
|
+
:param upload_features: `bool` uploads the features back to the given items
|
|
558
|
+
:param batch_size: `int` size of batch to run a single embed
|
|
559
|
+
|
|
560
|
+
:return: `bool` indicating if the embedding process completed successfully
|
|
561
|
+
"""
|
|
562
|
+
if batch_size is None:
|
|
563
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
564
|
+
upload_features = self.adapter_defaults.resolve("upload_features", upload_features)
|
|
565
|
+
|
|
566
|
+
self.logger.debug("Creating embeddings for dataset (name:{}, id:{}), using batch size {}".format(dataset.name,
|
|
567
|
+
dataset.id,
|
|
568
|
+
batch_size))
|
|
569
|
+
if not filters:
|
|
570
|
+
filters = entities.Filters()
|
|
571
|
+
if filters is not None and isinstance(filters, dict):
|
|
572
|
+
filters = entities.Filters(custom_filter=filters)
|
|
573
|
+
pages = dataset.items.list(filters=filters, page_size=batch_size)
|
|
574
|
+
# Item type is 'file' only, can be deleted if default filters are added to custom filters
|
|
575
|
+
items = [item for page in pages for item in page if item.type == 'file']
|
|
576
|
+
self.embed_items(items=items,
|
|
577
|
+
upload_features=upload_features,
|
|
578
|
+
batch_size=batch_size,
|
|
579
|
+
progress=progress,
|
|
580
|
+
**kwargs)
|
|
581
|
+
return True
|
|
582
|
+
|
|
583
|
+
@entities.Package.decorators.function(display_name='Predict Dataset with DQL',
|
|
584
|
+
inputs={'dataset': 'Dataset',
|
|
585
|
+
'filters': 'Json'})
|
|
586
|
+
def predict_dataset(self,
|
|
587
|
+
dataset: entities.Dataset,
|
|
588
|
+
filters: entities.Filters = None,
|
|
589
|
+
upload_annotations=None,
|
|
590
|
+
clean_annotations=None,
|
|
591
|
+
batch_size=None,
|
|
592
|
+
**kwargs):
|
|
593
|
+
"""
|
|
594
|
+
Predicts all items given
|
|
595
|
+
|
|
596
|
+
:param dataset: Dataset entity to predict
|
|
597
|
+
:param filters: Filters entity for a filtering before predicting
|
|
598
|
+
:param upload_annotations: `bool` uploads the predictions back to the given items
|
|
599
|
+
:param clean_annotations: `bool` if set removes existing predictions with the same package-model name (default: False)
|
|
600
|
+
:param batch_size: `int` size of batch to run a single inference
|
|
601
|
+
|
|
602
|
+
:return: `bool` indicating if the prediction process completed successfully
|
|
603
|
+
"""
|
|
604
|
+
|
|
605
|
+
if batch_size is None:
|
|
606
|
+
batch_size = self.configuration.get('batch_size', 4)
|
|
607
|
+
|
|
608
|
+
self.logger.debug("Predicting dataset (name:{}, id:{}, using batch size {}".format(dataset.name,
|
|
609
|
+
dataset.id,
|
|
610
|
+
batch_size))
|
|
611
|
+
if not filters:
|
|
612
|
+
filters = entities.Filters()
|
|
613
|
+
if filters is not None and isinstance(filters, dict):
|
|
614
|
+
filters = entities.Filters(custom_filter=filters)
|
|
615
|
+
pages = dataset.items.list(filters=filters, page_size=batch_size)
|
|
616
|
+
# Item type is 'file' only, can be deleted if default filters are added to custom filters
|
|
617
|
+
items = [item for page in pages for item in page if item.type == 'file']
|
|
618
|
+
self.predict_items(items=items,
|
|
619
|
+
upload_annotations=upload_annotations,
|
|
620
|
+
clean_annotations=clean_annotations,
|
|
621
|
+
batch_size=batch_size,
|
|
622
|
+
**kwargs)
|
|
623
|
+
return True
|
|
624
|
+
|
|
625
|
+
@entities.Package.decorators.function(display_name='Train a Model',
|
|
626
|
+
inputs={'model': entities.Model},
|
|
627
|
+
outputs={'model': entities.Model})
|
|
628
|
+
def train_model(self,
|
|
629
|
+
model: entities.Model,
|
|
630
|
+
cleanup=False,
|
|
631
|
+
progress: utilities.Progress = None,
|
|
632
|
+
context: utilities.Context = None):
|
|
633
|
+
"""
|
|
634
|
+
Train on existing model.
|
|
635
|
+
data will be taken from dl.Model.datasetId
|
|
636
|
+
configuration is as defined in dl.Model.configuration
|
|
637
|
+
upload the output the model's bucket (model.bucket)
|
|
638
|
+
"""
|
|
639
|
+
if isinstance(model, dict):
|
|
640
|
+
model = repositories.Models(client_api=self._client_api).get(model_id=model['id'])
|
|
641
|
+
output_path = None
|
|
642
|
+
try:
|
|
643
|
+
logger.info("Received {s} for training".format(s=model.id))
|
|
644
|
+
model = model.wait_for_model_ready()
|
|
645
|
+
if model.status == 'failed':
|
|
646
|
+
raise ValueError("Model is in failed state, cannot train.")
|
|
647
|
+
|
|
648
|
+
##############
|
|
649
|
+
# Set status #
|
|
650
|
+
##############
|
|
651
|
+
model.status = 'training'
|
|
652
|
+
if context is not None:
|
|
653
|
+
if 'system' not in model.metadata:
|
|
654
|
+
model.metadata['system'] = dict()
|
|
655
|
+
model.update()
|
|
656
|
+
|
|
657
|
+
##########################
|
|
658
|
+
# load model and weights #
|
|
659
|
+
##########################
|
|
660
|
+
logger.info("Loading Adapter with: {n} ({i!r})".format(n=model.name, i=model.id))
|
|
661
|
+
self.load_from_model(model_entity=model)
|
|
662
|
+
|
|
663
|
+
################
|
|
664
|
+
# prepare data #
|
|
665
|
+
################
|
|
666
|
+
root_path, data_path, output_path = self.prepare_data(
|
|
667
|
+
dataset=self.model_entity.dataset,
|
|
668
|
+
root_path=os.path.join('tmp', model.id)
|
|
669
|
+
)
|
|
670
|
+
# Start the Train
|
|
671
|
+
logger.info("Training {p_name!r} with model {m_name!r} on data {d_path!r}".
|
|
672
|
+
format(p_name=self.package_name, m_name=model.id, d_path=data_path))
|
|
673
|
+
if progress is not None:
|
|
674
|
+
progress.update(message='starting training')
|
|
675
|
+
|
|
676
|
+
def on_epoch_end_callback(i_epoch, n_epoch):
|
|
677
|
+
if progress is not None:
|
|
678
|
+
progress.update(progress=int(100 * (i_epoch + 1) / n_epoch),
|
|
679
|
+
message='finished epoch: {}/{}'.format(i_epoch, n_epoch))
|
|
680
|
+
|
|
681
|
+
self.train(data_path=data_path,
|
|
682
|
+
output_path=output_path,
|
|
683
|
+
on_epoch_end_callback=on_epoch_end_callback)
|
|
684
|
+
if progress is not None:
|
|
685
|
+
progress.update(message='saving model',
|
|
686
|
+
progress=99)
|
|
687
|
+
|
|
688
|
+
self.save_to_model(local_path=output_path, replace=True)
|
|
689
|
+
model.status = 'trained'
|
|
690
|
+
model.update()
|
|
691
|
+
###########
|
|
692
|
+
# cleanup #
|
|
693
|
+
###########
|
|
694
|
+
if cleanup:
|
|
695
|
+
shutil.rmtree(output_path, ignore_errors=True)
|
|
696
|
+
except Exception:
|
|
697
|
+
# save also on fail
|
|
698
|
+
if output_path is not None:
|
|
699
|
+
self.save_to_model(local_path=output_path, replace=True)
|
|
700
|
+
logger.info('Execution failed. Setting model.status to failed')
|
|
701
|
+
raise
|
|
702
|
+
return model
|
|
703
|
+
|
|
704
|
+
@entities.Package.decorators.function(display_name='Evaluate a Model',
|
|
705
|
+
inputs={'model': entities.Model,
|
|
706
|
+
'dataset': entities.Dataset,
|
|
707
|
+
'filters': 'Json'},
|
|
708
|
+
outputs={'model': entities.Model,
|
|
709
|
+
'dataset': entities.Dataset
|
|
710
|
+
})
|
|
711
|
+
def evaluate_model(self,
|
|
712
|
+
model: entities.Model,
|
|
713
|
+
dataset: entities.Dataset,
|
|
714
|
+
filters: entities.Filters = None,
|
|
715
|
+
#
|
|
716
|
+
progress: utilities.Progress = None,
|
|
717
|
+
context: utilities.Context = None):
|
|
718
|
+
"""
|
|
719
|
+
Evaluate a model.
|
|
720
|
+
data will be downloaded from the dataset and query
|
|
721
|
+
configuration is as defined in dl.Model.configuration
|
|
722
|
+
upload annotations and calculate metrics vs GT
|
|
723
|
+
|
|
724
|
+
:param model: Model entity to run prediction
|
|
725
|
+
:param dataset: Dataset to evaluate
|
|
726
|
+
:param filters: Filter for specific items from dataset
|
|
727
|
+
:param progress: dl.Progress for report FaaS progress
|
|
728
|
+
:param context:
|
|
729
|
+
:return:
|
|
730
|
+
"""
|
|
731
|
+
logger.info(
|
|
732
|
+
f"Received model: {model.id} for evaluation on dataset (name: {dataset.name}, id: {dataset.id}")
|
|
733
|
+
##########################
|
|
734
|
+
# load model and weights #
|
|
735
|
+
##########################
|
|
736
|
+
logger.info(f"Loading Adapter with: {model.name} ({model.id!r})")
|
|
737
|
+
self.load_from_model(dataset=dataset,
|
|
738
|
+
model_entity=model)
|
|
739
|
+
|
|
740
|
+
##############
|
|
741
|
+
# Predicting #
|
|
742
|
+
##############
|
|
743
|
+
logger.info(f"Calling prediction, dataset: {dataset.name!r} ({model.id!r}), filters: {filters}")
|
|
744
|
+
if not filters:
|
|
745
|
+
filters = entities.Filters()
|
|
746
|
+
self.predict_dataset(dataset=dataset,
|
|
747
|
+
filters=filters,
|
|
748
|
+
with_upload=True)
|
|
749
|
+
|
|
750
|
+
##############
|
|
751
|
+
# Evaluating #
|
|
752
|
+
##############
|
|
753
|
+
logger.info(f"Starting adapter.evaluate()")
|
|
754
|
+
if progress is not None:
|
|
755
|
+
progress.update(message='calculating metrics',
|
|
756
|
+
progress=98)
|
|
757
|
+
model = self.evaluate(model=model,
|
|
758
|
+
dataset=dataset,
|
|
759
|
+
filters=filters)
|
|
760
|
+
#########
|
|
761
|
+
# Done! #
|
|
762
|
+
#########
|
|
763
|
+
if progress is not None:
|
|
764
|
+
progress.update(message='finishing evaluation',
|
|
765
|
+
progress=99)
|
|
766
|
+
return model, dataset
|
|
767
|
+
|
|
768
|
+
# =============
|
|
769
|
+
# INNER METHODS
|
|
770
|
+
# =============
|
|
771
|
+
|
|
772
|
+
@staticmethod
|
|
773
|
+
def _upload_model_features(logger, feature_set_id, project_id, item: entities.Item, vector):
|
|
774
|
+
try:
|
|
775
|
+
if vector is not None:
|
|
776
|
+
item.features.create(value=vector,
|
|
777
|
+
project_id=project_id,
|
|
778
|
+
feature_set_id=feature_set_id,
|
|
779
|
+
entity=item)
|
|
780
|
+
except Exception as e:
|
|
781
|
+
logger.error(f'Failed to upload feature vector of length {len(vector)} to item {item.id}, Error: {e}')
|
|
782
|
+
|
|
783
|
+
def _upload_model_annotations(self, item: entities.Item, predictions, clean_annotations):
|
|
784
|
+
"""
|
|
785
|
+
Utility function that upload prediction to dlp platform based on the package.output_type
|
|
786
|
+
:param predictions: `dl.AnnotationCollection`
|
|
787
|
+
:param cleanup: `bool` if set removes existing predictions with the same package-model name
|
|
788
|
+
"""
|
|
789
|
+
if not (isinstance(predictions, entities.AnnotationCollection) or isinstance(predictions, list)):
|
|
790
|
+
raise TypeError('predictions was expected to be of type {}, but instead it is {}'.
|
|
791
|
+
format(entities.AnnotationCollection, type(predictions)))
|
|
792
|
+
if clean_annotations:
|
|
793
|
+
clean_filter = entities.Filters(resource=entities.FiltersResource.ANNOTATION)
|
|
794
|
+
clean_filter.add(field='metadata.user.model.name', values=self.model_entity.name)
|
|
795
|
+
# clean_filter.add(field='type', values=self.model_entity.output_type,)
|
|
796
|
+
item.annotations.delete(filters=clean_filter)
|
|
797
|
+
annotations = item.annotations.upload(annotations=predictions)
|
|
798
|
+
return annotations
|
|
799
|
+
|
|
800
|
+
@staticmethod
|
|
801
|
+
def _item_to_image(item):
|
|
802
|
+
"""
|
|
803
|
+
Preprocess items before calling the `predict` functions.
|
|
804
|
+
Convert item to numpy array
|
|
805
|
+
|
|
806
|
+
:param item:
|
|
807
|
+
:return:
|
|
808
|
+
"""
|
|
809
|
+
buffer = item.download(save_locally=False)
|
|
810
|
+
image = np.asarray(Image.open(buffer))
|
|
811
|
+
return image
|
|
812
|
+
|
|
813
|
+
@staticmethod
|
|
814
|
+
def _item_to_item(item):
|
|
815
|
+
"""
|
|
816
|
+
Default item to batch function.
|
|
817
|
+
This function should prepare a single item for the predict function, e.g. for images, it loads the image as numpy array
|
|
818
|
+
:param item:
|
|
819
|
+
:return:
|
|
820
|
+
"""
|
|
821
|
+
return item
|
|
822
|
+
|
|
823
|
+
@staticmethod
|
|
824
|
+
def _item_to_text(item):
|
|
825
|
+
filename = item.download(overwrite=True)
|
|
826
|
+
text = None
|
|
827
|
+
if item.mimetype == 'text/plain' or item.mimetype == 'text/markdown':
|
|
828
|
+
with open(filename, 'r') as f:
|
|
829
|
+
text = f.read()
|
|
830
|
+
text = text.replace('\n', ' ')
|
|
831
|
+
else:
|
|
832
|
+
logger.warning('Item is not text file. mimetype: {}'.format(item.mimetype))
|
|
833
|
+
text = item
|
|
834
|
+
if os.path.exists(filename):
|
|
835
|
+
os.remove(filename)
|
|
836
|
+
return text
|
|
837
|
+
|
|
838
|
+
@staticmethod
|
|
839
|
+
def _uri_to_image(data_uri):
|
|
840
|
+
# data_uri = ""
|
|
841
|
+
image_b64 = data_uri.split(",")[1]
|
|
842
|
+
binary = base64.b64decode(image_b64)
|
|
843
|
+
image = np.asarray(Image.open(io.BytesIO(binary)))
|
|
844
|
+
return image
|
|
845
|
+
|
|
846
|
+
def _update_predictions_metadata(self, item: entities.Item, predictions: entities.AnnotationCollection):
|
|
847
|
+
"""
|
|
848
|
+
add model_name and model_id to the metadata of the annotations.
|
|
849
|
+
add model_info to the metadata of the system metadata of the annotation.
|
|
850
|
+
Add item id to all the annotations in the AnnotationCollection
|
|
851
|
+
|
|
852
|
+
:param item: Entity.Item
|
|
853
|
+
:param predictions: item's AnnotationCollection
|
|
854
|
+
:return:
|
|
855
|
+
"""
|
|
856
|
+
for prediction in predictions:
|
|
857
|
+
if prediction.type == entities.AnnotationType.SEGMENTATION:
|
|
858
|
+
color = None
|
|
859
|
+
try:
|
|
860
|
+
color = item.dataset._get_ontology().color_map.get(prediction.label, None)
|
|
861
|
+
except (exceptions.BadRequest, exceptions.NotFound):
|
|
862
|
+
...
|
|
863
|
+
if color is None:
|
|
864
|
+
if self.model_entity._dataset is not None:
|
|
865
|
+
try:
|
|
866
|
+
color = self.model_entity.dataset._get_ontology().color_map.get(prediction.label,
|
|
867
|
+
(255, 255, 255))
|
|
868
|
+
except (exceptions.BadRequest, exceptions.NotFound):
|
|
869
|
+
...
|
|
870
|
+
if color is None:
|
|
871
|
+
logger.warning("Can't get annotation color from model's dataset, using default.")
|
|
872
|
+
color = prediction.color
|
|
873
|
+
prediction.color = color
|
|
874
|
+
|
|
875
|
+
prediction.item_id = item.id
|
|
876
|
+
if 'user' in prediction.metadata and 'model' in prediction.metadata['user']:
|
|
877
|
+
prediction.metadata['user']['model']['model_id'] = self.model_entity.id
|
|
878
|
+
prediction.metadata['user']['model']['name'] = self.model_entity.name
|
|
879
|
+
if 'system' not in prediction.metadata:
|
|
880
|
+
prediction.metadata['system'] = dict()
|
|
881
|
+
if 'model' not in prediction.metadata['system']:
|
|
882
|
+
prediction.metadata['system']['model'] = dict()
|
|
883
|
+
confidence = prediction.metadata.get('user', dict()).get('model', dict()).get('confidence', None)
|
|
884
|
+
prediction.metadata['system']['model'] = {
|
|
885
|
+
'model_id': self.model_entity.id,
|
|
886
|
+
'name': self.model_entity.name,
|
|
887
|
+
'confidence': confidence
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
##############################
|
|
891
|
+
# Callback Factory functions #
|
|
892
|
+
##############################
|
|
893
|
+
@property
|
|
894
|
+
def dataloop_keras_callback(self):
|
|
895
|
+
"""
|
|
896
|
+
Returns the constructor for a keras api dump callback
|
|
897
|
+
The callback is used for dlp platform to show train losses
|
|
898
|
+
|
|
899
|
+
:return: DumpHistoryCallback constructor
|
|
900
|
+
"""
|
|
901
|
+
try:
|
|
902
|
+
import keras
|
|
903
|
+
except (ImportError, ModuleNotFoundError) as err:
|
|
904
|
+
raise RuntimeError(
|
|
905
|
+
'{} depends on extenral package. Please install '.format(self.__class__.__name__)) from err
|
|
906
|
+
|
|
907
|
+
import os
|
|
908
|
+
import time
|
|
909
|
+
import json
|
|
910
|
+
|
|
911
|
+
class DumpHistoryCallback(keras.callbacks.Callback):
|
|
912
|
+
def __init__(self, dump_path):
|
|
913
|
+
super().__init__()
|
|
914
|
+
if os.path.isdir(dump_path):
|
|
915
|
+
dump_path = os.path.join(dump_path,
|
|
916
|
+
'__view__training-history__{}.json'.format(time.strftime("%F-%X")))
|
|
917
|
+
self.dump_file = dump_path
|
|
918
|
+
self.data = dict()
|
|
919
|
+
|
|
920
|
+
def on_epoch_end(self, epoch, logs=None):
|
|
921
|
+
logs = logs or {}
|
|
922
|
+
for name, val in logs.items():
|
|
923
|
+
if name not in self.data:
|
|
924
|
+
self.data[name] = {'x': list(), 'y': list()}
|
|
925
|
+
self.data[name]['x'].append(float(epoch))
|
|
926
|
+
self.data[name]['y'].append(float(val))
|
|
927
|
+
self.dump_history()
|
|
928
|
+
|
|
929
|
+
def dump_history(self):
|
|
930
|
+
_json = {
|
|
931
|
+
"query": {},
|
|
932
|
+
"datasetId": "",
|
|
933
|
+
"xlabel": "epoch",
|
|
934
|
+
"title": "training loss",
|
|
935
|
+
"ylabel": "val",
|
|
936
|
+
"type": "metric",
|
|
937
|
+
"data": [{"name": name,
|
|
938
|
+
"x": values['x'],
|
|
939
|
+
"y": values['y']} for name, values in self.data.items()]
|
|
940
|
+
}
|
|
941
|
+
|
|
942
|
+
with open(self.dump_file, 'w') as f:
|
|
943
|
+
json.dump(_json, f, indent=2)
|
|
944
|
+
|
|
945
|
+
return DumpHistoryCallback
|