code-loader 1.0.64a1__py3-none-any.whl → 1.0.64.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,83 @@
1
+ # mypy: ignore-errors
2
+ import traceback
3
+ from dataclasses import dataclass
4
+
5
+ from typing import List, Tuple, Optional
6
+
7
+ from multiprocessing import Process, Queue
8
+
9
+ from code_loader.leap_loader_parallelized_base import LeapLoaderParallelizedBase
10
+ from code_loader.leaploader import LeapLoader
11
+ from code_loader.contract.enums import DataStateEnum
12
+ from code_loader.metric_calculator_parallelized import MetricCalculatorParallelized
13
+ from code_loader.samples_generator_parallelized import SamplesGeneratorParallelized
14
+
15
+
16
+ @dataclass
17
+ class SampleSerializableError:
18
+ state: DataStateEnum
19
+ index: int
20
+ leap_script_trace: str
21
+ exception_as_str: str
22
+
23
+
24
+ class CodeIntegrationProcessesManager:
25
+ def __init__(self, code_path: str, code_entry_name: str, n_workers: Optional[int] = 2,
26
+ max_samples_in_queue: int = 128) -> None:
27
+ self.metric_calculator_parallelized = MetricCalculatorParallelized(code_path, code_entry_name)
28
+ self.samples_generator_parallelized = SamplesGeneratorParallelized(code_path, code_entry_name)
29
+
30
+ def _create_and_start_process(self) -> Process:
31
+ process = self.multiprocessing_context.Process(
32
+ target=CodeIntegrationProcessesManager._process_func,
33
+ args=(self.code_path, self.code_entry_name, self._inputs_waiting_to_be_process,
34
+ self._ready_processed_results))
35
+ process.daemon = True
36
+ process.start()
37
+ return process
38
+
39
+ def _run_and_warm_first_process(self):
40
+ process = self._create_and_start_process()
41
+ self.processes = [process]
42
+
43
+ # needed in order to make sure the preprocess func runs once in nonparallel
44
+ self._start_process_inputs([(DataStateEnum.training, 0)])
45
+ self._get_next_ready_processed_result()
46
+
47
+ def _operation_decider(self):
48
+ if self.metric_calculator_parallelized._ready_processed_results.empty() and not \
49
+ self.metric_calculator_parallelized._inputs_waiting_to_be_process.empty():
50
+ return 'metric'
51
+
52
+ if self.samples_generator_parallelized._ready_processed_results.empty() and not \
53
+ self.samples_generator_parallelized._inputs_waiting_to_be_process.empty():
54
+ return 'dataset'
55
+
56
+
57
+
58
+
59
+ @staticmethod
60
+ def _process_func(code_path: str, code_entry_name: str,
61
+ samples_to_process: Queue, ready_samples: Queue,
62
+ metrics_to_process: Queue, ready_metrics: Queue) -> None:
63
+ import os
64
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
65
+
66
+ leap_loader = LeapLoader(code_path, code_entry_name)
67
+ while True:
68
+
69
+ # decide on sample or metric to process
70
+ state, idx = samples_to_process.get(block=True)
71
+ leap_loader._preprocess_result()
72
+ try:
73
+ sample = leap_loader.get_sample(state, idx)
74
+ except Exception as e:
75
+ leap_script_trace = traceback.format_exc().split('File "<string>"')[-1]
76
+ ready_samples.put(SampleSerializableError(state, idx, leap_script_trace, str(e)))
77
+ continue
78
+
79
+ ready_samples.put(sample)
80
+
81
+ def generate_samples(self, sample_identities: List[Tuple[DataStateEnum, int]]):
82
+ return self.start_process_inputs(sample_identities)
83
+
@@ -117,31 +117,46 @@ MetricCallableReturnType = Union[Any, List[List[ConfusionMatrixElement]]]
117
117
 
118
118
 
119
119
  @dataclass
120
- class CustomLossHandler:
120
+ class CustomLossHandlerData:
121
121
  name: str
122
- function: CustomCallableInterface
123
122
  arg_names: List[str]
124
123
 
125
124
 
126
125
  @dataclass
127
- class MetricHandler:
126
+ class CustomLossHandler:
127
+ custom_loss_handler_data: CustomLossHandlerData
128
+ function: CustomCallableInterface
129
+
130
+
131
+ @dataclass
132
+ class MetricHandlerData:
128
133
  name: str
129
- function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
130
134
  arg_names: List[str]
131
135
  direction: Optional[MetricDirection] = MetricDirection.Downward
132
136
 
133
137
 
138
+ @dataclass
139
+ class MetricHandler:
140
+ metric_handler_data: MetricHandlerData
141
+ function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
142
+
143
+
134
144
  @dataclass
135
145
  class RawInputsForHeatmap:
136
146
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
137
147
 
138
148
 
139
149
  @dataclass
140
- class VisualizerHandler:
150
+ class VisualizerHandlerData:
141
151
  name: str
142
- function: VisualizerCallableInterface
143
152
  type: LeapDataType
144
153
  arg_names: List[str]
154
+
155
+
156
+ @dataclass
157
+ class VisualizerHandler:
158
+ visualizer_handler_data: VisualizerHandlerData
159
+ function: VisualizerCallableInterface
145
160
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None
146
161
 
147
162
 
@@ -1,4 +1,4 @@
1
- from typing import List, Any, Union, Optional
1
+ from typing import List, Any, Union
2
2
 
3
3
  import numpy as np
4
4
  import numpy.typing as npt
@@ -112,12 +112,11 @@ class LeapText:
112
112
 
113
113
  Example:
114
114
  text_data = ['I', 'ate', 'a', 'banana', '', '', '']
115
- heatmap = [0.1, 0.3, 0.2, 0.9, 0.0, 0.0, 0.0]
116
- leap_text = LeapText(data=text_data heatmap=heatmap) # Create LeapText object
115
+ leap_text = LeapText(data=text_data) # Create LeapText object
116
+ LeapText(leap_text)
117
117
  """
118
118
  data: List[str]
119
119
  type: LeapDataType = LeapDataType.Text
120
- heatmap: Optional[List[float]] = None
121
120
 
122
121
  def __post_init__(self) -> None:
123
122
  validate_type(self.type, LeapDataType.Text)
@@ -125,15 +124,6 @@ class LeapText:
125
124
  for value in self.data:
126
125
  validate_type(type(value), str)
127
126
 
128
- if self.heatmap is not None:
129
- validate_type(type(self.heatmap), list)
130
- for value in self.heatmap:
131
- validate_type(type(value), float)
132
- if len(self.heatmap) != len(self.data):
133
- raise LeapValidationError(
134
- f"Heatmap length ({len(self.heatmap)}) must match the number of tokens in `data` ({len(self.data)})."
135
- )
136
-
137
127
 
138
128
  @dataclass
139
129
  class LeapHorizontalBar:
@@ -0,0 +1,81 @@
1
+ from enum import Enum
2
+ from typing import List, Tuple
3
+ import numpy as np
4
+
5
+ from code_loader.contract.datasetclasses import ConfusionMatrixElement # type: ignore
6
+ from code_loader.contract.enums import ConfusionMatrixValue, MetricDirection # type: ignore
7
+
8
+
9
+ class Metric(Enum):
10
+ MeanSquaredError = 'MeanSquaredError'
11
+ MeanSquaredLogarithmicError = 'MeanSquaredLogarithmicError'
12
+ MeanAbsoluteError = 'MeanAbsoluteError'
13
+ MeanAbsolutePercentageError = 'MeanAbsolutePercentageError'
14
+ Accuracy = 'Accuracy'
15
+ ConfusionMatrixClassification = 'ConfusionMatrixClassification'
16
+
17
+
18
+ def accuracy_reduced(ground_truth: np.array, prediction: np.array) -> np.array:
19
+ ground_truth, prediction = flatten_non_batch_dims(ground_truth, prediction)
20
+ return np.mean((np.round(prediction).astype(np.bool_) == ground_truth.astype(np.bool_)), axis=1)
21
+
22
+
23
+ def mean_squared_error_dimension_reduced(ground_truth: np.array, prediction: np.array) -> np.array:
24
+ ground_truth, prediction = flatten_non_batch_dims(ground_truth, prediction)
25
+ return ((ground_truth - prediction) ** 2).mean(axis=1).astype(np.float32)
26
+
27
+
28
+ def mean_absolute_error_dimension_reduced(ground_truth: np.array, prediction: np.array) -> np.array:
29
+ ground_truth, prediction = flatten_non_batch_dims(ground_truth, prediction)
30
+ return np.abs(ground_truth - prediction).mean(axis=1).astype(np.float32)
31
+
32
+
33
+ def mean_absolute_percentage_error_dimension_reduced(ground_truth: np.array, prediction: np.array) -> np.array:
34
+ ground_truth, prediction = flatten_non_batch_dims(ground_truth, prediction)
35
+ return (np.abs(ground_truth - prediction) / np.abs(ground_truth)).mean(axis=1).astype(np.float32)
36
+
37
+
38
+ def mean_squared_logarithmic_error_dimension_reduced(ground_truth: np.array, prediction: np.array) -> np.array:
39
+ ground_truth, prediction = flatten_non_batch_dims(ground_truth, prediction)
40
+ return np.mean((np.log(1 + ground_truth) - np.log(1 + prediction)) ** 2, axis=1).astype(np.float32)
41
+
42
+
43
+ def flatten_non_batch_dims(ground_truth: np.array, prediction: np.array) -> Tuple[np.array, np.array]:
44
+ batch_size = ground_truth.shape[0]
45
+ ground_truth = np.reshape(ground_truth, (batch_size, -1))
46
+ prediction = np.reshape(prediction, (batch_size, -1))
47
+ return ground_truth, prediction
48
+
49
+
50
+ def confusion_matrix_classification_metric(ground_truth, prediction) -> List[List[ConfusionMatrixElement]]:
51
+ num_labels = prediction.shape[-1]
52
+ labels = [str(i) for i in range(num_labels)]
53
+ if len(labels) == 1:
54
+ labels = ['0', '1']
55
+ ground_truth = np.concatenate([1 - ground_truth, ground_truth], axis=1)
56
+ prediction = np.concatenate([1 - prediction, prediction], axis=1)
57
+
58
+ ret = []
59
+ for batch_i in range(ground_truth.shape[0]):
60
+ one_hot_vec = list(ground_truth[batch_i])
61
+ pred_vec = list(prediction[batch_i])
62
+ confusion_matrix_elements = []
63
+ for i, label in enumerate(labels):
64
+ expected_outcome = ConfusionMatrixValue.Positive if int(
65
+ one_hot_vec[i]) == 1 else ConfusionMatrixValue.Negative
66
+ cm_element = ConfusionMatrixElement(label, expected_outcome, float(pred_vec[i]))
67
+ confusion_matrix_elements.append(cm_element)
68
+ ret.append(confusion_matrix_elements)
69
+ return ret
70
+
71
+
72
+ metrics_names_to_functions_and_direction = {
73
+ Metric.MeanSquaredError.name: (mean_squared_error_dimension_reduced, MetricDirection.Downward),
74
+ Metric.MeanSquaredLogarithmicError.name: (
75
+ mean_squared_logarithmic_error_dimension_reduced, MetricDirection.Downward),
76
+ Metric.MeanAbsoluteError.name: (mean_absolute_error_dimension_reduced, MetricDirection.Downward),
77
+ Metric.MeanAbsolutePercentageError.name: (
78
+ mean_absolute_percentage_error_dimension_reduced, MetricDirection.Downward),
79
+ Metric.Accuracy.name: (accuracy_reduced, MetricDirection.Upward),
80
+ Metric.ConfusionMatrixClassification.name: (confusion_matrix_classification_metric, None),
81
+ }
@@ -9,10 +9,12 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
9
9
  PreprocessHandler, VisualizerCallableInterface, CustomLossHandler, CustomCallableInterface, PredictionTypeHandler, \
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
- CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, RawInputsForHeatmap
12
+ CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
+ RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData
13
14
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
14
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
15
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
17
+ from code_loader.default_metrics import metrics_names_to_functions_and_direction
16
18
  from code_loader.utils import to_numpy_return_wrapper, get_shape
17
19
  from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
18
20
  default_graph_visualizer, \
@@ -36,6 +38,7 @@ class LeapBinder:
36
38
  self._visualizer_names: List[str] = list()
37
39
  self._encoder_names: List[str] = list()
38
40
  self._extend_with_default_visualizers()
41
+ self._extend_with_default_metrics()
39
42
 
40
43
  self.batch_size_to_validate: Optional[int] = None
41
44
 
@@ -55,6 +58,10 @@ class LeapBinder:
55
58
  self.set_visualizer(function=default_text_mask_visualizer, name=DefaultVisualizer.TextMask.value,
56
59
  visualizer_type=LeapDataType.TextMask)
57
60
 
61
+ def _extend_with_default_metrics(self) -> None:
62
+ for metric_name, (func, direction) in metrics_names_to_functions_and_direction.items():
63
+ self.add_custom_metric(func, metric_name, direction)
64
+
58
65
  def set_visualizer(self, function: VisualizerCallableInterface,
59
66
  name: str,
60
67
  visualizer_type: LeapDataType,
@@ -127,7 +134,7 @@ class LeapBinder:
127
134
  f'should be {expected_return_type}')
128
135
 
129
136
  self.setup_container.visualizers.append(
130
- VisualizerHandler(name, function, visualizer_type, arg_names, heatmap_visualizer))
137
+ VisualizerHandler(VisualizerHandlerData(name, visualizer_type, arg_names), function, heatmap_visualizer))
131
138
  self._visualizer_names.append(name)
132
139
 
133
140
  def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
@@ -226,7 +233,7 @@ class LeapBinder:
226
233
  leap_binder.add_custom_loss(custom_loss_function, name='custom_loss')
227
234
  """
228
235
  arg_names = inspect.getfullargspec(function)[0]
229
- self.setup_container.custom_loss_handlers.append(CustomLossHandler(name, function, arg_names))
236
+ self.setup_container.custom_loss_handlers.append(CustomLossHandler(CustomLossHandlerData(name, arg_names), function))
230
237
 
231
238
  def add_custom_metric(self,
232
239
  function: Union[CustomCallableInterfaceMultiArgs,
@@ -252,7 +259,7 @@ class LeapBinder:
252
259
  leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
253
260
  """
254
261
  arg_names = inspect.getfullargspec(function)[0]
255
- self.setup_container.metrics.append(MetricHandler(name, function, arg_names, direction))
262
+ self.setup_container.metrics.append(MetricHandler(MetricHandlerData(name, arg_names, direction), function))
256
263
 
257
264
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
258
265
  """
@@ -22,7 +22,7 @@ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = M
22
22
  ConfusionMatrixCallableInterfaceMultiArgs]
23
23
  ):
24
24
  for metric_handler in leap_binder.setup_container.metrics:
25
- if metric_handler.name == name:
25
+ if metric_handler.metric_handler_data.name == name:
26
26
  raise Exception(f'Metric with name {name} already exists. '
27
27
  f'Please choose another')
28
28
 
@@ -94,7 +94,7 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
94
94
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None):
95
95
  def decorating_function(user_function: VisualizerCallableInterface):
96
96
  for viz_handler in leap_binder.setup_container.visualizers:
97
- if viz_handler.name == name:
97
+ if viz_handler.visualizer_handler_data.name == name:
98
98
  raise Exception(f'Visualizer with name {name} already exists. '
99
99
  f'Please choose another')
100
100
 
@@ -334,7 +334,7 @@ def tensorleap_gt_encoder(name: str):
334
334
  def tensorleap_custom_loss(name: str):
335
335
  def decorating_function(user_function: CustomCallableInterface):
336
336
  for loss_handler in leap_binder.setup_container.custom_loss_handlers:
337
- if loss_handler.name == name:
337
+ if loss_handler.custom_loss_handler_data.name == name:
338
338
  raise Exception(f'Custom loss with name {name} already exists. '
339
339
  f'Please choose another')
340
340
 
code_loader/leaploader.py CHANGED
@@ -1,32 +1,33 @@
1
1
  # mypy: ignore-errors
2
2
  import importlib.util
3
+ import inspect
3
4
  import io
4
5
  import sys
5
- import time
6
6
  from contextlib import redirect_stdout
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Iterable, Union, Any, Type
9
+ from typing import Dict, List, Iterable, Union, Any, Type, Optional
10
10
 
11
11
  import numpy as np
12
12
  import numpy.typing as npt
13
13
 
14
14
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
15
- PreprocessResponse, VisualizerHandler, LeapData, CustomLossHandler, \
16
- PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler
15
+ PreprocessResponse, VisualizerHandler, LeapData, \
16
+ PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData, \
17
+ MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler, RawInputsForHeatmap
17
18
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
18
19
  from code_loader.contract.exceptions import DatasetScriptException
19
20
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
20
21
  DatasetPreprocess, DatasetSetup, DatasetInputInstance, DatasetOutputInstance, DatasetMetadataInstance, \
21
22
  VisualizerInstance, PredictionTypeInstance, ModelSetup, CustomLayerInstance, MetricInstance, CustomLossInstance
22
23
  from code_loader.inner_leap_binder import global_leap_binder
24
+ from code_loader.leaploaderbase import LeapLoaderBase
23
25
  from code_loader.utils import get_root_exception_file_and_line_number
24
26
 
25
27
 
26
- class LeapLoader:
28
+ class LeapLoader(LeapLoaderBase):
27
29
  def __init__(self, code_path: str, code_entry_name: str):
28
- self.code_entry_name = code_entry_name
29
- self.code_path = code_path
30
+ super().__init__(code_path, code_entry_name)
30
31
 
31
32
  self._preprocess_result_cached = None
32
33
 
@@ -66,29 +67,56 @@ class LeapLoader:
66
67
  spec.loader.exec_module(file)
67
68
 
68
69
  @lru_cache()
69
- def metric_by_name(self) -> Dict[str, MetricHandler]:
70
+ def metric_by_name(self) -> Dict[str, MetricHandlerData]:
70
71
  self.exec_script()
71
72
  setup = global_leap_binder.setup_container
72
73
  return {
73
- metric_handler.name: metric_handler
74
+ metric_handler.metric_handler_data.name: metric_handler.metric_handler_data
74
75
  for metric_handler in setup.metrics
75
76
  }
76
77
 
77
78
  @lru_cache()
78
- def visualizer_by_name(self) -> Dict[str, VisualizerHandler]:
79
+ def _metric_handler_by_name(self) -> Dict[str, MetricHandler]:
79
80
  self.exec_script()
80
81
  setup = global_leap_binder.setup_container
81
82
  return {
82
- visualizer_handler.name: visualizer_handler
83
+ metric_handler.metric_handler_data.name: metric_handler
84
+ for metric_handler in setup.metrics
85
+ }
86
+
87
+ @lru_cache()
88
+ def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
89
+ self.exec_script()
90
+ setup = global_leap_binder.setup_container
91
+ return {
92
+ visualizer_handler.visualizer_handler_data.name: visualizer_handler.visualizer_handler_data
83
93
  for visualizer_handler in setup.visualizers
84
94
  }
85
95
 
86
96
  @lru_cache()
87
- def custom_loss_by_name(self) -> Dict[str, CustomLossHandler]:
97
+ def _visualizer_handler_by_name(self) -> Dict[str, VisualizerHandler]:
98
+ self.exec_script()
99
+ setup = global_leap_binder.setup_container
100
+ return {
101
+ visualizer_handler.visualizer_handler_data.name: visualizer_handler
102
+ for visualizer_handler in setup.visualizers
103
+ }
104
+
105
+ @lru_cache()
106
+ def custom_loss_by_name(self) -> Dict[str, CustomLossHandlerData]:
107
+ self.exec_script()
108
+ setup = global_leap_binder.setup_container
109
+ return {
110
+ custom_loss_handler.custom_loss_handler_data.name: custom_loss_handler.custom_loss_handler_data
111
+ for custom_loss_handler in setup.custom_loss_handlers
112
+ }
113
+
114
+ @lru_cache()
115
+ def _custom_loss_handler_by_name(self) -> Dict[str, CustomLossHandler]:
88
116
  self.exec_script()
89
117
  setup = global_leap_binder.setup_container
90
118
  return {
91
- custom_loss_handler.name: custom_loss_handler
119
+ custom_loss_handler.custom_loss_handler_data.name: custom_loss_handler
92
120
  for custom_loss_handler in setup.custom_loss_handlers
93
121
  }
94
122
 
@@ -200,20 +228,41 @@ class LeapLoader:
200
228
  all_dataset_base_handlers.extend(global_leap_binder.setup_container.metadata)
201
229
  return all_dataset_base_handlers
202
230
 
203
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
204
- ) -> LeapData:
205
- return self.visualizer_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
231
+ def run_metric(self, metric_name: str,
232
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
233
+ self._preprocess_result()
234
+ return self._metric_handler_by_name()[metric_name].function(**input_tensors_by_arg_name)
235
+
236
+ def run_custom_loss(self, custom_loss_name: str,
237
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
238
+ return self._custom_loss_handler_by_name()[custom_loss_name].function(**input_tensors_by_arg_name)
239
+
240
+ def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
241
+ # running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
242
+ # global param that visualizer is using)
243
+ self._preprocess_result()
244
+
245
+ return self._visualizer_handler_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
206
246
 
207
247
  def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
208
248
  ) -> npt.NDArray[np.float32]:
209
- heatmap_function = self.visualizer_by_name()[visualizer_name].heatmap_function
249
+ heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
210
250
  if heatmap_function is None:
211
251
  assert len(input_tensors_by_arg_name) == 1
212
252
  return list(input_tensors_by_arg_name.values())[0]
213
253
  return heatmap_function(**input_tensors_by_arg_name)
214
254
 
215
- @staticmethod
216
- def get_dataset_setup_response(handlers_test_payloads: List[DatasetTestResultPayload]) -> DatasetSetup:
255
+ def get_heatmap_visualizer_raw_vis_input_arg_name(self, visualizer_name: str) -> Optional[str]:
256
+ heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
257
+ if heatmap_function is None:
258
+ return None
259
+
260
+ for arg_name, arg_type in inspect.getfullargspec(heatmap_function).annotations.items():
261
+ if arg_type == RawInputsForHeatmap:
262
+ return arg_name
263
+ return None
264
+
265
+ def get_dataset_setup_response(self, handlers_test_payloads: List[DatasetTestResultPayload]) -> DatasetSetup:
217
266
  setup = global_leap_binder.setup_container
218
267
  assert setup.preprocess is not None
219
268
 
@@ -262,10 +311,13 @@ class LeapLoader:
262
311
  type=dataset_metadata_type))
263
312
 
264
313
  visualizers = [
265
- VisualizerInstance(visualizer_handler.name, visualizer_handler.type, visualizer_handler.arg_names)
314
+ VisualizerInstance(
315
+ visualizer_handler.visualizer_handler_data.name, visualizer_handler.visualizer_handler_data.type,
316
+ visualizer_handler.visualizer_handler_data.arg_names)
266
317
  for visualizer_handler in setup.visualizers]
267
318
 
268
- custom_losses = [CustomLossInstance(custom_loss.name, custom_loss.arg_names)
319
+ custom_losses = [CustomLossInstance(custom_loss.custom_loss_handler_data.name,
320
+ custom_loss.custom_loss_handler_data.arg_names)
269
321
  for custom_loss in setup.custom_loss_handlers]
270
322
 
271
323
  prediction_types = []
@@ -276,15 +328,14 @@ class LeapLoader:
276
328
 
277
329
  metrics = []
278
330
  for metric in setup.metrics:
279
- metric_inst = MetricInstance(metric.name, metric.arg_names)
331
+ metric_inst = MetricInstance(metric.metric_handler_data.name, metric.metric_handler_data.arg_names)
280
332
  metrics.append(metric_inst)
281
333
 
282
334
  return DatasetSetup(preprocess=dataset_preprocess, inputs=inputs, outputs=ground_truths,
283
335
  metadata=metadata_instances, visualizers=visualizers, prediction_types=prediction_types,
284
336
  custom_losses=custom_losses, metrics=metrics)
285
337
 
286
- @staticmethod
287
- def get_model_setup_response() -> ModelSetup:
338
+ def get_model_setup_response(self) -> ModelSetup:
288
339
  setup = global_leap_binder.setup_container
289
340
  custom_layer_instances = [
290
341
  CustomLayerInstance(custom_layer_handler.name, custom_layer_handler.init_arg_names,
@@ -389,3 +440,4 @@ class LeapLoader:
389
440
 
390
441
  return id_type
391
442
 
443
+
@@ -0,0 +1,95 @@
1
+ # mypy: ignore-errors
2
+
3
+ from abc import abstractmethod
4
+
5
+ from typing import Dict, List, Union, Type, Optional
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+
10
+ from code_loader.contract.datasetclasses import DatasetSample, LeapData, \
11
+ PredictionTypeHandler, CustomLayerHandler, VisualizerHandlerData, MetricHandlerData, MetricCallableReturnType, \
12
+ CustomLossHandlerData
13
+ from code_loader.contract.enums import DataStateEnum
14
+ from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
15
+ DatasetSetup, ModelSetup
16
+
17
+
18
+ class LeapLoaderBase:
19
+ def __init__(self, code_path: str, code_entry_name: str):
20
+ self.code_entry_name = code_entry_name
21
+ self.code_path = code_path
22
+
23
+ @abstractmethod
24
+ def metric_by_name(self) -> Dict[str, MetricHandlerData]:
25
+ pass
26
+
27
+ @abstractmethod
28
+ def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
29
+ pass
30
+
31
+ @abstractmethod
32
+ def custom_loss_by_name(self) -> Dict[str, CustomLossHandlerData]:
33
+ pass
34
+
35
+ @abstractmethod
36
+ def custom_layers(self) -> Dict[str, CustomLayerHandler]:
37
+ pass
38
+
39
+ @abstractmethod
40
+ def prediction_type_by_name(self) -> Dict[str, PredictionTypeHandler]:
41
+ pass
42
+
43
+ @abstractmethod
44
+ def get_sample(self, state: DataStateEnum, sample_id: Union[int, str]) -> DatasetSample:
45
+ pass
46
+
47
+ @abstractmethod
48
+ def check_dataset(self) -> DatasetIntegParseResult:
49
+ pass
50
+
51
+ @abstractmethod
52
+ def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
53
+ pass
54
+
55
+ @abstractmethod
56
+ def run_metric(self, metric_name: str,
57
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
58
+ pass
59
+
60
+ @abstractmethod
61
+ def run_custom_loss(self, custom_loss_name: str,
62
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
63
+ pass
64
+
65
+ @abstractmethod
66
+ def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
67
+ ) -> npt.NDArray[np.float32]:
68
+ pass
69
+
70
+ @abstractmethod
71
+ def get_dataset_setup_response(self, handlers_test_payloads: List[DatasetTestResultPayload]) -> DatasetSetup:
72
+ pass
73
+
74
+ @abstractmethod
75
+ def get_model_setup_response(self) -> ModelSetup:
76
+ pass
77
+
78
+ @abstractmethod
79
+ def get_preprocess_sample_ids(
80
+ self, update_unlabeled_preprocess=False) -> Dict[DataStateEnum, Union[List[int], List[str]]]:
81
+ pass
82
+
83
+ @abstractmethod
84
+ def get_sample_id_type(self) -> Type:
85
+ pass
86
+
87
+ @abstractmethod
88
+ def get_heatmap_visualizer_raw_vis_input_arg_name(self, visualizer_name: str) -> Optional[str]:
89
+ pass
90
+
91
+ def is_custom_latent_space(self) -> bool:
92
+ if not self.code_entry_name or not self.code_path:
93
+ return False
94
+ custom_layers = self.custom_layers()
95
+ return any(layer.use_custom_latent_space for layer in custom_layers.values())
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.64a1
3
+ Version: 1.0.64.dev2
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,11 +1,13 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
+ code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
3
4
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=cd6fRDC4XLrJa7PcrzoTPIKtGwFZq09DGwiSC5BSKvk,6705
5
+ code_loader/contract/datasetclasses.py,sha256=L_fSdSvf-eKoez2uBJ8VjfrKedEP0szNOPvaUvsWeRQ,6973
5
6
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
6
7
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
8
  code_loader/contract/responsedataclasses.py,sha256=RSx9m_R3LawhK5o1nAcO3hfp2F9oJYtxZr_bpP3bTmw,4005
8
- code_loader/contract/visualizer_classes.py,sha256=uK3Wl6_jN40I9jEZPrLd5396XSr9MxzLzAAIY8W80pE,12436
9
+ code_loader/contract/visualizer_classes.py,sha256=iIa_O2rKvPTwN5ILCTZvRpsGYiiFABKdwQwfIXGigDo,11928
10
+ code_loader/default_metrics.py,sha256=2kSaB71OrbQIXlcTMStYUebL8N-8bE9se3m-AXVdvCY,3936
9
11
  code_loader/experiment_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
12
  code_loader/experiment_api/api.py,sha256=a7wh6Hhe7IaVxu46eV2soSz-yxnmXG3ipU1BBtsEAaQ,2493
11
13
  code_loader/experiment_api/cli_config_utils.py,sha256=n6JMyNrquxql3KKxHhAP8jAzezlRT-PV2KWI95kKsm0,1140
@@ -17,13 +19,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
17
19
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
18
20
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
19
21
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
20
- code_loader/inner_leap_binder/leapbinder.py,sha256=LVzpynjISO-a774flzGt1yAQPsSYNE8B5V58Hacs7bQ,25216
21
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=AHU88W-UJ5alHOA3pN2oq7lIMiyYWdFSM0BL2WR-FGs,20998
22
- code_loader/leaploader.py,sha256=Tpf6A25hYuo4D0umGL3BHNYJhmz_NIwvFveQgAlsSOo,19534
22
+ code_loader/inner_leap_binder/leapbinder.py,sha256=y_k7bRFYYmrZo4jCJrZ6mJykxc1slKDkODYCZ58OPs0,25691
23
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=I6ipji6QMN9qqFYxqQyNjtDtsnwpB-NahgKKpLPctMo,21067
24
+ code_loader/leaploader.py,sha256=vssXiifNEGkfapjDtVd0Lpg3WKQ6qxYyCdoeGWvxiNc,22240
25
+ code_loader/leaploaderbase.py,sha256=I1xwgGnDlvfyoOhlpKXXEuCqJUs0jqPzk3DTQaE6riQ,3080
23
26
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
24
27
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
28
  code_loader/visualizers/default_visualizers.py,sha256=VoqO9FN84yXyMjRjHjUTOt2GdTkJRMbHbXJ1cJkREkk,2230
26
- code_loader-1.0.64a1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
27
- code_loader-1.0.64a1.dist-info/METADATA,sha256=d9JZm1AhWAexNMGsFkcqNbcYQ-xPTg8abDnd6RzRe8Q,851
28
- code_loader-1.0.64a1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
29
- code_loader-1.0.64a1.dist-info/RECORD,,
29
+ code_loader-1.0.64.dev2.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
+ code_loader-1.0.64.dev2.dist-info/METADATA,sha256=vKLtUJvnmvuG4wZ6pkkM54P6Z54v1IOd_PnGwYDJl-s,854
31
+ code_loader-1.0.64.dev2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
+ code_loader-1.0.64.dev2.dist-info/RECORD,,