code-loader 1.0.40a1__py3-none-any.whl → 1.0.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,29 +4,40 @@ from typing import Any, Callable, List, Optional, Dict, Union, Type
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
 
7
- from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection, InstanceAnalysisType
7
+ from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
8
8
  from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
9
- LeapTextMask, LeapImageMask, LeapImageWithBBox
9
+ LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
10
10
 
11
11
  custom_latent_space_attribute = "custom_latent_space"
12
12
 
13
13
 
14
14
  @dataclass
15
15
  class PreprocessResponse:
16
+ """
17
+ An object that holds the preprocessed data for use within the Tensorleap platform.
18
+
19
+ This class is used to encapsulate the results of data preprocessing, including inputs, metadata, labels, and other relevant information.
20
+ It facilitates handling and integration of the processed data within Tensorleap.
21
+
22
+ Attributes:
23
+ length (int): The length of the preprocessed data.
24
+ data (Any): The preprocessed data itself. This can be any data type depending on the preprocessing logic.
25
+
26
+ Example:
27
+ # Example usage of PreprocessResponse
28
+ preprocessed_data = {
29
+ 'images': ['path/to/image1.jpg', 'path/to/image2.jpg'],
30
+ 'labels': ['SUV', 'truck'],
31
+ 'metadata': [{'id': 1, 'source': 'camera1'}, {'id': 2, 'source': 'camera2'}]
32
+ }
33
+ response = PreprocessResponse(length=len(preprocessed_data), data=preprocessed_data)
34
+ """
16
35
  length: int
17
36
  data: Any
18
37
 
19
38
 
20
- @dataclass
21
- class ElementInstance:
22
- name: str
23
- mask: npt.NDArray[np.float32]
24
-
25
-
26
39
  SectionCallableInterface = Callable[[int, PreprocessResponse], npt.NDArray[np.float32]]
27
40
 
28
- InstanceCallableInterface = Callable[[int, PreprocessResponse], List[ElementInstance]]
29
-
30
41
  MetadataSectionCallableInterface = Union[
31
42
  Callable[[int, PreprocessResponse], int],
32
43
  Callable[[int, PreprocessResponse], Dict[str, int]],
@@ -58,11 +69,12 @@ VisualizerCallableInterface = Union[
58
69
  Callable[..., LeapHorizontalBar],
59
70
  Callable[..., LeapImageMask],
60
71
  Callable[..., LeapTextMask],
61
- Callable[..., LeapImageWithBBox]
72
+ Callable[..., LeapImageWithBBox],
73
+ Callable[..., LeapImageWithHeatmap]
62
74
  ]
63
75
 
64
76
  VisualizerCallableReturnType = Union[LeapImage, LeapText, LeapGraph, LeapHorizontalBar,
65
- LeapImageMask, LeapTextMask, LeapImageWithBBox]
77
+ LeapImageMask, LeapTextMask, LeapImageWithBBox, LeapImageWithHeatmap]
66
78
 
67
79
  CustomCallableInterface = Callable[..., Any]
68
80
 
@@ -97,7 +109,6 @@ class MetricHandler:
97
109
  arg_names: List[str]
98
110
  direction: Optional[MetricDirection] = MetricDirection.Downward
99
111
 
100
-
101
112
  @dataclass
102
113
  class RawInputsForHeatmap:
103
114
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -123,14 +134,6 @@ class InputHandler(DatasetBaseHandler):
123
134
  shape: Optional[List[int]] = None
124
135
 
125
136
 
126
- @dataclass
127
- class ElementInstanceHandler:
128
- input_name: str
129
- instance_function: InstanceCallableInterface
130
- analysis_type: InstanceAnalysisType
131
-
132
-
133
-
134
137
  @dataclass
135
138
  class GroundTruthHandler(DatasetBaseHandler):
136
139
  shape: Optional[List[int]] = None
@@ -164,7 +167,6 @@ class DatasetIntegrationSetup:
164
167
  unlabeled_data_preprocess: Optional[UnlabeledDataPreprocessHandler] = None
165
168
  visualizers: List[VisualizerHandler] = field(default_factory=list)
166
169
  inputs: List[InputHandler] = field(default_factory=list)
167
- element_instances: List[ElementInstanceHandler] = field(default_factory=list)
168
170
  ground_truths: List[GroundTruthHandler] = field(default_factory=list)
169
171
  metadata: List[MetadataHandler] = field(default_factory=list)
170
172
  prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
@@ -26,6 +26,7 @@ class LeapDataType(Enum):
26
26
  ImageMask = 'ImageMask'
27
27
  TextMask = 'TextMask'
28
28
  ImageWithBBox = 'ImageWithBBox'
29
+ ImageWithHeatmap = 'ImageWithHeatmap'
29
30
 
30
31
 
31
32
  class MetricDirection(Enum):
@@ -63,9 +64,3 @@ class ConfusionMatrixValue(Enum):
63
64
  class TestingSectionEnum(Enum):
64
65
  Warnings = "Warnings"
65
66
  Errors = "Errors"
66
-
67
-
68
-
69
- class InstanceAnalysisType(Enum):
70
- MaskInput = "MaskInput"
71
- MaskLatentSpace = "MaskLatentSpace"
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Dict, Any
1
+ from typing import List, Optional, Dict, Any, Union
2
2
 
3
3
  from dataclasses import dataclass, field
4
4
  from code_loader.contract.enums import DatasetMetadataType, LeapDataType
@@ -95,7 +95,19 @@ class DatasetTestResultPayload:
95
95
 
96
96
  @dataclass
97
97
  class BoundingBox:
98
- # (x, y) is the center of the bounding box
98
+ """
99
+ Represents a bounding box for an object in an image.
100
+
101
+ Attributes:
102
+ x (float): The x-coordinate of the center of the bounding box, a value between [0, 1] representing the percentage according to the image width.
103
+ y (float): The y-coordinate of the center of the bounding box, a value between [0, 1] representing the percentage according to the image height.
104
+ width (float): The width of the bounding box, a value between [0, 1] representing the percentage according to the image width.
105
+ height (float): The height of the bounding box, a value between [0, 1] representing the percentage according to the image height.
106
+ confidence (float): The confidence score of the bounding box. For predictions, this is a score typically between [0, 1]. For ground truth data, this can be 1.
107
+ label (str): The label or class name associated with the bounding box.
108
+ rotation (float): The rotation of the bounding box, a value between [0, 360] representing the degree of rotation. Default is 0.0.
109
+ metadata (Optional[Dict[str, Union[str, int, float]]]): Optional metadata associated with the bounding box.
110
+ """
99
111
  x: float # value between [0, 1], represent the percentage according to the image size.
100
112
  y: float # value between [0, 1], represent the percentage according to the image size.
101
113
 
@@ -104,6 +116,7 @@ class BoundingBox:
104
116
  confidence: float
105
117
  label: str
106
118
  rotation: float = 0.0 # value between [0, 360], represent the degree of rotation.
119
+ metadata: Optional[Dict[str, Union[str, int, float]]] = None
107
120
 
108
121
 
109
122
  @dataclass
@@ -28,6 +28,13 @@ def validate_type(actual: Any, expected: Any, prefix_message: str = '') -> None:
28
28
 
29
29
  @dataclass
30
30
  class LeapImage:
31
+ """
32
+ Visualizer representing an image for Tensorleap.
33
+
34
+ Attributes:
35
+ data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data.
36
+ type (LeapDataType): The data type, default is LeapDataType.Image.
37
+ """
31
38
  data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
32
39
  type: LeapDataType = LeapDataType.Image
33
40
 
@@ -41,6 +48,14 @@ class LeapImage:
41
48
 
42
49
  @dataclass
43
50
  class LeapImageWithBBox:
51
+ """
52
+ Visualizer representing an image with bounding boxes for Tensorleap, used for object detection tasks.
53
+
54
+ Attributes:
55
+ data (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or [H, W, 1].
56
+ bounding_boxes (List[BoundingBox]): List of Tensorleap bounding boxes objects in relative size to image size.
57
+ type (LeapDataType): The data type, default is LeapDataType.ImageWithBBox.
58
+ """
44
59
  data: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
45
60
  bounding_boxes: List[BoundingBox]
46
61
  type: LeapDataType = LeapDataType.ImageWithBBox
@@ -55,6 +70,13 @@ class LeapImageWithBBox:
55
70
 
56
71
  @dataclass
57
72
  class LeapGraph:
73
+ """
74
+ Visualizer representing a line chart data for Tensorleap.
75
+
76
+ Attributes:
77
+ data (npt.NDArray[np.float32]): The array data, shaped [M, N] where M is the number of data points and N is the number of variables.
78
+ type (LeapDataType): The data type, default is LeapDataType.Graph.
79
+ """
58
80
  data: npt.NDArray[np.float32]
59
81
  type: LeapDataType = LeapDataType.Graph
60
82
 
@@ -67,6 +89,14 @@ class LeapGraph:
67
89
 
68
90
  @dataclass
69
91
  class LeapText:
92
+ """
93
+ Visualizer representing text data for Tensorleap.
94
+
95
+ Attributes:
96
+ data (List[str]): The text data, consisting of a list of text tokens. If the model requires fixed-length inputs,
97
+ it is recommended to maintain the fixed length, using empty strings ('') instead of padding tokens ('PAD') e.g., ['I', 'ate', 'a', 'banana', '', '', '', ...]
98
+ type (LeapDataType): The data type, default is LeapDataType.Text.
99
+ """
70
100
  data: List[str]
71
101
  type: LeapDataType = LeapDataType.Text
72
102
 
@@ -79,6 +109,16 @@ class LeapText:
79
109
 
80
110
  @dataclass
81
111
  class LeapHorizontalBar:
112
+ """
113
+ Visualizer representing horizontal bar data for Tensorleap.
114
+ For example, this can be used to visualize the model's prediction scores in a classification problem.
115
+
116
+ Attributes:
117
+ body (npt.NDArray[np.float32]): The data for the bar, shaped [C], where C is the number of data points.
118
+ labels (List[str]): Labels for the horizontal bar; e.g., when visualizing the model's classification output, labels are the class names.
119
+ Length of `body` should match the length of `labels`, C.
120
+ type (LeapDataType): The data type, default is LeapDataType.HorizontalBar.
121
+ """
82
122
  body: npt.NDArray[np.float32]
83
123
  labels: List[str]
84
124
  type: LeapDataType = LeapDataType.HorizontalBar
@@ -96,6 +136,16 @@ class LeapHorizontalBar:
96
136
 
97
137
  @dataclass
98
138
  class LeapImageMask:
139
+ """
140
+ Visualizer representing an image with a mask for Tensorleap.
141
+ This can be used for tasks such as segmentation, and other applications where it is important to highlight specific regions within an image.
142
+
143
+ Attributes:
144
+ mask (npt.NDArray[np.uint8]): The mask data, shaped [H, W].
145
+ image (npt.NDArray[np.float32] | npt.NDArray[np.uint8]): The image data, shaped [H, W, 3] or shaped [H, W, 1].
146
+ labels (List[str]): Labels associated with the mask regions; e.g., class names for segmented objects. The length of `labels` should match the number of unique values in `mask`.
147
+ type (LeapDataType): The data type, default is LeapDataType.ImageMask.
148
+ """
99
149
  mask: npt.NDArray[np.uint8]
100
150
  image: Union[npt.NDArray[np.float32], npt.NDArray[np.uint8]]
101
151
  labels: List[str]
@@ -117,6 +167,16 @@ class LeapImageMask:
117
167
 
118
168
  @dataclass
119
169
  class LeapTextMask:
170
+ """
171
+ Visualizer representing text data with a mask for Tensorleap.
172
+ This can be used for tasks such as named entity recognition (NER), sentiment analysis, and other applications where it is important to highlight specific tokens or parts of the text.
173
+
174
+ Attributes:
175
+ mask (npt.NDArray[np.uint8]): The mask data, shaped [L].
176
+ text (List[str]): The text data, consisting of a list of text tokens, length of L.
177
+ labels (List[str]): Labels associated with the masked tokens; e.g., named entities or sentiment categories. The length of `labels` should match the number of unique values in `mask`.
178
+ type (LeapDataType): The data type, default is LeapDataType.TextMask.
179
+ """
120
180
  mask: npt.NDArray[np.uint8]
121
181
  text: List[str]
122
182
  labels: List[str]
@@ -135,6 +195,37 @@ class LeapTextMask:
135
195
  validate_type(type(label), str)
136
196
 
137
197
 
198
+ @dataclass
199
+ class LeapImageWithHeatmap:
200
+ """
201
+ Visualizer representing an image with heatmaps for Tensorleap.
202
+ This can be used for tasks such as highlighting important regions in an image, visualizing attention maps, and other applications where it is important to overlay heatmaps on images.
203
+
204
+ Attributes:
205
+ image (npt.NDArray[np.float32]): The image data, shaped [H, W, C], where C is the number of channels.
206
+ heatmaps (npt.NDArray[np.float32]): The heatmap data, shaped [N, H, W], where N is the number of heatmaps.
207
+ labels (List[str]): Labels associated with the heatmaps; e.g., feature names or attention regions. The length of `labels` should match the number of heatmaps, N.
208
+ type (LeapDataType): The data type, default is LeapDataType.ImageWithHeatmap.
209
+ """
210
+ image: npt.NDArray[np.float32]
211
+ heatmaps: npt.NDArray[np.float32]
212
+ labels: List[str]
213
+ type: LeapDataType = LeapDataType.ImageWithHeatmap
214
+
215
+ def __post_init__(self) -> None:
216
+ validate_type(self.type, LeapDataType.ImageWithHeatmap)
217
+ validate_type(type(self.heatmaps), np.ndarray)
218
+ validate_type(self.heatmaps.dtype, np.float32)
219
+ validate_type(type(self.image), np.ndarray)
220
+ validate_type(self.image.dtype, np.float32)
221
+ validate_type(type(self.labels), list)
222
+ for label in self.labels:
223
+ validate_type(type(label), str)
224
+ if self.heatmaps.shape[0] != len(self.labels):
225
+ raise LeapValidationError(
226
+ 'Number of heatmaps and labels must be equal')
227
+
228
+
138
229
  map_leap_data_type_to_visualizer_class = {
139
230
  LeapDataType.Image.value: LeapImage,
140
231
  LeapDataType.Graph.value: LeapGraph,
@@ -142,5 +233,6 @@ map_leap_data_type_to_visualizer_class = {
142
233
  LeapDataType.HorizontalBar.value: LeapHorizontalBar,
143
234
  LeapDataType.ImageMask.value: LeapImageMask,
144
235
  LeapDataType.TextMask.value: LeapTextMask,
145
- LeapDataType.ImageWithBBox.value: LeapImageWithBBox
236
+ LeapDataType.ImageWithBBox.value: LeapImageWithBBox,
237
+ LeapDataType.ImageWithHeatmap.value: LeapImageWithHeatmap
146
238
  }
@@ -9,9 +9,8 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, VisualizerCallableReturnType, \
12
- CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
- RawInputsForHeatmap, InstanceCallableInterface, ElementInstanceHandler
14
- from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, InstanceAnalysisType
12
+ CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
13
+ from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
15
14
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
15
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
17
16
  from code_loader.utils import to_numpy_return_wrapper, get_shape
@@ -22,6 +21,15 @@ from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
22
21
 
23
22
 
24
23
  class LeapBinder:
24
+ """
25
+ Interface to the Tensorleap platform. Provides methods to set up preprocessing,
26
+ visualization, custom loss functions, metrics, and other essential components for integrating the dataset and model
27
+ with Tensorleap.
28
+
29
+ Attributes:
30
+ setup_container (DatasetIntegrationSetup): Container to hold setup configurations.
31
+ cache_container (Dict[str, Any]): Cache container to store intermediate data.
32
+ """
25
33
  def __init__(self) -> None:
26
34
  self.setup_container = DatasetIntegrationSetup()
27
35
  self.cache_container: Dict[str, Any] = {"word_to_index": {}}
@@ -49,6 +57,36 @@ class LeapBinder:
49
57
  name: str,
50
58
  visualizer_type: LeapDataType,
51
59
  heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None) -> None:
60
+ """
61
+ Set a visualizer for a specific data type.
62
+
63
+ Args:
64
+ function (VisualizerCallableInterface): The visualizer function to be used for visualizing the data.
65
+ name (str): The name of the visualizer.
66
+ visualizer_type (LeapDataType): The type of data the visualizer handles (e.g., LeapDataType.Image, LeapDataType.Graph, LeapDataType.Text).
67
+ heatmap_visualizer (Optional[Callable[..., npt.NDArray[np.float32]]]): An optional heatmap visualizer function.
68
+ This is used when a heatmap must be reshaped to overlay correctly on the transformed data within the visualizer
69
+ function i.e., if the visualizer changes the shape or scale of the input data, the heatmap visualizer
70
+ adjusts the heatmap accordingly to ensure it aligns properly with the visualized data.
71
+
72
+ Example:
73
+ def image_resize_visualizer(data: np.ndarray) -> LeapImage:
74
+ # Resize the image to a fixed size
75
+ resized_image = resize_image(data, (224, 224))
76
+ return LeapImage(data=resized_image)
77
+
78
+ def image_resize_heatmap_visualizer(heatmap: RawInputsForHeatmap) -> np.ndarray:
79
+ # Resize the heatmap to match the resized image
80
+ resized_heatmap = resize_heatmap(heatmap, (224, 224))
81
+ return resized_heatmap
82
+
83
+ leap_binder.set_visualizer(
84
+ function=image_resize_visualizer,
85
+ name='image_resize_visualizer',
86
+ visualizer_type=LeapDataType.Image,
87
+ heatmap_visualizer=image_resize_heatmap_visualizer
88
+ )
89
+ """
52
90
  arg_names = inspect.getfullargspec(function)[0]
53
91
  if heatmap_visualizer:
54
92
  visualizer_arg_names_set = set(arg_names)
@@ -91,22 +129,99 @@ class LeapBinder:
91
129
  self._visualizer_names.append(name)
92
130
 
93
131
  def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
132
+ """
133
+ Set the preprocessing function for the dataset. That is the function that returns a list of PreprocessResponse objects for use within the Tensorleap platform.
134
+
135
+ Args:
136
+ function (Callable[[], List[PreprocessResponse]]): The preprocessing function.
137
+
138
+ Example:
139
+ def preprocess_func() -> List[PreprocessResponse]:
140
+ # Preprocess the dataset
141
+ train_data = {
142
+ 'subset': 'train',
143
+ 'images': ['path/to/train/image1.jpg', 'path/to/train/image2.jpg'],
144
+ 'labels': ['SUV', 'truck'],
145
+ 'metadata': [{'id': 1, 'source': 'camera1'}, {'id': 2, 'source': 'camera2'}]}
146
+
147
+ val_data = {
148
+ 'subset': 'val',
149
+ 'images': ['path/to/val/image1.jpg', 'path/to/va;/image2.jpg'],
150
+ 'labels': ['truck', 'truck'],
151
+ 'metadata': [{'id': 1, 'source': 'camera1'}, {'id': 2, 'source': 'camera2'}]}
152
+
153
+ return [PreprocessResponse(length=len(train_data['images']), data=train_data),
154
+ PreprocessResponse(length=len(val_data['images']), data=val_data)]
155
+
156
+ leap_binder.set_preprocess(preprocess_func)
157
+ """
94
158
  self.setup_container.preprocess = PreprocessHandler(function)
95
159
 
96
160
  def set_unlabeled_data_preprocess(self, function: Callable[[], PreprocessResponse]) -> None:
161
+ """
162
+ Set the preprocessing function for unlabeled dataset. This function returns a PreprocessResponse object for use within the Tensorleap platform for sample data that does not contain labels.
163
+
164
+ Args:
165
+ function (Callable[[], PreprocessResponse]): The preprocessing function for unlabeled data.
166
+
167
+ Example:
168
+ def unlabeled_preprocess_func() -> List[PreprocessResponse]:
169
+
170
+ # Preprocess the dataset
171
+ ul_data = {
172
+ 'subset': 'unlabeled',
173
+ 'images': ['path/to/train/image1.jpg', 'path/to/train/image2.jpg'],
174
+ 'metadata': [{'id': 1, 'source': 'camera1'}, {'id': 2, 'source': 'camera2'}]}
175
+
176
+ return [PreprocessResponse(length=len(train_data['images']), data=train_data)]
177
+
178
+ leap_binder.set_preprocess(unlabeled_preprocess_func)
179
+ """
97
180
  self.setup_container.unlabeled_data_preprocess = UnlabeledDataPreprocessHandler(function)
98
181
 
99
182
  def set_input(self, function: SectionCallableInterface, name: str) -> None:
183
+ """
184
+ Set the input handler function.
185
+
186
+ Args:
187
+ function (SectionCallableInterface): The input handler function.
188
+ name (str): The name of the input section.
189
+
190
+ Example:
191
+ def input_encoder(subset: PreprocessResponse, index: int) -> np.ndarray:
192
+ # Return the processed input data for the given index and given subset response
193
+ img_path = subset.`data["images"][idx]
194
+ img = read_img(img_path)
195
+ img = normalize(img)
196
+ return img
197
+
198
+ leap_binder.set_input(input_encoder, name='input_encoder')
199
+ """
100
200
  function = to_numpy_return_wrapper(function)
101
201
  self.setup_container.inputs.append(InputHandler(name, function))
102
202
 
103
203
  self._encoder_names.append(name)
104
204
 
105
- def set_instance_element(self, input_name: str, instance_function: Optional[InstanceCallableInterface] = None,
106
- analysis_type: InstanceAnalysisType = InstanceAnalysisType.MaskInput) -> None:
107
- self.setup_container.element_instances.append(ElementInstanceHandler(input_name, instance_function, analysis_type))
108
-
109
205
  def add_custom_loss(self, function: CustomCallableInterface, name: str) -> None:
206
+ """
207
+ Add a custom loss function to the setup.
208
+
209
+ Args:
210
+ function (CustomCallableInterface): The custom loss function.
211
+ This function receives:
212
+ - y_true: The true labels or values.
213
+ - y_pred: The predicted labels or values.
214
+ This function should return:
215
+ - A numeric value representing the loss.
216
+ name (str): The name of the custom loss function.
217
+
218
+ Example:
219
+ def custom_loss_function(y_true, y_pred):
220
+ # Calculate mean squared error as custom loss
221
+ return np.mean(np.square(y_true - y_pred))
222
+
223
+ leap_binder.add_custom_loss(custom_loss_function, name='custom_loss')
224
+ """
110
225
  arg_names = inspect.getfullargspec(function)[0]
111
226
  self.setup_container.custom_loss_handlers.append(CustomLossHandler(name, function, arg_names))
112
227
 
@@ -116,29 +231,137 @@ class LeapBinder:
116
231
  ConfusionMatrixCallableInterfaceMultiArgs],
117
232
  name: str,
118
233
  direction: Optional[MetricDirection] = MetricDirection.Downward) -> None:
234
+ """
235
+ Add a custom metric to the setup.
236
+
237
+ Args:
238
+ function (Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]): The custom metric function.
239
+ name (str): The name of the custom metric.
240
+ direction (Optional[MetricDirection]): The direction of the metric, either MetricDirection.Upward or MetricDirection.Downward.
241
+ - MetricDirection.Upward: Indicates that higher values of the metric are better and should be maximized.
242
+ - MetricDirection.Downward: Indicates that lower values of the metric are better and should be minimized.
243
+
244
+
245
+ Example:
246
+ def custom_metric_function(y_true, y_pred):
247
+ return np.mean(np.abs(y_true - y_pred))
248
+
249
+ leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
250
+ """
119
251
  arg_names = inspect.getfullargspec(function)[0]
120
252
  self.setup_container.metrics.append(MetricHandler(name, function, arg_names, direction))
121
253
 
122
254
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
255
+ """
256
+ Add prediction labels to the setup.
257
+
258
+ Args:
259
+ name (str): The name of the prediction.
260
+ labels (List[str]): The list of labels for the prediction.
261
+ channel_dim (int): The axis along which the prediction scores are located, default is -1.
262
+
263
+ Must satisfy len(labels) == len(output[channel_dim]).
264
+
265
+ Example:
266
+ leap_binder.add_prediction(name='class_labels', labels=['cat', 'dog'])
267
+ """
123
268
  self.setup_container.prediction_types.append(PredictionTypeHandler(name, labels, channel_dim))
124
269
 
125
270
  def set_ground_truth(self, function: SectionCallableInterface, name: str) -> None:
271
+ """
272
+ Set the ground truth handler function.
273
+
274
+ Args:
275
+ function: The ground truth handler function.
276
+ This function receives two parameters:
277
+ - subset: A `PreprocessResponse` object that contains the preprocessed data.
278
+ - index: The index of the sample within the subset.
279
+ This function should return:
280
+ - A numpy array representing the ground truth for the given sample.
281
+
282
+ name (str): The name of the ground truth section.
283
+
284
+ Example:
285
+ def ground_truth_handler(subset, index):
286
+ label = subset.data['labels'][index]
287
+ # Assuming labels are integers starting from 0
288
+ num_classes = 10 # Example number of classes
289
+ one_hot_label = np.zeros(num_classes)
290
+ one_hot_label[label] = 1
291
+ return one_hot_label
292
+
293
+ leap_binder.set_ground_truth(ground_truth_handler, name='ground_truth')
294
+ """
295
+
126
296
  function = to_numpy_return_wrapper(function)
127
297
  self.setup_container.ground_truths.append(GroundTruthHandler(name, function))
128
298
 
129
299
  self._encoder_names.append(name)
130
300
 
131
301
  def set_metadata(self, function: MetadataSectionCallableInterface, name: str) -> None:
302
+ """
303
+ Set the metadata handler function. This function is used for measuring and analyzing external variable values per sample, which is recommended for analysis within the Tensorleap platform.
304
+
305
+ Args:
306
+ function (MetadataSectionCallableInterface): The metadata handler function.
307
+ This function receives:
308
+ subset (PreprocessResponse): The subset of the data.
309
+ index (int): The index of the sample within the subset.
310
+ This function should return one of the following:
311
+ int: A single integer value.
312
+ Dict[str, int]: A dictionary with string keys and integer values.
313
+ str: A single string value.
314
+ Dict[str, str]: A dictionary with string keys and string values.
315
+ bool: A single boolean value.
316
+ Dict[str, bool]: A dictionary with string keys and boolean values.
317
+ float: A single float value.
318
+ Dict[str, float]: A dictionary with string keys and float values.
319
+
320
+ name (str): The name of the metadata section.
321
+
322
+ Example:
323
+ def metadata_handler_index(subset: PreprocessResponse, index: int) -> int:
324
+ return subset.data['metadata'][index]
325
+
326
+
327
+ def metadata_handler_image_mean(subset: PreprocessResponse, index: int) -> float:
328
+ fpath = subset.data['images'][index]
329
+ image = load_image(fpath)
330
+ mean_value = np.mean(image)
331
+ return mean_value
332
+
333
+ leap_binder.set_metadata(metadata_handler_index, name='metadata_index')
334
+ leap_binder.set_metadata(metadata_handler_image_mean, name='metadata_image_mean')
335
+ """
132
336
  self.setup_container.metadata.append(MetadataHandler(name, function))
133
337
 
134
338
  def set_custom_layer(self, custom_layer: Type[Any], name: str, inspect_layer: bool = False,
135
339
  kernel_index: Optional[int] = None, use_custom_latent_space: bool = False) -> None:
340
+ """
341
+ Set a custom layer for the model.
342
+
343
+ Args:
344
+ custom_layer (Type[Any]): The custom layer class.
345
+ name (str): The name of the custom layer.
346
+ inspect_layer (bool): Whether to inspect the layer, default is False.
347
+ kernel_index (Optional[int]): The index of the kernel to inspect, if inspect_layer is True.
348
+ use_custom_latent_space (bool): Whether to use a custom latent space, default is False.
349
+
350
+ Example:
351
+ class CustomLayer:
352
+ def __init__(self, units: int):
353
+ self.units = units
354
+
355
+ def call(self, inputs):
356
+ return inputs * self.units
357
+
358
+ leap_binder.set_custom_layer(CustomLayer, name='custom_layer', inspect_layer=True, kernel_index=0)
359
+ """
136
360
  if inspect_layer and kernel_index is not None:
137
361
  custom_layer.kernel_index = kernel_index
138
362
 
139
363
  if use_custom_latent_space and not hasattr(custom_layer, custom_latent_space_attribute):
140
- raise Exception(
141
- f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
364
+ raise Exception(f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
142
365
 
143
366
  init_args = inspect.getfullargspec(custom_layer.__init__)[0][1:]
144
367
  call_args = inspect.getfullargspec(custom_layer.call)[0][1:]
code_loader/leaploader.py CHANGED
@@ -5,14 +5,14 @@ import sys
5
5
  from contextlib import redirect_stdout
6
6
  from functools import lru_cache
7
7
  from pathlib import Path
8
- from typing import Dict, List, Iterable, Union, Any, Optional
8
+ from typing import Dict, List, Iterable, Union, Any
9
9
 
10
10
  import numpy as np
11
11
  import numpy.typing as npt
12
12
 
13
13
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
14
14
  PreprocessResponse, VisualizerHandler, VisualizerCallableReturnType, CustomLossHandler, \
15
- PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, ElementInstance
15
+ PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
16
16
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
17
17
  from code_loader.contract.exceptions import DatasetScriptException
18
18
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
@@ -298,17 +298,6 @@ class LeapLoader:
298
298
  def _get_inputs(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
299
299
  return self._get_dataset_handlers(global_leap_binder.setup_container.inputs, state, idx)
300
300
 
301
- def get_instance_elements(self, state: DataStateEnum, idx: int, input_name: str) \
302
- -> Tuple[Optional[List[ElementInstance]], Optional[InstanceAnalysisType]]:
303
- preprocess_result = self._preprocess_result()
304
- preprocess_state = preprocess_result[state]
305
- for element_instance in global_leap_binder.setup_container.element_instances:
306
- if element_instance.input_name == input_name:
307
- return element_instance.instance_function(idx, preprocess_state), element_instance.analysis_type
308
-
309
- return None, None
310
-
311
-
312
301
  def _get_gt(self, state: DataStateEnum, idx: int) -> Dict[str, npt.NDArray[np.float32]]:
313
302
  return self._get_dataset_handlers(global_leap_binder.setup_container.ground_truths, state, idx)
314
303
 
@@ -1,4 +1,5 @@
1
1
  from enum import Enum
2
+ from typing import List
2
3
 
3
4
  import numpy as np
4
5
  import numpy.typing as npt
@@ -16,7 +17,7 @@ class DefaultVisualizer(Enum):
16
17
  ImageMask = 'ImageMask'
17
18
  TextMask = 'TextMask'
18
19
  RawData = 'RawData'
19
-
20
+
20
21
 
21
22
  def default_image_visualizer(data: npt.NDArray[np.float32]) -> LeapImage:
22
23
  rescaled_data = rescale_min_max(data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.40a1
3
+ Version: 1.0.42
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -0,0 +1,18 @@
1
+ LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
+ code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
3
+ code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ code_loader/contract/datasetclasses.py,sha256=81TmCcVol7768lzKUp70MatLLipR3ftcR9jgE1r8Yqo,5698
5
+ code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
6
+ code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
+ code_loader/contract/responsedataclasses.py,sha256=w7xVOv2S8Hyb5lqyomMGiKAWXDTSOG-FX1YW39bXD3A,3969
8
+ code_loader/contract/visualizer_classes.py,sha256=t0EKxTGFJoFBnKuMmjtFpj-L32xEFF3NV-E5XutVXa0,10118
9
+ code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
10
+ code_loader/inner_leap_binder/leapbinder.py,sha256=4qmgQPOSh4aQKS8ETCaYIOSP3KXP1dMQuaOO1BorhFo,24020
11
+ code_loader/leaploader.py,sha256=pUySweZetJ6SsubCcZlDCJpvWmUrm5YlPlkZWQxY1hQ,17289
12
+ code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
13
+ code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ code_loader/visualizers/default_visualizers.py,sha256=F2qRT6rNy7SOmr4QfqVxIkLlYEa00CwDkLJuA45lfSI,2254
15
+ code_loader-1.0.42.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
16
+ code_loader-1.0.42.dist-info/METADATA,sha256=u87wLNMaA4amzOearJmVSmw3bloCXJMOJqBZCet8Nok,768
17
+ code_loader-1.0.42.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
18
+ code_loader-1.0.42.dist-info/RECORD,,
@@ -1,18 +0,0 @@
1
- LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
- code_loader/__init__.py,sha256=V3DEXSN6Ie6PlGeSAbzjp9ufRj0XPJLpD7pDLLYxk6M,122
3
- code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=45g-iy_HEd1Q7LAMaAjNbYqJxgEwvTlNNC_V6hN5koU,5141
5
- code_loader/contract/enums.py,sha256=3jBkDBa7od9SrSZdozrvy4InvyuRk4Ov1yEMsQxc-Pg,1665
6
- code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
- code_loader/contract/responsedataclasses.py,sha256=WSHmFZWOFhGL1eED1u-aoRotPQg2owFQ-t3xSViWXSI,2808
8
- code_loader/contract/visualizer_classes.py,sha256=1FjVO744J_EMuJfHWXGdvSz6vl3Vu7iS3CDfs8MzEEQ,5138
9
- code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
10
- code_loader/inner_leap_binder/leapbinder.py,sha256=wf3yxyxhN8k_eFecmt7nTUqbD3uMEoGccZcrOilOftc,14054
11
- code_loader/leaploader.py,sha256=D1CzclYll0kg_2sTZ_tS1-BFWlMCmiJZHrQCMwRiNtg,17882
12
- code_loader/utils.py,sha256=61I4PgSl-ZBIe4DifLxMNlBELE-HQR2pB9efVYPceIU,2230
13
- code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- code_loader/visualizers/default_visualizers.py,sha256=HqWx2qfTrroGl2n8Fpmr_4X-rk7tE2oGapjO3gzz4WY,2226
15
- code_loader-1.0.40a1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
16
- code_loader-1.0.40a1.dist-info/METADATA,sha256=o9NkQbpAF3lZIRPc147tORbE0NHH1gBJ7igKxKvDCIg,770
17
- code_loader-1.0.40a1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
18
- code_loader-1.0.40a1.dist-info/RECORD,,