code-loader 1.0.61.dev2__py3-none-any.whl → 1.0.61.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -124,24 +124,34 @@ class CustomLossHandler:
124
124
 
125
125
 
126
126
  @dataclass
127
- class MetricHandler:
127
+ class MetricHandlerData:
128
128
  name: str
129
- function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
130
129
  arg_names: List[str]
131
130
  direction: Optional[MetricDirection] = MetricDirection.Downward
132
131
 
133
132
 
133
+ @dataclass
134
+ class MetricHandler:
135
+ metric_handler_data: MetricHandlerData
136
+ function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
137
+
138
+
134
139
  @dataclass
135
140
  class RawInputsForHeatmap:
136
141
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
137
142
 
138
143
 
139
144
  @dataclass
140
- class VisualizerHandler:
145
+ class VisualizerHandlerData:
141
146
  name: str
142
- function: VisualizerCallableInterface
143
147
  type: LeapDataType
144
148
  arg_names: List[str]
149
+
150
+
151
+ @dataclass
152
+ class VisualizerHandler:
153
+ visualizer_handler_data: VisualizerHandlerData
154
+ function: VisualizerCallableInterface
145
155
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None
146
156
 
147
157
 
@@ -9,7 +9,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
- CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
12
+ CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
+ RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData
13
14
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
14
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
15
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
@@ -127,7 +128,7 @@ class LeapBinder:
127
128
  f'should be {expected_return_type}')
128
129
 
129
130
  self.setup_container.visualizers.append(
130
- VisualizerHandler(name, function, visualizer_type, arg_names, heatmap_visualizer))
131
+ VisualizerHandler(VisualizerHandlerData(name, visualizer_type, arg_names), function, heatmap_visualizer))
131
132
  self._visualizer_names.append(name)
132
133
 
133
134
  def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
@@ -251,7 +252,7 @@ class LeapBinder:
251
252
  leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
252
253
  """
253
254
  arg_names = inspect.getfullargspec(function)[0]
254
- self.setup_container.metrics.append(MetricHandler(name, function, arg_names, direction))
255
+ self.setup_container.metrics.append(MetricHandler(MetricHandlerData(name, arg_names, direction), function))
255
256
 
256
257
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
257
258
  """
code_loader/leaploader.py CHANGED
@@ -12,7 +12,7 @@ import numpy.typing as npt
12
12
 
13
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
14
14
  PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
15
- PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
15
+ PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData
16
16
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
17
17
  from code_loader.contract.exceptions import DatasetScriptException
18
18
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
@@ -65,7 +65,16 @@ class LeapLoader(LeapLoaderBase):
65
65
  spec.loader.exec_module(file)
66
66
 
67
67
  @lru_cache()
68
- def metric_by_name(self) -> Dict[str, MetricHandler]:
68
+ def metric_by_name(self) -> Dict[str, MetricHandlerData]:
69
+ self.exec_script()
70
+ setup = global_leap_binder.setup_container
71
+ return {
72
+ metric_handler.name: metric_handler.metric_handler_data
73
+ for metric_handler in setup.metrics
74
+ }
75
+
76
+ @lru_cache()
77
+ def _metric_handler_by_name(self) -> Dict[str, MetricHandler]:
69
78
  self.exec_script()
70
79
  setup = global_leap_binder.setup_container
71
80
  return {
@@ -74,7 +83,16 @@ class LeapLoader(LeapLoaderBase):
74
83
  }
75
84
 
76
85
  @lru_cache()
77
- def visualizer_by_name(self) -> Dict[str, VisualizerHandler]:
86
+ def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
87
+ self.exec_script()
88
+ setup = global_leap_binder.setup_container
89
+ return {
90
+ visualizer_handler.name: visualizer_handler.visualizer_handler_data
91
+ for visualizer_handler in setup.visualizers
92
+ }
93
+
94
+ @lru_cache()
95
+ def _visualizer_handler_by_name(self) -> Dict[str, VisualizerHandler]:
78
96
  self.exec_script()
79
97
  setup = global_leap_binder.setup_container
80
98
  return {
@@ -205,11 +223,11 @@ class LeapLoader(LeapLoaderBase):
205
223
  # global param that visualizer is using)
206
224
  self._preprocess_result()
207
225
 
208
- return self.visualizer_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
226
+ return self._visualizer_handler_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
209
227
 
210
228
  def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
211
229
  ) -> npt.NDArray[np.float32]:
212
- heatmap_function = self.visualizer_by_name()[visualizer_name].heatmap_function
230
+ heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
213
231
  if heatmap_function is None:
214
232
  assert len(input_tensors_by_arg_name) == 1
215
233
  return list(input_tensors_by_arg_name.values())[0]
@@ -7,8 +7,8 @@ from typing import Dict, List, Union, Type
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
 
10
- from code_loader.contract.datasetclasses import DatasetSample, VisualizerHandler, LeapData, CustomLossHandler, \
11
- PredictionTypeHandler, CustomLayerHandler, MetricHandler
10
+ from code_loader.contract.datasetclasses import DatasetSample, LeapData, CustomLossHandler, \
11
+ PredictionTypeHandler, CustomLayerHandler, VisualizerHandlerData, MetricHandlerData
12
12
  from code_loader.contract.enums import DataStateEnum
13
13
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
14
14
  DatasetSetup, ModelSetup
@@ -20,11 +20,11 @@ class LeapLoaderBase:
20
20
  self.code_path = code_path
21
21
 
22
22
  @abstractmethod
23
- def metric_by_name(self) -> Dict[str, MetricHandler]:
23
+ def metric_by_name(self) -> Dict[str, MetricHandlerData]:
24
24
  pass
25
25
 
26
26
  @abstractmethod
27
- def visualizer_by_name(self) -> Dict[str, VisualizerHandler]:
27
+ def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
28
28
  pass
29
29
 
30
30
  @abstractmethod
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.61.dev2
3
+ Version: 1.0.61.dev3
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -2,7 +2,7 @@ LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
4
4
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- code_loader/contract/datasetclasses.py,sha256=lFS7_weizsjzx4_tYwYGrrRUj1sgIl010h9FON4brb8,6670
5
+ code_loader/contract/datasetclasses.py,sha256=LVsZIPpxgjz9ycYU21nNe2-7ug-A8rpi-X236m-2HwM,6844
6
6
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
7
7
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
8
8
  code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
@@ -18,14 +18,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
18
18
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
19
19
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
20
20
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
21
- code_loader/inner_leap_binder/leapbinder.py,sha256=35hyesDdmjOD9wdrTLyayb-vm9aDfmEbMA0c4EQR1LA,25090
21
+ code_loader/inner_leap_binder/leapbinder.py,sha256=bb-z_QS3b1eQc2Be0lp5CDm8eRCP9NN1f7Hu2pQe_4E,25180
22
22
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=uuM_ht9HZ1GH2IabKeGQ_x9NmD3poK_h1Gt0NruwJuY,19704
23
- code_loader/leaploader.py,sha256=KC_6oso5pbOHZ56sUTcV6qdFzEbIJ8MdEtKu-nDrQfE,19707
24
- code_loader/leaploaderbase.py,sha256=ZPncue31Ld6NeaOZz4H0PJXxl5AYJfC01tX_H_ARVFc,2542
23
+ code_loader/leaploader.py,sha256=K5kVjY67aa-rfU5LwJQ0DZo70GlV394NmplGY4PAdG0,20415
24
+ code_loader/leaploaderbase.py,sha256=n_5zyHl4tlalJ7aDGCbfl9HKl_vL591RUJg1PgnQi9I,2558
25
25
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
26
26
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
28
- code_loader-1.0.61.dev2.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
29
- code_loader-1.0.61.dev2.dist-info/METADATA,sha256=EzpmH_OVUl7AVQxhi2-Gth86YjYw9N0PdIKDl_kd1J4,893
30
- code_loader-1.0.61.dev2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
31
- code_loader-1.0.61.dev2.dist-info/RECORD,,
28
+ code_loader-1.0.61.dev3.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
29
+ code_loader-1.0.61.dev3.dist-info/METADATA,sha256=DuC_vTtfnfP4WZKrlQ5ZGv-_rTiDpHzBAYGXWPeXBlc,893
30
+ code_loader-1.0.61.dev3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
31
+ code_loader-1.0.61.dev3.dist-info/RECORD,,