code-loader 1.0.94.dev11__py3-none-any.whl → 1.0.153.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of code-loader might be problematic. Click here for more details.
- code_loader/__init__.py +6 -0
- code_loader/contract/datasetclasses.py +30 -11
- code_loader/contract/mapping.py +11 -0
- code_loader/inner_leap_binder/leapbinder.py +21 -3
- code_loader/inner_leap_binder/leapbinder_decorators.py +1368 -147
- code_loader/leaploader.py +65 -25
- code_loader/leaploaderbase.py +30 -5
- code_loader/mixpanel_tracker.py +230 -0
- code_loader/plot_functions/__init__.py +0 -0
- code_loader/plot_functions/plot_functions.py +444 -0
- code_loader/plot_functions/visualize.py +22 -0
- code_loader/utils.py +6 -6
- code_loader/visualizers/default_visualizers.py +3 -0
- {code_loader-1.0.94.dev11.dist-info → code_loader-1.0.153.dev3.dist-info}/METADATA +6 -3
- {code_loader-1.0.94.dev11.dist-info → code_loader-1.0.153.dev3.dist-info}/RECORD +17 -13
- {code_loader-1.0.94.dev11.dist-info → code_loader-1.0.153.dev3.dist-info}/LICENSE +0 -0
- {code_loader-1.0.94.dev11.dist-info → code_loader-1.0.153.dev3.dist-info}/WHEEL +0 -0
|
@@ -1,114 +1,703 @@
|
|
|
1
1
|
# mypy: ignore-errors
|
|
2
|
-
|
|
3
|
-
|
|
2
|
+
import os
|
|
3
|
+
import warnings
|
|
4
|
+
import logging
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from functools import lru_cache
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Optional, Union, Callable, List, Dict, Set, Any
|
|
9
|
+
from typing import Optional, Union, Callable, List, Dict, get_args, get_origin, DefaultDict
|
|
4
10
|
|
|
5
11
|
import numpy as np
|
|
6
12
|
import numpy.typing as npt
|
|
7
13
|
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
8
16
|
from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
|
|
9
17
|
CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
|
|
10
18
|
VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
|
|
11
|
-
ConfusionMatrixElement, SamplePreprocessResponse, InstanceCallableInterface
|
|
12
|
-
|
|
19
|
+
ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance, \
|
|
20
|
+
InstanceLengthCallableInterface
|
|
21
|
+
from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType, DataStateType
|
|
13
22
|
from code_loader import leap_binder
|
|
14
23
|
from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
|
|
15
24
|
from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
|
|
16
25
|
LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
|
|
26
|
+
from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
|
|
27
|
+
from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
|
|
28
|
+
|
|
29
|
+
import inspect
|
|
30
|
+
import functools
|
|
31
|
+
from pathlib import Path
|
|
32
|
+
|
|
33
|
+
_called_from_inside_tl_decorator = 0
|
|
34
|
+
_called_from_inside_tl_integration_test_decorator = False
|
|
35
|
+
_call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
|
|
36
|
+
|
|
37
|
+
# ---- warnings store (module-level) ----
|
|
38
|
+
_UNSET = object()
|
|
39
|
+
_STORED_WARNINGS: List[Dict[str, Any]] = []
|
|
40
|
+
_STORED_WARNING_KEYS: Set[tuple] = set()
|
|
41
|
+
# param_name -> set(user_func_name)
|
|
42
|
+
_PARAM_DEFAULT_FUNCS: DefaultDict[str, Set[str]] = defaultdict(set)
|
|
43
|
+
# param_name -> default_value used (repr-able)
|
|
44
|
+
_PARAM_DEFAULT_VALUE: Dict[str, Any] = {}
|
|
45
|
+
_PARAM_DEFAULT_DOCS: Dict[str, str] = {}
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_entry_script_path() -> str:
|
|
49
|
+
import sys
|
|
50
|
+
argv0 = sys.argv[0] if sys.argv else ""
|
|
51
|
+
if argv0:
|
|
52
|
+
return str(Path(argv0).resolve())
|
|
53
|
+
import __main__
|
|
54
|
+
main_file = getattr(__main__, "__file__", "") or ""
|
|
55
|
+
return str(Path(main_file).resolve()) if main_file else ""
|
|
56
|
+
|
|
57
|
+
def started_from(filename: str) -> bool:
|
|
58
|
+
entry = get_entry_script_path()
|
|
59
|
+
return bool(entry) and Path(entry).name == filename
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def store_warning_by_param(
|
|
63
|
+
*,
|
|
64
|
+
param_name: str,
|
|
65
|
+
user_func_name: str,
|
|
66
|
+
default_value: Any,
|
|
67
|
+
link_to_docs: str = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
_PARAM_DEFAULT_FUNCS[param_name].add(user_func_name)
|
|
70
|
+
|
|
71
|
+
if param_name not in _PARAM_DEFAULT_VALUE:
|
|
72
|
+
_PARAM_DEFAULT_VALUE[param_name] = default_value
|
|
73
|
+
|
|
74
|
+
if link_to_docs and param_name not in _PARAM_DEFAULT_DOCS:
|
|
75
|
+
_PARAM_DEFAULT_DOCS[param_name] = link_to_docs
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _get_param_default_warnings() -> Dict[str, Dict[str, Any]]:
|
|
79
|
+
out: Dict[str, Dict[str, Any]] = {}
|
|
80
|
+
for p, funcs in _PARAM_DEFAULT_FUNCS.items():
|
|
81
|
+
out[p] = {
|
|
82
|
+
"default_value": _PARAM_DEFAULT_VALUE.get(p, None),
|
|
83
|
+
"funcs": set(funcs),
|
|
84
|
+
"link_to_docs": _PARAM_DEFAULT_DOCS.get(p),
|
|
85
|
+
}
|
|
86
|
+
return out
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _get_stored_warnings() -> List[Dict[str, Any]]:
|
|
90
|
+
return list(_STORED_WARNINGS)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def validate_args_structure(*args, types_order, func_name, expected_names, **kwargs):
|
|
94
|
+
def _type_to_str(t):
|
|
95
|
+
origin = get_origin(t)
|
|
96
|
+
if origin is Union:
|
|
97
|
+
return " | ".join(tt.__name__ for tt in get_args(t))
|
|
98
|
+
elif hasattr(t, "__name__"):
|
|
99
|
+
return t.__name__
|
|
100
|
+
else:
|
|
101
|
+
return str(t)
|
|
102
|
+
|
|
103
|
+
def _format_types(types, names=None):
|
|
104
|
+
return ", ".join(
|
|
105
|
+
f"{(names[i] + ': ') if names else f'arg{i}: '}{_type_to_str(ty)}"
|
|
106
|
+
for i, ty in enumerate(types)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
if expected_names:
|
|
110
|
+
normalized_args = []
|
|
111
|
+
for i, name in enumerate(expected_names):
|
|
112
|
+
if i < len(args):
|
|
113
|
+
normalized_args.append(args[i])
|
|
114
|
+
elif name in kwargs:
|
|
115
|
+
normalized_args.append(kwargs[name])
|
|
116
|
+
else:
|
|
117
|
+
raise AssertionError(
|
|
118
|
+
f"{func_name} validation failed: "
|
|
119
|
+
f"Missing required argument '{name}'. "
|
|
120
|
+
f"Expected arguments: {expected_names}."
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
normalized_args = list(args)
|
|
124
|
+
if len(normalized_args) != len(types_order):
|
|
125
|
+
expected = _format_types(types_order, expected_names)
|
|
126
|
+
got_types = ", ".join(type(arg).__name__ for arg in normalized_args)
|
|
127
|
+
raise AssertionError(
|
|
128
|
+
f"{func_name} validation failed: "
|
|
129
|
+
f"Expected exactly {len(types_order)} arguments ({expected}), "
|
|
130
|
+
f"but got {len(normalized_args)} argument(s) of type(s): ({got_types}). "
|
|
131
|
+
f"Correct usage example: {func_name}({expected})"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
for i, (arg, expected_type) in enumerate(zip(normalized_args, types_order)):
|
|
135
|
+
origin = get_origin(expected_type)
|
|
136
|
+
if origin is Union:
|
|
137
|
+
allowed_types = get_args(expected_type)
|
|
138
|
+
else:
|
|
139
|
+
allowed_types = (expected_type,)
|
|
140
|
+
|
|
141
|
+
if not isinstance(arg, allowed_types):
|
|
142
|
+
allowed_str = " | ".join(t.__name__ for t in allowed_types)
|
|
143
|
+
raise AssertionError(
|
|
144
|
+
f"{func_name} validation failed: "
|
|
145
|
+
f"Argument '{expected_names[i] if expected_names else f'arg{i}'}' "
|
|
146
|
+
f"expected type {allowed_str}, but got {type(arg).__name__}. "
|
|
147
|
+
f"Correct usage example: {func_name}({_format_types(types_order, expected_names)})"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def validate_output_structure(result, func_name: str, expected_type_name="np.ndarray", gt_flag=False):
|
|
152
|
+
if result is None or (isinstance(result, float) and np.isnan(result)):
|
|
153
|
+
if gt_flag:
|
|
154
|
+
raise AssertionError(
|
|
155
|
+
f"{func_name} validation failed: "
|
|
156
|
+
f"The function returned {result!r}. "
|
|
157
|
+
f"If you are working with an unlabeled dataset and no ground truth is available, "
|
|
158
|
+
f"use 'return np.array([], dtype=np.float32)' instead. "
|
|
159
|
+
f"Otherwise, {func_name} expected a single {expected_type_name} object. "
|
|
160
|
+
f"Make sure the function ends with 'return <{expected_type_name}>'."
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
raise AssertionError(
|
|
164
|
+
f"{func_name} validation failed: "
|
|
165
|
+
f"The function returned None. "
|
|
166
|
+
f"Expected a single {expected_type_name} object. "
|
|
167
|
+
f"Make sure the function ends with 'return <{expected_type_name}>'."
|
|
168
|
+
)
|
|
169
|
+
if isinstance(result, tuple):
|
|
170
|
+
element_descriptions = [
|
|
171
|
+
f"[{i}] type: {type(r).__name__}"
|
|
172
|
+
for i, r in enumerate(result)
|
|
173
|
+
]
|
|
174
|
+
element_summary = "\n ".join(element_descriptions)
|
|
175
|
+
|
|
176
|
+
raise AssertionError(
|
|
177
|
+
f"{func_name} validation failed: "
|
|
178
|
+
f"The function returned multiple outputs ({len(result)} values), "
|
|
179
|
+
f"but only a single {expected_type_name} is allowed.\n\n"
|
|
180
|
+
f"Returned elements:\n"
|
|
181
|
+
f" {element_summary}\n\n"
|
|
182
|
+
f"Correct usage example:\n"
|
|
183
|
+
f" def {func_name}(...):\n"
|
|
184
|
+
f" return <{expected_type_name}>\n\n"
|
|
185
|
+
f"If you intended to return multiple values, combine them into a single "
|
|
186
|
+
f"{expected_type_name} (e.g., by concatenation or stacking)."
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def batch_warning(result, func_name):
|
|
191
|
+
if len(result.shape) > 0 and result.shape[0] == 1:
|
|
192
|
+
warnings.warn(
|
|
193
|
+
f"{func_name} warning: Tensorleap will add a batch dimension at axis 0 to the output of {func_name}, "
|
|
194
|
+
f"although the detected size of axis 0 is already 1. "
|
|
195
|
+
f"This may lead to an extra batch dimension (e.g., shape (1, 1, ...)). "
|
|
196
|
+
f"Please ensure that the output of '{func_name}' is not already batched "
|
|
197
|
+
f"to avoid computation errors."
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
|
|
202
|
+
connection_destinations = [connection_destination for connection_destination in connection_destinations
|
|
203
|
+
if not isinstance(connection_destination, SamplePreprocessResponse)]
|
|
204
|
+
|
|
205
|
+
main_node_mapping = NodeMapping(name, node_mapping_type, user_unique_name, arg_names=arg_names)
|
|
206
|
+
|
|
207
|
+
node_inputs = {}
|
|
208
|
+
for arg_name, destination in zip(arg_names, connection_destinations):
|
|
209
|
+
node_inputs[arg_name] = destination.node_mapping
|
|
210
|
+
|
|
211
|
+
leap_binder.mapping_connections.append(NodeConnection(main_node_mapping, node_inputs))
|
|
17
212
|
|
|
18
213
|
|
|
19
214
|
def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
|
|
20
215
|
for user_unique_name, connection_destinations in connects_to.items():
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
216
|
+
_add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type)
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def tensorleap_integration_test():
|
|
220
|
+
def decorating_function(integration_test_function: Callable):
|
|
221
|
+
leap_binder.integration_test_func = integration_test_function
|
|
222
|
+
|
|
223
|
+
def _validate_input_args(*args, **kwargs):
|
|
224
|
+
sample_id, preprocess_response = args
|
|
225
|
+
assert type(sample_id) == preprocess_response.sample_id_type, (
|
|
226
|
+
f"tensorleap_integration_test validation failed: "
|
|
227
|
+
f"sample_id type ({type(sample_id).__name__}) does not match the expected "
|
|
228
|
+
f"type ({preprocess_response.sample_id_type}) from the PreprocessResponse."
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
def inner(*args, **kwargs):
|
|
232
|
+
if not _call_from_tl_platform:
|
|
233
|
+
set_current('tensorleap_integration_test')
|
|
234
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
235
|
+
func_name='integration_test', expected_names=["idx", "preprocess"], **kwargs)
|
|
236
|
+
_validate_input_args(*args, **kwargs)
|
|
237
|
+
|
|
238
|
+
global _called_from_inside_tl_integration_test_decorator
|
|
239
|
+
# Clear integration test events for new test
|
|
240
|
+
try:
|
|
241
|
+
clear_integration_events()
|
|
242
|
+
except Exception as e:
|
|
243
|
+
logger.debug(f"Failed to clear integration events: {e}")
|
|
244
|
+
try:
|
|
245
|
+
_called_from_inside_tl_integration_test_decorator = True
|
|
246
|
+
if not _call_from_tl_platform:
|
|
247
|
+
update_env_params_func("tensorleap_integration_test",
|
|
248
|
+
"v") # put here because otherwise it will become v only if it finishes all the script
|
|
249
|
+
ret = integration_test_function(*args, **kwargs)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
os.environ[mapping_runtime_mode_env_var_mame] = 'True'
|
|
253
|
+
integration_test_function(None, PreprocessResponse(state=DataStateType.training, length=0))
|
|
254
|
+
except Exception as e:
|
|
255
|
+
import traceback
|
|
256
|
+
first_tb = traceback.extract_tb(e.__traceback__)[-1]
|
|
257
|
+
file_name = Path(first_tb.filename).name
|
|
258
|
+
line_number = first_tb.lineno
|
|
259
|
+
if isinstance(e, TypeError) and 'is not subscriptable' in str(e):
|
|
260
|
+
update_env_params_func("code_mapping", "x")
|
|
261
|
+
raise (f'Invalid integration code. File {file_name}, line {line_number}: '
|
|
262
|
+
f"indexing is supported only on the model's predictions inside the integration test. Please remove this indexing operation usage from the integration test code.")
|
|
263
|
+
else:
|
|
264
|
+
update_env_params_func("code_mapping", "x")
|
|
265
|
+
|
|
266
|
+
raise (f'Invalid integration code. File {file_name}, line {line_number}: '
|
|
267
|
+
f'Integration test is only allowed to call Tensorleap decorators. '
|
|
268
|
+
f'Ensure any arithmetics, external library use, Python logic is placed within Tensorleap decoders')
|
|
269
|
+
finally:
|
|
270
|
+
if mapping_runtime_mode_env_var_mame in os.environ:
|
|
271
|
+
del os.environ[mapping_runtime_mode_env_var_mame]
|
|
272
|
+
finally:
|
|
273
|
+
_called_from_inside_tl_integration_test_decorator = False
|
|
274
|
+
|
|
275
|
+
leap_binder.check()
|
|
276
|
+
|
|
277
|
+
return inner
|
|
278
|
+
|
|
279
|
+
return decorating_function
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _safe_get_item(key):
|
|
283
|
+
try:
|
|
284
|
+
return NodeMappingType[f'Input{str(key)}']
|
|
285
|
+
except ValueError:
|
|
286
|
+
raise Exception(f'Tensorleap currently supports models with no more then 10 inputs')
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = _UNSET):
|
|
290
|
+
prediction_types_was_provided = prediction_types is not _UNSET
|
|
291
|
+
|
|
292
|
+
if not prediction_types_was_provided:
|
|
293
|
+
prediction_types = []
|
|
294
|
+
if not _call_from_tl_platform:
|
|
295
|
+
store_warning_by_param(
|
|
296
|
+
param_name="prediction_types",
|
|
297
|
+
user_func_name="tensorleap_load_model",
|
|
298
|
+
default_value=prediction_types,
|
|
299
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
|
|
300
|
+
)
|
|
301
|
+
assert isinstance(prediction_types, list), (
|
|
302
|
+
f"tensorleap_load_model validation failed: "
|
|
303
|
+
f" prediction_types is an optional argument of type List[PredictionTypeHandler]] but got {type(prediction_types).__name__}."
|
|
304
|
+
)
|
|
305
|
+
for i, prediction_type in enumerate(prediction_types):
|
|
306
|
+
assert isinstance(prediction_type, PredictionTypeHandler), (f"tensorleap_load_model validation failed: "
|
|
307
|
+
f" prediction_types at position {i} must be of type PredictionTypeHandler but got {type(prediction_types[i]).__name__}.")
|
|
308
|
+
prediction_type_channel_dim_was_provided = prediction_type.channel_dim != "tl_default_value"
|
|
309
|
+
if not prediction_type_channel_dim_was_provided:
|
|
310
|
+
prediction_types[i].channel_dim = -1
|
|
311
|
+
if not _call_from_tl_platform:
|
|
312
|
+
store_warning_by_param(
|
|
313
|
+
param_name=f"prediction_types[{i}].channel_dim",
|
|
314
|
+
user_func_name="tensorleap_load_model",
|
|
315
|
+
default_value=prediction_types[i].channel_dim,
|
|
316
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
|
|
317
|
+
)
|
|
318
|
+
leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
|
|
319
|
+
|
|
320
|
+
def _validate_result(result) -> None:
|
|
321
|
+
valid_types = ["onnxruntime", "keras"]
|
|
322
|
+
err_message = f"tensorleap_load_model validation failed:\nSupported models are Keras and onnxruntime only and non of them was returned."
|
|
323
|
+
validate_output_structure(result, func_name="tensorleap_load_model",
|
|
324
|
+
expected_type_name=[" | ".join(t for t in valid_types)][0])
|
|
325
|
+
try:
|
|
326
|
+
import keras
|
|
327
|
+
except ImportError:
|
|
328
|
+
keras = None
|
|
329
|
+
try:
|
|
330
|
+
import tensorflow as tf
|
|
331
|
+
except ImportError:
|
|
332
|
+
tf = None
|
|
333
|
+
try:
|
|
334
|
+
import onnxruntime
|
|
335
|
+
except ImportError:
|
|
336
|
+
onnxruntime = None
|
|
337
|
+
|
|
338
|
+
if not keras and not onnxruntime:
|
|
339
|
+
raise AssertionError(err_message)
|
|
340
|
+
|
|
341
|
+
is_keras_model = (
|
|
342
|
+
bool(keras and isinstance(result, getattr(keras, "Model", tuple())))
|
|
343
|
+
or bool(tf and isinstance(result, getattr(tf.keras, "Model", tuple())))
|
|
344
|
+
)
|
|
345
|
+
is_onnx_model = bool(onnxruntime and isinstance(result, onnxruntime.InferenceSession))
|
|
346
|
+
|
|
347
|
+
if not any([is_keras_model, is_onnx_model]):
|
|
348
|
+
raise AssertionError(err_message)
|
|
349
|
+
|
|
350
|
+
def decorating_function(load_model_func, prediction_types=prediction_types):
|
|
351
|
+
class TempMapping:
|
|
352
|
+
pass
|
|
353
|
+
|
|
354
|
+
@lru_cache()
|
|
355
|
+
def inner(*args, **kwargs):
|
|
356
|
+
if not _call_from_tl_platform:
|
|
357
|
+
set_current('tensorleap_load_model')
|
|
358
|
+
validate_args_structure(*args, types_order=[],
|
|
359
|
+
func_name='tensorleap_load_model', expected_names=[], **kwargs)
|
|
360
|
+
|
|
361
|
+
class ModelPlaceholder:
|
|
362
|
+
def __init__(self, prediction_types):
|
|
363
|
+
self.model = load_model_func() # TODO- check why this fails on onnx model
|
|
364
|
+
self.prediction_types = prediction_types
|
|
365
|
+
_validate_result(self.model)
|
|
366
|
+
|
|
367
|
+
# keras interface
|
|
368
|
+
def __call__(self, arg):
|
|
369
|
+
ret = self.model(arg)
|
|
370
|
+
self.validate_declared_prediction_types(ret)
|
|
371
|
+
if isinstance(ret, list):
|
|
372
|
+
return [r.numpy() for r in ret]
|
|
373
|
+
return ret.numpy()
|
|
374
|
+
|
|
375
|
+
def validate_declared_prediction_types(self, ret):
|
|
376
|
+
if not (len(self.prediction_types) == len(ret) if isinstance(ret, list) else 1) and len(
|
|
377
|
+
self.prediction_types) != 0:
|
|
378
|
+
if not _call_from_tl_platform:
|
|
379
|
+
update_env_params_func("tensorleap_load_model", "x")
|
|
380
|
+
raise Exception(
|
|
381
|
+
f"tensorleap_load_model validation failed: number of declared prediction types({len(prediction_types)}) != number of model outputs({len(ret) if isinstance(ret, list) else 1})")
|
|
382
|
+
|
|
383
|
+
def _convert_onnx_inputs_to_correct_type(
|
|
384
|
+
self, float_arrays_inputs: Dict[str, np.ndarray]
|
|
385
|
+
) -> Dict[str, np.ndarray]:
|
|
386
|
+
ONNX_TYPE_TO_NP = {
|
|
387
|
+
"tensor(float)": np.float32,
|
|
388
|
+
"tensor(double)": np.float64,
|
|
389
|
+
"tensor(int64)": np.int64,
|
|
390
|
+
"tensor(int32)": np.int32,
|
|
391
|
+
"tensor(int16)": np.int16,
|
|
392
|
+
"tensor(int8)": np.int8,
|
|
393
|
+
"tensor(uint64)": np.uint64,
|
|
394
|
+
"tensor(uint32)": np.uint32,
|
|
395
|
+
"tensor(uint16)": np.uint16,
|
|
396
|
+
"tensor(uint8)": np.uint8,
|
|
397
|
+
"tensor(bool)": np.bool_,
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
"""
|
|
401
|
+
Cast user-provided NumPy inputs to match the dtypes/shapes
|
|
402
|
+
expected by an ONNX Runtime InferenceSession.
|
|
403
|
+
"""
|
|
404
|
+
coerced = {}
|
|
405
|
+
meta = {i.name: i for i in self.model.get_inputs()}
|
|
406
|
+
|
|
407
|
+
for name, arr in float_arrays_inputs.items():
|
|
408
|
+
if name not in meta:
|
|
409
|
+
# Keep as-is unless extra inputs are disallowed
|
|
410
|
+
coerced[name] = arr
|
|
411
|
+
continue
|
|
412
|
+
|
|
413
|
+
info = meta[name]
|
|
414
|
+
onnx_type = info.type
|
|
415
|
+
want_dtype = ONNX_TYPE_TO_NP.get(onnx_type)
|
|
416
|
+
|
|
417
|
+
if want_dtype is None:
|
|
418
|
+
raise TypeError(f"Unsupported ONNX input type: {onnx_type}")
|
|
419
|
+
|
|
420
|
+
# Cast dtype if needed
|
|
421
|
+
if arr.dtype != want_dtype:
|
|
422
|
+
arr = arr.astype(want_dtype, copy=False)
|
|
423
|
+
|
|
424
|
+
coerced[name] = arr
|
|
425
|
+
|
|
426
|
+
# Verify required inputs are present
|
|
427
|
+
missing = [n for n in meta if n not in coerced]
|
|
428
|
+
if missing:
|
|
429
|
+
raise KeyError(f"Missing required input(s): {sorted(missing)}")
|
|
430
|
+
|
|
431
|
+
return coerced
|
|
432
|
+
|
|
433
|
+
# onnx runtime interface
|
|
434
|
+
def run(self, output_names, input_dict):
|
|
435
|
+
corrected_type_inputs = self._convert_onnx_inputs_to_correct_type(input_dict)
|
|
436
|
+
ret = self.model.run(output_names, corrected_type_inputs)
|
|
437
|
+
self.validate_declared_prediction_types(ret)
|
|
438
|
+
return ret
|
|
439
|
+
|
|
440
|
+
def get_inputs(self):
|
|
441
|
+
return self.model.get_inputs()
|
|
442
|
+
|
|
443
|
+
model_placeholder = ModelPlaceholder(prediction_types)
|
|
444
|
+
if not _call_from_tl_platform:
|
|
445
|
+
update_env_params_func("tensorleap_load_model", "v")
|
|
446
|
+
return model_placeholder
|
|
447
|
+
|
|
448
|
+
def mapping_inner():
|
|
449
|
+
class ModelOutputPlaceholder:
|
|
450
|
+
def __init__(self):
|
|
451
|
+
self.node_mapping = NodeMapping('', NodeMappingType.Prediction0)
|
|
452
|
+
|
|
453
|
+
def __getitem__(self, key):
|
|
454
|
+
assert isinstance(key, int), \
|
|
455
|
+
f'Expected key to be an int, got {type(key)} instead.'
|
|
456
|
+
|
|
457
|
+
ret = TempMapping()
|
|
458
|
+
try:
|
|
459
|
+
ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
|
|
460
|
+
except ValueError as e:
|
|
461
|
+
raise Exception(f'Tensorleap currently supports models with no more then 10 active predictions,'
|
|
462
|
+
f' {key} not supported.')
|
|
463
|
+
return ret
|
|
464
|
+
|
|
465
|
+
class ModelPlaceholder:
|
|
466
|
+
|
|
467
|
+
# keras interface
|
|
468
|
+
def __call__(self, arg):
|
|
469
|
+
if isinstance(arg, list):
|
|
470
|
+
for i, elem in enumerate(arg):
|
|
471
|
+
elem.node_mapping.type = _safe_get_item(i)
|
|
472
|
+
else:
|
|
473
|
+
arg.node_mapping.type = NodeMappingType.Input0
|
|
474
|
+
|
|
475
|
+
return ModelOutputPlaceholder()
|
|
476
|
+
|
|
477
|
+
# onnx runtime interface
|
|
478
|
+
def run(self, output_names, input_dict):
|
|
479
|
+
assert output_names is None
|
|
480
|
+
assert isinstance(input_dict, dict), \
|
|
481
|
+
f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
|
|
482
|
+
for i, (input_key, elem) in enumerate(input_dict.items()):
|
|
483
|
+
if isinstance(input_key, NodeMappingType):
|
|
484
|
+
elem.node_mapping.type = input_key
|
|
485
|
+
else:
|
|
486
|
+
elem.node_mapping.type = _safe_get_item(i)
|
|
487
|
+
|
|
488
|
+
return ModelOutputPlaceholder()
|
|
489
|
+
|
|
490
|
+
def get_inputs(self):
|
|
491
|
+
class FollowIndex:
|
|
492
|
+
def __init__(self, index):
|
|
493
|
+
self.name = _safe_get_item(index)
|
|
494
|
+
|
|
495
|
+
class FollowInputIndex:
|
|
496
|
+
def __init__(self):
|
|
497
|
+
pass
|
|
498
|
+
|
|
499
|
+
def __getitem__(self, index):
|
|
500
|
+
assert isinstance(index, int), \
|
|
501
|
+
f'Expected key to be an int, got {type(index)} instead.'
|
|
25
502
|
|
|
26
|
-
|
|
503
|
+
return FollowIndex(index)
|
|
504
|
+
|
|
505
|
+
return FollowInputIndex()
|
|
506
|
+
|
|
507
|
+
return ModelPlaceholder()
|
|
508
|
+
|
|
509
|
+
def final_inner(*args, **kwargs):
|
|
510
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
511
|
+
return mapping_inner()
|
|
512
|
+
else:
|
|
513
|
+
return inner(*args, **kwargs)
|
|
514
|
+
|
|
515
|
+
return final_inner
|
|
516
|
+
|
|
517
|
+
return decorating_function
|
|
27
518
|
|
|
28
519
|
|
|
29
520
|
def tensorleap_custom_metric(name: str,
|
|
30
|
-
direction: Union[MetricDirection, Dict[str, MetricDirection]] =
|
|
521
|
+
direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
|
|
31
522
|
compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
|
|
32
523
|
connects_to=None):
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
524
|
+
name_to_unique_name = defaultdict(set)
|
|
525
|
+
|
|
526
|
+
def decorating_function(
|
|
527
|
+
user_function: Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs,
|
|
528
|
+
ConfusionMatrixCallableInterfaceMultiArgs]):
|
|
529
|
+
nonlocal direction
|
|
530
|
+
|
|
531
|
+
direction_was_provided = direction is not _UNSET
|
|
532
|
+
|
|
533
|
+
if not direction_was_provided:
|
|
534
|
+
direction = MetricDirection.Downward
|
|
535
|
+
if not _call_from_tl_platform:
|
|
536
|
+
store_warning_by_param(
|
|
537
|
+
param_name="direction",
|
|
538
|
+
user_func_name=user_function.__name__,
|
|
539
|
+
default_value=direction,
|
|
540
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/custom-metrics"
|
|
541
|
+
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
def _validate_decorators_signature():
|
|
545
|
+
err_message = f"{user_function.__name__} validation failed.\n"
|
|
546
|
+
if not isinstance(name, str):
|
|
547
|
+
raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
|
|
548
|
+
valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
|
|
549
|
+
if isinstance(direction, MetricDirection):
|
|
550
|
+
if direction not in valid_directions:
|
|
551
|
+
raise ValueError(
|
|
552
|
+
err_message +
|
|
553
|
+
f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
|
|
554
|
+
f"got type {type(direction).__name__}."
|
|
555
|
+
)
|
|
556
|
+
elif isinstance(direction, dict):
|
|
557
|
+
if not all(isinstance(k, str) for k in direction.keys()):
|
|
558
|
+
invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
|
|
559
|
+
raise TypeError(
|
|
560
|
+
err_message +
|
|
561
|
+
f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
|
|
562
|
+
)
|
|
563
|
+
for k, v in direction.items():
|
|
564
|
+
if v not in valid_directions:
|
|
565
|
+
raise ValueError(
|
|
566
|
+
err_message +
|
|
567
|
+
f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
|
|
568
|
+
f"got type {type(v).__name__}."
|
|
569
|
+
)
|
|
570
|
+
else:
|
|
571
|
+
raise TypeError(
|
|
572
|
+
err_message +
|
|
573
|
+
f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
|
|
574
|
+
f"got type {type(direction).__name__}."
|
|
575
|
+
)
|
|
576
|
+
if compute_insights is not None:
|
|
577
|
+
if not isinstance(compute_insights, (bool, dict)):
|
|
578
|
+
raise TypeError(
|
|
579
|
+
err_message +
|
|
580
|
+
f"`compute_insights` must be a bool or a Dict[str, bool], "
|
|
581
|
+
f"got type {type(compute_insights).__name__}."
|
|
582
|
+
)
|
|
583
|
+
if isinstance(compute_insights, dict):
|
|
584
|
+
if not all(isinstance(k, str) for k in compute_insights.keys()):
|
|
585
|
+
invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
|
|
586
|
+
raise TypeError(
|
|
587
|
+
err_message +
|
|
588
|
+
f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
|
|
589
|
+
)
|
|
590
|
+
for k, v in compute_insights.items():
|
|
591
|
+
if not isinstance(v, bool):
|
|
592
|
+
raise TypeError(
|
|
593
|
+
err_message +
|
|
594
|
+
f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
|
|
595
|
+
)
|
|
596
|
+
if connects_to is not None:
|
|
597
|
+
valid_types = (str, list, tuple, set)
|
|
598
|
+
if not isinstance(connects_to, valid_types):
|
|
599
|
+
raise TypeError(
|
|
600
|
+
err_message +
|
|
601
|
+
f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
|
|
602
|
+
)
|
|
603
|
+
if isinstance(connects_to, (list, tuple, set)):
|
|
604
|
+
invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
|
|
605
|
+
if invalid_elems:
|
|
606
|
+
raise TypeError(
|
|
607
|
+
err_message +
|
|
608
|
+
f"All elements in `connects_to` must be strings, "
|
|
609
|
+
f"but found element types: {invalid_elems}."
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
_validate_decorators_signature()
|
|
613
|
+
|
|
36
614
|
for metric_handler in leap_binder.setup_container.metrics:
|
|
37
615
|
if metric_handler.metric_handler_data.name == name:
|
|
38
616
|
raise Exception(f'Metric with name {name} already exists. '
|
|
39
617
|
f'Please choose another')
|
|
40
618
|
|
|
41
|
-
leap_binder.add_custom_metric(user_function, name, direction, compute_insights)
|
|
42
|
-
|
|
43
|
-
if connects_to is not None:
|
|
44
|
-
arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
|
|
45
|
-
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
|
|
46
|
-
|
|
47
619
|
def _validate_input_args(*args, **kwargs) -> None:
|
|
620
|
+
assert len(args) + len(kwargs) > 0, (
|
|
621
|
+
f"{user_function.__name__}() validation failed: "
|
|
622
|
+
f"Expected at least one positional|key-word argument of type np.ndarray, "
|
|
623
|
+
f"but received none. "
|
|
624
|
+
f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
|
|
625
|
+
)
|
|
48
626
|
for i, arg in enumerate(args):
|
|
49
627
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
50
|
-
f'
|
|
628
|
+
f'{user_function.__name__}() validation failed: '
|
|
51
629
|
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
52
630
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
53
631
|
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
|
54
|
-
(f'
|
|
632
|
+
(f'{user_function.__name__}() validation failed: Argument #{i} '
|
|
55
633
|
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
|
56
634
|
f'instead of {leap_binder.batch_size_to_validate}')
|
|
57
635
|
|
|
58
636
|
for _arg_name, arg in kwargs.items():
|
|
59
637
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
60
|
-
f'
|
|
638
|
+
f'{user_function.__name__}() validation failed: '
|
|
61
639
|
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
62
640
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
63
641
|
assert arg.shape[0] == leap_binder.batch_size_to_validate, \
|
|
64
|
-
(f'
|
|
642
|
+
(f'{user_function.__name__}() validation failed: Argument {_arg_name} '
|
|
65
643
|
f'first dim should be as the batch size. Got {arg.shape[0]} '
|
|
66
644
|
f'instead of {leap_binder.batch_size_to_validate}')
|
|
67
645
|
|
|
68
646
|
def _validate_result(result) -> None:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
647
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
648
|
+
expected_type_name="List[float | int | None | List[ConfusionMatrixElement] ] | NDArray[np.float32] or dictonary with one of these types as its values types")
|
|
649
|
+
supported_types_message = (f'{user_function.__name__}() validation failed: '
|
|
650
|
+
f'{user_function.__name__}() has returned unsupported type.\nSupported types are List[float|int|None], '
|
|
651
|
+
f'List[List[ConfusionMatrixElement]], NDArray[np.float32] or dictonary with one of these types as its values types. ')
|
|
72
652
|
|
|
73
|
-
def _validate_single_metric(single_metric_result):
|
|
653
|
+
def _validate_single_metric(single_metric_result, key=None):
|
|
74
654
|
if isinstance(single_metric_result, list):
|
|
75
655
|
if isinstance(single_metric_result[0], list):
|
|
76
|
-
assert isinstance(single_metric_result[0]
|
|
77
|
-
f
|
|
656
|
+
assert all(isinstance(cm, ConfusionMatrixElement) for cm in single_metric_result[0]), (
|
|
657
|
+
f"{supported_types_message} "
|
|
658
|
+
f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
|
|
659
|
+
f"List[List[{', '.join(type(cm).__name__ for cm in single_metric_result[0])}]]."
|
|
660
|
+
)
|
|
661
|
+
|
|
78
662
|
else:
|
|
79
|
-
assert isinstance(single_metric_result
|
|
80
|
-
|
|
663
|
+
assert all(isinstance(v, (float, int, type(None), np.float32)) for v in single_metric_result), (
|
|
664
|
+
f"{supported_types_message}\n"
|
|
665
|
+
f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
|
|
666
|
+
f"List[{', '.join(type(v).__name__ for v in single_metric_result)}]."
|
|
667
|
+
)
|
|
81
668
|
else:
|
|
82
669
|
assert isinstance(single_metric_result,
|
|
83
|
-
np.ndarray), f'{supported_types_message}
|
|
84
|
-
assert len(single_metric_result.shape) == 1, (f'
|
|
670
|
+
np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
|
|
671
|
+
assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
|
|
85
672
|
f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
|
|
86
673
|
|
|
87
674
|
if leap_binder.batch_size_to_validate:
|
|
88
675
|
assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
|
|
89
|
-
f'
|
|
676
|
+
f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
|
|
90
677
|
|
|
91
678
|
if isinstance(result, dict):
|
|
92
679
|
for key, value in result.items():
|
|
680
|
+
_validate_single_metric(value, key)
|
|
681
|
+
|
|
93
682
|
assert isinstance(key, str), \
|
|
94
|
-
(f'
|
|
683
|
+
(f'{user_function.__name__}() validation failed: '
|
|
95
684
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
|
96
685
|
_validate_single_metric(value)
|
|
97
686
|
|
|
98
687
|
if isinstance(direction, dict):
|
|
99
688
|
for direction_key in direction:
|
|
100
689
|
assert direction_key in result, \
|
|
101
|
-
(f'
|
|
690
|
+
(f'{user_function.__name__}() validation failed: '
|
|
102
691
|
f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
|
|
103
692
|
|
|
104
693
|
if compute_insights is not None:
|
|
105
694
|
assert isinstance(compute_insights, dict), \
|
|
106
|
-
(f'
|
|
695
|
+
(f'{user_function.__name__}() validation failed: '
|
|
107
696
|
f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
|
|
108
697
|
|
|
109
698
|
for ci_key in compute_insights:
|
|
110
699
|
assert ci_key in result, \
|
|
111
|
-
(f'
|
|
700
|
+
(f'{user_function.__name__}() validation failed: '
|
|
112
701
|
f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
|
|
113
702
|
|
|
114
703
|
else:
|
|
@@ -116,17 +705,71 @@ def tensorleap_custom_metric(name: str,
|
|
|
116
705
|
|
|
117
706
|
if compute_insights is not None:
|
|
118
707
|
assert isinstance(compute_insights, bool), \
|
|
119
|
-
(f'
|
|
708
|
+
(f'{user_function.__name__}() validation failed: '
|
|
120
709
|
f'compute_insights should be boolean. Got {type(compute_insights)}.')
|
|
121
710
|
|
|
711
|
+
@functools.wraps(user_function)
|
|
712
|
+
def inner_without_validate(*args, **kwargs):
|
|
713
|
+
global _called_from_inside_tl_decorator
|
|
714
|
+
_called_from_inside_tl_decorator += 1
|
|
715
|
+
|
|
716
|
+
try:
|
|
717
|
+
result = user_function(*args, **kwargs)
|
|
718
|
+
finally:
|
|
719
|
+
_called_from_inside_tl_decorator -= 1
|
|
720
|
+
|
|
721
|
+
return result
|
|
722
|
+
|
|
723
|
+
try:
|
|
724
|
+
inner_without_validate.__signature__ = inspect.signature(user_function)
|
|
725
|
+
except (TypeError, ValueError):
|
|
726
|
+
pass
|
|
727
|
+
|
|
728
|
+
leap_binder.add_custom_metric(inner_without_validate, name, direction, compute_insights)
|
|
729
|
+
|
|
730
|
+
if connects_to is not None:
|
|
731
|
+
arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
|
|
732
|
+
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
|
|
122
733
|
|
|
123
734
|
def inner(*args, **kwargs):
|
|
735
|
+
if not _call_from_tl_platform:
|
|
736
|
+
set_current('tensorleap_custom_metric')
|
|
124
737
|
_validate_input_args(*args, **kwargs)
|
|
125
|
-
|
|
738
|
+
|
|
739
|
+
result = inner_without_validate(*args, **kwargs)
|
|
740
|
+
|
|
126
741
|
_validate_result(result)
|
|
742
|
+
if not _call_from_tl_platform:
|
|
743
|
+
update_env_params_func("tensorleap_custom_metric", "v")
|
|
127
744
|
return result
|
|
128
745
|
|
|
129
|
-
|
|
746
|
+
def mapping_inner(*args, **kwargs):
|
|
747
|
+
user_unique_name = mapping_inner.name
|
|
748
|
+
if 'user_unique_name' in kwargs:
|
|
749
|
+
user_unique_name = kwargs['user_unique_name']
|
|
750
|
+
|
|
751
|
+
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
752
|
+
ordered_connections = list(args) + ordered_connections
|
|
753
|
+
|
|
754
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
755
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
756
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
757
|
+
|
|
758
|
+
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
759
|
+
mapping_inner.name, NodeMappingType.Metric)
|
|
760
|
+
|
|
761
|
+
return None
|
|
762
|
+
|
|
763
|
+
mapping_inner.arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
|
|
764
|
+
mapping_inner.name = name
|
|
765
|
+
|
|
766
|
+
def final_inner(*args, **kwargs):
|
|
767
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
768
|
+
return mapping_inner(*args, **kwargs)
|
|
769
|
+
else:
|
|
770
|
+
return inner(*args, **kwargs)
|
|
771
|
+
|
|
772
|
+
return final_inner
|
|
130
773
|
|
|
131
774
|
return decorating_function
|
|
132
775
|
|
|
@@ -134,35 +777,40 @@ def tensorleap_custom_metric(name: str,
|
|
|
134
777
|
def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
135
778
|
heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
|
|
136
779
|
connects_to=None):
|
|
780
|
+
name_to_unique_name = defaultdict(set)
|
|
781
|
+
|
|
137
782
|
def decorating_function(user_function: VisualizerCallableInterface):
|
|
783
|
+
assert isinstance(visualizer_type, LeapDataType), (f"{user_function.__name__} validation failed: "
|
|
784
|
+
f"visualizer_type should be of type {LeapDataType.__name__} but got {type(visualizer_type)}"
|
|
785
|
+
)
|
|
138
786
|
for viz_handler in leap_binder.setup_container.visualizers:
|
|
139
787
|
if viz_handler.visualizer_handler_data.name == name:
|
|
140
788
|
raise Exception(f'Visualizer with name {name} already exists. '
|
|
141
789
|
f'Please choose another')
|
|
142
790
|
|
|
143
|
-
leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function)
|
|
144
|
-
|
|
145
|
-
if connects_to is not None:
|
|
146
|
-
arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
|
|
147
|
-
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
|
|
148
|
-
|
|
149
791
|
def _validate_input_args(*args, **kwargs):
|
|
792
|
+
assert len(args) + len(kwargs) > 0, (
|
|
793
|
+
f"{user_function.__name__}() validation failed: "
|
|
794
|
+
f"Expected at least one positional|key-word argument of type np.ndarray, "
|
|
795
|
+
f"but received none. "
|
|
796
|
+
f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
|
|
797
|
+
)
|
|
150
798
|
for i, arg in enumerate(args):
|
|
151
799
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
152
|
-
f'
|
|
800
|
+
f'{user_function.__name__}() validation failed: '
|
|
153
801
|
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
154
802
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
155
803
|
assert arg.shape[0] != leap_binder.batch_size_to_validate, \
|
|
156
|
-
(f'
|
|
804
|
+
(f'{user_function.__name__}() validation failed: '
|
|
157
805
|
f'Argument #{i} should be without batch dimension. ')
|
|
158
806
|
|
|
159
807
|
for _arg_name, arg in kwargs.items():
|
|
160
808
|
assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
|
|
161
|
-
f'
|
|
809
|
+
f'{user_function.__name__}() validation failed: '
|
|
162
810
|
f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
163
811
|
if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
|
|
164
812
|
assert arg.shape[0] != leap_binder.batch_size_to_validate, \
|
|
165
|
-
(f'
|
|
813
|
+
(f'{user_function.__name__}() validation failed: Argument {_arg_name} '
|
|
166
814
|
f'should be without batch dimension. ')
|
|
167
815
|
|
|
168
816
|
def _validate_result(result):
|
|
@@ -176,17 +824,74 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
|
|
|
176
824
|
LeapDataType.ImageWithBBox: LeapImageWithBBox,
|
|
177
825
|
LeapDataType.ImageWithHeatmap: LeapImageWithHeatmap
|
|
178
826
|
}
|
|
827
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
828
|
+
expected_type_name=result_type_map[visualizer_type])
|
|
829
|
+
|
|
179
830
|
assert isinstance(result, result_type_map[visualizer_type]), \
|
|
180
|
-
(f'
|
|
831
|
+
(f'{user_function.__name__}() validation failed: '
|
|
181
832
|
f'The return type should be {result_type_map[visualizer_type]}. Got {type(result)}.')
|
|
182
833
|
|
|
834
|
+
@functools.wraps(user_function)
|
|
835
|
+
def inner_without_validate(*args, **kwargs):
|
|
836
|
+
global _called_from_inside_tl_decorator
|
|
837
|
+
_called_from_inside_tl_decorator += 1
|
|
838
|
+
|
|
839
|
+
try:
|
|
840
|
+
result = user_function(*args, **kwargs)
|
|
841
|
+
finally:
|
|
842
|
+
_called_from_inside_tl_decorator -= 1
|
|
843
|
+
|
|
844
|
+
return result
|
|
845
|
+
|
|
846
|
+
try:
|
|
847
|
+
inner_without_validate.__signature__ = inspect.signature(user_function)
|
|
848
|
+
except (TypeError, ValueError):
|
|
849
|
+
pass
|
|
850
|
+
|
|
851
|
+
leap_binder.set_visualizer(inner_without_validate, name, visualizer_type, heatmap_function)
|
|
852
|
+
|
|
853
|
+
if connects_to is not None:
|
|
854
|
+
arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
|
|
855
|
+
_add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
|
|
856
|
+
|
|
183
857
|
def inner(*args, **kwargs):
|
|
858
|
+
if not _call_from_tl_platform:
|
|
859
|
+
set_current('tensorleap_custom_visualizer')
|
|
184
860
|
_validate_input_args(*args, **kwargs)
|
|
185
|
-
|
|
861
|
+
|
|
862
|
+
result = inner_without_validate(*args, **kwargs)
|
|
863
|
+
|
|
186
864
|
_validate_result(result)
|
|
865
|
+
if not _call_from_tl_platform:
|
|
866
|
+
update_env_params_func("tensorleap_custom_visualizer", "v")
|
|
187
867
|
return result
|
|
188
868
|
|
|
189
|
-
|
|
869
|
+
def mapping_inner(*args, **kwargs):
|
|
870
|
+
user_unique_name = mapping_inner.name
|
|
871
|
+
if 'user_unique_name' in kwargs:
|
|
872
|
+
user_unique_name = kwargs['user_unique_name']
|
|
873
|
+
|
|
874
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
875
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
876
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
877
|
+
|
|
878
|
+
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
879
|
+
ordered_connections = list(args) + ordered_connections
|
|
880
|
+
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
881
|
+
mapping_inner.name, NodeMappingType.Visualizer)
|
|
882
|
+
|
|
883
|
+
return None
|
|
884
|
+
|
|
885
|
+
mapping_inner.arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
|
|
886
|
+
mapping_inner.name = name
|
|
887
|
+
|
|
888
|
+
def final_inner(*args, **kwargs):
|
|
889
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
890
|
+
return mapping_inner(*args, **kwargs)
|
|
891
|
+
else:
|
|
892
|
+
return inner(*args, **kwargs)
|
|
893
|
+
|
|
894
|
+
return final_inner
|
|
190
895
|
|
|
191
896
|
return decorating_function
|
|
192
897
|
|
|
@@ -199,38 +904,105 @@ def tensorleap_metadata(
|
|
|
199
904
|
raise Exception(f'Metadata with name {name} already exists. '
|
|
200
905
|
f'Please choose another')
|
|
201
906
|
|
|
202
|
-
leap_binder.set_metadata(user_function, name, metadata_type)
|
|
203
|
-
|
|
204
907
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
205
|
-
assert isinstance(sample_id, (int, str)), \
|
|
206
|
-
(f'tensorleap_metadata validation failed: '
|
|
207
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
208
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
209
|
-
(f'tensorleap_metadata validation failed: '
|
|
210
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
211
908
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
212
|
-
(f'
|
|
909
|
+
(f'{user_function.__name__}() validation failed: '
|
|
213
910
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
214
911
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
215
912
|
|
|
216
913
|
def _validate_result(result):
|
|
217
914
|
supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
|
|
218
915
|
np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
|
|
916
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
917
|
+
expected_type_name=supported_result_types)
|
|
219
918
|
assert isinstance(result, supported_result_types), \
|
|
220
|
-
(f'
|
|
919
|
+
(f'{user_function.__name__}() validation failed: '
|
|
221
920
|
f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
|
|
222
921
|
if isinstance(result, dict):
|
|
223
922
|
for key, value in result.items():
|
|
224
923
|
assert isinstance(key, str), \
|
|
225
|
-
(f'
|
|
924
|
+
(f'{user_function.__name__}() validation failed: '
|
|
226
925
|
f'Keys in the return dict should be of type str. Got {type(key)}.')
|
|
227
926
|
assert isinstance(value, supported_result_types), \
|
|
228
|
-
(f'
|
|
927
|
+
(f'{user_function.__name__}() validation failed: '
|
|
229
928
|
f'Values in the return dict should be of type {str(supported_result_types)}. Got {type(value)}.')
|
|
230
929
|
|
|
930
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
931
|
+
|
|
932
|
+
global _called_from_inside_tl_decorator
|
|
933
|
+
_called_from_inside_tl_decorator += 1
|
|
934
|
+
|
|
935
|
+
try:
|
|
936
|
+
result = user_function(sample_id, preprocess_response)
|
|
937
|
+
finally:
|
|
938
|
+
_called_from_inside_tl_decorator -= 1
|
|
939
|
+
|
|
940
|
+
return result
|
|
941
|
+
|
|
942
|
+
leap_binder.set_metadata(inner_without_validate, name, metadata_type)
|
|
943
|
+
|
|
944
|
+
def inner(*args, **kwargs):
|
|
945
|
+
if not _call_from_tl_platform:
|
|
946
|
+
set_current('tensorleap_metadata')
|
|
947
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
948
|
+
return None
|
|
949
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
950
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
951
|
+
sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
|
|
952
|
+
_validate_input_args(sample_id, preprocess_response)
|
|
953
|
+
|
|
954
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
955
|
+
|
|
956
|
+
_validate_result(result)
|
|
957
|
+
if not _call_from_tl_platform:
|
|
958
|
+
update_env_params_func("tensorleap_metadata", "v")
|
|
959
|
+
return result
|
|
960
|
+
|
|
961
|
+
return inner
|
|
962
|
+
|
|
963
|
+
return decorating_function
|
|
964
|
+
|
|
965
|
+
|
|
966
|
+
def tensorleap_custom_latent_space():
|
|
967
|
+
def decorating_function(user_function: SectionCallableInterface):
|
|
968
|
+
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
969
|
+
assert isinstance(sample_id, (int, str)), \
|
|
970
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
971
|
+
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
972
|
+
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
973
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
974
|
+
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
975
|
+
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
976
|
+
(f'tensorleap_custom_latent_space validation failed: '
|
|
977
|
+
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
978
|
+
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
979
|
+
|
|
980
|
+
def _validate_result(result):
|
|
981
|
+
assert isinstance(result, np.ndarray), \
|
|
982
|
+
(f'tensorleap_custom_loss validation failed: '
|
|
983
|
+
f'The return type should be a numpy array. Got {type(result)}.')
|
|
984
|
+
|
|
985
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
986
|
+
global _called_from_inside_tl_decorator
|
|
987
|
+
_called_from_inside_tl_decorator += 1
|
|
988
|
+
|
|
989
|
+
try:
|
|
990
|
+
result = user_function(sample_id, preprocess_response)
|
|
991
|
+
finally:
|
|
992
|
+
_called_from_inside_tl_decorator -= 1
|
|
993
|
+
|
|
994
|
+
return result
|
|
995
|
+
|
|
996
|
+
leap_binder.set_custom_latent_space(inner_without_validate)
|
|
997
|
+
|
|
231
998
|
def inner(sample_id, preprocess_response):
|
|
999
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1000
|
+
return None
|
|
1001
|
+
|
|
232
1002
|
_validate_input_args(sample_id, preprocess_response)
|
|
233
|
-
|
|
1003
|
+
|
|
1004
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1005
|
+
|
|
234
1006
|
_validate_result(result)
|
|
235
1007
|
return result
|
|
236
1008
|
|
|
@@ -244,33 +1016,54 @@ def tensorleap_preprocess():
|
|
|
244
1016
|
leap_binder.set_preprocess(user_function)
|
|
245
1017
|
|
|
246
1018
|
def _validate_input_args(*args, **kwargs):
|
|
247
|
-
assert len(args)
|
|
248
|
-
(f'
|
|
1019
|
+
assert len(args) + len(kwargs) == 0, \
|
|
1020
|
+
(f'{user_function.__name__}() validation failed: '
|
|
249
1021
|
f'The function should not take any arguments. Got {args} and {kwargs}.')
|
|
250
1022
|
|
|
251
1023
|
def _validate_result(result):
|
|
252
|
-
assert isinstance(result, list),
|
|
253
|
-
(
|
|
254
|
-
|
|
1024
|
+
assert isinstance(result, list), (
|
|
1025
|
+
f"{user_function.__name__}() validation failed: expected return type list[{PreprocessResponse.__name__}]"
|
|
1026
|
+
f"(e.g., [PreprocessResponse1, PreprocessResponse2, ...]), but returned type is {type(result).__name__}."
|
|
1027
|
+
if not isinstance(result, tuple)
|
|
1028
|
+
else f"{user_function.__name__}() validation failed: expected to return a single list[{PreprocessResponse.__name__}] object, "
|
|
1029
|
+
f"but returned {len(result)} objects instead."
|
|
1030
|
+
)
|
|
255
1031
|
for i, response in enumerate(result):
|
|
256
1032
|
assert isinstance(response, PreprocessResponse), \
|
|
257
|
-
(f'
|
|
1033
|
+
(f'{user_function.__name__}() validation failed: '
|
|
258
1034
|
f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
|
|
259
1035
|
assert len(set(result)) == len(result), \
|
|
260
|
-
(f'
|
|
1036
|
+
(f'{user_function.__name__}() validation failed: '
|
|
261
1037
|
f'The return list should not contain duplicate PreprocessResponse objects.')
|
|
262
1038
|
|
|
263
1039
|
def inner(*args, **kwargs):
|
|
1040
|
+
if not _call_from_tl_platform:
|
|
1041
|
+
set_current('tensorleap_metadata')
|
|
1042
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1043
|
+
return [None, None, None, None]
|
|
1044
|
+
|
|
264
1045
|
_validate_input_args(*args, **kwargs)
|
|
265
1046
|
result = user_function()
|
|
266
1047
|
_validate_result(result)
|
|
1048
|
+
|
|
1049
|
+
# Emit integration test event once per test
|
|
1050
|
+
try:
|
|
1051
|
+
emit_integration_event_once(AnalyticsEvent.PREPROCESS_INTEGRATION_TEST, {
|
|
1052
|
+
'preprocess_responses_count': len(result)
|
|
1053
|
+
})
|
|
1054
|
+
except Exception as e:
|
|
1055
|
+
logger.debug(f"Failed to emit preprocess integration test event: {e}")
|
|
1056
|
+
if not _call_from_tl_platform:
|
|
1057
|
+
update_env_params_func("tensorleap_preprocess", "v")
|
|
267
1058
|
return result
|
|
268
1059
|
|
|
269
1060
|
return inner
|
|
270
1061
|
|
|
271
1062
|
return decorating_function
|
|
272
1063
|
|
|
273
|
-
|
|
1064
|
+
|
|
1065
|
+
def tensorleap_element_instance_preprocess(
|
|
1066
|
+
instance_length_encoder: InstanceLengthCallableInterface):
|
|
274
1067
|
def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
|
|
275
1068
|
def user_function_instance() -> List[PreprocessResponse]:
|
|
276
1069
|
result = user_function()
|
|
@@ -279,22 +1072,46 @@ def tensorleap_element_instance_preprocess(instance_mask_encoder: Callable[[str,
|
|
|
279
1072
|
instance_to_sample_ids_mappings = {}
|
|
280
1073
|
all_sample_ids = preprocess_response.sample_ids.copy()
|
|
281
1074
|
for sample_id in preprocess_response.sample_ids:
|
|
282
|
-
|
|
283
|
-
instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(
|
|
1075
|
+
instances_length = instance_length_encoder(sample_id, preprocess_response)
|
|
1076
|
+
instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(instances_length)]
|
|
284
1077
|
sample_ids_to_instance_mappings[sample_id] = instances_ids
|
|
285
1078
|
instance_to_sample_ids_mappings[sample_id] = sample_id
|
|
286
1079
|
for instance_id in instances_ids:
|
|
287
1080
|
instance_to_sample_ids_mappings[instance_id] = sample_id
|
|
288
1081
|
all_sample_ids.extend(instances_ids)
|
|
1082
|
+
preprocess_response.length = len(all_sample_ids)
|
|
289
1083
|
preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
|
|
290
1084
|
preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
|
|
291
1085
|
preprocess_response.sample_ids = all_sample_ids
|
|
292
1086
|
return result
|
|
293
1087
|
|
|
294
|
-
def
|
|
295
|
-
|
|
1088
|
+
# def extract_extra_instance_metadata():
|
|
1089
|
+
# result = user_function()
|
|
1090
|
+
# for preprocess_response in result:
|
|
1091
|
+
# for sample_id in preprocess_response.sample_ids:
|
|
1092
|
+
# instances_length = instance_length_encoder(sample_id, preprocess_response)
|
|
1093
|
+
# if instances_length > 0:
|
|
1094
|
+
# element_instance = instance_mask_encoder(sample_id, preprocess_response, 0)
|
|
1095
|
+
# instance_metadata = element_instance.instance_metadata
|
|
1096
|
+
# if instance_metadata is None:
|
|
1097
|
+
# return {}
|
|
1098
|
+
# return instance_metadata
|
|
1099
|
+
# return {}
|
|
1100
|
+
|
|
1101
|
+
|
|
1102
|
+
def builtin_instance_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
|
|
1103
|
+
return {'is_instance': '0', 'original_sample_id': idx, 'instance_name': 'none'}
|
|
1104
|
+
|
|
1105
|
+
# def builtin_instance_extra_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
|
|
1106
|
+
# instance_metadata = extract_extra_instance_metadata()
|
|
1107
|
+
# for key, value in instance_metadata.items():
|
|
1108
|
+
# instance_metadata[key] = 'unset'
|
|
1109
|
+
# return instance_metadata
|
|
1110
|
+
|
|
296
1111
|
leap_binder.set_preprocess(user_function_instance)
|
|
297
|
-
leap_binder.set_metadata(
|
|
1112
|
+
leap_binder.set_metadata(builtin_instance_metadata, "builtin_instance_metadata")
|
|
1113
|
+
# leap_binder.set_metadata(builtin_instance_extra_metadata, "builtin_instance_extra_metadata")
|
|
1114
|
+
|
|
298
1115
|
|
|
299
1116
|
def _validate_input_args(*args, **kwargs):
|
|
300
1117
|
assert len(args) == 0 and len(kwargs) == 0, \
|
|
@@ -314,9 +1131,15 @@ def tensorleap_element_instance_preprocess(instance_mask_encoder: Callable[[str,
|
|
|
314
1131
|
f'The return list should not contain duplicate PreprocessResponse objects.')
|
|
315
1132
|
|
|
316
1133
|
def inner(*args, **kwargs):
|
|
1134
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1135
|
+
return [None, None, None, None]
|
|
1136
|
+
|
|
317
1137
|
_validate_input_args(*args, **kwargs)
|
|
1138
|
+
|
|
318
1139
|
result = user_function_instance()
|
|
319
1140
|
_validate_result(result)
|
|
1141
|
+
if not _call_from_tl_platform:
|
|
1142
|
+
update_env_params_func("tensorleap_preprocess", "v")
|
|
320
1143
|
return result
|
|
321
1144
|
|
|
322
1145
|
return inner
|
|
@@ -351,9 +1174,7 @@ def tensorleap_unlabeled_preprocess():
|
|
|
351
1174
|
|
|
352
1175
|
def tensorleap_instances_masks_encoder(name: str):
|
|
353
1176
|
def decorating_function(user_function: InstanceCallableInterface):
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
|
|
1177
|
+
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: int):
|
|
357
1178
|
assert isinstance(sample_id, str), \
|
|
358
1179
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
359
1180
|
f'Argument sample_id should be str. Got {type(sample_id)}.')
|
|
@@ -364,15 +1185,82 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
364
1185
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
365
1186
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
366
1187
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
1188
|
+
assert isinstance(instance_id, int), \
|
|
1189
|
+
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
1190
|
+
f'Argument instance_id should be int. Got {type(instance_id)}.')
|
|
367
1191
|
|
|
368
1192
|
def _validate_result(result):
|
|
369
|
-
assert isinstance(result,
|
|
1193
|
+
assert isinstance(result, ElementInstance) or (result is None), \
|
|
370
1194
|
(f'tensorleap_instances_masks_encoder validation failed: '
|
|
371
|
-
f'Unsupported return type. Should be a
|
|
1195
|
+
f'Unsupported return type. Should be a ElementInstance or None. Got {type(result)}.')
|
|
1196
|
+
|
|
1197
|
+
def inner_without_validate(sample_id, preprocess_response, instance_id):
|
|
1198
|
+
global _called_from_inside_tl_decorator
|
|
1199
|
+
_called_from_inside_tl_decorator += 1
|
|
1200
|
+
|
|
1201
|
+
try:
|
|
1202
|
+
result = user_function(sample_id, preprocess_response, instance_id)
|
|
1203
|
+
finally:
|
|
1204
|
+
_called_from_inside_tl_decorator -= 1
|
|
1205
|
+
|
|
1206
|
+
return result
|
|
1207
|
+
|
|
1208
|
+
leap_binder.set_instance_masks(inner_without_validate, name)
|
|
1209
|
+
|
|
1210
|
+
def inner(sample_id, preprocess_response, instance_id):
|
|
1211
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1212
|
+
return None
|
|
1213
|
+
|
|
1214
|
+
_validate_input_args(sample_id, preprocess_response, instance_id)
|
|
1215
|
+
|
|
1216
|
+
result = inner_without_validate(sample_id, preprocess_response, instance_id)
|
|
1217
|
+
|
|
1218
|
+
_validate_result(result)
|
|
1219
|
+
return result
|
|
1220
|
+
|
|
1221
|
+
return inner
|
|
1222
|
+
|
|
1223
|
+
return decorating_function
|
|
1224
|
+
|
|
1225
|
+
|
|
1226
|
+
def tensorleap_instances_length_encoder(name: str):
|
|
1227
|
+
def decorating_function(user_function: InstanceLengthCallableInterface):
|
|
1228
|
+
def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
|
|
1229
|
+
assert isinstance(sample_id, (str, int)), \
|
|
1230
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1231
|
+
f'Argument sample_id should be str. Got {type(sample_id)}.')
|
|
1232
|
+
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
1233
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1234
|
+
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
1235
|
+
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
1236
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1237
|
+
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
1238
|
+
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
1239
|
+
|
|
1240
|
+
def _validate_result(result):
|
|
1241
|
+
assert isinstance(result, int), \
|
|
1242
|
+
(f'tensorleap_instances_length_encoder validation failed: '
|
|
1243
|
+
f'Unsupported return type. Should be a int. Got {type(result)}.')
|
|
1244
|
+
|
|
1245
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1246
|
+
global _called_from_inside_tl_decorator
|
|
1247
|
+
_called_from_inside_tl_decorator += 1
|
|
1248
|
+
|
|
1249
|
+
try:
|
|
1250
|
+
result = user_function(sample_id, preprocess_response)
|
|
1251
|
+
finally:
|
|
1252
|
+
_called_from_inside_tl_decorator -= 1
|
|
1253
|
+
|
|
1254
|
+
return result
|
|
372
1255
|
|
|
373
1256
|
def inner(sample_id, preprocess_response):
|
|
1257
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1258
|
+
return None
|
|
1259
|
+
|
|
374
1260
|
_validate_input_args(sample_id, preprocess_response)
|
|
375
|
-
|
|
1261
|
+
|
|
1262
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1263
|
+
|
|
376
1264
|
_validate_result(result)
|
|
377
1265
|
return result
|
|
378
1266
|
|
|
@@ -381,43 +1269,87 @@ def tensorleap_instances_masks_encoder(name: str):
|
|
|
381
1269
|
return decorating_function
|
|
382
1270
|
|
|
383
1271
|
|
|
384
|
-
def tensorleap_input_encoder(name: str, channel_dim
|
|
1272
|
+
def tensorleap_input_encoder(name: str, channel_dim=_UNSET, model_input_index=None):
|
|
385
1273
|
def decorating_function(user_function: SectionCallableInterface):
|
|
386
1274
|
for input_handler in leap_binder.setup_container.inputs:
|
|
387
1275
|
if input_handler.name == name:
|
|
388
1276
|
raise Exception(f'Input with name {name} already exists. '
|
|
389
1277
|
f'Please choose another')
|
|
1278
|
+
nonlocal channel_dim
|
|
1279
|
+
|
|
1280
|
+
channel_dim_was_provided = channel_dim is not _UNSET
|
|
1281
|
+
|
|
1282
|
+
if not channel_dim_was_provided:
|
|
1283
|
+
channel_dim = -1
|
|
1284
|
+
if not _call_from_tl_platform:
|
|
1285
|
+
store_warning_by_param(
|
|
1286
|
+
param_name="channel_dim",
|
|
1287
|
+
user_func_name=user_function.__name__,
|
|
1288
|
+
default_value=channel_dim,
|
|
1289
|
+
link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/input-encoder"
|
|
1290
|
+
|
|
1291
|
+
)
|
|
1292
|
+
|
|
390
1293
|
if channel_dim <= 0 and channel_dim != -1:
|
|
391
1294
|
raise Exception(f"Channel dim for input {name} is expected to be either -1 or positive")
|
|
392
1295
|
|
|
393
|
-
leap_binder.set_input(user_function, name, channel_dim=channel_dim)
|
|
394
|
-
|
|
395
1296
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
396
|
-
assert isinstance(sample_id, (int, str)), \
|
|
397
|
-
(f'tensorleap_input_encoder validation failed: '
|
|
398
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
399
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
400
|
-
(f'tensorleap_input_encoder validation failed: '
|
|
401
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
402
1297
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
403
|
-
(f'
|
|
1298
|
+
(f'{user_function.__name__}() validation failed: '
|
|
404
1299
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
405
1300
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
406
1301
|
|
|
407
1302
|
def _validate_result(result):
|
|
1303
|
+
validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray")
|
|
408
1304
|
assert isinstance(result, np.ndarray), \
|
|
409
|
-
(f'
|
|
1305
|
+
(f'{user_function.__name__}() validation failed: '
|
|
410
1306
|
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
411
1307
|
assert result.dtype == np.float32, \
|
|
412
|
-
(f'
|
|
1308
|
+
(f'{user_function.__name__}() validation failed: '
|
|
413
1309
|
f'The return type should be a numpy array of type float32. Got {result.dtype}.')
|
|
414
|
-
assert channel_dim - 1 <= len(result.shape), (f'
|
|
1310
|
+
assert channel_dim - 1 <= len(result.shape), (f'{user_function.__name__}() validation failed: '
|
|
415
1311
|
f'The channel_dim ({channel_dim}) should be <= to the rank of the resulting input rank ({len(result.shape)}).')
|
|
416
1312
|
|
|
417
|
-
def
|
|
1313
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1314
|
+
global _called_from_inside_tl_decorator
|
|
1315
|
+
_called_from_inside_tl_decorator += 1
|
|
1316
|
+
|
|
1317
|
+
try:
|
|
1318
|
+
result = user_function(sample_id, preprocess_response)
|
|
1319
|
+
finally:
|
|
1320
|
+
_called_from_inside_tl_decorator -= 1
|
|
1321
|
+
|
|
1322
|
+
return result
|
|
1323
|
+
|
|
1324
|
+
leap_binder.set_input(inner_without_validate, name, channel_dim=channel_dim)
|
|
1325
|
+
|
|
1326
|
+
def inner(*args, **kwargs):
|
|
1327
|
+
if not _call_from_tl_platform:
|
|
1328
|
+
set_current("tensorleap_input_encoder")
|
|
1329
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
1330
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
1331
|
+
sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
|
|
418
1332
|
_validate_input_args(sample_id, preprocess_response)
|
|
419
|
-
|
|
1333
|
+
|
|
1334
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1335
|
+
|
|
420
1336
|
_validate_result(result)
|
|
1337
|
+
|
|
1338
|
+
if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
|
|
1339
|
+
batch_warning(result, user_function.__name__)
|
|
1340
|
+
result = np.expand_dims(result, axis=0)
|
|
1341
|
+
# Emit integration test event once per test
|
|
1342
|
+
try:
|
|
1343
|
+
emit_integration_event_once(AnalyticsEvent.INPUT_ENCODER_INTEGRATION_TEST, {
|
|
1344
|
+
'encoder_name': name,
|
|
1345
|
+
'channel_dim': channel_dim,
|
|
1346
|
+
'model_input_index': model_input_index
|
|
1347
|
+
})
|
|
1348
|
+
except Exception as e:
|
|
1349
|
+
logger.debug(f"Failed to emit input_encoder integration test event: {e}")
|
|
1350
|
+
if not _call_from_tl_platform:
|
|
1351
|
+
update_env_params_func("tensorleap_input_encoder", "v")
|
|
1352
|
+
|
|
421
1353
|
return result
|
|
422
1354
|
|
|
423
1355
|
node_mapping_type = NodeMappingType.Input
|
|
@@ -425,7 +1357,27 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
|
|
|
425
1357
|
node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
|
|
426
1358
|
inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
427
1359
|
|
|
428
|
-
|
|
1360
|
+
def mapping_inner(*args, **kwargs):
|
|
1361
|
+
class TempMapping:
|
|
1362
|
+
pass
|
|
1363
|
+
|
|
1364
|
+
ret = TempMapping()
|
|
1365
|
+
ret.node_mapping = mapping_inner.node_mapping
|
|
1366
|
+
|
|
1367
|
+
leap_binder.mapping_connections.append(NodeConnection(mapping_inner.node_mapping, None))
|
|
1368
|
+
return ret
|
|
1369
|
+
|
|
1370
|
+
mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
1371
|
+
|
|
1372
|
+
def final_inner(*args, **kwargs):
|
|
1373
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1374
|
+
return mapping_inner(*args, **kwargs)
|
|
1375
|
+
else:
|
|
1376
|
+
return inner(*args, **kwargs)
|
|
1377
|
+
|
|
1378
|
+
final_inner.node_mapping = NodeMapping(name, node_mapping_type)
|
|
1379
|
+
|
|
1380
|
+
return final_inner
|
|
429
1381
|
|
|
430
1382
|
return decorating_function
|
|
431
1383
|
|
|
@@ -437,95 +1389,187 @@ def tensorleap_gt_encoder(name: str):
|
|
|
437
1389
|
raise Exception(f'GT with name {name} already exists. '
|
|
438
1390
|
f'Please choose another')
|
|
439
1391
|
|
|
440
|
-
leap_binder.set_ground_truth(user_function, name)
|
|
441
|
-
|
|
442
1392
|
def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
|
|
443
|
-
assert isinstance(sample_id, (int, str)), \
|
|
444
|
-
(f'tensorleap_gt_encoder validation failed: '
|
|
445
|
-
f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
|
|
446
|
-
assert isinstance(preprocess_response, PreprocessResponse), \
|
|
447
|
-
(f'tensorleap_gt_encoder validation failed: '
|
|
448
|
-
f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
|
|
449
1393
|
assert type(sample_id) == preprocess_response.sample_id_type, \
|
|
450
|
-
(f'
|
|
1394
|
+
(f'{user_function.__name__}() validation failed: '
|
|
451
1395
|
f'Argument sample_id should be as the same type as defined in the preprocess response '
|
|
452
1396
|
f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
|
|
453
1397
|
|
|
454
1398
|
def _validate_result(result):
|
|
1399
|
+
validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray",
|
|
1400
|
+
gt_flag=True)
|
|
455
1401
|
assert isinstance(result, np.ndarray), \
|
|
456
|
-
(f'
|
|
1402
|
+
(f'{user_function.__name__}() validation failed: '
|
|
457
1403
|
f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
|
|
458
1404
|
assert result.dtype == np.float32, \
|
|
459
|
-
(f'
|
|
1405
|
+
(f'{user_function.__name__}() validation failed: '
|
|
460
1406
|
f'The return type should be a numpy array of type float32. Got {result.dtype}.')
|
|
461
1407
|
|
|
462
|
-
def
|
|
1408
|
+
def inner_without_validate(sample_id, preprocess_response):
|
|
1409
|
+
global _called_from_inside_tl_decorator
|
|
1410
|
+
_called_from_inside_tl_decorator += 1
|
|
1411
|
+
|
|
1412
|
+
try:
|
|
1413
|
+
result = user_function(sample_id, preprocess_response)
|
|
1414
|
+
finally:
|
|
1415
|
+
_called_from_inside_tl_decorator -= 1
|
|
1416
|
+
|
|
1417
|
+
return result
|
|
1418
|
+
|
|
1419
|
+
leap_binder.set_ground_truth(inner_without_validate, name)
|
|
1420
|
+
|
|
1421
|
+
def inner(*args, **kwargs):
|
|
1422
|
+
if not _call_from_tl_platform:
|
|
1423
|
+
set_current("tensorleap_gt_encoder")
|
|
1424
|
+
validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
|
|
1425
|
+
func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
|
|
1426
|
+
sample_id, preprocess_response = args
|
|
463
1427
|
_validate_input_args(sample_id, preprocess_response)
|
|
464
|
-
|
|
1428
|
+
|
|
1429
|
+
result = inner_without_validate(sample_id, preprocess_response)
|
|
1430
|
+
|
|
465
1431
|
_validate_result(result)
|
|
1432
|
+
|
|
1433
|
+
if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
|
|
1434
|
+
batch_warning(result, user_function.__name__)
|
|
1435
|
+
result = np.expand_dims(result, axis=0)
|
|
1436
|
+
# Emit integration test event once per test
|
|
1437
|
+
try:
|
|
1438
|
+
emit_integration_event_once(AnalyticsEvent.GT_ENCODER_INTEGRATION_TEST, {
|
|
1439
|
+
'encoder_name': name
|
|
1440
|
+
})
|
|
1441
|
+
except Exception as e:
|
|
1442
|
+
logger.debug(f"Failed to emit gt_encoder integration test event: {e}")
|
|
1443
|
+
if not _call_from_tl_platform:
|
|
1444
|
+
update_env_params_func("tensorleap_gt_encoder", "v")
|
|
466
1445
|
return result
|
|
467
1446
|
|
|
468
1447
|
inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
469
1448
|
|
|
470
|
-
|
|
1449
|
+
def mapping_inner(*args, **kwargs):
|
|
1450
|
+
class TempMapping:
|
|
1451
|
+
pass
|
|
471
1452
|
|
|
1453
|
+
ret = TempMapping()
|
|
1454
|
+
ret.node_mapping = mapping_inner.node_mapping
|
|
472
1455
|
|
|
1456
|
+
return ret
|
|
1457
|
+
|
|
1458
|
+
mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
1459
|
+
|
|
1460
|
+
def final_inner(*args, **kwargs):
|
|
1461
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1462
|
+
return mapping_inner(*args, **kwargs)
|
|
1463
|
+
else:
|
|
1464
|
+
return inner(*args, **kwargs)
|
|
1465
|
+
|
|
1466
|
+
final_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
|
|
1467
|
+
|
|
1468
|
+
return final_inner
|
|
473
1469
|
|
|
474
1470
|
return decorating_function
|
|
475
1471
|
|
|
476
1472
|
|
|
477
1473
|
def tensorleap_custom_loss(name: str, connects_to=None):
|
|
1474
|
+
name_to_unique_name = defaultdict(set)
|
|
1475
|
+
|
|
478
1476
|
def decorating_function(user_function: CustomCallableInterface):
|
|
479
1477
|
for loss_handler in leap_binder.setup_container.custom_loss_handlers:
|
|
480
1478
|
if loss_handler.custom_loss_handler_data.name == name:
|
|
481
1479
|
raise Exception(f'Custom loss with name {name} already exists. '
|
|
482
1480
|
f'Please choose another')
|
|
483
1481
|
|
|
484
|
-
leap_binder.add_custom_loss(user_function, name)
|
|
485
|
-
|
|
486
|
-
if connects_to is not None:
|
|
487
|
-
arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
|
|
488
|
-
_add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
|
|
489
|
-
|
|
490
|
-
|
|
491
1482
|
valid_types = (np.ndarray, SamplePreprocessResponse)
|
|
492
|
-
try:
|
|
493
|
-
import tensorflow as tf
|
|
494
|
-
valid_types = (np.ndarray, SamplePreprocessResponse, tf.Tensor)
|
|
495
|
-
except ImportError:
|
|
496
|
-
pass
|
|
497
1483
|
|
|
498
1484
|
def _validate_input_args(*args, **kwargs):
|
|
499
|
-
|
|
1485
|
+
assert len(args) + len(kwargs) > 0, (
|
|
1486
|
+
f"{user_function.__name__}() validation failed: "
|
|
1487
|
+
f"Expected at least one positional|key-word argument of the allowed types (np.ndarray|SamplePreprocessResponse|). "
|
|
1488
|
+
f"but received none. "
|
|
1489
|
+
f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
|
|
1490
|
+
)
|
|
500
1491
|
for i, arg in enumerate(args):
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
504
|
-
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
|
505
|
-
else:
|
|
506
|
-
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
507
|
-
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
1492
|
+
assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
|
|
1493
|
+
f'Argument #{i} should be a numpy array. Got {type(arg)}.')
|
|
508
1494
|
for _arg_name, arg in kwargs.items():
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
512
|
-
f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
|
|
513
|
-
else:
|
|
514
|
-
assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
|
|
515
|
-
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
1495
|
+
assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
|
|
1496
|
+
f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
|
|
516
1497
|
|
|
517
1498
|
def _validate_result(result):
|
|
518
|
-
|
|
519
|
-
|
|
1499
|
+
validate_output_structure(result, func_name=user_function.__name__,
|
|
1500
|
+
expected_type_name="np.ndarray")
|
|
1501
|
+
assert isinstance(result, np.ndarray), \
|
|
1502
|
+
(f'{user_function.__name__} validation failed: '
|
|
520
1503
|
f'The return type should be a numpy array. Got {type(result)}.')
|
|
1504
|
+
assert result.ndim < 2, (f'{user_function.__name__} validation failed: '
|
|
1505
|
+
f'The return type should be a 1Dim numpy array but got {result.ndim}Dim.')
|
|
1506
|
+
|
|
1507
|
+
@functools.wraps(user_function)
|
|
1508
|
+
def inner_without_validate(*args, **kwargs):
|
|
1509
|
+
global _called_from_inside_tl_decorator
|
|
1510
|
+
_called_from_inside_tl_decorator += 1
|
|
1511
|
+
|
|
1512
|
+
try:
|
|
1513
|
+
result = user_function(*args, **kwargs)
|
|
1514
|
+
finally:
|
|
1515
|
+
_called_from_inside_tl_decorator -= 1
|
|
1516
|
+
|
|
1517
|
+
return result
|
|
1518
|
+
|
|
1519
|
+
try:
|
|
1520
|
+
inner_without_validate.__signature__ = inspect.signature(user_function)
|
|
1521
|
+
except (TypeError, ValueError):
|
|
1522
|
+
pass
|
|
1523
|
+
|
|
1524
|
+
leap_binder.add_custom_loss(inner_without_validate, name)
|
|
1525
|
+
|
|
1526
|
+
if connects_to is not None:
|
|
1527
|
+
arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
|
|
1528
|
+
_add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
|
|
521
1529
|
|
|
522
1530
|
def inner(*args, **kwargs):
|
|
1531
|
+
if not _call_from_tl_platform:
|
|
1532
|
+
set_current("tensorleap_custom_loss")
|
|
523
1533
|
_validate_input_args(*args, **kwargs)
|
|
524
|
-
|
|
1534
|
+
|
|
1535
|
+
result = inner_without_validate(*args, **kwargs)
|
|
1536
|
+
|
|
525
1537
|
_validate_result(result)
|
|
1538
|
+
if not _call_from_tl_platform:
|
|
1539
|
+
update_env_params_func("tensorleap_custom_loss", "v")
|
|
1540
|
+
|
|
526
1541
|
return result
|
|
527
1542
|
|
|
528
|
-
|
|
1543
|
+
def mapping_inner(*args, **kwargs):
|
|
1544
|
+
user_unique_name = mapping_inner.name
|
|
1545
|
+
if 'user_unique_name' in kwargs:
|
|
1546
|
+
user_unique_name = kwargs['user_unique_name']
|
|
1547
|
+
|
|
1548
|
+
if user_unique_name in name_to_unique_name[mapping_inner.name]:
|
|
1549
|
+
user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
|
|
1550
|
+
name_to_unique_name[mapping_inner.name].add(user_unique_name)
|
|
1551
|
+
|
|
1552
|
+
ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
|
|
1553
|
+
ordered_connections = list(args) + ordered_connections
|
|
1554
|
+
_add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
|
|
1555
|
+
mapping_inner.name, NodeMappingType.CustomLoss)
|
|
1556
|
+
|
|
1557
|
+
return None
|
|
1558
|
+
|
|
1559
|
+
mapping_inner.arg_names = leap_binder.setup_container.custom_loss_handlers[
|
|
1560
|
+
-1].custom_loss_handler_data.arg_names
|
|
1561
|
+
mapping_inner.name = name
|
|
1562
|
+
|
|
1563
|
+
def final_inner(*args, **kwargs):
|
|
1564
|
+
if os.environ.get(mapping_runtime_mode_env_var_mame):
|
|
1565
|
+
return mapping_inner(*args, **kwargs)
|
|
1566
|
+
else:
|
|
1567
|
+
return inner(*args, **kwargs)
|
|
1568
|
+
|
|
1569
|
+
final_inner.arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
|
|
1570
|
+
final_inner.name = name
|
|
1571
|
+
|
|
1572
|
+
return final_inner
|
|
529
1573
|
|
|
530
1574
|
return decorating_function
|
|
531
1575
|
|
|
@@ -550,3 +1594,180 @@ def tensorleap_custom_layer(name: str):
|
|
|
550
1594
|
return custom_layer
|
|
551
1595
|
|
|
552
1596
|
return decorating_function
|
|
1597
|
+
|
|
1598
|
+
|
|
1599
|
+
def tensorleap_status_table():
|
|
1600
|
+
import atexit
|
|
1601
|
+
import sys
|
|
1602
|
+
import traceback
|
|
1603
|
+
from typing import Any
|
|
1604
|
+
|
|
1605
|
+
CHECK = "✅"
|
|
1606
|
+
CROSS = "❌"
|
|
1607
|
+
UNKNOWN = "❔"
|
|
1608
|
+
|
|
1609
|
+
code_mapping_failure = [0]
|
|
1610
|
+
|
|
1611
|
+
table = [
|
|
1612
|
+
{"name": "tensorleap_preprocess", "Added to integration": UNKNOWN},
|
|
1613
|
+
{"name": "tensorleap_integration_test", "Added to integration": UNKNOWN},
|
|
1614
|
+
{"name": "tensorleap_input_encoder", "Added to integration": UNKNOWN},
|
|
1615
|
+
{"name": "tensorleap_gt_encoder", "Added to integration": UNKNOWN},
|
|
1616
|
+
{"name": "tensorleap_load_model", "Added to integration": UNKNOWN},
|
|
1617
|
+
{"name": "tensorleap_custom_loss", "Added to integration": UNKNOWN},
|
|
1618
|
+
{"name": "tensorleap_custom_metric (optional)", "Added to integration": UNKNOWN},
|
|
1619
|
+
{"name": "tensorleap_metadata (optional)", "Added to integration": UNKNOWN},
|
|
1620
|
+
{"name": "tensorleap_custom_visualizer (optional)", "Added to integration": UNKNOWN},
|
|
1621
|
+
]
|
|
1622
|
+
|
|
1623
|
+
_finalizer_called = {"done": False}
|
|
1624
|
+
_crashed = {"value": False}
|
|
1625
|
+
_current_func = {"name": None}
|
|
1626
|
+
|
|
1627
|
+
def _link(url: str) -> str:
|
|
1628
|
+
return f"\033{url}\033"
|
|
1629
|
+
|
|
1630
|
+
def _remove_suffix(s: str, suffix: str) -> str:
|
|
1631
|
+
if suffix and s.endswith(suffix):
|
|
1632
|
+
return s[:-len(suffix)]
|
|
1633
|
+
return s
|
|
1634
|
+
|
|
1635
|
+
def _find_row(name: str):
|
|
1636
|
+
for row in table:
|
|
1637
|
+
if _remove_suffix(row["name"], " (optional)") == name:
|
|
1638
|
+
return row
|
|
1639
|
+
return None
|
|
1640
|
+
|
|
1641
|
+
def _set_status(name: str, status_symbol: str):
|
|
1642
|
+
row = _find_row(name)
|
|
1643
|
+
if not row:
|
|
1644
|
+
return
|
|
1645
|
+
|
|
1646
|
+
cur = row["Added to integration"]
|
|
1647
|
+
if status_symbol == UNKNOWN:
|
|
1648
|
+
return
|
|
1649
|
+
if cur == CHECK and status_symbol != CHECK:
|
|
1650
|
+
return
|
|
1651
|
+
|
|
1652
|
+
row["Added to integration"] = status_symbol
|
|
1653
|
+
|
|
1654
|
+
def _mark_unknowns_as_cross():
|
|
1655
|
+
for row in table:
|
|
1656
|
+
if row["Added to integration"] == UNKNOWN:
|
|
1657
|
+
row["Added to integration"] = CROSS
|
|
1658
|
+
|
|
1659
|
+
def _format_default_value(v: Any) -> str:
|
|
1660
|
+
if hasattr(v, "name"):
|
|
1661
|
+
return str(v.name)
|
|
1662
|
+
if isinstance(v, str):
|
|
1663
|
+
return v
|
|
1664
|
+
s = repr(v)
|
|
1665
|
+
return s if len(s) <= 120 else s[:120] + "..."
|
|
1666
|
+
|
|
1667
|
+
def _print_param_default_warnings():
|
|
1668
|
+
data = _get_param_default_warnings()
|
|
1669
|
+
if not data:
|
|
1670
|
+
return
|
|
1671
|
+
|
|
1672
|
+
print("\nWarnings (Default use. It is recommended to set values explicitly):")
|
|
1673
|
+
for param_name in sorted(data.keys()):
|
|
1674
|
+
default_value = data[param_name]["default_value"]
|
|
1675
|
+
funcs = ", ".join(sorted(data[param_name]["funcs"]))
|
|
1676
|
+
dv = _format_default_value(default_value)
|
|
1677
|
+
|
|
1678
|
+
docs_link = data[param_name].get("link_to_docs")
|
|
1679
|
+
docs_part = f" {_link(docs_link)}" if docs_link else ""
|
|
1680
|
+
print(
|
|
1681
|
+
f" ⚠️ Parameter '{param_name}' defaults to {dv} in the following functions: [{funcs}]. "
|
|
1682
|
+
f"For more information, check {docs_part}")
|
|
1683
|
+
print("\nIf this isn’t the intended behaviour, set them explicitly.")
|
|
1684
|
+
|
|
1685
|
+
def _print_table():
|
|
1686
|
+
_print_param_default_warnings()
|
|
1687
|
+
|
|
1688
|
+
if not started_from("leap_integration.py"):
|
|
1689
|
+
return
|
|
1690
|
+
|
|
1691
|
+
ready_mess = "\nAll parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system."
|
|
1692
|
+
not_ready_mess = "\nSome mandatory components have not yet been added to the Integration test. Recommended next interface to add is: "
|
|
1693
|
+
mandatory_ready_mess = "\nAll mandatory parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system or continue to the next optional reccomeded interface,adding: "
|
|
1694
|
+
code_mapping_failure_mes = "Tensorleap_integration_test code flow failed, check raised exception."
|
|
1695
|
+
|
|
1696
|
+
name_width = max(len(row["name"]) for row in table)
|
|
1697
|
+
status_width = max(len(row["Added to integration"]) for row in table)
|
|
1698
|
+
|
|
1699
|
+
header = f"{'Decorator Name'.ljust(name_width)} | {'Added to integration'.ljust(status_width)}"
|
|
1700
|
+
sep = "-" * len(header)
|
|
1701
|
+
|
|
1702
|
+
print("\n" + header)
|
|
1703
|
+
print(sep)
|
|
1704
|
+
|
|
1705
|
+
ready = True
|
|
1706
|
+
next_step = None
|
|
1707
|
+
|
|
1708
|
+
for row in table:
|
|
1709
|
+
print(f"{row['name'].ljust(name_width)} | {row['Added to integration'].ljust(status_width)}")
|
|
1710
|
+
if not _crashed["value"] and ready:
|
|
1711
|
+
if row["Added to integration"] != CHECK:
|
|
1712
|
+
ready = False
|
|
1713
|
+
next_step = row["name"]
|
|
1714
|
+
|
|
1715
|
+
if _crashed["value"]:
|
|
1716
|
+
print(f"\nScript crashed before completing all steps. crashed at function '{_current_func['name']}'.")
|
|
1717
|
+
return
|
|
1718
|
+
|
|
1719
|
+
if code_mapping_failure[0]:
|
|
1720
|
+
print(f"\n{CROSS + code_mapping_failure_mes}.")
|
|
1721
|
+
return
|
|
1722
|
+
|
|
1723
|
+
print(ready_mess) if ready else print(
|
|
1724
|
+
mandatory_ready_mess + next_step
|
|
1725
|
+
) if (next_step and "optional" in next_step) else print(not_ready_mess + (next_step or ""))
|
|
1726
|
+
|
|
1727
|
+
def set_current(name: str):
|
|
1728
|
+
_current_func["name"] = name
|
|
1729
|
+
|
|
1730
|
+
def update_env_params(name: str, status: str = "v"):
|
|
1731
|
+
if name == "code_mapping":
|
|
1732
|
+
code_mapping_failure[0] = 1
|
|
1733
|
+
if status == "v":
|
|
1734
|
+
_set_status(name, CHECK)
|
|
1735
|
+
else:
|
|
1736
|
+
_set_status(name, CROSS)
|
|
1737
|
+
|
|
1738
|
+
def run_on_exit():
|
|
1739
|
+
if _finalizer_called["done"]:
|
|
1740
|
+
return
|
|
1741
|
+
_finalizer_called["done"] = True
|
|
1742
|
+
if not _crashed["value"]:
|
|
1743
|
+
_mark_unknowns_as_cross()
|
|
1744
|
+
|
|
1745
|
+
_print_table()
|
|
1746
|
+
|
|
1747
|
+
def handle_exception(exc_type, exc_value, exc_traceback):
|
|
1748
|
+
_crashed["value"] = True
|
|
1749
|
+
crashed_name = _current_func["name"]
|
|
1750
|
+
if crashed_name:
|
|
1751
|
+
row = _find_row(crashed_name)
|
|
1752
|
+
if row and row["Added to integration"] != CHECK:
|
|
1753
|
+
row["Added to integration"] = CROSS
|
|
1754
|
+
|
|
1755
|
+
traceback.print_exception(exc_type, exc_value, exc_traceback)
|
|
1756
|
+
run_on_exit()
|
|
1757
|
+
|
|
1758
|
+
atexit.register(run_on_exit)
|
|
1759
|
+
sys.excepthook = handle_exception
|
|
1760
|
+
|
|
1761
|
+
return set_current, update_env_params
|
|
1762
|
+
|
|
1763
|
+
|
|
1764
|
+
if not _call_from_tl_platform:
|
|
1765
|
+
set_current, update_env_params_func = tensorleap_status_table()
|
|
1766
|
+
|
|
1767
|
+
|
|
1768
|
+
|
|
1769
|
+
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
|
|
1773
|
+
|