code-loader 1.0.40a0__py3-none-any.whl → 1.0.40a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@ from typing import Any, Callable, List, Optional, Dict, Union, Type
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
 
7
- from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
7
+ from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection, InstanceAnalysisType
8
8
  from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
9
9
  LeapTextMask, LeapImageMask, LeapImageWithBBox
10
10
 
@@ -121,7 +121,14 @@ class DatasetBaseHandler:
121
121
  @dataclass
122
122
  class InputHandler(DatasetBaseHandler):
123
123
  shape: Optional[List[int]] = None
124
- instance_function: Optional[InstanceCallableInterface] = None
124
+
125
+
126
+ @dataclass
127
+ class ElementInstanceHandler:
128
+ input_name: str
129
+ instance_function: InstanceCallableInterface
130
+ analysis_type: InstanceAnalysisType
131
+
125
132
 
126
133
 
127
134
  @dataclass
@@ -157,6 +164,7 @@ class DatasetIntegrationSetup:
157
164
  unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
158
165
  visualizers: List[VisualizerHandler] = field(default_factory=list)
159
166
  inputs: List[InputHandler] = field(default_factory=list)
167
+ element_instances: List[ElementInstanceHandler] = field(default_factory=list)
160
168
  ground_truths: List[GroundTruthHandler] = field(default_factory=list)
161
169
  metadata: List[MetadataHandler] = field(default_factory=list)
162
170
  prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
@@ -63,3 +63,9 @@ class ConfusionMatrixValue(Enum):
63
63
  class TestingSectionEnum(Enum):
64
64
  Warnings = "Warnings"
65
65
  Errors = "Errors"
66
+
67
+
68
+
69
+ class InstanceAnalysisType(Enum):
70
+ MaskInput = "MaskInput"
71
+ MaskLatentSpace = "MaskLatentSpace"
@@ -10,8 +10,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
12
12
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
- RawInputsForHeatmap, InstanceCallableInterface
14
- from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
13
+ RawInputsForHeatmap, InstanceCallableInterface, ElementInstanceHandler
14
+ from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, InstanceAnalysisType
15
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
17
17
  from code_loader.utils import to_numpy_return_wrapper, get_shape
@@ -96,13 +96,16 @@ class LeapBinder:
96
96
  def set_unlabeled_data_preprocess(self, function: Callable[[], PreprocessResponse]) -> None:
97
97
  self.setup_container.unlabeled_data_preprocess = UnlabeledDataPreprocessHandler(function)
98
98
 
99
- def set_input(self, function: SectionCallableInterface, name: str,
100
- instance_function: Optional[InstanceCallableInterface] = None) -> None:
99
+ def set_input(self, function: SectionCallableInterface, name: str) -> None:
101
100
  function = to_numpy_return_wrapper(function)
102
- self.setup_container.inputs.append(InputHandler(name, function, instance_function=instance_function))
101
+ self.setup_container.inputs.append(InputHandler(name, function))
103
102
 
104
103
  self._encoder_names.append(name)
105
104
 
105
+ def set_instance_element(self, input_name: str, instance_function: Optional[InstanceCallableInterface] = None,
106
+ analysis_type: InstanceAnalysisType = InstanceAnalysisType.MaskInput) -> None:
107
+ self.setup_container.element_instances.append(ElementInstanceHandler(input_name, instance_function, analysis_type))
108
+
106
109
  def add_custom_loss(self, function: CustomCallableInterface, name: str) -> None:
107
110
  arg_names = inspect.getfullargspec(function)[0]
108
111
  self.setup_container.custom_loss_handlers.append(CustomLossHandler(name, function, arg_names))
code_loader/leaploader.py CHANGED
@@ -298,13 +298,15 @@ class LeapLoader:
298
298
  def _get_inputs(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
299
299
  return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, idx)
300
300
 
301
- def get_instance_elements(self, state: DataStateEnum, idx: int, input_name: str) -> Optional[List[ElementInstance]]:
301
+ def get_instance_elements(self, state: DataStateEnum, idx: int, input_name: str) \
302
+ -> Tuple[Optional[List[ElementInstance]], Optional[InstanceAnalysisType]]:
302
303
  preprocess_result = self._preprocess_result()
303
304
  preprocess_state = preprocess_result[state]
304
- for input in global_leap_binder.setup_container.inputs:
305
- if input.name == input_name:
306
- if input.instance_function is not None:
307
- return input.instance_function(idx, preprocess_state)
305
+ for element_instance in global_leap_binder.setup_container.element_instances:
306
+ if element_instance.input_name == input_name:
307
+ return element_instance.instance_function(idx, preprocess_state), element_instance.analysis_type
308
+
309
+ return None, None
308
310
 
309
311
 
310
312
  def _get_gt(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.40a0
3
+ Version: 1.0.40a1
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,18 +1,18 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=eVPrNAoU3ViaxeZQof-qS2bWsgtuJXTqfCF9psS1f7k,4950
5
- code_loader/contract/enums.py,sha256=wJSMFXL_E-JWK_XAgsqOZwOLDiRfcT_rp0oQ3P_edNI,1560
4
+ code_loader/contract/datasetclasses.py,sha256=45g-iy_HEd1Q7LAMaAjNbYqJxgEwvTlNNC_V6hN5koU,5141
5
+ code_loader/contract/enums.py,sha256=3jBkDBa7od9SrSZdozrvy4InvyuRk4Ov1yEMsQxc-Pg,1665
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/responsedataclasses.py,sha256=WSHmFZWOFhGL1eED1u-aoRotPQg2owFQ-t3xSViWXSI,2808
8
8
  code_loader/contract/visualizer_classes.py,sha256=1FjVO744J_EMuJfHWXGdvSz6vl3Vu7iS3CDfs8MzEEQ,5138
9
9
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
10
- code_loader/inner_leap_binder/leapbinder.py,sha256=n9oiuC-bFLQsG-ZRSO6VeOWg2CB6qHZciAjRlIFKmi8,13778
11
- code_loader/leaploader.py,sha256=hsU1SdUTZIU_sRMGs9LjxMuvIECKsKWq_wL9W5sKMX4,17780
10
+ code_loader/inner_leap_binder/leapbinder.py,sha256=wf3yxyxhN8k_eFecmt7nTUqbD3uMEoGccZcrOilOftc,14054
11
+ code_loader/leaploader.py,sha256=D1CzclYll0kg_2sTZ_tS1-BFWlMCmiJZHrQCMwRiNtg,17882
12
12
  code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
13
13
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  code_loader/visualizers/default_visualizers.py,sha256=HqWx2qfTrroGl2n8Fpmr_4X-rk7tE2oGapjO3gzz4WY,2226
15
- code_loader-1.0.40a0.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
16
- code_loader-1.0.40a0.dist-info/METADATA,sha256=ZZDeEYKHHQLlyqMKFzGXyKpWqui6gg9HWXcWrU6FdSs,770
17
- code_loader-1.0.40a0.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
18
- code_loader-1.0.40a0.dist-info/RECORD,,
15
+ code_loader-1.0.40a1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
16
+ code_loader-1.0.40a1.dist-info/METADATA,sha256=o9NkQbpAF3lZIRPc147tORbE0NHH1gBJ7igKxKvDCIg,770
17
+ code_loader-1.0.40a1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
18
+ code_loader-1.0.40a1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.8.1
2
+ Generator: poetry-core 1.9.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any