code-loader 1.0.49.dev10__py3-none-any.whl → 1.0.49.dev100__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,5 @@
1
1
  import inspect
2
+ import sys
2
3
  from typing import Callable, List, Optional, Dict, Any, Type, Union
3
4
 
4
5
  import numpy as np
@@ -37,6 +38,8 @@ class LeapBinder:
37
38
  self._encoder_names: List[str] = list()
38
39
  self._extend_with_default_visualizers()
39
40
 
41
+ self.batch_size_to_validate: Optional[int] = None
42
+
40
43
  def _extend_with_default_visualizers(self) -> None:
41
44
  self.set_visualizer(function=default_image_visualizer, name=DefaultVisualizer.Image.value,
42
45
  visualizer_type=LeapDataType.Image)
@@ -472,4 +475,19 @@ class LeapBinder:
472
475
  self.check_handlers(preprocess_result)
473
476
  print("Successful!")
474
477
 
478
+ def set_batch_size_to_validate(self, batch_size: int):
479
+ self.batch_size_to_validate = batch_size
480
+
481
+ @staticmethod
482
+ def init():
483
+ available_functions = inspect.getmembers(sys.modules[__name__], inspect.isfunction)
484
+ for func_name, func in available_functions:
485
+ if 'tensorleap_custom_metric' in str(func):
486
+ try:
487
+ func()
488
+ except:
489
+ pass
490
+
491
+
492
+
475
493
 
@@ -0,0 +1,56 @@
1
+ from typing import Optional, Union
2
+
3
+ import numpy as np
4
+
5
+ from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
6
+ CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs
7
+ from code_loader.contract.enums import MetricDirection
8
+ from code_loader import leap_binder
9
+
10
+
11
+ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = MetricDirection.Downward):
12
+ def decorating_function(
13
+ user_function: Union[CustomCallableInterfaceMultiArgs,
14
+ CustomMultipleReturnCallableInterfaceMultiArgs,
15
+ ConfusionMatrixCallableInterfaceMultiArgs]
16
+ ):
17
+
18
+ leap_binder.add_custom_metric(user_function, name, direction)
19
+
20
+ def _validate_custom_metric_input_args(*args, **kwargs):
21
+ for i, arg in enumerate(args):
22
+ assert isinstance(arg, np.ndarray), (f'tensorleap_custom_metric validation failed: '
23
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
24
+ if leap_binder.batch_size_to_validate:
25
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
26
+ (f'tensorleap_custom_metric validation failed: Argument #{i} '
27
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
28
+ f'instead of {leap_binder.batch_size_to_validate}')
29
+
30
+ for _arg_name, arg in kwargs.items():
31
+ assert isinstance(arg, np.ndarray), (f'tensorleap_custom_metric validation failed: '
32
+ f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
33
+ if leap_binder.batch_size_to_validate:
34
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
35
+ (f'tensorleap_custom_metric validation failed: Argument {_arg_name} '
36
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
37
+ f'instead of {leap_binder.batch_size_to_validate}')
38
+
39
+ def _validate_custom_metric_result(result):
40
+ assert isinstance(result, np.ndarray), (f'tensorleap_custom_metric validation failed: '
41
+ f'The return type should be a numpy array. Got {type(result)}.')
42
+ assert len(result.shape) == 1, (f'tensorleap_custom_metric validation failed: '
43
+ f'The return shape should be 1D. Got {len(result.shape)}D.')
44
+ if leap_binder.batch_size_to_validate:
45
+ assert result.shape[0] == leap_binder.batch_size_to_validate, \
46
+ f'tensorleap_custom_metric validation failed: The return len should be as the batch size.'
47
+
48
+ def inner(*args, **kwargs):
49
+ _validate_custom_metric_input_args(*args, **kwargs)
50
+ result = user_function(*args, **kwargs)
51
+ _validate_custom_metric_result(result)
52
+ return result
53
+
54
+ return inner
55
+
56
+ return decorating_function
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.49.dev10
3
+ Version: 1.0.49.dev100
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -18,12 +18,13 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
18
18
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
19
19
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
20
20
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
21
- code_loader/inner_leap_binder/leapbinder.py,sha256=4DaLjwwa0wR9qR6K5hKZNakd1oludBRRZPJcCzKsi78,24912
21
+ code_loader/inner_leap_binder/leapbinder.py,sha256=QXHXEXV5jBCqggDBD7hDHVcgveNb1jeL382iPTa9K-o,25425
22
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=pZjIVP-zqdOPk785r6G4ycTTvlNiRB-UqQz9_gcPPKY,3133
22
23
  code_loader/leaploader.py,sha256=POUgD6x1GH_iF_eDGz-VLX4DsIl2kddufKVDdrA_K-U,19491
23
24
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
24
25
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
26
  code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
26
- code_loader-1.0.49.dev10.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
27
- code_loader-1.0.49.dev10.dist-info/METADATA,sha256=8USERR1xGkl_wSILrlnF-6gZSr6QunOIkUyO4ulQUBI,894
28
- code_loader-1.0.49.dev10.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
29
- code_loader-1.0.49.dev10.dist-info/RECORD,,
27
+ code_loader-1.0.49.dev100.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
28
+ code_loader-1.0.49.dev100.dist-info/METADATA,sha256=zgwAvViqFWeiGeHRUr64Ry-umEsndhv8pma3jKkCaoc,895
29
+ code_loader-1.0.49.dev100.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
30
+ code_loader-1.0.49.dev100.dist-info/RECORD,,