code-loader 1.0.117__py3-none-any.whl → 1.0.153.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,25 +1,201 @@
1
1
  # mypy: ignore-errors
2
2
  import os
3
- from typing import Optional, Union, Callable, List, Dict
3
+ import warnings
4
+ import logging
5
+ from collections import defaultdict
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from typing import Optional, Union, Callable, List, Dict, Set, Any
9
+ from typing import Optional, Union, Callable, List, Dict, get_args, get_origin, DefaultDict
4
10
 
5
11
  import numpy as np
6
12
  import numpy.typing as npt
7
13
 
14
+ logger = logging.getLogger(__name__)
15
+
8
16
  from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
9
17
  CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
10
18
  VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
11
- ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance
12
- from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType
19
+ ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance, \
20
+ InstanceLengthCallableInterface
21
+ from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType, DataStateType
13
22
  from code_loader import leap_binder
14
23
  from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
15
24
  from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
16
25
  LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
17
26
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
27
+ from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
18
28
 
19
29
  import inspect
20
30
  import functools
31
+ from pathlib import Path
21
32
 
22
33
  _called_from_inside_tl_decorator = 0
34
+ _called_from_inside_tl_integration_test_decorator = False
35
+ _call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
36
+
37
+ # ---- warnings store (module-level) ----
38
+ _UNSET = object()
39
+ _STORED_WARNINGS: List[Dict[str, Any]] = []
40
+ _STORED_WARNING_KEYS: Set[tuple] = set()
41
+ # param_name -> set(user_func_name)
42
+ _PARAM_DEFAULT_FUNCS: DefaultDict[str, Set[str]] = defaultdict(set)
43
+ # param_name -> default_value used (repr-able)
44
+ _PARAM_DEFAULT_VALUE: Dict[str, Any] = {}
45
+ _PARAM_DEFAULT_DOCS: Dict[str, str] = {}
46
+
47
+
48
+ def get_entry_script_path() -> str:
49
+ import sys
50
+ argv0 = sys.argv[0] if sys.argv else ""
51
+ if argv0:
52
+ return str(Path(argv0).resolve())
53
+ import __main__
54
+ main_file = getattr(__main__, "__file__", "") or ""
55
+ return str(Path(main_file).resolve()) if main_file else ""
56
+
57
+ def started_from(filename: str) -> bool:
58
+ entry = get_entry_script_path()
59
+ return bool(entry) and Path(entry).name == filename
60
+
61
+
62
+ def store_warning_by_param(
63
+ *,
64
+ param_name: str,
65
+ user_func_name: str,
66
+ default_value: Any,
67
+ link_to_docs: str = None,
68
+ ) -> None:
69
+ _PARAM_DEFAULT_FUNCS[param_name].add(user_func_name)
70
+
71
+ if param_name not in _PARAM_DEFAULT_VALUE:
72
+ _PARAM_DEFAULT_VALUE[param_name] = default_value
73
+
74
+ if link_to_docs and param_name not in _PARAM_DEFAULT_DOCS:
75
+ _PARAM_DEFAULT_DOCS[param_name] = link_to_docs
76
+
77
+
78
+ def _get_param_default_warnings() -> Dict[str, Dict[str, Any]]:
79
+ out: Dict[str, Dict[str, Any]] = {}
80
+ for p, funcs in _PARAM_DEFAULT_FUNCS.items():
81
+ out[p] = {
82
+ "default_value": _PARAM_DEFAULT_VALUE.get(p, None),
83
+ "funcs": set(funcs),
84
+ "link_to_docs": _PARAM_DEFAULT_DOCS.get(p),
85
+ }
86
+ return out
87
+
88
+
89
+ def _get_stored_warnings() -> List[Dict[str, Any]]:
90
+ return list(_STORED_WARNINGS)
91
+
92
+
93
+ def validate_args_structure(*args, types_order, func_name, expected_names, **kwargs):
94
+ def _type_to_str(t):
95
+ origin = get_origin(t)
96
+ if origin is Union:
97
+ return " | ".join(tt.__name__ for tt in get_args(t))
98
+ elif hasattr(t, "__name__"):
99
+ return t.__name__
100
+ else:
101
+ return str(t)
102
+
103
+ def _format_types(types, names=None):
104
+ return ", ".join(
105
+ f"{(names[i] + ': ') if names else f'arg{i}: '}{_type_to_str(ty)}"
106
+ for i, ty in enumerate(types)
107
+ )
108
+
109
+ if expected_names:
110
+ normalized_args = []
111
+ for i, name in enumerate(expected_names):
112
+ if i < len(args):
113
+ normalized_args.append(args[i])
114
+ elif name in kwargs:
115
+ normalized_args.append(kwargs[name])
116
+ else:
117
+ raise AssertionError(
118
+ f"{func_name} validation failed: "
119
+ f"Missing required argument '{name}'. "
120
+ f"Expected arguments: {expected_names}."
121
+ )
122
+ else:
123
+ normalized_args = list(args)
124
+ if len(normalized_args) != len(types_order):
125
+ expected = _format_types(types_order, expected_names)
126
+ got_types = ", ".join(type(arg).__name__ for arg in normalized_args)
127
+ raise AssertionError(
128
+ f"{func_name} validation failed: "
129
+ f"Expected exactly {len(types_order)} arguments ({expected}), "
130
+ f"but got {len(normalized_args)} argument(s) of type(s): ({got_types}). "
131
+ f"Correct usage example: {func_name}({expected})"
132
+ )
133
+
134
+ for i, (arg, expected_type) in enumerate(zip(normalized_args, types_order)):
135
+ origin = get_origin(expected_type)
136
+ if origin is Union:
137
+ allowed_types = get_args(expected_type)
138
+ else:
139
+ allowed_types = (expected_type,)
140
+
141
+ if not isinstance(arg, allowed_types):
142
+ allowed_str = " | ".join(t.__name__ for t in allowed_types)
143
+ raise AssertionError(
144
+ f"{func_name} validation failed: "
145
+ f"Argument '{expected_names[i] if expected_names else f'arg{i}'}' "
146
+ f"expected type {allowed_str}, but got {type(arg).__name__}. "
147
+ f"Correct usage example: {func_name}({_format_types(types_order, expected_names)})"
148
+ )
149
+
150
+
151
+ def validate_output_structure(result, func_name: str, expected_type_name="np.ndarray", gt_flag=False):
152
+ if result is None or (isinstance(result, float) and np.isnan(result)):
153
+ if gt_flag:
154
+ raise AssertionError(
155
+ f"{func_name} validation failed: "
156
+ f"The function returned {result!r}. "
157
+ f"If you are working with an unlabeled dataset and no ground truth is available, "
158
+ f"use 'return np.array([], dtype=np.float32)' instead. "
159
+ f"Otherwise, {func_name} expected a single {expected_type_name} object. "
160
+ f"Make sure the function ends with 'return <{expected_type_name}>'."
161
+ )
162
+
163
+ raise AssertionError(
164
+ f"{func_name} validation failed: "
165
+ f"The function returned None. "
166
+ f"Expected a single {expected_type_name} object. "
167
+ f"Make sure the function ends with 'return <{expected_type_name}>'."
168
+ )
169
+ if isinstance(result, tuple):
170
+ element_descriptions = [
171
+ f"[{i}] type: {type(r).__name__}"
172
+ for i, r in enumerate(result)
173
+ ]
174
+ element_summary = "\n ".join(element_descriptions)
175
+
176
+ raise AssertionError(
177
+ f"{func_name} validation failed: "
178
+ f"The function returned multiple outputs ({len(result)} values), "
179
+ f"but only a single {expected_type_name} is allowed.\n\n"
180
+ f"Returned elements:\n"
181
+ f" {element_summary}\n\n"
182
+ f"Correct usage example:\n"
183
+ f" def {func_name}(...):\n"
184
+ f" return <{expected_type_name}>\n\n"
185
+ f"If you intended to return multiple values, combine them into a single "
186
+ f"{expected_type_name} (e.g., by concatenation or stacking)."
187
+ )
188
+
189
+
190
+ def batch_warning(result, func_name):
191
+ if len(result.shape) > 0 and result.shape[0] == 1:
192
+ warnings.warn(
193
+ f"{func_name} warning: Tensorleap will add a batch dimension at axis 0 to the output of {func_name}, "
194
+ f"although the detected size of axis 0 is already 1. "
195
+ f"This may lead to an extra batch dimension (e.g., shape (1, 1, ...)). "
196
+ f"Please ensure that the output of '{func_name}' is not already batched "
197
+ f"to avoid computation errors."
198
+ )
23
199
 
24
200
 
25
201
  def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
@@ -40,51 +216,234 @@ def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
40
216
  _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type)
41
217
 
42
218
 
43
- def integration_test():
219
+ def tensorleap_integration_test():
44
220
  def decorating_function(integration_test_function: Callable):
45
221
  leap_binder.integration_test_func = integration_test_function
46
222
 
223
+ def _validate_input_args(*args, **kwargs):
224
+ sample_id, preprocess_response = args
225
+ assert type(sample_id) == preprocess_response.sample_id_type, (
226
+ f"tensorleap_integration_test validation failed: "
227
+ f"sample_id type ({type(sample_id).__name__}) does not match the expected "
228
+ f"type ({preprocess_response.sample_id_type}) from the PreprocessResponse."
229
+ )
230
+
47
231
  def inner(*args, **kwargs):
48
- ret = integration_test_function(*args, **kwargs)
232
+ if not _call_from_tl_platform:
233
+ set_current('tensorleap_integration_test')
234
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
235
+ func_name='integration_test', expected_names=["idx", "preprocess"], **kwargs)
236
+ _validate_input_args(*args, **kwargs)
49
237
 
238
+ global _called_from_inside_tl_integration_test_decorator
239
+ # Clear integration test events for new test
50
240
  try:
51
- os.environ[mapping_runtime_mode_env_var_mame] = 'True'
52
- integration_test_function(None, None)
241
+ clear_integration_events()
53
242
  except Exception as e:
54
- print(f'Error during integration test: Make sure to disable any non tensorleap decorators '
55
- f'functions before pushing a new TL version')
243
+ logger.debug(f"Failed to clear integration events: {e}")
244
+ try:
245
+ _called_from_inside_tl_integration_test_decorator = True
246
+ if not _call_from_tl_platform:
247
+ update_env_params_func("tensorleap_integration_test",
248
+ "v") # put here because otherwise it will become v only if it finishes all the script
249
+ ret = integration_test_function(*args, **kwargs)
250
+
251
+ try:
252
+ os.environ[mapping_runtime_mode_env_var_mame] = 'True'
253
+ integration_test_function(None, PreprocessResponse(state=DataStateType.training, length=0))
254
+ except Exception as e:
255
+ import traceback
256
+ first_tb = traceback.extract_tb(e.__traceback__)[-1]
257
+ file_name = Path(first_tb.filename).name
258
+ line_number = first_tb.lineno
259
+ if isinstance(e, TypeError) and 'is not subscriptable' in str(e):
260
+ update_env_params_func("code_mapping", "x")
261
+ raise (f'Invalid integration code. File {file_name}, line {line_number}: '
262
+ f"indexing is supported only on the model's predictions inside the integration test. Please remove this indexing operation usage from the integration test code.")
263
+ else:
264
+ update_env_params_func("code_mapping", "x")
265
+
266
+ raise (f'Invalid integration code. File {file_name}, line {line_number}: '
267
+ f'Integration test is only allowed to call Tensorleap decorators. '
268
+ f'Ensure any arithmetics, external library use, Python logic is placed within Tensorleap decoders')
269
+ finally:
270
+ if mapping_runtime_mode_env_var_mame in os.environ:
271
+ del os.environ[mapping_runtime_mode_env_var_mame]
56
272
  finally:
57
- if mapping_runtime_mode_env_var_mame in os.environ:
58
- del os.environ[mapping_runtime_mode_env_var_mame]
273
+ _called_from_inside_tl_integration_test_decorator = False
274
+
275
+ leap_binder.check()
59
276
 
60
277
  return inner
61
278
 
62
279
  return decorating_function
63
280
 
64
281
 
65
- def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = None):
282
+ def _safe_get_item(key):
283
+ try:
284
+ return NodeMappingType[f'Input{str(key)}']
285
+ except ValueError:
286
+ raise Exception(f'Tensorleap currently supports models with no more then 10 inputs')
287
+
288
+
289
+ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = _UNSET):
290
+ prediction_types_was_provided = prediction_types is not _UNSET
291
+
292
+ if not prediction_types_was_provided:
293
+ prediction_types = []
294
+ if not _call_from_tl_platform:
295
+ store_warning_by_param(
296
+ param_name="prediction_types",
297
+ user_func_name="tensorleap_load_model",
298
+ default_value=prediction_types,
299
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
300
+ )
301
+ assert isinstance(prediction_types, list), (
302
+ f"tensorleap_load_model validation failed: "
303
+ f" prediction_types is an optional argument of type List[PredictionTypeHandler]] but got {type(prediction_types).__name__}."
304
+ )
66
305
  for i, prediction_type in enumerate(prediction_types):
306
+ assert isinstance(prediction_type, PredictionTypeHandler), (f"tensorleap_load_model validation failed: "
307
+ f" prediction_types at position {i} must be of type PredictionTypeHandler but got {type(prediction_types[i]).__name__}.")
308
+ prediction_type_channel_dim_was_provided = prediction_type.channel_dim != "tl_default_value"
309
+ if not prediction_type_channel_dim_was_provided:
310
+ prediction_types[i].channel_dim = -1
311
+ if not _call_from_tl_platform:
312
+ store_warning_by_param(
313
+ param_name=f"prediction_types[{i}].channel_dim",
314
+ user_func_name="tensorleap_load_model",
315
+ default_value=prediction_types[i].channel_dim,
316
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
317
+ )
67
318
  leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
68
319
 
69
- def decorating_function(load_model_func):
320
+ def _validate_result(result) -> None:
321
+ valid_types = ["onnxruntime", "keras"]
322
+ err_message = f"tensorleap_load_model validation failed:\nSupported models are Keras and onnxruntime only and non of them was returned."
323
+ validate_output_structure(result, func_name="tensorleap_load_model",
324
+ expected_type_name=[" | ".join(t for t in valid_types)][0])
325
+ try:
326
+ import keras
327
+ except ImportError:
328
+ keras = None
329
+ try:
330
+ import tensorflow as tf
331
+ except ImportError:
332
+ tf = None
333
+ try:
334
+ import onnxruntime
335
+ except ImportError:
336
+ onnxruntime = None
337
+
338
+ if not keras and not onnxruntime:
339
+ raise AssertionError(err_message)
340
+
341
+ is_keras_model = (
342
+ bool(keras and isinstance(result, getattr(keras, "Model", tuple())))
343
+ or bool(tf and isinstance(result, getattr(tf.keras, "Model", tuple())))
344
+ )
345
+ is_onnx_model = bool(onnxruntime and isinstance(result, onnxruntime.InferenceSession))
346
+
347
+ if not any([is_keras_model, is_onnx_model]):
348
+ raise AssertionError(err_message)
349
+
350
+ def decorating_function(load_model_func, prediction_types=prediction_types):
70
351
  class TempMapping:
71
352
  pass
72
353
 
73
- def inner():
354
+ @lru_cache()
355
+ def inner(*args, **kwargs):
356
+ if not _call_from_tl_platform:
357
+ set_current('tensorleap_load_model')
358
+ validate_args_structure(*args, types_order=[],
359
+ func_name='tensorleap_load_model', expected_names=[], **kwargs)
360
+
74
361
  class ModelPlaceholder:
75
- def __init__(self):
76
- self.model = load_model_func()
362
+ def __init__(self, prediction_types):
363
+ self.model = load_model_func() # TODO- check why this fails on onnx model
364
+ self.prediction_types = prediction_types
365
+ _validate_result(self.model)
77
366
 
78
367
  # keras interface
79
368
  def __call__(self, arg):
80
369
  ret = self.model(arg)
370
+ self.validate_declared_prediction_types(ret)
371
+ if isinstance(ret, list):
372
+ return [r.numpy() for r in ret]
81
373
  return ret.numpy()
82
374
 
375
+ def validate_declared_prediction_types(self, ret):
376
+ if not (len(self.prediction_types) == len(ret) if isinstance(ret, list) else 1) and len(
377
+ self.prediction_types) != 0:
378
+ if not _call_from_tl_platform:
379
+ update_env_params_func("tensorleap_load_model", "x")
380
+ raise Exception(
381
+ f"tensorleap_load_model validation failed: number of declared prediction types({len(prediction_types)}) != number of model outputs({len(ret) if isinstance(ret, list) else 1})")
382
+
383
+ def _convert_onnx_inputs_to_correct_type(
384
+ self, float_arrays_inputs: Dict[str, np.ndarray]
385
+ ) -> Dict[str, np.ndarray]:
386
+ ONNX_TYPE_TO_NP = {
387
+ "tensor(float)": np.float32,
388
+ "tensor(double)": np.float64,
389
+ "tensor(int64)": np.int64,
390
+ "tensor(int32)": np.int32,
391
+ "tensor(int16)": np.int16,
392
+ "tensor(int8)": np.int8,
393
+ "tensor(uint64)": np.uint64,
394
+ "tensor(uint32)": np.uint32,
395
+ "tensor(uint16)": np.uint16,
396
+ "tensor(uint8)": np.uint8,
397
+ "tensor(bool)": np.bool_,
398
+ }
399
+
400
+ """
401
+ Cast user-provided NumPy inputs to match the dtypes/shapes
402
+ expected by an ONNX Runtime InferenceSession.
403
+ """
404
+ coerced = {}
405
+ meta = {i.name: i for i in self.model.get_inputs()}
406
+
407
+ for name, arr in float_arrays_inputs.items():
408
+ if name not in meta:
409
+ # Keep as-is unless extra inputs are disallowed
410
+ coerced[name] = arr
411
+ continue
412
+
413
+ info = meta[name]
414
+ onnx_type = info.type
415
+ want_dtype = ONNX_TYPE_TO_NP.get(onnx_type)
416
+
417
+ if want_dtype is None:
418
+ raise TypeError(f"Unsupported ONNX input type: {onnx_type}")
419
+
420
+ # Cast dtype if needed
421
+ if arr.dtype != want_dtype:
422
+ arr = arr.astype(want_dtype, copy=False)
423
+
424
+ coerced[name] = arr
425
+
426
+ # Verify required inputs are present
427
+ missing = [n for n in meta if n not in coerced]
428
+ if missing:
429
+ raise KeyError(f"Missing required input(s): {sorted(missing)}")
430
+
431
+ return coerced
432
+
83
433
  # onnx runtime interface
84
434
  def run(self, output_names, input_dict):
85
- return self.model.run(output_names, input_dict)
435
+ corrected_type_inputs = self._convert_onnx_inputs_to_correct_type(input_dict)
436
+ ret = self.model.run(output_names, corrected_type_inputs)
437
+ self.validate_declared_prediction_types(ret)
438
+ return ret
86
439
 
87
- return ModelPlaceholder()
440
+ def get_inputs(self):
441
+ return self.model.get_inputs()
442
+
443
+ model_placeholder = ModelPlaceholder(prediction_types)
444
+ if not _call_from_tl_platform:
445
+ update_env_params_func("tensorleap_load_model", "v")
446
+ return model_placeholder
88
447
 
89
448
  def mapping_inner():
90
449
  class ModelOutputPlaceholder:
@@ -96,15 +455,20 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
96
455
  f'Expected key to be an int, got {type(key)} instead.'
97
456
 
98
457
  ret = TempMapping()
99
- ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
458
+ try:
459
+ ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
460
+ except ValueError as e:
461
+ raise Exception(f'Tensorleap currently supports models with no more then 10 active predictions,'
462
+ f' {key} not supported.')
100
463
  return ret
101
464
 
102
465
  class ModelPlaceholder:
466
+
103
467
  # keras interface
104
468
  def __call__(self, arg):
105
469
  if isinstance(arg, list):
106
470
  for i, elem in enumerate(arg):
107
- elem.node_mapping.type = NodeMappingType[f'Input{str(i)}']
471
+ elem.node_mapping.type = _safe_get_item(i)
108
472
  else:
109
473
  arg.node_mapping.type = NodeMappingType.Input0
110
474
 
@@ -115,18 +479,38 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
115
479
  assert output_names is None
116
480
  assert isinstance(input_dict, dict), \
117
481
  f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
118
- for i, elem in enumerate(input_dict.values()):
119
- elem.node_mapping.type = NodeMappingType[f'Input{str(i)}']
482
+ for i, (input_key, elem) in enumerate(input_dict.items()):
483
+ if isinstance(input_key, NodeMappingType):
484
+ elem.node_mapping.type = input_key
485
+ else:
486
+ elem.node_mapping.type = _safe_get_item(i)
120
487
 
121
488
  return ModelOutputPlaceholder()
122
489
 
490
+ def get_inputs(self):
491
+ class FollowIndex:
492
+ def __init__(self, index):
493
+ self.name = _safe_get_item(index)
494
+
495
+ class FollowInputIndex:
496
+ def __init__(self):
497
+ pass
498
+
499
+ def __getitem__(self, index):
500
+ assert isinstance(index, int), \
501
+ f'Expected key to be an int, got {type(index)} instead.'
502
+
503
+ return FollowIndex(index)
504
+
505
+ return FollowInputIndex()
506
+
123
507
  return ModelPlaceholder()
124
508
 
125
- def final_inner():
509
+ def final_inner(*args, **kwargs):
126
510
  if os.environ.get(mapping_runtime_mode_env_var_mame):
127
511
  return mapping_inner()
128
512
  else:
129
- return inner()
513
+ return inner(*args, **kwargs)
130
514
 
131
515
  return final_inner
132
516
 
@@ -134,83 +518,186 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
134
518
 
135
519
 
136
520
  def tensorleap_custom_metric(name: str,
137
- direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
521
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
138
522
  compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
139
523
  connects_to=None):
140
- def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
141
- CustomMultipleReturnCallableInterfaceMultiArgs,
142
- ConfusionMatrixCallableInterfaceMultiArgs]):
524
+ name_to_unique_name = defaultdict(set)
525
+
526
+ def decorating_function(
527
+ user_function: Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs,
528
+ ConfusionMatrixCallableInterfaceMultiArgs]):
529
+ nonlocal direction
530
+
531
+ direction_was_provided = direction is not _UNSET
532
+
533
+ if not direction_was_provided:
534
+ direction = MetricDirection.Downward
535
+ if not _call_from_tl_platform:
536
+ store_warning_by_param(
537
+ param_name="direction",
538
+ user_func_name=user_function.__name__,
539
+ default_value=direction,
540
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/custom-metrics"
541
+
542
+ )
543
+
544
+ def _validate_decorators_signature():
545
+ err_message = f"{user_function.__name__} validation failed.\n"
546
+ if not isinstance(name, str):
547
+ raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
548
+ valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
549
+ if isinstance(direction, MetricDirection):
550
+ if direction not in valid_directions:
551
+ raise ValueError(
552
+ err_message +
553
+ f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
554
+ f"got type {type(direction).__name__}."
555
+ )
556
+ elif isinstance(direction, dict):
557
+ if not all(isinstance(k, str) for k in direction.keys()):
558
+ invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
559
+ raise TypeError(
560
+ err_message +
561
+ f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
562
+ )
563
+ for k, v in direction.items():
564
+ if v not in valid_directions:
565
+ raise ValueError(
566
+ err_message +
567
+ f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
568
+ f"got type {type(v).__name__}."
569
+ )
570
+ else:
571
+ raise TypeError(
572
+ err_message +
573
+ f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
574
+ f"got type {type(direction).__name__}."
575
+ )
576
+ if compute_insights is not None:
577
+ if not isinstance(compute_insights, (bool, dict)):
578
+ raise TypeError(
579
+ err_message +
580
+ f"`compute_insights` must be a bool or a Dict[str, bool], "
581
+ f"got type {type(compute_insights).__name__}."
582
+ )
583
+ if isinstance(compute_insights, dict):
584
+ if not all(isinstance(k, str) for k in compute_insights.keys()):
585
+ invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
586
+ raise TypeError(
587
+ err_message +
588
+ f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
589
+ )
590
+ for k, v in compute_insights.items():
591
+ if not isinstance(v, bool):
592
+ raise TypeError(
593
+ err_message +
594
+ f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
595
+ )
596
+ if connects_to is not None:
597
+ valid_types = (str, list, tuple, set)
598
+ if not isinstance(connects_to, valid_types):
599
+ raise TypeError(
600
+ err_message +
601
+ f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
602
+ )
603
+ if isinstance(connects_to, (list, tuple, set)):
604
+ invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
605
+ if invalid_elems:
606
+ raise TypeError(
607
+ err_message +
608
+ f"All elements in `connects_to` must be strings, "
609
+ f"but found element types: {invalid_elems}."
610
+ )
611
+
612
+ _validate_decorators_signature()
613
+
143
614
  for metric_handler in leap_binder.setup_container.metrics:
144
615
  if metric_handler.metric_handler_data.name == name:
145
616
  raise Exception(f'Metric with name {name} already exists. '
146
617
  f'Please choose another')
147
618
 
148
619
  def _validate_input_args(*args, **kwargs) -> None:
620
+ assert len(args) + len(kwargs) > 0, (
621
+ f"{user_function.__name__}() validation failed: "
622
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
623
+ f"but received none. "
624
+ f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
625
+ )
149
626
  for i, arg in enumerate(args):
150
627
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
151
- f'tensorleap_custom_metric validation failed: '
628
+ f'{user_function.__name__}() validation failed: '
152
629
  f'Argument #{i} should be a numpy array. Got {type(arg)}.')
153
630
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
154
631
  assert arg.shape[0] == leap_binder.batch_size_to_validate, \
155
- (f'tensorleap_custom_metric validation failed: Argument #{i} '
632
+ (f'{user_function.__name__}() validation failed: Argument #{i} '
156
633
  f'first dim should be as the batch size. Got {arg.shape[0]} '
157
634
  f'instead of {leap_binder.batch_size_to_validate}')
158
635
 
159
636
  for _arg_name, arg in kwargs.items():
160
637
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
161
- f'tensorleap_custom_metric validation failed: '
638
+ f'{user_function.__name__}() validation failed: '
162
639
  f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
163
640
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
164
641
  assert arg.shape[0] == leap_binder.batch_size_to_validate, \
165
- (f'tensorleap_custom_metric validation failed: Argument {_arg_name} '
642
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
166
643
  f'first dim should be as the batch size. Got {arg.shape[0]} '
167
644
  f'instead of {leap_binder.batch_size_to_validate}')
168
645
 
169
646
  def _validate_result(result) -> None:
170
- supported_types_message = (f'tensorleap_custom_metric validation failed: '
171
- f'Metric has returned unsupported type. Supported types are List[float], '
172
- f'List[List[ConfusionMatrixElement]], NDArray[np.float32]. ')
647
+ validate_output_structure(result, func_name=user_function.__name__,
648
+ expected_type_name="List[float | int | None | List[ConfusionMatrixElement] ] | NDArray[np.float32] or dictonary with one of these types as its values types")
649
+ supported_types_message = (f'{user_function.__name__}() validation failed: '
650
+ f'{user_function.__name__}() has returned unsupported type.\nSupported types are List[float|int|None], '
651
+ f'List[List[ConfusionMatrixElement]], NDArray[np.float32] or dictonary with one of these types as its values types. ')
173
652
 
174
- def _validate_single_metric(single_metric_result):
653
+ def _validate_single_metric(single_metric_result, key=None):
175
654
  if isinstance(single_metric_result, list):
176
655
  if isinstance(single_metric_result[0], list):
177
- assert isinstance(single_metric_result[0][0], ConfusionMatrixElement), \
178
- f'{supported_types_message}Got List[List[{type(single_metric_result[0][0])}]].'
656
+ assert all(isinstance(cm, ConfusionMatrixElement) for cm in single_metric_result[0]), (
657
+ f"{supported_types_message} "
658
+ f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
659
+ f"List[List[{', '.join(type(cm).__name__ for cm in single_metric_result[0])}]]."
660
+ )
661
+
179
662
  else:
180
- assert isinstance(single_metric_result[0], (
181
- float, int,
182
- type(None))), f'{supported_types_message}Got List[{type(single_metric_result[0])}].'
663
+ assert all(isinstance(v, (float, int, type(None), np.float32)) for v in single_metric_result), (
664
+ f"{supported_types_message}\n"
665
+ f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
666
+ f"List[{', '.join(type(v).__name__ for v in single_metric_result)}]."
667
+ )
183
668
  else:
184
669
  assert isinstance(single_metric_result,
185
- np.ndarray), f'{supported_types_message}Got {type(single_metric_result)}.'
186
- assert len(single_metric_result.shape) == 1, (f'tensorleap_custom_metric validation failed: '
670
+ np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
671
+ assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
187
672
  f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
188
673
 
189
674
  if leap_binder.batch_size_to_validate:
190
675
  assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
191
- f'tensorleap_custom_metrix validation failed: The return len should be as the batch size.'
676
+ f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
192
677
 
193
678
  if isinstance(result, dict):
194
679
  for key, value in result.items():
680
+ _validate_single_metric(value, key)
681
+
195
682
  assert isinstance(key, str), \
196
- (f'tensorleap_custom_metric validation failed: '
683
+ (f'{user_function.__name__}() validation failed: '
197
684
  f'Keys in the return dict should be of type str. Got {type(key)}.')
198
685
  _validate_single_metric(value)
199
686
 
200
687
  if isinstance(direction, dict):
201
688
  for direction_key in direction:
202
689
  assert direction_key in result, \
203
- (f'tensorleap_custom_metric validation failed: '
690
+ (f'{user_function.__name__}() validation failed: '
204
691
  f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
205
692
 
206
693
  if compute_insights is not None:
207
694
  assert isinstance(compute_insights, dict), \
208
- (f'tensorleap_custom_metric validation failed: '
695
+ (f'{user_function.__name__}() validation failed: '
209
696
  f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
210
697
 
211
698
  for ci_key in compute_insights:
212
699
  assert ci_key in result, \
213
- (f'tensorleap_custom_metric validation failed: '
700
+ (f'{user_function.__name__}() validation failed: '
214
701
  f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
215
702
 
216
703
  else:
@@ -218,7 +705,7 @@ def tensorleap_custom_metric(name: str,
218
705
 
219
706
  if compute_insights is not None:
220
707
  assert isinstance(compute_insights, bool), \
221
- (f'tensorleap_custom_metric validation failed: '
708
+ (f'{user_function.__name__}() validation failed: '
222
709
  f'compute_insights should be boolean. Got {type(compute_insights)}.')
223
710
 
224
711
  @functools.wraps(user_function)
@@ -245,11 +732,15 @@ def tensorleap_custom_metric(name: str,
245
732
  _add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
246
733
 
247
734
  def inner(*args, **kwargs):
735
+ if not _call_from_tl_platform:
736
+ set_current('tensorleap_custom_metric')
248
737
  _validate_input_args(*args, **kwargs)
249
738
 
250
739
  result = inner_without_validate(*args, **kwargs)
251
740
 
252
741
  _validate_result(result)
742
+ if not _call_from_tl_platform:
743
+ update_env_params_func("tensorleap_custom_metric", "v")
253
744
  return result
254
745
 
255
746
  def mapping_inner(*args, **kwargs):
@@ -259,6 +750,11 @@ def tensorleap_custom_metric(name: str,
259
750
 
260
751
  ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
261
752
  ordered_connections = list(args) + ordered_connections
753
+
754
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
755
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
756
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
757
+
262
758
  _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
263
759
  mapping_inner.name, NodeMappingType.Metric)
264
760
 
@@ -281,29 +777,40 @@ def tensorleap_custom_metric(name: str,
281
777
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
282
778
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
283
779
  connects_to=None):
780
+ name_to_unique_name = defaultdict(set)
781
+
284
782
  def decorating_function(user_function: VisualizerCallableInterface):
783
+ assert isinstance(visualizer_type, LeapDataType), (f"{user_function.__name__} validation failed: "
784
+ f"visualizer_type should be of type {LeapDataType.__name__} but got {type(visualizer_type)}"
785
+ )
285
786
  for viz_handler in leap_binder.setup_container.visualizers:
286
787
  if viz_handler.visualizer_handler_data.name == name:
287
788
  raise Exception(f'Visualizer with name {name} already exists. '
288
789
  f'Please choose another')
289
790
 
290
791
  def _validate_input_args(*args, **kwargs):
792
+ assert len(args) + len(kwargs) > 0, (
793
+ f"{user_function.__name__}() validation failed: "
794
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
795
+ f"but received none. "
796
+ f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
797
+ )
291
798
  for i, arg in enumerate(args):
292
799
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
293
- f'tensorleap_custom_visualizer validation failed: '
800
+ f'{user_function.__name__}() validation failed: '
294
801
  f'Argument #{i} should be a numpy array. Got {type(arg)}.')
295
802
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
296
803
  assert arg.shape[0] != leap_binder.batch_size_to_validate, \
297
- (f'tensorleap_custom_visualizer validation failed: '
804
+ (f'{user_function.__name__}() validation failed: '
298
805
  f'Argument #{i} should be without batch dimension. ')
299
806
 
300
807
  for _arg_name, arg in kwargs.items():
301
808
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
302
- f'tensorleap_custom_visualizer validation failed: '
809
+ f'{user_function.__name__}() validation failed: '
303
810
  f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
304
811
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
305
812
  assert arg.shape[0] != leap_binder.batch_size_to_validate, \
306
- (f'tensorleap_custom_visualizer validation failed: Argument {_arg_name} '
813
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
307
814
  f'should be without batch dimension. ')
308
815
 
309
816
  def _validate_result(result):
@@ -317,8 +824,11 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
317
824
  LeapDataType.ImageWithBBox: LeapImageWithBBox,
318
825
  LeapDataType.ImageWithHeatmap: LeapImageWithHeatmap
319
826
  }
827
+ validate_output_structure(result, func_name=user_function.__name__,
828
+ expected_type_name=result_type_map[visualizer_type])
829
+
320
830
  assert isinstance(result, result_type_map[visualizer_type]), \
321
- (f'tensorleap_custom_visualizer validation failed: '
831
+ (f'{user_function.__name__}() validation failed: '
322
832
  f'The return type should be {result_type_map[visualizer_type]}. Got {type(result)}.')
323
833
 
324
834
  @functools.wraps(user_function)
@@ -345,11 +855,15 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
345
855
  _add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
346
856
 
347
857
  def inner(*args, **kwargs):
858
+ if not _call_from_tl_platform:
859
+ set_current('tensorleap_custom_visualizer')
348
860
  _validate_input_args(*args, **kwargs)
349
861
 
350
862
  result = inner_without_validate(*args, **kwargs)
351
863
 
352
864
  _validate_result(result)
865
+ if not _call_from_tl_platform:
866
+ update_env_params_func("tensorleap_custom_visualizer", "v")
353
867
  return result
354
868
 
355
869
  def mapping_inner(*args, **kwargs):
@@ -357,6 +871,10 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
357
871
  if 'user_unique_name' in kwargs:
358
872
  user_unique_name = kwargs['user_unique_name']
359
873
 
874
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
875
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
876
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
877
+
360
878
  ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
361
879
  ordered_connections = list(args) + ordered_connections
362
880
  _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
@@ -387,30 +905,26 @@ def tensorleap_metadata(
387
905
  f'Please choose another')
388
906
 
389
907
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
390
- assert isinstance(sample_id, (int, str)), \
391
- (f'tensorleap_metadata validation failed: '
392
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
393
- assert isinstance(preprocess_response, PreprocessResponse), \
394
- (f'tensorleap_metadata validation failed: '
395
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
396
908
  assert type(sample_id) == preprocess_response.sample_id_type, \
397
- (f'tensorleap_metadata validation failed: '
909
+ (f'{user_function.__name__}() validation failed: '
398
910
  f'Argument sample_id should be as the same type as defined in the preprocess response '
399
911
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
400
912
 
401
913
  def _validate_result(result):
402
914
  supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
403
915
  np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
916
+ validate_output_structure(result, func_name=user_function.__name__,
917
+ expected_type_name=supported_result_types)
404
918
  assert isinstance(result, supported_result_types), \
405
- (f'tensorleap_metadata validation failed: '
919
+ (f'{user_function.__name__}() validation failed: '
406
920
  f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
407
921
  if isinstance(result, dict):
408
922
  for key, value in result.items():
409
923
  assert isinstance(key, str), \
410
- (f'tensorleap_metadata validation failed: '
924
+ (f'{user_function.__name__}() validation failed: '
411
925
  f'Keys in the return dict should be of type str. Got {type(key)}.')
412
926
  assert isinstance(value, supported_result_types), \
413
- (f'tensorleap_metadata validation failed: '
927
+ (f'{user_function.__name__}() validation failed: '
414
928
  f'Values in the return dict should be of type {str(supported_result_types)}. Got {type(value)}.')
415
929
 
416
930
  def inner_without_validate(sample_id, preprocess_response):
@@ -427,6 +941,60 @@ def tensorleap_metadata(
427
941
 
428
942
  leap_binder.set_metadata(inner_without_validate, name, metadata_type)
429
943
 
944
+ def inner(*args, **kwargs):
945
+ if not _call_from_tl_platform:
946
+ set_current('tensorleap_metadata')
947
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
948
+ return None
949
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
950
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
951
+ sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
952
+ _validate_input_args(sample_id, preprocess_response)
953
+
954
+ result = inner_without_validate(sample_id, preprocess_response)
955
+
956
+ _validate_result(result)
957
+ if not _call_from_tl_platform:
958
+ update_env_params_func("tensorleap_metadata", "v")
959
+ return result
960
+
961
+ return inner
962
+
963
+ return decorating_function
964
+
965
+
966
+ def tensorleap_custom_latent_space():
967
+ def decorating_function(user_function: SectionCallableInterface):
968
+ def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
969
+ assert isinstance(sample_id, (int, str)), \
970
+ (f'tensorleap_custom_latent_space validation failed: '
971
+ f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
972
+ assert isinstance(preprocess_response, PreprocessResponse), \
973
+ (f'tensorleap_custom_latent_space validation failed: '
974
+ f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
975
+ assert type(sample_id) == preprocess_response.sample_id_type, \
976
+ (f'tensorleap_custom_latent_space validation failed: '
977
+ f'Argument sample_id should be as the same type as defined in the preprocess response '
978
+ f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
979
+
980
+ def _validate_result(result):
981
+ assert isinstance(result, np.ndarray), \
982
+ (f'tensorleap_custom_loss validation failed: '
983
+ f'The return type should be a numpy array. Got {type(result)}.')
984
+
985
+ def inner_without_validate(sample_id, preprocess_response):
986
+ global _called_from_inside_tl_decorator
987
+ _called_from_inside_tl_decorator += 1
988
+
989
+ try:
990
+ result = user_function(sample_id, preprocess_response)
991
+ finally:
992
+ _called_from_inside_tl_decorator -= 1
993
+
994
+ return result
995
+
996
+ leap_binder.set_custom_latent_space(inner_without_validate)
997
+
430
998
  def inner(sample_id, preprocess_response):
431
999
  if os.environ.get(mapping_runtime_mode_env_var_mame):
432
1000
  return None
@@ -448,30 +1016,45 @@ def tensorleap_preprocess():
448
1016
  leap_binder.set_preprocess(user_function)
449
1017
 
450
1018
  def _validate_input_args(*args, **kwargs):
451
- assert len(args) == 0 and len(kwargs) == 0, \
452
- (f'tensorleap_preprocess validation failed: '
1019
+ assert len(args) + len(kwargs) == 0, \
1020
+ (f'{user_function.__name__}() validation failed: '
453
1021
  f'The function should not take any arguments. Got {args} and {kwargs}.')
454
1022
 
455
1023
  def _validate_result(result):
456
- assert isinstance(result, list), \
457
- (f'tensorleap_preprocess validation failed: '
458
- f'The return type should be a list. Got {type(result)}.')
1024
+ assert isinstance(result, list), (
1025
+ f"{user_function.__name__}() validation failed: expected return type list[{PreprocessResponse.__name__}]"
1026
+ f"(e.g., [PreprocessResponse1, PreprocessResponse2, ...]), but returned type is {type(result).__name__}."
1027
+ if not isinstance(result, tuple)
1028
+ else f"{user_function.__name__}() validation failed: expected to return a single list[{PreprocessResponse.__name__}] object, "
1029
+ f"but returned {len(result)} objects instead."
1030
+ )
459
1031
  for i, response in enumerate(result):
460
1032
  assert isinstance(response, PreprocessResponse), \
461
- (f'tensorleap_preprocess validation failed: '
1033
+ (f'{user_function.__name__}() validation failed: '
462
1034
  f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
463
1035
  assert len(set(result)) == len(result), \
464
- (f'tensorleap_preprocess validation failed: '
1036
+ (f'{user_function.__name__}() validation failed: '
465
1037
  f'The return list should not contain duplicate PreprocessResponse objects.')
466
1038
 
467
1039
  def inner(*args, **kwargs):
1040
+ if not _call_from_tl_platform:
1041
+ set_current('tensorleap_metadata')
468
1042
  if os.environ.get(mapping_runtime_mode_env_var_mame):
469
1043
  return [None, None, None, None]
470
1044
 
471
1045
  _validate_input_args(*args, **kwargs)
472
-
473
1046
  result = user_function()
474
1047
  _validate_result(result)
1048
+
1049
+ # Emit integration test event once per test
1050
+ try:
1051
+ emit_integration_event_once(AnalyticsEvent.PREPROCESS_INTEGRATION_TEST, {
1052
+ 'preprocess_responses_count': len(result)
1053
+ })
1054
+ except Exception as e:
1055
+ logger.debug(f"Failed to emit preprocess integration test event: {e}")
1056
+ if not _call_from_tl_platform:
1057
+ update_env_params_func("tensorleap_preprocess", "v")
475
1058
  return result
476
1059
 
477
1060
  return inner
@@ -480,37 +1063,55 @@ def tensorleap_preprocess():
480
1063
 
481
1064
 
482
1065
  def tensorleap_element_instance_preprocess(
483
- instance_mask_encoder: Callable[[str, PreprocessResponse], List[ElementInstance]]):
1066
+ instance_length_encoder: InstanceLengthCallableInterface, instance_mask_encoder: InstanceCallableInterface):
484
1067
  def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
485
1068
  def user_function_instance() -> List[PreprocessResponse]:
486
1069
  result = user_function()
487
1070
  for preprocess_response in result:
488
1071
  sample_ids_to_instance_mappings = {}
489
1072
  instance_to_sample_ids_mappings = {}
490
- instance_ids_to_names = {}
491
1073
  all_sample_ids = preprocess_response.sample_ids.copy()
492
1074
  for sample_id in preprocess_response.sample_ids:
493
- instances_masks = instance_mask_encoder(sample_id, preprocess_response)
494
- instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(len(instances_masks))]
1075
+ instances_length = instance_length_encoder(sample_id, preprocess_response)
1076
+ instances_ids = [f'{sample_id}_{instance_id}' for instance_id in range(instances_length)]
495
1077
  sample_ids_to_instance_mappings[sample_id] = instances_ids
496
1078
  instance_to_sample_ids_mappings[sample_id] = sample_id
497
- instance_names = [instance.name for instance in instances_masks]
498
- instance_ids_to_names[sample_id] = 'none'
499
- for instance_id, instance_name in zip(instances_ids, instance_names):
1079
+ for instance_id in instances_ids:
500
1080
  instance_to_sample_ids_mappings[instance_id] = sample_id
501
- instance_ids_to_names[instance_id] = instance_name
502
1081
  all_sample_ids.extend(instances_ids)
1082
+ preprocess_response.length = len(all_sample_ids)
503
1083
  preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
504
1084
  preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
505
- preprocess_response.instance_ids_to_names = instance_ids_to_names
506
1085
  preprocess_response.sample_ids = all_sample_ids
507
1086
  return result
508
1087
 
1088
+ def extract_extra_instance_metadata():
1089
+ result = user_function()
1090
+ for preprocess_response in result:
1091
+ for sample_id in preprocess_response.sample_ids:
1092
+ instances_length = instance_length_encoder(sample_id, preprocess_response)
1093
+ if instances_length > 0:
1094
+ element_instance = instance_mask_encoder(sample_id, preprocess_response, 0)
1095
+ instance_metadata = element_instance.instance_metadata
1096
+ if instance_metadata is None:
1097
+ return {}
1098
+ return instance_metadata
1099
+ return {}
1100
+
1101
+
509
1102
  def builtin_instance_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
510
1103
  return {'is_instance': '0', 'original_sample_id': idx, 'instance_name': 'none'}
511
1104
 
1105
+ def builtin_instance_extra_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
1106
+ instance_metadata = extract_extra_instance_metadata()
1107
+ for key, value in instance_metadata.items():
1108
+ instance_metadata[key] = 'unset'
1109
+ return instance_metadata
1110
+
512
1111
  leap_binder.set_preprocess(user_function_instance)
513
1112
  leap_binder.set_metadata(builtin_instance_metadata, "builtin_instance_metadata")
1113
+ leap_binder.set_metadata(builtin_instance_extra_metadata, "builtin_instance_extra_metadata")
1114
+
514
1115
 
515
1116
  def _validate_input_args(*args, **kwargs):
516
1117
  assert len(args) == 0 and len(kwargs) == 0, \
@@ -537,6 +1138,8 @@ def tensorleap_element_instance_preprocess(
537
1138
 
538
1139
  result = user_function_instance()
539
1140
  _validate_result(result)
1141
+ if not _call_from_tl_platform:
1142
+ update_env_params_func("tensorleap_preprocess", "v")
540
1143
  return result
541
1144
 
542
1145
  return inner
@@ -571,7 +1174,7 @@ def tensorleap_unlabeled_preprocess():
571
1174
 
572
1175
  def tensorleap_instances_masks_encoder(name: str):
573
1176
  def decorating_function(user_function: InstanceCallableInterface):
574
- def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
1177
+ def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: int):
575
1178
  assert isinstance(sample_id, str), \
576
1179
  (f'tensorleap_instances_masks_encoder validation failed: '
577
1180
  f'Argument sample_id should be str. Got {type(sample_id)}.')
@@ -582,18 +1185,21 @@ def tensorleap_instances_masks_encoder(name: str):
582
1185
  (f'tensorleap_instances_masks_encoder validation failed: '
583
1186
  f'Argument sample_id should be as the same type as defined in the preprocess response '
584
1187
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
1188
+ assert isinstance(instance_id, int), \
1189
+ (f'tensorleap_instances_masks_encoder validation failed: '
1190
+ f'Argument instance_id should be int. Got {type(instance_id)}.')
585
1191
 
586
1192
  def _validate_result(result):
587
- assert isinstance(result, list), \
1193
+ assert isinstance(result, ElementInstance) or (result is None), \
588
1194
  (f'tensorleap_instances_masks_encoder validation failed: '
589
- f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
1195
+ f'Unsupported return type. Should be a ElementInstance or None. Got {type(result)}.')
590
1196
 
591
- def inner_without_validate(sample_id, preprocess_response):
1197
+ def inner_without_validate(sample_id, preprocess_response, instance_id):
592
1198
  global _called_from_inside_tl_decorator
593
1199
  _called_from_inside_tl_decorator += 1
594
1200
 
595
1201
  try:
596
- result = user_function(sample_id, preprocess_response)
1202
+ result = user_function(sample_id, preprocess_response, instance_id)
597
1203
  finally:
598
1204
  _called_from_inside_tl_decorator -= 1
599
1205
 
@@ -601,6 +1207,52 @@ def tensorleap_instances_masks_encoder(name: str):
601
1207
 
602
1208
  leap_binder.set_instance_masks(inner_without_validate, name)
603
1209
 
1210
+ def inner(sample_id, preprocess_response, instance_id):
1211
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1212
+ return None
1213
+
1214
+ _validate_input_args(sample_id, preprocess_response, instance_id)
1215
+
1216
+ result = inner_without_validate(sample_id, preprocess_response, instance_id)
1217
+
1218
+ _validate_result(result)
1219
+ return result
1220
+
1221
+ return inner
1222
+
1223
+ return decorating_function
1224
+
1225
+
1226
+ def tensorleap_instances_length_encoder(name: str):
1227
+ def decorating_function(user_function: InstanceLengthCallableInterface):
1228
+ def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
1229
+ assert isinstance(sample_id, (str, int)), \
1230
+ (f'tensorleap_instances_length_encoder validation failed: '
1231
+ f'Argument sample_id should be str. Got {type(sample_id)}.')
1232
+ assert isinstance(preprocess_response, PreprocessResponse), \
1233
+ (f'tensorleap_instances_length_encoder validation failed: '
1234
+ f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
1235
+ assert type(sample_id) == preprocess_response.sample_id_type, \
1236
+ (f'tensorleap_instances_length_encoder validation failed: '
1237
+ f'Argument sample_id should be as the same type as defined in the preprocess response '
1238
+ f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
1239
+
1240
+ def _validate_result(result):
1241
+ assert isinstance(result, int), \
1242
+ (f'tensorleap_instances_length_encoder validation failed: '
1243
+ f'Unsupported return type. Should be a int. Got {type(result)}.')
1244
+
1245
+ def inner_without_validate(sample_id, preprocess_response):
1246
+ global _called_from_inside_tl_decorator
1247
+ _called_from_inside_tl_decorator += 1
1248
+
1249
+ try:
1250
+ result = user_function(sample_id, preprocess_response)
1251
+ finally:
1252
+ _called_from_inside_tl_decorator -= 1
1253
+
1254
+ return result
1255
+
604
1256
  def inner(sample_id, preprocess_response):
605
1257
  if os.environ.get(mapping_runtime_mode_env_var_mame):
606
1258
  return None
@@ -617,46 +1269,87 @@ def tensorleap_instances_masks_encoder(name: str):
617
1269
  return decorating_function
618
1270
 
619
1271
 
620
- def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
1272
+ def tensorleap_input_encoder(name: str, channel_dim=_UNSET, model_input_index=None):
621
1273
  def decorating_function(user_function: SectionCallableInterface):
622
1274
  for input_handler in leap_binder.setup_container.inputs:
623
1275
  if input_handler.name == name:
624
1276
  raise Exception(f'Input with name {name} already exists. '
625
1277
  f'Please choose another')
1278
+ nonlocal channel_dim
1279
+
1280
+ channel_dim_was_provided = channel_dim is not _UNSET
1281
+
1282
+ if not channel_dim_was_provided:
1283
+ channel_dim = -1
1284
+ if not _call_from_tl_platform:
1285
+ store_warning_by_param(
1286
+ param_name="channel_dim",
1287
+ user_func_name=user_function.__name__,
1288
+ default_value=channel_dim,
1289
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/input-encoder"
1290
+
1291
+ )
1292
+
626
1293
  if channel_dim <= 0 and channel_dim != -1:
627
1294
  raise Exception(f"Channel dim for input {name} is expected to be either -1 or positive")
628
1295
 
629
- leap_binder.set_input(user_function, name, channel_dim=channel_dim)
630
-
631
1296
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
632
- assert isinstance(sample_id, (int, str)), \
633
- (f'tensorleap_input_encoder validation failed: '
634
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
635
- assert isinstance(preprocess_response, PreprocessResponse), \
636
- (f'tensorleap_input_encoder validation failed: '
637
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
638
1297
  assert type(sample_id) == preprocess_response.sample_id_type, \
639
- (f'tensorleap_input_encoder validation failed: '
1298
+ (f'{user_function.__name__}() validation failed: '
640
1299
  f'Argument sample_id should be as the same type as defined in the preprocess response '
641
1300
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
642
1301
 
643
1302
  def _validate_result(result):
1303
+ validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray")
644
1304
  assert isinstance(result, np.ndarray), \
645
- (f'tensorleap_input_encoder validation failed: '
1305
+ (f'{user_function.__name__}() validation failed: '
646
1306
  f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
647
1307
  assert result.dtype == np.float32, \
648
- (f'tensorleap_input_encoder validation failed: '
1308
+ (f'{user_function.__name__}() validation failed: '
649
1309
  f'The return type should be a numpy array of type float32. Got {result.dtype}.')
650
- assert channel_dim - 1 <= len(result.shape), (f'tensorleap_input_encoder validation failed: '
1310
+ assert channel_dim - 1 <= len(result.shape), (f'{user_function.__name__}() validation failed: '
651
1311
  f'The channel_dim ({channel_dim}) should be <= to the rank of the resulting input rank ({len(result.shape)}).')
652
1312
 
653
- def inner(sample_id, preprocess_response):
1313
+ def inner_without_validate(sample_id, preprocess_response):
1314
+ global _called_from_inside_tl_decorator
1315
+ _called_from_inside_tl_decorator += 1
1316
+
1317
+ try:
1318
+ result = user_function(sample_id, preprocess_response)
1319
+ finally:
1320
+ _called_from_inside_tl_decorator -= 1
1321
+
1322
+ return result
1323
+
1324
+ leap_binder.set_input(inner_without_validate, name, channel_dim=channel_dim)
1325
+
1326
+ def inner(*args, **kwargs):
1327
+ if not _call_from_tl_platform:
1328
+ set_current("tensorleap_input_encoder")
1329
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
1330
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
1331
+ sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
654
1332
  _validate_input_args(sample_id, preprocess_response)
655
- result = user_function(sample_id, preprocess_response)
1333
+
1334
+ result = inner_without_validate(sample_id, preprocess_response)
1335
+
656
1336
  _validate_result(result)
657
1337
 
658
- if _called_from_inside_tl_decorator == 0:
1338
+ if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
1339
+ batch_warning(result, user_function.__name__)
659
1340
  result = np.expand_dims(result, axis=0)
1341
+ # Emit integration test event once per test
1342
+ try:
1343
+ emit_integration_event_once(AnalyticsEvent.INPUT_ENCODER_INTEGRATION_TEST, {
1344
+ 'encoder_name': name,
1345
+ 'channel_dim': channel_dim,
1346
+ 'model_input_index': model_input_index
1347
+ })
1348
+ except Exception as e:
1349
+ logger.debug(f"Failed to emit input_encoder integration test event: {e}")
1350
+ if not _call_from_tl_platform:
1351
+ update_env_params_func("tensorleap_input_encoder", "v")
1352
+
660
1353
  return result
661
1354
 
662
1355
  node_mapping_type = NodeMappingType.Input
@@ -664,22 +1357,23 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
664
1357
  node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
665
1358
  inner.node_mapping = NodeMapping(name, node_mapping_type)
666
1359
 
667
- def mapping_inner(sample_id, preprocess_response):
1360
+ def mapping_inner(*args, **kwargs):
668
1361
  class TempMapping:
669
1362
  pass
670
1363
 
671
1364
  ret = TempMapping()
672
1365
  ret.node_mapping = mapping_inner.node_mapping
673
1366
 
1367
+ leap_binder.mapping_connections.append(NodeConnection(mapping_inner.node_mapping, None))
674
1368
  return ret
675
1369
 
676
1370
  mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
677
1371
 
678
- def final_inner(sample_id, preprocess_response):
1372
+ def final_inner(*args, **kwargs):
679
1373
  if os.environ.get(mapping_runtime_mode_env_var_mame):
680
- return mapping_inner(sample_id, preprocess_response)
1374
+ return mapping_inner(*args, **kwargs)
681
1375
  else:
682
- return inner(sample_id, preprocess_response)
1376
+ return inner(*args, **kwargs)
683
1377
 
684
1378
  final_inner.node_mapping = NodeMapping(name, node_mapping_type)
685
1379
 
@@ -695,40 +1389,64 @@ def tensorleap_gt_encoder(name: str):
695
1389
  raise Exception(f'GT with name {name} already exists. '
696
1390
  f'Please choose another')
697
1391
 
698
- leap_binder.set_ground_truth(user_function, name)
699
-
700
1392
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
701
- assert isinstance(sample_id, (int, str)), \
702
- (f'tensorleap_gt_encoder validation failed: '
703
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
704
- assert isinstance(preprocess_response, PreprocessResponse), \
705
- (f'tensorleap_gt_encoder validation failed: '
706
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
707
1393
  assert type(sample_id) == preprocess_response.sample_id_type, \
708
- (f'tensorleap_gt_encoder validation failed: '
1394
+ (f'{user_function.__name__}() validation failed: '
709
1395
  f'Argument sample_id should be as the same type as defined in the preprocess response '
710
1396
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
711
1397
 
712
1398
  def _validate_result(result):
1399
+ validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray",
1400
+ gt_flag=True)
713
1401
  assert isinstance(result, np.ndarray), \
714
- (f'tensorleap_gt_encoder validation failed: '
1402
+ (f'{user_function.__name__}() validation failed: '
715
1403
  f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
716
1404
  assert result.dtype == np.float32, \
717
- (f'tensorleap_gt_encoder validation failed: '
1405
+ (f'{user_function.__name__}() validation failed: '
718
1406
  f'The return type should be a numpy array of type float32. Got {result.dtype}.')
719
1407
 
720
- def inner(sample_id, preprocess_response):
1408
+ def inner_without_validate(sample_id, preprocess_response):
1409
+ global _called_from_inside_tl_decorator
1410
+ _called_from_inside_tl_decorator += 1
1411
+
1412
+ try:
1413
+ result = user_function(sample_id, preprocess_response)
1414
+ finally:
1415
+ _called_from_inside_tl_decorator -= 1
1416
+
1417
+ return result
1418
+
1419
+ leap_binder.set_ground_truth(inner_without_validate, name)
1420
+
1421
+ def inner(*args, **kwargs):
1422
+ if not _call_from_tl_platform:
1423
+ set_current("tensorleap_gt_encoder")
1424
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
1425
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
1426
+ sample_id, preprocess_response = args
721
1427
  _validate_input_args(sample_id, preprocess_response)
722
- result = user_function(sample_id, preprocess_response)
1428
+
1429
+ result = inner_without_validate(sample_id, preprocess_response)
1430
+
723
1431
  _validate_result(result)
724
1432
 
725
- if _called_from_inside_tl_decorator == 0:
1433
+ if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
1434
+ batch_warning(result, user_function.__name__)
726
1435
  result = np.expand_dims(result, axis=0)
1436
+ # Emit integration test event once per test
1437
+ try:
1438
+ emit_integration_event_once(AnalyticsEvent.GT_ENCODER_INTEGRATION_TEST, {
1439
+ 'encoder_name': name
1440
+ })
1441
+ except Exception as e:
1442
+ logger.debug(f"Failed to emit gt_encoder integration test event: {e}")
1443
+ if not _call_from_tl_platform:
1444
+ update_env_params_func("tensorleap_gt_encoder", "v")
727
1445
  return result
728
1446
 
729
1447
  inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
730
1448
 
731
- def mapping_inner(sample_id, preprocess_response):
1449
+ def mapping_inner(*args, **kwargs):
732
1450
  class TempMapping:
733
1451
  pass
734
1452
 
@@ -739,11 +1457,11 @@ def tensorleap_gt_encoder(name: str):
739
1457
 
740
1458
  mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
741
1459
 
742
- def final_inner(sample_id, preprocess_response):
1460
+ def final_inner(*args, **kwargs):
743
1461
  if os.environ.get(mapping_runtime_mode_env_var_mame):
744
- return mapping_inner(sample_id, preprocess_response)
1462
+ return mapping_inner(*args, **kwargs)
745
1463
  else:
746
- return inner(sample_id, preprocess_response)
1464
+ return inner(*args, **kwargs)
747
1465
 
748
1466
  final_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
749
1467
 
@@ -753,6 +1471,8 @@ def tensorleap_gt_encoder(name: str):
753
1471
 
754
1472
 
755
1473
  def tensorleap_custom_loss(name: str, connects_to=None):
1474
+ name_to_unique_name = defaultdict(set)
1475
+
756
1476
  def decorating_function(user_function: CustomCallableInterface):
757
1477
  for loss_handler in leap_binder.setup_container.custom_loss_handlers:
758
1478
  if loss_handler.custom_loss_handler_data.name == name:
@@ -760,35 +1480,29 @@ def tensorleap_custom_loss(name: str, connects_to=None):
760
1480
  f'Please choose another')
761
1481
 
762
1482
  valid_types = (np.ndarray, SamplePreprocessResponse)
763
- try:
764
- import tensorflow as tf
765
- valid_types = (np.ndarray, SamplePreprocessResponse, tf.Tensor)
766
- except ImportError:
767
- pass
768
1483
 
769
1484
  def _validate_input_args(*args, **kwargs):
770
-
1485
+ assert len(args) + len(kwargs) > 0, (
1486
+ f"{user_function.__name__}() validation failed: "
1487
+ f"Expected at least one positional|key-word argument of the allowed types (np.ndarray|SamplePreprocessResponse|). "
1488
+ f"but received none. "
1489
+ f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
1490
+ )
771
1491
  for i, arg in enumerate(args):
772
- if isinstance(arg, list):
773
- for y, elem in enumerate(arg):
774
- assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
775
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
776
- else:
777
- assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
778
- f'Argument #{i} should be a numpy array. Got {type(arg)}.')
1492
+ assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
1493
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
779
1494
  for _arg_name, arg in kwargs.items():
780
- if isinstance(arg, list):
781
- for y, elem in enumerate(arg):
782
- assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
783
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
784
- else:
785
- assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
786
- f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
1495
+ assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
1496
+ f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
787
1497
 
788
1498
  def _validate_result(result):
789
- assert isinstance(result, valid_types), \
790
- (f'tensorleap_custom_loss validation failed: '
1499
+ validate_output_structure(result, func_name=user_function.__name__,
1500
+ expected_type_name="np.ndarray")
1501
+ assert isinstance(result, np.ndarray), \
1502
+ (f'{user_function.__name__} validation failed: '
791
1503
  f'The return type should be a numpy array. Got {type(result)}.')
1504
+ assert result.ndim < 2, (f'{user_function.__name__} validation failed: '
1505
+ f'The return type should be a 1Dim numpy array but got {result.ndim}Dim.')
792
1506
 
793
1507
  @functools.wraps(user_function)
794
1508
  def inner_without_validate(*args, **kwargs):
@@ -814,11 +1528,16 @@ def tensorleap_custom_loss(name: str, connects_to=None):
814
1528
  _add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
815
1529
 
816
1530
  def inner(*args, **kwargs):
1531
+ if not _call_from_tl_platform:
1532
+ set_current("tensorleap_custom_loss")
817
1533
  _validate_input_args(*args, **kwargs)
818
1534
 
819
1535
  result = inner_without_validate(*args, **kwargs)
820
1536
 
821
1537
  _validate_result(result)
1538
+ if not _call_from_tl_platform:
1539
+ update_env_params_func("tensorleap_custom_loss", "v")
1540
+
822
1541
  return result
823
1542
 
824
1543
  def mapping_inner(*args, **kwargs):
@@ -826,6 +1545,10 @@ def tensorleap_custom_loss(name: str, connects_to=None):
826
1545
  if 'user_unique_name' in kwargs:
827
1546
  user_unique_name = kwargs['user_unique_name']
828
1547
 
1548
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
1549
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
1550
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
1551
+
829
1552
  ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
830
1553
  ordered_connections = list(args) + ordered_connections
831
1554
  _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
@@ -871,3 +1594,180 @@ def tensorleap_custom_layer(name: str):
871
1594
  return custom_layer
872
1595
 
873
1596
  return decorating_function
1597
+
1598
+
1599
+ def tensorleap_status_table():
1600
+ import atexit
1601
+ import sys
1602
+ import traceback
1603
+ from typing import Any
1604
+
1605
+ CHECK = "✅"
1606
+ CROSS = "❌"
1607
+ UNKNOWN = "❔"
1608
+
1609
+ code_mapping_failure = [0]
1610
+
1611
+ table = [
1612
+ {"name": "tensorleap_preprocess", "Added to integration": UNKNOWN},
1613
+ {"name": "tensorleap_integration_test", "Added to integration": UNKNOWN},
1614
+ {"name": "tensorleap_input_encoder", "Added to integration": UNKNOWN},
1615
+ {"name": "tensorleap_gt_encoder", "Added to integration": UNKNOWN},
1616
+ {"name": "tensorleap_load_model", "Added to integration": UNKNOWN},
1617
+ {"name": "tensorleap_custom_loss", "Added to integration": UNKNOWN},
1618
+ {"name": "tensorleap_custom_metric (optional)", "Added to integration": UNKNOWN},
1619
+ {"name": "tensorleap_metadata (optional)", "Added to integration": UNKNOWN},
1620
+ {"name": "tensorleap_custom_visualizer (optional)", "Added to integration": UNKNOWN},
1621
+ ]
1622
+
1623
+ _finalizer_called = {"done": False}
1624
+ _crashed = {"value": False}
1625
+ _current_func = {"name": None}
1626
+
1627
+ def _link(url: str) -> str:
1628
+ return f"\033{url}\033"
1629
+
1630
+ def _remove_suffix(s: str, suffix: str) -> str:
1631
+ if suffix and s.endswith(suffix):
1632
+ return s[:-len(suffix)]
1633
+ return s
1634
+
1635
+ def _find_row(name: str):
1636
+ for row in table:
1637
+ if _remove_suffix(row["name"], " (optional)") == name:
1638
+ return row
1639
+ return None
1640
+
1641
+ def _set_status(name: str, status_symbol: str):
1642
+ row = _find_row(name)
1643
+ if not row:
1644
+ return
1645
+
1646
+ cur = row["Added to integration"]
1647
+ if status_symbol == UNKNOWN:
1648
+ return
1649
+ if cur == CHECK and status_symbol != CHECK:
1650
+ return
1651
+
1652
+ row["Added to integration"] = status_symbol
1653
+
1654
+ def _mark_unknowns_as_cross():
1655
+ for row in table:
1656
+ if row["Added to integration"] == UNKNOWN:
1657
+ row["Added to integration"] = CROSS
1658
+
1659
+ def _format_default_value(v: Any) -> str:
1660
+ if hasattr(v, "name"):
1661
+ return str(v.name)
1662
+ if isinstance(v, str):
1663
+ return v
1664
+ s = repr(v)
1665
+ return s if len(s) <= 120 else s[:120] + "..."
1666
+
1667
+ def _print_param_default_warnings():
1668
+ data = _get_param_default_warnings()
1669
+ if not data:
1670
+ return
1671
+
1672
+ print("\nWarnings (Default use. It is recommended to set values explicitly):")
1673
+ for param_name in sorted(data.keys()):
1674
+ default_value = data[param_name]["default_value"]
1675
+ funcs = ", ".join(sorted(data[param_name]["funcs"]))
1676
+ dv = _format_default_value(default_value)
1677
+
1678
+ docs_link = data[param_name].get("link_to_docs")
1679
+ docs_part = f" {_link(docs_link)}" if docs_link else ""
1680
+ print(
1681
+ f" ⚠️ Parameter '{param_name}' defaults to {dv} in the following functions: [{funcs}]. "
1682
+ f"For more information, check {docs_part}")
1683
+ print("\nIf this isn’t the intended behaviour, set them explicitly.")
1684
+
1685
+ def _print_table():
1686
+ _print_param_default_warnings()
1687
+
1688
+ if not started_from("leap_integration.py"):
1689
+ return
1690
+
1691
+ ready_mess = "\nAll parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system."
1692
+ not_ready_mess = "\nSome mandatory components have not yet been added to the Integration test. Recommended next interface to add is: "
1693
+ mandatory_ready_mess = "\nAll mandatory parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system or continue to the next optional reccomeded interface,adding: "
1694
+ code_mapping_failure_mes = "Tensorleap_integration_test code flow failed, check raised exception."
1695
+
1696
+ name_width = max(len(row["name"]) for row in table)
1697
+ status_width = max(len(row["Added to integration"]) for row in table)
1698
+
1699
+ header = f"{'Decorator Name'.ljust(name_width)} | {'Added to integration'.ljust(status_width)}"
1700
+ sep = "-" * len(header)
1701
+
1702
+ print("\n" + header)
1703
+ print(sep)
1704
+
1705
+ ready = True
1706
+ next_step = None
1707
+
1708
+ for row in table:
1709
+ print(f"{row['name'].ljust(name_width)} | {row['Added to integration'].ljust(status_width)}")
1710
+ if not _crashed["value"] and ready:
1711
+ if row["Added to integration"] != CHECK:
1712
+ ready = False
1713
+ next_step = row["name"]
1714
+
1715
+ if _crashed["value"]:
1716
+ print(f"\nScript crashed before completing all steps. crashed at function '{_current_func['name']}'.")
1717
+ return
1718
+
1719
+ if code_mapping_failure[0]:
1720
+ print(f"\n{CROSS + code_mapping_failure_mes}.")
1721
+ return
1722
+
1723
+ print(ready_mess) if ready else print(
1724
+ mandatory_ready_mess + next_step
1725
+ ) if (next_step and "optional" in next_step) else print(not_ready_mess + (next_step or ""))
1726
+
1727
+ def set_current(name: str):
1728
+ _current_func["name"] = name
1729
+
1730
+ def update_env_params(name: str, status: str = "v"):
1731
+ if name == "code_mapping":
1732
+ code_mapping_failure[0] = 1
1733
+ if status == "v":
1734
+ _set_status(name, CHECK)
1735
+ else:
1736
+ _set_status(name, CROSS)
1737
+
1738
+ def run_on_exit():
1739
+ if _finalizer_called["done"]:
1740
+ return
1741
+ _finalizer_called["done"] = True
1742
+ if not _crashed["value"]:
1743
+ _mark_unknowns_as_cross()
1744
+
1745
+ _print_table()
1746
+
1747
+ def handle_exception(exc_type, exc_value, exc_traceback):
1748
+ _crashed["value"] = True
1749
+ crashed_name = _current_func["name"]
1750
+ if crashed_name:
1751
+ row = _find_row(crashed_name)
1752
+ if row and row["Added to integration"] != CHECK:
1753
+ row["Added to integration"] = CROSS
1754
+
1755
+ traceback.print_exception(exc_type, exc_value, exc_traceback)
1756
+ run_on_exit()
1757
+
1758
+ atexit.register(run_on_exit)
1759
+ sys.excepthook = handle_exception
1760
+
1761
+ return set_current, update_env_params
1762
+
1763
+
1764
+ if not _call_from_tl_platform:
1765
+ set_current, update_env_params_func = tensorleap_status_table()
1766
+
1767
+
1768
+
1769
+
1770
+
1771
+
1772
+
1773
+