code-loader 1.0.94.dev11__py3-none-any.whl → 1.0.153.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of code-loader might be problematic. Click here for more details.

@@ -1,114 +1,703 @@
1
1
  # mypy: ignore-errors
2
-
3
- from typing import Optional, Union, Callable, List, Dict
2
+ import os
3
+ import warnings
4
+ import logging
5
+ from collections import defaultdict
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from typing import Optional, Union, Callable, List, Dict, Set, Any
9
+ from typing import Optional, Union, Callable, List, Dict, get_args, get_origin, DefaultDict
4
10
 
5
11
  import numpy as np
6
12
  import numpy.typing as npt
7
13
 
14
+ logger = logging.getLogger(__name__)
15
+
8
16
  from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
9
17
  CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
10
18
  VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
11
- ConfusionMatrixElement, SamplePreprocessResponse, InstanceCallableInterface
12
- from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType
19
+ ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance, \
20
+ InstanceLengthCallableInterface
21
+ from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType, DataStateType
13
22
  from code_loader import leap_binder
14
23
  from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
15
24
  from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
16
25
  LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
26
+ from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
27
+ from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
28
+
29
+ import inspect
30
+ import functools
31
+ from pathlib import Path
32
+
33
+ _called_from_inside_tl_decorator = 0
34
+ _called_from_inside_tl_integration_test_decorator = False
35
+ _call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
36
+
37
+ # ---- warnings store (module-level) ----
38
+ _UNSET = object()
39
+ _STORED_WARNINGS: List[Dict[str, Any]] = []
40
+ _STORED_WARNING_KEYS: Set[tuple] = set()
41
+ # param_name -> set(user_func_name)
42
+ _PARAM_DEFAULT_FUNCS: DefaultDict[str, Set[str]] = defaultdict(set)
43
+ # param_name -> default_value used (repr-able)
44
+ _PARAM_DEFAULT_VALUE: Dict[str, Any] = {}
45
+ _PARAM_DEFAULT_DOCS: Dict[str, str] = {}
46
+
47
+
48
+ def get_entry_script_path() -> str:
49
+ import sys
50
+ argv0 = sys.argv[0] if sys.argv else ""
51
+ if argv0:
52
+ return str(Path(argv0).resolve())
53
+ import __main__
54
+ main_file = getattr(__main__, "__file__", "") or ""
55
+ return str(Path(main_file).resolve()) if main_file else ""
56
+
57
+ def started_from(filename: str) -> bool:
58
+ entry = get_entry_script_path()
59
+ return bool(entry) and Path(entry).name == filename
60
+
61
+
62
+ def store_warning_by_param(
63
+ *,
64
+ param_name: str,
65
+ user_func_name: str,
66
+ default_value: Any,
67
+ link_to_docs: str = None,
68
+ ) -> None:
69
+ _PARAM_DEFAULT_FUNCS[param_name].add(user_func_name)
70
+
71
+ if param_name not in _PARAM_DEFAULT_VALUE:
72
+ _PARAM_DEFAULT_VALUE[param_name] = default_value
73
+
74
+ if link_to_docs and param_name not in _PARAM_DEFAULT_DOCS:
75
+ _PARAM_DEFAULT_DOCS[param_name] = link_to_docs
76
+
77
+
78
+ def _get_param_default_warnings() -> Dict[str, Dict[str, Any]]:
79
+ out: Dict[str, Dict[str, Any]] = {}
80
+ for p, funcs in _PARAM_DEFAULT_FUNCS.items():
81
+ out[p] = {
82
+ "default_value": _PARAM_DEFAULT_VALUE.get(p, None),
83
+ "funcs": set(funcs),
84
+ "link_to_docs": _PARAM_DEFAULT_DOCS.get(p),
85
+ }
86
+ return out
87
+
88
+
89
+ def _get_stored_warnings() -> List[Dict[str, Any]]:
90
+ return list(_STORED_WARNINGS)
91
+
92
+
93
+ def validate_args_structure(*args, types_order, func_name, expected_names, **kwargs):
94
+ def _type_to_str(t):
95
+ origin = get_origin(t)
96
+ if origin is Union:
97
+ return " | ".join(tt.__name__ for tt in get_args(t))
98
+ elif hasattr(t, "__name__"):
99
+ return t.__name__
100
+ else:
101
+ return str(t)
102
+
103
+ def _format_types(types, names=None):
104
+ return ", ".join(
105
+ f"{(names[i] + ': ') if names else f'arg{i}: '}{_type_to_str(ty)}"
106
+ for i, ty in enumerate(types)
107
+ )
108
+
109
+ if expected_names:
110
+ normalized_args = []
111
+ for i, name in enumerate(expected_names):
112
+ if i < len(args):
113
+ normalized_args.append(args[i])
114
+ elif name in kwargs:
115
+ normalized_args.append(kwargs[name])
116
+ else:
117
+ raise AssertionError(
118
+ f"{func_name} validation failed: "
119
+ f"Missing required argument '{name}'. "
120
+ f"Expected arguments: {expected_names}."
121
+ )
122
+ else:
123
+ normalized_args = list(args)
124
+ if len(normalized_args) != len(types_order):
125
+ expected = _format_types(types_order, expected_names)
126
+ got_types = ", ".join(type(arg).__name__ for arg in normalized_args)
127
+ raise AssertionError(
128
+ f"{func_name} validation failed: "
129
+ f"Expected exactly {len(types_order)} arguments ({expected}), "
130
+ f"but got {len(normalized_args)} argument(s) of type(s): ({got_types}). "
131
+ f"Correct usage example: {func_name}({expected})"
132
+ )
133
+
134
+ for i, (arg, expected_type) in enumerate(zip(normalized_args, types_order)):
135
+ origin = get_origin(expected_type)
136
+ if origin is Union:
137
+ allowed_types = get_args(expected_type)
138
+ else:
139
+ allowed_types = (expected_type,)
140
+
141
+ if not isinstance(arg, allowed_types):
142
+ allowed_str = " | ".join(t.__name__ for t in allowed_types)
143
+ raise AssertionError(
144
+ f"{func_name} validation failed: "
145
+ f"Argument '{expected_names[i] if expected_names else f'arg{i}'}' "
146
+ f"expected type {allowed_str}, but got {type(arg).__name__}. "
147
+ f"Correct usage example: {func_name}({_format_types(types_order, expected_names)})"
148
+ )
149
+
150
+
151
+ def validate_output_structure(result, func_name: str, expected_type_name="np.ndarray", gt_flag=False):
152
+ if result is None or (isinstance(result, float) and np.isnan(result)):
153
+ if gt_flag:
154
+ raise AssertionError(
155
+ f"{func_name} validation failed: "
156
+ f"The function returned {result!r}. "
157
+ f"If you are working with an unlabeled dataset and no ground truth is available, "
158
+ f"use 'return np.array([], dtype=np.float32)' instead. "
159
+ f"Otherwise, {func_name} expected a single {expected_type_name} object. "
160
+ f"Make sure the function ends with 'return <{expected_type_name}>'."
161
+ )
162
+
163
+ raise AssertionError(
164
+ f"{func_name} validation failed: "
165
+ f"The function returned None. "
166
+ f"Expected a single {expected_type_name} object. "
167
+ f"Make sure the function ends with 'return <{expected_type_name}>'."
168
+ )
169
+ if isinstance(result, tuple):
170
+ element_descriptions = [
171
+ f"[{i}] type: {type(r).__name__}"
172
+ for i, r in enumerate(result)
173
+ ]
174
+ element_summary = "\n ".join(element_descriptions)
175
+
176
+ raise AssertionError(
177
+ f"{func_name} validation failed: "
178
+ f"The function returned multiple outputs ({len(result)} values), "
179
+ f"but only a single {expected_type_name} is allowed.\n\n"
180
+ f"Returned elements:\n"
181
+ f" {element_summary}\n\n"
182
+ f"Correct usage example:\n"
183
+ f" def {func_name}(...):\n"
184
+ f" return <{expected_type_name}>\n\n"
185
+ f"If you intended to return multiple values, combine them into a single "
186
+ f"{expected_type_name} (e.g., by concatenation or stacking)."
187
+ )
188
+
189
+
190
+ def batch_warning(result, func_name):
191
+ if len(result.shape) > 0 and result.shape[0] == 1:
192
+ warnings.warn(
193
+ f"{func_name} warning: Tensorleap will add a batch dimension at axis 0 to the output of {func_name}, "
194
+ f"although the detected size of axis 0 is already 1. "
195
+ f"This may lead to an extra batch dimension (e.g., shape (1, 1, ...)). "
196
+ f"Please ensure that the output of '{func_name}' is not already batched "
197
+ f"to avoid computation errors."
198
+ )
199
+
200
+
201
+ def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
202
+ connection_destinations = [connection_destination for connection_destination in connection_destinations
203
+ if not isinstance(connection_destination, SamplePreprocessResponse)]
204
+
205
+ main_node_mapping = NodeMapping(name, node_mapping_type, user_unique_name, arg_names=arg_names)
206
+
207
+ node_inputs = {}
208
+ for arg_name, destination in zip(arg_names, connection_destinations):
209
+ node_inputs[arg_name] = destination.node_mapping
210
+
211
+ leap_binder.mapping_connections.append(NodeConnection(main_node_mapping, node_inputs))
17
212
 
18
213
 
19
214
  def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
20
215
  for user_unique_name, connection_destinations in connects_to.items():
21
- main_node_mapping = NodeMapping(name, node_mapping_type, user_unique_name, arg_names=arg_names)
22
- node_inputs = {}
23
- for arg_name, destination in zip(arg_names, connection_destinations):
24
- node_inputs[arg_name] = destination.node_mapping
216
+ _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type)
217
+
218
+
219
+ def tensorleap_integration_test():
220
+ def decorating_function(integration_test_function: Callable):
221
+ leap_binder.integration_test_func = integration_test_function
222
+
223
+ def _validate_input_args(*args, **kwargs):
224
+ sample_id, preprocess_response = args
225
+ assert type(sample_id) == preprocess_response.sample_id_type, (
226
+ f"tensorleap_integration_test validation failed: "
227
+ f"sample_id type ({type(sample_id).__name__}) does not match the expected "
228
+ f"type ({preprocess_response.sample_id_type}) from the PreprocessResponse."
229
+ )
230
+
231
+ def inner(*args, **kwargs):
232
+ if not _call_from_tl_platform:
233
+ set_current('tensorleap_integration_test')
234
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
235
+ func_name='integration_test', expected_names=["idx", "preprocess"], **kwargs)
236
+ _validate_input_args(*args, **kwargs)
237
+
238
+ global _called_from_inside_tl_integration_test_decorator
239
+ # Clear integration test events for new test
240
+ try:
241
+ clear_integration_events()
242
+ except Exception as e:
243
+ logger.debug(f"Failed to clear integration events: {e}")
244
+ try:
245
+ _called_from_inside_tl_integration_test_decorator = True
246
+ if not _call_from_tl_platform:
247
+ update_env_params_func("tensorleap_integration_test",
248
+ "v") # put here because otherwise it will become v only if it finishes all the script
249
+ ret = integration_test_function(*args, **kwargs)
250
+
251
+ try:
252
+ os.environ[mapping_runtime_mode_env_var_mame] = 'True'
253
+ integration_test_function(None, PreprocessResponse(state=DataStateType.training, length=0))
254
+ except Exception as e:
255
+ import traceback
256
+ first_tb = traceback.extract_tb(e.__traceback__)[-1]
257
+ file_name = Path(first_tb.filename).name
258
+ line_number = first_tb.lineno
259
+ if isinstance(e, TypeError) and 'is not subscriptable' in str(e):
260
+ update_env_params_func("code_mapping", "x")
261
+ raise (f'Invalid integration code. File {file_name}, line {line_number}: '
262
+ f"indexing is supported only on the model's predictions inside the integration test. Please remove this indexing operation usage from the integration test code.")
263
+ else:
264
+ update_env_params_func("code_mapping", "x")
265
+
266
+ raise (f'Invalid integration code. File {file_name}, line {line_number}: '
267
+ f'Integration test is only allowed to call Tensorleap decorators. '
268
+ f'Ensure any arithmetics, external library use, Python logic is placed within Tensorleap decoders')
269
+ finally:
270
+ if mapping_runtime_mode_env_var_mame in os.environ:
271
+ del os.environ[mapping_runtime_mode_env_var_mame]
272
+ finally:
273
+ _called_from_inside_tl_integration_test_decorator = False
274
+
275
+ leap_binder.check()
276
+
277
+ return inner
278
+
279
+ return decorating_function
280
+
281
+
282
+ def _safe_get_item(key):
283
+ try:
284
+ return NodeMappingType[f'Input{str(key)}']
285
+ except ValueError:
286
+ raise Exception(f'Tensorleap currently supports models with no more then 10 inputs')
287
+
288
+
289
+ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = _UNSET):
290
+ prediction_types_was_provided = prediction_types is not _UNSET
291
+
292
+ if not prediction_types_was_provided:
293
+ prediction_types = []
294
+ if not _call_from_tl_platform:
295
+ store_warning_by_param(
296
+ param_name="prediction_types",
297
+ user_func_name="tensorleap_load_model",
298
+ default_value=prediction_types,
299
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
300
+ )
301
+ assert isinstance(prediction_types, list), (
302
+ f"tensorleap_load_model validation failed: "
303
+ f" prediction_types is an optional argument of type List[PredictionTypeHandler]] but got {type(prediction_types).__name__}."
304
+ )
305
+ for i, prediction_type in enumerate(prediction_types):
306
+ assert isinstance(prediction_type, PredictionTypeHandler), (f"tensorleap_load_model validation failed: "
307
+ f" prediction_types at position {i} must be of type PredictionTypeHandler but got {type(prediction_types[i]).__name__}.")
308
+ prediction_type_channel_dim_was_provided = prediction_type.channel_dim != "tl_default_value"
309
+ if not prediction_type_channel_dim_was_provided:
310
+ prediction_types[i].channel_dim = -1
311
+ if not _call_from_tl_platform:
312
+ store_warning_by_param(
313
+ param_name=f"prediction_types[{i}].channel_dim",
314
+ user_func_name="tensorleap_load_model",
315
+ default_value=prediction_types[i].channel_dim,
316
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
317
+ )
318
+ leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
319
+
320
+ def _validate_result(result) -> None:
321
+ valid_types = ["onnxruntime", "keras"]
322
+ err_message = f"tensorleap_load_model validation failed:\nSupported models are Keras and onnxruntime only and non of them was returned."
323
+ validate_output_structure(result, func_name="tensorleap_load_model",
324
+ expected_type_name=[" | ".join(t for t in valid_types)][0])
325
+ try:
326
+ import keras
327
+ except ImportError:
328
+ keras = None
329
+ try:
330
+ import tensorflow as tf
331
+ except ImportError:
332
+ tf = None
333
+ try:
334
+ import onnxruntime
335
+ except ImportError:
336
+ onnxruntime = None
337
+
338
+ if not keras and not onnxruntime:
339
+ raise AssertionError(err_message)
340
+
341
+ is_keras_model = (
342
+ bool(keras and isinstance(result, getattr(keras, "Model", tuple())))
343
+ or bool(tf and isinstance(result, getattr(tf.keras, "Model", tuple())))
344
+ )
345
+ is_onnx_model = bool(onnxruntime and isinstance(result, onnxruntime.InferenceSession))
346
+
347
+ if not any([is_keras_model, is_onnx_model]):
348
+ raise AssertionError(err_message)
349
+
350
+ def decorating_function(load_model_func, prediction_types=prediction_types):
351
+ class TempMapping:
352
+ pass
353
+
354
+ @lru_cache()
355
+ def inner(*args, **kwargs):
356
+ if not _call_from_tl_platform:
357
+ set_current('tensorleap_load_model')
358
+ validate_args_structure(*args, types_order=[],
359
+ func_name='tensorleap_load_model', expected_names=[], **kwargs)
360
+
361
+ class ModelPlaceholder:
362
+ def __init__(self, prediction_types):
363
+ self.model = load_model_func() # TODO- check why this fails on onnx model
364
+ self.prediction_types = prediction_types
365
+ _validate_result(self.model)
366
+
367
+ # keras interface
368
+ def __call__(self, arg):
369
+ ret = self.model(arg)
370
+ self.validate_declared_prediction_types(ret)
371
+ if isinstance(ret, list):
372
+ return [r.numpy() for r in ret]
373
+ return ret.numpy()
374
+
375
+ def validate_declared_prediction_types(self, ret):
376
+ if not (len(self.prediction_types) == len(ret) if isinstance(ret, list) else 1) and len(
377
+ self.prediction_types) != 0:
378
+ if not _call_from_tl_platform:
379
+ update_env_params_func("tensorleap_load_model", "x")
380
+ raise Exception(
381
+ f"tensorleap_load_model validation failed: number of declared prediction types({len(prediction_types)}) != number of model outputs({len(ret) if isinstance(ret, list) else 1})")
382
+
383
+ def _convert_onnx_inputs_to_correct_type(
384
+ self, float_arrays_inputs: Dict[str, np.ndarray]
385
+ ) -> Dict[str, np.ndarray]:
386
+ ONNX_TYPE_TO_NP = {
387
+ "tensor(float)": np.float32,
388
+ "tensor(double)": np.float64,
389
+ "tensor(int64)": np.int64,
390
+ "tensor(int32)": np.int32,
391
+ "tensor(int16)": np.int16,
392
+ "tensor(int8)": np.int8,
393
+ "tensor(uint64)": np.uint64,
394
+ "tensor(uint32)": np.uint32,
395
+ "tensor(uint16)": np.uint16,
396
+ "tensor(uint8)": np.uint8,
397
+ "tensor(bool)": np.bool_,
398
+ }
399
+
400
+ """
401
+ Cast user-provided NumPy inputs to match the dtypes/shapes
402
+ expected by an ONNX Runtime InferenceSession.
403
+ """
404
+ coerced = {}
405
+ meta = {i.name: i for i in self.model.get_inputs()}
406
+
407
+ for name, arr in float_arrays_inputs.items():
408
+ if name not in meta:
409
+ # Keep as-is unless extra inputs are disallowed
410
+ coerced[name] = arr
411
+ continue
412
+
413
+ info = meta[name]
414
+ onnx_type = info.type
415
+ want_dtype = ONNX_TYPE_TO_NP.get(onnx_type)
416
+
417
+ if want_dtype is None:
418
+ raise TypeError(f"Unsupported ONNX input type: {onnx_type}")
419
+
420
+ # Cast dtype if needed
421
+ if arr.dtype != want_dtype:
422
+ arr = arr.astype(want_dtype, copy=False)
423
+
424
+ coerced[name] = arr
425
+
426
+ # Verify required inputs are present
427
+ missing = [n for n in meta if n not in coerced]
428
+ if missing:
429
+ raise KeyError(f"Missing required input(s): {sorted(missing)}")
430
+
431
+ return coerced
432
+
433
+ # onnx runtime interface
434
+ def run(self, output_names, input_dict):
435
+ corrected_type_inputs = self._convert_onnx_inputs_to_correct_type(input_dict)
436
+ ret = self.model.run(output_names, corrected_type_inputs)
437
+ self.validate_declared_prediction_types(ret)
438
+ return ret
439
+
440
+ def get_inputs(self):
441
+ return self.model.get_inputs()
442
+
443
+ model_placeholder = ModelPlaceholder(prediction_types)
444
+ if not _call_from_tl_platform:
445
+ update_env_params_func("tensorleap_load_model", "v")
446
+ return model_placeholder
447
+
448
+ def mapping_inner():
449
+ class ModelOutputPlaceholder:
450
+ def __init__(self):
451
+ self.node_mapping = NodeMapping('', NodeMappingType.Prediction0)
452
+
453
+ def __getitem__(self, key):
454
+ assert isinstance(key, int), \
455
+ f'Expected key to be an int, got {type(key)} instead.'
456
+
457
+ ret = TempMapping()
458
+ try:
459
+ ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
460
+ except ValueError as e:
461
+ raise Exception(f'Tensorleap currently supports models with no more then 10 active predictions,'
462
+ f' {key} not supported.')
463
+ return ret
464
+
465
+ class ModelPlaceholder:
466
+
467
+ # keras interface
468
+ def __call__(self, arg):
469
+ if isinstance(arg, list):
470
+ for i, elem in enumerate(arg):
471
+ elem.node_mapping.type = _safe_get_item(i)
472
+ else:
473
+ arg.node_mapping.type = NodeMappingType.Input0
474
+
475
+ return ModelOutputPlaceholder()
476
+
477
+ # onnx runtime interface
478
+ def run(self, output_names, input_dict):
479
+ assert output_names is None
480
+ assert isinstance(input_dict, dict), \
481
+ f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
482
+ for i, (input_key, elem) in enumerate(input_dict.items()):
483
+ if isinstance(input_key, NodeMappingType):
484
+ elem.node_mapping.type = input_key
485
+ else:
486
+ elem.node_mapping.type = _safe_get_item(i)
487
+
488
+ return ModelOutputPlaceholder()
489
+
490
+ def get_inputs(self):
491
+ class FollowIndex:
492
+ def __init__(self, index):
493
+ self.name = _safe_get_item(index)
494
+
495
+ class FollowInputIndex:
496
+ def __init__(self):
497
+ pass
498
+
499
+ def __getitem__(self, index):
500
+ assert isinstance(index, int), \
501
+ f'Expected key to be an int, got {type(index)} instead.'
25
502
 
26
- leap_binder.mapping_connections.append(NodeConnection(main_node_mapping, node_inputs))
503
+ return FollowIndex(index)
504
+
505
+ return FollowInputIndex()
506
+
507
+ return ModelPlaceholder()
508
+
509
+ def final_inner(*args, **kwargs):
510
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
511
+ return mapping_inner()
512
+ else:
513
+ return inner(*args, **kwargs)
514
+
515
+ return final_inner
516
+
517
+ return decorating_function
27
518
 
28
519
 
29
520
  def tensorleap_custom_metric(name: str,
30
- direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
521
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
31
522
  compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
32
523
  connects_to=None):
33
- def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
34
- CustomMultipleReturnCallableInterfaceMultiArgs,
35
- ConfusionMatrixCallableInterfaceMultiArgs]):
524
+ name_to_unique_name = defaultdict(set)
525
+
526
+ def decorating_function(
527
+ user_function: Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs,
528
+ ConfusionMatrixCallableInterfaceMultiArgs]):
529
+ nonlocal direction
530
+
531
+ direction_was_provided = direction is not _UNSET
532
+
533
+ if not direction_was_provided:
534
+ direction = MetricDirection.Downward
535
+ if not _call_from_tl_platform:
536
+ store_warning_by_param(
537
+ param_name="direction",
538
+ user_func_name=user_function.__name__,
539
+ default_value=direction,
540
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/custom-metrics"
541
+
542
+ )
543
+
544
+ def _validate_decorators_signature():
545
+ err_message = f"{user_function.__name__} validation failed.\n"
546
+ if not isinstance(name, str):
547
+ raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
548
+ valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
549
+ if isinstance(direction, MetricDirection):
550
+ if direction not in valid_directions:
551
+ raise ValueError(
552
+ err_message +
553
+ f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
554
+ f"got type {type(direction).__name__}."
555
+ )
556
+ elif isinstance(direction, dict):
557
+ if not all(isinstance(k, str) for k in direction.keys()):
558
+ invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
559
+ raise TypeError(
560
+ err_message +
561
+ f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
562
+ )
563
+ for k, v in direction.items():
564
+ if v not in valid_directions:
565
+ raise ValueError(
566
+ err_message +
567
+ f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
568
+ f"got type {type(v).__name__}."
569
+ )
570
+ else:
571
+ raise TypeError(
572
+ err_message +
573
+ f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
574
+ f"got type {type(direction).__name__}."
575
+ )
576
+ if compute_insights is not None:
577
+ if not isinstance(compute_insights, (bool, dict)):
578
+ raise TypeError(
579
+ err_message +
580
+ f"`compute_insights` must be a bool or a Dict[str, bool], "
581
+ f"got type {type(compute_insights).__name__}."
582
+ )
583
+ if isinstance(compute_insights, dict):
584
+ if not all(isinstance(k, str) for k in compute_insights.keys()):
585
+ invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
586
+ raise TypeError(
587
+ err_message +
588
+ f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
589
+ )
590
+ for k, v in compute_insights.items():
591
+ if not isinstance(v, bool):
592
+ raise TypeError(
593
+ err_message +
594
+ f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
595
+ )
596
+ if connects_to is not None:
597
+ valid_types = (str, list, tuple, set)
598
+ if not isinstance(connects_to, valid_types):
599
+ raise TypeError(
600
+ err_message +
601
+ f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
602
+ )
603
+ if isinstance(connects_to, (list, tuple, set)):
604
+ invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
605
+ if invalid_elems:
606
+ raise TypeError(
607
+ err_message +
608
+ f"All elements in `connects_to` must be strings, "
609
+ f"but found element types: {invalid_elems}."
610
+ )
611
+
612
+ _validate_decorators_signature()
613
+
36
614
  for metric_handler in leap_binder.setup_container.metrics:
37
615
  if metric_handler.metric_handler_data.name == name:
38
616
  raise Exception(f'Metric with name {name} already exists. '
39
617
  f'Please choose another')
40
618
 
41
- leap_binder.add_custom_metric(user_function, name, direction, compute_insights)
42
-
43
- if connects_to is not None:
44
- arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
45
- _add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
46
-
47
619
  def _validate_input_args(*args, **kwargs) -> None:
620
+ assert len(args) + len(kwargs) > 0, (
621
+ f"{user_function.__name__}() validation failed: "
622
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
623
+ f"but received none. "
624
+ f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
625
+ )
48
626
  for i, arg in enumerate(args):
49
627
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
50
- f'tensorleap_custom_metric validation failed: '
628
+ f'{user_function.__name__}() validation failed: '
51
629
  f'Argument #{i} should be a numpy array. Got {type(arg)}.')
52
630
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
53
631
  assert arg.shape[0] == leap_binder.batch_size_to_validate, \
54
- (f'tensorleap_custom_metric validation failed: Argument #{i} '
632
+ (f'{user_function.__name__}() validation failed: Argument #{i} '
55
633
  f'first dim should be as the batch size. Got {arg.shape[0]} '
56
634
  f'instead of {leap_binder.batch_size_to_validate}')
57
635
 
58
636
  for _arg_name, arg in kwargs.items():
59
637
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
60
- f'tensorleap_custom_metric validation failed: '
638
+ f'{user_function.__name__}() validation failed: '
61
639
  f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
62
640
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
63
641
  assert arg.shape[0] == leap_binder.batch_size_to_validate, \
64
- (f'tensorleap_custom_metric validation failed: Argument {_arg_name} '
642
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
65
643
  f'first dim should be as the batch size. Got {arg.shape[0]} '
66
644
  f'instead of {leap_binder.batch_size_to_validate}')
67
645
 
68
646
  def _validate_result(result) -> None:
69
- supported_types_message = (f'tensorleap_custom_metric validation failed: '
70
- f'Metric has returned unsupported type. Supported types are List[float], '
71
- f'List[List[ConfusionMatrixElement]], NDArray[np.float32]. ')
647
+ validate_output_structure(result, func_name=user_function.__name__,
648
+ expected_type_name="List[float | int | None | List[ConfusionMatrixElement] ] | NDArray[np.float32] or dictonary with one of these types as its values types")
649
+ supported_types_message = (f'{user_function.__name__}() validation failed: '
650
+ f'{user_function.__name__}() has returned unsupported type.\nSupported types are List[float|int|None], '
651
+ f'List[List[ConfusionMatrixElement]], NDArray[np.float32] or dictonary with one of these types as its values types. ')
72
652
 
73
- def _validate_single_metric(single_metric_result):
653
+ def _validate_single_metric(single_metric_result, key=None):
74
654
  if isinstance(single_metric_result, list):
75
655
  if isinstance(single_metric_result[0], list):
76
- assert isinstance(single_metric_result[0][0], ConfusionMatrixElement), \
77
- f'{supported_types_message}Got List[List[{type(single_metric_result[0][0])}]].'
656
+ assert all(isinstance(cm, ConfusionMatrixElement) for cm in single_metric_result[0]), (
657
+ f"{supported_types_message} "
658
+ f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
659
+ f"List[List[{', '.join(type(cm).__name__ for cm in single_metric_result[0])}]]."
660
+ )
661
+
78
662
  else:
79
- assert isinstance(single_metric_result[0], (
80
- float, int, type(None))), f'{supported_types_message}Got List[{type(single_metric_result[0])}].'
663
+ assert all(isinstance(v, (float, int, type(None), np.float32)) for v in single_metric_result), (
664
+ f"{supported_types_message}\n"
665
+ f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
666
+ f"List[{', '.join(type(v).__name__ for v in single_metric_result)}]."
667
+ )
81
668
  else:
82
669
  assert isinstance(single_metric_result,
83
- np.ndarray), f'{supported_types_message}Got {type(single_metric_result)}.'
84
- assert len(single_metric_result.shape) == 1, (f'tensorleap_custom_metric validation failed: '
670
+ np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
671
+ assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
85
672
  f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
86
673
 
87
674
  if leap_binder.batch_size_to_validate:
88
675
  assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
89
- f'tensorleap_custom_metrix validation failed: The return len should be as the batch size.'
676
+ f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
90
677
 
91
678
  if isinstance(result, dict):
92
679
  for key, value in result.items():
680
+ _validate_single_metric(value, key)
681
+
93
682
  assert isinstance(key, str), \
94
- (f'tensorleap_custom_metric validation failed: '
683
+ (f'{user_function.__name__}() validation failed: '
95
684
  f'Keys in the return dict should be of type str. Got {type(key)}.')
96
685
  _validate_single_metric(value)
97
686
 
98
687
  if isinstance(direction, dict):
99
688
  for direction_key in direction:
100
689
  assert direction_key in result, \
101
- (f'tensorleap_custom_metric validation failed: '
690
+ (f'{user_function.__name__}() validation failed: '
102
691
  f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
103
692
 
104
693
  if compute_insights is not None:
105
694
  assert isinstance(compute_insights, dict), \
106
- (f'tensorleap_custom_metric validation failed: '
695
+ (f'{user_function.__name__}() validation failed: '
107
696
  f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
108
697
 
109
698
  for ci_key in compute_insights:
110
699
  assert ci_key in result, \
111
- (f'tensorleap_custom_metric validation failed: '
700
+ (f'{user_function.__name__}() validation failed: '
112
701
  f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
113
702
 
114
703
  else:
@@ -116,17 +705,71 @@ def tensorleap_custom_metric(name: str,
116
705
 
117
706
  if compute_insights is not None:
118
707
  assert isinstance(compute_insights, bool), \
119
- (f'tensorleap_custom_metric validation failed: '
708
+ (f'{user_function.__name__}() validation failed: '
120
709
  f'compute_insights should be boolean. Got {type(compute_insights)}.')
121
710
 
711
+ @functools.wraps(user_function)
712
+ def inner_without_validate(*args, **kwargs):
713
+ global _called_from_inside_tl_decorator
714
+ _called_from_inside_tl_decorator += 1
715
+
716
+ try:
717
+ result = user_function(*args, **kwargs)
718
+ finally:
719
+ _called_from_inside_tl_decorator -= 1
720
+
721
+ return result
722
+
723
+ try:
724
+ inner_without_validate.__signature__ = inspect.signature(user_function)
725
+ except (TypeError, ValueError):
726
+ pass
727
+
728
+ leap_binder.add_custom_metric(inner_without_validate, name, direction, compute_insights)
729
+
730
+ if connects_to is not None:
731
+ arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
732
+ _add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
122
733
 
123
734
  def inner(*args, **kwargs):
735
+ if not _call_from_tl_platform:
736
+ set_current('tensorleap_custom_metric')
124
737
  _validate_input_args(*args, **kwargs)
125
- result = user_function(*args, **kwargs)
738
+
739
+ result = inner_without_validate(*args, **kwargs)
740
+
126
741
  _validate_result(result)
742
+ if not _call_from_tl_platform:
743
+ update_env_params_func("tensorleap_custom_metric", "v")
127
744
  return result
128
745
 
129
- return inner
746
+ def mapping_inner(*args, **kwargs):
747
+ user_unique_name = mapping_inner.name
748
+ if 'user_unique_name' in kwargs:
749
+ user_unique_name = kwargs['user_unique_name']
750
+
751
+ ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
752
+ ordered_connections = list(args) + ordered_connections
753
+
754
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
755
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
756
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
757
+
758
+ _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
759
+ mapping_inner.name, NodeMappingType.Metric)
760
+
761
+ return None
762
+
763
+ mapping_inner.arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
764
+ mapping_inner.name = name
765
+
766
+ def final_inner(*args, **kwargs):
767
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
768
+ return mapping_inner(*args, **kwargs)
769
+ else:
770
+ return inner(*args, **kwargs)
771
+
772
+ return final_inner
130
773
 
131
774
  return decorating_function
132
775
 
@@ -134,35 +777,40 @@ def tensorleap_custom_metric(name: str,
134
777
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
135
778
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
136
779
  connects_to=None):
780
+ name_to_unique_name = defaultdict(set)
781
+
137
782
  def decorating_function(user_function: VisualizerCallableInterface):
783
+ assert isinstance(visualizer_type, LeapDataType), (f"{user_function.__name__} validation failed: "
784
+ f"visualizer_type should be of type {LeapDataType.__name__} but got {type(visualizer_type)}"
785
+ )
138
786
  for viz_handler in leap_binder.setup_container.visualizers:
139
787
  if viz_handler.visualizer_handler_data.name == name:
140
788
  raise Exception(f'Visualizer with name {name} already exists. '
141
789
  f'Please choose another')
142
790
 
143
- leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function)
144
-
145
- if connects_to is not None:
146
- arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
147
- _add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
148
-
149
791
  def _validate_input_args(*args, **kwargs):
792
+ assert len(args) + len(kwargs) > 0, (
793
+ f"{user_function.__name__}() validation failed: "
794
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
795
+ f"but received none. "
796
+ f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
797
+ )
150
798
  for i, arg in enumerate(args):
151
799
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
152
- f'tensorleap_custom_visualizer validation failed: '
800
+ f'{user_function.__name__}() validation failed: '
153
801
  f'Argument #{i} should be a numpy array. Got {type(arg)}.')
154
802
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
155
803
  assert arg.shape[0] != leap_binder.batch_size_to_validate, \
156
- (f'tensorleap_custom_visualizer validation failed: '
804
+ (f'{user_function.__name__}() validation failed: '
157
805
  f'Argument #{i} should be without batch dimension. ')
158
806
 
159
807
  for _arg_name, arg in kwargs.items():
160
808
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
161
- f'tensorleap_custom_visualizer validation failed: '
809
+ f'{user_function.__name__}() validation failed: '
162
810
  f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
163
811
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
164
812
  assert arg.shape[0] != leap_binder.batch_size_to_validate, \
165
- (f'tensorleap_custom_visualizer validation failed: Argument {_arg_name} '
813
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
166
814
  f'should be without batch dimension. ')
167
815
 
168
816
  def _validate_result(result):
@@ -176,17 +824,74 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
176
824
  LeapDataType.ImageWithBBox: LeapImageWithBBox,
177
825
  LeapDataType.ImageWithHeatmap: LeapImageWithHeatmap
178
826
  }
827
+ validate_output_structure(result, func_name=user_function.__name__,
828
+ expected_type_name=result_type_map[visualizer_type])
829
+
179
830
  assert isinstance(result, result_type_map[visualizer_type]), \
180
- (f'tensorleap_custom_visualizer validation failed: '
831
+ (f'{user_function.__name__}() validation failed: '
181
832
  f'The return type should be {result_type_map[visualizer_type]}. Got {type(result)}.')
182
833
 
834
+ @functools.wraps(user_function)
835
+ def inner_without_validate(*args, **kwargs):
836
+ global _called_from_inside_tl_decorator
837
+ _called_from_inside_tl_decorator += 1
838
+
839
+ try:
840
+ result = user_function(*args, **kwargs)
841
+ finally:
842
+ _called_from_inside_tl_decorator -= 1
843
+
844
+ return result
845
+
846
+ try:
847
+ inner_without_validate.__signature__ = inspect.signature(user_function)
848
+ except (TypeError, ValueError):
849
+ pass
850
+
851
+ leap_binder.set_visualizer(inner_without_validate, name, visualizer_type, heatmap_function)
852
+
853
+ if connects_to is not None:
854
+ arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
855
+ _add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
856
+
183
857
  def inner(*args, **kwargs):
858
+ if not _call_from_tl_platform:
859
+ set_current('tensorleap_custom_visualizer')
184
860
  _validate_input_args(*args, **kwargs)
185
- result = user_function(*args, **kwargs)
861
+
862
+ result = inner_without_validate(*args, **kwargs)
863
+
186
864
  _validate_result(result)
865
+ if not _call_from_tl_platform:
866
+ update_env_params_func("tensorleap_custom_visualizer", "v")
187
867
  return result
188
868
 
189
- return inner
869
+ def mapping_inner(*args, **kwargs):
870
+ user_unique_name = mapping_inner.name
871
+ if 'user_unique_name' in kwargs:
872
+ user_unique_name = kwargs['user_unique_name']
873
+
874
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
875
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
876
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
877
+
878
+ ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
879
+ ordered_connections = list(args) + ordered_connections
880
+ _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
881
+ mapping_inner.name, NodeMappingType.Visualizer)
882
+
883
+ return None
884
+
885
+ mapping_inner.arg_names = leap_binder.setup_container.visualizers[-1].visualizer_handler_data.arg_names
886
+ mapping_inner.name = name
887
+
888
+ def final_inner(*args, **kwargs):
889
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
890
+ return mapping_inner(*args, **kwargs)
891
+ else:
892
+ return inner(*args, **kwargs)
893
+
894
+ return final_inner
190
895
 
191
896
  return decorating_function
192
897
 
@@ -199,38 +904,105 @@ def tensorleap_metadata(
199
904
  raise Exception(f'Metadata with name {name} already exists. '
200
905
  f'Please choose another')
201
906
 
202
- leap_binder.set_metadata(user_function, name, metadata_type)
203
-
204
907
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
205
- assert isinstance(sample_id, (int, str)), \
206
- (f'tensorleap_metadata validation failed: '
207
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
208
- assert isinstance(preprocess_response, PreprocessResponse), \
209
- (f'tensorleap_metadata validation failed: '
210
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
211
908
  assert type(sample_id) == preprocess_response.sample_id_type, \
212
- (f'tensorleap_metadata validation failed: '
909
+ (f'{user_function.__name__}() validation failed: '
213
910
  f'Argument sample_id should be as the same type as defined in the preprocess response '
214
911
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
215
912
 
216
913
  def _validate_result(result):
217
914
  supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
218
915
  np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
916
+ validate_output_structure(result, func_name=user_function.__name__,
917
+ expected_type_name=supported_result_types)
219
918
  assert isinstance(result, supported_result_types), \
220
- (f'tensorleap_metadata validation failed: '
919
+ (f'{user_function.__name__}() validation failed: '
221
920
  f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
222
921
  if isinstance(result, dict):
223
922
  for key, value in result.items():
224
923
  assert isinstance(key, str), \
225
- (f'tensorleap_metadata validation failed: '
924
+ (f'{user_function.__name__}() validation failed: '
226
925
  f'Keys in the return dict should be of type str. Got {type(key)}.')
227
926
  assert isinstance(value, supported_result_types), \
228
- (f'tensorleap_metadata validation failed: '
927
+ (f'{user_function.__name__}() validation failed: '
229
928
  f'Values in the return dict should be of type {str(supported_result_types)}. Got {type(value)}.')
230
929
 
930
+ def inner_without_validate(sample_id, preprocess_response):
931
+
932
+ global _called_from_inside_tl_decorator
933
+ _called_from_inside_tl_decorator += 1
934
+
935
+ try:
936
+ result = user_function(sample_id, preprocess_response)
937
+ finally:
938
+ _called_from_inside_tl_decorator -= 1
939
+
940
+ return result
941
+
942
+ leap_binder.set_metadata(inner_without_validate, name, metadata_type)
943
+
944
+ def inner(*args, **kwargs):
945
+ if not _call_from_tl_platform:
946
+ set_current('tensorleap_metadata')
947
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
948
+ return None
949
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
950
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
951
+ sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
952
+ _validate_input_args(sample_id, preprocess_response)
953
+
954
+ result = inner_without_validate(sample_id, preprocess_response)
955
+
956
+ _validate_result(result)
957
+ if not _call_from_tl_platform:
958
+ update_env_params_func("tensorleap_metadata", "v")
959
+ return result
960
+
961
+ return inner
962
+
963
+ return decorating_function
964
+
965
+
966
+ def tensorleap_custom_latent_space():
967
+ def decorating_function(user_function: SectionCallableInterface):
968
+ def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
969
+ assert isinstance(sample_id, (int, str)), \
970
+ (f'tensorleap_custom_latent_space validation failed: '
971
+ f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
972
+ assert isinstance(preprocess_response, PreprocessResponse), \
973
+ (f'tensorleap_custom_latent_space validation failed: '
974
+ f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
975
+ assert type(sample_id) == preprocess_response.sample_id_type, \
976
+ (f'tensorleap_custom_latent_space validation failed: '
977
+ f'Argument sample_id should be as the same type as defined in the preprocess response '
978
+ f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
979
+
980
+ def _validate_result(result):
981
+ assert isinstance(result, np.ndarray), \
982
+ (f'tensorleap_custom_loss validation failed: '
983
+ f'The return type should be a numpy array. Got {type(result)}.')
984
+
985
+ def inner_without_validate(sample_id, preprocess_response):
986
+ global _called_from_inside_tl_decorator
987
+ _called_from_inside_tl_decorator += 1
988
+
989
+ try:
990
+ result = user_function(sample_id, preprocess_response)
991
+ finally:
992
+ _called_from_inside_tl_decorator -= 1
993
+
994
+ return result
995
+
996
+ leap_binder.set_custom_latent_space(inner_without_validate)
997
+
231
998
  def inner(sample_id, preprocess_response):
999
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1000
+ return None
1001
+
232
1002
  _validate_input_args(sample_id, preprocess_response)
233
- result = user_function(sample_id, preprocess_response)
1003
+
1004
+ result = inner_without_validate(sample_id, preprocess_response)
1005
+
234
1006
  _validate_result(result)
235
1007
  return result
236
1008
 
@@ -244,33 +1016,54 @@ def tensorleap_preprocess():
244
1016
  leap_binder.set_preprocess(user_function)
245
1017
 
246
1018
  def _validate_input_args(*args, **kwargs):
247
- assert len(args) == 0 and len(kwargs) == 0, \
248
- (f'tensorleap_preprocess validation failed: '
1019
+ assert len(args) + len(kwargs) == 0, \
1020
+ (f'{user_function.__name__}() validation failed: '
249
1021
  f'The function should not take any arguments. Got {args} and {kwargs}.')
250
1022
 
251
1023
  def _validate_result(result):
252
- assert isinstance(result, list), \
253
- (f'tensorleap_preprocess validation failed: '
254
- f'The return type should be a list. Got {type(result)}.')
1024
+ assert isinstance(result, list), (
1025
+ f"{user_function.__name__}() validation failed: expected return type list[{PreprocessResponse.__name__}]"
1026
+ f"(e.g., [PreprocessResponse1, PreprocessResponse2, ...]), but returned type is {type(result).__name__}."
1027
+ if not isinstance(result, tuple)
1028
+ else f"{user_function.__name__}() validation failed: expected to return a single list[{PreprocessResponse.__name__}] object, "
1029
+ f"but returned {len(result)} objects instead."
1030
+ )
255
1031
  for i, response in enumerate(result):
256
1032
  assert isinstance(response, PreprocessResponse), \
257
- (f'tensorleap_preprocess validation failed: '
1033
+ (f'{user_function.__name__}() validation failed: '
258
1034
  f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
259
1035
  assert len(set(result)) == len(result), \
260
- (f'tensorleap_preprocess validation failed: '
1036
+ (f'{user_function.__name__}() validation failed: '
261
1037
  f'The return list should not contain duplicate PreprocessResponse objects.')
262
1038
 
263
1039
  def inner(*args, **kwargs):
1040
+ if not _call_from_tl_platform:
1041
+ set_current('tensorleap_metadata')
1042
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1043
+ return [None, None, None, None]
1044
+
264
1045
  _validate_input_args(*args, **kwargs)
265
1046
  result = user_function()
266
1047
  _validate_result(result)
1048
+
1049
+ # Emit integration test event once per test
1050
+ try:
1051
+ emit_integration_event_once(AnalyticsEvent.PREPROCESS_INTEGRATION_TEST, {
1052
+ 'preprocess_responses_count': len(result)
1053
+ })
1054
+ except Exception as e:
1055
+ logger.debug(f"Failed to emit preprocess integration test event: {e}")
1056
+ if not _call_from_tl_platform:
1057
+ update_env_params_func("tensorleap_preprocess", "v")
267
1058
  return result
268
1059
 
269
1060
  return inner
270
1061
 
271
1062
  return decorating_function
272
1063
 
273
- def tensorleap_element_instance_preprocess(instance_mask_encoder: Callable[[str, PreprocessResponse], List[PreprocessResponse]]):
1064
+
1065
+ def tensorleap_element_instance_preprocess(
1066
+ instance_length_encoder: InstanceLengthCallableInterface):
274
1067
  def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
275
1068
  def user_function_instance() -> List[PreprocessResponse]:
276
1069
  result = user_function()
@@ -279,22 +1072,46 @@ def tensorleap_element_instance_preprocess(instance_mask_encoder: Callable[[str,
279
1072
  instance_to_sample_ids_mappings = {}
280
1073
  all_sample_ids = preprocess_response.sample_ids.copy()
281
1074
  for sample_id in preprocess_response.sample_ids:
282
- instances_masks = instance_mask_encoder(sample_id, preprocess_response)
283
- instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(len(instances_masks))]
1075
+ instances_length = instance_length_encoder(sample_id, preprocess_response)
1076
+ instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(instances_length)]
284
1077
  sample_ids_to_instance_mappings[sample_id] = instances_ids
285
1078
  instance_to_sample_ids_mappings[sample_id] = sample_id
286
1079
  for instance_id in instances_ids:
287
1080
  instance_to_sample_ids_mappings[instance_id] = sample_id
288
1081
  all_sample_ids.extend(instances_ids)
1082
+ preprocess_response.length = len(all_sample_ids)
289
1083
  preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
290
1084
  preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
291
1085
  preprocess_response.sample_ids = all_sample_ids
292
1086
  return result
293
1087
 
294
- def metadata_is_instance(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
295
- return '0'
1088
+ # def extract_extra_instance_metadata():
1089
+ # result = user_function()
1090
+ # for preprocess_response in result:
1091
+ # for sample_id in preprocess_response.sample_ids:
1092
+ # instances_length = instance_length_encoder(sample_id, preprocess_response)
1093
+ # if instances_length > 0:
1094
+ # element_instance = instance_mask_encoder(sample_id, preprocess_response, 0)
1095
+ # instance_metadata = element_instance.instance_metadata
1096
+ # if instance_metadata is None:
1097
+ # return {}
1098
+ # return instance_metadata
1099
+ # return {}
1100
+
1101
+
1102
+ def builtin_instance_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
1103
+ return {'is_instance': '0', 'original_sample_id': idx, 'instance_name': 'none'}
1104
+
1105
+ # def builtin_instance_extra_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
1106
+ # instance_metadata = extract_extra_instance_metadata()
1107
+ # for key, value in instance_metadata.items():
1108
+ # instance_metadata[key] = 'unset'
1109
+ # return instance_metadata
1110
+
296
1111
  leap_binder.set_preprocess(user_function_instance)
297
- leap_binder.set_metadata(metadata_is_instance, "metadata_is_instance")
1112
+ leap_binder.set_metadata(builtin_instance_metadata, "builtin_instance_metadata")
1113
+ # leap_binder.set_metadata(builtin_instance_extra_metadata, "builtin_instance_extra_metadata")
1114
+
298
1115
 
299
1116
  def _validate_input_args(*args, **kwargs):
300
1117
  assert len(args) == 0 and len(kwargs) == 0, \
@@ -314,9 +1131,15 @@ def tensorleap_element_instance_preprocess(instance_mask_encoder: Callable[[str,
314
1131
  f'The return list should not contain duplicate PreprocessResponse objects.')
315
1132
 
316
1133
  def inner(*args, **kwargs):
1134
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1135
+ return [None, None, None, None]
1136
+
317
1137
  _validate_input_args(*args, **kwargs)
1138
+
318
1139
  result = user_function_instance()
319
1140
  _validate_result(result)
1141
+ if not _call_from_tl_platform:
1142
+ update_env_params_func("tensorleap_preprocess", "v")
320
1143
  return result
321
1144
 
322
1145
  return inner
@@ -351,9 +1174,7 @@ def tensorleap_unlabeled_preprocess():
351
1174
 
352
1175
  def tensorleap_instances_masks_encoder(name: str):
353
1176
  def decorating_function(user_function: InstanceCallableInterface):
354
- leap_binder.set_instance_masks(user_function, name)
355
-
356
- def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
1177
+ def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: int):
357
1178
  assert isinstance(sample_id, str), \
358
1179
  (f'tensorleap_instances_masks_encoder validation failed: '
359
1180
  f'Argument sample_id should be str. Got {type(sample_id)}.')
@@ -364,15 +1185,82 @@ def tensorleap_instances_masks_encoder(name: str):
364
1185
  (f'tensorleap_instances_masks_encoder validation failed: '
365
1186
  f'Argument sample_id should be as the same type as defined in the preprocess response '
366
1187
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
1188
+ assert isinstance(instance_id, int), \
1189
+ (f'tensorleap_instances_masks_encoder validation failed: '
1190
+ f'Argument instance_id should be int. Got {type(instance_id)}.')
367
1191
 
368
1192
  def _validate_result(result):
369
- assert isinstance(result, list), \
1193
+ assert isinstance(result, ElementInstance) or (result is None), \
370
1194
  (f'tensorleap_instances_masks_encoder validation failed: '
371
- f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
1195
+ f'Unsupported return type. Should be a ElementInstance or None. Got {type(result)}.')
1196
+
1197
+ def inner_without_validate(sample_id, preprocess_response, instance_id):
1198
+ global _called_from_inside_tl_decorator
1199
+ _called_from_inside_tl_decorator += 1
1200
+
1201
+ try:
1202
+ result = user_function(sample_id, preprocess_response, instance_id)
1203
+ finally:
1204
+ _called_from_inside_tl_decorator -= 1
1205
+
1206
+ return result
1207
+
1208
+ leap_binder.set_instance_masks(inner_without_validate, name)
1209
+
1210
+ def inner(sample_id, preprocess_response, instance_id):
1211
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1212
+ return None
1213
+
1214
+ _validate_input_args(sample_id, preprocess_response, instance_id)
1215
+
1216
+ result = inner_without_validate(sample_id, preprocess_response, instance_id)
1217
+
1218
+ _validate_result(result)
1219
+ return result
1220
+
1221
+ return inner
1222
+
1223
+ return decorating_function
1224
+
1225
+
1226
+ def tensorleap_instances_length_encoder(name: str):
1227
+ def decorating_function(user_function: InstanceLengthCallableInterface):
1228
+ def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
1229
+ assert isinstance(sample_id, (str, int)), \
1230
+ (f'tensorleap_instances_length_encoder validation failed: '
1231
+ f'Argument sample_id should be str. Got {type(sample_id)}.')
1232
+ assert isinstance(preprocess_response, PreprocessResponse), \
1233
+ (f'tensorleap_instances_length_encoder validation failed: '
1234
+ f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
1235
+ assert type(sample_id) == preprocess_response.sample_id_type, \
1236
+ (f'tensorleap_instances_length_encoder validation failed: '
1237
+ f'Argument sample_id should be as the same type as defined in the preprocess response '
1238
+ f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
1239
+
1240
+ def _validate_result(result):
1241
+ assert isinstance(result, int), \
1242
+ (f'tensorleap_instances_length_encoder validation failed: '
1243
+ f'Unsupported return type. Should be a int. Got {type(result)}.')
1244
+
1245
+ def inner_without_validate(sample_id, preprocess_response):
1246
+ global _called_from_inside_tl_decorator
1247
+ _called_from_inside_tl_decorator += 1
1248
+
1249
+ try:
1250
+ result = user_function(sample_id, preprocess_response)
1251
+ finally:
1252
+ _called_from_inside_tl_decorator -= 1
1253
+
1254
+ return result
372
1255
 
373
1256
  def inner(sample_id, preprocess_response):
1257
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1258
+ return None
1259
+
374
1260
  _validate_input_args(sample_id, preprocess_response)
375
- result = user_function(sample_id, preprocess_response)
1261
+
1262
+ result = inner_without_validate(sample_id, preprocess_response)
1263
+
376
1264
  _validate_result(result)
377
1265
  return result
378
1266
 
@@ -381,43 +1269,87 @@ def tensorleap_instances_masks_encoder(name: str):
381
1269
  return decorating_function
382
1270
 
383
1271
 
384
- def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
1272
+ def tensorleap_input_encoder(name: str, channel_dim=_UNSET, model_input_index=None):
385
1273
  def decorating_function(user_function: SectionCallableInterface):
386
1274
  for input_handler in leap_binder.setup_container.inputs:
387
1275
  if input_handler.name == name:
388
1276
  raise Exception(f'Input with name {name} already exists. '
389
1277
  f'Please choose another')
1278
+ nonlocal channel_dim
1279
+
1280
+ channel_dim_was_provided = channel_dim is not _UNSET
1281
+
1282
+ if not channel_dim_was_provided:
1283
+ channel_dim = -1
1284
+ if not _call_from_tl_platform:
1285
+ store_warning_by_param(
1286
+ param_name="channel_dim",
1287
+ user_func_name=user_function.__name__,
1288
+ default_value=channel_dim,
1289
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/input-encoder"
1290
+
1291
+ )
1292
+
390
1293
  if channel_dim <= 0 and channel_dim != -1:
391
1294
  raise Exception(f"Channel dim for input {name} is expected to be either -1 or positive")
392
1295
 
393
- leap_binder.set_input(user_function, name, channel_dim=channel_dim)
394
-
395
1296
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
396
- assert isinstance(sample_id, (int, str)), \
397
- (f'tensorleap_input_encoder validation failed: '
398
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
399
- assert isinstance(preprocess_response, PreprocessResponse), \
400
- (f'tensorleap_input_encoder validation failed: '
401
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
402
1297
  assert type(sample_id) == preprocess_response.sample_id_type, \
403
- (f'tensorleap_input_encoder validation failed: '
1298
+ (f'{user_function.__name__}() validation failed: '
404
1299
  f'Argument sample_id should be as the same type as defined in the preprocess response '
405
1300
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
406
1301
 
407
1302
  def _validate_result(result):
1303
+ validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray")
408
1304
  assert isinstance(result, np.ndarray), \
409
- (f'tensorleap_input_encoder validation failed: '
1305
+ (f'{user_function.__name__}() validation failed: '
410
1306
  f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
411
1307
  assert result.dtype == np.float32, \
412
- (f'tensorleap_input_encoder validation failed: '
1308
+ (f'{user_function.__name__}() validation failed: '
413
1309
  f'The return type should be a numpy array of type float32. Got {result.dtype}.')
414
- assert channel_dim - 1 <= len(result.shape), (f'tensorleap_input_encoder validation failed: '
1310
+ assert channel_dim - 1 <= len(result.shape), (f'{user_function.__name__}() validation failed: '
415
1311
  f'The channel_dim ({channel_dim}) should be <= to the rank of the resulting input rank ({len(result.shape)}).')
416
1312
 
417
- def inner(sample_id, preprocess_response):
1313
+ def inner_without_validate(sample_id, preprocess_response):
1314
+ global _called_from_inside_tl_decorator
1315
+ _called_from_inside_tl_decorator += 1
1316
+
1317
+ try:
1318
+ result = user_function(sample_id, preprocess_response)
1319
+ finally:
1320
+ _called_from_inside_tl_decorator -= 1
1321
+
1322
+ return result
1323
+
1324
+ leap_binder.set_input(inner_without_validate, name, channel_dim=channel_dim)
1325
+
1326
+ def inner(*args, **kwargs):
1327
+ if not _call_from_tl_platform:
1328
+ set_current("tensorleap_input_encoder")
1329
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
1330
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
1331
+ sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
418
1332
  _validate_input_args(sample_id, preprocess_response)
419
- result = user_function(sample_id, preprocess_response)
1333
+
1334
+ result = inner_without_validate(sample_id, preprocess_response)
1335
+
420
1336
  _validate_result(result)
1337
+
1338
+ if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
1339
+ batch_warning(result, user_function.__name__)
1340
+ result = np.expand_dims(result, axis=0)
1341
+ # Emit integration test event once per test
1342
+ try:
1343
+ emit_integration_event_once(AnalyticsEvent.INPUT_ENCODER_INTEGRATION_TEST, {
1344
+ 'encoder_name': name,
1345
+ 'channel_dim': channel_dim,
1346
+ 'model_input_index': model_input_index
1347
+ })
1348
+ except Exception as e:
1349
+ logger.debug(f"Failed to emit input_encoder integration test event: {e}")
1350
+ if not _call_from_tl_platform:
1351
+ update_env_params_func("tensorleap_input_encoder", "v")
1352
+
421
1353
  return result
422
1354
 
423
1355
  node_mapping_type = NodeMappingType.Input
@@ -425,7 +1357,27 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
425
1357
  node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
426
1358
  inner.node_mapping = NodeMapping(name, node_mapping_type)
427
1359
 
428
- return inner
1360
+ def mapping_inner(*args, **kwargs):
1361
+ class TempMapping:
1362
+ pass
1363
+
1364
+ ret = TempMapping()
1365
+ ret.node_mapping = mapping_inner.node_mapping
1366
+
1367
+ leap_binder.mapping_connections.append(NodeConnection(mapping_inner.node_mapping, None))
1368
+ return ret
1369
+
1370
+ mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
1371
+
1372
+ def final_inner(*args, **kwargs):
1373
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1374
+ return mapping_inner(*args, **kwargs)
1375
+ else:
1376
+ return inner(*args, **kwargs)
1377
+
1378
+ final_inner.node_mapping = NodeMapping(name, node_mapping_type)
1379
+
1380
+ return final_inner
429
1381
 
430
1382
  return decorating_function
431
1383
 
@@ -437,95 +1389,187 @@ def tensorleap_gt_encoder(name: str):
437
1389
  raise Exception(f'GT with name {name} already exists. '
438
1390
  f'Please choose another')
439
1391
 
440
- leap_binder.set_ground_truth(user_function, name)
441
-
442
1392
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
443
- assert isinstance(sample_id, (int, str)), \
444
- (f'tensorleap_gt_encoder validation failed: '
445
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
446
- assert isinstance(preprocess_response, PreprocessResponse), \
447
- (f'tensorleap_gt_encoder validation failed: '
448
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
449
1393
  assert type(sample_id) == preprocess_response.sample_id_type, \
450
- (f'tensorleap_gt_encoder validation failed: '
1394
+ (f'{user_function.__name__}() validation failed: '
451
1395
  f'Argument sample_id should be as the same type as defined in the preprocess response '
452
1396
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
453
1397
 
454
1398
  def _validate_result(result):
1399
+ validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray",
1400
+ gt_flag=True)
455
1401
  assert isinstance(result, np.ndarray), \
456
- (f'tensorleap_gt_encoder validation failed: '
1402
+ (f'{user_function.__name__}() validation failed: '
457
1403
  f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
458
1404
  assert result.dtype == np.float32, \
459
- (f'tensorleap_gt_encoder validation failed: '
1405
+ (f'{user_function.__name__}() validation failed: '
460
1406
  f'The return type should be a numpy array of type float32. Got {result.dtype}.')
461
1407
 
462
- def inner(sample_id, preprocess_response):
1408
+ def inner_without_validate(sample_id, preprocess_response):
1409
+ global _called_from_inside_tl_decorator
1410
+ _called_from_inside_tl_decorator += 1
1411
+
1412
+ try:
1413
+ result = user_function(sample_id, preprocess_response)
1414
+ finally:
1415
+ _called_from_inside_tl_decorator -= 1
1416
+
1417
+ return result
1418
+
1419
+ leap_binder.set_ground_truth(inner_without_validate, name)
1420
+
1421
+ def inner(*args, **kwargs):
1422
+ if not _call_from_tl_platform:
1423
+ set_current("tensorleap_gt_encoder")
1424
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
1425
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
1426
+ sample_id, preprocess_response = args
463
1427
  _validate_input_args(sample_id, preprocess_response)
464
- result = user_function(sample_id, preprocess_response)
1428
+
1429
+ result = inner_without_validate(sample_id, preprocess_response)
1430
+
465
1431
  _validate_result(result)
1432
+
1433
+ if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
1434
+ batch_warning(result, user_function.__name__)
1435
+ result = np.expand_dims(result, axis=0)
1436
+ # Emit integration test event once per test
1437
+ try:
1438
+ emit_integration_event_once(AnalyticsEvent.GT_ENCODER_INTEGRATION_TEST, {
1439
+ 'encoder_name': name
1440
+ })
1441
+ except Exception as e:
1442
+ logger.debug(f"Failed to emit gt_encoder integration test event: {e}")
1443
+ if not _call_from_tl_platform:
1444
+ update_env_params_func("tensorleap_gt_encoder", "v")
466
1445
  return result
467
1446
 
468
1447
  inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
469
1448
 
470
- return inner
1449
+ def mapping_inner(*args, **kwargs):
1450
+ class TempMapping:
1451
+ pass
471
1452
 
1453
+ ret = TempMapping()
1454
+ ret.node_mapping = mapping_inner.node_mapping
472
1455
 
1456
+ return ret
1457
+
1458
+ mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
1459
+
1460
+ def final_inner(*args, **kwargs):
1461
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1462
+ return mapping_inner(*args, **kwargs)
1463
+ else:
1464
+ return inner(*args, **kwargs)
1465
+
1466
+ final_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
1467
+
1468
+ return final_inner
473
1469
 
474
1470
  return decorating_function
475
1471
 
476
1472
 
477
1473
  def tensorleap_custom_loss(name: str, connects_to=None):
1474
+ name_to_unique_name = defaultdict(set)
1475
+
478
1476
  def decorating_function(user_function: CustomCallableInterface):
479
1477
  for loss_handler in leap_binder.setup_container.custom_loss_handlers:
480
1478
  if loss_handler.custom_loss_handler_data.name == name:
481
1479
  raise Exception(f'Custom loss with name {name} already exists. '
482
1480
  f'Please choose another')
483
1481
 
484
- leap_binder.add_custom_loss(user_function, name)
485
-
486
- if connects_to is not None:
487
- arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
488
- _add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
489
-
490
-
491
1482
  valid_types = (np.ndarray, SamplePreprocessResponse)
492
- try:
493
- import tensorflow as tf
494
- valid_types = (np.ndarray, SamplePreprocessResponse, tf.Tensor)
495
- except ImportError:
496
- pass
497
1483
 
498
1484
  def _validate_input_args(*args, **kwargs):
499
-
1485
+ assert len(args) + len(kwargs) > 0, (
1486
+ f"{user_function.__name__}() validation failed: "
1487
+ f"Expected at least one positional|key-word argument of the allowed types (np.ndarray|SamplePreprocessResponse|). "
1488
+ f"but received none. "
1489
+ f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
1490
+ )
500
1491
  for i, arg in enumerate(args):
501
- if isinstance(arg, list):
502
- for y, elem in enumerate(arg):
503
- assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
504
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
505
- else:
506
- assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
507
- f'Argument #{i} should be a numpy array. Got {type(arg)}.')
1492
+ assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
1493
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
508
1494
  for _arg_name, arg in kwargs.items():
509
- if isinstance(arg, list):
510
- for y, elem in enumerate(arg):
511
- assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
512
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
513
- else:
514
- assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
515
- f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
1495
+ assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
1496
+ f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
516
1497
 
517
1498
  def _validate_result(result):
518
- assert isinstance(result, valid_types), \
519
- (f'tensorleap_custom_loss validation failed: '
1499
+ validate_output_structure(result, func_name=user_function.__name__,
1500
+ expected_type_name="np.ndarray")
1501
+ assert isinstance(result, np.ndarray), \
1502
+ (f'{user_function.__name__} validation failed: '
520
1503
  f'The return type should be a numpy array. Got {type(result)}.')
1504
+ assert result.ndim < 2, (f'{user_function.__name__} validation failed: '
1505
+ f'The return type should be a 1Dim numpy array but got {result.ndim}Dim.')
1506
+
1507
+ @functools.wraps(user_function)
1508
+ def inner_without_validate(*args, **kwargs):
1509
+ global _called_from_inside_tl_decorator
1510
+ _called_from_inside_tl_decorator += 1
1511
+
1512
+ try:
1513
+ result = user_function(*args, **kwargs)
1514
+ finally:
1515
+ _called_from_inside_tl_decorator -= 1
1516
+
1517
+ return result
1518
+
1519
+ try:
1520
+ inner_without_validate.__signature__ = inspect.signature(user_function)
1521
+ except (TypeError, ValueError):
1522
+ pass
1523
+
1524
+ leap_binder.add_custom_loss(inner_without_validate, name)
1525
+
1526
+ if connects_to is not None:
1527
+ arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
1528
+ _add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
521
1529
 
522
1530
  def inner(*args, **kwargs):
1531
+ if not _call_from_tl_platform:
1532
+ set_current("tensorleap_custom_loss")
523
1533
  _validate_input_args(*args, **kwargs)
524
- result = user_function(*args, **kwargs)
1534
+
1535
+ result = inner_without_validate(*args, **kwargs)
1536
+
525
1537
  _validate_result(result)
1538
+ if not _call_from_tl_platform:
1539
+ update_env_params_func("tensorleap_custom_loss", "v")
1540
+
526
1541
  return result
527
1542
 
528
- return inner
1543
+ def mapping_inner(*args, **kwargs):
1544
+ user_unique_name = mapping_inner.name
1545
+ if 'user_unique_name' in kwargs:
1546
+ user_unique_name = kwargs['user_unique_name']
1547
+
1548
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
1549
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
1550
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
1551
+
1552
+ ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
1553
+ ordered_connections = list(args) + ordered_connections
1554
+ _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
1555
+ mapping_inner.name, NodeMappingType.CustomLoss)
1556
+
1557
+ return None
1558
+
1559
+ mapping_inner.arg_names = leap_binder.setup_container.custom_loss_handlers[
1560
+ -1].custom_loss_handler_data.arg_names
1561
+ mapping_inner.name = name
1562
+
1563
+ def final_inner(*args, **kwargs):
1564
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1565
+ return mapping_inner(*args, **kwargs)
1566
+ else:
1567
+ return inner(*args, **kwargs)
1568
+
1569
+ final_inner.arg_names = leap_binder.setup_container.custom_loss_handlers[-1].custom_loss_handler_data.arg_names
1570
+ final_inner.name = name
1571
+
1572
+ return final_inner
529
1573
 
530
1574
  return decorating_function
531
1575
 
@@ -550,3 +1594,180 @@ def tensorleap_custom_layer(name: str):
550
1594
  return custom_layer
551
1595
 
552
1596
  return decorating_function
1597
+
1598
+
1599
+ def tensorleap_status_table():
1600
+ import atexit
1601
+ import sys
1602
+ import traceback
1603
+ from typing import Any
1604
+
1605
+ CHECK = "✅"
1606
+ CROSS = "❌"
1607
+ UNKNOWN = "❔"
1608
+
1609
+ code_mapping_failure = [0]
1610
+
1611
+ table = [
1612
+ {"name": "tensorleap_preprocess", "Added to integration": UNKNOWN},
1613
+ {"name": "tensorleap_integration_test", "Added to integration": UNKNOWN},
1614
+ {"name": "tensorleap_input_encoder", "Added to integration": UNKNOWN},
1615
+ {"name": "tensorleap_gt_encoder", "Added to integration": UNKNOWN},
1616
+ {"name": "tensorleap_load_model", "Added to integration": UNKNOWN},
1617
+ {"name": "tensorleap_custom_loss", "Added to integration": UNKNOWN},
1618
+ {"name": "tensorleap_custom_metric (optional)", "Added to integration": UNKNOWN},
1619
+ {"name": "tensorleap_metadata (optional)", "Added to integration": UNKNOWN},
1620
+ {"name": "tensorleap_custom_visualizer (optional)", "Added to integration": UNKNOWN},
1621
+ ]
1622
+
1623
+ _finalizer_called = {"done": False}
1624
+ _crashed = {"value": False}
1625
+ _current_func = {"name": None}
1626
+
1627
+ def _link(url: str) -> str:
1628
+ return f"\033{url}\033"
1629
+
1630
+ def _remove_suffix(s: str, suffix: str) -> str:
1631
+ if suffix and s.endswith(suffix):
1632
+ return s[:-len(suffix)]
1633
+ return s
1634
+
1635
+ def _find_row(name: str):
1636
+ for row in table:
1637
+ if _remove_suffix(row["name"], " (optional)") == name:
1638
+ return row
1639
+ return None
1640
+
1641
+ def _set_status(name: str, status_symbol: str):
1642
+ row = _find_row(name)
1643
+ if not row:
1644
+ return
1645
+
1646
+ cur = row["Added to integration"]
1647
+ if status_symbol == UNKNOWN:
1648
+ return
1649
+ if cur == CHECK and status_symbol != CHECK:
1650
+ return
1651
+
1652
+ row["Added to integration"] = status_symbol
1653
+
1654
+ def _mark_unknowns_as_cross():
1655
+ for row in table:
1656
+ if row["Added to integration"] == UNKNOWN:
1657
+ row["Added to integration"] = CROSS
1658
+
1659
+ def _format_default_value(v: Any) -> str:
1660
+ if hasattr(v, "name"):
1661
+ return str(v.name)
1662
+ if isinstance(v, str):
1663
+ return v
1664
+ s = repr(v)
1665
+ return s if len(s) <= 120 else s[:120] + "..."
1666
+
1667
+ def _print_param_default_warnings():
1668
+ data = _get_param_default_warnings()
1669
+ if not data:
1670
+ return
1671
+
1672
+ print("\nWarnings (Default use. It is recommended to set values explicitly):")
1673
+ for param_name in sorted(data.keys()):
1674
+ default_value = data[param_name]["default_value"]
1675
+ funcs = ", ".join(sorted(data[param_name]["funcs"]))
1676
+ dv = _format_default_value(default_value)
1677
+
1678
+ docs_link = data[param_name].get("link_to_docs")
1679
+ docs_part = f" {_link(docs_link)}" if docs_link else ""
1680
+ print(
1681
+ f" ⚠️ Parameter '{param_name}' defaults to {dv} in the following functions: [{funcs}]. "
1682
+ f"For more information, check {docs_part}")
1683
+ print("\nIf this isn’t the intended behaviour, set them explicitly.")
1684
+
1685
+ def _print_table():
1686
+ _print_param_default_warnings()
1687
+
1688
+ if not started_from("leap_integration.py"):
1689
+ return
1690
+
1691
+ ready_mess = "\nAll parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system."
1692
+ not_ready_mess = "\nSome mandatory components have not yet been added to the Integration test. Recommended next interface to add is: "
1693
+ mandatory_ready_mess = "\nAll mandatory parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system or continue to the next optional reccomeded interface,adding: "
1694
+ code_mapping_failure_mes = "Tensorleap_integration_test code flow failed, check raised exception."
1695
+
1696
+ name_width = max(len(row["name"]) for row in table)
1697
+ status_width = max(len(row["Added to integration"]) for row in table)
1698
+
1699
+ header = f"{'Decorator Name'.ljust(name_width)} | {'Added to integration'.ljust(status_width)}"
1700
+ sep = "-" * len(header)
1701
+
1702
+ print("\n" + header)
1703
+ print(sep)
1704
+
1705
+ ready = True
1706
+ next_step = None
1707
+
1708
+ for row in table:
1709
+ print(f"{row['name'].ljust(name_width)} | {row['Added to integration'].ljust(status_width)}")
1710
+ if not _crashed["value"] and ready:
1711
+ if row["Added to integration"] != CHECK:
1712
+ ready = False
1713
+ next_step = row["name"]
1714
+
1715
+ if _crashed["value"]:
1716
+ print(f"\nScript crashed before completing all steps. crashed at function '{_current_func['name']}'.")
1717
+ return
1718
+
1719
+ if code_mapping_failure[0]:
1720
+ print(f"\n{CROSS + code_mapping_failure_mes}.")
1721
+ return
1722
+
1723
+ print(ready_mess) if ready else print(
1724
+ mandatory_ready_mess + next_step
1725
+ ) if (next_step and "optional" in next_step) else print(not_ready_mess + (next_step or ""))
1726
+
1727
+ def set_current(name: str):
1728
+ _current_func["name"] = name
1729
+
1730
+ def update_env_params(name: str, status: str = "v"):
1731
+ if name == "code_mapping":
1732
+ code_mapping_failure[0] = 1
1733
+ if status == "v":
1734
+ _set_status(name, CHECK)
1735
+ else:
1736
+ _set_status(name, CROSS)
1737
+
1738
+ def run_on_exit():
1739
+ if _finalizer_called["done"]:
1740
+ return
1741
+ _finalizer_called["done"] = True
1742
+ if not _crashed["value"]:
1743
+ _mark_unknowns_as_cross()
1744
+
1745
+ _print_table()
1746
+
1747
+ def handle_exception(exc_type, exc_value, exc_traceback):
1748
+ _crashed["value"] = True
1749
+ crashed_name = _current_func["name"]
1750
+ if crashed_name:
1751
+ row = _find_row(crashed_name)
1752
+ if row and row["Added to integration"] != CHECK:
1753
+ row["Added to integration"] = CROSS
1754
+
1755
+ traceback.print_exception(exc_type, exc_value, exc_traceback)
1756
+ run_on_exit()
1757
+
1758
+ atexit.register(run_on_exit)
1759
+ sys.excepthook = handle_exception
1760
+
1761
+ return set_current, update_env_params
1762
+
1763
+
1764
+ if not _call_from_tl_platform:
1765
+ set_current, update_env_params_func = tensorleap_status_table()
1766
+
1767
+
1768
+
1769
+
1770
+
1771
+
1772
+
1773
+