code-loader 1.0.61.dev4__py3-none-any.whl → 1.0.61.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,12 +117,17 @@ MetricCallableReturnType = Union[Any, List[List[ConfusionMatrixElement]]]
117
117
 
118
118
 
119
119
  @dataclass
120
- class CustomLossHandler:
120
+ class CustomLossHandlerData:
121
121
  name: str
122
- function: CustomCallableInterface
123
122
  arg_names: List[str]
124
123
 
125
124
 
125
+ @dataclass
126
+ class CustomLossHandler:
127
+ custom_loss_handler_data: CustomLossHandlerData
128
+ function: CustomCallableInterface
129
+
130
+
126
131
  @dataclass
127
132
  class MetricHandlerData:
128
133
  name: str
@@ -10,7 +10,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
12
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
- RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData
13
+ RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData
14
14
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
15
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
@@ -226,7 +226,7 @@ class LeapBinder:
226
226
  leap_binder.add_custom_loss(custom_loss_function, name='custom_loss')
227
227
  """
228
228
  arg_names = inspect.getfullargspec(function)[0]
229
- self.setup_container.custom_loss_handlers.append(CustomLossHandler(name, function, arg_names))
229
+ self.setup_container.custom_loss_handlers.append(CustomLossHandler(CustomLossHandlerData(name, arg_names), function))
230
230
 
231
231
  def add_custom_metric(self,
232
232
  function: Union[CustomCallableInterfaceMultiArgs,
code_loader/leaploader.py CHANGED
@@ -11,9 +11,9 @@ import numpy as np
11
11
  import numpy.typing as npt
12
12
 
13
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
14
- PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
14
+ PreprocessResponse, VisualizerHandler, LeapData, \
15
15
  PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData, \
16
- MetricCallableReturnType
16
+ MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler
17
17
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
18
18
  from code_loader.contract.exceptions import DatasetScriptException
19
19
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
@@ -102,7 +102,16 @@ class LeapLoader(LeapLoaderBase):
102
102
  }
103
103
 
104
104
  @lru_cache()
105
- def custom_loss_by_name(self) -> Dict[str, CustomLossHandler]:
105
+ def custom_loss_by_name(self) -> Dict[str, CustomLossHandlerData]:
106
+ self.exec_script()
107
+ setup = global_leap_binder.setup_container
108
+ return {
109
+ custom_loss_handler.name: custom_loss_handler.custom_loss_handler_data
110
+ for custom_loss_handler in setup.custom_loss_handlers
111
+ }
112
+
113
+ @lru_cache()
114
+ def _custom_loss_handler_by_name(self) -> Dict[str, CustomLossHandler]:
106
115
  self.exec_script()
107
116
  setup = global_leap_binder.setup_container
108
117
  return {
@@ -223,6 +232,10 @@ class LeapLoader(LeapLoaderBase):
223
232
  self._preprocess_result()
224
233
  return self._metric_handler_by_name()[metric_name].function(**input_tensors_by_arg_name)
225
234
 
235
+ def run_custom_loss(self, custom_loss_name: str,
236
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
237
+ return self._custom_loss_handler_by_name()[custom_loss_name].function(**input_tensors_by_arg_name)
238
+
226
239
  def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
227
240
  # running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
228
241
  # global param that visualizer is using)
@@ -7,8 +7,9 @@ from typing import Dict, List, Union, Type
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
 
10
- from code_loader.contract.datasetclasses import DatasetSample, LeapData, CustomLossHandler, \
11
- PredictionTypeHandler, CustomLayerHandler, VisualizerHandlerData, MetricHandlerData, MetricCallableReturnType
10
+ from code_loader.contract.datasetclasses import DatasetSample, LeapData, \
11
+ PredictionTypeHandler, CustomLayerHandler, VisualizerHandlerData, MetricHandlerData, MetricCallableReturnType, \
12
+ CustomLossHandlerData
12
13
  from code_loader.contract.enums import DataStateEnum
13
14
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
14
15
  DatasetSetup, ModelSetup
@@ -28,7 +29,7 @@ class LeapLoaderBase:
28
29
  pass
29
30
 
30
31
  @abstractmethod
31
- def custom_loss_by_name(self) -> Dict[str, CustomLossHandler]:
32
+ def custom_loss_by_name(self) -> Dict[str, CustomLossHandlerData]:
32
33
  pass
33
34
 
34
35
  @abstractmethod
@@ -56,6 +57,11 @@ class LeapLoaderBase:
56
57
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
57
58
  pass
58
59
 
60
+ @abstractmethod
61
+ def run_custom_loss(self, custom_loss_name: str,
62
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
63
+ pass
64
+
59
65
  @abstractmethod
60
66
  def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
61
67
  ) -> npt.NDArray[np.float32]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.61.dev4
3
+ Version: 1.0.61.dev6
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -2,7 +2,7 @@ LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
4
4
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- code_loader/contract/datasetclasses.py,sha256=LVsZIPpxgjz9ycYU21nNe2-7ug-A8rpi-X236m-2HwM,6844
5
+ code_loader/contract/datasetclasses.py,sha256=pLlMMmeD6Q6MMbWfxA3Ihn-tAgJaYw2DBcIcxi_HXLQ,6938
6
6
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
7
7
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
8
8
  code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
@@ -18,14 +18,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
18
18
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
19
19
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
20
20
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
21
- code_loader/inner_leap_binder/leapbinder.py,sha256=bb-z_QS3b1eQc2Be0lp5CDm8eRCP9NN1f7Hu2pQe_4E,25180
21
+ code_loader/inner_leap_binder/leapbinder.py,sha256=u7WF_nfXTe2c5pvcgVbBzJEf61zGzTGlhPNBa0usaVE,25226
22
22
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=uuM_ht9HZ1GH2IabKeGQ_x9NmD3poK_h1Gt0NruwJuY,19704
23
- code_loader/leaploader.py,sha256=uk2E1_0psS89W68y_pTId3ta7d1s8gmJXDbAzt1sqzw,20708
24
- code_loader/leaploaderbase.py,sha256=Ursxo27dhiWUc5kG0WcQwJPkcwBFaq6U3sn8DHDLOdg,2747
23
+ code_loader/leaploader.py,sha256=obFJFRWM4PzKlWcW7JrB41NmFJkyJuqiJOD8tcp4e7Y,21332
24
+ code_loader/leaploaderbase.py,sha256=waxwf5xQP8FhJjP7v7kRxDIRzyGu_26Z-glHWXyDt4g,2936
25
25
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
26
26
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
28
- code_loader-1.0.61.dev4.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
29
- code_loader-1.0.61.dev4.dist-info/METADATA,sha256=aolebyNYL9I7vdmTng-2XkEBSCijaCV1p9nQ0aQLEU4,893
30
- code_loader-1.0.61.dev4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
31
- code_loader-1.0.61.dev4.dist-info/RECORD,,
28
+ code_loader-1.0.61.dev6.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
29
+ code_loader-1.0.61.dev6.dist-info/METADATA,sha256=3pJkO1HzFP0a-NYfy8L3Fu8NsgRP6RFa18Vi8eBhqvM,893
30
+ code_loader-1.0.61.dev6.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
31
+ code_loader-1.0.61.dev6.dist-info/RECORD,,