code-loader 1.0.61.dev2__py3-none-any.whl → 1.0.61.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -124,24 +124,34 @@ class CustomLossHandler:
124
124
 
125
125
 
126
126
  @dataclass
127
- class MetricHandler:
127
+ class MetricHandlerData:
128
128
  name: str
129
- function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
130
129
  arg_names: List[str]
131
130
  direction: Optional[MetricDirection] = MetricDirection.Downward
132
131
 
133
132
 
133
+ @dataclass
134
+ class MetricHandler:
135
+ metric_handler_data: MetricHandlerData
136
+ function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
137
+
138
+
134
139
  @dataclass
135
140
  class RawInputsForHeatmap:
136
141
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
137
142
 
138
143
 
139
144
  @dataclass
140
- class VisualizerHandler:
145
+ class VisualizerHandlerData:
141
146
  name: str
142
- function: VisualizerCallableInterface
143
147
  type: LeapDataType
144
148
  arg_names: List[str]
149
+
150
+
151
+ @dataclass
152
+ class VisualizerHandler:
153
+ visualizer_handler_data: VisualizerHandlerData
154
+ function: VisualizerCallableInterface
145
155
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None
146
156
 
147
157
 
@@ -9,7 +9,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
- CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
12
+ CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
+ RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData
13
14
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
14
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
15
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
@@ -127,7 +128,7 @@ class LeapBinder:
127
128
  f'should be {expected_return_type}')
128
129
 
129
130
  self.setup_container.visualizers.append(
130
- VisualizerHandler(name, function, visualizer_type, arg_names, heatmap_visualizer))
131
+ VisualizerHandler(VisualizerHandlerData(name, visualizer_type, arg_names), function, heatmap_visualizer))
131
132
  self._visualizer_names.append(name)
132
133
 
133
134
  def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
@@ -251,7 +252,7 @@ class LeapBinder:
251
252
  leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
252
253
  """
253
254
  arg_names = inspect.getfullargspec(function)[0]
254
- self.setup_container.metrics.append(MetricHandler(name, function, arg_names, direction))
255
+ self.setup_container.metrics.append(MetricHandler(MetricHandlerData(name, arg_names, direction), function))
255
256
 
256
257
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
257
258
  """
code_loader/leaploader.py CHANGED
@@ -12,7 +12,8 @@ import numpy.typing as npt
12
12
 
13
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
14
14
  PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
15
- PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
15
+ PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData, \
16
+ MetricCallableReturnType
16
17
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
17
18
  from code_loader.contract.exceptions import DatasetScriptException
18
19
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
@@ -65,7 +66,16 @@ class LeapLoader(LeapLoaderBase):
65
66
  spec.loader.exec_module(file)
66
67
 
67
68
  @lru_cache()
68
- def metric_by_name(self) -> Dict[str, MetricHandler]:
69
+ def metric_by_name(self) -> Dict[str, MetricHandlerData]:
70
+ self.exec_script()
71
+ setup = global_leap_binder.setup_container
72
+ return {
73
+ metric_handler.name: metric_handler.metric_handler_data
74
+ for metric_handler in setup.metrics
75
+ }
76
+
77
+ @lru_cache()
78
+ def _metric_handler_by_name(self) -> Dict[str, MetricHandler]:
69
79
  self.exec_script()
70
80
  setup = global_leap_binder.setup_container
71
81
  return {
@@ -74,7 +84,16 @@ class LeapLoader(LeapLoaderBase):
74
84
  }
75
85
 
76
86
  @lru_cache()
77
- def visualizer_by_name(self) -> Dict[str, VisualizerHandler]:
87
+ def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
88
+ self.exec_script()
89
+ setup = global_leap_binder.setup_container
90
+ return {
91
+ visualizer_handler.name: visualizer_handler.visualizer_handler_data
92
+ for visualizer_handler in setup.visualizers
93
+ }
94
+
95
+ @lru_cache()
96
+ def _visualizer_handler_by_name(self) -> Dict[str, VisualizerHandler]:
78
97
  self.exec_script()
79
98
  setup = global_leap_binder.setup_container
80
99
  return {
@@ -199,17 +218,21 @@ class LeapLoader(LeapLoaderBase):
199
218
  all_dataset_base_handlers.extend(global_leap_binder.setup_container.metadata)
200
219
  return all_dataset_base_handlers
201
220
 
202
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
203
- ) -> LeapData:
221
+ def run_metric(self, metric_name: str,
222
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
223
+ self._preprocess_result()
224
+ return self._metric_handler_by_name()[metric_name].function(**input_tensors_by_arg_name)
225
+
226
+ def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
204
227
  # running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
205
228
  # global param that visualizer is using)
206
229
  self._preprocess_result()
207
230
 
208
- return self.visualizer_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
231
+ return self._visualizer_handler_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
209
232
 
210
233
  def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
211
234
  ) -> npt.NDArray[np.float32]:
212
- heatmap_function = self.visualizer_by_name()[visualizer_name].heatmap_function
235
+ heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
213
236
  if heatmap_function is None:
214
237
  assert len(input_tensors_by_arg_name) == 1
215
238
  return list(input_tensors_by_arg_name.values())[0]
@@ -7,8 +7,8 @@ from typing import Dict, List, Union, Type
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
 
10
- from code_loader.contract.datasetclasses import DatasetSample, VisualizerHandler, LeapData, CustomLossHandler, \
11
- PredictionTypeHandler, CustomLayerHandler, MetricHandler
10
+ from code_loader.contract.datasetclasses import DatasetSample, LeapData, CustomLossHandler, \
11
+ PredictionTypeHandler, CustomLayerHandler, VisualizerHandlerData, MetricHandlerData, MetricCallableReturnType
12
12
  from code_loader.contract.enums import DataStateEnum
13
13
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
14
14
  DatasetSetup, ModelSetup
@@ -20,11 +20,11 @@ class LeapLoaderBase:
20
20
  self.code_path = code_path
21
21
 
22
22
  @abstractmethod
23
- def metric_by_name(self) -> Dict[str, MetricHandler]:
23
+ def metric_by_name(self) -> Dict[str, MetricHandlerData]:
24
24
  pass
25
25
 
26
26
  @abstractmethod
27
- def visualizer_by_name(self) -> Dict[str, VisualizerHandler]:
27
+ def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
28
28
  pass
29
29
 
30
30
  @abstractmethod
@@ -48,8 +48,12 @@ class LeapLoaderBase:
48
48
  pass
49
49
 
50
50
  @abstractmethod
51
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
52
- ) -> LeapData:
51
+ def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
52
+ pass
53
+
54
+ @abstractmethod
55
+ def run_metric(self, metric_name: str,
56
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
53
57
  pass
54
58
 
55
59
  @abstractmethod
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.61.dev2
3
+ Version: 1.0.61.dev4
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -2,7 +2,7 @@ LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
4
4
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- code_loader/contract/datasetclasses.py,sha256=lFS7_weizsjzx4_tYwYGrrRUj1sgIl010h9FON4brb8,6670
5
+ code_loader/contract/datasetclasses.py,sha256=LVsZIPpxgjz9ycYU21nNe2-7ug-A8rpi-X236m-2HwM,6844
6
6
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
7
7
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
8
8
  code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
@@ -18,14 +18,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
18
18
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
19
19
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
20
20
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
21
- code_loader/inner_leap_binder/leapbinder.py,sha256=35hyesDdmjOD9wdrTLyayb-vm9aDfmEbMA0c4EQR1LA,25090
21
+ code_loader/inner_leap_binder/leapbinder.py,sha256=bb-z_QS3b1eQc2Be0lp5CDm8eRCP9NN1f7Hu2pQe_4E,25180
22
22
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=uuM_ht9HZ1GH2IabKeGQ_x9NmD3poK_h1Gt0NruwJuY,19704
23
- code_loader/leaploader.py,sha256=KC_6oso5pbOHZ56sUTcV6qdFzEbIJ8MdEtKu-nDrQfE,19707
24
- code_loader/leaploaderbase.py,sha256=ZPncue31Ld6NeaOZz4H0PJXxl5AYJfC01tX_H_ARVFc,2542
23
+ code_loader/leaploader.py,sha256=uk2E1_0psS89W68y_pTId3ta7d1s8gmJXDbAzt1sqzw,20708
24
+ code_loader/leaploaderbase.py,sha256=Ursxo27dhiWUc5kG0WcQwJPkcwBFaq6U3sn8DHDLOdg,2747
25
25
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
26
26
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
28
- code_loader-1.0.61.dev2.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
29
- code_loader-1.0.61.dev2.dist-info/METADATA,sha256=EzpmH_OVUl7AVQxhi2-Gth86YjYw9N0PdIKDl_kd1J4,893
30
- code_loader-1.0.61.dev2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
31
- code_loader-1.0.61.dev2.dist-info/RECORD,,
28
+ code_loader-1.0.61.dev4.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
29
+ code_loader-1.0.61.dev4.dist-info/METADATA,sha256=aolebyNYL9I7vdmTng-2XkEBSCijaCV1p9nQ0aQLEU4,893
30
+ code_loader-1.0.61.dev4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
31
+ code_loader-1.0.61.dev4.dist-info/RECORD,,