code-loader 1.0.112.dev6__py3-none-any.whl → 1.0.153.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/__init__.py +6 -0
- code_loader/contract/datasetclasses.py +27 -4
- code_loader/contract/mapping.py +11 -0
- code_loader/inner_leap_binder/leapbinder.py +14 -1
- code_loader/inner_leap_binder/leapbinder_decorators.py +993 -142
- code_loader/leaploader.py +53 -25
- code_loader/leaploaderbase.py +30 -5
- code_loader/mixpanel_tracker.py +230 -0
- code_loader/plot_functions/plot_functions.py +1 -2
- code_loader/utils.py +1 -1
- {code_loader-1.0.112.dev6.dist-info → code_loader-1.0.153.dev4.dist-info}/METADATA +4 -2
- {code_loader-1.0.112.dev6.dist-info → code_loader-1.0.153.dev4.dist-info}/RECORD +14 -13
- {code_loader-1.0.112.dev6.dist-info → code_loader-1.0.153.dev4.dist-info}/LICENSE +0 -0
- {code_loader-1.0.112.dev6.dist-info → code_loader-1.0.153.dev4.dist-info}/WHEEL +0 -0
|
@@ -1,26 +1,201 @@
|
|
|
1
1
|
# mypy: ignore-errors
|
|
2
2
|
import os
|
|
3
|
-
|
|
3
|
+
import warnings
|
|
4
|
+
import logging
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, Union, Callable, List, Dict, Set, Any
|
|
9
|
+
from typing import Optional, Union, Callable, List, Dict, get_args, get_origin, DefaultDict
|
|
4
10
|
|
|
5
11
|
import numpy as np
|
|
6
12
|
import numpy.typing as npt
|
|
7
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
8
16
|
from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
|
|
9
17
|
CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
|
|
10
18
|
VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
|
|
11
19
|
ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance, \
|
|
12
20
|
InstanceLengthCallableInterface
|
|
13
|
-
from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType
|
|
21
|
+
from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType, DataStateType
|
|
14
22
|
from code_loader import leap_binder
|
|
15
23
|
from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
|
|
16
24
|
from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
|
|
17
25
|
LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
|
|
18
26
|
from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
|
|
27
|
+
from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
|
|
19
28
|
|
|
20
29
|
import inspect
|
|
21
30
|
import functools
|
|
31
|
+
from pathlib import Path
|
|
22
32
|
|
|
23
33
|
_called_from_inside_tl_decorator = 0
|
|
34
|
+
_called_from_inside_tl_integration_test_decorator = False
|
|
35
|
+
_call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
|
|
36
|
+
|
|
37
|
+
# ---- warnings store (module-level) ----
|
|
38
|
+
_UNSET = object()
|
|
39
|
+
_STORED_WARNINGS: List[Dict[str, Any]] = []
|
|
40
|
+
_STORED_WARNING_KEYS: Set[tuple] = set()
|
|
41
|
+
# param_name -> set(user_func_name)
|
|
42
|
+
_PARAM_DEFAULT_FUNCS: DefaultDict[str, Set[str]] = defaultdict(set)
|
|
43
|
+
# param_name -> default_value used (repr-able)
|
|
44
|
+
_PARAM_DEFAULT_VALUE: Dict[str, Any] = {}
|
|
45
|
+
_PARAM_DEFAULT_DOCS: Dict[str, str] = {}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_entry_script_path() -> str:
|
|
49
|
+
import sys
|
|
50
|
+
argv0 = sys.argv[0] if sys.argv else ""
|
|
51
|
+
if argv0:
|
|
52
|
+
return str(Path(argv0).resolve())
|
|
53
|
+
import __main__
|
|
54
|
+
main_file = getattr(__main__, "__file__", "") or ""
|
|
55
|
+
return str(Path(main_file).resolve()) if main_file else ""
|
|
56
|
+
|
|
57
|
+
def started_from(filename: str) -> bool:
|
|
58
|
+
entry = get_entry_script_path()
|
|
59
|
+
return bool(entry) and Path(entry).name == filename
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def store_warning_by_param(
|
|
63
|
+
*,
|
|
64
|
+
param_name: str,
|
|
65
|
+
user_func_name: str,
|
|
66
|
+
default_value: Any,
|
|
67
|
+
link_to_docs: str = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
_PARAM_DEFAULT_FUNCS[param_name].add(user_func_name)
|
|
70
|
+
|
|
71
|
+
if param_name not in _PARAM_DEFAULT_VALUE:
|
|
72
|
+
_PARAM_DEFAULT_VALUE[param_name] = default_value
|
|
73
|
+
|
|
74
|
+
if link_to_docs and param_name not in _PARAM_DEFAULT_DOCS:
|
|
75
|
+
_PARAM_DEFAULT_DOCS[param_name] = link_to_docs
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_param_default_warnings() -> Dict[str, Dict[str, Any]]:
|
|
79
|
+
out: Dict[str, Dict[str, Any]] = {}
|
|
80
|
+
for p, funcs in _PARAM_DEFAULT_FUNCS.items():
|
|
81
|
+
out[p] = {
|
|
82
|
+
"default_value": _PARAM_DEFAULT_VALUE.get(p, None),
|
|
83
|
+
"funcs": set(funcs),
|
|
84
|
+
"link_to_docs": _PARAM_DEFAULT_DOCS.get(p),
|
|
85
|
+
}
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_stored_warnings() -> List[Dict[str, Any]]:
|
|
90
|
+
return list(_STORED_WARNINGS)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def validate_args_structure(*args, types_order, func_name, expected_names, **kwargs):
|
|
94
|
+
def _type_to_str(t):
|
|
95
|
+
origin = get_origin(t)
|
|
96
|
+
if origin is Union:
|
|
97
|
+
return " | ".join(tt.__name__ for tt in get_args(t))
|
|
98
|
+
elif hasattr(t, "__name__"):
|
|
99
|
+
return t.__name__
|
|
100
|
+
else:
|
|
101
|
+
return str(t)
|
|
102
|
+
|
|
103
|
+
def _format_types(types, names=None):
|
|
104
|
+
return ", ".join(
|
|
105
|
+
f"{(names[i] + ': ') if names else f'arg{i}: '}{_type_to_str(ty)}"
|
|
106
|
+
for i, ty in enumerate(types)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if expected_names:
|
|
110
|
+
normalized_args = []
|
|
111
|
+
for i, name in enumerate(expected_names):
|
|
112
|
+
if i < len(args):
|
|
113
|
+
normalized_args.append(args[i])
|
|
114
|
+
elif name in kwargs:
|
|
115
|
+
normalized_args.append(kwargs[name])
|
|
116
|
+
else:
|
|
117
|
+
raise AssertionError(
|
|
118
|
+
f"{func_name} validation failed: "
|
|
119
|
+
f"Missing required argument '{name}'. "
|
|
120
|
+
f"Expected arguments: {expected_names}."
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
normalized_args = list(args)
|
|
124
|
+
if len(normalized_args) != len(types_order):
|
|
125
|
+
expected = _format_types(types_order, expected_names)
|
|
126
|
+
got_types = ", ".join(type(arg).__name__ for arg in normalized_args)
|
|
127
|
+
raise AssertionError(
|
|
128
|
+
f"{func_name} validation failed: "
|
|
129
|
+
f"Expected exactly {len(types_order)} arguments ({expected}), "
|
|
130
|
+
f"but got {len(normalized_args)} argument(s) of type(s): ({got_types}). "
|
|
131
|
+
f"Correct usage example: {func_name}({expected})"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
for i, (arg, expected_type) in enumerate(zip(normalized_args, types_order)):
|
|
135
|
+
origin = get_origin(expected_type)
|
|
136
|
+
if origin is Union:
|
|
137
|
+
allowed_types = get_args(expected_type)
|
|
138
|
+
else:
|
|
139
|
+
allowed_types = (expected_type,)
|
|
140
|
+
|
|
141
|
+
if not isinstance(arg, allowed_types):
|
|
142
|
+
allowed_str = " | ".join(t.__name__ for t in allowed_types)
|
|
143
|
+
raise AssertionError(
|
|
144
|
+
f"{func_name} validation failed: "
|
|
145
|
+
f"Argument '{expected_names[i] if expected_names else f'arg{i}'}' "
|
|
146
|
+
f"expected type {allowed_str}, but got {type(arg).__name__}. "
|
|
147
|
+
f"Correct usage example: {func_name}({_format_types(types_order, expected_names)})"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def validate_output_structure(result, func_name: str, expected_type_name="np.ndarray", gt_flag=False):
|
|
152
|
+
if result is None or (isinstance(result, float) and np.isnan(result)):
|
|
153
|
+
if gt_flag:
|
|
154
|
+
raise AssertionError(
|
|
155
|
+
f"{func_name} validation failed: "
|
|
156
|
+
f"The function returned {result!r}. "
|
|
157
|
+
f"If you are working with an unlabeled dataset and no ground truth is available, "
|
|
158
|
+
f"use 'return np.array([], dtype=np.float32)' instead. "
|
|
159
|
+
f"Otherwise, {func_name} expected a single {expected_type_name} object. "
|
|
160
|
+
f"Make sure the function ends with 'return <{expected_type_name}>'."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
raise AssertionError(
|
|
164
|
+
f"{func_name} validation failed: "
|
|
165
|
+
f"The function returned None. "
|
|
166
|
+
f"Expected a single {expected_type_name} object. "
|
|
167
|
+
f"Make sure the function ends with 'return <{expected_type_name}>'."
|
|
168
|
+
)
|
|
169
|
+
if isinstance(result, tuple):
|
|
170
|
+
element_descriptions = [
|
|
171
|
+
f"[{i}] type: {type(r).__name__}"
|
|
172
|
+
for i, r in enumerate(result)
|
|
173
|
+
]
|
|
174
|
+
element_summary = "\n ".join(element_descriptions)
|
|
175
|
+
|
|
176
|
+
raise AssertionError(
|
|
177
|
+
f"{func_name} validation failed: "
|
|
178
|
+
f"The function returned multiple outputs ({len(result)} values), "
|
|
179
|
+
f"but only a single {expected_type_name} is allowed.\n\n"
|
|
180
|
+
f"Returned elements:\n"
|
|
181
|
+
f" {element_summary}\n\n"
|
|
182
|
+
f"Correct usage example:\n"
|
|
183
|
+
f" def {func_name}(...):\n"
|
|
184
|
+
f" return <{expected_type_name}>\n\n"
|
|
185
|
+
f"If you intended to return multiple values, combine them into a single "
|
|
186
|
+
f"{expected_type_name} (e.g., by concatenation or stacking)."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def batch_warning(result, func_name):
|
|
191
|
+
if len(result.shape) > 0 and result.shape[0] == 1:
|
|
192
|
+
warnings.warn(
|
|
193
|
+
f"{func_name} warning: Tensorleap will add a batch dimension at axis 0 to the output of {func_name}, "
|
|
194
|
+
f"although the detected size of axis 0 is already 1. "
|
|
195
|
+
f"This may lead to an extra batch dimension (e.g., shape (1, 1, ...)). "
|
|
196
|
+
f"Please ensure that the output of '{func_name}' is not already batched "
|
|
197
|
+
f"to avoid computation errors."
|
|
198
|
+
)
|
|
24
199
|
|
|
25
200
|
|
|
26
201
|
def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
|
|
@@ -41,51 +216,234 @@ def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
|
|
|
41
216
|
_add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type)
|
|
42
217
|
|
|
43
218
|
|
|
44
|
-
def
|
|
219
|
+
def tensorleap_integration_test():
|
|
45
220
|
def decorating_function(integration_test_function: Callable):
|
|
46
221
|
leap_binder.integration_test_func = integration_test_function
|
|
47
222
|
|
|
223
|
+
def _validate_input_args(*args, **kwargs):
|
|
224
|
+
sample_id, preprocess_response = args
|
|
225
|
+
assert type(sample_id) == preprocess_response.sample_id_type, (
|
|
226
|
+
f"tensorleap_integration_test validation failed: "
|
|
227
|
+
f"sample_id type ({type(sample_id).__name__}) does not match the expected "
|
|
228
|
+
f"type ({preprocess_response.sample_id_type}) from the PreprocessResponse."
|
|
229
|
+
)
|
|
230
|
+
|
|
48
231
|
def inner(*args, **kwargs):
|
|
49
|
-
|
|
232
|
+
if not _call_from_tl_platform:
|
|
233
|
+
set_current('tensorleap_integration_test')
|
|
234
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
235
|
+
func_name='integration_test', expected_names=["idx", "preprocess"], **kwargs)
|
|
236
|
+
_validate_input_args(*args, **kwargs)
|
|
50
237
|
|
|
238
|
+
global _called_from_inside_tl_integration_test_decorator
|
|
239
|
+
# Clear integration test events for new test
|
|
51
240
|
try:
|
|
52
|
-
|
|
53
|
-
integration_test_function(None, None)
|
|
241
|
+
clear_integration_events()
|
|
54
242
|
except Exception as e:
|
|
55
|
-
|
|
56
|
-
|
|
243
|
+
logger.debug(f"Failed to clear integration events: {e}")
|
|
244
|
+
try:
|
|
245
|
+
_called_from_inside_tl_integration_test_decorator = True
|
|
246
|
+
if not _call_from_tl_platform:
|
|
247
|
+
update_env_params_func("tensorleap_integration_test",
|
|
248
|
+
"v") # put here because otherwise it will become v only if it finishes all the script
|
|
249
|
+
ret = integration_test_function(*args, **kwargs)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
os.environ[mapping_runtime_mode_env_var_mame] = 'True'
|
|
253
|
+
integration_test_function(None, PreprocessResponse(state=DataStateType.training, length=0))
|
|
254
|
+
except Exception as e:
|
|
255
|
+
import traceback
|
|
256
|
+
first_tb = traceback.extract_tb(e.__traceback__)[-1]
|
|
257
|
+
file_name = Path(first_tb.filename).name
|
|
258
|
+
line_number = first_tb.lineno
|
|
259
|
+
if isinstance(e, TypeError) and 'is not subscriptable' in str(e):
|
|
260
|
+
update_env_params_func("code_mapping", "x")
|
|
261
|
+
raise (f'Invalid integration code. File {file_name}, line {line_number}: '
|
|
262
|
+
f"indexing is supported only on the model's predictions inside the integration test. Please remove this indexing operation usage from the integration test code.")
|
|
263
|
+
else:
|
|
264
|
+
update_env_params_func("code_mapping", "x")
|
|
265
|
+
|
|
266
|
+
raise (f'Invalid integration code. File {file_name}, line {line_number}: '
|
|
267
|
+
f'Integration test is only allowed to call Tensorleap decorators. '
|
|
268
|
+
f'Ensure any arithmetics, external library use, Python logic is placed within Tensorleap decoders')
|
|
269
|
+
finally:
|
|
270
|
+
if mapping_runtime_mode_env_var_mame in os.environ:
|
|
271
|
+
del os.environ[mapping_runtime_mode_env_var_mame]
|
|
57
272
|
finally:
|
|
58
|
-
|
|
59
|
-
|
|
273
|
+
_called_from_inside_tl_integration_test_decorator = False
|
|
274
|
+
|
|
275
|
+
leap_binder.check()
|
|
60
276
|
|
|
61
277
|
return inner
|
|
62
278
|
|
|
63
279
|
return decorating_function
|
|
64
280
|
|
|
65
281
|
|
|
66
|
-
def
|
|
282
|
+
def _safe_get_item(key):
|
|
283
|
+
try:
|
|
284
|
+
return NodeMappingType[f'Input{str(key)}']
|
|
285
|
+
except ValueError:
|
|
286
|
+
raise Exception(f'Tensorleap currently supports models with no more then 10 inputs')
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = _UNSET):
|
|
290
|
+
prediction_types_was_provided = prediction_types is not _UNSET
|
|
291
|
+
|
|
292
|
+
if not prediction_types_was_provided:
|
|
293
|
+
prediction_types = []
|
|
294
|
+
if not _call_from_tl_platform:
|
|
295
|
+
store_warning_by_param(
|
|
296
|
+
param_name="prediction_types",
|
|
297
|
+
user_func_name="tensorleap_load_model",
|
|
298
|
+
default_value=prediction_types,
|
|
299
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
|
|
300
|
+
)
|
|
301
|
+
assert isinstance(prediction_types, list), (
|
|
302
|
+
f"tensorleap_load_model validation failed: "
|
|
303
|
+
f" prediction_types is an optional argument of type List[PredictionTypeHandler]] but got {type(prediction_types).__name__}."
|
|
304
|
+
)
|
|
67
305
|
for i, prediction_type in enumerate(prediction_types):
|
|
306
|
+
assert isinstance(prediction_type, PredictionTypeHandler), (f"tensorleap_load_model validation failed: "
|
|
307
|
+
f" prediction_types at position {i} must be of type PredictionTypeHandler but got {type(prediction_types[i]).__name__}.")
|
|
308
|
+
prediction_type_channel_dim_was_provided = prediction_type.channel_dim != "tl_default_value"
|
|
309
|
+
if not prediction_type_channel_dim_was_provided:
|
|
310
|
+
prediction_types[i].channel_dim = -1
|
|
311
|
+
if not _call_from_tl_platform:
|
|
312
|
+
store_warning_by_param(
|
|
313
|
+
param_name=f"prediction_types[{i}].channel_dim",
|
|
314
|
+
user_func_name="tensorleap_load_model",
|
|
315
|
+
default_value=prediction_types[i].channel_dim,
|
|
316
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
|
|
317
|
+
)
|
|
68
318
|
leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
|
|
69
319
|
|
|
70
|
-
def
|
|
320
|
+
def _validate_result(result) -> None:
|
|
321
|
+
valid_types = ["onnxruntime", "keras"]
|
|
322
|
+
err_message = f"tensorleap_load_model validation failed:\nSupported models are Keras and onnxruntime only and non of them was returned."
|
|
323
|
+
validate_output_structure(result, func_name="tensorleap_load_model",
|
|
324
|
+
expected_type_name=[" | ".join(t for t in valid_types)][0])
|
|
325
|
+
try:
|
|
326
|
+
import keras
|
|
327
|
+
except ImportError:
|
|
328
|
+
keras = None
|
|
329
|
+
try:
|
|
330
|
+
import tensorflow as tf
|
|
331
|
+
except ImportError:
|
|
332
|
+
tf = None
|
|
333
|
+
try:
|
|
334
|
+
import onnxruntime
|
|
335
|
+
except ImportError:
|
|
336
|
+
onnxruntime = None
|
|
337
|
+
|
|
338
|
+
if not keras and not onnxruntime:
|
|
339
|
+
raise AssertionError(err_message)
|
|
340
|
+
|
|
341
|
+
is_keras_model = (
|
|
342
|
+
bool(keras and isinstance(result, getattr(keras, "Model", tuple())))
|
|
343
|
+
or bool(tf and isinstance(result, getattr(tf.keras, "Model", tuple())))
|
|
344
|
+
)
|
|
345
|
+
is_onnx_model = bool(onnxruntime and isinstance(result, onnxruntime.InferenceSession))
|
|
346
|
+
|
|
347
|
+
if not any([is_keras_model, is_onnx_model]):
|
|
348
|
+
raise AssertionError(err_message)
|
|
349
|
+
|
|
350
|
+
def decorating_function(load_model_func, prediction_types=prediction_types):
|
|
71
351
|
class TempMapping:
|
|
72
352
|
pass
|
|
73
353
|
|
|
74
|
-
|
|
354
|
+
@lru_cache()
|
|
355
|
+
def inner(*args, **kwargs):
|
|
356
|
+
if not _call_from_tl_platform:
|
|
357
|
+
set_current('tensorleap_load_model')
|
|
358
|
+
validate_args_structure(*args, types_order=[],
|
|
359
|
+
func_name='tensorleap_load_model', expected_names=[], **kwargs)
|
|
360
|
+
|
|
75
361
|
class ModelPlaceholder:
|
|
76
|
-
def __init__(self):
|
|
77
|
-
self.model = load_model_func()
|
|
362
|
+
def __init__(self, prediction_types):
|
|
363
|
+
self.model = load_model_func() # TODO- check why this fails on onnx model
|
|
364
|
+
self.prediction_types = prediction_types
|
|
365
|
+
_validate_result(self.model)
|
|
78
366
|
|
|
79
367
|
# keras interface
|
|
80
368
|
def __call__(self, arg):
|
|
81
369
|
ret = self.model(arg)
|
|
370
|
+
self.validate_declared_prediction_types(ret)
|
|
371
|
+
if isinstance(ret, list):
|
|
372
|
+
return [r.numpy() for r in ret]
|
|
82
373
|
return ret.numpy()
|
|
83
374
|
|
|
375
|
+
def validate_declared_prediction_types(self, ret):
|
|
376
|
+
if not (len(self.prediction_types) == len(ret) if isinstance(ret, list) else 1) and len(
|
|
377
|
+
self.prediction_types) != 0:
|
|
378
|
+
if not _call_from_tl_platform:
|
|
379
|
+
update_env_params_func("tensorleap_load_model", "x")
|
|
380
|
+
raise Exception(
|
|
381
|
+
f"tensorleap_load_model validation failed: number of declared prediction types({len(prediction_types)}) != number of model outputs({len(ret) if isinstance(ret, list) else 1})")
|
|
382
|
+
|
|
383
|
+
def _convert_onnx_inputs_to_correct_type(
|
|
384
|
+
self, float_arrays_inputs: Dict[str, np.ndarray]
|
|
385
|
+
) -> Dict[str, np.ndarray]:
|
|
386
|
+
ONNX_TYPE_TO_NP = {
|
|
387
|
+
"tensor(float)": np.float32,
|
|
388
|
+
"tensor(double)": np.float64,
|
|
389
|
+
"tensor(int64)": np.int64,
|
|
390
|
+
"tensor(int32)": np.int32,
|
|
391
|
+
"tensor(int16)": np.int16,
|
|
392
|
+
"tensor(int8)": np.int8,
|
|
393
|
+
"tensor(uint64)": np.uint64,
|
|
394
|
+
"tensor(uint32)": np.uint32,
|
|
395
|
+
"tensor(uint16)": np.uint16,
|
|
396
|
+
"tensor(uint8)": np.uint8,
|
|
397
|
+
"tensor(bool)": np.bool_,
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
"""
|
|
401
|
+
Cast user-provided NumPy inputs to match the dtypes/shapes
|
|
402
|
+
expected by an ONNX Runtime InferenceSession.
|
|
403
|
+
"""
|
|
404
|
+
coerced = {}
|
|
405
|
+
meta = {i.name: i for i in self.model.get_inputs()}
|
|
406
|
+
|
|
407
|
+
for name, arr in float_arrays_inputs.items():
|
|
408
|
+
if name not in meta:
|
|
409
|
+
# Keep as-is unless extra inputs are disallowed
|
|
410
|
+
coerced[name] = arr
|
|
411
|
+
continue
|
|
412
|
+
|
|
413
|
+
info = meta[name]
|
|
414
|
+
onnx_type = info.type
|
|
415
|
+
want_dtype = ONNX_TYPE_TO_NP.get(onnx_type)
|
|
416
|
+
|
|
417
|
+
if want_dtype is None:
|
|
418
|
+
raise TypeError(f"Unsupported ONNX input type: {onnx_type}")
|
|
419
|
+
|
|
420
|
+
# Cast dtype if needed
|
|
421
|
+
if arr.dtype != want_dtype:
|
|
422
|
+
arr = arr.astype(want_dtype, copy=False)
|
|
423
|
+
|
|
424
|
+
coerced[name] = arr
|
|
425
|
+
|
|
426
|
+
# Verify required inputs are present
|
|
427
|
+
missing = [n for n in meta if n not in coerced]
|
|
428
|
+
if missing:
|
|
429
|
+
raise KeyError(f"Missing required input(s): {sorted(missing)}")
|
|
430
|
+
|
|
431
|
+
return coerced
|
|
432
|
+
|
|
84
433
|
# onnx runtime interface
|
|
85
434
|
def run(self, output_names, input_dict):
|
|
86
|
-
|
|
435
|
+
corrected_type_inputs = self._convert_onnx_inputs_to_correct_type(input_dict)
|
|
436
|
+
ret = self.model.run(output_names, corrected_type_inputs)
|
|
437
|
+
self.validate_declared_prediction_types(ret)
|
|
438
|
+
return ret
|
|
87
439
|
|
|
88
|
-
|
|
440
|
+
def get_inputs(self):
|
|
441
|
+
return self.model.get_inputs()
|
|
442
|
+
|
|
443
|
+
model_placeholder = ModelPlaceholder(prediction_types)
|
|
444
|
+
if not _call_from_tl_platform:
|
|
445
|
+
update_env_params_func("tensorleap_load_model", "v")
|
|
446
|
+
return model_placeholder
|
|
89
447
|
|
|
90
448
|
def mapping_inner():
|
|
91
449
|
class ModelOutputPlaceholder:
|
|
@@ -97,15 +455,20 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
|
|
|
97
455
|
f'Expected key to be an int, got {type(key)} instead.'
|
|
98
456
|
|
|
99
457
|
ret = TempMapping()
|
|
100
|
-
|
|
458
|
+
try:
|
|
459
|
+
ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
|
|
460
|
+
except ValueError as e:
|
|
461
|
+
raise Exception(f'Tensorleap currently supports models with no more then 10 active predictions,'
|
|
462
|
+
f' {key} not supported.')
|
|
101
463
|
return ret
|
|
102
464
|
|
|
103
465
|
class ModelPlaceholder:
|
|
466
|
+
|
|
104
467
|
# keras interface
|
|
105
468
|
def __call__(self, arg):
|
|
106
469
|
if isinstance(arg, list):
|
|
107
470
|
for i, elem in enumerate(arg):
|
|
108
|
-
elem.node_mapping.type =
|
|
471
|
+
elem.node_mapping.type = _safe_get_item(i)
|
|
109
472
|
else:
|
|
110
473
|
arg.node_mapping.type = NodeMappingType.Input0
|
|
111
474
|
|
|
@@ -116,18 +479,38 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
|
|
|
116
479
|
assert output_names is None
|
|
117
480
|
assert isinstance(input_dict, dict), \
|
|
118
481
|
f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
|
|
119
|
-
for i, elem in enumerate(input_dict.
|
|
120
|
-
|
|
482
|
+
for i, (input_key, elem) in enumerate(input_dict.items()):
|
|
483
|
+
if isinstance(input_key, NodeMappingType):
|
|
484
|
+
elem.node_mapping.type = input_key
|
|
485
|
+
else:
|
|
486
|
+
elem.node_mapping.type = _safe_get_item(i)
|
|
121
487
|
|
|
122
488
|
return ModelOutputPlaceholder()
|
|
123
489
|
|
|
490
|
+
def get_inputs(self):
|
|
491
|
+
class FollowIndex:
|
|
492
|
+
def __init__(self, index):
|
|
493
|
+
self.name = _safe_get_item(index)
|
|
494
|
+
|
|
495
|
+
class FollowInputIndex:
|
|
496
|
+
def __init__(self):
|
|
497
|
+
pass
|
|
498
|
+
|
|
499
|
+
def __getitem__(self, index):
|
|
500
|
+
assert isinstance(index, int), \
|
|
501
|
+
f'Expected key to be an int, got {type(index)} instead.'
|
|
502
|
+
|
|
503
|
+
return FollowIndex(index)
|
|
504
|
+
|
|
505
|
+
return FollowInputIndex()
|
|
506
|
+
|
|
124
507
|
return ModelPlaceholder()
|
|
125
508
|
|
|
126
|
-
def final_inner():
|
|
509
|
+
def final_inner(*args, **kwargs):
|
|
127
510
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
128
511
|
return mapping_inner()
|
|
129
512
|
else:
|
|
130
|
-
return inner()
|
|
513
|
+
return inner(*args, **kwargs)
|
|
131
514
|
|
|
132
515
|
return final_inner
|
|
133
516
|
|
|
@@ -135,83 +518,186 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
|
|
|
135
518
|
|
|
136
519
|
|
|
137
520
|
def tensorleap_custom_metric(name: str,
|
|
138
|
-
direction: Union[MetricDirection, Dict[str, MetricDirection]] =
|
|
521
|
+
direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
|
|
139
522
|
compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
|
|
140
523
|
connects_to=None):
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
524
|
+
name_to_unique_name = defaultdict(set)
|
|
525
|
+
|
|
526
|
+
def decorating_function(
|
|
527
|
+
user_function: Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs,
|
|
528
|
+
ConfusionMatrixCallableInterfaceMultiArgs]):
|
|
529
|
+
nonlocal direction
|
|
530
|
+
|
|
531
|
+
direction_was_provided = direction is not _UNSET
|
|
532
|
+
|
|
533
|
+
if not direction_was_provided:
|
|
534
|
+
direction = MetricDirection.Downward
|
|
535
|
+
if not _call_from_tl_platform:
|
|
536
|
+
store_warning_by_param(
|
|
537
|
+
param_name="direction",
|
|
538
|
+
user_func_name=user_function.__name__,
|
|
539
|
+
default_value=direction,
|
|
540
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/custom-metrics"
|
|
541
|
+
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def _validate_decorators_signature():
|
|
545
|
+
err_message = f"{user_function.__name__} validation failed.\n"
|
|
546
|
+
if not isinstance(name, str):
|
|
547
|
+
raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
|
|
548
|
+
valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
|
|
549
|
+
if isinstance(direction, MetricDirection):
|
|
550
|
+
if direction not in valid_directions:
|
|
551
|
+
raise ValueError(
|
|
552
|
+
err_message +
|
|
553
|
+
f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
|
|
554
|
+
f"got type {type(direction).__name__}."
|
|
555
|
+
)
|
|
556
|
+
elif isinstance(direction, dict):
|
|
557
|
+
if not all(isinstance(k, str) for k in direction.keys()):
|
|
558
|
+
invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
|
|
559
|
+
raise TypeError(
|
|
560
|
+
err_message +
|
|
561
|
+
f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
|
|
562
|
+
)
|
|
563
|
+
for k, v in direction.items():
|
|
564
|
+
if v not in valid_directions:
|
|
565
|
+
raise ValueError(
|
|
566
|
+
err_message +
|
|
567
|
+
f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
|
|
568
|
+
f"got type {type(v).__name__}."
|
|
569
|
+
)
|
|
570
|
+
else:
|
|
571
|
+
raise TypeError(
|
|
572
|
+
err_message +
|
|
573
|
+
f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
|
|
574
|
+
f"got type {type(direction).__name__}."
|
|
575
|
+
)
|
|
576
|
+
if compute_insights is not None:
|
|
577
|
+
if not isinstance(compute_insights, (bool, dict)):
|
|
578
|
+
raise TypeError(
|
|
579
|
+
err_message +
|
|
580
|
+
f"`compute_insights` must be a bool or a Dict[str, bool], "
|
|
581
|
+
f"got type {type(compute_insights).__name__}."
|
|
582
|
+
)
|
|
583
|
+
if isinstance(compute_insights, dict):
|
|
584
|
+
if not all(isinstance(k, str) for k in compute_insights.keys()):
|
|
585
|
+
invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
|
|
586
|
+
raise TypeError(
|
|
587
|
+
err_message +
|
|
588
|
+
f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
|
|
589
|
+
)
|
|
590
|
+
for k, v in compute_insights.items():
|
|
591
|
+
if not isinstance(v, bool):
|
|
592
|
+
raise TypeError(
|
|
593
|
+
err_message +
|
|
594
|
+
f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
|
|
595
|
+
)
|
|
596
|
+
if connects_to is not None:
|
|
597
|
+
valid_types = (str, list, tuple, set)
|
|
598
|
+
if not isinstance(connects_to, valid_types):
|
|
599
|
+
raise TypeError(
|
|
600
|
+
err_message +
|
|
601
|
+
f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
|
|
602
|
+
)
|
|
603
|
+
if isinstance(connects_to, (list, tuple, set)):
|
|
604
|
+
invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
|
|
605
|
+
if invalid_elems:
|
|
606
|
+
raise TypeError(
|
|
607
|
+
err_message +
|
|
608
|
+
f"All elements in `connects_to` must be strings, "
|
|
609
|
+
f"but found element types: {invalid_elems}."
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
_validate_decorators_signature()
|
|
613
|
+
|
|
144
614
|
for metric_handler in leap_binder.setup_container.metrics:
|
|
145
615
|
if metric_handler.metric_handler_data.name == name:
|
|
146
616
|
raise Exception(f'Metric with name {name} already exists. '
|
|
147
617
|
f'Please choose another')
|
|
148
618
|
|
|
149
619
|
def _validate_input_args(*args, **kwargs) -> None:
|
|
620
|
+
assert len(args) + len(kwargs) > 0, (
|
|
621
|
+
f"{user_function.__name__}() validation failed: "
|
|
622
|
+
f"Expected at least one positional|key-word argument of type np.ndarray, "
|
|
623
|
+
f"but received none. "
|
|
624
|
+
f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
|
|
625
|
+
)
|
|
150
626
|
for i, arg in enumerate(args):
|
|
151
627
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
152
|
-
f'
|
|
628
|
+
f'{user_function.__name__}() validation failed: '
|
|
153
629
|
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
154
630
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
155
631
|
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
|
156
|
-
(f'
|
|
632
|
+
(f'{user_function.__name__}() validation failed: Argument #{i} '
|
|
157
633
|
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
|
158
634
|
f'instead of {leap_binder.batch_size_to_validate}')
|
|
159
635
|
|
|
160
636
|
for _arg_name, arg in kwargs.items():
|
|
161
637
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
162
|
-
f'
|
|
638
|
+
f'{user_function.__name__}() validation failed: '
|
|
163
639
|
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
164
640
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
165
641
|
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
|
166
|
-
(f'
|
|
642
|
+
(f'{user_function.__name__}() validation failed: Argument {_arg_name} '
|
|
167
643
|
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
|
168
644
|
f'instead of {leap_binder.batch_size_to_validate}')
|
|
169
645
|
|
|
170
646
|
def _validate_result(result) -> None:
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
647
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
648
|
+
expected_type_name="List[float | int | None | List[ConfusionMatrixElement] ] | NDArray[np.float32] or dictonary with one of these types as its values types")
|
|
649
|
+
supported_types_message = (f'{user_function.__name__}() validation failed: '
|
|
650
|
+
f'{user_function.__name__}() has returned unsupported type.\nSupported types are List[float|int|None], '
|
|
651
|
+
f'List[List[ConfusionMatrixElement]], NDArray[np.float32] or dictonary with one of these types as its values types. ')
|
|
174
652
|
|
|
175
|
-
def _validate_single_metric(single_metric_result):
|
|
653
|
+
def _validate_single_metric(single_metric_result, key=None):
|
|
176
654
|
if isinstance(single_metric_result, list):
|
|
177
655
|
if isinstance(single_metric_result[0], list):
|
|
178
|
-
assert isinstance(single_metric_result[0]
|
|
179
|
-
f
|
|
656
|
+
assert all(isinstance(cm, ConfusionMatrixElement) for cm in single_metric_result[0]), (
|
|
657
|
+
f"{supported_types_message} "
|
|
658
|
+
f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
|
|
659
|
+
f"List[List[{', '.join(type(cm).__name__ for cm in single_metric_result[0])}]]."
|
|
660
|
+
)
|
|
661
|
+
|
|
180
662
|
else:
|
|
181
|
-
assert isinstance(single_metric_result
|
|
182
|
-
|
|
183
|
-
|
|
663
|
+
assert all(isinstance(v, (float, int, type(None), np.float32)) for v in single_metric_result), (
|
|
664
|
+
f"{supported_types_message}\n"
|
|
665
|
+
f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
|
|
666
|
+
f"List[{', '.join(type(v).__name__ for v in single_metric_result)}]."
|
|
667
|
+
)
|
|
184
668
|
else:
|
|
185
669
|
assert isinstance(single_metric_result,
|
|
186
|
-
np.ndarray), f'{supported_types_message}
|
|
187
|
-
assert len(single_metric_result.shape) == 1, (f'
|
|
670
|
+
np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
|
|
671
|
+
assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
|
|
188
672
|
f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
|
|
189
673
|
|
|
190
674
|
if leap_binder.batch_size_to_validate:
|
|
191
675
|
assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
|
|
192
|
-
f'
|
|
676
|
+
f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
|
|
193
677
|
|
|
194
678
|
if isinstance(result, dict):
|
|
195
679
|
for key, value in result.items():
|
|
680
|
+
_validate_single_metric(value, key)
|
|
681
|
+
|
|
196
682
|
assert isinstance(key, str), \
|
|
197
|
-
(f'
|
|
683
|
+
(f'{user_function.__name__}() validation failed: '
|
|
198
684
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
|
199
685
|
_validate_single_metric(value)
|
|
200
686
|
|
|
201
687
|
if isinstance(direction, dict):
|
|
202
688
|
for direction_key in direction:
|
|
203
689
|
assert direction_key in result, \
|
|
204
|
-
(f'
|
|
690
|
+
(f'{user_function.__name__}() validation failed: '
|
|
205
691
|
f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
|
|
206
692
|
|
|
207
693
|
if compute_insights is not None:
|
|
208
694
|
assert isinstance(compute_insights, dict), \
|
|
209
|
-
(f'
|
|
695
|
+
(f'{user_function.__name__}() validation failed: '
|
|
210
696
|
f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
|
|
211
697
|
|
|
212
698
|
for ci_key in compute_insights:
|
|
213
699
|
assert ci_key in result, \
|
|
214
|
-
(f'
|
|
700
|
+
(f'{user_function.__name__}() validation failed: '
|
|
215
701
|
f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
|
|
216
702
|
|
|
217
703
|
else:
|
|
@@ -219,7 +705,7 @@ def tensorleap_custom_metric(name: str,
|
|
|
219
705
|
|
|
220
706
|
if compute_insights is not None:
|
|
221
707
|
assert isinstance(compute_insights, bool), \
|
|
222
|
-
(f'
|
|
708
|
+
(f'{user_function.__name__}() validation failed: '
|
|
223
709
|
f'compute_insights should be boolean. Got {type(compute_insights)}.')
|
|
224
710
|
|
|
225
711
|
@functools.wraps(user_function)
|
|
@@ -246,11 +732,15 @@ def tensorleap_custom_metric(name: str,
|
|
|
246
732
|
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
|
|
247
733
|
|
|
248
734
|
def inner(*args, **kwargs):
|
|
735
|
+
if not _call_from_tl_platform:
|
|
736
|
+
set_current('tensorleap_custom_metric')
|
|
249
737
|
_validate_input_args(*args, **kwargs)
|
|
250
738
|
|
|
251
739
|
result = inner_without_validate(*args, **kwargs)
|
|
252
740
|
|
|
253
741
|
_validate_result(result)
|
|
742
|
+
if not _call_from_tl_platform:
|
|
743
|
+
update_env_params_func("tensorleap_custom_metric", "v")
|
|
254
744
|
return result
|
|
255
745
|
|
|
256
746
|
def mapping_inner(*args, **kwargs):
|
|
@@ -260,6 +750,11 @@ def tensorleap_custom_metric(name: str,
|
|
|
260
750
|
|
|
261
751
|
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
262
752
|
ordered_connections = list(args) + ordered_connections
|
|
753
|
+
|
|
754
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
755
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
756
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
757
|
+
|
|
263
758
|
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
264
759
|
mapping_inner.name, NodeMappingType.Metric)
|
|
265
760
|
|
|
@@ -282,29 +777,40 @@ def tensorleap_custom_metric(name: str,
|
|
|
282
777
|
def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
283
778
|
heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
|
|
284
779
|
connects_to=None):
|
|
780
|
+
name_to_unique_name = defaultdict(set)
|
|
781
|
+
|
|
285
782
|
def decorating_function(user_function: VisualizerCallableInterface):
|
|
783
|
+
assert isinstance(visualizer_type, LeapDataType), (f"{user_function.__name__} validation failed: "
|
|
784
|
+
f"visualizer_type should be of type {LeapDataType.__name__} but got {type(visualizer_type)}"
|
|
785
|
+
)
|
|
286
786
|
for viz_handler in leap_binder.setup_container.visualizers:
|
|
287
787
|
if viz_handler.visualizer_handler_data.name == name:
|
|
288
788
|
raise Exception(f'Visualizer with name {name} already exists. '
|
|
289
789
|
f'Please choose another')
|
|
290
790
|
|
|
291
791
|
def _validate_input_args(*args, **kwargs):
|
|
792
|
+
assert len(args) + len(kwargs) > 0, (
|
|
793
|
+
f"{user_function.__name__}() validation failed: "
|
|
794
|
+
f"Expected at least one positional|key-word argument of type np.ndarray, "
|
|
795
|
+
f"but received none. "
|
|
796
|
+
f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
|
|
797
|
+
)
|
|
292
798
|
for i, arg in enumerate(args):
|
|
293
799
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
294
|
-
f'
|
|
800
|
+
f'{user_function.__name__}() validation failed: '
|
|
295
801
|
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
296
802
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
297
803
|
assert arg.shape[0] != leap_binder.batch_size_to_validate, \
|
|
298
|
-
(f'
|
|
804
|
+
(f'{user_function.__name__}() validation failed: '
|
|
299
805
|
f'Argument #{i} should be without batch dimension. ')
|
|
300
806
|
|
|
301
807
|
for _arg_name, arg in kwargs.items():
|
|
302
808
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
303
|
-
f'
|
|
809
|
+
f'{user_function.__name__}() validation failed: '
|
|
304
810
|
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
305
811
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
306
812
|
assert arg.shape[0] != leap_binder.batch_size_to_validate, \
|
|
307
|
-
(f'
|
|
813
|
+
(f'{user_function.__name__}() validation failed: Argument {_arg_name} '
|
|
308
814
|
f'should be without batch dimension. ')
|
|
309
815
|
|
|
310
816
|
def _validate_result(result):
|
|
@@ -318,8 +824,11 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
318
824
|
LeapDataType.ImageWithBBox: LeapImageWithBBox,
|
|
319
825
|
LeapDataType.ImageWithHeatmap: LeapImageWithHeatmap
|
|
320
826
|
}
|
|
827
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
828
|
+
expected_type_name=result_type_map[visualizer_type])
|
|
829
|
+
|
|
321
830
|
assert isinstance(result, result_type_map[visualizer_type]), \
|
|
322
|
-
(f'
|
|
831
|
+
(f'{user_function.__name__}() validation failed: '
|
|
323
832
|
f'The return type should be {result_type_map[visualizer_type]}. Got {type(result)}.')
|
|
324
833
|
|
|
325
834
|
@functools.wraps(user_function)
|
|
@@ -346,11 +855,15 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
346
855
|
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
|
|
347
856
|
|
|
348
857
|
def inner(*args, **kwargs):
|
|
858
|
+
if not _call_from_tl_platform:
|
|
859
|
+
set_current('tensorleap_custom_visualizer')
|
|
349
860
|
_validate_input_args(*args, **kwargs)
|
|
350
861
|
|
|
351
862
|
result = inner_without_validate(*args, **kwargs)
|
|
352
863
|
|
|
353
864
|
_validate_result(result)
|
|
865
|
+
if not _call_from_tl_platform:
|
|
866
|
+
update_env_params_func("tensorleap_custom_visualizer", "v")
|
|
354
867
|
return result
|
|
355
868
|
|
|
356
869
|
def mapping_inner(*args, **kwargs):
|
|
@@ -358,6 +871,10 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
358
871
|
if 'user_unique_name' in kwargs:
|
|
359
872
|
user_unique_name = kwargs['user_unique_name']
|
|
360
873
|
|
|
874
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
875
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
876
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
877
|
+
|
|
361
878
|
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
362
879
|
ordered_connections = list(args) + ordered_connections
|
|
363
880
|
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
@@ -388,30 +905,26 @@ def tensorleap_metadata(
|
|
|
388
905
|
f'Please choose another')
|
|
389
906
|
|
|
390
907
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
391
|
-
assert isinstance(sample_id, (int, str)), \
|
|
392
|
-
(f'tensorleap_metadata validation failed: '
|
|
393
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
394
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
395
|
-
(f'tensorleap_metadata validation failed: '
|
|
396
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
397
908
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
398
|
-
(f'
|
|
909
|
+
(f'{user_function.__name__}() validation failed: '
|
|
399
910
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
400
911
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
401
912
|
|
|
402
913
|
def _validate_result(result):
|
|
403
914
|
supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
|
|
404
915
|
np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
|
|
916
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
917
|
+
expected_type_name=supported_result_types)
|
|
405
918
|
assert isinstance(result, supported_result_types), \
|
|
406
|
-
(f'
|
|
919
|
+
(f'{user_function.__name__}() validation failed: '
|
|
407
920
|
f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
|
|
408
921
|
if isinstance(result, dict):
|
|
409
922
|
for key, value in result.items():
|
|
410
923
|
assert isinstance(key, str), \
|
|
411
|
-
(f'
|
|
924
|
+
(f'{user_function.__name__}() validation failed: '
|
|
412
925
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
|
413
926
|
assert isinstance(value, supported_result_types), \
|
|
414
|
-
(f'
|
|
927
|
+
(f'{user_function.__name__}() validation failed: '
|
|
415
928
|
f'Values in the return dict should be of type {str(supported_result_types)}. Got {type(value)}.')
|
|
416
929
|
|
|
417
930
|
def inner_without_validate(sample_id, preprocess_response):
|
|
@@ -428,6 +941,60 @@ def tensorleap_metadata(
|
|
|
428
941
|
|
|
429
942
|
leap_binder.set_metadata(inner_without_validate, name, metadata_type)
|
|
430
943
|
|
|
944
|
+
def inner(*args, **kwargs):
|
|
945
|
+
if not _call_from_tl_platform:
|
|
946
|
+
set_current('tensorleap_metadata')
|
|
947
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
948
|
+
return None
|
|
949
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
950
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
951
|
+
sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
|
|
952
|
+
_validate_input_args(sample_id, preprocess_response)
|
|
953
|
+
|
|
954
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
955
|
+
|
|
956
|
+
_validate_result(result)
|
|
957
|
+
if not _call_from_tl_platform:
|
|
958
|
+
update_env_params_func("tensorleap_metadata", "v")
|
|
959
|
+
return result
|
|
960
|
+
|
|
961
|
+
return inner
|
|
962
|
+
|
|
963
|
+
return decorating_function
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def tensorleap_custom_latent_space():
|
|
967
|
+
def decorating_function(user_function: SectionCallableInterface):
|
|
968
|
+
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
969
|
+
assert isinstance(sample_id, (int, str)), \
|
|
970
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
971
|
+
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
972
|
+
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
973
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
974
|
+
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
975
|
+
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
976
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
977
|
+
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
978
|
+
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
979
|
+
|
|
980
|
+
def _validate_result(result):
|
|
981
|
+
assert isinstance(result, np.ndarray), \
|
|
982
|
+
(f'tensorleap_custom_loss validation failed: '
|
|
983
|
+
f'The return type should be a numpy array. Got {type(result)}.')
|
|
984
|
+
|
|
985
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
986
|
+
global _called_from_inside_tl_decorator
|
|
987
|
+
_called_from_inside_tl_decorator += 1
|
|
988
|
+
|
|
989
|
+
try:
|
|
990
|
+
result = user_function(sample_id, preprocess_response)
|
|
991
|
+
finally:
|
|
992
|
+
_called_from_inside_tl_decorator -= 1
|
|
993
|
+
|
|
994
|
+
return result
|
|
995
|
+
|
|
996
|
+
leap_binder.set_custom_latent_space(inner_without_validate)
|
|
997
|
+
|
|
431
998
|
def inner(sample_id, preprocess_response):
|
|
432
999
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
433
1000
|
return None
|
|
@@ -449,30 +1016,45 @@ def tensorleap_preprocess():
|
|
|
449
1016
|
leap_binder.set_preprocess(user_function)
|
|
450
1017
|
|
|
451
1018
|
def _validate_input_args(*args, **kwargs):
|
|
452
|
-
assert len(args)
|
|
453
|
-
(f'
|
|
1019
|
+
assert len(args) + len(kwargs) == 0, \
|
|
1020
|
+
(f'{user_function.__name__}() validation failed: '
|
|
454
1021
|
f'The function should not take any arguments. Got {args} and {kwargs}.')
|
|
455
1022
|
|
|
456
1023
|
def _validate_result(result):
|
|
457
|
-
assert isinstance(result, list),
|
|
458
|
-
(
|
|
459
|
-
|
|
1024
|
+
assert isinstance(result, list), (
|
|
1025
|
+
f"{user_function.__name__}() validation failed: expected return type list[{PreprocessResponse.__name__}]"
|
|
1026
|
+
f"(e.g., [PreprocessResponse1, PreprocessResponse2, ...]), but returned type is {type(result).__name__}."
|
|
1027
|
+
if not isinstance(result, tuple)
|
|
1028
|
+
else f"{user_function.__name__}() validation failed: expected to return a single list[{PreprocessResponse.__name__}] object, "
|
|
1029
|
+
f"but returned {len(result)} objects instead."
|
|
1030
|
+
)
|
|
460
1031
|
for i, response in enumerate(result):
|
|
461
1032
|
assert isinstance(response, PreprocessResponse), \
|
|
462
|
-
(f'
|
|
1033
|
+
(f'{user_function.__name__}() validation failed: '
|
|
463
1034
|
f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
|
|
464
1035
|
assert len(set(result)) == len(result), \
|
|
465
|
-
(f'
|
|
1036
|
+
(f'{user_function.__name__}() validation failed: '
|
|
466
1037
|
f'The return list should not contain duplicate PreprocessResponse objects.')
|
|
467
1038
|
|
|
468
1039
|
def inner(*args, **kwargs):
|
|
1040
|
+
if not _call_from_tl_platform:
|
|
1041
|
+
set_current('tensorleap_metadata')
|
|
469
1042
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
470
1043
|
return [None, None, None, None]
|
|
471
1044
|
|
|
472
1045
|
_validate_input_args(*args, **kwargs)
|
|
473
|
-
|
|
474
1046
|
result = user_function()
|
|
475
1047
|
_validate_result(result)
|
|
1048
|
+
|
|
1049
|
+
# Emit integration test event once per test
|
|
1050
|
+
try:
|
|
1051
|
+
emit_integration_event_once(AnalyticsEvent.PREPROCESS_INTEGRATION_TEST, {
|
|
1052
|
+
'preprocess_responses_count': len(result)
|
|
1053
|
+
})
|
|
1054
|
+
except Exception as e:
|
|
1055
|
+
logger.debug(f"Failed to emit preprocess integration test event: {e}")
|
|
1056
|
+
if not _call_from_tl_platform:
|
|
1057
|
+
update_env_params_func("tensorleap_preprocess", "v")
|
|
476
1058
|
return result
|
|
477
1059
|
|
|
478
1060
|
return inner
|
|
@@ -481,7 +1063,7 @@ def tensorleap_preprocess():
|
|
|
481
1063
|
|
|
482
1064
|
|
|
483
1065
|
def tensorleap_element_instance_preprocess(
|
|
484
|
-
instance_length_encoder: InstanceLengthCallableInterface):
|
|
1066
|
+
instance_length_encoder: InstanceLengthCallableInterface, instance_mask_encoder: InstanceCallableInterface):
|
|
485
1067
|
def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
|
|
486
1068
|
def user_function_instance() -> List[PreprocessResponse]:
|
|
487
1069
|
result = user_function()
|
|
@@ -497,16 +1079,39 @@ def tensorleap_element_instance_preprocess(
|
|
|
497
1079
|
for instance_id in instances_ids:
|
|
498
1080
|
instance_to_sample_ids_mappings[instance_id] = sample_id
|
|
499
1081
|
all_sample_ids.extend(instances_ids)
|
|
1082
|
+
preprocess_response.length = len(all_sample_ids)
|
|
500
1083
|
preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
|
|
501
1084
|
preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
|
|
502
1085
|
preprocess_response.sample_ids = all_sample_ids
|
|
503
1086
|
return result
|
|
504
1087
|
|
|
1088
|
+
def extract_extra_instance_metadata():
|
|
1089
|
+
result = user_function()
|
|
1090
|
+
for preprocess_response in result:
|
|
1091
|
+
for sample_id in preprocess_response.sample_ids:
|
|
1092
|
+
instances_length = instance_length_encoder(sample_id, preprocess_response)
|
|
1093
|
+
if instances_length > 0:
|
|
1094
|
+
element_instance = instance_mask_encoder(sample_id, preprocess_response, 0)
|
|
1095
|
+
instance_metadata = element_instance.instance_metadata
|
|
1096
|
+
if instance_metadata is None:
|
|
1097
|
+
return {}
|
|
1098
|
+
return instance_metadata
|
|
1099
|
+
return {}
|
|
1100
|
+
|
|
1101
|
+
|
|
505
1102
|
def builtin_instance_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
|
|
506
1103
|
return {'is_instance': '0', 'original_sample_id': idx, 'instance_name': 'none'}
|
|
507
1104
|
|
|
1105
|
+
def builtin_instance_extra_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
|
|
1106
|
+
instance_metadata = extract_extra_instance_metadata()
|
|
1107
|
+
for key, value in instance_metadata.items():
|
|
1108
|
+
instance_metadata[key] = 'unset'
|
|
1109
|
+
return instance_metadata
|
|
1110
|
+
|
|
508
1111
|
leap_binder.set_preprocess(user_function_instance)
|
|
509
1112
|
leap_binder.set_metadata(builtin_instance_metadata, "builtin_instance_metadata")
|
|
1113
|
+
leap_binder.set_metadata(builtin_instance_extra_metadata, "builtin_instance_extra_metadata")
|
|
1114
|
+
|
|
510
1115
|
|
|
511
1116
|
def _validate_input_args(*args, **kwargs):
|
|
512
1117
|
assert len(args) == 0 and len(kwargs) == 0, \
|
|
@@ -533,6 +1138,8 @@ def tensorleap_element_instance_preprocess(
|
|
|
533
1138
|
|
|
534
1139
|
result = user_function_instance()
|
|
535
1140
|
_validate_result(result)
|
|
1141
|
+
if not _call_from_tl_platform:
|
|
1142
|
+
update_env_params_func("tensorleap_preprocess", "v")
|
|
536
1143
|
return result
|
|
537
1144
|
|
|
538
1145
|
return inner
|
|
@@ -567,7 +1174,7 @@ def tensorleap_unlabeled_preprocess():
|
|
|
567
1174
|
|
|
568
1175
|
def tensorleap_instances_masks_encoder(name: str):
|
|
569
1176
|
def decorating_function(user_function: InstanceCallableInterface):
|
|
570
|
-
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id:
|
|
1177
|
+
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: int):
|
|
571
1178
|
assert isinstance(sample_id, str), \
|
|
572
1179
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
573
1180
|
f'Argument sample_id should be str. Got {type(sample_id)}.')
|
|
@@ -578,13 +1185,9 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
578
1185
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
579
1186
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
580
1187
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
581
|
-
assert isinstance(instance_id,
|
|
582
|
-
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
583
|
-
f'Argument instance_id should be str. Got {type(instance_id)}.')
|
|
584
|
-
assert type(instance_id) == preprocess_response.sample_id_type, \
|
|
1188
|
+
assert isinstance(instance_id, int), \
|
|
585
1189
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
586
|
-
f'Argument instance_id should be
|
|
587
|
-
f'{preprocess_response.sample_id_type}. Got {type(instance_id)}.')
|
|
1190
|
+
f'Argument instance_id should be int. Got {type(instance_id)}.')
|
|
588
1191
|
|
|
589
1192
|
def _validate_result(result):
|
|
590
1193
|
assert isinstance(result, ElementInstance) or (result is None), \
|
|
@@ -619,10 +1222,11 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
619
1222
|
|
|
620
1223
|
return decorating_function
|
|
621
1224
|
|
|
1225
|
+
|
|
622
1226
|
def tensorleap_instances_length_encoder(name: str):
|
|
623
1227
|
def decorating_function(user_function: InstanceLengthCallableInterface):
|
|
624
1228
|
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
|
|
625
|
-
assert isinstance(sample_id, str), \
|
|
1229
|
+
assert isinstance(sample_id, (str, int)), \
|
|
626
1230
|
(f'tensorleap_instances_length_encoder validation failed: '
|
|
627
1231
|
f'Argument sample_id should be str. Got {type(sample_id)}.')
|
|
628
1232
|
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
@@ -649,8 +1253,6 @@ def tensorleap_instances_length_encoder(name: str):
|
|
|
649
1253
|
|
|
650
1254
|
return result
|
|
651
1255
|
|
|
652
|
-
# leap_binder.set_instance_masks(inner_without_validate, name). # TODO: do i need this?
|
|
653
|
-
|
|
654
1256
|
def inner(sample_id, preprocess_response):
|
|
655
1257
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
656
1258
|
return None
|
|
@@ -666,46 +1268,88 @@ def tensorleap_instances_length_encoder(name: str):
|
|
|
666
1268
|
|
|
667
1269
|
return decorating_function
|
|
668
1270
|
|
|
669
|
-
|
|
1271
|
+
|
|
1272
|
+
def tensorleap_input_encoder(name: str, channel_dim=_UNSET, model_input_index=None):
|
|
670
1273
|
def decorating_function(user_function: SectionCallableInterface):
|
|
671
1274
|
for input_handler in leap_binder.setup_container.inputs:
|
|
672
1275
|
if input_handler.name == name:
|
|
673
1276
|
raise Exception(f'Input with name {name} already exists. '
|
|
674
1277
|
f'Please choose another')
|
|
1278
|
+
nonlocal channel_dim
|
|
1279
|
+
|
|
1280
|
+
channel_dim_was_provided = channel_dim is not _UNSET
|
|
1281
|
+
|
|
1282
|
+
if not channel_dim_was_provided:
|
|
1283
|
+
channel_dim = -1
|
|
1284
|
+
if not _call_from_tl_platform:
|
|
1285
|
+
store_warning_by_param(
|
|
1286
|
+
param_name="channel_dim",
|
|
1287
|
+
user_func_name=user_function.__name__,
|
|
1288
|
+
default_value=channel_dim,
|
|
1289
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/input-encoder"
|
|
1290
|
+
|
|
1291
|
+
)
|
|
1292
|
+
|
|
675
1293
|
if channel_dim <= 0 and channel_dim != -1:
|
|
676
1294
|
raise Exception(f"Channel dim for input {name} is expected to be either -1 or positive")
|
|
677
1295
|
|
|
678
|
-
leap_binder.set_input(user_function, name, channel_dim=channel_dim)
|
|
679
|
-
|
|
680
1296
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
681
|
-
assert isinstance(sample_id, (int, str)), \
|
|
682
|
-
(f'tensorleap_input_encoder validation failed: '
|
|
683
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
684
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
685
|
-
(f'tensorleap_input_encoder validation failed: '
|
|
686
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
687
1297
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
688
|
-
(f'
|
|
1298
|
+
(f'{user_function.__name__}() validation failed: '
|
|
689
1299
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
690
1300
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
691
1301
|
|
|
692
1302
|
def _validate_result(result):
|
|
1303
|
+
validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray")
|
|
693
1304
|
assert isinstance(result, np.ndarray), \
|
|
694
|
-
(f'
|
|
1305
|
+
(f'{user_function.__name__}() validation failed: '
|
|
695
1306
|
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
696
1307
|
assert result.dtype == np.float32, \
|
|
697
|
-
(f'
|
|
1308
|
+
(f'{user_function.__name__}() validation failed: '
|
|
698
1309
|
f'The return type should be a numpy array of type float32. Got {result.dtype}.')
|
|
699
|
-
assert channel_dim - 1 <= len(result.shape), (f'
|
|
1310
|
+
assert channel_dim - 1 <= len(result.shape), (f'{user_function.__name__}() validation failed: '
|
|
700
1311
|
f'The channel_dim ({channel_dim}) should be <= to the rank of the resulting input rank ({len(result.shape)}).')
|
|
701
1312
|
|
|
702
|
-
def
|
|
1313
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1314
|
+
global _called_from_inside_tl_decorator
|
|
1315
|
+
_called_from_inside_tl_decorator += 1
|
|
1316
|
+
|
|
1317
|
+
try:
|
|
1318
|
+
result = user_function(sample_id, preprocess_response)
|
|
1319
|
+
finally:
|
|
1320
|
+
_called_from_inside_tl_decorator -= 1
|
|
1321
|
+
|
|
1322
|
+
return result
|
|
1323
|
+
|
|
1324
|
+
leap_binder.set_input(inner_without_validate, name, channel_dim=channel_dim)
|
|
1325
|
+
|
|
1326
|
+
def inner(*args, **kwargs):
|
|
1327
|
+
if not _call_from_tl_platform:
|
|
1328
|
+
set_current("tensorleap_input_encoder")
|
|
1329
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
1330
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
1331
|
+
sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
|
|
703
1332
|
_validate_input_args(sample_id, preprocess_response)
|
|
704
|
-
|
|
1333
|
+
|
|
1334
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1335
|
+
|
|
705
1336
|
_validate_result(result)
|
|
706
1337
|
|
|
707
|
-
if _called_from_inside_tl_decorator == 0:
|
|
1338
|
+
if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
|
|
1339
|
+
batch_warning(result, user_function.__name__)
|
|
708
1340
|
result = np.expand_dims(result, axis=0)
|
|
1341
|
+
# Emit integration test event once per test
|
|
1342
|
+
try:
|
|
1343
|
+
emit_integration_event_once(AnalyticsEvent.INPUT_ENCODER_INTEGRATION_TEST, {
|
|
1344
|
+
'encoder_name': name,
|
|
1345
|
+
'channel_dim': channel_dim,
|
|
1346
|
+
'model_input_index': model_input_index
|
|
1347
|
+
})
|
|
1348
|
+
except Exception as e:
|
|
1349
|
+
logger.debug(f"Failed to emit input_encoder integration test event: {e}")
|
|
1350
|
+
if not _call_from_tl_platform:
|
|
1351
|
+
update_env_params_func("tensorleap_input_encoder", "v")
|
|
1352
|
+
|
|
709
1353
|
return result
|
|
710
1354
|
|
|
711
1355
|
node_mapping_type = NodeMappingType.Input
|
|
@@ -713,22 +1357,23 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
|
|
|
713
1357
|
node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
|
|
714
1358
|
inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
715
1359
|
|
|
716
|
-
def mapping_inner(
|
|
1360
|
+
def mapping_inner(*args, **kwargs):
|
|
717
1361
|
class TempMapping:
|
|
718
1362
|
pass
|
|
719
1363
|
|
|
720
1364
|
ret = TempMapping()
|
|
721
1365
|
ret.node_mapping = mapping_inner.node_mapping
|
|
722
1366
|
|
|
1367
|
+
leap_binder.mapping_connections.append(NodeConnection(mapping_inner.node_mapping, None))
|
|
723
1368
|
return ret
|
|
724
1369
|
|
|
725
1370
|
mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
726
1371
|
|
|
727
|
-
def final_inner(
|
|
1372
|
+
def final_inner(*args, **kwargs):
|
|
728
1373
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
729
|
-
return mapping_inner(
|
|
1374
|
+
return mapping_inner(*args, **kwargs)
|
|
730
1375
|
else:
|
|
731
|
-
return inner(
|
|
1376
|
+
return inner(*args, **kwargs)
|
|
732
1377
|
|
|
733
1378
|
final_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
734
1379
|
|
|
@@ -744,40 +1389,64 @@ def tensorleap_gt_encoder(name: str):
|
|
|
744
1389
|
raise Exception(f'GT with name {name} already exists. '
|
|
745
1390
|
f'Please choose another')
|
|
746
1391
|
|
|
747
|
-
leap_binder.set_ground_truth(user_function, name)
|
|
748
|
-
|
|
749
1392
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
750
|
-
assert isinstance(sample_id, (int, str)), \
|
|
751
|
-
(f'tensorleap_gt_encoder validation failed: '
|
|
752
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
753
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
754
|
-
(f'tensorleap_gt_encoder validation failed: '
|
|
755
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
756
1393
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
757
|
-
(f'
|
|
1394
|
+
(f'{user_function.__name__}() validation failed: '
|
|
758
1395
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
759
1396
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
760
1397
|
|
|
761
1398
|
def _validate_result(result):
|
|
1399
|
+
validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray",
|
|
1400
|
+
gt_flag=True)
|
|
762
1401
|
assert isinstance(result, np.ndarray), \
|
|
763
|
-
(f'
|
|
1402
|
+
(f'{user_function.__name__}() validation failed: '
|
|
764
1403
|
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
765
1404
|
assert result.dtype == np.float32, \
|
|
766
|
-
(f'
|
|
1405
|
+
(f'{user_function.__name__}() validation failed: '
|
|
767
1406
|
f'The return type should be a numpy array of type float32. Got {result.dtype}.')
|
|
768
1407
|
|
|
769
|
-
def
|
|
1408
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1409
|
+
global _called_from_inside_tl_decorator
|
|
1410
|
+
_called_from_inside_tl_decorator += 1
|
|
1411
|
+
|
|
1412
|
+
try:
|
|
1413
|
+
result = user_function(sample_id, preprocess_response)
|
|
1414
|
+
finally:
|
|
1415
|
+
_called_from_inside_tl_decorator -= 1
|
|
1416
|
+
|
|
1417
|
+
return result
|
|
1418
|
+
|
|
1419
|
+
leap_binder.set_ground_truth(inner_without_validate, name)
|
|
1420
|
+
|
|
1421
|
+
def inner(*args, **kwargs):
|
|
1422
|
+
if not _call_from_tl_platform:
|
|
1423
|
+
set_current("tensorleap_gt_encoder")
|
|
1424
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
1425
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
1426
|
+
sample_id, preprocess_response = args
|
|
770
1427
|
_validate_input_args(sample_id, preprocess_response)
|
|
771
|
-
|
|
1428
|
+
|
|
1429
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1430
|
+
|
|
772
1431
|
_validate_result(result)
|
|
773
1432
|
|
|
774
|
-
if _called_from_inside_tl_decorator == 0:
|
|
1433
|
+
if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
|
|
1434
|
+
batch_warning(result, user_function.__name__)
|
|
775
1435
|
result = np.expand_dims(result, axis=0)
|
|
1436
|
+
# Emit integration test event once per test
|
|
1437
|
+
try:
|
|
1438
|
+
emit_integration_event_once(AnalyticsEvent.GT_ENCODER_INTEGRATION_TEST, {
|
|
1439
|
+
'encoder_name': name
|
|
1440
|
+
})
|
|
1441
|
+
except Exception as e:
|
|
1442
|
+
logger.debug(f"Failed to emit gt_encoder integration test event: {e}")
|
|
1443
|
+
if not _call_from_tl_platform:
|
|
1444
|
+
update_env_params_func("tensorleap_gt_encoder", "v")
|
|
776
1445
|
return result
|
|
777
1446
|
|
|
778
1447
|
inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
779
1448
|
|
|
780
|
-
def mapping_inner(
|
|
1449
|
+
def mapping_inner(*args, **kwargs):
|
|
781
1450
|
class TempMapping:
|
|
782
1451
|
pass
|
|
783
1452
|
|
|
@@ -788,11 +1457,11 @@ def tensorleap_gt_encoder(name: str):
|
|
|
788
1457
|
|
|
789
1458
|
mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
790
1459
|
|
|
791
|
-
def final_inner(
|
|
1460
|
+
def final_inner(*args, **kwargs):
|
|
792
1461
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
793
|
-
return mapping_inner(
|
|
1462
|
+
return mapping_inner(*args, **kwargs)
|
|
794
1463
|
else:
|
|
795
|
-
return inner(
|
|
1464
|
+
return inner(*args, **kwargs)
|
|
796
1465
|
|
|
797
1466
|
final_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
798
1467
|
|
|
@@ -802,6 +1471,8 @@ def tensorleap_gt_encoder(name: str):
|
|
|
802
1471
|
|
|
803
1472
|
|
|
804
1473
|
def tensorleap_custom_loss(name: str, connects_to=None):
|
|
1474
|
+
name_to_unique_name = defaultdict(set)
|
|
1475
|
+
|
|
805
1476
|
def decorating_function(user_function: CustomCallableInterface):
|
|
806
1477
|
for loss_handler in leap_binder.setup_container.custom_loss_handlers:
|
|
807
1478
|
if loss_handler.custom_loss_handler_data.name == name:
|
|
@@ -809,35 +1480,29 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
809
1480
|
f'Please choose another')
|
|
810
1481
|
|
|
811
1482
|
valid_types = (np.ndarray, SamplePreprocessResponse)
|
|
812
|
-
try:
|
|
813
|
-
import tensorflow as tf
|
|
814
|
-
valid_types = (np.ndarray, SamplePreprocessResponse, tf.Tensor)
|
|
815
|
-
except ImportError:
|
|
816
|
-
pass
|
|
817
1483
|
|
|
818
1484
|
def _validate_input_args(*args, **kwargs):
|
|
819
|
-
|
|
1485
|
+
assert len(args) + len(kwargs) > 0, (
|
|
1486
|
+
f"{user_function.__name__}() validation failed: "
|
|
1487
|
+
f"Expected at least one positional|key-word argument of the allowed types (np.ndarray|SamplePreprocessResponse|). "
|
|
1488
|
+
f"but received none. "
|
|
1489
|
+
f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
|
|
1490
|
+
)
|
|
820
1491
|
for i, arg in enumerate(args):
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
824
|
-
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
|
825
|
-
else:
|
|
826
|
-
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
827
|
-
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
1492
|
+
assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
|
|
1493
|
+
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
828
1494
|
for _arg_name, arg in kwargs.items():
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
832
|
-
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
|
833
|
-
else:
|
|
834
|
-
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
835
|
-
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
1495
|
+
assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
|
|
1496
|
+
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
836
1497
|
|
|
837
1498
|
def _validate_result(result):
|
|
838
|
-
|
|
839
|
-
|
|
1499
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
1500
|
+
expected_type_name="np.ndarray")
|
|
1501
|
+
assert isinstance(result, np.ndarray), \
|
|
1502
|
+
(f'{user_function.__name__} validation failed: '
|
|
840
1503
|
f'The return type should be a numpy array. Got {type(result)}.')
|
|
1504
|
+
assert result.ndim < 2, (f'{user_function.__name__} validation failed: '
|
|
1505
|
+
f'The return type should be a 1Dim numpy array but got {result.ndim}Dim.')
|
|
841
1506
|
|
|
842
1507
|
@functools.wraps(user_function)
|
|
843
1508
|
def inner_without_validate(*args, **kwargs):
|
|
@@ -863,11 +1528,16 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
863
1528
|
_add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
|
|
864
1529
|
|
|
865
1530
|
def inner(*args, **kwargs):
|
|
1531
|
+
if not _call_from_tl_platform:
|
|
1532
|
+
set_current("tensorleap_custom_loss")
|
|
866
1533
|
_validate_input_args(*args, **kwargs)
|
|
867
1534
|
|
|
868
1535
|
result = inner_without_validate(*args, **kwargs)
|
|
869
1536
|
|
|
870
1537
|
_validate_result(result)
|
|
1538
|
+
if not _call_from_tl_platform:
|
|
1539
|
+
update_env_params_func("tensorleap_custom_loss", "v")
|
|
1540
|
+
|
|
871
1541
|
return result
|
|
872
1542
|
|
|
873
1543
|
def mapping_inner(*args, **kwargs):
|
|
@@ -875,6 +1545,10 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
875
1545
|
if 'user_unique_name' in kwargs:
|
|
876
1546
|
user_unique_name = kwargs['user_unique_name']
|
|
877
1547
|
|
|
1548
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
1549
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
1550
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
1551
|
+
|
|
878
1552
|
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
879
1553
|
ordered_connections = list(args) + ordered_connections
|
|
880
1554
|
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
@@ -920,3 +1594,180 @@ def tensorleap_custom_layer(name: str):
|
|
|
920
1594
|
return custom_layer
|
|
921
1595
|
|
|
922
1596
|
return decorating_function
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def tensorleap_status_table():
|
|
1600
|
+
import atexit
|
|
1601
|
+
import sys
|
|
1602
|
+
import traceback
|
|
1603
|
+
from typing import Any
|
|
1604
|
+
|
|
1605
|
+
CHECK = "✅"
|
|
1606
|
+
CROSS = "❌"
|
|
1607
|
+
UNKNOWN = "❔"
|
|
1608
|
+
|
|
1609
|
+
code_mapping_failure = [0]
|
|
1610
|
+
|
|
1611
|
+
table = [
|
|
1612
|
+
{"name": "tensorleap_preprocess", "Added to integration": UNKNOWN},
|
|
1613
|
+
{"name": "tensorleap_integration_test", "Added to integration": UNKNOWN},
|
|
1614
|
+
{"name": "tensorleap_input_encoder", "Added to integration": UNKNOWN},
|
|
1615
|
+
{"name": "tensorleap_gt_encoder", "Added to integration": UNKNOWN},
|
|
1616
|
+
{"name": "tensorleap_load_model", "Added to integration": UNKNOWN},
|
|
1617
|
+
{"name": "tensorleap_custom_loss", "Added to integration": UNKNOWN},
|
|
1618
|
+
{"name": "tensorleap_custom_metric (optional)", "Added to integration": UNKNOWN},
|
|
1619
|
+
{"name": "tensorleap_metadata (optional)", "Added to integration": UNKNOWN},
|
|
1620
|
+
{"name": "tensorleap_custom_visualizer (optional)", "Added to integration": UNKNOWN},
|
|
1621
|
+
]
|
|
1622
|
+
|
|
1623
|
+
_finalizer_called = {"done": False}
|
|
1624
|
+
_crashed = {"value": False}
|
|
1625
|
+
_current_func = {"name": None}
|
|
1626
|
+
|
|
1627
|
+
def _link(url: str) -> str:
|
|
1628
|
+
return f"\033{url}\033"
|
|
1629
|
+
|
|
1630
|
+
def _remove_suffix(s: str, suffix: str) -> str:
|
|
1631
|
+
if suffix and s.endswith(suffix):
|
|
1632
|
+
return s[:-len(suffix)]
|
|
1633
|
+
return s
|
|
1634
|
+
|
|
1635
|
+
def _find_row(name: str):
|
|
1636
|
+
for row in table:
|
|
1637
|
+
if _remove_suffix(row["name"], " (optional)") == name:
|
|
1638
|
+
return row
|
|
1639
|
+
return None
|
|
1640
|
+
|
|
1641
|
+
def _set_status(name: str, status_symbol: str):
|
|
1642
|
+
row = _find_row(name)
|
|
1643
|
+
if not row:
|
|
1644
|
+
return
|
|
1645
|
+
|
|
1646
|
+
cur = row["Added to integration"]
|
|
1647
|
+
if status_symbol == UNKNOWN:
|
|
1648
|
+
return
|
|
1649
|
+
if cur == CHECK and status_symbol != CHECK:
|
|
1650
|
+
return
|
|
1651
|
+
|
|
1652
|
+
row["Added to integration"] = status_symbol
|
|
1653
|
+
|
|
1654
|
+
def _mark_unknowns_as_cross():
|
|
1655
|
+
for row in table:
|
|
1656
|
+
if row["Added to integration"] == UNKNOWN:
|
|
1657
|
+
row["Added to integration"] = CROSS
|
|
1658
|
+
|
|
1659
|
+
def _format_default_value(v: Any) -> str:
|
|
1660
|
+
if hasattr(v, "name"):
|
|
1661
|
+
return str(v.name)
|
|
1662
|
+
if isinstance(v, str):
|
|
1663
|
+
return v
|
|
1664
|
+
s = repr(v)
|
|
1665
|
+
return s if len(s) <= 120 else s[:120] + "..."
|
|
1666
|
+
|
|
1667
|
+
def _print_param_default_warnings():
|
|
1668
|
+
data = _get_param_default_warnings()
|
|
1669
|
+
if not data:
|
|
1670
|
+
return
|
|
1671
|
+
|
|
1672
|
+
print("\nWarnings (Default use. It is recommended to set values explicitly):")
|
|
1673
|
+
for param_name in sorted(data.keys()):
|
|
1674
|
+
default_value = data[param_name]["default_value"]
|
|
1675
|
+
funcs = ", ".join(sorted(data[param_name]["funcs"]))
|
|
1676
|
+
dv = _format_default_value(default_value)
|
|
1677
|
+
|
|
1678
|
+
docs_link = data[param_name].get("link_to_docs")
|
|
1679
|
+
docs_part = f" {_link(docs_link)}" if docs_link else ""
|
|
1680
|
+
print(
|
|
1681
|
+
f" ⚠️ Parameter '{param_name}' defaults to {dv} in the following functions: [{funcs}]. "
|
|
1682
|
+
f"For more information, check {docs_part}")
|
|
1683
|
+
print("\nIf this isn’t the intended behaviour, set them explicitly.")
|
|
1684
|
+
|
|
1685
|
+
def _print_table():
|
|
1686
|
+
_print_param_default_warnings()
|
|
1687
|
+
|
|
1688
|
+
if not started_from("leap_integration.py"):
|
|
1689
|
+
return
|
|
1690
|
+
|
|
1691
|
+
ready_mess = "\nAll parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system."
|
|
1692
|
+
not_ready_mess = "\nSome mandatory components have not yet been added to the Integration test. Recommended next interface to add is: "
|
|
1693
|
+
mandatory_ready_mess = "\nAll mandatory parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system or continue to the next optional reccomeded interface,adding: "
|
|
1694
|
+
code_mapping_failure_mes = "Tensorleap_integration_test code flow failed, check raised exception."
|
|
1695
|
+
|
|
1696
|
+
name_width = max(len(row["name"]) for row in table)
|
|
1697
|
+
status_width = max(len(row["Added to integration"]) for row in table)
|
|
1698
|
+
|
|
1699
|
+
header = f"{'Decorator Name'.ljust(name_width)} | {'Added to integration'.ljust(status_width)}"
|
|
1700
|
+
sep = "-" * len(header)
|
|
1701
|
+
|
|
1702
|
+
print("\n" + header)
|
|
1703
|
+
print(sep)
|
|
1704
|
+
|
|
1705
|
+
ready = True
|
|
1706
|
+
next_step = None
|
|
1707
|
+
|
|
1708
|
+
for row in table:
|
|
1709
|
+
print(f"{row['name'].ljust(name_width)} | {row['Added to integration'].ljust(status_width)}")
|
|
1710
|
+
if not _crashed["value"] and ready:
|
|
1711
|
+
if row["Added to integration"] != CHECK:
|
|
1712
|
+
ready = False
|
|
1713
|
+
next_step = row["name"]
|
|
1714
|
+
|
|
1715
|
+
if _crashed["value"]:
|
|
1716
|
+
print(f"\nScript crashed before completing all steps. crashed at function '{_current_func['name']}'.")
|
|
1717
|
+
return
|
|
1718
|
+
|
|
1719
|
+
if code_mapping_failure[0]:
|
|
1720
|
+
print(f"\n{CROSS + code_mapping_failure_mes}.")
|
|
1721
|
+
return
|
|
1722
|
+
|
|
1723
|
+
print(ready_mess) if ready else print(
|
|
1724
|
+
mandatory_ready_mess + next_step
|
|
1725
|
+
) if (next_step and "optional" in next_step) else print(not_ready_mess + (next_step or ""))
|
|
1726
|
+
|
|
1727
|
+
def set_current(name: str):
|
|
1728
|
+
_current_func["name"] = name
|
|
1729
|
+
|
|
1730
|
+
def update_env_params(name: str, status: str = "v"):
|
|
1731
|
+
if name == "code_mapping":
|
|
1732
|
+
code_mapping_failure[0] = 1
|
|
1733
|
+
if status == "v":
|
|
1734
|
+
_set_status(name, CHECK)
|
|
1735
|
+
else:
|
|
1736
|
+
_set_status(name, CROSS)
|
|
1737
|
+
|
|
1738
|
+
def run_on_exit():
|
|
1739
|
+
if _finalizer_called["done"]:
|
|
1740
|
+
return
|
|
1741
|
+
_finalizer_called["done"] = True
|
|
1742
|
+
if not _crashed["value"]:
|
|
1743
|
+
_mark_unknowns_as_cross()
|
|
1744
|
+
|
|
1745
|
+
_print_table()
|
|
1746
|
+
|
|
1747
|
+
def handle_exception(exc_type, exc_value, exc_traceback):
|
|
1748
|
+
_crashed["value"] = True
|
|
1749
|
+
crashed_name = _current_func["name"]
|
|
1750
|
+
if crashed_name:
|
|
1751
|
+
row = _find_row(crashed_name)
|
|
1752
|
+
if row and row["Added to integration"] != CHECK:
|
|
1753
|
+
row["Added to integration"] = CROSS
|
|
1754
|
+
|
|
1755
|
+
traceback.print_exception(exc_type, exc_value, exc_traceback)
|
|
1756
|
+
run_on_exit()
|
|
1757
|
+
|
|
1758
|
+
atexit.register(run_on_exit)
|
|
1759
|
+
sys.excepthook = handle_exception
|
|
1760
|
+
|
|
1761
|
+
return set_current, update_env_params
|
|
1762
|
+
|
|
1763
|
+
|
|
1764
|
+
if not _call_from_tl_platform:
|
|
1765
|
+
set_current, update_env_params_func = tensorleap_status_table()
|
|
1766
|
+
|
|
1767
|
+
|
|
1768
|
+
|
|
1769
|
+
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
|
|
1773
|
+
|