code-loader 1.0.93.dev4__py3-none-any.whl → 1.0.94__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of code-loader might be problematic. Click here for more details.
- code_loader/contract/datasetclasses.py +17 -1
- code_loader/contract/visualizer_classes.py +6 -2
- code_loader/inner_leap_binder/leapbinder.py +27 -2
- code_loader/inner_leap_binder/leapbinder_decorators.py +101 -156
- code_loader/leaploader.py +43 -8
- code_loader/leaploaderbase.py +9 -1
- code_loader/utils.py +12 -2
- {code_loader-1.0.93.dev4.dist-info → code_loader-1.0.94.dist-info}/METADATA +1 -1
- {code_loader-1.0.93.dev4.dist-info → code_loader-1.0.94.dist-info}/RECORD +11 -11
- {code_loader-1.0.93.dev4.dist-info → code_loader-1.0.94.dist-info}/LICENSE +0 -0
- {code_loader-1.0.93.dev4.dist-info → code_loader-1.0.94.dist-info}/WHEEL +0 -0
|
@@ -38,6 +38,9 @@ class PreprocessResponse:
|
|
|
38
38
|
sample_ids: Optional[Union[List[str], List[int]]] = None
|
|
39
39
|
state: Optional[DataStateType] = None
|
|
40
40
|
sample_id_type: Optional[Union[Type[str], Type[int]]] = None
|
|
41
|
+
sample_ids_to_instance_mappings: Optional[Dict[Union[str, int], Union[List[str], List[int]]]] = None # in use only for element instance
|
|
42
|
+
instance_to_sample_ids_mappings: Optional[Dict[Union[str, int], Union[str, int]]] = None # in use only for element instance
|
|
43
|
+
|
|
41
44
|
|
|
42
45
|
def __post_init__(self) -> None:
|
|
43
46
|
def is_valid_string(s: str) -> bool:
|
|
@@ -65,8 +68,14 @@ class PreprocessResponse:
|
|
|
65
68
|
assert self.sample_ids is not None
|
|
66
69
|
return len(self.sample_ids)
|
|
67
70
|
|
|
71
|
+
@dataclass
|
|
72
|
+
class ElementInstance:
|
|
73
|
+
name: str
|
|
74
|
+
mask: npt.NDArray[np.float32]
|
|
75
|
+
# instance_filling_type: InstanceFillingType # TODO: implement InstanceFillingType
|
|
68
76
|
|
|
69
77
|
SectionCallableInterface = Callable[[Union[int, str], PreprocessResponse], npt.NDArray[np.float32]]
|
|
78
|
+
InstanceCallableInterface = Callable[[int, PreprocessResponse], List[ElementInstance]]
|
|
70
79
|
|
|
71
80
|
MetadataSectionCallableInterface = Union[
|
|
72
81
|
Callable[[Union[int, str], PreprocessResponse], int],
|
|
@@ -188,6 +197,10 @@ class InputHandler(DatasetBaseHandler):
|
|
|
188
197
|
shape: Optional[List[int]] = None
|
|
189
198
|
channel_dim: Optional[int] = -1
|
|
190
199
|
|
|
200
|
+
@dataclass
|
|
201
|
+
class ElementInstanceMasksHandler:
|
|
202
|
+
name: str
|
|
203
|
+
function: InstanceCallableInterface
|
|
191
204
|
|
|
192
205
|
@dataclass
|
|
193
206
|
class GroundTruthHandler(DatasetBaseHandler):
|
|
@@ -205,7 +218,7 @@ class MetadataHandler:
|
|
|
205
218
|
class PredictionTypeHandler:
|
|
206
219
|
name: str
|
|
207
220
|
labels: List[str]
|
|
208
|
-
channel_dim: int
|
|
221
|
+
channel_dim: int
|
|
209
222
|
|
|
210
223
|
|
|
211
224
|
@dataclass
|
|
@@ -223,6 +236,7 @@ class DatasetIntegrationSetup:
|
|
|
223
236
|
unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
|
|
224
237
|
visualizers: List[VisualizerHandler] = field(default_factory=list)
|
|
225
238
|
inputs: List[InputHandler] = field(default_factory=list)
|
|
239
|
+
instance_masks: List[ElementInstanceMasksHandler] = field(default_factory=list)
|
|
226
240
|
ground_truths: List[GroundTruthHandler] = field(default_factory=list)
|
|
227
241
|
metadata: List[MetadataHandler] = field(default_factory=list)
|
|
228
242
|
prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
|
|
@@ -239,3 +253,5 @@ class DatasetSample:
|
|
|
239
253
|
metadata_is_none: Dict[str, bool]
|
|
240
254
|
index: Union[int, str]
|
|
241
255
|
state: DataStateEnum
|
|
256
|
+
instance_masks: Optional[Dict[str, List[ElementInstance]]] = None
|
|
257
|
+
|
|
@@ -121,19 +121,23 @@ class LeapGraph:
|
|
|
121
121
|
x_label = 'Frequency [Seconds]'
|
|
122
122
|
y_label = 'Amplitude [Voltage]'
|
|
123
123
|
x_range = (0.1, 3.0)
|
|
124
|
-
|
|
124
|
+
legend = ['experiment1', 'experiment2', 'experiment3']
|
|
125
|
+
leap_graph = LeapGraph(data=graph_data, x_label=x_label, y_label=y_label, x_range=x_range, legend=legend)
|
|
125
126
|
"""
|
|
126
127
|
data: npt.NDArray[np.float32]
|
|
127
128
|
type: LeapDataType = LeapDataType.Graph
|
|
128
129
|
x_label: Optional[str] = None
|
|
129
130
|
y_label: Optional[str] = None
|
|
130
131
|
x_range: Optional[Tuple[float,float]] = None
|
|
132
|
+
legend: Optional[List[str]] = None
|
|
131
133
|
|
|
132
134
|
def __post_init__(self) -> None:
|
|
133
135
|
validate_type(self.type, LeapDataType.Graph)
|
|
134
136
|
validate_type(type(self.data), np.ndarray)
|
|
135
137
|
validate_type(self.data.dtype, np.float32)
|
|
136
|
-
validate_type(len(self.data.shape), 2, 'Graph must be of shape 2')
|
|
138
|
+
validate_type(len(self.data.shape), 2, f'Graph must be of shape 2')
|
|
139
|
+
if self.legend:
|
|
140
|
+
validate_type(self.data.shape[1], len(self.legend), 'Number of labels supplied should equal the number of graphs')
|
|
137
141
|
validate_type(type(self.x_label), [str, type(None)], 'x_label must be a string or None')
|
|
138
142
|
validate_type(type(self.y_label), [str, type(None)], 'y_label must be a string or None')
|
|
139
143
|
validate_type(type(self.x_range), [tuple, type(None)], 'x_range must be a tuple or None')
|
|
@@ -10,14 +10,15 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
|
|
|
10
10
|
MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
|
|
11
11
|
CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
|
|
12
12
|
CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
|
|
13
|
-
RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse
|
|
13
|
+
RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse, \
|
|
14
|
+
ElementInstanceMasksHandler, InstanceCallableInterface
|
|
14
15
|
from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, DatasetMetadataType
|
|
15
16
|
from code_loader.contract.mapping import NodeConnection, NodeMapping, NodeMappingType
|
|
16
17
|
from code_loader.contract.responsedataclasses import DatasetTestResultPayload
|
|
17
18
|
from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
|
|
18
19
|
from code_loader.default_losses import loss_name_to_function
|
|
19
20
|
from code_loader.default_metrics import metrics_names_to_functions_and_direction
|
|
20
|
-
from code_loader.utils import to_numpy_return_wrapper, get_shape
|
|
21
|
+
from code_loader.utils import to_numpy_return_wrapper, get_shape, to_numpy_return_masks_wrapper
|
|
21
22
|
from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
|
|
22
23
|
default_graph_visualizer, \
|
|
23
24
|
default_image_visualizer, default_horizontal_bar_visualizer, default_word_visualizer, \
|
|
@@ -234,6 +235,30 @@ class LeapBinder:
|
|
|
234
235
|
|
|
235
236
|
self._encoder_names.append(name)
|
|
236
237
|
|
|
238
|
+
|
|
239
|
+
def set_instance_masks(self, function: InstanceCallableInterface, name: str) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Set the input handler function.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
function (SectionCallableInterface): The input handler function.
|
|
245
|
+
name (str): The name of the input section.
|
|
246
|
+
channel_dim (int): The dimension of the channels axis
|
|
247
|
+
|
|
248
|
+
Example:
|
|
249
|
+
def input_encoder(subset: PreprocessResponse, index: int) -> np.ndarray:
|
|
250
|
+
# Return the processed input data for the given index and given subset response
|
|
251
|
+
img_path = subset.`data["images"][idx]
|
|
252
|
+
img = read_img(img_path)
|
|
253
|
+
img = normalize(img)
|
|
254
|
+
return img
|
|
255
|
+
|
|
256
|
+
leap_binder.set_input(input_encoder, name='input_encoder', channel_dim=-1)
|
|
257
|
+
"""
|
|
258
|
+
function = to_numpy_return_masks_wrapper(function)
|
|
259
|
+
self.setup_container.instance_masks.append(ElementInstanceMasksHandler(name, function))
|
|
260
|
+
|
|
261
|
+
|
|
237
262
|
def add_custom_loss(self, function: CustomCallableInterface, name: str) -> None:
|
|
238
263
|
"""
|
|
239
264
|
Add a custom loss function to the setup.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
# mypy: ignore-errors
|
|
2
|
-
|
|
2
|
+
|
|
3
3
|
from typing import Optional, Union, Callable, List, Dict
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
@@ -8,7 +8,7 @@ import numpy.typing as npt
|
|
|
8
8
|
from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
|
|
9
9
|
CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
|
|
10
10
|
VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
|
|
11
|
-
ConfusionMatrixElement, SamplePreprocessResponse,
|
|
11
|
+
ConfusionMatrixElement, SamplePreprocessResponse, InstanceCallableInterface
|
|
12
12
|
from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType
|
|
13
13
|
from code_loader import leap_binder
|
|
14
14
|
from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
|
|
@@ -16,77 +16,14 @@ from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, Le
|
|
|
16
16
|
LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
|
|
17
17
|
|
|
18
18
|
|
|
19
|
-
mapping_runtime_mode_env_var_mame = '__MAPPING_RUNTIME_MODE__'
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
|
|
23
|
-
main_node_mapping = NodeMapping(name, node_mapping_type, user_unique_name, arg_names=arg_names)
|
|
24
|
-
node_inputs = {}
|
|
25
|
-
for arg_name, destination in zip(arg_names, connection_destinations):
|
|
26
|
-
node_inputs[arg_name] = destination.node_mapping
|
|
27
|
-
|
|
28
|
-
leap_binder.mapping_connections.append(NodeConnection(main_node_mapping, node_inputs))
|
|
29
|
-
|
|
30
|
-
|
|
31
19
|
def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
|
|
32
20
|
for user_unique_name, connection_destinations in connects_to.items():
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = None):
|
|
39
|
-
for i, prediction_type in enumerate(prediction_types):
|
|
40
|
-
leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
|
|
41
|
-
|
|
42
|
-
def decorating_function(load_model_func):
|
|
43
|
-
class TempMapping:
|
|
44
|
-
pass
|
|
45
|
-
|
|
46
|
-
def mapping_inner():
|
|
47
|
-
class ModelOutputPlaceholder:
|
|
48
|
-
def __init__(self):
|
|
49
|
-
self.node_mapping = NodeMapping('', NodeMappingType.Prediction0)
|
|
50
|
-
|
|
51
|
-
def __getitem__(self, key):
|
|
52
|
-
assert isinstance(key, int), \
|
|
53
|
-
f'Expected key to be an int, got {type(key)} instead.'
|
|
54
|
-
|
|
55
|
-
ret = TempMapping()
|
|
56
|
-
ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
|
|
57
|
-
return ret
|
|
58
|
-
|
|
59
|
-
class ModelPlaceholder:
|
|
60
|
-
#keras interface
|
|
61
|
-
def __call__(self, arg):
|
|
62
|
-
if isinstance(arg, list):
|
|
63
|
-
for i, elem in enumerate(arg):
|
|
64
|
-
elem.node_mapping.type = NodeMappingType[f'Input{str(i)}']
|
|
65
|
-
else:
|
|
66
|
-
arg.node_mapping.type = NodeMappingType.Input0
|
|
67
|
-
|
|
68
|
-
return ModelOutputPlaceholder()
|
|
69
|
-
|
|
70
|
-
# onnx runtime interface
|
|
71
|
-
def run(self, output_names, input_dict):
|
|
72
|
-
assert output_names is None
|
|
73
|
-
assert isinstance(input_dict, dict), \
|
|
74
|
-
f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
|
|
75
|
-
for i, elem in enumerate(input_dict.values()):
|
|
76
|
-
elem.node_mapping.type = NodeMappingType[f'Input{str(i)}']
|
|
77
|
-
|
|
78
|
-
return ModelOutputPlaceholder()
|
|
79
|
-
|
|
80
|
-
return ModelPlaceholder()
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
if os.environ[mapping_runtime_mode_env_var_mame]:
|
|
84
|
-
return mapping_inner
|
|
85
|
-
else:
|
|
86
|
-
return load_model_func
|
|
87
|
-
|
|
88
|
-
return decorating_function
|
|
21
|
+
main_node_mapping = NodeMapping(name, node_mapping_type, user_unique_name, arg_names=arg_names)
|
|
22
|
+
node_inputs = {}
|
|
23
|
+
for arg_name, destination in zip(arg_names, connection_destinations):
|
|
24
|
+
node_inputs[arg_name] = destination.node_mapping
|
|
89
25
|
|
|
26
|
+
leap_binder.mapping_connections.append(NodeConnection(main_node_mapping, node_inputs))
|
|
90
27
|
|
|
91
28
|
|
|
92
29
|
def tensorleap_custom_metric(name: str,
|
|
@@ -94,8 +31,8 @@ def tensorleap_custom_metric(name: str,
|
|
|
94
31
|
compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
|
|
95
32
|
connects_to=None):
|
|
96
33
|
def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
|
|
97
|
-
|
|
98
|
-
|
|
34
|
+
CustomMultipleReturnCallableInterfaceMultiArgs,
|
|
35
|
+
ConfusionMatrixCallableInterfaceMultiArgs]):
|
|
99
36
|
for metric_handler in leap_binder.setup_container.metrics:
|
|
100
37
|
if metric_handler.metric_handler_data.name == name:
|
|
101
38
|
raise Exception(f'Metric with name {name} already exists. '
|
|
@@ -182,31 +119,14 @@ def tensorleap_custom_metric(name: str,
|
|
|
182
119
|
(f'tensorleap_custom_metric validation failed: '
|
|
183
120
|
f'compute_insights should be boolean. Got {type(compute_insights)}.')
|
|
184
121
|
|
|
122
|
+
|
|
185
123
|
def inner(*args, **kwargs):
|
|
186
124
|
_validate_input_args(*args, **kwargs)
|
|
187
125
|
result = user_function(*args, **kwargs)
|
|
188
126
|
_validate_result(result)
|
|
189
127
|
return result
|
|
190
128
|
|
|
191
|
-
|
|
192
|
-
user_unique_name = mapping_inner.name
|
|
193
|
-
if 'user_unique_name' in kwargs:
|
|
194
|
-
user_unique_name = kwargs['user_unique_name']
|
|
195
|
-
|
|
196
|
-
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
197
|
-
ordered_connections = list(args) + ordered_connections
|
|
198
|
-
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
199
|
-
mapping_inner.name, NodeMappingType.Metric)
|
|
200
|
-
|
|
201
|
-
return None
|
|
202
|
-
|
|
203
|
-
mapping_inner.arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
|
|
204
|
-
mapping_inner.name = name
|
|
205
|
-
|
|
206
|
-
if os.environ[mapping_runtime_mode_env_var_mame]:
|
|
207
|
-
return mapping_inner
|
|
208
|
-
else:
|
|
209
|
-
return inner
|
|
129
|
+
return inner
|
|
210
130
|
|
|
211
131
|
return decorating_function
|
|
212
132
|
|
|
@@ -266,25 +186,7 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
266
186
|
_validate_result(result)
|
|
267
187
|
return result
|
|
268
188
|
|
|
269
|
-
|
|
270
|
-
user_unique_name = mapping_inner.name
|
|
271
|
-
if 'user_unique_name' in kwargs:
|
|
272
|
-
user_unique_name = kwargs['user_unique_name']
|
|
273
|
-
|
|
274
|
-
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
275
|
-
ordered_connections = list(args) + ordered_connections
|
|
276
|
-
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
277
|
-
mapping_inner.name, NodeMappingType.Visualizer)
|
|
278
|
-
|
|
279
|
-
return None
|
|
280
|
-
|
|
281
|
-
mapping_inner.arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
|
|
282
|
-
mapping_inner.name = name
|
|
283
|
-
|
|
284
|
-
if os.environ[mapping_runtime_mode_env_var_mame]:
|
|
285
|
-
return mapping_inner
|
|
286
|
-
else:
|
|
287
|
-
return inner
|
|
189
|
+
return inner
|
|
288
190
|
|
|
289
191
|
return decorating_function
|
|
290
192
|
|
|
@@ -368,6 +270,60 @@ def tensorleap_preprocess():
|
|
|
368
270
|
|
|
369
271
|
return decorating_function
|
|
370
272
|
|
|
273
|
+
def tensorleap_element_instance_preprocess(instance_mask_encoder: Callable[[int, PreprocessResponse], List[PreprocessResponse]]):
|
|
274
|
+
def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
|
|
275
|
+
def user_function_instance() -> List[PreprocessResponse]:
|
|
276
|
+
result = user_function()
|
|
277
|
+
for preprocess_response in result:
|
|
278
|
+
sample_ids_to_instance_mappings = {}
|
|
279
|
+
instance_to_sample_ids_mappings = {}
|
|
280
|
+
all_sample_ids = preprocess_response.sample_ids.copy()
|
|
281
|
+
for sample_id in preprocess_response.sample_ids:
|
|
282
|
+
instances_masks = instance_mask_encoder(sample_id, preprocess_response)
|
|
283
|
+
instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(len(instances_masks))]
|
|
284
|
+
sample_ids_to_instance_mappings[sample_id] = instances_ids
|
|
285
|
+
instance_to_sample_ids_mappings[sample_id] = sample_id
|
|
286
|
+
for instance_id in instances_ids:
|
|
287
|
+
instance_to_sample_ids_mappings[instance_id] = sample_id
|
|
288
|
+
all_sample_ids.extend(instances_ids)
|
|
289
|
+
preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
|
|
290
|
+
preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
|
|
291
|
+
preprocess_response.sample_ids = all_sample_ids
|
|
292
|
+
return result
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def metadata_is_instance(idx: str, preprocess: PreprocessResponse) -> str:
|
|
296
|
+
return "0"
|
|
297
|
+
leap_binder.set_preprocess(user_function_instance)
|
|
298
|
+
leap_binder.set_metadata(metadata_is_instance, "metadata_is_instance")
|
|
299
|
+
|
|
300
|
+
def _validate_input_args(*args, **kwargs):
|
|
301
|
+
assert len(args) == 0 and len(kwargs) == 0, \
|
|
302
|
+
(f'tensorleap_preprocess validation failed: '
|
|
303
|
+
f'The function should not take any arguments. Got {args} and {kwargs}.')
|
|
304
|
+
|
|
305
|
+
def _validate_result(result):
|
|
306
|
+
assert isinstance(result, list), \
|
|
307
|
+
(f'tensorleap_preprocess validation failed: '
|
|
308
|
+
f'The return type should be a list. Got {type(result)}.')
|
|
309
|
+
for i, response in enumerate(result):
|
|
310
|
+
assert isinstance(response, PreprocessResponse), \
|
|
311
|
+
(f'tensorleap_preprocess validation failed: '
|
|
312
|
+
f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
|
|
313
|
+
assert len(set(result)) == len(result), \
|
|
314
|
+
(f'tensorleap_preprocess validation failed: '
|
|
315
|
+
f'The return list should not contain duplicate PreprocessResponse objects.')
|
|
316
|
+
|
|
317
|
+
def inner(*args, **kwargs):
|
|
318
|
+
_validate_input_args(*args, **kwargs)
|
|
319
|
+
result = user_function_instance()
|
|
320
|
+
_validate_result(result)
|
|
321
|
+
return result
|
|
322
|
+
|
|
323
|
+
return inner
|
|
324
|
+
|
|
325
|
+
return decorating_function
|
|
326
|
+
|
|
371
327
|
|
|
372
328
|
def tensorleap_unlabeled_preprocess():
|
|
373
329
|
def decorating_function(user_function: Callable[[], PreprocessResponse]):
|
|
@@ -394,6 +350,38 @@ def tensorleap_unlabeled_preprocess():
|
|
|
394
350
|
return decorating_function
|
|
395
351
|
|
|
396
352
|
|
|
353
|
+
def tensorleap_instances_masks_encoder(name: str):
|
|
354
|
+
def decorating_function(user_function: InstanceCallableInterface):
|
|
355
|
+
leap_binder.set_instance_masks(user_function, name)
|
|
356
|
+
|
|
357
|
+
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
358
|
+
assert isinstance(sample_id, (int, str)), \
|
|
359
|
+
(f'tensorleap_input_encoder validation failed: '
|
|
360
|
+
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
361
|
+
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
362
|
+
(f'tensorleap_input_encoder validation failed: '
|
|
363
|
+
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
364
|
+
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
365
|
+
(f'tensorleap_input_encoder validation failed: '
|
|
366
|
+
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
367
|
+
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
368
|
+
|
|
369
|
+
def _validate_result(result):
|
|
370
|
+
assert isinstance(result, list), \
|
|
371
|
+
(f'tensorleap_input_encoder validation failed: '
|
|
372
|
+
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
373
|
+
|
|
374
|
+
def inner(sample_id, preprocess_response):
|
|
375
|
+
_validate_input_args(sample_id, preprocess_response)
|
|
376
|
+
result = user_function(sample_id, preprocess_response)
|
|
377
|
+
_validate_result(result)
|
|
378
|
+
return result
|
|
379
|
+
|
|
380
|
+
return inner
|
|
381
|
+
|
|
382
|
+
return decorating_function
|
|
383
|
+
|
|
384
|
+
|
|
397
385
|
def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
|
|
398
386
|
def decorating_function(user_function: SectionCallableInterface):
|
|
399
387
|
for input_handler in leap_binder.setup_container.inputs:
|
|
@@ -438,21 +426,7 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
|
|
|
438
426
|
node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
|
|
439
427
|
inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
440
428
|
|
|
441
|
-
|
|
442
|
-
def mapping_inner(*args, **kwargs):
|
|
443
|
-
class TempMapping:
|
|
444
|
-
pass
|
|
445
|
-
ret = TempMapping()
|
|
446
|
-
ret.node_mapping = mapping_inner.node_mapping
|
|
447
|
-
|
|
448
|
-
return ret
|
|
449
|
-
|
|
450
|
-
mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
451
|
-
|
|
452
|
-
if os.environ[mapping_runtime_mode_env_var_mame]:
|
|
453
|
-
return mapping_inner
|
|
454
|
-
else:
|
|
455
|
-
return inner
|
|
429
|
+
return inner
|
|
456
430
|
|
|
457
431
|
return decorating_function
|
|
458
432
|
|
|
@@ -494,20 +468,9 @@ def tensorleap_gt_encoder(name: str):
|
|
|
494
468
|
|
|
495
469
|
inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
496
470
|
|
|
497
|
-
|
|
498
|
-
class TempMapping:
|
|
499
|
-
pass
|
|
500
|
-
ret = TempMapping()
|
|
501
|
-
ret.node_mapping = mapping_inner.node_mapping
|
|
502
|
-
|
|
503
|
-
return ret
|
|
471
|
+
return inner
|
|
504
472
|
|
|
505
|
-
mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
506
473
|
|
|
507
|
-
if os.environ[mapping_runtime_mode_env_var_mame]:
|
|
508
|
-
return mapping_inner
|
|
509
|
-
else:
|
|
510
|
-
return inner
|
|
511
474
|
|
|
512
475
|
return decorating_function
|
|
513
476
|
|
|
@@ -563,25 +526,7 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
563
526
|
_validate_result(result)
|
|
564
527
|
return result
|
|
565
528
|
|
|
566
|
-
|
|
567
|
-
user_unique_name = mapping_inner.name
|
|
568
|
-
if 'user_unique_name' in kwargs:
|
|
569
|
-
user_unique_name = kwargs['user_unique_name']
|
|
570
|
-
|
|
571
|
-
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
572
|
-
ordered_connections = list(args) + ordered_connections
|
|
573
|
-
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
574
|
-
mapping_inner.name, NodeMappingType.CustomLoss)
|
|
575
|
-
|
|
576
|
-
return None
|
|
577
|
-
|
|
578
|
-
mapping_inner.arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
|
|
579
|
-
mapping_inner.name = name
|
|
580
|
-
|
|
581
|
-
if os.environ[mapping_runtime_mode_env_var_mame]:
|
|
582
|
-
return mapping_inner
|
|
583
|
-
else:
|
|
584
|
-
return inner
|
|
529
|
+
return inner
|
|
585
530
|
|
|
586
531
|
return decorating_function
|
|
587
532
|
|
code_loader/leaploader.py
CHANGED
|
@@ -2,7 +2,6 @@
|
|
|
2
2
|
import importlib.util
|
|
3
3
|
import inspect
|
|
4
4
|
import io
|
|
5
|
-
import os
|
|
6
5
|
import sys
|
|
7
6
|
from contextlib import redirect_stdout
|
|
8
7
|
from functools import lru_cache
|
|
@@ -15,7 +14,8 @@ import numpy.typing as npt
|
|
|
15
14
|
from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
|
|
16
15
|
PreprocessResponse, VisualizerHandler, LeapData, \
|
|
17
16
|
PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData, \
|
|
18
|
-
MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler, RawInputsForHeatmap, SamplePreprocessResponse
|
|
17
|
+
MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler, RawInputsForHeatmap, SamplePreprocessResponse, \
|
|
18
|
+
ElementInstance
|
|
19
19
|
from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
|
|
20
20
|
from code_loader.contract.exceptions import DatasetScriptException
|
|
21
21
|
from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
|
|
@@ -23,7 +23,6 @@ from code_loader.contract.responsedataclasses import DatasetIntegParseResult, Da
|
|
|
23
23
|
VisualizerInstance, PredictionTypeInstance, ModelSetup, CustomLayerInstance, MetricInstance, CustomLossInstance, \
|
|
24
24
|
EngineFileContract
|
|
25
25
|
from code_loader.inner_leap_binder import global_leap_binder
|
|
26
|
-
from code_loader.inner_leap_binder.leapbinder_decorators import mapping_runtime_mode_env_var_mame
|
|
27
26
|
from code_loader.leaploaderbase import LeapLoaderBase
|
|
28
27
|
from code_loader.utils import get_root_exception_file_and_line_number
|
|
29
28
|
|
|
@@ -152,6 +151,22 @@ class LeapLoader(LeapLoaderBase):
|
|
|
152
151
|
state=state)
|
|
153
152
|
return sample
|
|
154
153
|
|
|
154
|
+
def get_sample_with_masks(self, state: DataStateEnum, sample_id: Union[int, str]) -> DatasetSample:
|
|
155
|
+
self.exec_script()
|
|
156
|
+
preprocess_result = self._preprocess_result()
|
|
157
|
+
if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
|
|
158
|
+
self._preprocess_result(update_unlabeled_preprocess=True)
|
|
159
|
+
|
|
160
|
+
metadata, metadata_is_none = self._get_metadata(state, sample_id)
|
|
161
|
+
sample = DatasetSample(inputs=self._get_inputs(state, sample_id),
|
|
162
|
+
gt=None if state == DataStateEnum.unlabeled else self._get_gt(state, sample_id),
|
|
163
|
+
metadata=metadata,
|
|
164
|
+
metadata_is_none=metadata_is_none,
|
|
165
|
+
index=sample_id,
|
|
166
|
+
state=state,
|
|
167
|
+
instance_masks=self._get_masks(state, sample_id))
|
|
168
|
+
return sample
|
|
169
|
+
|
|
155
170
|
def check_dataset(self) -> DatasetIntegParseResult:
|
|
156
171
|
test_payloads: List[DatasetTestResultPayload] = []
|
|
157
172
|
setup_response = None
|
|
@@ -159,7 +174,6 @@ class LeapLoader(LeapLoaderBase):
|
|
|
159
174
|
stdout_steam = io.StringIO()
|
|
160
175
|
with redirect_stdout(stdout_steam):
|
|
161
176
|
try:
|
|
162
|
-
os.environ[mapping_runtime_mode_env_var_mame] = 'TRUE'
|
|
163
177
|
self.exec_script()
|
|
164
178
|
preprocess_test_payload = self._check_preprocess()
|
|
165
179
|
test_payloads.append(preprocess_test_payload)
|
|
@@ -176,10 +190,6 @@ class LeapLoader(LeapLoaderBase):
|
|
|
176
190
|
general_error = f"Something went wrong. {repr(e.__cause__)} in file {file_name}, line_number: {line_number}\nStacktrace:\n{stacktrace}"
|
|
177
191
|
is_valid = False
|
|
178
192
|
|
|
179
|
-
del os.environ[mapping_runtime_mode_env_var_mame]
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
193
|
print_log = stdout_steam.getvalue()
|
|
184
194
|
is_valid_for_model = bool(global_leap_binder.setup_container.custom_layers)
|
|
185
195
|
model_setup = self.get_model_setup_response()
|
|
@@ -444,6 +454,16 @@ class LeapLoader(LeapLoaderBase):
|
|
|
444
454
|
def _get_inputs(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
|
|
445
455
|
return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, sample_id)
|
|
446
456
|
|
|
457
|
+
def _get_masks(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, List[ElementInstance]]:
|
|
458
|
+
preprocess_result = self._preprocess_result()
|
|
459
|
+
preprocess_state = preprocess_result[state]
|
|
460
|
+
result_agg = {}
|
|
461
|
+
for handler in global_leap_binder.setup_container.instance_masks:
|
|
462
|
+
handler_result = handler.function(sample_id, preprocess_state)
|
|
463
|
+
handler_name = handler.name
|
|
464
|
+
result_agg[handler_name] = handler_result
|
|
465
|
+
return result_agg
|
|
466
|
+
|
|
447
467
|
def _get_gt(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, npt.NDArray[np.float32]]:
|
|
448
468
|
return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, sample_id)
|
|
449
469
|
|
|
@@ -512,3 +532,18 @@ class LeapLoader(LeapLoaderBase):
|
|
|
512
532
|
raise Exception("Different id types in preprocess results")
|
|
513
533
|
|
|
514
534
|
return id_type
|
|
535
|
+
|
|
536
|
+
def get_instances_data(self, state: DataStateEnum) -> Tuple[Dict[Union[int, str], List[Union[int, str]]], Dict[Union[int, str], Union[int, str]], List[Union[int, str]]]:
|
|
537
|
+
"""
|
|
538
|
+
This Method get the data state and returns two dictionaries that holds the mapping of the sample ids to their
|
|
539
|
+
instances and the other way around and the sample ids array.
|
|
540
|
+
Args:
|
|
541
|
+
state: DataStateEnum state
|
|
542
|
+
Returns:
|
|
543
|
+
sample_ids_to_instance_mappings: sample id to instance mappings
|
|
544
|
+
instance_to_sample_ids_mappings: instance to sample ids mappings
|
|
545
|
+
sample_ids: sample ids array
|
|
546
|
+
"""
|
|
547
|
+
preprocess_result = self._preprocess_result()
|
|
548
|
+
preprocess_state = preprocess_result[state]
|
|
549
|
+
return preprocess_state.sample_ids_to_instance_mappings, preprocess_state.instance_to_sample_ids_mappings, preprocess_state.sample_ids
|
code_loader/leaploaderbase.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
|
|
5
|
-
from typing import Dict, List, Union, Type, Optional
|
|
5
|
+
from typing import Dict, List, Union, Type, Optional, Tuple
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
8
|
import numpy.typing as npt
|
|
@@ -64,6 +64,14 @@ class LeapLoaderBase:
|
|
|
64
64
|
def get_sample(self, state: DataStateEnum, sample_id: Union[int, str]) -> DatasetSample:
|
|
65
65
|
pass
|
|
66
66
|
|
|
67
|
+
@abstractmethod
|
|
68
|
+
def get_sample_with_masks(self, state: DataStateEnum, sample_id: Union[int, str]) -> DatasetSample:
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@abstractmethod
|
|
72
|
+
def get_instances_data(self, state: DataStateEnum) -> Tuple[Dict[Union[int, str], List[Union[int, str]]], Dict[Union[int, str], Union[int, str]], List[Union[int, str]]]:
|
|
73
|
+
pass
|
|
74
|
+
|
|
67
75
|
@abstractmethod
|
|
68
76
|
def check_dataset(self) -> DatasetIntegParseResult:
|
|
69
77
|
pass
|
code_loader/utils.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import sys
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from types import TracebackType
|
|
4
|
-
from typing import List, Union, Tuple, Any
|
|
4
|
+
from typing import List, Union, Tuple, Any, Callable
|
|
5
5
|
import traceback
|
|
6
6
|
import numpy as np
|
|
7
7
|
import numpy.typing as npt
|
|
8
8
|
|
|
9
|
-
from code_loader.contract.datasetclasses import SectionCallableInterface, PreprocessResponse
|
|
9
|
+
from code_loader.contract.datasetclasses import SectionCallableInterface, PreprocessResponse, \
|
|
10
|
+
InstanceCallableInterface, ElementInstance
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
def to_numpy_return_wrapper(encoder_function: SectionCallableInterface) -> SectionCallableInterface:
|
|
@@ -17,6 +18,15 @@ def to_numpy_return_wrapper(encoder_function: SectionCallableInterface) -> Secti
|
|
|
17
18
|
|
|
18
19
|
return numpy_encoder_function
|
|
19
20
|
|
|
21
|
+
def to_numpy_return_masks_wrapper(encoder_function: InstanceCallableInterface) -> Callable[
|
|
22
|
+
[Union[int, str], PreprocessResponse], List[ElementInstance]]:
|
|
23
|
+
def numpy_encoder_function(idx: Union[int, str], samples: PreprocessResponse) -> List[ElementInstance]:
|
|
24
|
+
result = encoder_function(idx, samples)
|
|
25
|
+
for res in result:
|
|
26
|
+
res.mask = np.array(res.mask)
|
|
27
|
+
return result
|
|
28
|
+
return numpy_encoder_function
|
|
29
|
+
|
|
20
30
|
|
|
21
31
|
def get_root_traceback(exc_tb: TracebackType) -> TracebackType:
|
|
22
32
|
return_traceback = exc_tb
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
|
2
2
|
code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
|
|
3
3
|
code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
code_loader/contract/datasetclasses.py,sha256=
|
|
4
|
+
code_loader/contract/datasetclasses.py,sha256=bNqxut_Ifm0upGRfyTl7j65PysS8xHAmuqTqei_B5Zk,8744
|
|
5
5
|
code_loader/contract/enums.py,sha256=GEFkvUMXnCNt-GOoz7NJ9ecQZ2PPDettJNOsxsiM0wk,1622
|
|
6
6
|
code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
|
|
7
7
|
code_loader/contract/mapping.py,sha256=e11h_sprwOyE32PcqgRq9JvyahQrPzwqgkhmbQLKLQY,1165
|
|
8
8
|
code_loader/contract/responsedataclasses.py,sha256=6-5DJkYBdXb3UB1eNidTTPPBIYxMjEoMdYDkp9VhH8o,4223
|
|
9
|
-
code_loader/contract/visualizer_classes.py,sha256=
|
|
9
|
+
code_loader/contract/visualizer_classes.py,sha256=Wz9eItmoRaKEHa3p0aW0Ypxx4_xUmaZyLBznnTuxwi0,15425
|
|
10
10
|
code_loader/default_losses.py,sha256=NoOQym1106bDN5dcIk56Elr7ZG5quUHArqfP5-Nyxyo,1139
|
|
11
11
|
code_loader/default_metrics.py,sha256=v16Mrt2Ze1tXPgfKywGVdRSrkaK4CKLNQztN1UdVqIY,5010
|
|
12
12
|
code_loader/experiment_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -20,14 +20,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
|
|
|
20
20
|
code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
|
|
21
21
|
code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
|
|
22
22
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
|
23
|
-
code_loader/inner_leap_binder/leapbinder.py,sha256=
|
|
24
|
-
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=
|
|
25
|
-
code_loader/leaploader.py,sha256=
|
|
26
|
-
code_loader/leaploaderbase.py,sha256=
|
|
27
|
-
code_loader/utils.py,sha256=
|
|
23
|
+
code_loader/inner_leap_binder/leapbinder.py,sha256=eHnjPfJvYQDQsBM55sf63kI-NC2M-lOB4cwxjYHNTkk,32766
|
|
24
|
+
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=hfT2mdr6SB90vNu42ZxPjgRWaCFuqAqnwAzgb7cWvJI,28955
|
|
25
|
+
code_loader/leaploader.py,sha256=K532J4Z8YUjyBpTagJPff4PD0dkuwT9szgBbwBiWwwY,28846
|
|
26
|
+
code_loader/leaploaderbase.py,sha256=tpMVEd97675b_var4hvesjN7EgQzoCbPEayNBut6AvI,4551
|
|
27
|
+
code_loader/utils.py,sha256=_j8b60pimoNAvWMRj7hEkkT6C76qES6cZoBFHpXHMxA,2698
|
|
28
28
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
29
29
|
code_loader/visualizers/default_visualizers.py,sha256=669lBpLISLO6my5Qcgn1FLDDeZgHumPf252m4KHY4YM,2555
|
|
30
|
-
code_loader-1.0.
|
|
31
|
-
code_loader-1.0.
|
|
32
|
-
code_loader-1.0.
|
|
33
|
-
code_loader-1.0.
|
|
30
|
+
code_loader-1.0.94.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
|
31
|
+
code_loader-1.0.94.dist-info/METADATA,sha256=l0E_zdPlClM6EIp1kGIu4szoallA_t_hnfyH9qRJz-w,849
|
|
32
|
+
code_loader-1.0.94.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
33
|
+
code_loader-1.0.94.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|