code-loader 1.0.71__py3-none-any.whl → 1.0.72__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/contract/datasetclasses.py +3 -1
- code_loader/contract/visualizer_classes.py +3 -2
- code_loader/inner_leap_binder/leapbinder.py +16 -9
- code_loader/inner_leap_binder/leapbinder_decorators.py +25 -12
- {code_loader-1.0.71.dist-info → code_loader-1.0.72.dist-info}/METADATA +1 -1
- {code_loader-1.0.71.dist-info → code_loader-1.0.72.dist-info}/RECORD +8 -8
- {code_loader-1.0.71.dist-info → code_loader-1.0.72.dist-info}/WHEEL +1 -1
- {code_loader-1.0.71.dist-info → code_loader-1.0.72.dist-info}/LICENSE +0 -0
@@ -132,7 +132,8 @@ class CustomLossHandler:
|
|
132
132
|
class MetricHandlerData:
|
133
133
|
name: str
|
134
134
|
arg_names: List[str]
|
135
|
-
direction:
|
135
|
+
direction: Union[None, MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward
|
136
|
+
compute_insights: Union[bool, Dict[str, bool]] = True
|
136
137
|
|
137
138
|
|
138
139
|
@dataclass
|
@@ -171,6 +172,7 @@ class InputHandler(DatasetBaseHandler):
|
|
171
172
|
shape: Optional[List[int]] = None
|
172
173
|
channel_dim: Optional[int] = -1
|
173
174
|
|
175
|
+
|
174
176
|
@dataclass
|
175
177
|
class GroundTruthHandler(DatasetBaseHandler):
|
176
178
|
shape: Optional[List[int]] = None
|
@@ -170,12 +170,13 @@ class LeapHorizontalBar:
|
|
170
170
|
Example:
|
171
171
|
body_data = np.random.rand(5).astype(np.float32)
|
172
172
|
labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
|
173
|
-
|
173
|
+
gt_data = np.array([0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32)
|
174
|
+
leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels, gt=gt_data)
|
174
175
|
"""
|
175
176
|
body: npt.NDArray[np.float32]
|
176
177
|
labels: List[str]
|
177
|
-
type: LeapDataType = LeapDataType.HorizontalBar
|
178
178
|
gt: Optional[npt.NDArray[np.float32]] = None
|
179
|
+
type: LeapDataType = LeapDataType.HorizontalBar
|
179
180
|
|
180
181
|
|
181
182
|
def __post_init__(self) -> None:
|
@@ -33,6 +33,7 @@ class LeapBinder:
|
|
33
33
|
setup_container (DatasetIntegrationSetup): Container to hold setup configurations.
|
34
34
|
cache_container (Dict[str, Any]): Cache container to store intermediate data.
|
35
35
|
"""
|
36
|
+
|
36
37
|
def __init__(self) -> None:
|
37
38
|
self.setup_container = DatasetIntegrationSetup()
|
38
39
|
self.cache_container: Dict[str, Any] = {"word_to_index": {}}
|
@@ -239,23 +240,31 @@ class LeapBinder:
|
|
239
240
|
leap_binder.add_custom_loss(custom_loss_function, name='custom_loss')
|
240
241
|
"""
|
241
242
|
arg_names = inspect.getfullargspec(function)[0]
|
242
|
-
self.setup_container.custom_loss_handlers.append(
|
243
|
+
self.setup_container.custom_loss_handlers.append(
|
244
|
+
CustomLossHandler(CustomLossHandlerData(name, arg_names), function))
|
243
245
|
|
244
246
|
def add_custom_metric(self,
|
245
247
|
function: Union[CustomCallableInterfaceMultiArgs,
|
246
248
|
CustomMultipleReturnCallableInterfaceMultiArgs,
|
247
249
|
ConfusionMatrixCallableInterfaceMultiArgs],
|
248
250
|
name: str,
|
249
|
-
direction: Optional[
|
251
|
+
direction: Optional[
|
252
|
+
Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward,
|
253
|
+
compute_insights: Union[bool, Dict[str, bool]] = True) -> None:
|
250
254
|
"""
|
251
255
|
Add a custom metric to the setup.
|
252
256
|
|
253
257
|
Args:
|
254
258
|
function (Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]): The custom metric function.
|
255
259
|
name (str): The name of the custom metric.
|
256
|
-
direction (Optional[MetricDirection]): The direction of the metric, either
|
260
|
+
direction (Optional[Union[MetricDirection, Dict[str, MetricDirection]]]): The direction of the metric, either
|
261
|
+
MetricDirection.Upward or MetricDirection.Downward, in case custom metric return a dictionary of metrics we can
|
262
|
+
supply a dictionary of directions correspondingly.
|
257
263
|
- MetricDirection.Upward: Indicates that higher values of the metric are better and should be maximized.
|
258
264
|
- MetricDirection.Downward: Indicates that lower values of the metric are better and should be minimized.
|
265
|
+
compute_insights (Union[bool, Dict[str, bool]]): Whether to compute insights or not. in case custom metric
|
266
|
+
return a dictionary of metrics we can supply a dictionary of values correspondingly
|
267
|
+
|
259
268
|
|
260
269
|
|
261
270
|
Example:
|
@@ -265,7 +274,8 @@ class LeapBinder:
|
|
265
274
|
leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
|
266
275
|
"""
|
267
276
|
arg_names = inspect.getfullargspec(function)[0]
|
268
|
-
|
277
|
+
metric_handler_data = MetricHandlerData(name, arg_names, direction, compute_insights)
|
278
|
+
self.setup_container.metrics.append(MetricHandler(metric_handler_data, function))
|
269
279
|
|
270
280
|
def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
|
271
281
|
"""
|
@@ -377,7 +387,8 @@ class LeapBinder:
|
|
377
387
|
custom_layer.kernel_index = kernel_index
|
378
388
|
|
379
389
|
if use_custom_latent_space and not hasattr(custom_layer, custom_latent_space_attribute):
|
380
|
-
raise Exception(
|
390
|
+
raise Exception(
|
391
|
+
f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
|
381
392
|
|
382
393
|
init_args = inspect.getfullargspec(custom_layer.__init__)[0][1:]
|
383
394
|
call_args = inspect.getfullargspec(custom_layer.call)[0][1:]
|
@@ -490,7 +501,3 @@ class LeapBinder:
|
|
490
501
|
|
491
502
|
def set_batch_size_to_validate(self, batch_size: int) -> None:
|
492
503
|
self.batch_size_to_validate = batch_size
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# mypy: ignore-errors
|
2
2
|
|
3
|
-
from typing import Optional, Union, Callable, List
|
3
|
+
from typing import Optional, Union, Callable, List, Dict
|
4
4
|
|
5
5
|
import numpy as np
|
6
6
|
import numpy.typing as npt
|
@@ -15,18 +15,18 @@ from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, Le
|
|
15
15
|
LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
|
16
16
|
|
17
17
|
|
18
|
-
def tensorleap_custom_metric(name: str,
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
):
|
18
|
+
def tensorleap_custom_metric(name: str,
|
19
|
+
direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
|
20
|
+
compute_insights: Union[bool, Dict[str, bool]] = True):
|
21
|
+
def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
|
22
|
+
CustomMultipleReturnCallableInterfaceMultiArgs,
|
23
|
+
ConfusionMatrixCallableInterfaceMultiArgs]):
|
24
24
|
for metric_handler in leap_binder.setup_container.metrics:
|
25
25
|
if metric_handler.metric_handler_data.name == name:
|
26
26
|
raise Exception(f'Metric with name {name} already exists. '
|
27
27
|
f'Please choose another')
|
28
28
|
|
29
|
-
leap_binder.add_custom_metric(user_function, name, direction)
|
29
|
+
leap_binder.add_custom_metric(user_function, name, direction, compute_insights)
|
30
30
|
|
31
31
|
def _validate_input_args(*args, **kwargs) -> None:
|
32
32
|
for i, arg in enumerate(args):
|
@@ -76,6 +76,19 @@ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = M
|
|
76
76
|
(f'tensorleap_custom_metric validation failed: '
|
77
77
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
78
78
|
_validate_single_metric(value)
|
79
|
+
|
80
|
+
if isinstance(direction, dict):
|
81
|
+
for direction_key in direction:
|
82
|
+
assert direction_key in result, \
|
83
|
+
(f'tensorleap_custom_metric validation failed: '
|
84
|
+
f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
|
85
|
+
|
86
|
+
if isinstance(compute_insights, dict):
|
87
|
+
for ci_key in compute_insights:
|
88
|
+
assert ci_key in result, \
|
89
|
+
(f'tensorleap_custom_metric validation failed: '
|
90
|
+
f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
|
91
|
+
|
79
92
|
else:
|
80
93
|
_validate_single_metric(result)
|
81
94
|
|
@@ -356,15 +369,15 @@ def tensorleap_custom_loss(name: str):
|
|
356
369
|
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
357
370
|
else:
|
358
371
|
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
359
|
-
|
372
|
+
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
360
373
|
for _arg_name, arg in kwargs.items():
|
361
374
|
if isinstance(arg, list):
|
362
375
|
for y, elem in enumerate(arg):
|
363
|
-
assert isinstance(elem,valid_types), (f'tensorleap_custom_loss validation failed: '
|
364
|
-
|
376
|
+
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
377
|
+
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
365
378
|
else:
|
366
379
|
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
367
|
-
|
380
|
+
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
368
381
|
|
369
382
|
def _validate_result(result):
|
370
383
|
assert isinstance(result, valid_types), \
|
@@ -1,11 +1,11 @@
|
|
1
1
|
LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
2
2
|
code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
|
3
3
|
code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
code_loader/contract/datasetclasses.py,sha256=
|
4
|
+
code_loader/contract/datasetclasses.py,sha256=vnxcit0LKNdXAMQi4EhjDiUpwsavAw2YMcEwNwmCsJo,7063
|
5
5
|
code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
|
6
6
|
code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
|
7
7
|
code_loader/contract/responsedataclasses.py,sha256=RSx9m_R3LawhK5o1nAcO3hfp2F9oJYtxZr_bpP3bTmw,4005
|
8
|
-
code_loader/contract/visualizer_classes.py,sha256=
|
8
|
+
code_loader/contract/visualizer_classes.py,sha256=m31lg2P2QJs3Reqr6-N1AlVhH3RxPr772Jw3LuIVCVM,14177
|
9
9
|
code_loader/default_losses.py,sha256=NoOQym1106bDN5dcIk56Elr7ZG5quUHArqfP5-Nyxyo,1139
|
10
10
|
code_loader/default_metrics.py,sha256=v16Mrt2Ze1tXPgfKywGVdRSrkaK4CKLNQztN1UdVqIY,5010
|
11
11
|
code_loader/experiment_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -19,14 +19,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
|
|
19
19
|
code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
|
20
20
|
code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
|
21
21
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
22
|
-
code_loader/inner_leap_binder/leapbinder.py,sha256=
|
23
|
-
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=
|
22
|
+
code_loader/inner_leap_binder/leapbinder.py,sha256=qBpN_hNANcXi6ilPBauBBWSmMt3tl4Ha5JkLe1dPzjE,26571
|
23
|
+
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=B-XSw4xYF39kMPnMTRNKMYFg09whnfl7VSbcx195VG8,21626
|
24
24
|
code_loader/leaploader.py,sha256=GWlpvgSsCWevP2BwwFBKTImQeDgHAQg1lMU9bqFMwRw,22315
|
25
25
|
code_loader/leaploaderbase.py,sha256=aHlqWDZRacIdBefeB9goYVnpApaNN2FT24uPIWKkCeQ,3090
|
26
26
|
code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
|
27
27
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
|
29
|
-
code_loader-1.0.
|
30
|
-
code_loader-1.0.
|
31
|
-
code_loader-1.0.
|
32
|
-
code_loader-1.0.
|
29
|
+
code_loader-1.0.72.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
30
|
+
code_loader-1.0.72.dist-info/METADATA,sha256=YRI0I_t6gfDnauY5w93AZSY1aAILsGSJujWfEFxErY4,849
|
31
|
+
code_loader-1.0.72.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
32
|
+
code_loader-1.0.72.dist-info/RECORD,,
|
File without changes
|