code-loader 1.0.77.dev4__py3-none-any.whl → 1.0.77.dev30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/contract/datasetclasses.py +2 -5
- code_loader/inner_leap_binder/inner_classes.py +34 -0
- code_loader/inner_leap_binder/leapbinder.py +31 -54
- code_loader/inner_leap_binder/leapbinder_decorators.py +15 -11
- code_loader/inner_leap_binder/mapping.py +219 -0
- code_loader/leaploader.py +14 -26
- {code_loader-1.0.77.dev4.dist-info → code_loader-1.0.77.dev30.dist-info}/METADATA +1 -1
- {code_loader-1.0.77.dev4.dist-info → code_loader-1.0.77.dev30.dist-info}/RECORD +10 -8
- {code_loader-1.0.77.dev4.dist-info → code_loader-1.0.77.dev30.dist-info}/LICENSE +0 -0
- {code_loader-1.0.77.dev4.dist-info → code_loader-1.0.77.dev30.dist-info}/WHEEL +0 -0
@@ -4,8 +4,7 @@ import re
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
6
6
|
|
7
|
-
from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue,
|
8
|
-
MetricDirection, DatasetMetadataType
|
7
|
+
from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
|
9
8
|
from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
|
10
9
|
LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
|
11
10
|
|
@@ -197,7 +196,6 @@ class GroundTruthHandler(DatasetBaseHandler):
|
|
197
196
|
class MetadataHandler:
|
198
197
|
name: str
|
199
198
|
function: MetadataSectionCallableInterface
|
200
|
-
metadata_type: Optional[Union[DatasetMetadataType, Dict[str, DatasetMetadataType]]] = None
|
201
199
|
|
202
200
|
|
203
201
|
@dataclass
|
@@ -234,7 +232,6 @@ class DatasetIntegrationSetup:
|
|
234
232
|
class DatasetSample:
|
235
233
|
inputs: Dict[str, npt.NDArray[np.float32]]
|
236
234
|
gt: Optional[Dict[str, npt.NDArray[np.float32]]]
|
237
|
-
metadata: Dict[str, Union[
|
238
|
-
metadata_is_none: Dict[str, bool]
|
235
|
+
metadata: Dict[str, Union[str, int, bool, float]]
|
239
236
|
index: Union[int, str]
|
240
237
|
state: DataStateEnum
|
@@ -0,0 +1,34 @@
|
|
1
|
+
|
2
|
+
from typing import List, Dict, Union, Callable
|
3
|
+
from enum import Enum, IntEnum
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import Union
|
6
|
+
|
7
|
+
class FunctionType(IntEnum):
|
8
|
+
gt = 0
|
9
|
+
visualizer = 1
|
10
|
+
input = 2
|
11
|
+
metric = 3
|
12
|
+
loss = 4
|
13
|
+
model_input = 5
|
14
|
+
model_output = 6
|
15
|
+
|
16
|
+
@dataclass(frozen=True)
|
17
|
+
class TLFunction:
|
18
|
+
FuncType: FunctionType
|
19
|
+
FuncName: str
|
20
|
+
|
21
|
+
MappingList = List[Dict[TLFunction, Dict[str, Union[Callable, int]]]]
|
22
|
+
|
23
|
+
def leap_input(idx):
|
24
|
+
def dummy():
|
25
|
+
return None
|
26
|
+
dummy.tl_func = TLFunction(FunctionType.model_input, idx)
|
27
|
+
return dummy
|
28
|
+
|
29
|
+
def leap_output(idx):
|
30
|
+
def dummy():
|
31
|
+
return None
|
32
|
+
dummy.tl_func = TLFunction(FunctionType.model_output, idx)
|
33
|
+
return dummy
|
34
|
+
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import Callable, List, Optional, Dict, Any, Type, Union
|
2
|
+
from typing import Callable, List, Optional, Dict, Any, Type, Union, Tuple
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
import numpy.typing as npt
|
@@ -11,7 +11,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
|
|
11
11
|
CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
|
12
12
|
CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
|
13
13
|
RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse
|
14
|
-
from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
|
14
|
+
from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
|
15
15
|
from code_loader.contract.responsedataclasses import DatasetTestResultPayload
|
16
16
|
from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
|
17
17
|
from code_loader.default_losses import loss_name_to_function
|
@@ -21,6 +21,7 @@ from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
|
|
21
21
|
default_graph_visualizer, \
|
22
22
|
default_image_visualizer, default_horizontal_bar_visualizer, default_word_visualizer, \
|
23
23
|
default_image_mask_visualizer, default_text_mask_visualizer, default_raw_data_visualizer
|
24
|
+
from code_loader.inner_leap_binder.inner_classes import FunctionType, TLFunction, MappingList, leap_output, leap_input
|
24
25
|
|
25
26
|
|
26
27
|
class LeapBinder:
|
@@ -37,6 +38,7 @@ class LeapBinder:
|
|
37
38
|
def __init__(self) -> None:
|
38
39
|
self.setup_container = DatasetIntegrationSetup()
|
39
40
|
self.cache_container: Dict[str, Any] = {"word_to_index": {}}
|
41
|
+
self._mappings: MappingList = list()
|
40
42
|
self._visualizer_names: List[str] = list()
|
41
43
|
self._encoder_names: List[str] = list()
|
42
44
|
self._extend_with_default_visualizers()
|
@@ -72,7 +74,8 @@ class LeapBinder:
|
|
72
74
|
def set_visualizer(self, function: VisualizerCallableInterface,
|
73
75
|
name: str,
|
74
76
|
visualizer_type: LeapDataType,
|
75
|
-
heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None
|
77
|
+
heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
|
78
|
+
connects_to = None) -> None: #TODO add types
|
76
79
|
"""
|
77
80
|
Set a visualizer for a specific data type.
|
78
81
|
|
@@ -139,10 +142,11 @@ class LeapBinder:
|
|
139
142
|
raise Exception(
|
140
143
|
f'The return type of function {function.__name__} is invalid. current return type: {return_type}, '
|
141
144
|
f'should be {expected_return_type}')
|
142
|
-
|
143
145
|
self.setup_container.visualizers.append(
|
144
146
|
VisualizerHandler(VisualizerHandlerData(name, visualizer_type, arg_names), function, heatmap_visualizer))
|
145
147
|
self._visualizer_names.append(name)
|
148
|
+
if connects_to is not None:
|
149
|
+
self._mappings.append({TLFunction(FunctionType.visualizer, name): connects_to})
|
146
150
|
|
147
151
|
def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
|
148
152
|
"""
|
@@ -259,7 +263,8 @@ class LeapBinder:
|
|
259
263
|
name: str,
|
260
264
|
direction: Optional[
|
261
265
|
Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward,
|
262
|
-
compute_insights: Union[bool, Dict[str, bool]] = True
|
266
|
+
compute_insights: Union[bool, Dict[str, bool]] = True,
|
267
|
+
connects_to = None) -> None: #TODO TOM add types
|
263
268
|
"""
|
264
269
|
Add a custom metric to the setup.
|
265
270
|
|
@@ -285,6 +290,8 @@ class LeapBinder:
|
|
285
290
|
arg_names = inspect.getfullargspec(function)[0]
|
286
291
|
metric_handler_data = MetricHandlerData(name, arg_names, direction, compute_insights)
|
287
292
|
self.setup_container.metrics.append(MetricHandler(metric_handler_data, function))
|
293
|
+
if connects_to is not None:
|
294
|
+
self._mappings.append({TLFunction(FunctionType.metric, name): connects_to})
|
288
295
|
|
289
296
|
def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
|
290
297
|
"""
|
@@ -333,8 +340,7 @@ class LeapBinder:
|
|
333
340
|
|
334
341
|
self._encoder_names.append(name)
|
335
342
|
|
336
|
-
def set_metadata(self, function: MetadataSectionCallableInterface, name: str
|
337
|
-
metadata_type: Optional[Union[DatasetMetadataType, Dict[str, DatasetMetadataType]]] = None) -> None:
|
343
|
+
def set_metadata(self, function: MetadataSectionCallableInterface, name: str) -> None:
|
338
344
|
"""
|
339
345
|
Set the metadata handler function. This function is used for measuring and analyzing external variable values per sample, which is recommended for analysis within the Tensorleap platform.
|
340
346
|
|
@@ -369,7 +375,7 @@ class LeapBinder:
|
|
369
375
|
leap_binder.set_metadata(metadata_handler_index, name='metadata_index')
|
370
376
|
leap_binder.set_metadata(metadata_handler_image_mean, name='metadata_image_mean')
|
371
377
|
"""
|
372
|
-
self.setup_container.metadata.append(MetadataHandler(name, function
|
378
|
+
self.setup_container.metadata.append(MetadataHandler(name, function))
|
373
379
|
|
374
380
|
def set_custom_layer(self, custom_layer: Type[Any], name: str, inspect_layer: bool = False,
|
375
381
|
kernel_index: Optional[int] = None, use_custom_latent_space: bool = False) -> None:
|
@@ -466,55 +472,23 @@ class LeapBinder:
|
|
466
472
|
@staticmethod
|
467
473
|
def check_handler(
|
468
474
|
preprocess_response: PreprocessResponse, test_result: List[DatasetTestResultPayload],
|
469
|
-
dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler]
|
475
|
+
dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler]) -> List[DatasetTestResultPayload]:
|
470
476
|
assert preprocess_response.sample_ids is not None
|
471
477
|
raw_result = dataset_base_handler.function(preprocess_response.sample_ids[0], preprocess_response)
|
472
478
|
handler_type = 'metadata' if isinstance(dataset_base_handler, MetadataHandler) else None
|
473
|
-
if isinstance(dataset_base_handler, MetadataHandler):
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
continue
|
487
|
-
|
488
|
-
if dataset_base_handler.metadata_type is None:
|
489
|
-
raise Exception(f"Metadata {single_metadata_name} is None and no metadata type is provided")
|
490
|
-
elif isinstance(dataset_base_handler.metadata_type, dict):
|
491
|
-
if single_metadata_name not in dataset_base_handler.metadata_type:
|
492
|
-
raise Exception(f"Metadata {single_metadata_name} is None and no metadata type is provided")
|
493
|
-
metadata_type = dataset_base_handler.metadata_type[single_metadata_name]
|
494
|
-
else:
|
495
|
-
raise Exception(f"Metadata {single_metadata_name} is None and no metadata type is provided")
|
496
|
-
|
497
|
-
result_shape = get_shape(single_metadata_result)
|
498
|
-
metadata_test_result.shape = result_shape
|
499
|
-
metadata_test_result.raw_result = (
|
500
|
-
single_metadata_result) if single_metadata_result is not None else metadata_type
|
501
|
-
metadata_test_result.handler_type = handler_type
|
502
|
-
test_result = metadata_test_result_payloads
|
503
|
-
else:
|
504
|
-
if raw_result is None:
|
505
|
-
if state != DataStateEnum.training:
|
506
|
-
return test_result
|
507
|
-
|
508
|
-
if dataset_base_handler.metadata_type is None:
|
509
|
-
raise Exception(f"Metadata {dataset_base_handler.name} is None and no metadata type is provided")
|
510
|
-
elif isinstance(dataset_base_handler.metadata_type, dict):
|
511
|
-
raise Exception(f"Metadata {dataset_base_handler.name} is None and no metadata type is provided")
|
512
|
-
metadata_type = dataset_base_handler.metadata_type
|
513
|
-
|
514
|
-
result_shape = get_shape(raw_result)
|
515
|
-
test_result[0].shape = result_shape
|
516
|
-
test_result[0].raw_result = raw_result if raw_result is not None else metadata_type
|
517
|
-
test_result[0].handler_type = handler_type
|
479
|
+
if isinstance(dataset_base_handler, MetadataHandler) and isinstance(raw_result, dict):
|
480
|
+
metadata_test_result_payloads = [
|
481
|
+
DatasetTestResultPayload(f'{dataset_base_handler.name}_{single_metadata_name}')
|
482
|
+
for single_metadata_name, single_metadata_result in raw_result.items()
|
483
|
+
]
|
484
|
+
for i, (single_metadata_name, single_metadata_result) in enumerate(raw_result.items()):
|
485
|
+
metadata_test_result = metadata_test_result_payloads[i]
|
486
|
+
|
487
|
+
result_shape = get_shape(single_metadata_result)
|
488
|
+
metadata_test_result.shape = result_shape
|
489
|
+
metadata_test_result.raw_result = single_metadata_result
|
490
|
+
metadata_test_result.handler_type = handler_type
|
491
|
+
test_result = metadata_test_result_payloads
|
518
492
|
else:
|
519
493
|
result_shape = get_shape(raw_result)
|
520
494
|
test_result[0].shape = result_shape
|
@@ -535,6 +509,9 @@ class LeapBinder:
|
|
535
509
|
continue
|
536
510
|
self.check_handler(preprocess_response, test_result, dataset_base_handler)
|
537
511
|
|
512
|
+
def set_model_inputs(self, connections: Tuple[Callable[..., Any], ...]):
|
513
|
+
self._mappings.append({TLFunction(FunctionType.model_input, "model"): {str(i): (connections[i], ) for i in range(len(connections))}})
|
514
|
+
|
538
515
|
def check(self) -> None:
|
539
516
|
preprocess_result = self.get_preprocess_result()
|
540
517
|
self.check_preprocess(preprocess_result)
|
@@ -9,15 +9,16 @@ from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs
|
|
9
9
|
CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
|
10
10
|
VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
|
11
11
|
ConfusionMatrixElement, SamplePreprocessResponse
|
12
|
-
from code_loader.contract.enums import MetricDirection, LeapDataType
|
12
|
+
from code_loader.contract.enums import MetricDirection, LeapDataType
|
13
13
|
from code_loader import leap_binder
|
14
14
|
from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
|
15
15
|
LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
|
16
|
-
|
16
|
+
from code_loader.inner_leap_binder.inner_classes import TLFunction, FunctionType
|
17
17
|
|
18
18
|
def tensorleap_custom_metric(name: str,
|
19
19
|
direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
|
20
|
-
compute_insights: Union[bool, Dict[str, bool]] = True
|
20
|
+
compute_insights: Union[bool, Dict[str, bool]] = True,
|
21
|
+
connects_to = None):
|
21
22
|
def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
|
22
23
|
CustomMultipleReturnCallableInterfaceMultiArgs,
|
23
24
|
ConfusionMatrixCallableInterfaceMultiArgs]):
|
@@ -26,7 +27,7 @@ def tensorleap_custom_metric(name: str,
|
|
26
27
|
raise Exception(f'Metric with name {name} already exists. '
|
27
28
|
f'Please choose another')
|
28
29
|
|
29
|
-
leap_binder.add_custom_metric(user_function, name, direction, compute_insights)
|
30
|
+
leap_binder.add_custom_metric(user_function, name, direction, compute_insights, connects_to=connects_to)
|
30
31
|
|
31
32
|
def _validate_input_args(*args, **kwargs) -> None:
|
32
33
|
for i, arg in enumerate(args):
|
@@ -106,14 +107,15 @@ def tensorleap_custom_metric(name: str,
|
|
106
107
|
|
107
108
|
|
108
109
|
def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
109
|
-
heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None
|
110
|
+
heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
|
111
|
+
connects_to = None): #TODO TOM ADD TYPES
|
110
112
|
def decorating_function(user_function: VisualizerCallableInterface):
|
111
113
|
for viz_handler in leap_binder.setup_container.visualizers:
|
112
114
|
if viz_handler.visualizer_handler_data.name == name:
|
113
115
|
raise Exception(f'Visualizer with name {name} already exists. '
|
114
116
|
f'Please choose another')
|
115
117
|
|
116
|
-
leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function)
|
118
|
+
leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function, connects_to)
|
117
119
|
|
118
120
|
def _validate_input_args(*args, **kwargs):
|
119
121
|
for i, arg in enumerate(args):
|
@@ -154,21 +156,21 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
154
156
|
result = user_function(*args, **kwargs)
|
155
157
|
_validate_result(result)
|
156
158
|
return result
|
159
|
+
inner.tl_func = TLFunction(FunctionType.visualizer, name)
|
157
160
|
|
158
161
|
return inner
|
159
162
|
|
160
163
|
return decorating_function
|
161
164
|
|
162
165
|
|
163
|
-
def tensorleap_metadata(
|
164
|
-
name: str, metadata_type: Optional[Union[DatasetMetadataType, Dict[str, DatasetMetadataType]]] = None):
|
166
|
+
def tensorleap_metadata(name: str):
|
165
167
|
def decorating_function(user_function: MetadataSectionCallableInterface):
|
166
168
|
for metadata_handler in leap_binder.setup_container.metadata:
|
167
169
|
if metadata_handler.name == name:
|
168
170
|
raise Exception(f'Metadata with name {name} already exists. '
|
169
171
|
f'Please choose another')
|
170
172
|
|
171
|
-
leap_binder.set_metadata(user_function, name
|
173
|
+
leap_binder.set_metadata(user_function, name)
|
172
174
|
|
173
175
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
174
176
|
assert isinstance(sample_id, (int, str)), \
|
@@ -183,7 +185,7 @@ def tensorleap_metadata(
|
|
183
185
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
184
186
|
|
185
187
|
def _validate_result(result):
|
186
|
-
supported_result_types = (
|
188
|
+
supported_result_types = (int, str, bool, float, dict, np.floating,
|
187
189
|
np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
|
188
190
|
assert isinstance(result, supported_result_types), \
|
189
191
|
(f'tensorleap_metadata validation failed: '
|
@@ -303,6 +305,7 @@ def tensorleap_input_encoder(name: str, channel_dim=-1):
|
|
303
305
|
result = user_function(sample_id, preprocess_response)
|
304
306
|
_validate_result(result)
|
305
307
|
return result
|
308
|
+
inner.tl_func = TLFunction(FunctionType.input, name)
|
306
309
|
|
307
310
|
return inner
|
308
311
|
|
@@ -343,7 +346,7 @@ def tensorleap_gt_encoder(name: str):
|
|
343
346
|
result = user_function(sample_id, preprocess_response)
|
344
347
|
_validate_result(result)
|
345
348
|
return result
|
346
|
-
|
349
|
+
inner.tl_func = TLFunction(FunctionType.gt, name)
|
347
350
|
return inner
|
348
351
|
|
349
352
|
return decorating_function
|
@@ -395,6 +398,7 @@ def tensorleap_custom_loss(name: str):
|
|
395
398
|
_validate_result(result)
|
396
399
|
return result
|
397
400
|
|
401
|
+
inner.tl_func = TLFunction(FunctionType.loss, name)
|
398
402
|
return inner
|
399
403
|
|
400
404
|
return decorating_function
|
@@ -0,0 +1,219 @@
|
|
1
|
+
from typing import Optional, Callable, Any, Dict, List, Set, OrderedDict
|
2
|
+
from copy import deepcopy
|
3
|
+
import numpy as np
|
4
|
+
from collections import OrderedDict as od
|
5
|
+
from code_loader.inner_leap_binder.inner_classes import FunctionType, TLFunction, MappingList
|
6
|
+
from code_loader.contract.datasetclasses import DatasetIntegrationSetup, VisualizerHandler, MetricHandler
|
7
|
+
from code_loader.contract.datasetclasses import DatasetBaseHandler, GroundTruthHandler, InputHandler
|
8
|
+
|
9
|
+
|
10
|
+
class LeapMapping:
|
11
|
+
|
12
|
+
def __init__(self, setup_container: DatasetIntegrationSetup, mapping: MappingList):
|
13
|
+
self.container : DatasetIntegrationSetup = setup_container
|
14
|
+
self.mapping: MappingList = mapping
|
15
|
+
self.max_id = 20000
|
16
|
+
self.current_id = self.max_id
|
17
|
+
self.mapping_json = []
|
18
|
+
self.handle_dict = {
|
19
|
+
FunctionType.gt : self.container.ground_truths,
|
20
|
+
FunctionType.metric : self.container.metrics,
|
21
|
+
FunctionType.visualizer: self.container.visualizers,
|
22
|
+
FunctionType.input: self.container.inputs,
|
23
|
+
}
|
24
|
+
self.handle_to_op_type = {
|
25
|
+
GroundTruthHandler: "GroundTruth",
|
26
|
+
InputHandler: "Input",
|
27
|
+
VisualizerHandler: "Visualizer"
|
28
|
+
|
29
|
+
}
|
30
|
+
self.function_to_json : Dict[TLFunction, List[Dict[str, Any]]] = {}
|
31
|
+
|
32
|
+
def create_mapping_list(self, connections_list): #TODO add types
|
33
|
+
pass
|
34
|
+
|
35
|
+
def create_connect_nodes(self):
|
36
|
+
for element in self.mapping:
|
37
|
+
for sink in element.keys():
|
38
|
+
for node_name, connected_func_list in element[sink].items():
|
39
|
+
tl_connected_functions = [ connected_func.tl_func for connected_func in connected_func_list
|
40
|
+
if not isinstance(connected_func, int)]
|
41
|
+
for i, source in enumerate(tl_connected_functions):
|
42
|
+
print(sink, source, node_name)
|
43
|
+
sink_node = self.create_node(sink, name=node_name)
|
44
|
+
self.add_connections(sink_node, source, input_idx=i)
|
45
|
+
print(1)
|
46
|
+
import yaml
|
47
|
+
with open('temp.yaml', 'w') as yf:
|
48
|
+
yaml.safe_dump(({'decoder': self.mapping_json}), yf, default_flow_style=False, sort_keys=False)
|
49
|
+
|
50
|
+
def create_partial_mapping(self): #TODO add types
|
51
|
+
referenced = self._get_referenced_sinks()
|
52
|
+
for func in referenced:
|
53
|
+
kwargs = {}
|
54
|
+
if func.FuncType == FunctionType.model_output:
|
55
|
+
kwargs = {'name': func.FuncName}
|
56
|
+
self.create_node(func, **kwargs)
|
57
|
+
self.create_connect_nodes()
|
58
|
+
|
59
|
+
def _get_referenced_sinks(self) -> Set[TLFunction]:
|
60
|
+
referenced : Set[TLFunction] = set()
|
61
|
+
references_types = {FunctionType.input, FunctionType.gt, FunctionType.model_output}
|
62
|
+
for element in self.mapping:
|
63
|
+
for func in element.keys():
|
64
|
+
for node_name, connected_func_list in element[func].items():
|
65
|
+
tl_connected_functions = [ connected_func.tl_func for connected_func in connected_func_list
|
66
|
+
if not isinstance(connected_func, int)]
|
67
|
+
for connected_func in tl_connected_functions:
|
68
|
+
if connected_func.FuncType in references_types:
|
69
|
+
referenced.add(connected_func)
|
70
|
+
return referenced
|
71
|
+
|
72
|
+
def create_gt_input(self, handle: DatasetBaseHandler, **kwargs) -> Dict[str, Any]:
|
73
|
+
op_type = self.handle_to_op_type[type(handle)]
|
74
|
+
json_node = {'operation': op_type,
|
75
|
+
'data': {'type': op_type, 'output_name': handle.name},
|
76
|
+
'id': f'{str(self.current_id)}',
|
77
|
+
'inputs': {},
|
78
|
+
'outputs': {handle.name: []}}
|
79
|
+
return json_node
|
80
|
+
|
81
|
+
def create_visualize(self, visualize_handle: VisualizerHandler ,**kwargs) -> Dict[str, Any]:
|
82
|
+
op_type = "Visualizer"
|
83
|
+
json_node = {'operation': op_type,
|
84
|
+
'data': {'type': op_type,
|
85
|
+
'name': visualize_handle.name,
|
86
|
+
'visualizer_name': visualize_handle.name,
|
87
|
+
'visualizer_type': visualize_handle.type.name,
|
88
|
+
'arg_names': deepcopy(visualize_handle.arg_names),
|
89
|
+
'user_unique_name': kwargs['name']},
|
90
|
+
'id': f'{str(self.current_id)}',
|
91
|
+
'inputs': {arg: [] for arg in visualize_handle.arg_names},
|
92
|
+
'outputs': {} }
|
93
|
+
return json_node
|
94
|
+
|
95
|
+
def create_metric(self, metric_handle: MetricHandler, **kwargs) -> Dict[str, Any]:
|
96
|
+
op_type = "Metric"
|
97
|
+
json_node = {'operation': op_type,
|
98
|
+
'data': {'type': op_type,
|
99
|
+
'name': metric_handle.name,
|
100
|
+
'metric_name': metric_handle.name,
|
101
|
+
'arg_names': deepcopy(metric_handle.arg_names),
|
102
|
+
'user_unique_name': kwargs['name']},
|
103
|
+
'id': f'{str(self.current_id)}',
|
104
|
+
'inputs': {arg: [] for arg in metric_handle.arg_names},
|
105
|
+
'outputs': {} }
|
106
|
+
return json_node
|
107
|
+
|
108
|
+
|
109
|
+
def create_temp_model_node(self, node_type: TLFunction, **kwargs) -> Dict[str, Any]:
|
110
|
+
default_name = self.get_default_name(node_type.FuncType, kwargs['name'])
|
111
|
+
op_type = "PlaceHolder"
|
112
|
+
json_node = {'operation': op_type,
|
113
|
+
'data': {'type': op_type,
|
114
|
+
'name': default_name,
|
115
|
+
'arg_names': [default_name],
|
116
|
+
'user_unique_name': default_name},
|
117
|
+
'id': f'{default_name}',
|
118
|
+
'inputs': {},
|
119
|
+
'outputs': {} }
|
120
|
+
if node_type.FuncType == FunctionType.model_input:
|
121
|
+
json_node['inputs'] = {default_name: []}
|
122
|
+
else:
|
123
|
+
json_node['outputs'] = {default_name: []}
|
124
|
+
return json_node
|
125
|
+
|
126
|
+
def _find_handle(self, node: TLFunction) -> DatasetBaseHandler:
|
127
|
+
handle_list = self.handle_dict[node.FuncType]
|
128
|
+
chosen_handle: Optional[DatasetBaseHandler] = None
|
129
|
+
for handle in handle_list:
|
130
|
+
if handle.name == node.FuncName:
|
131
|
+
chosen_handle = handle
|
132
|
+
assert chosen_handle is not None
|
133
|
+
return chosen_handle
|
134
|
+
|
135
|
+
def _find_json_node(self, node: TLFunction, **kwargs) -> Dict[str, Any]:
|
136
|
+
candidates = self.function_to_json[node]
|
137
|
+
if len(candidates) == 1:
|
138
|
+
return candidates[0]
|
139
|
+
elif len(candidates) > 1:
|
140
|
+
raise Exception("More than 1 candidate - solve")
|
141
|
+
else:
|
142
|
+
raise Exception(f"No json node creted for node {node}")
|
143
|
+
|
144
|
+
def _get_node_id(self, node: TLFunction, **kwargs):
|
145
|
+
node = self._find_json_node(node, **kwargs)
|
146
|
+
return node['id']
|
147
|
+
|
148
|
+
def create_node(self, node_type: TLFunction, **kwargs) -> Dict[str, Any]:
|
149
|
+
if node_type.FuncType in [FunctionType.model_output, FunctionType.model_input]:
|
150
|
+
json_node = self.create_temp_model_node(node_type, **kwargs)
|
151
|
+
else:
|
152
|
+
node_creator_map: Dict[FunctionType, Callable[..., Any]] = { FunctionType.gt : self.create_gt_input,
|
153
|
+
FunctionType.input: self.create_gt_input,
|
154
|
+
FunctionType.visualizer: self.create_visualize,
|
155
|
+
FunctionType.metric: self.create_metric}
|
156
|
+
handle = self._find_handle(node_type)
|
157
|
+
json_node = node_creator_map[node_type.FuncType](handle, **kwargs)
|
158
|
+
self.current_id += 1
|
159
|
+
self.function_to_json[node_type] = self.function_to_json.get(node_type, []) + [json_node]
|
160
|
+
self.mapping_json.append(json_node)
|
161
|
+
return json_node
|
162
|
+
|
163
|
+
def get_default_name(self, f_type: FunctionType, name: str):
|
164
|
+
return f"{str(f_type)}_{name}"
|
165
|
+
|
166
|
+
def _get_node_output_name(self, node: TLFunction):
|
167
|
+
if node.FuncType not in {FunctionType.model_output, FunctionType.model_input}:
|
168
|
+
from_handle = self._find_handle(node)
|
169
|
+
output_name = from_handle.name
|
170
|
+
else:
|
171
|
+
output_name = self.get_default_name(node.FuncType, node.FuncName)
|
172
|
+
return output_name
|
173
|
+
|
174
|
+
|
175
|
+
def add_connections(self, to_json: Dict[str, Any], from_node: TLFunction, input_idx: int):
|
176
|
+
try:
|
177
|
+
input_name = to_json['data']['arg_names'][input_idx]
|
178
|
+
except Exception as e:
|
179
|
+
raise Exception(f"arg names mismatch when adding input for {to_json['operation']}")
|
180
|
+
output_key = self._get_node_output_name(from_node)
|
181
|
+
from_json = self._find_json_node(from_node)
|
182
|
+
operation = from_json['operation']
|
183
|
+
from_id = from_json['id']
|
184
|
+
# Add Input
|
185
|
+
to_json['inputs'][input_name].append({'outputKey': output_key,
|
186
|
+
'operation': operation,
|
187
|
+
'id': from_id
|
188
|
+
})
|
189
|
+
# Add Output
|
190
|
+
from_json['outputs'][output_key].append({'inputKey': input_name,
|
191
|
+
'operation': to_json['operation'],
|
192
|
+
'id': to_json['id']
|
193
|
+
})
|
194
|
+
|
195
|
+
def add_output(self, input_node, output_node):
|
196
|
+
pass
|
197
|
+
|
198
|
+
def create_mapping(model_parser_map, connections_list):
|
199
|
+
|
200
|
+
referenced_gt = []
|
201
|
+
referenced_inputs = []
|
202
|
+
for element in connections_list:
|
203
|
+
for func in element.keys():
|
204
|
+
for node_name, connected_func_list in element[func].items():
|
205
|
+
tl_connected_functions = [ connected_func.tl_func if not isinstance(connected_func, int) else connected_func
|
206
|
+
for connected_func in connected_func_list ]
|
207
|
+
|
208
|
+
|
209
|
+
|
210
|
+
|
211
|
+
|
212
|
+
# Check Max Node ID
|
213
|
+
# Get Preds Node ID
|
214
|
+
|
215
|
+
# TODO Create Inputs (if referenced)
|
216
|
+
# TODO Create Gts (if referenced)
|
217
|
+
# TODO Create Visualizers
|
218
|
+
# TODO Create Metrics
|
219
|
+
# TODO Create losses
|
code_loader/leaploader.py
CHANGED
@@ -6,7 +6,7 @@ import sys
|
|
6
6
|
from contextlib import redirect_stdout
|
7
7
|
from functools import lru_cache
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable
|
9
|
+
from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable
|
10
10
|
|
11
11
|
import numpy as np
|
12
12
|
import numpy.typing as npt
|
@@ -46,7 +46,7 @@ class LeapLoader(LeapLoaderBase):
|
|
46
46
|
|
47
47
|
def evaluate_module(self) -> None:
|
48
48
|
def append_path_recursively(full_path: str) -> None:
|
49
|
-
if
|
49
|
+
if self.code_path not in full_path or full_path == '/':
|
50
50
|
return
|
51
51
|
|
52
52
|
parent_path = str(Path(full_path).parent)
|
@@ -140,11 +140,9 @@ class LeapLoader(LeapLoaderBase):
|
|
140
140
|
if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
|
141
141
|
self._preprocess_result(update_unlabeled_preprocess=True)
|
142
142
|
|
143
|
-
metadata, metadata_is_none = self._get_metadata(state, sample_id)
|
144
143
|
sample = DatasetSample(inputs=self._get_inputs(state, sample_id),
|
145
144
|
gt=None if state == DataStateEnum.unlabeled else self._get_gt(state, sample_id),
|
146
|
-
metadata=
|
147
|
-
metadata_is_none=metadata_is_none,
|
145
|
+
metadata=self._get_metadata(state, sample_id),
|
148
146
|
index=sample_id,
|
149
147
|
state=state)
|
150
148
|
return sample
|
@@ -212,7 +210,7 @@ class LeapLoader(LeapLoaderBase):
|
|
212
210
|
state_name = state.name
|
213
211
|
try:
|
214
212
|
test_result = global_leap_binder.check_handler(
|
215
|
-
preprocess_response, test_result, dataset_base_handler
|
213
|
+
preprocess_response, test_result, dataset_base_handler)
|
216
214
|
except Exception as e:
|
217
215
|
line_number, file_name, stacktrace = get_root_exception_file_and_line_number()
|
218
216
|
test_result[0].display[
|
@@ -341,11 +339,6 @@ class LeapLoader(LeapLoaderBase):
|
|
341
339
|
continue
|
342
340
|
if hasattr(handler_test_payload.raw_result, 'tolist'):
|
343
341
|
handler_test_payload.raw_result = handler_test_payload.raw_result.tolist()
|
344
|
-
if isinstance(handler_test_payload.raw_result, DatasetMetadataType):
|
345
|
-
dataset_metadata_type = handler_test_payload.raw_result
|
346
|
-
metadata_instances.append(DatasetMetadataInstance(name=handler_test_payload.name,
|
347
|
-
type=dataset_metadata_type))
|
348
|
-
continue
|
349
342
|
metadata_type = type(handler_test_payload.raw_result)
|
350
343
|
if metadata_type == int or isinstance(handler_test_payload.raw_result,
|
351
344
|
(np.unsignedinteger, np.signedinteger)):
|
@@ -445,7 +438,7 @@ class LeapLoader(LeapLoaderBase):
|
|
445
438
|
}
|
446
439
|
return metadata_name_to_type
|
447
440
|
|
448
|
-
def _convert_metadata_to_correct_type(self, metadata_name: str, value: Any) ->
|
441
|
+
def _convert_metadata_to_correct_type(self, metadata_name: str, value: Any) -> Any:
|
449
442
|
metadata_name_to_type = self._metadata_name_to_type()
|
450
443
|
metadata_type_to_python_type = {
|
451
444
|
DatasetMetadataType.float: float,
|
@@ -454,26 +447,22 @@ class LeapLoader(LeapLoaderBase):
|
|
454
447
|
DatasetMetadataType.int: int
|
455
448
|
}
|
456
449
|
metadata_type_to_default_value = {
|
457
|
-
DatasetMetadataType.float: -1
|
458
|
-
DatasetMetadataType.string:
|
450
|
+
DatasetMetadataType.float: -1,
|
451
|
+
DatasetMetadataType.string: "",
|
459
452
|
DatasetMetadataType.boolean: False,
|
460
453
|
DatasetMetadataType.int: -1
|
461
454
|
}
|
462
455
|
|
463
456
|
try:
|
464
|
-
is_none = False
|
465
|
-
if value is None:
|
466
|
-
raise ValueError()
|
467
457
|
converted_value = metadata_type_to_python_type[metadata_name_to_type[metadata_name]](value)
|
468
458
|
except ValueError:
|
469
|
-
is_none = True
|
470
459
|
converted_value = metadata_type_to_default_value[metadata_name_to_type[metadata_name]]
|
471
460
|
|
472
|
-
return converted_value
|
461
|
+
return converted_value
|
473
462
|
|
474
|
-
def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) ->
|
463
|
+
def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[
|
464
|
+
str, Union[str, int, bool, float]]:
|
475
465
|
result_agg = {}
|
476
|
-
is_none = {}
|
477
466
|
preprocess_result = self._preprocess_result()
|
478
467
|
preprocess_state = preprocess_result[state]
|
479
468
|
for handler in global_leap_binder.setup_container.metadata:
|
@@ -481,14 +470,13 @@ class LeapLoader(LeapLoaderBase):
|
|
481
470
|
if isinstance(handler_result, dict):
|
482
471
|
for single_metadata_name, single_metadata_result in handler_result.items():
|
483
472
|
handler_name = f'{handler.name}_{single_metadata_name}'
|
484
|
-
result_agg[handler_name]
|
485
|
-
|
473
|
+
result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name,
|
474
|
+
single_metadata_result)
|
486
475
|
else:
|
487
476
|
handler_name = handler.name
|
488
|
-
result_agg[handler_name]
|
489
|
-
handler_name, handler_result)
|
477
|
+
result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, handler_result)
|
490
478
|
|
491
|
-
return result_agg
|
479
|
+
return result_agg
|
492
480
|
|
493
481
|
@lru_cache()
|
494
482
|
def get_sample_id_type(self) -> Type:
|
@@ -1,7 +1,7 @@
|
|
1
1
|
LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
2
2
|
code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
|
3
3
|
code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
code_loader/contract/datasetclasses.py,sha256=
|
4
|
+
code_loader/contract/datasetclasses.py,sha256=t1pvW68dVIbYB0bllqxHXO6jOgHCJI5NGn9rSrYC6Vw,7722
|
5
5
|
code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
|
6
6
|
code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
|
7
7
|
code_loader/contract/responsedataclasses.py,sha256=RSx9m_R3LawhK5o1nAcO3hfp2F9oJYtxZr_bpP3bTmw,4005
|
@@ -19,14 +19,16 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
|
|
19
19
|
code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
|
20
20
|
code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
|
21
21
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
22
|
-
code_loader/inner_leap_binder/
|
23
|
-
code_loader/inner_leap_binder/
|
24
|
-
code_loader/
|
22
|
+
code_loader/inner_leap_binder/inner_classes.py,sha256=DHaIaRUdplQPCaukKLG8RboW9KIkdO6kekeGw9PCmT4,726
|
23
|
+
code_loader/inner_leap_binder/leapbinder.py,sha256=jkSsoJwnNEmn8vlibjT4GczNuj4NAIM3_v-KZCLjPuw,27832
|
24
|
+
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=1oUrEbcY3GZkkaxVAdzlV4Vl3Y5XQ4aGwPWvSUb2hWM,22388
|
25
|
+
code_loader/inner_leap_binder/mapping.py,sha256=WD3p_dNMp8SQckUq62Yl0bpIBxoaZAkJszCl4IzWSms,10198
|
26
|
+
code_loader/leaploader.py,sha256=KkAPKyzZxZYI135_67q63ICVVERu7WH60liB0pboA-c,24983
|
25
27
|
code_loader/leaploaderbase.py,sha256=VH0vddRmkqLtcDlYPCO7hfz1_VbKo43lUdHDAbd4iJc,4198
|
26
28
|
code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
|
27
29
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
30
|
code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
|
29
|
-
code_loader-1.0.77.
|
30
|
-
code_loader-1.0.77.
|
31
|
-
code_loader-1.0.77.
|
32
|
-
code_loader-1.0.77.
|
31
|
+
code_loader-1.0.77.dev30.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
32
|
+
code_loader-1.0.77.dev30.dist-info/METADATA,sha256=ntsmieF-9tHybrMeZ_YqnkSvI5k9Pp3eUKi0uO-GZ0w,855
|
33
|
+
code_loader-1.0.77.dev30.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
34
|
+
code_loader-1.0.77.dev30.dist-info/RECORD,,
|
File without changes
|
File without changes
|