code-loader 1.0.72.dev4__py3-none-any.whl → 1.0.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -133,6 +133,7 @@ class MetricHandlerData:
133
133
  name: str
134
134
  arg_names: List[str]
135
135
  direction: Union[None, MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward
136
+ compute_insights: Union[bool, Dict[str, bool]] = True
136
137
 
137
138
 
138
139
  @dataclass
@@ -148,7 +149,7 @@ class RawInputsForHeatmap:
148
149
 
149
150
  @dataclass
150
151
  class SamplePreprocessResponse:
151
- sample_ids: np.array
152
+ sample_ids: Union[npt.NDArray[np.float32], npt.NDArray[np.str_]]
152
153
  preprocess_response: PreprocessResponse
153
154
 
154
155
 
@@ -258,7 +258,8 @@ class LeapBinder:
258
258
  ConfusionMatrixCallableInterfaceMultiArgs],
259
259
  name: str,
260
260
  direction: Optional[
261
- Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward) -> None:
261
+ Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward,
262
+ compute_insights: Union[bool, Dict[str, bool]] = True) -> None:
262
263
  """
263
264
  Add a custom metric to the setup.
264
265
 
@@ -267,9 +268,11 @@ class LeapBinder:
267
268
  name (str): The name of the custom metric.
268
269
  direction (Optional[Union[MetricDirection, Dict[str, MetricDirection]]]): The direction of the metric, either
269
270
  MetricDirection.Upward or MetricDirection.Downward, in case custom metric return a dictionary of metrics we can
270
- supply a dictionary of directions correspondingly
271
+ supply a dictionary of directions correspondingly.
271
272
  - MetricDirection.Upward: Indicates that higher values of the metric are better and should be maximized.
272
273
  - MetricDirection.Downward: Indicates that lower values of the metric are better and should be minimized.
274
+ compute_insights (Union[bool, Dict[str, bool]]): Whether to compute insights or not. in case custom metric
275
+ return a dictionary of metrics we can supply a dictionary of values correspondingly
273
276
 
274
277
 
275
278
 
@@ -280,7 +283,8 @@ class LeapBinder:
280
283
  leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
281
284
  """
282
285
  arg_names = inspect.getfullargspec(function)[0]
283
- self.setup_container.metrics.append(MetricHandler(MetricHandlerData(name, arg_names, direction), function))
286
+ metric_handler_data = MetricHandlerData(name, arg_names, direction, compute_insights)
287
+ self.setup_container.metrics.append(MetricHandler(metric_handler_data, function))
284
288
 
285
289
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
286
290
  """
@@ -16,7 +16,8 @@ from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, Le
16
16
 
17
17
 
18
18
  def tensorleap_custom_metric(name: str,
19
- direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward):
19
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
20
+ compute_insights: Union[bool, Dict[str, bool]] = True):
20
21
  def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
21
22
  CustomMultipleReturnCallableInterfaceMultiArgs,
22
23
  ConfusionMatrixCallableInterfaceMultiArgs]):
@@ -25,7 +26,7 @@ def tensorleap_custom_metric(name: str,
25
26
  raise Exception(f'Metric with name {name} already exists. '
26
27
  f'Please choose another')
27
28
 
28
- leap_binder.add_custom_metric(user_function, name, direction)
29
+ leap_binder.add_custom_metric(user_function, name, direction, compute_insights)
29
30
 
30
31
  def _validate_input_args(*args, **kwargs) -> None:
31
32
  for i, arg in enumerate(args):
@@ -75,6 +76,19 @@ def tensorleap_custom_metric(name: str,
75
76
  (f'tensorleap_custom_metric validation failed: '
76
77
  f'Keys in the return dict should be of type str. Got {type(key)}.')
77
78
  _validate_single_metric(value)
79
+
80
+ if isinstance(direction, dict):
81
+ for direction_key in direction:
82
+ assert direction_key in result, \
83
+ (f'tensorleap_custom_metric validation failed: '
84
+ f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
85
+
86
+ if isinstance(compute_insights, dict):
87
+ for ci_key in compute_insights:
88
+ assert ci_key in result, \
89
+ (f'tensorleap_custom_metric validation failed: '
90
+ f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
91
+
78
92
  else:
79
93
  _validate_single_metric(result)
80
94
 
code_loader/leaploader.py CHANGED
@@ -6,7 +6,7 @@ import sys
6
6
  from contextlib import redirect_stdout
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Iterable, Union, Any, Type, Optional
9
+ from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable
10
10
 
11
11
  import numpy as np
12
12
  import numpy.typing as npt
@@ -213,7 +213,8 @@ class LeapLoader(LeapLoaderBase):
213
213
  preprocess_response, test_result, dataset_base_handler)
214
214
  except Exception as e:
215
215
  line_number, file_name, stacktrace = get_root_exception_file_and_line_number()
216
- test_result[0].display[state_name] = f"{repr(e)} in file {file_name}, line_number: {line_number}\nStacktrace:\n{stacktrace}"
216
+ test_result[0].display[
217
+ state_name] = f"{repr(e)} in file {file_name}, line_number: {line_number}\nStacktrace:\n{stacktrace}"
217
218
  test_result[0].is_passed = False
218
219
 
219
220
  result_payloads.extend(test_result)
@@ -228,39 +229,69 @@ class LeapLoader(LeapLoaderBase):
228
229
  all_dataset_base_handlers.extend(global_leap_binder.setup_container.metadata)
229
230
  return all_dataset_base_handlers
230
231
 
231
- def run_metric(self, metric_name: str,
232
+ def run_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
232
233
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
233
234
  self._preprocess_result()
234
- return self._metric_handler_by_name()[metric_name].function(**input_tensors_by_arg_name)
235
+
236
+ metric_handler = self._metric_handler_by_name()[metric_name]
237
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(metric_handler.function)
238
+
239
+ if preprocess_response_arg_name is not None:
240
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(
241
+ sample_ids, self._preprocess_result()[state])
242
+
243
+ return metric_handler.function(**input_tensors_by_arg_name)
244
+
245
+ @staticmethod
246
+ def _get_preprocess_response_arg_name(
247
+ func: Callable) -> Optional[str]:
248
+ for arg_name, arg_type in inspect.getfullargspec(func).annotations.items():
249
+ if arg_type == SamplePreprocessResponse:
250
+ return arg_name
251
+ return None
235
252
 
236
253
  def run_custom_loss(self, custom_loss_name: str, sample_ids: np.array, state: DataStateEnum,
237
254
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
238
255
 
239
256
  custom_loss_handler = self._custom_loss_handler_by_name()[custom_loss_name]
240
- preprocess_response_arg_name = None
241
- for arg_name, arg_type in inspect.getfullargspec(custom_loss_handler.function).annotations.items():
242
- if arg_type == SamplePreprocessResponse:
243
- preprocess_response_arg_name = arg_name
244
- break
257
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(custom_loss_handler.function)
245
258
 
246
259
  if preprocess_response_arg_name is not None:
247
- input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids, self._preprocess_result()[state])
260
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
261
+ self._preprocess_result()[
262
+ state])
248
263
  return custom_loss_handler.function(**input_tensors_by_arg_name)
249
264
 
250
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
265
+ def run_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
266
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
251
267
  # running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
252
268
  # global param that visualizer is using)
253
269
  self._preprocess_result()
254
270
 
255
- return self._visualizer_handler_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
271
+ vis_handler = self._visualizer_handler_by_name()[visualizer_name]
272
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(vis_handler.function)
273
+
274
+ if preprocess_response_arg_name is not None:
275
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
276
+ self._preprocess_result()[
277
+ state])
278
+
279
+ return vis_handler.function(**input_tensors_by_arg_name)
256
280
 
257
- def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
281
+ def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
282
+ sample_ids: np.array, state: DataStateEnum,
258
283
  ) -> Optional[npt.NDArray[np.float32]]:
259
284
  heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
260
285
  if heatmap_function is None:
261
286
  assert len(input_tensors_by_arg_name) == 1
262
287
  return None
263
288
 
289
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(heatmap_function)
290
+ if preprocess_response_arg_name is not None:
291
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
292
+ self._preprocess_result()[
293
+ state])
294
+
264
295
  return heatmap_function(**input_tensors_by_arg_name)
265
296
 
266
297
  def get_heatmap_visualizer_raw_vis_input_arg_name(self, visualizer_name: str) -> Optional[str]:
@@ -309,7 +340,8 @@ class LeapLoader(LeapLoaderBase):
309
340
  if hasattr(handler_test_payload.raw_result, 'tolist'):
310
341
  handler_test_payload.raw_result = handler_test_payload.raw_result.tolist()
311
342
  metadata_type = type(handler_test_payload.raw_result)
312
- if metadata_type == int or isinstance(handler_test_payload.raw_result, (np.unsignedinteger, np.signedinteger)):
343
+ if metadata_type == int or isinstance(handler_test_payload.raw_result,
344
+ (np.unsignedinteger, np.signedinteger)):
313
345
  metadata_type = float
314
346
  if isinstance(handler_test_payload.raw_result, str):
315
347
  dataset_metadata_type = DatasetMetadataType.string
@@ -369,7 +401,8 @@ class LeapLoader(LeapLoaderBase):
369
401
 
370
402
  return self._preprocess_result_cached
371
403
 
372
- def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[DataStateEnum, Union[List[int], List[str]]]:
404
+ def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[
405
+ DataStateEnum, Union[List[int], List[str]]]:
373
406
  preprocess_result = self._preprocess_result(update_unlabeled_preprocess)
374
407
  sample_ids = {}
375
408
  for state, preprocess_response in preprocess_result.items():
@@ -427,7 +460,8 @@ class LeapLoader(LeapLoaderBase):
427
460
 
428
461
  return converted_value
429
462
 
430
- def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, Union[str, int, bool, float]]:
463
+ def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[
464
+ str, Union[str, int, bool, float]]:
431
465
  result_agg = {}
432
466
  preprocess_result = self._preprocess_result()
433
467
  preprocess_state = preprocess_result[state]
@@ -436,7 +470,8 @@ class LeapLoader(LeapLoaderBase):
436
470
  if isinstance(handler_result, dict):
437
471
  for single_metadata_name, single_metadata_result in handler_result.items():
438
472
  handler_name = f'{handler.name}_{single_metadata_name}'
439
- result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, single_metadata_result)
473
+ result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name,
474
+ single_metadata_result)
440
475
  else:
441
476
  handler_name = handler.name
442
477
  result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, handler_result)
@@ -452,5 +487,3 @@ class LeapLoader(LeapLoaderBase):
452
487
  raise Exception("Different id types in preprocess results")
453
488
 
454
489
  return id_type
455
-
456
-
@@ -69,11 +69,12 @@ class LeapLoaderBase:
69
69
  pass
70
70
 
71
71
  @abstractmethod
72
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
72
+ def run_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
73
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
73
74
  pass
74
75
 
75
76
  @abstractmethod
76
- def run_metric(self, metric_name: str,
77
+ def run_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
77
78
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
78
79
  pass
79
80
 
@@ -83,7 +84,8 @@ class LeapLoaderBase:
83
84
  pass
84
85
 
85
86
  @abstractmethod
86
- def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
87
+ def run_heatmap_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
88
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
87
89
  ) -> Optional[npt.NDArray[np.float32]]:
88
90
  pass
89
91
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.72.dev4
3
+ Version: 1.0.73
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=2m9wyF_SO_q3kfrrAELS8rseMX-veeaP-Aof0ZhG_7g,7119
4
+ code_loader/contract/datasetclasses.py,sha256=hmPqgQcXya-8iF2mV06Nc_mJQdLoks7L22Ew9qH01d8,7221
5
5
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/responsedataclasses.py,sha256=RSx9m_R3LawhK5o1nAcO3hfp2F9oJYtxZr_bpP3bTmw,4005
@@ -19,14 +19,14 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
19
19
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
20
20
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
21
21
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
22
- code_loader/inner_leap_binder/leapbinder.py,sha256=-fryKzD8T8K2EgrOsR5NryabP8_1k_m3POLwhYIA_8I,26708
23
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=ebMxknpKMW-dE8Erq0fFq4RrE5E_Jfx9IvmRRZSdhlc,20813
24
- code_loader/leaploader.py,sha256=g160Z0MRop_7m3bu1HawcGoZNKakvs9Su4tEfbQ9pR0,22914
25
- code_loader/leaploaderbase.py,sha256=ijTodEBL-Q9DulR9z0xU0fo72rVvm06VvoRyrxzoCtE,4012
22
+ code_loader/inner_leap_binder/leapbinder.py,sha256=0l9zjlF27tZwg6SnpyqVoAAgf9QHQcKpR9lg7vho2Xw,27065
23
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=B-XSw4xYF39kMPnMTRNKMYFg09whnfl7VSbcx195VG8,21626
24
+ code_loader/leaploader.py,sha256=he1c46jQNrCBEBq03gQDS0WVxiX7nlGAE9_9hLagxQc,24973
25
+ code_loader/leaploaderbase.py,sha256=VH0vddRmkqLtcDlYPCO7hfz1_VbKo43lUdHDAbd4iJc,4198
26
26
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
27
27
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
29
- code_loader-1.0.72.dev4.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
- code_loader-1.0.72.dev4.dist-info/METADATA,sha256=TQ8fEw3bxJr1ZLchIY_e8EMfRVWreauzQhnK426Ca2k,854
31
- code_loader-1.0.72.dev4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
- code_loader-1.0.72.dev4.dist-info/RECORD,,
29
+ code_loader-1.0.73.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
+ code_loader-1.0.73.dist-info/METADATA,sha256=d7-FSiS0YvfmoWlJZyqlf644v1SuNGf-_MfunY5KYcA,849
31
+ code_loader-1.0.73.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
+ code_loader-1.0.73.dist-info/RECORD,,