code-loader 1.0.72.dev3__py3-none-any.whl → 1.0.72.dev5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/leaploader.py +52 -19
- code_loader/leaploaderbase.py +23 -4
- {code_loader-1.0.72.dev3.dist-info → code_loader-1.0.72.dev5.dist-info}/METADATA +1 -1
- {code_loader-1.0.72.dev3.dist-info → code_loader-1.0.72.dev5.dist-info}/RECORD +6 -6
- {code_loader-1.0.72.dev3.dist-info → code_loader-1.0.72.dev5.dist-info}/LICENSE +0 -0
- {code_loader-1.0.72.dev3.dist-info → code_loader-1.0.72.dev5.dist-info}/WHEEL +0 -0
code_loader/leaploader.py
CHANGED
@@ -6,7 +6,7 @@ import sys
|
|
6
6
|
from contextlib import redirect_stdout
|
7
7
|
from functools import lru_cache
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import Dict, List, Iterable, Union, Any, Type, Optional
|
9
|
+
from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable
|
10
10
|
|
11
11
|
import numpy as np
|
12
12
|
import numpy.typing as npt
|
@@ -213,7 +213,8 @@ class LeapLoader(LeapLoaderBase):
|
|
213
213
|
preprocess_response, test_result, dataset_base_handler)
|
214
214
|
except Exception as e:
|
215
215
|
line_number, file_name, stacktrace = get_root_exception_file_and_line_number()
|
216
|
-
test_result[0].display[
|
216
|
+
test_result[0].display[
|
217
|
+
state_name] = f"{repr(e)} in file {file_name}, line_number: {line_number}\nStacktrace:\n{stacktrace}"
|
217
218
|
test_result[0].is_passed = False
|
218
219
|
|
219
220
|
result_payloads.extend(test_result)
|
@@ -228,39 +229,69 @@ class LeapLoader(LeapLoaderBase):
|
|
228
229
|
all_dataset_base_handlers.extend(global_leap_binder.setup_container.metadata)
|
229
230
|
return all_dataset_base_handlers
|
230
231
|
|
231
|
-
def run_metric(self, metric_name: str,
|
232
|
+
def run_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
|
232
233
|
input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
|
233
234
|
self._preprocess_result()
|
234
|
-
|
235
|
+
|
236
|
+
metric_handler = self._metric_handler_by_name()[metric_name]
|
237
|
+
preprocess_response_arg_name = self._get_preprocess_response_arg_name(metric_handler.function)
|
238
|
+
|
239
|
+
if preprocess_response_arg_name is not None:
|
240
|
+
input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(
|
241
|
+
sample_ids, self._preprocess_result()[state])
|
242
|
+
|
243
|
+
return metric_handler.function(**input_tensors_by_arg_name)
|
244
|
+
|
245
|
+
@staticmethod
|
246
|
+
def _get_preprocess_response_arg_name(
|
247
|
+
func: Callable) -> Optional[str]:
|
248
|
+
for arg_name, arg_type in inspect.getfullargspec(func).annotations.items():
|
249
|
+
if arg_type == SamplePreprocessResponse:
|
250
|
+
return arg_name
|
251
|
+
return None
|
235
252
|
|
236
253
|
def run_custom_loss(self, custom_loss_name: str, sample_ids: np.array, state: DataStateEnum,
|
237
254
|
input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
|
238
255
|
|
239
256
|
custom_loss_handler = self._custom_loss_handler_by_name()[custom_loss_name]
|
240
|
-
preprocess_response_arg_name =
|
241
|
-
for arg_name, arg_type in inspect.getfullargspec(custom_loss_handler.function).annotations.items():
|
242
|
-
if arg_type == SamplePreprocessResponse:
|
243
|
-
preprocess_response_arg_name = arg_name
|
244
|
-
break
|
257
|
+
preprocess_response_arg_name = self._get_preprocess_response_arg_name(custom_loss_handler.function)
|
245
258
|
|
246
259
|
if preprocess_response_arg_name is not None:
|
247
|
-
input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
|
260
|
+
input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
|
261
|
+
self._preprocess_result()[
|
262
|
+
state])
|
248
263
|
return custom_loss_handler.function(**input_tensors_by_arg_name)
|
249
264
|
|
250
|
-
def run_visualizer(self, visualizer_name: str,
|
265
|
+
def run_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
|
266
|
+
input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
|
251
267
|
# running preprocessing to sync preprocessing in main thread (can be valuable when preprocess is filling a
|
252
268
|
# global param that visualizer is using)
|
253
269
|
self._preprocess_result()
|
254
270
|
|
255
|
-
|
271
|
+
vis_handler = self._visualizer_handler_by_name()[visualizer_name]
|
272
|
+
preprocess_response_arg_name = self._get_preprocess_response_arg_name(vis_handler.function)
|
273
|
+
|
274
|
+
if preprocess_response_arg_name is not None:
|
275
|
+
input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
|
276
|
+
self._preprocess_result()[
|
277
|
+
state])
|
278
|
+
|
279
|
+
return vis_handler.function(**input_tensors_by_arg_name)
|
256
280
|
|
257
|
-
def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
|
281
|
+
def run_heatmap_visualizer(self, visualizer_name: str, input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]],
|
282
|
+
sample_ids: np.array, state: DataStateEnum,
|
258
283
|
) -> Optional[npt.NDArray[np.float32]]:
|
259
284
|
heatmap_function = self._visualizer_handler_by_name()[visualizer_name].heatmap_function
|
260
285
|
if heatmap_function is None:
|
261
286
|
assert len(input_tensors_by_arg_name) == 1
|
262
287
|
return None
|
263
288
|
|
289
|
+
preprocess_response_arg_name = self._get_preprocess_response_arg_name(heatmap_function)
|
290
|
+
if preprocess_response_arg_name is not None:
|
291
|
+
input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(sample_ids,
|
292
|
+
self._preprocess_result()[
|
293
|
+
state])
|
294
|
+
|
264
295
|
return heatmap_function(**input_tensors_by_arg_name)
|
265
296
|
|
266
297
|
def get_heatmap_visualizer_raw_vis_input_arg_name(self, visualizer_name: str) -> Optional[str]:
|
@@ -309,7 +340,8 @@ class LeapLoader(LeapLoaderBase):
|
|
309
340
|
if hasattr(handler_test_payload.raw_result, 'tolist'):
|
310
341
|
handler_test_payload.raw_result = handler_test_payload.raw_result.tolist()
|
311
342
|
metadata_type = type(handler_test_payload.raw_result)
|
312
|
-
if metadata_type == int or isinstance(handler_test_payload.raw_result,
|
343
|
+
if metadata_type == int or isinstance(handler_test_payload.raw_result,
|
344
|
+
(np.unsignedinteger, np.signedinteger)):
|
313
345
|
metadata_type = float
|
314
346
|
if isinstance(handler_test_payload.raw_result, str):
|
315
347
|
dataset_metadata_type = DatasetMetadataType.string
|
@@ -369,7 +401,8 @@ class LeapLoader(LeapLoaderBase):
|
|
369
401
|
|
370
402
|
return self._preprocess_result_cached
|
371
403
|
|
372
|
-
def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[
|
404
|
+
def get_preprocess_sample_ids(self, update_unlabeled_preprocess=False) -> Dict[
|
405
|
+
DataStateEnum, Union[List[int], List[str]]]:
|
373
406
|
preprocess_result = self._preprocess_result(update_unlabeled_preprocess)
|
374
407
|
sample_ids = {}
|
375
408
|
for state, preprocess_response in preprocess_result.items():
|
@@ -427,7 +460,8 @@ class LeapLoader(LeapLoaderBase):
|
|
427
460
|
|
428
461
|
return converted_value
|
429
462
|
|
430
|
-
def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[
|
463
|
+
def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[
|
464
|
+
str, Union[str, int, bool, float]]:
|
431
465
|
result_agg = {}
|
432
466
|
preprocess_result = self._preprocess_result()
|
433
467
|
preprocess_state = preprocess_result[state]
|
@@ -436,7 +470,8 @@ class LeapLoader(LeapLoaderBase):
|
|
436
470
|
if isinstance(handler_result, dict):
|
437
471
|
for single_metadata_name, single_metadata_result in handler_result.items():
|
438
472
|
handler_name = f'{handler.name}_{single_metadata_name}'
|
439
|
-
result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name,
|
473
|
+
result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name,
|
474
|
+
single_metadata_result)
|
440
475
|
else:
|
441
476
|
handler_name = handler.name
|
442
477
|
result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, handler_result)
|
@@ -452,5 +487,3 @@ class LeapLoader(LeapLoaderBase):
|
|
452
487
|
raise Exception("Different id types in preprocess results")
|
453
488
|
|
454
489
|
return id_type
|
455
|
-
|
456
|
-
|
code_loader/leaploaderbase.py
CHANGED
@@ -10,7 +10,7 @@ import numpy.typing as npt
|
|
10
10
|
from code_loader.contract.datasetclasses import DatasetSample, LeapData, \
|
11
11
|
PredictionTypeHandler, CustomLayerHandler, VisualizerHandlerData, MetricHandlerData, MetricCallableReturnType, \
|
12
12
|
CustomLossHandlerData
|
13
|
-
from code_loader.contract.enums import DataStateEnum
|
13
|
+
from code_loader.contract.enums import DataStateEnum, DataStateType
|
14
14
|
from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
|
15
15
|
DatasetSetup, ModelSetup
|
16
16
|
|
@@ -23,6 +23,23 @@ class LeapLoaderBase:
|
|
23
23
|
self.current_working_sample_ids: Optional[np.array] = None
|
24
24
|
self.current_working_state: Optional[DataStateEnum] = None
|
25
25
|
|
26
|
+
def set_current_working_sample_ids(self, sample_ids: np.array):
|
27
|
+
if type(sample_ids[0]) is bytes:
|
28
|
+
sample_ids = np.array([sample_id.decode('utf-8') for sample_id in sample_ids])
|
29
|
+
self.current_working_sample_ids = sample_ids
|
30
|
+
|
31
|
+
def set_current_working_state(self, state: Union[DataStateEnum, DataStateType, str, int, bytes]):
|
32
|
+
if type(state) is bytes:
|
33
|
+
state = DataStateEnum[state.decode('utf-8')]
|
34
|
+
elif type(state) is str:
|
35
|
+
state = DataStateEnum[state]
|
36
|
+
elif type(state) is int:
|
37
|
+
state = DataStateEnum(state)
|
38
|
+
elif type(state) is DataStateType:
|
39
|
+
state = DataStateEnum[state.name]
|
40
|
+
|
41
|
+
self.current_working_state = state
|
42
|
+
|
26
43
|
@abstractmethod
|
27
44
|
def metric_by_name(self) -> Dict[str, MetricHandlerData]:
|
28
45
|
pass
|
@@ -52,11 +69,12 @@ class LeapLoaderBase:
|
|
52
69
|
pass
|
53
70
|
|
54
71
|
@abstractmethod
|
55
|
-
def run_visualizer(self, visualizer_name: str,
|
72
|
+
def run_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
|
73
|
+
input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> LeapData:
|
56
74
|
pass
|
57
75
|
|
58
76
|
@abstractmethod
|
59
|
-
def run_metric(self, metric_name: str,
|
77
|
+
def run_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
|
60
78
|
input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
|
61
79
|
pass
|
62
80
|
|
@@ -66,7 +84,8 @@ class LeapLoaderBase:
|
|
66
84
|
pass
|
67
85
|
|
68
86
|
@abstractmethod
|
69
|
-
def run_heatmap_visualizer(self, visualizer_name: str,
|
87
|
+
def run_heatmap_visualizer(self, visualizer_name: str, sample_ids: np.array, state: DataStateEnum,
|
88
|
+
input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
|
70
89
|
) -> Optional[npt.NDArray[np.float32]]:
|
71
90
|
pass
|
72
91
|
|
@@ -21,12 +21,12 @@ code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaS
|
|
21
21
|
code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
|
22
22
|
code_loader/inner_leap_binder/leapbinder.py,sha256=-fryKzD8T8K2EgrOsR5NryabP8_1k_m3POLwhYIA_8I,26708
|
23
23
|
code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=ebMxknpKMW-dE8Erq0fFq4RrE5E_Jfx9IvmRRZSdhlc,20813
|
24
|
-
code_loader/leaploader.py,sha256=
|
25
|
-
code_loader/leaploaderbase.py,sha256=
|
24
|
+
code_loader/leaploader.py,sha256=he1c46jQNrCBEBq03gQDS0WVxiX7nlGAE9_9hLagxQc,24973
|
25
|
+
code_loader/leaploaderbase.py,sha256=VH0vddRmkqLtcDlYPCO7hfz1_VbKo43lUdHDAbd4iJc,4198
|
26
26
|
code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
|
27
27
|
code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
|
29
|
-
code_loader-1.0.72.
|
30
|
-
code_loader-1.0.72.
|
31
|
-
code_loader-1.0.72.
|
32
|
-
code_loader-1.0.72.
|
29
|
+
code_loader-1.0.72.dev5.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
|
30
|
+
code_loader-1.0.72.dev5.dist-info/METADATA,sha256=IuNmOiTSxEC2A_-cCOc06mXUHiZPaf-Wg71eT28YGJE,854
|
31
|
+
code_loader-1.0.72.dev5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
32
|
+
code_loader-1.0.72.dev5.dist-info/RECORD,,
|
File without changes
|
File without changes
|