code-loader 1.0.71__py3-none-any.whl → 1.0.72.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -132,7 +132,7 @@ class CustomLossHandler:
132
132
  class MetricHandlerData:
133
133
  name: str
134
134
  arg_names: List[str]
135
- direction: Optional[MetricDirection] = MetricDirection.Downward
135
+ direction: Union[None, MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward
136
136
 
137
137
 
138
138
  @dataclass
@@ -146,6 +146,12 @@ class RawInputsForHeatmap:
146
146
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
147
147
 
148
148
 
149
+ @dataclass
150
+ class SamplePreprocessResponse:
151
+ sample_ids: np.array
152
+ preprocess_response: PreprocessResponse
153
+
154
+
149
155
  @dataclass
150
156
  class VisualizerHandlerData:
151
157
  name: str
@@ -171,6 +177,7 @@ class InputHandler(DatasetBaseHandler):
171
177
  shape: Optional[List[int]] = None
172
178
  channel_dim: Optional[int] = -1
173
179
 
180
+
174
181
  @dataclass
175
182
  class GroundTruthHandler(DatasetBaseHandler):
176
183
  shape: Optional[List[int]] = None
@@ -170,12 +170,13 @@ class LeapHorizontalBar:
170
170
  Example:
171
171
  body_data = np.random.rand(5).astype(np.float32)
172
172
  labels = ['Class A', 'Class B', 'Class C', 'Class D', 'Class E']
173
- leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels)
173
+ gt_data = np.array([0.1, 0.2, 0.3, 0.4, 0.5]).astype(np.float32)
174
+ leap_horizontal_bar = LeapHorizontalBar(body=body_data, labels=labels, gt=gt_data)
174
175
  """
175
176
  body: npt.NDArray[np.float32]
176
177
  labels: List[str]
177
- type: LeapDataType = LeapDataType.HorizontalBar
178
178
  gt: Optional[npt.NDArray[np.float32]] = None
179
+ type: LeapDataType = LeapDataType.HorizontalBar
179
180
 
180
181
 
181
182
  def __post_init__(self) -> None:
@@ -10,7 +10,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
10
10
  MetadataSectionCallableInterface, UnlabeledDataPreprocessHandler, CustomLayerHandler, MetricHandler, \
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
12
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
- RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData
13
+ RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse
14
14
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
15
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
@@ -33,6 +33,7 @@ class LeapBinder:
33
33
  setup_container (DatasetIntegrationSetup): Container to hold setup configurations.
34
34
  cache_container (Dict[str, Any]): Cache container to store intermediate data.
35
35
  """
36
+
36
37
  def __init__(self) -> None:
37
38
  self.setup_container = DatasetIntegrationSetup()
38
39
  self.cache_container: Dict[str, Any] = {"word_to_index": {}}
@@ -238,26 +239,41 @@ class LeapBinder:
238
239
 
239
240
  leap_binder.add_custom_loss(custom_loss_function, name='custom_loss')
240
241
  """
241
- arg_names = inspect.getfullargspec(function)[0]
242
- self.setup_container.custom_loss_handlers.append(CustomLossHandler(CustomLossHandlerData(name, arg_names), function))
242
+
243
+ regular_arg_names = []
244
+ preprocess_response_arg_name = None
245
+ for arg_name, arg_type in inspect.getfullargspec(function).annotations.items():
246
+ if arg_type == SamplePreprocessResponse:
247
+ if preprocess_response_arg_name is not None:
248
+ raise Exception("only one argument can be of type SamplePreprocessResponse")
249
+ preprocess_response_arg_name = arg_name
250
+ else:
251
+ regular_arg_names.append(arg_name)
252
+
253
+ self.setup_container.custom_loss_handlers.append(
254
+ CustomLossHandler(CustomLossHandlerData(name, regular_arg_names), function))
243
255
 
244
256
  def add_custom_metric(self,
245
257
  function: Union[CustomCallableInterfaceMultiArgs,
246
258
  CustomMultipleReturnCallableInterfaceMultiArgs,
247
259
  ConfusionMatrixCallableInterfaceMultiArgs],
248
260
  name: str,
249
- direction: Optional[MetricDirection] = MetricDirection.Downward) -> None:
261
+ direction: Optional[
262
+ Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward) -> None:
250
263
  """
251
264
  Add a custom metric to the setup.
252
265
 
253
266
  Args:
254
267
  function (Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]): The custom metric function.
255
268
  name (str): The name of the custom metric.
256
- direction (Optional[MetricDirection]): The direction of the metric, either MetricDirection.Upward or MetricDirection.Downward.
269
+ direction (Optional[Union[MetricDirection, Dict[str, MetricDirection]]]): The direction of the metric, either
270
+ MetricDirection.Upward or MetricDirection.Downward, in case custom metric return a dictionary of metrics we can
271
+ supply a dictionary of directions correspondingly
257
272
  - MetricDirection.Upward: Indicates that higher values of the metric are better and should be maximized.
258
273
  - MetricDirection.Downward: Indicates that lower values of the metric are better and should be minimized.
259
274
 
260
275
 
276
+
261
277
  Example:
262
278
  def custom_metric_function(y_true, y_pred):
263
279
  return np.mean(np.abs(y_true - y_pred))
@@ -377,7 +393,8 @@ class LeapBinder:
377
393
  custom_layer.kernel_index = kernel_index
378
394
 
379
395
  if use_custom_latent_space and not hasattr(custom_layer, custom_latent_space_attribute):
380
- raise Exception(f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
396
+ raise Exception(
397
+ f"{custom_latent_space_attribute} function has not been set for custom layer: {custom_layer.__name__}")
381
398
 
382
399
  init_args = inspect.getfullargspec(custom_layer.__init__)[0][1:]
383
400
  call_args = inspect.getfullargspec(custom_layer.call)[0][1:]
@@ -490,7 +507,3 @@ class LeapBinder:
490
507
 
491
508
  def set_batch_size_to_validate(self, batch_size: int) -> None:
492
509
  self.batch_size_to_validate = batch_size
493
-
494
-
495
-
496
-
@@ -1,6 +1,6 @@
1
1
  # mypy: ignore-errors
2
2
 
3
- from typing import Optional, Union, Callable, List
3
+ from typing import Optional, Union, Callable, List, Dict
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
@@ -15,12 +15,11 @@ from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, Le
15
15
  LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
16
16
 
17
17
 
18
- def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = MetricDirection.Downward):
19
- def decorating_function(
20
- user_function: Union[CustomCallableInterfaceMultiArgs,
21
- CustomMultipleReturnCallableInterfaceMultiArgs,
22
- ConfusionMatrixCallableInterfaceMultiArgs]
23
- ):
18
+ def tensorleap_custom_metric(name: str,
19
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward):
20
+ def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
21
+ CustomMultipleReturnCallableInterfaceMultiArgs,
22
+ ConfusionMatrixCallableInterfaceMultiArgs]):
24
23
  for metric_handler in leap_binder.setup_container.metrics:
25
24
  if metric_handler.metric_handler_data.name == name:
26
25
  raise Exception(f'Metric with name {name} already exists. '
@@ -356,15 +355,15 @@ def tensorleap_custom_loss(name: str):
356
355
  f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
357
356
  else:
358
357
  assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
359
- f'Argument #{i} should be a numpy array. Got {type(arg)}.')
358
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
360
359
  for _arg_name, arg in kwargs.items():
361
360
  if isinstance(arg, list):
362
361
  for y, elem in enumerate(arg):
363
- assert isinstance(elem,valid_types), (f'tensorleap_custom_loss validation failed: '
364
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
362
+ assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
363
+ f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
365
364
  else:
366
365
  assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
367
- f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
366
+ f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
368
367
 
369
368
  def _validate_result(result):
370
369
  assert isinstance(result, valid_types), \
code_loader/leaploader.py CHANGED
@@ -14,7 +14,7 @@ import numpy.typing as npt
14
14
  from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandler, GroundTruthHandler, \
15
15
  PreprocessResponse, VisualizerHandler, LeapData, \
16
16
  PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData, \
17
- MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler, RawInputsForHeatmap
17
+ MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler, RawInputsForHeatmap, SamplePreprocessResponse
18
18
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
19
19
  from code_loader.contract.exceptions import DatasetScriptException
20
20
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
@@ -233,9 +233,19 @@ class LeapLoader(LeapLoaderBase):
233
233
  self._preprocess_result()
234
234
  return self._metric_handler_by_name()[metric_name].function(**input_tensors_by_arg_name)
235
235
 
236
- def run_custom_loss(self, custom_loss_name: str,
236
+ def run_custom_loss(self, custom_loss_name: str, sample_ids: np.array, state: DataStateEnum,
237
237
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
238
- return self._custom_loss_handler_by_name()[custom_loss_name].function(**input_tensors_by_arg_name)
238
+
239
+ custom_loss_handler = self._custom_loss_handler_by_name()[custom_loss_name]
240
+ preprocess_response_arg_name = None
241
+ for arg_name, arg_type in inspect.getfullargspec(custom_loss_handler.function).annotations.items():
242
+ if arg_type == SamplePreprocessResponse:
243
+ preprocess_response_arg_name = arg_name
244
+ break
245
+
246
+ if preprocess_response_arg_name is not None:
247
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids, self._preprocess_result()[state])
248
+ return custom_loss_handler.function(**input_tensors_by_arg_name)
239
249
 
240
250
  def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
241
251
  # running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
@@ -320,7 +330,8 @@ class LeapLoader(LeapLoaderBase):
320
330
  for visualizer_handler in setup.visualizers]
321
331
 
322
332
  custom_losses = [CustomLossInstance(custom_loss.custom_loss_handler_data.name,
323
- custom_loss.custom_loss_handler_data.arg_names)
333
+ custom_loss.custom_loss_handler_data.arg_names,
334
+ custom_loss.custom_loss_handler_data.preprocess_response_arg_name)
324
335
  for custom_loss in setup.custom_loss_handlers]
325
336
 
326
337
  prediction_types = []
@@ -20,6 +20,9 @@ class LeapLoaderBase:
20
20
  self.code_entry_name = code_entry_name
21
21
  self.code_path = code_path
22
22
 
23
+ self.current_working_sample_ids: Optional[np.array] = None
24
+ self.current_working_state: Optional[DataStateEnum] = None
25
+
23
26
  @abstractmethod
24
27
  def metric_by_name(self) -> Dict[str, MetricHandlerData]:
25
28
  pass
@@ -58,7 +61,7 @@ class LeapLoaderBase:
58
61
  pass
59
62
 
60
63
  @abstractmethod
61
- def run_custom_loss(self, custom_loss_name: str,
64
+ def run_custom_loss(self, custom_loss_name: str, sample_ids: np.array, state: DataStateEnum,
62
65
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
63
66
  pass
64
67
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.71
3
+ Version: 1.0.72.dev1
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,11 +1,11 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=L_fSdSvf-eKoez2uBJ8VjfrKedEP0szNOPvaUvsWeRQ,6973
4
+ code_loader/contract/datasetclasses.py,sha256=2m9wyF_SO_q3kfrrAELS8rseMX-veeaP-Aof0ZhG_7g,7119
5
5
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/responsedataclasses.py,sha256=RSx9m_R3LawhK5o1nAcO3hfp2F9oJYtxZr_bpP3bTmw,4005
8
- code_loader/contract/visualizer_classes.py,sha256=WpO5KF4FMGGMgsF1vIs7nbVev_Mr5-VE9XPakDk87zM,14092
8
+ code_loader/contract/visualizer_classes.py,sha256=m31lg2P2QJs3Reqr6-N1AlVhH3RxPr772Jw3LuIVCVM,14177
9
9
  code_loader/default_losses.py,sha256=NoOQym1106bDN5dcIk56Elr7ZG5quUHArqfP5-Nyxyo,1139
10
10
  code_loader/default_metrics.py,sha256=v16Mrt2Ze1tXPgfKywGVdRSrkaK4CKLNQztN1UdVqIY,5010
11
11
  code_loader/experiment_api/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -19,14 +19,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
19
19
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
20
20
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
21
21
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
22
- code_loader/inner_leap_binder/leapbinder.py,sha256=o57Pj-iY61-OBuTjK-jYUKCJ0g2pPWbbqitv_e75Bps,25959
23
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=S4XHoC4GVicUIhM0UAAsQ6qVmztD52L-Uxa-6OcZptA,20780
24
- code_loader/leaploader.py,sha256=GWlpvgSsCWevP2BwwFBKTImQeDgHAQg1lMU9bqFMwRw,22315
25
- code_loader/leaploaderbase.py,sha256=aHlqWDZRacIdBefeB9goYVnpApaNN2FT24uPIWKkCeQ,3090
22
+ code_loader/inner_leap_binder/leapbinder.py,sha256=una-6k_nnSeqrLNjCy3beXdje_MWXbCRLoCalKdNYbg,26693
23
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=ebMxknpKMW-dE8Erq0fFq4RrE5E_Jfx9IvmRRZSdhlc,20813
24
+ code_loader/leaploader.py,sha256=DCrSS2789dkkePROJSwhAhgzQvR4_27YAR__PxCcLNc,23025
25
+ code_loader/leaploaderbase.py,sha256=FDXjTFBjnpfUyZ2tS3wLo7y2j82Qb_VKHI1BMB5gFww,3269
26
26
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
27
27
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
29
- code_loader-1.0.71.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
- code_loader-1.0.71.dist-info/METADATA,sha256=5tErPPtvKSmWqVro-YKhWqz6CbGT4z-8KNVYZdDqtGY,849
31
- code_loader-1.0.71.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
32
- code_loader-1.0.71.dist-info/RECORD,,
29
+ code_loader-1.0.72.dev1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
+ code_loader-1.0.72.dev1.dist-info/METADATA,sha256=WBFziucTTs-QHppp3QzdoSp_j9KXh1W9njNsHAGxxSo,854
31
+ code_loader-1.0.72.dev1.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
+ code_loader-1.0.72.dev1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.1
2
+ Generator: poetry-core 1.9.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any