code-loader 1.0.117__py3-none-any.whl → 1.0.153.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- code_loader/contract/datasetclasses.py +30 -7
- code_loader/contract/mapping.py +11 -0
- code_loader/inner_leap_binder/leapbinder.py +14 -1
- code_loader/inner_leap_binder/leapbinder_decorators.py +1046 -146
- code_loader/leaploader.py +47 -29
- code_loader/leaploaderbase.py +30 -5
- code_loader/mixpanel_tracker.py +106 -10
- code_loader/plot_functions/plot_functions.py +1 -2
- code_loader/utils.py +6 -6
- code_loader-1.0.153.dev4.dist-info/LICENSE +21 -0
- {code_loader-1.0.117.dist-info → code_loader-1.0.153.dev4.dist-info}/METADATA +3 -3
- {code_loader-1.0.117.dist-info → code_loader-1.0.153.dev4.dist-info}/RECORD +14 -13
- {code_loader-1.0.117.dist-info → code_loader-1.0.153.dev4.dist-info}/WHEEL +1 -1
- /code_loader-1.0.117.dist-info/LICENSE → /LICENSE +0 -0
|
@@ -1,25 +1,201 @@
|
|
|
1
1
|
# mypy: ignore-errors
|
|
2
2
|
import os
|
|
3
|
-
|
|
3
|
+
import warnings
|
|
4
|
+
import logging
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, Union, Callable, List, Dict, Set, Any
|
|
9
|
+
from typing import Optional, Union, Callable, List, Dict, get_args, get_origin, DefaultDict
|
|
4
10
|
|
|
5
11
|
import numpy as np
|
|
6
12
|
import numpy.typing as npt
|
|
7
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
8
16
|
from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
|
|
9
17
|
CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
|
|
10
18
|
VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
|
|
11
|
-
ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance
|
|
12
|
-
|
|
19
|
+
ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance, \
|
|
20
|
+
InstanceLengthCallableInterface
|
|
21
|
+
from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType, DataStateType
|
|
13
22
|
from code_loader import leap_binder
|
|
14
23
|
from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
|
|
15
24
|
from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
|
|
16
25
|
LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
|
|
17
26
|
from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
|
|
27
|
+
from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
|
|
18
28
|
|
|
19
29
|
import inspect
|
|
20
30
|
import functools
|
|
31
|
+
from pathlib import Path
|
|
21
32
|
|
|
22
33
|
_called_from_inside_tl_decorator = 0
|
|
34
|
+
_called_from_inside_tl_integration_test_decorator = False
|
|
35
|
+
_call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
|
|
36
|
+
|
|
37
|
+
# ---- warnings store (module-level) ----
|
|
38
|
+
_UNSET = object()
|
|
39
|
+
_STORED_WARNINGS: List[Dict[str, Any]] = []
|
|
40
|
+
_STORED_WARNING_KEYS: Set[tuple] = set()
|
|
41
|
+
# param_name -> set(user_func_name)
|
|
42
|
+
_PARAM_DEFAULT_FUNCS: DefaultDict[str, Set[str]] = defaultdict(set)
|
|
43
|
+
# param_name -> default_value used (repr-able)
|
|
44
|
+
_PARAM_DEFAULT_VALUE: Dict[str, Any] = {}
|
|
45
|
+
_PARAM_DEFAULT_DOCS: Dict[str, str] = {}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_entry_script_path() -> str:
|
|
49
|
+
import sys
|
|
50
|
+
argv0 = sys.argv[0] if sys.argv else ""
|
|
51
|
+
if argv0:
|
|
52
|
+
return str(Path(argv0).resolve())
|
|
53
|
+
import __main__
|
|
54
|
+
main_file = getattr(__main__, "__file__", "") or ""
|
|
55
|
+
return str(Path(main_file).resolve()) if main_file else ""
|
|
56
|
+
|
|
57
|
+
def started_from(filename: str) -> bool:
|
|
58
|
+
entry = get_entry_script_path()
|
|
59
|
+
return bool(entry) and Path(entry).name == filename
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def store_warning_by_param(
|
|
63
|
+
*,
|
|
64
|
+
param_name: str,
|
|
65
|
+
user_func_name: str,
|
|
66
|
+
default_value: Any,
|
|
67
|
+
link_to_docs: str = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
_PARAM_DEFAULT_FUNCS[param_name].add(user_func_name)
|
|
70
|
+
|
|
71
|
+
if param_name not in _PARAM_DEFAULT_VALUE:
|
|
72
|
+
_PARAM_DEFAULT_VALUE[param_name] = default_value
|
|
73
|
+
|
|
74
|
+
if link_to_docs and param_name not in _PARAM_DEFAULT_DOCS:
|
|
75
|
+
_PARAM_DEFAULT_DOCS[param_name] = link_to_docs
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_param_default_warnings() -> Dict[str, Dict[str, Any]]:
|
|
79
|
+
out: Dict[str, Dict[str, Any]] = {}
|
|
80
|
+
for p, funcs in _PARAM_DEFAULT_FUNCS.items():
|
|
81
|
+
out[p] = {
|
|
82
|
+
"default_value": _PARAM_DEFAULT_VALUE.get(p, None),
|
|
83
|
+
"funcs": set(funcs),
|
|
84
|
+
"link_to_docs": _PARAM_DEFAULT_DOCS.get(p),
|
|
85
|
+
}
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_stored_warnings() -> List[Dict[str, Any]]:
|
|
90
|
+
return list(_STORED_WARNINGS)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def validate_args_structure(*args, types_order, func_name, expected_names, **kwargs):
|
|
94
|
+
def _type_to_str(t):
|
|
95
|
+
origin = get_origin(t)
|
|
96
|
+
if origin is Union:
|
|
97
|
+
return " | ".join(tt.__name__ for tt in get_args(t))
|
|
98
|
+
elif hasattr(t, "__name__"):
|
|
99
|
+
return t.__name__
|
|
100
|
+
else:
|
|
101
|
+
return str(t)
|
|
102
|
+
|
|
103
|
+
def _format_types(types, names=None):
|
|
104
|
+
return ", ".join(
|
|
105
|
+
f"{(names[i] + ': ') if names else f'arg{i}: '}{_type_to_str(ty)}"
|
|
106
|
+
for i, ty in enumerate(types)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if expected_names:
|
|
110
|
+
normalized_args = []
|
|
111
|
+
for i, name in enumerate(expected_names):
|
|
112
|
+
if i < len(args):
|
|
113
|
+
normalized_args.append(args[i])
|
|
114
|
+
elif name in kwargs:
|
|
115
|
+
normalized_args.append(kwargs[name])
|
|
116
|
+
else:
|
|
117
|
+
raise AssertionError(
|
|
118
|
+
f"{func_name} validation failed: "
|
|
119
|
+
f"Missing required argument '{name}'. "
|
|
120
|
+
f"Expected arguments: {expected_names}."
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
normalized_args = list(args)
|
|
124
|
+
if len(normalized_args) != len(types_order):
|
|
125
|
+
expected = _format_types(types_order, expected_names)
|
|
126
|
+
got_types = ", ".join(type(arg).__name__ for arg in normalized_args)
|
|
127
|
+
raise AssertionError(
|
|
128
|
+
f"{func_name} validation failed: "
|
|
129
|
+
f"Expected exactly {len(types_order)} arguments ({expected}), "
|
|
130
|
+
f"but got {len(normalized_args)} argument(s) of type(s): ({got_types}). "
|
|
131
|
+
f"Correct usage example: {func_name}({expected})"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
for i, (arg, expected_type) in enumerate(zip(normalized_args, types_order)):
|
|
135
|
+
origin = get_origin(expected_type)
|
|
136
|
+
if origin is Union:
|
|
137
|
+
allowed_types = get_args(expected_type)
|
|
138
|
+
else:
|
|
139
|
+
allowed_types = (expected_type,)
|
|
140
|
+
|
|
141
|
+
if not isinstance(arg, allowed_types):
|
|
142
|
+
allowed_str = " | ".join(t.__name__ for t in allowed_types)
|
|
143
|
+
raise AssertionError(
|
|
144
|
+
f"{func_name} validation failed: "
|
|
145
|
+
f"Argument '{expected_names[i] if expected_names else f'arg{i}'}' "
|
|
146
|
+
f"expected type {allowed_str}, but got {type(arg).__name__}. "
|
|
147
|
+
f"Correct usage example: {func_name}({_format_types(types_order, expected_names)})"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def validate_output_structure(result, func_name: str, expected_type_name="np.ndarray", gt_flag=False):
|
|
152
|
+
if result is None or (isinstance(result, float) and np.isnan(result)):
|
|
153
|
+
if gt_flag:
|
|
154
|
+
raise AssertionError(
|
|
155
|
+
f"{func_name} validation failed: "
|
|
156
|
+
f"The function returned {result!r}. "
|
|
157
|
+
f"If you are working with an unlabeled dataset and no ground truth is available, "
|
|
158
|
+
f"use 'return np.array([], dtype=np.float32)' instead. "
|
|
159
|
+
f"Otherwise, {func_name} expected a single {expected_type_name} object. "
|
|
160
|
+
f"Make sure the function ends with 'return <{expected_type_name}>'."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
raise AssertionError(
|
|
164
|
+
f"{func_name} validation failed: "
|
|
165
|
+
f"The function returned None. "
|
|
166
|
+
f"Expected a single {expected_type_name} object. "
|
|
167
|
+
f"Make sure the function ends with 'return <{expected_type_name}>'."
|
|
168
|
+
)
|
|
169
|
+
if isinstance(result, tuple):
|
|
170
|
+
element_descriptions = [
|
|
171
|
+
f"[{i}] type: {type(r).__name__}"
|
|
172
|
+
for i, r in enumerate(result)
|
|
173
|
+
]
|
|
174
|
+
element_summary = "\n ".join(element_descriptions)
|
|
175
|
+
|
|
176
|
+
raise AssertionError(
|
|
177
|
+
f"{func_name} validation failed: "
|
|
178
|
+
f"The function returned multiple outputs ({len(result)} values), "
|
|
179
|
+
f"but only a single {expected_type_name} is allowed.\n\n"
|
|
180
|
+
f"Returned elements:\n"
|
|
181
|
+
f" {element_summary}\n\n"
|
|
182
|
+
f"Correct usage example:\n"
|
|
183
|
+
f" def {func_name}(...):\n"
|
|
184
|
+
f" return <{expected_type_name}>\n\n"
|
|
185
|
+
f"If you intended to return multiple values, combine them into a single "
|
|
186
|
+
f"{expected_type_name} (e.g., by concatenation or stacking)."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def batch_warning(result, func_name):
|
|
191
|
+
if len(result.shape) > 0 and result.shape[0] == 1:
|
|
192
|
+
warnings.warn(
|
|
193
|
+
f"{func_name} warning: Tensorleap will add a batch dimension at axis 0 to the output of {func_name}, "
|
|
194
|
+
f"although the detected size of axis 0 is already 1. "
|
|
195
|
+
f"This may lead to an extra batch dimension (e.g., shape (1, 1, ...)). "
|
|
196
|
+
f"Please ensure that the output of '{func_name}' is not already batched "
|
|
197
|
+
f"to avoid computation errors."
|
|
198
|
+
)
|
|
23
199
|
|
|
24
200
|
|
|
25
201
|
def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
|
|
@@ -40,51 +216,234 @@ def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
|
|
|
40
216
|
_add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type)
|
|
41
217
|
|
|
42
218
|
|
|
43
|
-
def
|
|
219
|
+
def tensorleap_integration_test():
|
|
44
220
|
def decorating_function(integration_test_function: Callable):
|
|
45
221
|
leap_binder.integration_test_func = integration_test_function
|
|
46
222
|
|
|
223
|
+
def _validate_input_args(*args, **kwargs):
|
|
224
|
+
sample_id, preprocess_response = args
|
|
225
|
+
assert type(sample_id) == preprocess_response.sample_id_type, (
|
|
226
|
+
f"tensorleap_integration_test validation failed: "
|
|
227
|
+
f"sample_id type ({type(sample_id).__name__}) does not match the expected "
|
|
228
|
+
f"type ({preprocess_response.sample_id_type}) from the PreprocessResponse."
|
|
229
|
+
)
|
|
230
|
+
|
|
47
231
|
def inner(*args, **kwargs):
|
|
48
|
-
|
|
232
|
+
if not _call_from_tl_platform:
|
|
233
|
+
set_current('tensorleap_integration_test')
|
|
234
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
235
|
+
func_name='integration_test', expected_names=["idx", "preprocess"], **kwargs)
|
|
236
|
+
_validate_input_args(*args, **kwargs)
|
|
49
237
|
|
|
238
|
+
global _called_from_inside_tl_integration_test_decorator
|
|
239
|
+
# Clear integration test events for new test
|
|
50
240
|
try:
|
|
51
|
-
|
|
52
|
-
integration_test_function(None, None)
|
|
241
|
+
clear_integration_events()
|
|
53
242
|
except Exception as e:
|
|
54
|
-
|
|
55
|
-
|
|
243
|
+
logger.debug(f"Failed to clear integration events: {e}")
|
|
244
|
+
try:
|
|
245
|
+
_called_from_inside_tl_integration_test_decorator = True
|
|
246
|
+
if not _call_from_tl_platform:
|
|
247
|
+
update_env_params_func("tensorleap_integration_test",
|
|
248
|
+
"v") # put here because otherwise it will become v only if it finishes all the script
|
|
249
|
+
ret = integration_test_function(*args, **kwargs)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
os.environ[mapping_runtime_mode_env_var_mame] = 'True'
|
|
253
|
+
integration_test_function(None, PreprocessResponse(state=DataStateType.training, length=0))
|
|
254
|
+
except Exception as e:
|
|
255
|
+
import traceback
|
|
256
|
+
first_tb = traceback.extract_tb(e.__traceback__)[-1]
|
|
257
|
+
file_name = Path(first_tb.filename).name
|
|
258
|
+
line_number = first_tb.lineno
|
|
259
|
+
if isinstance(e, TypeError) and 'is not subscriptable' in str(e):
|
|
260
|
+
update_env_params_func("code_mapping", "x")
|
|
261
|
+
raise (f'Invalid integration code. File {file_name}, line {line_number}: '
|
|
262
|
+
f"indexing is supported only on the model's predictions inside the integration test. Please remove this indexing operation usage from the integration test code.")
|
|
263
|
+
else:
|
|
264
|
+
update_env_params_func("code_mapping", "x")
|
|
265
|
+
|
|
266
|
+
raise (f'Invalid integration code. File {file_name}, line {line_number}: '
|
|
267
|
+
f'Integration test is only allowed to call Tensorleap decorators. '
|
|
268
|
+
f'Ensure any arithmetics, external library use, Python logic is placed within Tensorleap decoders')
|
|
269
|
+
finally:
|
|
270
|
+
if mapping_runtime_mode_env_var_mame in os.environ:
|
|
271
|
+
del os.environ[mapping_runtime_mode_env_var_mame]
|
|
56
272
|
finally:
|
|
57
|
-
|
|
58
|
-
|
|
273
|
+
_called_from_inside_tl_integration_test_decorator = False
|
|
274
|
+
|
|
275
|
+
leap_binder.check()
|
|
59
276
|
|
|
60
277
|
return inner
|
|
61
278
|
|
|
62
279
|
return decorating_function
|
|
63
280
|
|
|
64
281
|
|
|
65
|
-
def
|
|
282
|
+
def _safe_get_item(key):
|
|
283
|
+
try:
|
|
284
|
+
return NodeMappingType[f'Input{str(key)}']
|
|
285
|
+
except ValueError:
|
|
286
|
+
raise Exception(f'Tensorleap currently supports models with no more then 10 inputs')
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = _UNSET):
|
|
290
|
+
prediction_types_was_provided = prediction_types is not _UNSET
|
|
291
|
+
|
|
292
|
+
if not prediction_types_was_provided:
|
|
293
|
+
prediction_types = []
|
|
294
|
+
if not _call_from_tl_platform:
|
|
295
|
+
store_warning_by_param(
|
|
296
|
+
param_name="prediction_types",
|
|
297
|
+
user_func_name="tensorleap_load_model",
|
|
298
|
+
default_value=prediction_types,
|
|
299
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
|
|
300
|
+
)
|
|
301
|
+
assert isinstance(prediction_types, list), (
|
|
302
|
+
f"tensorleap_load_model validation failed: "
|
|
303
|
+
f" prediction_types is an optional argument of type List[PredictionTypeHandler]] but got {type(prediction_types).__name__}."
|
|
304
|
+
)
|
|
66
305
|
for i, prediction_type in enumerate(prediction_types):
|
|
306
|
+
assert isinstance(prediction_type, PredictionTypeHandler), (f"tensorleap_load_model validation failed: "
|
|
307
|
+
f" prediction_types at position {i} must be of type PredictionTypeHandler but got {type(prediction_types[i]).__name__}.")
|
|
308
|
+
prediction_type_channel_dim_was_provided = prediction_type.channel_dim != "tl_default_value"
|
|
309
|
+
if not prediction_type_channel_dim_was_provided:
|
|
310
|
+
prediction_types[i].channel_dim = -1
|
|
311
|
+
if not _call_from_tl_platform:
|
|
312
|
+
store_warning_by_param(
|
|
313
|
+
param_name=f"prediction_types[{i}].channel_dim",
|
|
314
|
+
user_func_name="tensorleap_load_model",
|
|
315
|
+
default_value=prediction_types[i].channel_dim,
|
|
316
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
|
|
317
|
+
)
|
|
67
318
|
leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
|
|
68
319
|
|
|
69
|
-
def
|
|
320
|
+
def _validate_result(result) -> None:
|
|
321
|
+
valid_types = ["onnxruntime", "keras"]
|
|
322
|
+
err_message = f"tensorleap_load_model validation failed:\nSupported models are Keras and onnxruntime only and non of them was returned."
|
|
323
|
+
validate_output_structure(result, func_name="tensorleap_load_model",
|
|
324
|
+
expected_type_name=[" | ".join(t for t in valid_types)][0])
|
|
325
|
+
try:
|
|
326
|
+
import keras
|
|
327
|
+
except ImportError:
|
|
328
|
+
keras = None
|
|
329
|
+
try:
|
|
330
|
+
import tensorflow as tf
|
|
331
|
+
except ImportError:
|
|
332
|
+
tf = None
|
|
333
|
+
try:
|
|
334
|
+
import onnxruntime
|
|
335
|
+
except ImportError:
|
|
336
|
+
onnxruntime = None
|
|
337
|
+
|
|
338
|
+
if not keras and not onnxruntime:
|
|
339
|
+
raise AssertionError(err_message)
|
|
340
|
+
|
|
341
|
+
is_keras_model = (
|
|
342
|
+
bool(keras and isinstance(result, getattr(keras, "Model", tuple())))
|
|
343
|
+
or bool(tf and isinstance(result, getattr(tf.keras, "Model", tuple())))
|
|
344
|
+
)
|
|
345
|
+
is_onnx_model = bool(onnxruntime and isinstance(result, onnxruntime.InferenceSession))
|
|
346
|
+
|
|
347
|
+
if not any([is_keras_model, is_onnx_model]):
|
|
348
|
+
raise AssertionError(err_message)
|
|
349
|
+
|
|
350
|
+
def decorating_function(load_model_func, prediction_types=prediction_types):
|
|
70
351
|
class TempMapping:
|
|
71
352
|
pass
|
|
72
353
|
|
|
73
|
-
|
|
354
|
+
@lru_cache()
|
|
355
|
+
def inner(*args, **kwargs):
|
|
356
|
+
if not _call_from_tl_platform:
|
|
357
|
+
set_current('tensorleap_load_model')
|
|
358
|
+
validate_args_structure(*args, types_order=[],
|
|
359
|
+
func_name='tensorleap_load_model', expected_names=[], **kwargs)
|
|
360
|
+
|
|
74
361
|
class ModelPlaceholder:
|
|
75
|
-
def __init__(self):
|
|
76
|
-
self.model = load_model_func()
|
|
362
|
+
def __init__(self, prediction_types):
|
|
363
|
+
self.model = load_model_func() # TODO- check why this fails on onnx model
|
|
364
|
+
self.prediction_types = prediction_types
|
|
365
|
+
_validate_result(self.model)
|
|
77
366
|
|
|
78
367
|
# keras interface
|
|
79
368
|
def __call__(self, arg):
|
|
80
369
|
ret = self.model(arg)
|
|
370
|
+
self.validate_declared_prediction_types(ret)
|
|
371
|
+
if isinstance(ret, list):
|
|
372
|
+
return [r.numpy() for r in ret]
|
|
81
373
|
return ret.numpy()
|
|
82
374
|
|
|
375
|
+
def validate_declared_prediction_types(self, ret):
|
|
376
|
+
if not (len(self.prediction_types) == len(ret) if isinstance(ret, list) else 1) and len(
|
|
377
|
+
self.prediction_types) != 0:
|
|
378
|
+
if not _call_from_tl_platform:
|
|
379
|
+
update_env_params_func("tensorleap_load_model", "x")
|
|
380
|
+
raise Exception(
|
|
381
|
+
f"tensorleap_load_model validation failed: number of declared prediction types({len(prediction_types)}) != number of model outputs({len(ret) if isinstance(ret, list) else 1})")
|
|
382
|
+
|
|
383
|
+
def _convert_onnx_inputs_to_correct_type(
|
|
384
|
+
self, float_arrays_inputs: Dict[str, np.ndarray]
|
|
385
|
+
) -> Dict[str, np.ndarray]:
|
|
386
|
+
ONNX_TYPE_TO_NP = {
|
|
387
|
+
"tensor(float)": np.float32,
|
|
388
|
+
"tensor(double)": np.float64,
|
|
389
|
+
"tensor(int64)": np.int64,
|
|
390
|
+
"tensor(int32)": np.int32,
|
|
391
|
+
"tensor(int16)": np.int16,
|
|
392
|
+
"tensor(int8)": np.int8,
|
|
393
|
+
"tensor(uint64)": np.uint64,
|
|
394
|
+
"tensor(uint32)": np.uint32,
|
|
395
|
+
"tensor(uint16)": np.uint16,
|
|
396
|
+
"tensor(uint8)": np.uint8,
|
|
397
|
+
"tensor(bool)": np.bool_,
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
"""
|
|
401
|
+
Cast user-provided NumPy inputs to match the dtypes/shapes
|
|
402
|
+
expected by an ONNX Runtime InferenceSession.
|
|
403
|
+
"""
|
|
404
|
+
coerced = {}
|
|
405
|
+
meta = {i.name: i for i in self.model.get_inputs()}
|
|
406
|
+
|
|
407
|
+
for name, arr in float_arrays_inputs.items():
|
|
408
|
+
if name not in meta:
|
|
409
|
+
# Keep as-is unless extra inputs are disallowed
|
|
410
|
+
coerced[name] = arr
|
|
411
|
+
continue
|
|
412
|
+
|
|
413
|
+
info = meta[name]
|
|
414
|
+
onnx_type = info.type
|
|
415
|
+
want_dtype = ONNX_TYPE_TO_NP.get(onnx_type)
|
|
416
|
+
|
|
417
|
+
if want_dtype is None:
|
|
418
|
+
raise TypeError(f"Unsupported ONNX input type: {onnx_type}")
|
|
419
|
+
|
|
420
|
+
# Cast dtype if needed
|
|
421
|
+
if arr.dtype != want_dtype:
|
|
422
|
+
arr = arr.astype(want_dtype, copy=False)
|
|
423
|
+
|
|
424
|
+
coerced[name] = arr
|
|
425
|
+
|
|
426
|
+
# Verify required inputs are present
|
|
427
|
+
missing = [n for n in meta if n not in coerced]
|
|
428
|
+
if missing:
|
|
429
|
+
raise KeyError(f"Missing required input(s): {sorted(missing)}")
|
|
430
|
+
|
|
431
|
+
return coerced
|
|
432
|
+
|
|
83
433
|
# onnx runtime interface
|
|
84
434
|
def run(self, output_names, input_dict):
|
|
85
|
-
|
|
435
|
+
corrected_type_inputs = self._convert_onnx_inputs_to_correct_type(input_dict)
|
|
436
|
+
ret = self.model.run(output_names, corrected_type_inputs)
|
|
437
|
+
self.validate_declared_prediction_types(ret)
|
|
438
|
+
return ret
|
|
86
439
|
|
|
87
|
-
|
|
440
|
+
def get_inputs(self):
|
|
441
|
+
return self.model.get_inputs()
|
|
442
|
+
|
|
443
|
+
model_placeholder = ModelPlaceholder(prediction_types)
|
|
444
|
+
if not _call_from_tl_platform:
|
|
445
|
+
update_env_params_func("tensorleap_load_model", "v")
|
|
446
|
+
return model_placeholder
|
|
88
447
|
|
|
89
448
|
def mapping_inner():
|
|
90
449
|
class ModelOutputPlaceholder:
|
|
@@ -96,15 +455,20 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
|
|
|
96
455
|
f'Expected key to be an int, got {type(key)} instead.'
|
|
97
456
|
|
|
98
457
|
ret = TempMapping()
|
|
99
|
-
|
|
458
|
+
try:
|
|
459
|
+
ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
|
|
460
|
+
except ValueError as e:
|
|
461
|
+
raise Exception(f'Tensorleap currently supports models with no more then 10 active predictions,'
|
|
462
|
+
f' {key} not supported.')
|
|
100
463
|
return ret
|
|
101
464
|
|
|
102
465
|
class ModelPlaceholder:
|
|
466
|
+
|
|
103
467
|
# keras interface
|
|
104
468
|
def __call__(self, arg):
|
|
105
469
|
if isinstance(arg, list):
|
|
106
470
|
for i, elem in enumerate(arg):
|
|
107
|
-
elem.node_mapping.type =
|
|
471
|
+
elem.node_mapping.type = _safe_get_item(i)
|
|
108
472
|
else:
|
|
109
473
|
arg.node_mapping.type = NodeMappingType.Input0
|
|
110
474
|
|
|
@@ -115,18 +479,38 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
|
|
|
115
479
|
assert output_names is None
|
|
116
480
|
assert isinstance(input_dict, dict), \
|
|
117
481
|
f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
|
|
118
|
-
for i, elem in enumerate(input_dict.
|
|
119
|
-
|
|
482
|
+
for i, (input_key, elem) in enumerate(input_dict.items()):
|
|
483
|
+
if isinstance(input_key, NodeMappingType):
|
|
484
|
+
elem.node_mapping.type = input_key
|
|
485
|
+
else:
|
|
486
|
+
elem.node_mapping.type = _safe_get_item(i)
|
|
120
487
|
|
|
121
488
|
return ModelOutputPlaceholder()
|
|
122
489
|
|
|
490
|
+
def get_inputs(self):
|
|
491
|
+
class FollowIndex:
|
|
492
|
+
def __init__(self, index):
|
|
493
|
+
self.name = _safe_get_item(index)
|
|
494
|
+
|
|
495
|
+
class FollowInputIndex:
|
|
496
|
+
def __init__(self):
|
|
497
|
+
pass
|
|
498
|
+
|
|
499
|
+
def __getitem__(self, index):
|
|
500
|
+
assert isinstance(index, int), \
|
|
501
|
+
f'Expected key to be an int, got {type(index)} instead.'
|
|
502
|
+
|
|
503
|
+
return FollowIndex(index)
|
|
504
|
+
|
|
505
|
+
return FollowInputIndex()
|
|
506
|
+
|
|
123
507
|
return ModelPlaceholder()
|
|
124
508
|
|
|
125
|
-
def final_inner():
|
|
509
|
+
def final_inner(*args, **kwargs):
|
|
126
510
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
127
511
|
return mapping_inner()
|
|
128
512
|
else:
|
|
129
|
-
return inner()
|
|
513
|
+
return inner(*args, **kwargs)
|
|
130
514
|
|
|
131
515
|
return final_inner
|
|
132
516
|
|
|
@@ -134,83 +518,186 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
|
|
|
134
518
|
|
|
135
519
|
|
|
136
520
|
def tensorleap_custom_metric(name: str,
|
|
137
|
-
direction: Union[MetricDirection, Dict[str, MetricDirection]] =
|
|
521
|
+
direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
|
|
138
522
|
compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
|
|
139
523
|
connects_to=None):
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
524
|
+
name_to_unique_name = defaultdict(set)
|
|
525
|
+
|
|
526
|
+
def decorating_function(
|
|
527
|
+
user_function: Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs,
|
|
528
|
+
ConfusionMatrixCallableInterfaceMultiArgs]):
|
|
529
|
+
nonlocal direction
|
|
530
|
+
|
|
531
|
+
direction_was_provided = direction is not _UNSET
|
|
532
|
+
|
|
533
|
+
if not direction_was_provided:
|
|
534
|
+
direction = MetricDirection.Downward
|
|
535
|
+
if not _call_from_tl_platform:
|
|
536
|
+
store_warning_by_param(
|
|
537
|
+
param_name="direction",
|
|
538
|
+
user_func_name=user_function.__name__,
|
|
539
|
+
default_value=direction,
|
|
540
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/custom-metrics"
|
|
541
|
+
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def _validate_decorators_signature():
|
|
545
|
+
err_message = f"{user_function.__name__} validation failed.\n"
|
|
546
|
+
if not isinstance(name, str):
|
|
547
|
+
raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
|
|
548
|
+
valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
|
|
549
|
+
if isinstance(direction, MetricDirection):
|
|
550
|
+
if direction not in valid_directions:
|
|
551
|
+
raise ValueError(
|
|
552
|
+
err_message +
|
|
553
|
+
f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
|
|
554
|
+
f"got type {type(direction).__name__}."
|
|
555
|
+
)
|
|
556
|
+
elif isinstance(direction, dict):
|
|
557
|
+
if not all(isinstance(k, str) for k in direction.keys()):
|
|
558
|
+
invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
|
|
559
|
+
raise TypeError(
|
|
560
|
+
err_message +
|
|
561
|
+
f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
|
|
562
|
+
)
|
|
563
|
+
for k, v in direction.items():
|
|
564
|
+
if v not in valid_directions:
|
|
565
|
+
raise ValueError(
|
|
566
|
+
err_message +
|
|
567
|
+
f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
|
|
568
|
+
f"got type {type(v).__name__}."
|
|
569
|
+
)
|
|
570
|
+
else:
|
|
571
|
+
raise TypeError(
|
|
572
|
+
err_message +
|
|
573
|
+
f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
|
|
574
|
+
f"got type {type(direction).__name__}."
|
|
575
|
+
)
|
|
576
|
+
if compute_insights is not None:
|
|
577
|
+
if not isinstance(compute_insights, (bool, dict)):
|
|
578
|
+
raise TypeError(
|
|
579
|
+
err_message +
|
|
580
|
+
f"`compute_insights` must be a bool or a Dict[str, bool], "
|
|
581
|
+
f"got type {type(compute_insights).__name__}."
|
|
582
|
+
)
|
|
583
|
+
if isinstance(compute_insights, dict):
|
|
584
|
+
if not all(isinstance(k, str) for k in compute_insights.keys()):
|
|
585
|
+
invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
|
|
586
|
+
raise TypeError(
|
|
587
|
+
err_message +
|
|
588
|
+
f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
|
|
589
|
+
)
|
|
590
|
+
for k, v in compute_insights.items():
|
|
591
|
+
if not isinstance(v, bool):
|
|
592
|
+
raise TypeError(
|
|
593
|
+
err_message +
|
|
594
|
+
f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
|
|
595
|
+
)
|
|
596
|
+
if connects_to is not None:
|
|
597
|
+
valid_types = (str, list, tuple, set)
|
|
598
|
+
if not isinstance(connects_to, valid_types):
|
|
599
|
+
raise TypeError(
|
|
600
|
+
err_message +
|
|
601
|
+
f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
|
|
602
|
+
)
|
|
603
|
+
if isinstance(connects_to, (list, tuple, set)):
|
|
604
|
+
invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
|
|
605
|
+
if invalid_elems:
|
|
606
|
+
raise TypeError(
|
|
607
|
+
err_message +
|
|
608
|
+
f"All elements in `connects_to` must be strings, "
|
|
609
|
+
f"but found element types: {invalid_elems}."
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
_validate_decorators_signature()
|
|
613
|
+
|
|
143
614
|
for metric_handler in leap_binder.setup_container.metrics:
|
|
144
615
|
if metric_handler.metric_handler_data.name == name:
|
|
145
616
|
raise Exception(f'Metric with name {name} already exists. '
|
|
146
617
|
f'Please choose another')
|
|
147
618
|
|
|
148
619
|
def _validate_input_args(*args, **kwargs) -> None:
|
|
620
|
+
assert len(args) + len(kwargs) > 0, (
|
|
621
|
+
f"{user_function.__name__}() validation failed: "
|
|
622
|
+
f"Expected at least one positional|key-word argument of type np.ndarray, "
|
|
623
|
+
f"but received none. "
|
|
624
|
+
f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
|
|
625
|
+
)
|
|
149
626
|
for i, arg in enumerate(args):
|
|
150
627
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
151
|
-
f'
|
|
628
|
+
f'{user_function.__name__}() validation failed: '
|
|
152
629
|
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
153
630
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
154
631
|
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
|
155
|
-
(f'
|
|
632
|
+
(f'{user_function.__name__}() validation failed: Argument #{i} '
|
|
156
633
|
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
|
157
634
|
f'instead of {leap_binder.batch_size_to_validate}')
|
|
158
635
|
|
|
159
636
|
for _arg_name, arg in kwargs.items():
|
|
160
637
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
161
|
-
f'
|
|
638
|
+
f'{user_function.__name__}() validation failed: '
|
|
162
639
|
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
163
640
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
164
641
|
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
|
165
|
-
(f'
|
|
642
|
+
(f'{user_function.__name__}() validation failed: Argument {_arg_name} '
|
|
166
643
|
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
|
167
644
|
f'instead of {leap_binder.batch_size_to_validate}')
|
|
168
645
|
|
|
169
646
|
def _validate_result(result) -> None:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
647
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
648
|
+
expected_type_name="List[float | int | None | List[ConfusionMatrixElement] ] | NDArray[np.float32] or dictonary with one of these types as its values types")
|
|
649
|
+
supported_types_message = (f'{user_function.__name__}() validation failed: '
|
|
650
|
+
f'{user_function.__name__}() has returned unsupported type.\nSupported types are List[float|int|None], '
|
|
651
|
+
f'List[List[ConfusionMatrixElement]], NDArray[np.float32] or dictonary with one of these types as its values types. ')
|
|
173
652
|
|
|
174
|
-
def _validate_single_metric(single_metric_result):
|
|
653
|
+
def _validate_single_metric(single_metric_result, key=None):
|
|
175
654
|
if isinstance(single_metric_result, list):
|
|
176
655
|
if isinstance(single_metric_result[0], list):
|
|
177
|
-
assert isinstance(single_metric_result[0]
|
|
178
|
-
f
|
|
656
|
+
assert all(isinstance(cm, ConfusionMatrixElement) for cm in single_metric_result[0]), (
|
|
657
|
+
f"{supported_types_message} "
|
|
658
|
+
f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
|
|
659
|
+
f"List[List[{', '.join(type(cm).__name__ for cm in single_metric_result[0])}]]."
|
|
660
|
+
)
|
|
661
|
+
|
|
179
662
|
else:
|
|
180
|
-
assert isinstance(single_metric_result
|
|
181
|
-
|
|
182
|
-
|
|
663
|
+
assert all(isinstance(v, (float, int, type(None), np.float32)) for v in single_metric_result), (
|
|
664
|
+
f"{supported_types_message}\n"
|
|
665
|
+
f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
|
|
666
|
+
f"List[{', '.join(type(v).__name__ for v in single_metric_result)}]."
|
|
667
|
+
)
|
|
183
668
|
else:
|
|
184
669
|
assert isinstance(single_metric_result,
|
|
185
|
-
np.ndarray), f'{supported_types_message}
|
|
186
|
-
assert len(single_metric_result.shape) == 1, (f'
|
|
670
|
+
np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
|
|
671
|
+
assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
|
|
187
672
|
f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
|
|
188
673
|
|
|
189
674
|
if leap_binder.batch_size_to_validate:
|
|
190
675
|
assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
|
|
191
|
-
f'
|
|
676
|
+
f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
|
|
192
677
|
|
|
193
678
|
if isinstance(result, dict):
|
|
194
679
|
for key, value in result.items():
|
|
680
|
+
_validate_single_metric(value, key)
|
|
681
|
+
|
|
195
682
|
assert isinstance(key, str), \
|
|
196
|
-
(f'
|
|
683
|
+
(f'{user_function.__name__}() validation failed: '
|
|
197
684
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
|
198
685
|
_validate_single_metric(value)
|
|
199
686
|
|
|
200
687
|
if isinstance(direction, dict):
|
|
201
688
|
for direction_key in direction:
|
|
202
689
|
assert direction_key in result, \
|
|
203
|
-
(f'
|
|
690
|
+
(f'{user_function.__name__}() validation failed: '
|
|
204
691
|
f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
|
|
205
692
|
|
|
206
693
|
if compute_insights is not None:
|
|
207
694
|
assert isinstance(compute_insights, dict), \
|
|
208
|
-
(f'
|
|
695
|
+
(f'{user_function.__name__}() validation failed: '
|
|
209
696
|
f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
|
|
210
697
|
|
|
211
698
|
for ci_key in compute_insights:
|
|
212
699
|
assert ci_key in result, \
|
|
213
|
-
(f'
|
|
700
|
+
(f'{user_function.__name__}() validation failed: '
|
|
214
701
|
f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
|
|
215
702
|
|
|
216
703
|
else:
|
|
@@ -218,7 +705,7 @@ def tensorleap_custom_metric(name: str,
|
|
|
218
705
|
|
|
219
706
|
if compute_insights is not None:
|
|
220
707
|
assert isinstance(compute_insights, bool), \
|
|
221
|
-
(f'
|
|
708
|
+
(f'{user_function.__name__}() validation failed: '
|
|
222
709
|
f'compute_insights should be boolean. Got {type(compute_insights)}.')
|
|
223
710
|
|
|
224
711
|
@functools.wraps(user_function)
|
|
@@ -245,11 +732,15 @@ def tensorleap_custom_metric(name: str,
|
|
|
245
732
|
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
|
|
246
733
|
|
|
247
734
|
def inner(*args, **kwargs):
|
|
735
|
+
if not _call_from_tl_platform:
|
|
736
|
+
set_current('tensorleap_custom_metric')
|
|
248
737
|
_validate_input_args(*args, **kwargs)
|
|
249
738
|
|
|
250
739
|
result = inner_without_validate(*args, **kwargs)
|
|
251
740
|
|
|
252
741
|
_validate_result(result)
|
|
742
|
+
if not _call_from_tl_platform:
|
|
743
|
+
update_env_params_func("tensorleap_custom_metric", "v")
|
|
253
744
|
return result
|
|
254
745
|
|
|
255
746
|
def mapping_inner(*args, **kwargs):
|
|
@@ -259,6 +750,11 @@ def tensorleap_custom_metric(name: str,
|
|
|
259
750
|
|
|
260
751
|
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
261
752
|
ordered_connections = list(args) + ordered_connections
|
|
753
|
+
|
|
754
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
755
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
756
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
757
|
+
|
|
262
758
|
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
263
759
|
mapping_inner.name, NodeMappingType.Metric)
|
|
264
760
|
|
|
@@ -281,29 +777,40 @@ def tensorleap_custom_metric(name: str,
|
|
|
281
777
|
def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
282
778
|
heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
|
|
283
779
|
connects_to=None):
|
|
780
|
+
name_to_unique_name = defaultdict(set)
|
|
781
|
+
|
|
284
782
|
def decorating_function(user_function: VisualizerCallableInterface):
|
|
783
|
+
assert isinstance(visualizer_type, LeapDataType), (f"{user_function.__name__} validation failed: "
|
|
784
|
+
f"visualizer_type should be of type {LeapDataType.__name__} but got {type(visualizer_type)}"
|
|
785
|
+
)
|
|
285
786
|
for viz_handler in leap_binder.setup_container.visualizers:
|
|
286
787
|
if viz_handler.visualizer_handler_data.name == name:
|
|
287
788
|
raise Exception(f'Visualizer with name {name} already exists. '
|
|
288
789
|
f'Please choose another')
|
|
289
790
|
|
|
290
791
|
def _validate_input_args(*args, **kwargs):
|
|
792
|
+
assert len(args) + len(kwargs) > 0, (
|
|
793
|
+
f"{user_function.__name__}() validation failed: "
|
|
794
|
+
f"Expected at least one positional|key-word argument of type np.ndarray, "
|
|
795
|
+
f"but received none. "
|
|
796
|
+
f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
|
|
797
|
+
)
|
|
291
798
|
for i, arg in enumerate(args):
|
|
292
799
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
293
|
-
f'
|
|
800
|
+
f'{user_function.__name__}() validation failed: '
|
|
294
801
|
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
295
802
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
296
803
|
assert arg.shape[0] != leap_binder.batch_size_to_validate, \
|
|
297
|
-
(f'
|
|
804
|
+
(f'{user_function.__name__}() validation failed: '
|
|
298
805
|
f'Argument #{i} should be without batch dimension. ')
|
|
299
806
|
|
|
300
807
|
for _arg_name, arg in kwargs.items():
|
|
301
808
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
302
|
-
f'
|
|
809
|
+
f'{user_function.__name__}() validation failed: '
|
|
303
810
|
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
304
811
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
305
812
|
assert arg.shape[0] != leap_binder.batch_size_to_validate, \
|
|
306
|
-
(f'
|
|
813
|
+
(f'{user_function.__name__}() validation failed: Argument {_arg_name} '
|
|
307
814
|
f'should be without batch dimension. ')
|
|
308
815
|
|
|
309
816
|
def _validate_result(result):
|
|
@@ -317,8 +824,11 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
317
824
|
LeapDataType.ImageWithBBox: LeapImageWithBBox,
|
|
318
825
|
LeapDataType.ImageWithHeatmap: LeapImageWithHeatmap
|
|
319
826
|
}
|
|
827
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
828
|
+
expected_type_name=result_type_map[visualizer_type])
|
|
829
|
+
|
|
320
830
|
assert isinstance(result, result_type_map[visualizer_type]), \
|
|
321
|
-
(f'
|
|
831
|
+
(f'{user_function.__name__}() validation failed: '
|
|
322
832
|
f'The return type should be {result_type_map[visualizer_type]}. Got {type(result)}.')
|
|
323
833
|
|
|
324
834
|
@functools.wraps(user_function)
|
|
@@ -345,11 +855,15 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
345
855
|
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
|
|
346
856
|
|
|
347
857
|
def inner(*args, **kwargs):
|
|
858
|
+
if not _call_from_tl_platform:
|
|
859
|
+
set_current('tensorleap_custom_visualizer')
|
|
348
860
|
_validate_input_args(*args, **kwargs)
|
|
349
861
|
|
|
350
862
|
result = inner_without_validate(*args, **kwargs)
|
|
351
863
|
|
|
352
864
|
_validate_result(result)
|
|
865
|
+
if not _call_from_tl_platform:
|
|
866
|
+
update_env_params_func("tensorleap_custom_visualizer", "v")
|
|
353
867
|
return result
|
|
354
868
|
|
|
355
869
|
def mapping_inner(*args, **kwargs):
|
|
@@ -357,6 +871,10 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
357
871
|
if 'user_unique_name' in kwargs:
|
|
358
872
|
user_unique_name = kwargs['user_unique_name']
|
|
359
873
|
|
|
874
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
875
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
876
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
877
|
+
|
|
360
878
|
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
361
879
|
ordered_connections = list(args) + ordered_connections
|
|
362
880
|
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
@@ -387,30 +905,26 @@ def tensorleap_metadata(
|
|
|
387
905
|
f'Please choose another')
|
|
388
906
|
|
|
389
907
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
390
|
-
assert isinstance(sample_id, (int, str)), \
|
|
391
|
-
(f'tensorleap_metadata validation failed: '
|
|
392
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
393
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
394
|
-
(f'tensorleap_metadata validation failed: '
|
|
395
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
396
908
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
397
|
-
(f'
|
|
909
|
+
(f'{user_function.__name__}() validation failed: '
|
|
398
910
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
399
911
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
400
912
|
|
|
401
913
|
def _validate_result(result):
|
|
402
914
|
supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
|
|
403
915
|
np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
|
|
916
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
917
|
+
expected_type_name=supported_result_types)
|
|
404
918
|
assert isinstance(result, supported_result_types), \
|
|
405
|
-
(f'
|
|
919
|
+
(f'{user_function.__name__}() validation failed: '
|
|
406
920
|
f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
|
|
407
921
|
if isinstance(result, dict):
|
|
408
922
|
for key, value in result.items():
|
|
409
923
|
assert isinstance(key, str), \
|
|
410
|
-
(f'
|
|
924
|
+
(f'{user_function.__name__}() validation failed: '
|
|
411
925
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
|
412
926
|
assert isinstance(value, supported_result_types), \
|
|
413
|
-
(f'
|
|
927
|
+
(f'{user_function.__name__}() validation failed: '
|
|
414
928
|
f'Values in the return dict should be of type {str(supported_result_types)}. Got {type(value)}.')
|
|
415
929
|
|
|
416
930
|
def inner_without_validate(sample_id, preprocess_response):
|
|
@@ -427,6 +941,60 @@ def tensorleap_metadata(
|
|
|
427
941
|
|
|
428
942
|
leap_binder.set_metadata(inner_without_validate, name, metadata_type)
|
|
429
943
|
|
|
944
|
+
def inner(*args, **kwargs):
|
|
945
|
+
if not _call_from_tl_platform:
|
|
946
|
+
set_current('tensorleap_metadata')
|
|
947
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
948
|
+
return None
|
|
949
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
950
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
951
|
+
sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
|
|
952
|
+
_validate_input_args(sample_id, preprocess_response)
|
|
953
|
+
|
|
954
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
955
|
+
|
|
956
|
+
_validate_result(result)
|
|
957
|
+
if not _call_from_tl_platform:
|
|
958
|
+
update_env_params_func("tensorleap_metadata", "v")
|
|
959
|
+
return result
|
|
960
|
+
|
|
961
|
+
return inner
|
|
962
|
+
|
|
963
|
+
return decorating_function
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def tensorleap_custom_latent_space():
|
|
967
|
+
def decorating_function(user_function: SectionCallableInterface):
|
|
968
|
+
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
969
|
+
assert isinstance(sample_id, (int, str)), \
|
|
970
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
971
|
+
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
972
|
+
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
973
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
974
|
+
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
975
|
+
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
976
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
977
|
+
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
978
|
+
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
979
|
+
|
|
980
|
+
def _validate_result(result):
|
|
981
|
+
assert isinstance(result, np.ndarray), \
|
|
982
|
+
(f'tensorleap_custom_loss validation failed: '
|
|
983
|
+
f'The return type should be a numpy array. Got {type(result)}.')
|
|
984
|
+
|
|
985
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
986
|
+
global _called_from_inside_tl_decorator
|
|
987
|
+
_called_from_inside_tl_decorator += 1
|
|
988
|
+
|
|
989
|
+
try:
|
|
990
|
+
result = user_function(sample_id, preprocess_response)
|
|
991
|
+
finally:
|
|
992
|
+
_called_from_inside_tl_decorator -= 1
|
|
993
|
+
|
|
994
|
+
return result
|
|
995
|
+
|
|
996
|
+
leap_binder.set_custom_latent_space(inner_without_validate)
|
|
997
|
+
|
|
430
998
|
def inner(sample_id, preprocess_response):
|
|
431
999
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
432
1000
|
return None
|
|
@@ -448,30 +1016,45 @@ def tensorleap_preprocess():
|
|
|
448
1016
|
leap_binder.set_preprocess(user_function)
|
|
449
1017
|
|
|
450
1018
|
def _validate_input_args(*args, **kwargs):
|
|
451
|
-
assert len(args)
|
|
452
|
-
(f'
|
|
1019
|
+
assert len(args) + len(kwargs) == 0, \
|
|
1020
|
+
(f'{user_function.__name__}() validation failed: '
|
|
453
1021
|
f'The function should not take any arguments. Got {args} and {kwargs}.')
|
|
454
1022
|
|
|
455
1023
|
def _validate_result(result):
|
|
456
|
-
assert isinstance(result, list),
|
|
457
|
-
(
|
|
458
|
-
|
|
1024
|
+
assert isinstance(result, list), (
|
|
1025
|
+
f"{user_function.__name__}() validation failed: expected return type list[{PreprocessResponse.__name__}]"
|
|
1026
|
+
f"(e.g., [PreprocessResponse1, PreprocessResponse2, ...]), but returned type is {type(result).__name__}."
|
|
1027
|
+
if not isinstance(result, tuple)
|
|
1028
|
+
else f"{user_function.__name__}() validation failed: expected to return a single list[{PreprocessResponse.__name__}] object, "
|
|
1029
|
+
f"but returned {len(result)} objects instead."
|
|
1030
|
+
)
|
|
459
1031
|
for i, response in enumerate(result):
|
|
460
1032
|
assert isinstance(response, PreprocessResponse), \
|
|
461
|
-
(f'
|
|
1033
|
+
(f'{user_function.__name__}() validation failed: '
|
|
462
1034
|
f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
|
|
463
1035
|
assert len(set(result)) == len(result), \
|
|
464
|
-
(f'
|
|
1036
|
+
(f'{user_function.__name__}() validation failed: '
|
|
465
1037
|
f'The return list should not contain duplicate PreprocessResponse objects.')
|
|
466
1038
|
|
|
467
1039
|
def inner(*args, **kwargs):
|
|
1040
|
+
if not _call_from_tl_platform:
|
|
1041
|
+
set_current('tensorleap_metadata')
|
|
468
1042
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
469
1043
|
return [None, None, None, None]
|
|
470
1044
|
|
|
471
1045
|
_validate_input_args(*args, **kwargs)
|
|
472
|
-
|
|
473
1046
|
result = user_function()
|
|
474
1047
|
_validate_result(result)
|
|
1048
|
+
|
|
1049
|
+
# Emit integration test event once per test
|
|
1050
|
+
try:
|
|
1051
|
+
emit_integration_event_once(AnalyticsEvent.PREPROCESS_INTEGRATION_TEST, {
|
|
1052
|
+
'preprocess_responses_count': len(result)
|
|
1053
|
+
})
|
|
1054
|
+
except Exception as e:
|
|
1055
|
+
logger.debug(f"Failed to emit preprocess integration test event: {e}")
|
|
1056
|
+
if not _call_from_tl_platform:
|
|
1057
|
+
update_env_params_func("tensorleap_preprocess", "v")
|
|
475
1058
|
return result
|
|
476
1059
|
|
|
477
1060
|
return inner
|
|
@@ -480,37 +1063,55 @@ def tensorleap_preprocess():
|
|
|
480
1063
|
|
|
481
1064
|
|
|
482
1065
|
def tensorleap_element_instance_preprocess(
|
|
483
|
-
|
|
1066
|
+
instance_length_encoder: InstanceLengthCallableInterface, instance_mask_encoder: InstanceCallableInterface):
|
|
484
1067
|
def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
|
|
485
1068
|
def user_function_instance() -> List[PreprocessResponse]:
|
|
486
1069
|
result = user_function()
|
|
487
1070
|
for preprocess_response in result:
|
|
488
1071
|
sample_ids_to_instance_mappings = {}
|
|
489
1072
|
instance_to_sample_ids_mappings = {}
|
|
490
|
-
instance_ids_to_names = {}
|
|
491
1073
|
all_sample_ids = preprocess_response.sample_ids.copy()
|
|
492
1074
|
for sample_id in preprocess_response.sample_ids:
|
|
493
|
-
|
|
494
|
-
instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(
|
|
1075
|
+
instances_length = instance_length_encoder(sample_id, preprocess_response)
|
|
1076
|
+
instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(instances_length)]
|
|
495
1077
|
sample_ids_to_instance_mappings[sample_id] = instances_ids
|
|
496
1078
|
instance_to_sample_ids_mappings[sample_id] = sample_id
|
|
497
|
-
|
|
498
|
-
instance_ids_to_names[sample_id] = 'none'
|
|
499
|
-
for instance_id, instance_name in zip(instances_ids, instance_names):
|
|
1079
|
+
for instance_id in instances_ids:
|
|
500
1080
|
instance_to_sample_ids_mappings[instance_id] = sample_id
|
|
501
|
-
instance_ids_to_names[instance_id] = instance_name
|
|
502
1081
|
all_sample_ids.extend(instances_ids)
|
|
1082
|
+
preprocess_response.length = len(all_sample_ids)
|
|
503
1083
|
preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
|
|
504
1084
|
preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
|
|
505
|
-
preprocess_response.instance_ids_to_names = instance_ids_to_names
|
|
506
1085
|
preprocess_response.sample_ids = all_sample_ids
|
|
507
1086
|
return result
|
|
508
1087
|
|
|
1088
|
+
def extract_extra_instance_metadata():
|
|
1089
|
+
result = user_function()
|
|
1090
|
+
for preprocess_response in result:
|
|
1091
|
+
for sample_id in preprocess_response.sample_ids:
|
|
1092
|
+
instances_length = instance_length_encoder(sample_id, preprocess_response)
|
|
1093
|
+
if instances_length > 0:
|
|
1094
|
+
element_instance = instance_mask_encoder(sample_id, preprocess_response, 0)
|
|
1095
|
+
instance_metadata = element_instance.instance_metadata
|
|
1096
|
+
if instance_metadata is None:
|
|
1097
|
+
return {}
|
|
1098
|
+
return instance_metadata
|
|
1099
|
+
return {}
|
|
1100
|
+
|
|
1101
|
+
|
|
509
1102
|
def builtin_instance_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
|
|
510
1103
|
return {'is_instance': '0', 'original_sample_id': idx, 'instance_name': 'none'}
|
|
511
1104
|
|
|
1105
|
+
def builtin_instance_extra_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
|
|
1106
|
+
instance_metadata = extract_extra_instance_metadata()
|
|
1107
|
+
for key, value in instance_metadata.items():
|
|
1108
|
+
instance_metadata[key] = 'unset'
|
|
1109
|
+
return instance_metadata
|
|
1110
|
+
|
|
512
1111
|
leap_binder.set_preprocess(user_function_instance)
|
|
513
1112
|
leap_binder.set_metadata(builtin_instance_metadata, "builtin_instance_metadata")
|
|
1113
|
+
leap_binder.set_metadata(builtin_instance_extra_metadata, "builtin_instance_extra_metadata")
|
|
1114
|
+
|
|
514
1115
|
|
|
515
1116
|
def _validate_input_args(*args, **kwargs):
|
|
516
1117
|
assert len(args) == 0 and len(kwargs) == 0, \
|
|
@@ -537,6 +1138,8 @@ def tensorleap_element_instance_preprocess(
|
|
|
537
1138
|
|
|
538
1139
|
result = user_function_instance()
|
|
539
1140
|
_validate_result(result)
|
|
1141
|
+
if not _call_from_tl_platform:
|
|
1142
|
+
update_env_params_func("tensorleap_preprocess", "v")
|
|
540
1143
|
return result
|
|
541
1144
|
|
|
542
1145
|
return inner
|
|
@@ -571,7 +1174,7 @@ def tensorleap_unlabeled_preprocess():
|
|
|
571
1174
|
|
|
572
1175
|
def tensorleap_instances_masks_encoder(name: str):
|
|
573
1176
|
def decorating_function(user_function: InstanceCallableInterface):
|
|
574
|
-
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
|
|
1177
|
+
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: int):
|
|
575
1178
|
assert isinstance(sample_id, str), \
|
|
576
1179
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
577
1180
|
f'Argument sample_id should be str. Got {type(sample_id)}.')
|
|
@@ -582,18 +1185,21 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
582
1185
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
583
1186
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
584
1187
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
1188
|
+
assert isinstance(instance_id, int), \
|
|
1189
|
+
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
1190
|
+
f'Argument instance_id should be int. Got {type(instance_id)}.')
|
|
585
1191
|
|
|
586
1192
|
def _validate_result(result):
|
|
587
|
-
assert isinstance(result,
|
|
1193
|
+
assert isinstance(result, ElementInstance) or (result is None), \
|
|
588
1194
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
589
|
-
f'Unsupported return type. Should be a
|
|
1195
|
+
f'Unsupported return type. Should be a ElementInstance or None. Got {type(result)}.')
|
|
590
1196
|
|
|
591
|
-
def inner_without_validate(sample_id, preprocess_response):
|
|
1197
|
+
def inner_without_validate(sample_id, preprocess_response, instance_id):
|
|
592
1198
|
global _called_from_inside_tl_decorator
|
|
593
1199
|
_called_from_inside_tl_decorator += 1
|
|
594
1200
|
|
|
595
1201
|
try:
|
|
596
|
-
result = user_function(sample_id, preprocess_response)
|
|
1202
|
+
result = user_function(sample_id, preprocess_response, instance_id)
|
|
597
1203
|
finally:
|
|
598
1204
|
_called_from_inside_tl_decorator -= 1
|
|
599
1205
|
|
|
@@ -601,6 +1207,52 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
601
1207
|
|
|
602
1208
|
leap_binder.set_instance_masks(inner_without_validate, name)
|
|
603
1209
|
|
|
1210
|
+
def inner(sample_id, preprocess_response, instance_id):
|
|
1211
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1212
|
+
return None
|
|
1213
|
+
|
|
1214
|
+
_validate_input_args(sample_id, preprocess_response, instance_id)
|
|
1215
|
+
|
|
1216
|
+
result = inner_without_validate(sample_id, preprocess_response, instance_id)
|
|
1217
|
+
|
|
1218
|
+
_validate_result(result)
|
|
1219
|
+
return result
|
|
1220
|
+
|
|
1221
|
+
return inner
|
|
1222
|
+
|
|
1223
|
+
return decorating_function
|
|
1224
|
+
|
|
1225
|
+
|
|
1226
|
+
def tensorleap_instances_length_encoder(name: str):
|
|
1227
|
+
def decorating_function(user_function: InstanceLengthCallableInterface):
|
|
1228
|
+
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
|
|
1229
|
+
assert isinstance(sample_id, (str, int)), \
|
|
1230
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1231
|
+
f'Argument sample_id should be str. Got {type(sample_id)}.')
|
|
1232
|
+
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
1233
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1234
|
+
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
1235
|
+
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
1236
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1237
|
+
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
1238
|
+
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
1239
|
+
|
|
1240
|
+
def _validate_result(result):
|
|
1241
|
+
assert isinstance(result, int), \
|
|
1242
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1243
|
+
f'Unsupported return type. Should be a int. Got {type(result)}.')
|
|
1244
|
+
|
|
1245
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1246
|
+
global _called_from_inside_tl_decorator
|
|
1247
|
+
_called_from_inside_tl_decorator += 1
|
|
1248
|
+
|
|
1249
|
+
try:
|
|
1250
|
+
result = user_function(sample_id, preprocess_response)
|
|
1251
|
+
finally:
|
|
1252
|
+
_called_from_inside_tl_decorator -= 1
|
|
1253
|
+
|
|
1254
|
+
return result
|
|
1255
|
+
|
|
604
1256
|
def inner(sample_id, preprocess_response):
|
|
605
1257
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
606
1258
|
return None
|
|
@@ -617,46 +1269,87 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
617
1269
|
return decorating_function
|
|
618
1270
|
|
|
619
1271
|
|
|
620
|
-
def tensorleap_input_encoder(name: str, channel_dim
|
|
1272
|
+
def tensorleap_input_encoder(name: str, channel_dim=_UNSET, model_input_index=None):
|
|
621
1273
|
def decorating_function(user_function: SectionCallableInterface):
|
|
622
1274
|
for input_handler in leap_binder.setup_container.inputs:
|
|
623
1275
|
if input_handler.name == name:
|
|
624
1276
|
raise Exception(f'Input with name {name} already exists. '
|
|
625
1277
|
f'Please choose another')
|
|
1278
|
+
nonlocal channel_dim
|
|
1279
|
+
|
|
1280
|
+
channel_dim_was_provided = channel_dim is not _UNSET
|
|
1281
|
+
|
|
1282
|
+
if not channel_dim_was_provided:
|
|
1283
|
+
channel_dim = -1
|
|
1284
|
+
if not _call_from_tl_platform:
|
|
1285
|
+
store_warning_by_param(
|
|
1286
|
+
param_name="channel_dim",
|
|
1287
|
+
user_func_name=user_function.__name__,
|
|
1288
|
+
default_value=channel_dim,
|
|
1289
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/input-encoder"
|
|
1290
|
+
|
|
1291
|
+
)
|
|
1292
|
+
|
|
626
1293
|
if channel_dim <= 0 and channel_dim != -1:
|
|
627
1294
|
raise Exception(f"Channel dim for input {name} is expected to be either -1 or positive")
|
|
628
1295
|
|
|
629
|
-
leap_binder.set_input(user_function, name, channel_dim=channel_dim)
|
|
630
|
-
|
|
631
1296
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
632
|
-
assert isinstance(sample_id, (int, str)), \
|
|
633
|
-
(f'tensorleap_input_encoder validation failed: '
|
|
634
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
635
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
636
|
-
(f'tensorleap_input_encoder validation failed: '
|
|
637
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
638
1297
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
639
|
-
(f'
|
|
1298
|
+
(f'{user_function.__name__}() validation failed: '
|
|
640
1299
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
641
1300
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
642
1301
|
|
|
643
1302
|
def _validate_result(result):
|
|
1303
|
+
validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray")
|
|
644
1304
|
assert isinstance(result, np.ndarray), \
|
|
645
|
-
(f'
|
|
1305
|
+
(f'{user_function.__name__}() validation failed: '
|
|
646
1306
|
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
647
1307
|
assert result.dtype == np.float32, \
|
|
648
|
-
(f'
|
|
1308
|
+
(f'{user_function.__name__}() validation failed: '
|
|
649
1309
|
f'The return type should be a numpy array of type float32. Got {result.dtype}.')
|
|
650
|
-
assert channel_dim - 1 <= len(result.shape), (f'
|
|
1310
|
+
assert channel_dim - 1 <= len(result.shape), (f'{user_function.__name__}() validation failed: '
|
|
651
1311
|
f'The channel_dim ({channel_dim}) should be <= to the rank of the resulting input rank ({len(result.shape)}).')
|
|
652
1312
|
|
|
653
|
-
def
|
|
1313
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1314
|
+
global _called_from_inside_tl_decorator
|
|
1315
|
+
_called_from_inside_tl_decorator += 1
|
|
1316
|
+
|
|
1317
|
+
try:
|
|
1318
|
+
result = user_function(sample_id, preprocess_response)
|
|
1319
|
+
finally:
|
|
1320
|
+
_called_from_inside_tl_decorator -= 1
|
|
1321
|
+
|
|
1322
|
+
return result
|
|
1323
|
+
|
|
1324
|
+
leap_binder.set_input(inner_without_validate, name, channel_dim=channel_dim)
|
|
1325
|
+
|
|
1326
|
+
def inner(*args, **kwargs):
|
|
1327
|
+
if not _call_from_tl_platform:
|
|
1328
|
+
set_current("tensorleap_input_encoder")
|
|
1329
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
1330
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
1331
|
+
sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
|
|
654
1332
|
_validate_input_args(sample_id, preprocess_response)
|
|
655
|
-
|
|
1333
|
+
|
|
1334
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1335
|
+
|
|
656
1336
|
_validate_result(result)
|
|
657
1337
|
|
|
658
|
-
if _called_from_inside_tl_decorator == 0:
|
|
1338
|
+
if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
|
|
1339
|
+
batch_warning(result, user_function.__name__)
|
|
659
1340
|
result = np.expand_dims(result, axis=0)
|
|
1341
|
+
# Emit integration test event once per test
|
|
1342
|
+
try:
|
|
1343
|
+
emit_integration_event_once(AnalyticsEvent.INPUT_ENCODER_INTEGRATION_TEST, {
|
|
1344
|
+
'encoder_name': name,
|
|
1345
|
+
'channel_dim': channel_dim,
|
|
1346
|
+
'model_input_index': model_input_index
|
|
1347
|
+
})
|
|
1348
|
+
except Exception as e:
|
|
1349
|
+
logger.debug(f"Failed to emit input_encoder integration test event: {e}")
|
|
1350
|
+
if not _call_from_tl_platform:
|
|
1351
|
+
update_env_params_func("tensorleap_input_encoder", "v")
|
|
1352
|
+
|
|
660
1353
|
return result
|
|
661
1354
|
|
|
662
1355
|
node_mapping_type = NodeMappingType.Input
|
|
@@ -664,22 +1357,23 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
|
|
|
664
1357
|
node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
|
|
665
1358
|
inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
666
1359
|
|
|
667
|
-
def mapping_inner(
|
|
1360
|
+
def mapping_inner(*args, **kwargs):
|
|
668
1361
|
class TempMapping:
|
|
669
1362
|
pass
|
|
670
1363
|
|
|
671
1364
|
ret = TempMapping()
|
|
672
1365
|
ret.node_mapping = mapping_inner.node_mapping
|
|
673
1366
|
|
|
1367
|
+
leap_binder.mapping_connections.append(NodeConnection(mapping_inner.node_mapping, None))
|
|
674
1368
|
return ret
|
|
675
1369
|
|
|
676
1370
|
mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
677
1371
|
|
|
678
|
-
def final_inner(
|
|
1372
|
+
def final_inner(*args, **kwargs):
|
|
679
1373
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
680
|
-
return mapping_inner(
|
|
1374
|
+
return mapping_inner(*args, **kwargs)
|
|
681
1375
|
else:
|
|
682
|
-
return inner(
|
|
1376
|
+
return inner(*args, **kwargs)
|
|
683
1377
|
|
|
684
1378
|
final_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
685
1379
|
|
|
@@ -695,40 +1389,64 @@ def tensorleap_gt_encoder(name: str):
|
|
|
695
1389
|
raise Exception(f'GT with name {name} already exists. '
|
|
696
1390
|
f'Please choose another')
|
|
697
1391
|
|
|
698
|
-
leap_binder.set_ground_truth(user_function, name)
|
|
699
|
-
|
|
700
1392
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
701
|
-
assert isinstance(sample_id, (int, str)), \
|
|
702
|
-
(f'tensorleap_gt_encoder validation failed: '
|
|
703
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
704
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
705
|
-
(f'tensorleap_gt_encoder validation failed: '
|
|
706
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
707
1393
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
708
|
-
(f'
|
|
1394
|
+
(f'{user_function.__name__}() validation failed: '
|
|
709
1395
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
710
1396
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
711
1397
|
|
|
712
1398
|
def _validate_result(result):
|
|
1399
|
+
validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray",
|
|
1400
|
+
gt_flag=True)
|
|
713
1401
|
assert isinstance(result, np.ndarray), \
|
|
714
|
-
(f'
|
|
1402
|
+
(f'{user_function.__name__}() validation failed: '
|
|
715
1403
|
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
716
1404
|
assert result.dtype == np.float32, \
|
|
717
|
-
(f'
|
|
1405
|
+
(f'{user_function.__name__}() validation failed: '
|
|
718
1406
|
f'The return type should be a numpy array of type float32. Got {result.dtype}.')
|
|
719
1407
|
|
|
720
|
-
def
|
|
1408
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1409
|
+
global _called_from_inside_tl_decorator
|
|
1410
|
+
_called_from_inside_tl_decorator += 1
|
|
1411
|
+
|
|
1412
|
+
try:
|
|
1413
|
+
result = user_function(sample_id, preprocess_response)
|
|
1414
|
+
finally:
|
|
1415
|
+
_called_from_inside_tl_decorator -= 1
|
|
1416
|
+
|
|
1417
|
+
return result
|
|
1418
|
+
|
|
1419
|
+
leap_binder.set_ground_truth(inner_without_validate, name)
|
|
1420
|
+
|
|
1421
|
+
def inner(*args, **kwargs):
|
|
1422
|
+
if not _call_from_tl_platform:
|
|
1423
|
+
set_current("tensorleap_gt_encoder")
|
|
1424
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
1425
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
1426
|
+
sample_id, preprocess_response = args
|
|
721
1427
|
_validate_input_args(sample_id, preprocess_response)
|
|
722
|
-
|
|
1428
|
+
|
|
1429
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1430
|
+
|
|
723
1431
|
_validate_result(result)
|
|
724
1432
|
|
|
725
|
-
if _called_from_inside_tl_decorator == 0:
|
|
1433
|
+
if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
|
|
1434
|
+
batch_warning(result, user_function.__name__)
|
|
726
1435
|
result = np.expand_dims(result, axis=0)
|
|
1436
|
+
# Emit integration test event once per test
|
|
1437
|
+
try:
|
|
1438
|
+
emit_integration_event_once(AnalyticsEvent.GT_ENCODER_INTEGRATION_TEST, {
|
|
1439
|
+
'encoder_name': name
|
|
1440
|
+
})
|
|
1441
|
+
except Exception as e:
|
|
1442
|
+
logger.debug(f"Failed to emit gt_encoder integration test event: {e}")
|
|
1443
|
+
if not _call_from_tl_platform:
|
|
1444
|
+
update_env_params_func("tensorleap_gt_encoder", "v")
|
|
727
1445
|
return result
|
|
728
1446
|
|
|
729
1447
|
inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
730
1448
|
|
|
731
|
-
def mapping_inner(
|
|
1449
|
+
def mapping_inner(*args, **kwargs):
|
|
732
1450
|
class TempMapping:
|
|
733
1451
|
pass
|
|
734
1452
|
|
|
@@ -739,11 +1457,11 @@ def tensorleap_gt_encoder(name: str):
|
|
|
739
1457
|
|
|
740
1458
|
mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
741
1459
|
|
|
742
|
-
def final_inner(
|
|
1460
|
+
def final_inner(*args, **kwargs):
|
|
743
1461
|
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
744
|
-
return mapping_inner(
|
|
1462
|
+
return mapping_inner(*args, **kwargs)
|
|
745
1463
|
else:
|
|
746
|
-
return inner(
|
|
1464
|
+
return inner(*args, **kwargs)
|
|
747
1465
|
|
|
748
1466
|
final_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
749
1467
|
|
|
@@ -753,6 +1471,8 @@ def tensorleap_gt_encoder(name: str):
|
|
|
753
1471
|
|
|
754
1472
|
|
|
755
1473
|
def tensorleap_custom_loss(name: str, connects_to=None):
|
|
1474
|
+
name_to_unique_name = defaultdict(set)
|
|
1475
|
+
|
|
756
1476
|
def decorating_function(user_function: CustomCallableInterface):
|
|
757
1477
|
for loss_handler in leap_binder.setup_container.custom_loss_handlers:
|
|
758
1478
|
if loss_handler.custom_loss_handler_data.name == name:
|
|
@@ -760,35 +1480,29 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
760
1480
|
f'Please choose another')
|
|
761
1481
|
|
|
762
1482
|
valid_types = (np.ndarray, SamplePreprocessResponse)
|
|
763
|
-
try:
|
|
764
|
-
import tensorflow as tf
|
|
765
|
-
valid_types = (np.ndarray, SamplePreprocessResponse, tf.Tensor)
|
|
766
|
-
except ImportError:
|
|
767
|
-
pass
|
|
768
1483
|
|
|
769
1484
|
def _validate_input_args(*args, **kwargs):
|
|
770
|
-
|
|
1485
|
+
assert len(args) + len(kwargs) > 0, (
|
|
1486
|
+
f"{user_function.__name__}() validation failed: "
|
|
1487
|
+
f"Expected at least one positional|key-word argument of the allowed types (np.ndarray|SamplePreprocessResponse|). "
|
|
1488
|
+
f"but received none. "
|
|
1489
|
+
f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
|
|
1490
|
+
)
|
|
771
1491
|
for i, arg in enumerate(args):
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
775
|
-
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
|
776
|
-
else:
|
|
777
|
-
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
778
|
-
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
1492
|
+
assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
|
|
1493
|
+
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
779
1494
|
for _arg_name, arg in kwargs.items():
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
783
|
-
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
|
784
|
-
else:
|
|
785
|
-
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
786
|
-
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
1495
|
+
assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
|
|
1496
|
+
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
787
1497
|
|
|
788
1498
|
def _validate_result(result):
|
|
789
|
-
|
|
790
|
-
|
|
1499
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
1500
|
+
expected_type_name="np.ndarray")
|
|
1501
|
+
assert isinstance(result, np.ndarray), \
|
|
1502
|
+
(f'{user_function.__name__} validation failed: '
|
|
791
1503
|
f'The return type should be a numpy array. Got {type(result)}.')
|
|
1504
|
+
assert result.ndim < 2, (f'{user_function.__name__} validation failed: '
|
|
1505
|
+
f'The return type should be a 1Dim numpy array but got {result.ndim}Dim.')
|
|
792
1506
|
|
|
793
1507
|
@functools.wraps(user_function)
|
|
794
1508
|
def inner_without_validate(*args, **kwargs):
|
|
@@ -814,11 +1528,16 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
814
1528
|
_add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
|
|
815
1529
|
|
|
816
1530
|
def inner(*args, **kwargs):
|
|
1531
|
+
if not _call_from_tl_platform:
|
|
1532
|
+
set_current("tensorleap_custom_loss")
|
|
817
1533
|
_validate_input_args(*args, **kwargs)
|
|
818
1534
|
|
|
819
1535
|
result = inner_without_validate(*args, **kwargs)
|
|
820
1536
|
|
|
821
1537
|
_validate_result(result)
|
|
1538
|
+
if not _call_from_tl_platform:
|
|
1539
|
+
update_env_params_func("tensorleap_custom_loss", "v")
|
|
1540
|
+
|
|
822
1541
|
return result
|
|
823
1542
|
|
|
824
1543
|
def mapping_inner(*args, **kwargs):
|
|
@@ -826,6 +1545,10 @@ def tensorleap_custom_loss(name: str, connects_to=None):
|
|
|
826
1545
|
if 'user_unique_name' in kwargs:
|
|
827
1546
|
user_unique_name = kwargs['user_unique_name']
|
|
828
1547
|
|
|
1548
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
1549
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
1550
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
1551
|
+
|
|
829
1552
|
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
830
1553
|
ordered_connections = list(args) + ordered_connections
|
|
831
1554
|
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
@@ -871,3 +1594,180 @@ def tensorleap_custom_layer(name: str):
|
|
|
871
1594
|
return custom_layer
|
|
872
1595
|
|
|
873
1596
|
return decorating_function
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def tensorleap_status_table():
|
|
1600
|
+
import atexit
|
|
1601
|
+
import sys
|
|
1602
|
+
import traceback
|
|
1603
|
+
from typing import Any
|
|
1604
|
+
|
|
1605
|
+
CHECK = "✅"
|
|
1606
|
+
CROSS = "❌"
|
|
1607
|
+
UNKNOWN = "❔"
|
|
1608
|
+
|
|
1609
|
+
code_mapping_failure = [0]
|
|
1610
|
+
|
|
1611
|
+
table = [
|
|
1612
|
+
{"name": "tensorleap_preprocess", "Added to integration": UNKNOWN},
|
|
1613
|
+
{"name": "tensorleap_integration_test", "Added to integration": UNKNOWN},
|
|
1614
|
+
{"name": "tensorleap_input_encoder", "Added to integration": UNKNOWN},
|
|
1615
|
+
{"name": "tensorleap_gt_encoder", "Added to integration": UNKNOWN},
|
|
1616
|
+
{"name": "tensorleap_load_model", "Added to integration": UNKNOWN},
|
|
1617
|
+
{"name": "tensorleap_custom_loss", "Added to integration": UNKNOWN},
|
|
1618
|
+
{"name": "tensorleap_custom_metric (optional)", "Added to integration": UNKNOWN},
|
|
1619
|
+
{"name": "tensorleap_metadata (optional)", "Added to integration": UNKNOWN},
|
|
1620
|
+
{"name": "tensorleap_custom_visualizer (optional)", "Added to integration": UNKNOWN},
|
|
1621
|
+
]
|
|
1622
|
+
|
|
1623
|
+
_finalizer_called = {"done": False}
|
|
1624
|
+
_crashed = {"value": False}
|
|
1625
|
+
_current_func = {"name": None}
|
|
1626
|
+
|
|
1627
|
+
def _link(url: str) -> str:
|
|
1628
|
+
return f"\033{url}\033"
|
|
1629
|
+
|
|
1630
|
+
def _remove_suffix(s: str, suffix: str) -> str:
|
|
1631
|
+
if suffix and s.endswith(suffix):
|
|
1632
|
+
return s[:-len(suffix)]
|
|
1633
|
+
return s
|
|
1634
|
+
|
|
1635
|
+
def _find_row(name: str):
|
|
1636
|
+
for row in table:
|
|
1637
|
+
if _remove_suffix(row["name"], " (optional)") == name:
|
|
1638
|
+
return row
|
|
1639
|
+
return None
|
|
1640
|
+
|
|
1641
|
+
def _set_status(name: str, status_symbol: str):
|
|
1642
|
+
row = _find_row(name)
|
|
1643
|
+
if not row:
|
|
1644
|
+
return
|
|
1645
|
+
|
|
1646
|
+
cur = row["Added to integration"]
|
|
1647
|
+
if status_symbol == UNKNOWN:
|
|
1648
|
+
return
|
|
1649
|
+
if cur == CHECK and status_symbol != CHECK:
|
|
1650
|
+
return
|
|
1651
|
+
|
|
1652
|
+
row["Added to integration"] = status_symbol
|
|
1653
|
+
|
|
1654
|
+
def _mark_unknowns_as_cross():
|
|
1655
|
+
for row in table:
|
|
1656
|
+
if row["Added to integration"] == UNKNOWN:
|
|
1657
|
+
row["Added to integration"] = CROSS
|
|
1658
|
+
|
|
1659
|
+
def _format_default_value(v: Any) -> str:
|
|
1660
|
+
if hasattr(v, "name"):
|
|
1661
|
+
return str(v.name)
|
|
1662
|
+
if isinstance(v, str):
|
|
1663
|
+
return v
|
|
1664
|
+
s = repr(v)
|
|
1665
|
+
return s if len(s) <= 120 else s[:120] + "..."
|
|
1666
|
+
|
|
1667
|
+
def _print_param_default_warnings():
|
|
1668
|
+
data = _get_param_default_warnings()
|
|
1669
|
+
if not data:
|
|
1670
|
+
return
|
|
1671
|
+
|
|
1672
|
+
print("\nWarnings (Default use. It is recommended to set values explicitly):")
|
|
1673
|
+
for param_name in sorted(data.keys()):
|
|
1674
|
+
default_value = data[param_name]["default_value"]
|
|
1675
|
+
funcs = ", ".join(sorted(data[param_name]["funcs"]))
|
|
1676
|
+
dv = _format_default_value(default_value)
|
|
1677
|
+
|
|
1678
|
+
docs_link = data[param_name].get("link_to_docs")
|
|
1679
|
+
docs_part = f" {_link(docs_link)}" if docs_link else ""
|
|
1680
|
+
print(
|
|
1681
|
+
f" ⚠️ Parameter '{param_name}' defaults to {dv} in the following functions: [{funcs}]. "
|
|
1682
|
+
f"For more information, check {docs_part}")
|
|
1683
|
+
print("\nIf this isn’t the intended behaviour, set them explicitly.")
|
|
1684
|
+
|
|
1685
|
+
def _print_table():
|
|
1686
|
+
_print_param_default_warnings()
|
|
1687
|
+
|
|
1688
|
+
if not started_from("leap_integration.py"):
|
|
1689
|
+
return
|
|
1690
|
+
|
|
1691
|
+
ready_mess = "\nAll parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system."
|
|
1692
|
+
not_ready_mess = "\nSome mandatory components have not yet been added to the Integration test. Recommended next interface to add is: "
|
|
1693
|
+
mandatory_ready_mess = "\nAll mandatory parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system or continue to the next optional reccomeded interface,adding: "
|
|
1694
|
+
code_mapping_failure_mes = "Tensorleap_integration_test code flow failed, check raised exception."
|
|
1695
|
+
|
|
1696
|
+
name_width = max(len(row["name"]) for row in table)
|
|
1697
|
+
status_width = max(len(row["Added to integration"]) for row in table)
|
|
1698
|
+
|
|
1699
|
+
header = f"{'Decorator Name'.ljust(name_width)} | {'Added to integration'.ljust(status_width)}"
|
|
1700
|
+
sep = "-" * len(header)
|
|
1701
|
+
|
|
1702
|
+
print("\n" + header)
|
|
1703
|
+
print(sep)
|
|
1704
|
+
|
|
1705
|
+
ready = True
|
|
1706
|
+
next_step = None
|
|
1707
|
+
|
|
1708
|
+
for row in table:
|
|
1709
|
+
print(f"{row['name'].ljust(name_width)} | {row['Added to integration'].ljust(status_width)}")
|
|
1710
|
+
if not _crashed["value"] and ready:
|
|
1711
|
+
if row["Added to integration"] != CHECK:
|
|
1712
|
+
ready = False
|
|
1713
|
+
next_step = row["name"]
|
|
1714
|
+
|
|
1715
|
+
if _crashed["value"]:
|
|
1716
|
+
print(f"\nScript crashed before completing all steps. crashed at function '{_current_func['name']}'.")
|
|
1717
|
+
return
|
|
1718
|
+
|
|
1719
|
+
if code_mapping_failure[0]:
|
|
1720
|
+
print(f"\n{CROSS + code_mapping_failure_mes}.")
|
|
1721
|
+
return
|
|
1722
|
+
|
|
1723
|
+
print(ready_mess) if ready else print(
|
|
1724
|
+
mandatory_ready_mess + next_step
|
|
1725
|
+
) if (next_step and "optional" in next_step) else print(not_ready_mess + (next_step or ""))
|
|
1726
|
+
|
|
1727
|
+
def set_current(name: str):
|
|
1728
|
+
_current_func["name"] = name
|
|
1729
|
+
|
|
1730
|
+
def update_env_params(name: str, status: str = "v"):
|
|
1731
|
+
if name == "code_mapping":
|
|
1732
|
+
code_mapping_failure[0] = 1
|
|
1733
|
+
if status == "v":
|
|
1734
|
+
_set_status(name, CHECK)
|
|
1735
|
+
else:
|
|
1736
|
+
_set_status(name, CROSS)
|
|
1737
|
+
|
|
1738
|
+
def run_on_exit():
|
|
1739
|
+
if _finalizer_called["done"]:
|
|
1740
|
+
return
|
|
1741
|
+
_finalizer_called["done"] = True
|
|
1742
|
+
if not _crashed["value"]:
|
|
1743
|
+
_mark_unknowns_as_cross()
|
|
1744
|
+
|
|
1745
|
+
_print_table()
|
|
1746
|
+
|
|
1747
|
+
def handle_exception(exc_type, exc_value, exc_traceback):
|
|
1748
|
+
_crashed["value"] = True
|
|
1749
|
+
crashed_name = _current_func["name"]
|
|
1750
|
+
if crashed_name:
|
|
1751
|
+
row = _find_row(crashed_name)
|
|
1752
|
+
if row and row["Added to integration"] != CHECK:
|
|
1753
|
+
row["Added to integration"] = CROSS
|
|
1754
|
+
|
|
1755
|
+
traceback.print_exception(exc_type, exc_value, exc_traceback)
|
|
1756
|
+
run_on_exit()
|
|
1757
|
+
|
|
1758
|
+
atexit.register(run_on_exit)
|
|
1759
|
+
sys.excepthook = handle_exception
|
|
1760
|
+
|
|
1761
|
+
return set_current, update_env_params
|
|
1762
|
+
|
|
1763
|
+
|
|
1764
|
+
if not _call_from_tl_platform:
|
|
1765
|
+
set_current, update_env_params_func = tensorleap_status_table()
|
|
1766
|
+
|
|
1767
|
+
|
|
1768
|
+
|
|
1769
|
+
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
|
|
1773
|
+
|