code-loader 1.0.72.dev4__py3-none-any.whl → 1.0.72.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
code_loader/leaploader.py CHANGED
@@ -6,7 +6,7 @@ import sys
6
6
  from contextlib import redirect_stdout
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Iterable, Union, Any, Type, Optional
9
+ from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable
10
10
 
11
11
  import numpy as np
12
12
  import numpy.typing as npt
@@ -213,7 +213,8 @@ class LeapLoader(LeapLoaderBase):
213
213
  preprocess_response, test_result, dataset_base_handler)
214
214
  except Exception as e:
215
215
  line_number, file_name, stacktrace = get_root_exception_file_and_line_number()
216
- test_result[0].display[state_name] = f"{repr(e)} in file {file_name}, line_number: {line_number}\nStacktrace:\n{stacktrace}"
216
+ test_result[0].display[
217
+ state_name] = f"{repr(e)} in file {file_name}, line_number: {line_number}\nStacktrace:\n{stacktrace}"
217
218
  test_result[0].is_passed = False
218
219
 
219
220
  result_payloads.extend(test_result)
@@ -228,39 +229,69 @@ class LeapLoader(LeapLoaderBase):
228
229
  all_dataset_base_handlers.extend(global_leap_binder.setup_container.metadata)
229
230
  return all_dataset_base_handlers
230
231
 
231
- def run_metric(self, metric_name: str,
232
+ def run_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
232
233
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
233
234
  self._preprocess_result()
234
- return self._metric_handler_by_name()[metric_name].function(**input_tensors_by_arg_name)
235
+
236
+ metric_handler = self._metric_handler_by_name()[metric_name]
237
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(metric_handler.function)
238
+
239
+ if preprocess_response_arg_name is not None:
240
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(
241
+ sample_ids, self._preprocess_result()[state])
242
+
243
+ return metric_handler.function(**input_tensors_by_arg_name)
244
+
245
+ @staticmethod
246
+ def _get_preprocess_response_arg_name(
247
+ func: Callable) -> Optional[str]:
248
+ for arg_name, arg_type in inspect.getfullargspec(func).annotations.items():
249
+ if arg_type == SamplePreprocessResponse:
250
+ return arg_name
251
+ return None
235
252
 
236
253
  def run_custom_loss(self, custom_loss_name: str, sample_ids: np.array, state: DataStateEnum,
237
254
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
238
255
 
239
256
  custom_loss_handler = self._custom_loss_handler_by_name()[custom_loss_name]
240
- preprocess_response_arg_name = None
241
- for arg_name, arg_type in inspect.getfullargspec(custom_loss_handler.function).annotations.items():
242
- if arg_type == SamplePreprocessResponse:
243
- preprocess_response_arg_name = arg_name
244
- break
257
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(custom_loss_handler.function)
245
258
 
246
259
  if preprocess_response_arg_name is not None:
247
- input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids, self._preprocess_result()[state])
260
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
261
+ self._preprocess_result()[
262
+ state])
248
263
  return custom_loss_handler.function(**input_tensors_by_arg_name)
249
264
 
250
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
265
+ def run_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
266
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
251
267
  # running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
252
268
  # global param that visualizer is using)
253
269
  self._preprocess_result()
254
270
 
255
- return self._visualizer_handler_by_name()[visualizer_name].function(**input_tensors_by_arg_name)
271
+ vis_handler = self._visualizer_handler_by_name()[visualizer_name]
272
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(vis_handler.function)
273
+
274
+ if preprocess_response_arg_name is not None:
275
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
276
+ self._preprocess_result()[
277
+ state])
278
+
279
+ return vis_handler.function(**input_tensors_by_arg_name)
256
280
 
257
- def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
281
+ def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
282
+ sample_ids: np.array, state: DataStateEnum,
258
283
  ) -> Optional[npt.NDArray[np.float32]]:
259
284
  heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
260
285
  if heatmap_function is None:
261
286
  assert len(input_tensors_by_arg_name) == 1
262
287
  return None
263
288
 
289
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(heatmap_function)
290
+ if preprocess_response_arg_name is not None:
291
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
292
+ self._preprocess_result()[
293
+ state])
294
+
264
295
  return heatmap_function(**input_tensors_by_arg_name)
265
296
 
266
297
  def get_heatmap_visualizer_raw_vis_input_arg_name(self, visualizer_name: str) -> Optional[str]:
@@ -309,7 +340,8 @@ class LeapLoader(LeapLoaderBase):
309
340
  if hasattr(handler_test_payload.raw_result, 'tolist'):
310
341
  handler_test_payload.raw_result = handler_test_payload.raw_result.tolist()
311
342
  metadata_type = type(handler_test_payload.raw_result)
312
- if metadata_type == int or isinstance(handler_test_payload.raw_result, (np.unsignedinteger, np.signedinteger)):
343
+ if metadata_type == int or isinstance(handler_test_payload.raw_result,
344
+ (np.unsignedinteger, np.signedinteger)):
313
345
  metadata_type = float
314
346
  if isinstance(handler_test_payload.raw_result, str):
315
347
  dataset_metadata_type = DatasetMetadataType.string
@@ -369,7 +401,8 @@ class LeapLoader(LeapLoaderBase):
369
401
 
370
402
  return self._preprocess_result_cached
371
403
 
372
- def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[DataStateEnum, Union[List[int], List[str]]]:
404
+ def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[
405
+ DataStateEnum, Union[List[int], List[str]]]:
373
406
  preprocess_result = self._preprocess_result(update_unlabeled_preprocess)
374
407
  sample_ids = {}
375
408
  for state, preprocess_response in preprocess_result.items():
@@ -427,7 +460,8 @@ class LeapLoader(LeapLoaderBase):
427
460
 
428
461
  return converted_value
429
462
 
430
- def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[str, Union[str, int, bool, float]]:
463
+ def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[
464
+ str, Union[str, int, bool, float]]:
431
465
  result_agg = {}
432
466
  preprocess_result = self._preprocess_result()
433
467
  preprocess_state = preprocess_result[state]
@@ -436,7 +470,8 @@ class LeapLoader(LeapLoaderBase):
436
470
  if isinstance(handler_result, dict):
437
471
  for single_metadata_name, single_metadata_result in handler_result.items():
438
472
  handler_name = f'{handler.name}_{single_metadata_name}'
439
- result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, single_metadata_result)
473
+ result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name,
474
+ single_metadata_result)
440
475
  else:
441
476
  handler_name = handler.name
442
477
  result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, handler_result)
@@ -452,5 +487,3 @@ class LeapLoader(LeapLoaderBase):
452
487
  raise Exception("Different id types in preprocess results")
453
488
 
454
489
  return id_type
455
-
456
-
@@ -69,11 +69,12 @@ class LeapLoaderBase:
69
69
  pass
70
70
 
71
71
  @abstractmethod
72
- def run_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
72
+ def run_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
73
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
73
74
  pass
74
75
 
75
76
  @abstractmethod
76
- def run_metric(self, metric_name: str,
77
+ def run_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
77
78
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
78
79
  pass
79
80
 
@@ -83,7 +84,8 @@ class LeapLoaderBase:
83
84
  pass
84
85
 
85
86
  @abstractmethod
86
- def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
87
+ def run_heatmap_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
88
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
87
89
  ) -> Optional[npt.NDArray[np.float32]]:
88
90
  pass
89
91
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.72.dev4
3
+ Version: 1.0.72.dev5
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -21,12 +21,12 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
21
21
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
22
22
  code_loader/inner_leap_binder/leapbinder.py,sha256=-fryKzD8T8K2EgrOsR5NryabP8_1k_m3POLwhYIA_8I,26708
23
23
  code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=ebMxknpKMW-dE8Erq0fFq4RrE5E_Jfx9IvmRRZSdhlc,20813
24
- code_loader/leaploader.py,sha256=g160Z0MRop_7m3bu1HawcGoZNKakvs9Su4tEfbQ9pR0,22914
25
- code_loader/leaploaderbase.py,sha256=ijTodEBL-Q9DulR9z0xU0fo72rVvm06VvoRyrxzoCtE,4012
24
+ code_loader/leaploader.py,sha256=he1c46jQNrCBEBq03gQDS0WVxiX7nlGAE9_9hLagxQc,24973
25
+ code_loader/leaploaderbase.py,sha256=VH0vddRmkqLtcDlYPCO7hfz1_VbKo43lUdHDAbd4iJc,4198
26
26
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
27
27
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
29
- code_loader-1.0.72.dev4.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
- code_loader-1.0.72.dev4.dist-info/METADATA,sha256=TQ8fEw3bxJr1ZLchIY_e8EMfRVWreauzQhnK426Ca2k,854
31
- code_loader-1.0.72.dev4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
- code_loader-1.0.72.dev4.dist-info/RECORD,,
29
+ code_loader-1.0.72.dev5.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
+ code_loader-1.0.72.dev5.dist-info/METADATA,sha256=IuNmOiTSxEC2A_-cCOc06mXUHiZPaf-Wg71eT28YGJE,854
31
+ code_loader-1.0.72.dev5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
+ code_loader-1.0.72.dev5.dist-info/RECORD,,