code-loader 1.0.112.dev6__py3-none-any.whl → 1.0.153.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,26 +1,201 @@
1
1
  # mypy: ignore-errors
2
2
  import os
3
- from typing import Optional, Union, Callable, List, Dict
3
+ import warnings
4
+ import logging
5
+ from collections import defaultdict
6
+ from functools import lru_cache
7
+ from pathlib import Path
8
+ from typing import Optional, Union, Callable, List, Dict, Set, Any
9
+ from typing import Optional, Union, Callable, List, Dict, get_args, get_origin, DefaultDict
4
10
 
5
11
  import numpy as np
6
12
  import numpy.typing as npt
7
13
 
14
+ logger = logging.getLogger(__name__)
15
+
8
16
  from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs, \
9
17
  CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
10
18
  VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
11
19
  ConfusionMatrixElement, SamplePreprocessResponse, PredictionTypeHandler, InstanceCallableInterface, ElementInstance, \
12
20
  InstanceLengthCallableInterface
13
- from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType
21
+ from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType, DataStateType
14
22
  from code_loader import leap_binder
15
23
  from code_loader.contract.mapping import NodeMapping, NodeMappingType, NodeConnection
16
24
  from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
17
25
  LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
18
26
  from code_loader.inner_leap_binder.leapbinder import mapping_runtime_mode_env_var_mame
27
+ from code_loader.mixpanel_tracker import clear_integration_events, AnalyticsEvent, emit_integration_event_once
19
28
 
20
29
  import inspect
21
30
  import functools
31
+ from pathlib import Path
22
32
 
23
33
  _called_from_inside_tl_decorator = 0
34
+ _called_from_inside_tl_integration_test_decorator = False
35
+ _call_from_tl_platform = os.environ.get('IS_TENSORLEAP_PLATFORM') == 'true'
36
+
37
+ # ---- warnings store (module-level) ----
38
+ _UNSET = object()
39
+ _STORED_WARNINGS: List[Dict[str, Any]] = []
40
+ _STORED_WARNING_KEYS: Set[tuple] = set()
41
+ # param_name -> set(user_func_name)
42
+ _PARAM_DEFAULT_FUNCS: DefaultDict[str, Set[str]] = defaultdict(set)
43
+ # param_name -> default_value used (repr-able)
44
+ _PARAM_DEFAULT_VALUE: Dict[str, Any] = {}
45
+ _PARAM_DEFAULT_DOCS: Dict[str, str] = {}
46
+
47
+
48
+ def get_entry_script_path() -> str:
49
+ import sys
50
+ argv0 = sys.argv[0] if sys.argv else ""
51
+ if argv0:
52
+ return str(Path(argv0).resolve())
53
+ import __main__
54
+ main_file = getattr(__main__, "__file__", "") or ""
55
+ return str(Path(main_file).resolve()) if main_file else ""
56
+
57
+ def started_from(filename: str) -> bool:
58
+ entry = get_entry_script_path()
59
+ return bool(entry) and Path(entry).name == filename
60
+
61
+
62
+ def store_warning_by_param(
63
+ *,
64
+ param_name: str,
65
+ user_func_name: str,
66
+ default_value: Any,
67
+ link_to_docs: str = None,
68
+ ) -> None:
69
+ _PARAM_DEFAULT_FUNCS[param_name].add(user_func_name)
70
+
71
+ if param_name not in _PARAM_DEFAULT_VALUE:
72
+ _PARAM_DEFAULT_VALUE[param_name] = default_value
73
+
74
+ if link_to_docs and param_name not in _PARAM_DEFAULT_DOCS:
75
+ _PARAM_DEFAULT_DOCS[param_name] = link_to_docs
76
+
77
+
78
+ def _get_param_default_warnings() -> Dict[str, Dict[str, Any]]:
79
+ out: Dict[str, Dict[str, Any]] = {}
80
+ for p, funcs in _PARAM_DEFAULT_FUNCS.items():
81
+ out[p] = {
82
+ "default_value": _PARAM_DEFAULT_VALUE.get(p, None),
83
+ "funcs": set(funcs),
84
+ "link_to_docs": _PARAM_DEFAULT_DOCS.get(p),
85
+ }
86
+ return out
87
+
88
+
89
+ def _get_stored_warnings() -> List[Dict[str, Any]]:
90
+ return list(_STORED_WARNINGS)
91
+
92
+
93
+ def validate_args_structure(*args, types_order, func_name, expected_names, **kwargs):
94
+ def _type_to_str(t):
95
+ origin = get_origin(t)
96
+ if origin is Union:
97
+ return " | ".join(tt.__name__ for tt in get_args(t))
98
+ elif hasattr(t, "__name__"):
99
+ return t.__name__
100
+ else:
101
+ return str(t)
102
+
103
+ def _format_types(types, names=None):
104
+ return ", ".join(
105
+ f"{(names[i] + ': ') if names else f'arg{i}: '}{_type_to_str(ty)}"
106
+ for i, ty in enumerate(types)
107
+ )
108
+
109
+ if expected_names:
110
+ normalized_args = []
111
+ for i, name in enumerate(expected_names):
112
+ if i < len(args):
113
+ normalized_args.append(args[i])
114
+ elif name in kwargs:
115
+ normalized_args.append(kwargs[name])
116
+ else:
117
+ raise AssertionError(
118
+ f"{func_name} validation failed: "
119
+ f"Missing required argument '{name}'. "
120
+ f"Expected arguments: {expected_names}."
121
+ )
122
+ else:
123
+ normalized_args = list(args)
124
+ if len(normalized_args) != len(types_order):
125
+ expected = _format_types(types_order, expected_names)
126
+ got_types = ", ".join(type(arg).__name__ for arg in normalized_args)
127
+ raise AssertionError(
128
+ f"{func_name} validation failed: "
129
+ f"Expected exactly {len(types_order)} arguments ({expected}), "
130
+ f"but got {len(normalized_args)} argument(s) of type(s): ({got_types}). "
131
+ f"Correct usage example: {func_name}({expected})"
132
+ )
133
+
134
+ for i, (arg, expected_type) in enumerate(zip(normalized_args, types_order)):
135
+ origin = get_origin(expected_type)
136
+ if origin is Union:
137
+ allowed_types = get_args(expected_type)
138
+ else:
139
+ allowed_types = (expected_type,)
140
+
141
+ if not isinstance(arg, allowed_types):
142
+ allowed_str = " | ".join(t.__name__ for t in allowed_types)
143
+ raise AssertionError(
144
+ f"{func_name} validation failed: "
145
+ f"Argument '{expected_names[i] if expected_names else f'arg{i}'}' "
146
+ f"expected type {allowed_str}, but got {type(arg).__name__}. "
147
+ f"Correct usage example: {func_name}({_format_types(types_order, expected_names)})"
148
+ )
149
+
150
+
151
+ def validate_output_structure(result, func_name: str, expected_type_name="np.ndarray", gt_flag=False):
152
+ if result is None or (isinstance(result, float) and np.isnan(result)):
153
+ if gt_flag:
154
+ raise AssertionError(
155
+ f"{func_name} validation failed: "
156
+ f"The function returned {result!r}. "
157
+ f"If you are working with an unlabeled dataset and no ground truth is available, "
158
+ f"use 'return np.array([], dtype=np.float32)' instead. "
159
+ f"Otherwise, {func_name} expected a single {expected_type_name} object. "
160
+ f"Make sure the function ends with 'return <{expected_type_name}>'."
161
+ )
162
+
163
+ raise AssertionError(
164
+ f"{func_name} validation failed: "
165
+ f"The function returned None. "
166
+ f"Expected a single {expected_type_name} object. "
167
+ f"Make sure the function ends with 'return <{expected_type_name}>'."
168
+ )
169
+ if isinstance(result, tuple):
170
+ element_descriptions = [
171
+ f"[{i}] type: {type(r).__name__}"
172
+ for i, r in enumerate(result)
173
+ ]
174
+ element_summary = "\n ".join(element_descriptions)
175
+
176
+ raise AssertionError(
177
+ f"{func_name} validation failed: "
178
+ f"The function returned multiple outputs ({len(result)} values), "
179
+ f"but only a single {expected_type_name} is allowed.\n\n"
180
+ f"Returned elements:\n"
181
+ f" {element_summary}\n\n"
182
+ f"Correct usage example:\n"
183
+ f" def {func_name}(...):\n"
184
+ f" return <{expected_type_name}>\n\n"
185
+ f"If you intended to return multiple values, combine them into a single "
186
+ f"{expected_type_name} (e.g., by concatenation or stacking)."
187
+ )
188
+
189
+
190
+ def batch_warning(result, func_name):
191
+ if len(result.shape) > 0 and result.shape[0] == 1:
192
+ warnings.warn(
193
+ f"{func_name} warning: Tensorleap will add a batch dimension at axis 0 to the output of {func_name}, "
194
+ f"although the detected size of axis 0 is already 1. "
195
+ f"This may lead to an extra batch dimension (e.g., shape (1, 1, ...)). "
196
+ f"Please ensure that the output of '{func_name}' is not already batched "
197
+ f"to avoid computation errors."
198
+ )
24
199
 
25
200
 
26
201
  def _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type):
@@ -41,51 +216,234 @@ def _add_mapping_connections(connects_to, arg_names, node_mapping_type, name):
41
216
  _add_mapping_connection(user_unique_name, connection_destinations, arg_names, name, node_mapping_type)
42
217
 
43
218
 
44
- def integration_test():
219
+ def tensorleap_integration_test():
45
220
  def decorating_function(integration_test_function: Callable):
46
221
  leap_binder.integration_test_func = integration_test_function
47
222
 
223
+ def _validate_input_args(*args, **kwargs):
224
+ sample_id, preprocess_response = args
225
+ assert type(sample_id) == preprocess_response.sample_id_type, (
226
+ f"tensorleap_integration_test validation failed: "
227
+ f"sample_id type ({type(sample_id).__name__}) does not match the expected "
228
+ f"type ({preprocess_response.sample_id_type}) from the PreprocessResponse."
229
+ )
230
+
48
231
  def inner(*args, **kwargs):
49
- ret = integration_test_function(*args, **kwargs)
232
+ if not _call_from_tl_platform:
233
+ set_current('tensorleap_integration_test')
234
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
235
+ func_name='integration_test', expected_names=["idx", "preprocess"], **kwargs)
236
+ _validate_input_args(*args, **kwargs)
50
237
 
238
+ global _called_from_inside_tl_integration_test_decorator
239
+ # Clear integration test events for new test
51
240
  try:
52
- os.environ[mapping_runtime_mode_env_var_mame] = 'True'
53
- integration_test_function(None, None)
241
+ clear_integration_events()
54
242
  except Exception as e:
55
- print(f'Error during integration test: Make sure to disable any non tensorleap decorators '
56
- f'functions before pushing a new TL version')
243
+ logger.debug(f"Failed to clear integration events: {e}")
244
+ try:
245
+ _called_from_inside_tl_integration_test_decorator = True
246
+ if not _call_from_tl_platform:
247
+ update_env_params_func("tensorleap_integration_test",
248
+ "v") # put here because otherwise it will become v only if it finishes all the script
249
+ ret = integration_test_function(*args, **kwargs)
250
+
251
+ try:
252
+ os.environ[mapping_runtime_mode_env_var_mame] = 'True'
253
+ integration_test_function(None, PreprocessResponse(state=DataStateType.training, length=0))
254
+ except Exception as e:
255
+ import traceback
256
+ first_tb = traceback.extract_tb(e.__traceback__)[-1]
257
+ file_name = Path(first_tb.filename).name
258
+ line_number = first_tb.lineno
259
+ if isinstance(e, TypeError) and 'is not subscriptable' in str(e):
260
+ update_env_params_func("code_mapping", "x")
261
+ raise (f'Invalid integration code. File {file_name}, line {line_number}: '
262
+ f"indexing is supported only on the model's predictions inside the integration test. Please remove this indexing operation usage from the integration test code.")
263
+ else:
264
+ update_env_params_func("code_mapping", "x")
265
+
266
+ raise (f'Invalid integration code. File {file_name}, line {line_number}: '
267
+ f'Integration test is only allowed to call Tensorleap decorators. '
268
+ f'Ensure any arithmetics, external library use, Python logic is placed within Tensorleap decoders')
269
+ finally:
270
+ if mapping_runtime_mode_env_var_mame in os.environ:
271
+ del os.environ[mapping_runtime_mode_env_var_mame]
57
272
  finally:
58
- if mapping_runtime_mode_env_var_mame in os.environ:
59
- del os.environ[mapping_runtime_mode_env_var_mame]
273
+ _called_from_inside_tl_integration_test_decorator = False
274
+
275
+ leap_binder.check()
60
276
 
61
277
  return inner
62
278
 
63
279
  return decorating_function
64
280
 
65
281
 
66
- def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = None):
282
+ def _safe_get_item(key):
283
+ try:
284
+ return NodeMappingType[f'Input{str(key)}']
285
+ except ValueError:
286
+ raise Exception(f'Tensorleap currently supports models with no more then 10 inputs')
287
+
288
+
289
+ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]] = _UNSET):
290
+ prediction_types_was_provided = prediction_types is not _UNSET
291
+
292
+ if not prediction_types_was_provided:
293
+ prediction_types = []
294
+ if not _call_from_tl_platform:
295
+ store_warning_by_param(
296
+ param_name="prediction_types",
297
+ user_func_name="tensorleap_load_model",
298
+ default_value=prediction_types,
299
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
300
+ )
301
+ assert isinstance(prediction_types, list), (
302
+ f"tensorleap_load_model validation failed: "
303
+ f" prediction_types is an optional argument of type List[PredictionTypeHandler]] but got {type(prediction_types).__name__}."
304
+ )
67
305
  for i, prediction_type in enumerate(prediction_types):
306
+ assert isinstance(prediction_type, PredictionTypeHandler), (f"tensorleap_load_model validation failed: "
307
+ f" prediction_types at position {i} must be of type PredictionTypeHandler but got {type(prediction_types[i]).__name__}.")
308
+ prediction_type_channel_dim_was_provided = prediction_type.channel_dim != "tl_default_value"
309
+ if not prediction_type_channel_dim_was_provided:
310
+ prediction_types[i].channel_dim = -1
311
+ if not _call_from_tl_platform:
312
+ store_warning_by_param(
313
+ param_name=f"prediction_types[{i}].channel_dim",
314
+ user_func_name="tensorleap_load_model",
315
+ default_value=prediction_types[i].channel_dim,
316
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/integration-test#tensorleap_load_model"
317
+ )
68
318
  leap_binder.add_prediction(prediction_type.name, prediction_type.labels, prediction_type.channel_dim, i)
69
319
 
70
- def decorating_function(load_model_func):
320
+ def _validate_result(result) -> None:
321
+ valid_types = ["onnxruntime", "keras"]
322
+ err_message = f"tensorleap_load_model validation failed:\nSupported models are Keras and onnxruntime only and non of them was returned."
323
+ validate_output_structure(result, func_name="tensorleap_load_model",
324
+ expected_type_name=[" | ".join(t for t in valid_types)][0])
325
+ try:
326
+ import keras
327
+ except ImportError:
328
+ keras = None
329
+ try:
330
+ import tensorflow as tf
331
+ except ImportError:
332
+ tf = None
333
+ try:
334
+ import onnxruntime
335
+ except ImportError:
336
+ onnxruntime = None
337
+
338
+ if not keras and not onnxruntime:
339
+ raise AssertionError(err_message)
340
+
341
+ is_keras_model = (
342
+ bool(keras and isinstance(result, getattr(keras, "Model", tuple())))
343
+ or bool(tf and isinstance(result, getattr(tf.keras, "Model", tuple())))
344
+ )
345
+ is_onnx_model = bool(onnxruntime and isinstance(result, onnxruntime.InferenceSession))
346
+
347
+ if not any([is_keras_model, is_onnx_model]):
348
+ raise AssertionError(err_message)
349
+
350
+ def decorating_function(load_model_func, prediction_types=prediction_types):
71
351
  class TempMapping:
72
352
  pass
73
353
 
74
- def inner():
354
+ @lru_cache()
355
+ def inner(*args, **kwargs):
356
+ if not _call_from_tl_platform:
357
+ set_current('tensorleap_load_model')
358
+ validate_args_structure(*args, types_order=[],
359
+ func_name='tensorleap_load_model', expected_names=[], **kwargs)
360
+
75
361
  class ModelPlaceholder:
76
- def __init__(self):
77
- self.model = load_model_func()
362
+ def __init__(self, prediction_types):
363
+ self.model = load_model_func() # TODO- check why this fails on onnx model
364
+ self.prediction_types = prediction_types
365
+ _validate_result(self.model)
78
366
 
79
367
  # keras interface
80
368
  def __call__(self, arg):
81
369
  ret = self.model(arg)
370
+ self.validate_declared_prediction_types(ret)
371
+ if isinstance(ret, list):
372
+ return [r.numpy() for r in ret]
82
373
  return ret.numpy()
83
374
 
375
+ def validate_declared_prediction_types(self, ret):
376
+ if not (len(self.prediction_types) == len(ret) if isinstance(ret, list) else 1) and len(
377
+ self.prediction_types) != 0:
378
+ if not _call_from_tl_platform:
379
+ update_env_params_func("tensorleap_load_model", "x")
380
+ raise Exception(
381
+ f"tensorleap_load_model validation failed: number of declared prediction types({len(prediction_types)}) != number of model outputs({len(ret) if isinstance(ret, list) else 1})")
382
+
383
+ def _convert_onnx_inputs_to_correct_type(
384
+ self, float_arrays_inputs: Dict[str, np.ndarray]
385
+ ) -> Dict[str, np.ndarray]:
386
+ ONNX_TYPE_TO_NP = {
387
+ "tensor(float)": np.float32,
388
+ "tensor(double)": np.float64,
389
+ "tensor(int64)": np.int64,
390
+ "tensor(int32)": np.int32,
391
+ "tensor(int16)": np.int16,
392
+ "tensor(int8)": np.int8,
393
+ "tensor(uint64)": np.uint64,
394
+ "tensor(uint32)": np.uint32,
395
+ "tensor(uint16)": np.uint16,
396
+ "tensor(uint8)": np.uint8,
397
+ "tensor(bool)": np.bool_,
398
+ }
399
+
400
+ """
401
+ Cast user-provided NumPy inputs to match the dtypes/shapes
402
+ expected by an ONNX Runtime InferenceSession.
403
+ """
404
+ coerced = {}
405
+ meta = {i.name: i for i in self.model.get_inputs()}
406
+
407
+ for name, arr in float_arrays_inputs.items():
408
+ if name not in meta:
409
+ # Keep as-is unless extra inputs are disallowed
410
+ coerced[name] = arr
411
+ continue
412
+
413
+ info = meta[name]
414
+ onnx_type = info.type
415
+ want_dtype = ONNX_TYPE_TO_NP.get(onnx_type)
416
+
417
+ if want_dtype is None:
418
+ raise TypeError(f"Unsupported ONNX input type: {onnx_type}")
419
+
420
+ # Cast dtype if needed
421
+ if arr.dtype != want_dtype:
422
+ arr = arr.astype(want_dtype, copy=False)
423
+
424
+ coerced[name] = arr
425
+
426
+ # Verify required inputs are present
427
+ missing = [n for n in meta if n not in coerced]
428
+ if missing:
429
+ raise KeyError(f"Missing required input(s): {sorted(missing)}")
430
+
431
+ return coerced
432
+
84
433
  # onnx runtime interface
85
434
  def run(self, output_names, input_dict):
86
- return self.model.run(output_names, input_dict)
435
+ corrected_type_inputs = self._convert_onnx_inputs_to_correct_type(input_dict)
436
+ ret = self.model.run(output_names, corrected_type_inputs)
437
+ self.validate_declared_prediction_types(ret)
438
+ return ret
87
439
 
88
- return ModelPlaceholder()
440
+ def get_inputs(self):
441
+ return self.model.get_inputs()
442
+
443
+ model_placeholder = ModelPlaceholder(prediction_types)
444
+ if not _call_from_tl_platform:
445
+ update_env_params_func("tensorleap_load_model", "v")
446
+ return model_placeholder
89
447
 
90
448
  def mapping_inner():
91
449
  class ModelOutputPlaceholder:
@@ -97,15 +455,20 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
97
455
  f'Expected key to be an int, got {type(key)} instead.'
98
456
 
99
457
  ret = TempMapping()
100
- ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
458
+ try:
459
+ ret.node_mapping = NodeMapping('', NodeMappingType(f'Prediction{str(key)}'))
460
+ except ValueError as e:
461
+ raise Exception(f'Tensorleap currently supports models with no more then 10 active predictions,'
462
+ f' {key} not supported.')
101
463
  return ret
102
464
 
103
465
  class ModelPlaceholder:
466
+
104
467
  # keras interface
105
468
  def __call__(self, arg):
106
469
  if isinstance(arg, list):
107
470
  for i, elem in enumerate(arg):
108
- elem.node_mapping.type = NodeMappingType[f'Input{str(i)}']
471
+ elem.node_mapping.type = _safe_get_item(i)
109
472
  else:
110
473
  arg.node_mapping.type = NodeMappingType.Input0
111
474
 
@@ -116,18 +479,38 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
116
479
  assert output_names is None
117
480
  assert isinstance(input_dict, dict), \
118
481
  f'Expected input_dict to be a dict, got {type(input_dict)} instead.'
119
- for i, elem in enumerate(input_dict.values()):
120
- elem.node_mapping.type = NodeMappingType[f'Input{str(i)}']
482
+ for i, (input_key, elem) in enumerate(input_dict.items()):
483
+ if isinstance(input_key, NodeMappingType):
484
+ elem.node_mapping.type = input_key
485
+ else:
486
+ elem.node_mapping.type = _safe_get_item(i)
121
487
 
122
488
  return ModelOutputPlaceholder()
123
489
 
490
+ def get_inputs(self):
491
+ class FollowIndex:
492
+ def __init__(self, index):
493
+ self.name = _safe_get_item(index)
494
+
495
+ class FollowInputIndex:
496
+ def __init__(self):
497
+ pass
498
+
499
+ def __getitem__(self, index):
500
+ assert isinstance(index, int), \
501
+ f'Expected key to be an int, got {type(index)} instead.'
502
+
503
+ return FollowIndex(index)
504
+
505
+ return FollowInputIndex()
506
+
124
507
  return ModelPlaceholder()
125
508
 
126
- def final_inner():
509
+ def final_inner(*args, **kwargs):
127
510
  if os.environ.get(mapping_runtime_mode_env_var_mame):
128
511
  return mapping_inner()
129
512
  else:
130
- return inner()
513
+ return inner(*args, **kwargs)
131
514
 
132
515
  return final_inner
133
516
 
@@ -135,83 +518,186 @@ def tensorleap_load_model(prediction_types: Optional[List[PredictionTypeHandler]
135
518
 
136
519
 
137
520
  def tensorleap_custom_metric(name: str,
138
- direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
521
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
139
522
  compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
140
523
  connects_to=None):
141
- def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
142
- CustomMultipleReturnCallableInterfaceMultiArgs,
143
- ConfusionMatrixCallableInterfaceMultiArgs]):
524
+ name_to_unique_name = defaultdict(set)
525
+
526
+ def decorating_function(
527
+ user_function: Union[CustomCallableInterfaceMultiArgs, CustomMultipleReturnCallableInterfaceMultiArgs,
528
+ ConfusionMatrixCallableInterfaceMultiArgs]):
529
+ nonlocal direction
530
+
531
+ direction_was_provided = direction is not _UNSET
532
+
533
+ if not direction_was_provided:
534
+ direction = MetricDirection.Downward
535
+ if not _call_from_tl_platform:
536
+ store_warning_by_param(
537
+ param_name="direction",
538
+ user_func_name=user_function.__name__,
539
+ default_value=direction,
540
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/custom-metrics"
541
+
542
+ )
543
+
544
+ def _validate_decorators_signature():
545
+ err_message = f"{user_function.__name__} validation failed.\n"
546
+ if not isinstance(name, str):
547
+ raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
548
+ valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
549
+ if isinstance(direction, MetricDirection):
550
+ if direction not in valid_directions:
551
+ raise ValueError(
552
+ err_message +
553
+ f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
554
+ f"got type {type(direction).__name__}."
555
+ )
556
+ elif isinstance(direction, dict):
557
+ if not all(isinstance(k, str) for k in direction.keys()):
558
+ invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
559
+ raise TypeError(
560
+ err_message +
561
+ f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
562
+ )
563
+ for k, v in direction.items():
564
+ if v not in valid_directions:
565
+ raise ValueError(
566
+ err_message +
567
+ f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
568
+ f"got type {type(v).__name__}."
569
+ )
570
+ else:
571
+ raise TypeError(
572
+ err_message +
573
+ f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
574
+ f"got type {type(direction).__name__}."
575
+ )
576
+ if compute_insights is not None:
577
+ if not isinstance(compute_insights, (bool, dict)):
578
+ raise TypeError(
579
+ err_message +
580
+ f"`compute_insights` must be a bool or a Dict[str, bool], "
581
+ f"got type {type(compute_insights).__name__}."
582
+ )
583
+ if isinstance(compute_insights, dict):
584
+ if not all(isinstance(k, str) for k in compute_insights.keys()):
585
+ invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
586
+ raise TypeError(
587
+ err_message +
588
+ f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
589
+ )
590
+ for k, v in compute_insights.items():
591
+ if not isinstance(v, bool):
592
+ raise TypeError(
593
+ err_message +
594
+ f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
595
+ )
596
+ if connects_to is not None:
597
+ valid_types = (str, list, tuple, set)
598
+ if not isinstance(connects_to, valid_types):
599
+ raise TypeError(
600
+ err_message +
601
+ f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
602
+ )
603
+ if isinstance(connects_to, (list, tuple, set)):
604
+ invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
605
+ if invalid_elems:
606
+ raise TypeError(
607
+ err_message +
608
+ f"All elements in `connects_to` must be strings, "
609
+ f"but found element types: {invalid_elems}."
610
+ )
611
+
612
+ _validate_decorators_signature()
613
+
144
614
  for metric_handler in leap_binder.setup_container.metrics:
145
615
  if metric_handler.metric_handler_data.name == name:
146
616
  raise Exception(f'Metric with name {name} already exists. '
147
617
  f'Please choose another')
148
618
 
149
619
  def _validate_input_args(*args, **kwargs) -> None:
620
+ assert len(args) + len(kwargs) > 0, (
621
+ f"{user_function.__name__}() validation failed: "
622
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
623
+ f"but received none. "
624
+ f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
625
+ )
150
626
  for i, arg in enumerate(args):
151
627
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
152
- f'tensorleap_custom_metric validation failed: '
628
+ f'{user_function.__name__}() validation failed: '
153
629
  f'Argument #{i} should be a numpy array. Got {type(arg)}.')
154
630
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
155
631
  assert arg.shape[0] == leap_binder.batch_size_to_validate, \
156
- (f'tensorleap_custom_metric validation failed: Argument #{i} '
632
+ (f'{user_function.__name__}() validation failed: Argument #{i} '
157
633
  f'first dim should be as the batch size. Got {arg.shape[0]} '
158
634
  f'instead of {leap_binder.batch_size_to_validate}')
159
635
 
160
636
  for _arg_name, arg in kwargs.items():
161
637
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
162
- f'tensorleap_custom_metric validation failed: '
638
+ f'{user_function.__name__}() validation failed: '
163
639
  f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
164
640
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
165
641
  assert arg.shape[0] == leap_binder.batch_size_to_validate, \
166
- (f'tensorleap_custom_metric validation failed: Argument {_arg_name} '
642
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
167
643
  f'first dim should be as the batch size. Got {arg.shape[0]} '
168
644
  f'instead of {leap_binder.batch_size_to_validate}')
169
645
 
170
646
  def _validate_result(result) -> None:
171
- supported_types_message = (f'tensorleap_custom_metric validation failed: '
172
- f'Metric has returned unsupported type. Supported types are List[float], '
173
- f'List[List[ConfusionMatrixElement]], NDArray[np.float32]. ')
647
+ validate_output_structure(result, func_name=user_function.__name__,
648
+ expected_type_name="List[float | int | None | List[ConfusionMatrixElement] ] | NDArray[np.float32] or dictonary with one of these types as its values types")
649
+ supported_types_message = (f'{user_function.__name__}() validation failed: '
650
+ f'{user_function.__name__}() has returned unsupported type.\nSupported types are List[float|int|None], '
651
+ f'List[List[ConfusionMatrixElement]], NDArray[np.float32] or dictonary with one of these types as its values types. ')
174
652
 
175
- def _validate_single_metric(single_metric_result):
653
+ def _validate_single_metric(single_metric_result, key=None):
176
654
  if isinstance(single_metric_result, list):
177
655
  if isinstance(single_metric_result[0], list):
178
- assert isinstance(single_metric_result[0][0], ConfusionMatrixElement), \
179
- f'{supported_types_message}Got List[List[{type(single_metric_result[0][0])}]].'
656
+ assert all(isinstance(cm, ConfusionMatrixElement) for cm in single_metric_result[0]), (
657
+ f"{supported_types_message} "
658
+ f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
659
+ f"List[List[{', '.join(type(cm).__name__ for cm in single_metric_result[0])}]]."
660
+ )
661
+
180
662
  else:
181
- assert isinstance(single_metric_result[0], (
182
- float, int,
183
- type(None))), f'{supported_types_message}Got List[{type(single_metric_result[0])}].'
663
+ assert all(isinstance(v, (float, int, type(None), np.float32)) for v in single_metric_result), (
664
+ f"{supported_types_message}\n"
665
+ f"Got {'a dict where the value of ' + str(key) + ' is of type ' if key is not None else ''}"
666
+ f"List[{', '.join(type(v).__name__ for v in single_metric_result)}]."
667
+ )
184
668
  else:
185
669
  assert isinstance(single_metric_result,
186
- np.ndarray), f'{supported_types_message}Got {type(single_metric_result)}.'
187
- assert len(single_metric_result.shape) == 1, (f'tensorleap_custom_metric validation failed: '
670
+ np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
671
+ assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
188
672
  f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
189
673
 
190
674
  if leap_binder.batch_size_to_validate:
191
675
  assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
192
- f'tensorleap_custom_metrix validation failed: The return len should be as the batch size.'
676
+ f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
193
677
 
194
678
  if isinstance(result, dict):
195
679
  for key, value in result.items():
680
+ _validate_single_metric(value, key)
681
+
196
682
  assert isinstance(key, str), \
197
- (f'tensorleap_custom_metric validation failed: '
683
+ (f'{user_function.__name__}() validation failed: '
198
684
  f'Keys in the return dict should be of type str. Got {type(key)}.')
199
685
  _validate_single_metric(value)
200
686
 
201
687
  if isinstance(direction, dict):
202
688
  for direction_key in direction:
203
689
  assert direction_key in result, \
204
- (f'tensorleap_custom_metric validation failed: '
690
+ (f'{user_function.__name__}() validation failed: '
205
691
  f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
206
692
 
207
693
  if compute_insights is not None:
208
694
  assert isinstance(compute_insights, dict), \
209
- (f'tensorleap_custom_metric validation failed: '
695
+ (f'{user_function.__name__}() validation failed: '
210
696
  f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
211
697
 
212
698
  for ci_key in compute_insights:
213
699
  assert ci_key in result, \
214
- (f'tensorleap_custom_metric validation failed: '
700
+ (f'{user_function.__name__}() validation failed: '
215
701
  f'Keys in the compute_insights mapping should be part of result keys. Got key {ci_key}.')
216
702
 
217
703
  else:
@@ -219,7 +705,7 @@ def tensorleap_custom_metric(name: str,
219
705
 
220
706
  if compute_insights is not None:
221
707
  assert isinstance(compute_insights, bool), \
222
- (f'tensorleap_custom_metric validation failed: '
708
+ (f'{user_function.__name__}() validation failed: '
223
709
  f'compute_insights should be boolean. Got {type(compute_insights)}.')
224
710
 
225
711
  @functools.wraps(user_function)
@@ -246,11 +732,15 @@ def tensorleap_custom_metric(name: str,
246
732
  _add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
247
733
 
248
734
  def inner(*args, **kwargs):
735
+ if not _call_from_tl_platform:
736
+ set_current('tensorleap_custom_metric')
249
737
  _validate_input_args(*args, **kwargs)
250
738
 
251
739
  result = inner_without_validate(*args, **kwargs)
252
740
 
253
741
  _validate_result(result)
742
+ if not _call_from_tl_platform:
743
+ update_env_params_func("tensorleap_custom_metric", "v")
254
744
  return result
255
745
 
256
746
  def mapping_inner(*args, **kwargs):
@@ -260,6 +750,11 @@ def tensorleap_custom_metric(name: str,
260
750
 
261
751
  ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
262
752
  ordered_connections = list(args) + ordered_connections
753
+
754
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
755
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
756
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
757
+
263
758
  _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
264
759
  mapping_inner.name, NodeMappingType.Metric)
265
760
 
@@ -282,29 +777,40 @@ def tensorleap_custom_metric(name: str,
282
777
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
283
778
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
284
779
  connects_to=None):
780
+ name_to_unique_name = defaultdict(set)
781
+
285
782
  def decorating_function(user_function: VisualizerCallableInterface):
783
+ assert isinstance(visualizer_type, LeapDataType), (f"{user_function.__name__} validation failed: "
784
+ f"visualizer_type should be of type {LeapDataType.__name__} but got {type(visualizer_type)}"
785
+ )
286
786
  for viz_handler in leap_binder.setup_container.visualizers:
287
787
  if viz_handler.visualizer_handler_data.name == name:
288
788
  raise Exception(f'Visualizer with name {name} already exists. '
289
789
  f'Please choose another')
290
790
 
291
791
  def _validate_input_args(*args, **kwargs):
792
+ assert len(args) + len(kwargs) > 0, (
793
+ f"{user_function.__name__}() validation failed: "
794
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
795
+ f"but received none. "
796
+ f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
797
+ )
292
798
  for i, arg in enumerate(args):
293
799
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
294
- f'tensorleap_custom_visualizer validation failed: '
800
+ f'{user_function.__name__}() validation failed: '
295
801
  f'Argument #{i} should be a numpy array. Got {type(arg)}.')
296
802
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
297
803
  assert arg.shape[0] != leap_binder.batch_size_to_validate, \
298
- (f'tensorleap_custom_visualizer validation failed: '
804
+ (f'{user_function.__name__}() validation failed: '
299
805
  f'Argument #{i} should be without batch dimension. ')
300
806
 
301
807
  for _arg_name, arg in kwargs.items():
302
808
  assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
303
- f'tensorleap_custom_visualizer validation failed: '
809
+ f'{user_function.__name__}() validation failed: '
304
810
  f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
305
811
  if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
306
812
  assert arg.shape[0] != leap_binder.batch_size_to_validate, \
307
- (f'tensorleap_custom_visualizer validation failed: Argument {_arg_name} '
813
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
308
814
  f'should be without batch dimension. ')
309
815
 
310
816
  def _validate_result(result):
@@ -318,8 +824,11 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
318
824
  LeapDataType.ImageWithBBox: LeapImageWithBBox,
319
825
  LeapDataType.ImageWithHeatmap: LeapImageWithHeatmap
320
826
  }
827
+ validate_output_structure(result, func_name=user_function.__name__,
828
+ expected_type_name=result_type_map[visualizer_type])
829
+
321
830
  assert isinstance(result, result_type_map[visualizer_type]), \
322
- (f'tensorleap_custom_visualizer validation failed: '
831
+ (f'{user_function.__name__}() validation failed: '
323
832
  f'The return type should be {result_type_map[visualizer_type]}. Got {type(result)}.')
324
833
 
325
834
  @functools.wraps(user_function)
@@ -346,11 +855,15 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
346
855
  _add_mapping_connections(connects_to, arg_names, NodeMappingType.Visualizer, name)
347
856
 
348
857
  def inner(*args, **kwargs):
858
+ if not _call_from_tl_platform:
859
+ set_current('tensorleap_custom_visualizer')
349
860
  _validate_input_args(*args, **kwargs)
350
861
 
351
862
  result = inner_without_validate(*args, **kwargs)
352
863
 
353
864
  _validate_result(result)
865
+ if not _call_from_tl_platform:
866
+ update_env_params_func("tensorleap_custom_visualizer", "v")
354
867
  return result
355
868
 
356
869
  def mapping_inner(*args, **kwargs):
@@ -358,6 +871,10 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
358
871
  if 'user_unique_name' in kwargs:
359
872
  user_unique_name = kwargs['user_unique_name']
360
873
 
874
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
875
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
876
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
877
+
361
878
  ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
362
879
  ordered_connections = list(args) + ordered_connections
363
880
  _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
@@ -388,30 +905,26 @@ def tensorleap_metadata(
388
905
  f'Please choose another')
389
906
 
390
907
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
391
- assert isinstance(sample_id, (int, str)), \
392
- (f'tensorleap_metadata validation failed: '
393
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
394
- assert isinstance(preprocess_response, PreprocessResponse), \
395
- (f'tensorleap_metadata validation failed: '
396
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
397
908
  assert type(sample_id) == preprocess_response.sample_id_type, \
398
- (f'tensorleap_metadata validation failed: '
909
+ (f'{user_function.__name__}() validation failed: '
399
910
  f'Argument sample_id should be as the same type as defined in the preprocess response '
400
911
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
401
912
 
402
913
  def _validate_result(result):
403
914
  supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
404
915
  np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
916
+ validate_output_structure(result, func_name=user_function.__name__,
917
+ expected_type_name=supported_result_types)
405
918
  assert isinstance(result, supported_result_types), \
406
- (f'tensorleap_metadata validation failed: '
919
+ (f'{user_function.__name__}() validation failed: '
407
920
  f'Unsupported return type. Got {type(result)}. should be any of {str(supported_result_types)}')
408
921
  if isinstance(result, dict):
409
922
  for key, value in result.items():
410
923
  assert isinstance(key, str), \
411
- (f'tensorleap_metadata validation failed: '
924
+ (f'{user_function.__name__}() validation failed: '
412
925
  f'Keys in the return dict should be of type str. Got {type(key)}.')
413
926
  assert isinstance(value, supported_result_types), \
414
- (f'tensorleap_metadata validation failed: '
927
+ (f'{user_function.__name__}() validation failed: '
415
928
  f'Values in the return dict should be of type {str(supported_result_types)}. Got {type(value)}.')
416
929
 
417
930
  def inner_without_validate(sample_id, preprocess_response):
@@ -428,6 +941,60 @@ def tensorleap_metadata(
428
941
 
429
942
  leap_binder.set_metadata(inner_without_validate, name, metadata_type)
430
943
 
944
+ def inner(*args, **kwargs):
945
+ if not _call_from_tl_platform:
946
+ set_current('tensorleap_metadata')
947
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
948
+ return None
949
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
950
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
951
+ sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
952
+ _validate_input_args(sample_id, preprocess_response)
953
+
954
+ result = inner_without_validate(sample_id, preprocess_response)
955
+
956
+ _validate_result(result)
957
+ if not _call_from_tl_platform:
958
+ update_env_params_func("tensorleap_metadata", "v")
959
+ return result
960
+
961
+ return inner
962
+
963
+ return decorating_function
964
+
965
+
966
+ def tensorleap_custom_latent_space():
967
+ def decorating_function(user_function: SectionCallableInterface):
968
+ def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
969
+ assert isinstance(sample_id, (int, str)), \
970
+ (f'tensorleap_custom_latent_space validation failed: '
971
+ f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
972
+ assert isinstance(preprocess_response, PreprocessResponse), \
973
+ (f'tensorleap_custom_latent_space validation failed: '
974
+ f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
975
+ assert type(sample_id) == preprocess_response.sample_id_type, \
976
+ (f'tensorleap_custom_latent_space validation failed: '
977
+ f'Argument sample_id should be as the same type as defined in the preprocess response '
978
+ f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
979
+
980
+ def _validate_result(result):
981
+ assert isinstance(result, np.ndarray), \
982
+ (f'tensorleap_custom_loss validation failed: '
983
+ f'The return type should be a numpy array. Got {type(result)}.')
984
+
985
+ def inner_without_validate(sample_id, preprocess_response):
986
+ global _called_from_inside_tl_decorator
987
+ _called_from_inside_tl_decorator += 1
988
+
989
+ try:
990
+ result = user_function(sample_id, preprocess_response)
991
+ finally:
992
+ _called_from_inside_tl_decorator -= 1
993
+
994
+ return result
995
+
996
+ leap_binder.set_custom_latent_space(inner_without_validate)
997
+
431
998
  def inner(sample_id, preprocess_response):
432
999
  if os.environ.get(mapping_runtime_mode_env_var_mame):
433
1000
  return None
@@ -449,30 +1016,45 @@ def tensorleap_preprocess():
449
1016
  leap_binder.set_preprocess(user_function)
450
1017
 
451
1018
  def _validate_input_args(*args, **kwargs):
452
- assert len(args) == 0 and len(kwargs) == 0, \
453
- (f'tensorleap_preprocess validation failed: '
1019
+ assert len(args) + len(kwargs) == 0, \
1020
+ (f'{user_function.__name__}() validation failed: '
454
1021
  f'The function should not take any arguments. Got {args} and {kwargs}.')
455
1022
 
456
1023
  def _validate_result(result):
457
- assert isinstance(result, list), \
458
- (f'tensorleap_preprocess validation failed: '
459
- f'The return type should be a list. Got {type(result)}.')
1024
+ assert isinstance(result, list), (
1025
+ f"{user_function.__name__}() validation failed: expected return type list[{PreprocessResponse.__name__}]"
1026
+ f"(e.g., [PreprocessResponse1, PreprocessResponse2, ...]), but returned type is {type(result).__name__}."
1027
+ if not isinstance(result, tuple)
1028
+ else f"{user_function.__name__}() validation failed: expected to return a single list[{PreprocessResponse.__name__}] object, "
1029
+ f"but returned {len(result)} objects instead."
1030
+ )
460
1031
  for i, response in enumerate(result):
461
1032
  assert isinstance(response, PreprocessResponse), \
462
- (f'tensorleap_preprocess validation failed: '
1033
+ (f'{user_function.__name__}() validation failed: '
463
1034
  f'Element #{i} in the return list should be a PreprocessResponse. Got {type(response)}.')
464
1035
  assert len(set(result)) == len(result), \
465
- (f'tensorleap_preprocess validation failed: '
1036
+ (f'{user_function.__name__}() validation failed: '
466
1037
  f'The return list should not contain duplicate PreprocessResponse objects.')
467
1038
 
468
1039
  def inner(*args, **kwargs):
1040
+ if not _call_from_tl_platform:
1041
+ set_current('tensorleap_metadata')
469
1042
  if os.environ.get(mapping_runtime_mode_env_var_mame):
470
1043
  return [None, None, None, None]
471
1044
 
472
1045
  _validate_input_args(*args, **kwargs)
473
-
474
1046
  result = user_function()
475
1047
  _validate_result(result)
1048
+
1049
+ # Emit integration test event once per test
1050
+ try:
1051
+ emit_integration_event_once(AnalyticsEvent.PREPROCESS_INTEGRATION_TEST, {
1052
+ 'preprocess_responses_count': len(result)
1053
+ })
1054
+ except Exception as e:
1055
+ logger.debug(f"Failed to emit preprocess integration test event: {e}")
1056
+ if not _call_from_tl_platform:
1057
+ update_env_params_func("tensorleap_preprocess", "v")
476
1058
  return result
477
1059
 
478
1060
  return inner
@@ -481,7 +1063,7 @@ def tensorleap_preprocess():
481
1063
 
482
1064
 
483
1065
  def tensorleap_element_instance_preprocess(
484
- instance_length_encoder: InstanceLengthCallableInterface):
1066
+ instance_length_encoder: InstanceLengthCallableInterface, instance_mask_encoder: InstanceCallableInterface):
485
1067
  def decorating_function(user_function: Callable[[], List[PreprocessResponse]]):
486
1068
  def user_function_instance() -> List[PreprocessResponse]:
487
1069
  result = user_function()
@@ -497,16 +1079,39 @@ def tensorleap_element_instance_preprocess(
497
1079
  for instance_id in instances_ids:
498
1080
  instance_to_sample_ids_mappings[instance_id] = sample_id
499
1081
  all_sample_ids.extend(instances_ids)
1082
+ preprocess_response.length = len(all_sample_ids)
500
1083
  preprocess_response.sample_ids_to_instance_mappings = sample_ids_to_instance_mappings
501
1084
  preprocess_response.instance_to_sample_ids_mappings = instance_to_sample_ids_mappings
502
1085
  preprocess_response.sample_ids = all_sample_ids
503
1086
  return result
504
1087
 
1088
+ def extract_extra_instance_metadata():
1089
+ result = user_function()
1090
+ for preprocess_response in result:
1091
+ for sample_id in preprocess_response.sample_ids:
1092
+ instances_length = instance_length_encoder(sample_id, preprocess_response)
1093
+ if instances_length > 0:
1094
+ element_instance = instance_mask_encoder(sample_id, preprocess_response, 0)
1095
+ instance_metadata = element_instance.instance_metadata
1096
+ if instance_metadata is None:
1097
+ return {}
1098
+ return instance_metadata
1099
+ return {}
1100
+
1101
+
505
1102
  def builtin_instance_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
506
1103
  return {'is_instance': '0', 'original_sample_id': idx, 'instance_name': 'none'}
507
1104
 
1105
+ def builtin_instance_extra_metadata(idx: str, preprocess: PreprocessResponse) -> Dict[str, str]:
1106
+ instance_metadata = extract_extra_instance_metadata()
1107
+ for key, value in instance_metadata.items():
1108
+ instance_metadata[key] = 'unset'
1109
+ return instance_metadata
1110
+
508
1111
  leap_binder.set_preprocess(user_function_instance)
509
1112
  leap_binder.set_metadata(builtin_instance_metadata, "builtin_instance_metadata")
1113
+ leap_binder.set_metadata(builtin_instance_extra_metadata, "builtin_instance_extra_metadata")
1114
+
510
1115
 
511
1116
  def _validate_input_args(*args, **kwargs):
512
1117
  assert len(args) == 0 and len(kwargs) == 0, \
@@ -533,6 +1138,8 @@ def tensorleap_element_instance_preprocess(
533
1138
 
534
1139
  result = user_function_instance()
535
1140
  _validate_result(result)
1141
+ if not _call_from_tl_platform:
1142
+ update_env_params_func("tensorleap_preprocess", "v")
536
1143
  return result
537
1144
 
538
1145
  return inner
@@ -567,7 +1174,7 @@ def tensorleap_unlabeled_preprocess():
567
1174
 
568
1175
  def tensorleap_instances_masks_encoder(name: str):
569
1176
  def decorating_function(user_function: InstanceCallableInterface):
570
- def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: str):
1177
+ def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse, instance_id: int):
571
1178
  assert isinstance(sample_id, str), \
572
1179
  (f'tensorleap_instances_masks_encoder validation failed: '
573
1180
  f'Argument sample_id should be str. Got {type(sample_id)}.')
@@ -578,13 +1185,9 @@ def tensorleap_instances_masks_encoder(name: str):
578
1185
  (f'tensorleap_instances_masks_encoder validation failed: '
579
1186
  f'Argument sample_id should be as the same type as defined in the preprocess response '
580
1187
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
581
- assert isinstance(instance_id, str), \
582
- (f'tensorleap_instances_masks_encoder validation failed: '
583
- f'Argument instance_id should be str. Got {type(instance_id)}.')
584
- assert type(instance_id) == preprocess_response.sample_id_type, \
1188
+ assert isinstance(instance_id, int), \
585
1189
  (f'tensorleap_instances_masks_encoder validation failed: '
586
- f'Argument instance_id should be as the same type as defined in the preprocess response '
587
- f'{preprocess_response.sample_id_type}. Got {type(instance_id)}.')
1190
+ f'Argument instance_id should be int. Got {type(instance_id)}.')
588
1191
 
589
1192
  def _validate_result(result):
590
1193
  assert isinstance(result, ElementInstance) or (result is None), \
@@ -619,10 +1222,11 @@ def tensorleap_instances_masks_encoder(name: str):
619
1222
 
620
1223
  return decorating_function
621
1224
 
1225
+
622
1226
  def tensorleap_instances_length_encoder(name: str):
623
1227
  def decorating_function(user_function: InstanceLengthCallableInterface):
624
1228
  def _validate_input_args(sample_id: str, preprocess_response: PreprocessResponse):
625
- assert isinstance(sample_id, str), \
1229
+ assert isinstance(sample_id, (str, int)), \
626
1230
  (f'tensorleap_instances_length_encoder validation failed: '
627
1231
  f'Argument sample_id should be str. Got {type(sample_id)}.')
628
1232
  assert isinstance(preprocess_response, PreprocessResponse), \
@@ -649,8 +1253,6 @@ def tensorleap_instances_length_encoder(name: str):
649
1253
 
650
1254
  return result
651
1255
 
652
- # leap_binder.set_instance_masks(inner_without_validate, name). # TODO: do i need this?
653
-
654
1256
  def inner(sample_id, preprocess_response):
655
1257
  if os.environ.get(mapping_runtime_mode_env_var_mame):
656
1258
  return None
@@ -666,46 +1268,88 @@ def tensorleap_instances_length_encoder(name: str):
666
1268
 
667
1269
  return decorating_function
668
1270
 
669
- def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
1271
+
1272
+ def tensorleap_input_encoder(name: str, channel_dim=_UNSET, model_input_index=None):
670
1273
  def decorating_function(user_function: SectionCallableInterface):
671
1274
  for input_handler in leap_binder.setup_container.inputs:
672
1275
  if input_handler.name == name:
673
1276
  raise Exception(f'Input with name {name} already exists. '
674
1277
  f'Please choose another')
1278
+ nonlocal channel_dim
1279
+
1280
+ channel_dim_was_provided = channel_dim is not _UNSET
1281
+
1282
+ if not channel_dim_was_provided:
1283
+ channel_dim = -1
1284
+ if not _call_from_tl_platform:
1285
+ store_warning_by_param(
1286
+ param_name="channel_dim",
1287
+ user_func_name=user_function.__name__,
1288
+ default_value=channel_dim,
1289
+ link_to_docs="https://docs.tensorleap.ai/tensorleap-integration/writing-integration-code/input-encoder"
1290
+
1291
+ )
1292
+
675
1293
  if channel_dim <= 0 and channel_dim != -1:
676
1294
  raise Exception(f"Channel dim for input {name} is expected to be either -1 or positive")
677
1295
 
678
- leap_binder.set_input(user_function, name, channel_dim=channel_dim)
679
-
680
1296
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
681
- assert isinstance(sample_id, (int, str)), \
682
- (f'tensorleap_input_encoder validation failed: '
683
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
684
- assert isinstance(preprocess_response, PreprocessResponse), \
685
- (f'tensorleap_input_encoder validation failed: '
686
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
687
1297
  assert type(sample_id) == preprocess_response.sample_id_type, \
688
- (f'tensorleap_input_encoder validation failed: '
1298
+ (f'{user_function.__name__}() validation failed: '
689
1299
  f'Argument sample_id should be as the same type as defined in the preprocess response '
690
1300
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
691
1301
 
692
1302
  def _validate_result(result):
1303
+ validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray")
693
1304
  assert isinstance(result, np.ndarray), \
694
- (f'tensorleap_input_encoder validation failed: '
1305
+ (f'{user_function.__name__}() validation failed: '
695
1306
  f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
696
1307
  assert result.dtype == np.float32, \
697
- (f'tensorleap_input_encoder validation failed: '
1308
+ (f'{user_function.__name__}() validation failed: '
698
1309
  f'The return type should be a numpy array of type float32. Got {result.dtype}.')
699
- assert channel_dim - 1 <= len(result.shape), (f'tensorleap_input_encoder validation failed: '
1310
+ assert channel_dim - 1 <= len(result.shape), (f'{user_function.__name__}() validation failed: '
700
1311
  f'The channel_dim ({channel_dim}) should be <= to the rank of the resulting input rank ({len(result.shape)}).')
701
1312
 
702
- def inner(sample_id, preprocess_response):
1313
+ def inner_without_validate(sample_id, preprocess_response):
1314
+ global _called_from_inside_tl_decorator
1315
+ _called_from_inside_tl_decorator += 1
1316
+
1317
+ try:
1318
+ result = user_function(sample_id, preprocess_response)
1319
+ finally:
1320
+ _called_from_inside_tl_decorator -= 1
1321
+
1322
+ return result
1323
+
1324
+ leap_binder.set_input(inner_without_validate, name, channel_dim=channel_dim)
1325
+
1326
+ def inner(*args, **kwargs):
1327
+ if not _call_from_tl_platform:
1328
+ set_current("tensorleap_input_encoder")
1329
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
1330
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
1331
+ sample_id, preprocess_response = args if len(args) != 0 else kwargs.values()
703
1332
  _validate_input_args(sample_id, preprocess_response)
704
- result = user_function(sample_id, preprocess_response)
1333
+
1334
+ result = inner_without_validate(sample_id, preprocess_response)
1335
+
705
1336
  _validate_result(result)
706
1337
 
707
- if _called_from_inside_tl_decorator == 0:
1338
+ if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
1339
+ batch_warning(result, user_function.__name__)
708
1340
  result = np.expand_dims(result, axis=0)
1341
+ # Emit integration test event once per test
1342
+ try:
1343
+ emit_integration_event_once(AnalyticsEvent.INPUT_ENCODER_INTEGRATION_TEST, {
1344
+ 'encoder_name': name,
1345
+ 'channel_dim': channel_dim,
1346
+ 'model_input_index': model_input_index
1347
+ })
1348
+ except Exception as e:
1349
+ logger.debug(f"Failed to emit input_encoder integration test event: {e}")
1350
+ if not _call_from_tl_platform:
1351
+ update_env_params_func("tensorleap_input_encoder", "v")
1352
+
709
1353
  return result
710
1354
 
711
1355
  node_mapping_type = NodeMappingType.Input
@@ -713,22 +1357,23 @@ def tensorleap_input_encoder(name: str, channel_dim=-1, model_input_index=None):
713
1357
  node_mapping_type = NodeMappingType(f'Input{str(model_input_index)}')
714
1358
  inner.node_mapping = NodeMapping(name, node_mapping_type)
715
1359
 
716
- def mapping_inner(sample_id, preprocess_response):
1360
+ def mapping_inner(*args, **kwargs):
717
1361
  class TempMapping:
718
1362
  pass
719
1363
 
720
1364
  ret = TempMapping()
721
1365
  ret.node_mapping = mapping_inner.node_mapping
722
1366
 
1367
+ leap_binder.mapping_connections.append(NodeConnection(mapping_inner.node_mapping, None))
723
1368
  return ret
724
1369
 
725
1370
  mapping_inner.node_mapping = NodeMapping(name, node_mapping_type)
726
1371
 
727
- def final_inner(sample_id, preprocess_response):
1372
+ def final_inner(*args, **kwargs):
728
1373
  if os.environ.get(mapping_runtime_mode_env_var_mame):
729
- return mapping_inner(sample_id, preprocess_response)
1374
+ return mapping_inner(*args, **kwargs)
730
1375
  else:
731
- return inner(sample_id, preprocess_response)
1376
+ return inner(*args, **kwargs)
732
1377
 
733
1378
  final_inner.node_mapping = NodeMapping(name, node_mapping_type)
734
1379
 
@@ -744,40 +1389,64 @@ def tensorleap_gt_encoder(name: str):
744
1389
  raise Exception(f'GT with name {name} already exists. '
745
1390
  f'Please choose another')
746
1391
 
747
- leap_binder.set_ground_truth(user_function, name)
748
-
749
1392
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
750
- assert isinstance(sample_id, (int, str)), \
751
- (f'tensorleap_gt_encoder validation failed: '
752
- f'Argument sample_id should be either int or str. Got {type(sample_id)}.')
753
- assert isinstance(preprocess_response, PreprocessResponse), \
754
- (f'tensorleap_gt_encoder validation failed: '
755
- f'Argument preprocess_response should be a PreprocessResponse. Got {type(preprocess_response)}.')
756
1393
  assert type(sample_id) == preprocess_response.sample_id_type, \
757
- (f'tensorleap_gt_encoder validation failed: '
1394
+ (f'{user_function.__name__}() validation failed: '
758
1395
  f'Argument sample_id should be as the same type as defined in the preprocess response '
759
1396
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
760
1397
 
761
1398
  def _validate_result(result):
1399
+ validate_output_structure(result, func_name=user_function.__name__, expected_type_name="np.ndarray",
1400
+ gt_flag=True)
762
1401
  assert isinstance(result, np.ndarray), \
763
- (f'tensorleap_gt_encoder validation failed: '
1402
+ (f'{user_function.__name__}() validation failed: '
764
1403
  f'Unsupported return type. Should be a numpy array. Got {type(result)}.')
765
1404
  assert result.dtype == np.float32, \
766
- (f'tensorleap_gt_encoder validation failed: '
1405
+ (f'{user_function.__name__}() validation failed: '
767
1406
  f'The return type should be a numpy array of type float32. Got {result.dtype}.')
768
1407
 
769
- def inner(sample_id, preprocess_response):
1408
+ def inner_without_validate(sample_id, preprocess_response):
1409
+ global _called_from_inside_tl_decorator
1410
+ _called_from_inside_tl_decorator += 1
1411
+
1412
+ try:
1413
+ result = user_function(sample_id, preprocess_response)
1414
+ finally:
1415
+ _called_from_inside_tl_decorator -= 1
1416
+
1417
+ return result
1418
+
1419
+ leap_binder.set_ground_truth(inner_without_validate, name)
1420
+
1421
+ def inner(*args, **kwargs):
1422
+ if not _call_from_tl_platform:
1423
+ set_current("tensorleap_gt_encoder")
1424
+ validate_args_structure(*args, types_order=[Union[int, str], PreprocessResponse],
1425
+ func_name=user_function.__name__, expected_names=["idx", "preprocess"], **kwargs)
1426
+ sample_id, preprocess_response = args
770
1427
  _validate_input_args(sample_id, preprocess_response)
771
- result = user_function(sample_id, preprocess_response)
1428
+
1429
+ result = inner_without_validate(sample_id, preprocess_response)
1430
+
772
1431
  _validate_result(result)
773
1432
 
774
- if _called_from_inside_tl_decorator == 0:
1433
+ if _called_from_inside_tl_decorator == 0 and _called_from_inside_tl_integration_test_decorator:
1434
+ batch_warning(result, user_function.__name__)
775
1435
  result = np.expand_dims(result, axis=0)
1436
+ # Emit integration test event once per test
1437
+ try:
1438
+ emit_integration_event_once(AnalyticsEvent.GT_ENCODER_INTEGRATION_TEST, {
1439
+ 'encoder_name': name
1440
+ })
1441
+ except Exception as e:
1442
+ logger.debug(f"Failed to emit gt_encoder integration test event: {e}")
1443
+ if not _call_from_tl_platform:
1444
+ update_env_params_func("tensorleap_gt_encoder", "v")
776
1445
  return result
777
1446
 
778
1447
  inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
779
1448
 
780
- def mapping_inner(sample_id, preprocess_response):
1449
+ def mapping_inner(*args, **kwargs):
781
1450
  class TempMapping:
782
1451
  pass
783
1452
 
@@ -788,11 +1457,11 @@ def tensorleap_gt_encoder(name: str):
788
1457
 
789
1458
  mapping_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
790
1459
 
791
- def final_inner(sample_id, preprocess_response):
1460
+ def final_inner(*args, **kwargs):
792
1461
  if os.environ.get(mapping_runtime_mode_env_var_mame):
793
- return mapping_inner(sample_id, preprocess_response)
1462
+ return mapping_inner(*args, **kwargs)
794
1463
  else:
795
- return inner(sample_id, preprocess_response)
1464
+ return inner(*args, **kwargs)
796
1465
 
797
1466
  final_inner.node_mapping = NodeMapping(name, NodeMappingType.GroundTruth)
798
1467
 
@@ -802,6 +1471,8 @@ def tensorleap_gt_encoder(name: str):
802
1471
 
803
1472
 
804
1473
  def tensorleap_custom_loss(name: str, connects_to=None):
1474
+ name_to_unique_name = defaultdict(set)
1475
+
805
1476
  def decorating_function(user_function: CustomCallableInterface):
806
1477
  for loss_handler in leap_binder.setup_container.custom_loss_handlers:
807
1478
  if loss_handler.custom_loss_handler_data.name == name:
@@ -809,35 +1480,29 @@ def tensorleap_custom_loss(name: str, connects_to=None):
809
1480
  f'Please choose another')
810
1481
 
811
1482
  valid_types = (np.ndarray, SamplePreprocessResponse)
812
- try:
813
- import tensorflow as tf
814
- valid_types = (np.ndarray, SamplePreprocessResponse, tf.Tensor)
815
- except ImportError:
816
- pass
817
1483
 
818
1484
  def _validate_input_args(*args, **kwargs):
819
-
1485
+ assert len(args) + len(kwargs) > 0, (
1486
+ f"{user_function.__name__}() validation failed: "
1487
+ f"Expected at least one positional|key-word argument of the allowed types (np.ndarray|SamplePreprocessResponse|). "
1488
+ f"but received none. "
1489
+ f"Correct usage example: {user_function.__name__}(input_array: np.ndarray, ...)"
1490
+ )
820
1491
  for i, arg in enumerate(args):
821
- if isinstance(arg, list):
822
- for y, elem in enumerate(arg):
823
- assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
824
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
825
- else:
826
- assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
827
- f'Argument #{i} should be a numpy array. Got {type(arg)}.')
1492
+ assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
1493
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
828
1494
  for _arg_name, arg in kwargs.items():
829
- if isinstance(arg, list):
830
- for y, elem in enumerate(arg):
831
- assert isinstance(elem, valid_types), (f'tensorleap_custom_loss validation failed: '
832
- f'Element #{y} of list should be a numpy array. Got {type(elem)}.')
833
- else:
834
- assert isinstance(arg, valid_types), (f'tensorleap_custom_loss validation failed: '
835
- f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
1495
+ assert isinstance(arg, valid_types), (f'{user_function.__name__}() validation failed: '
1496
+ f'Argument #{_arg_name} should be a numpy array. Got {type(arg)}.')
836
1497
 
837
1498
  def _validate_result(result):
838
- assert isinstance(result, valid_types), \
839
- (f'tensorleap_custom_loss validation failed: '
1499
+ validate_output_structure(result, func_name=user_function.__name__,
1500
+ expected_type_name="np.ndarray")
1501
+ assert isinstance(result, np.ndarray), \
1502
+ (f'{user_function.__name__} validation failed: '
840
1503
  f'The return type should be a numpy array. Got {type(result)}.')
1504
+ assert result.ndim < 2, (f'{user_function.__name__} validation failed: '
1505
+ f'The return type should be a 1Dim numpy array but got {result.ndim}Dim.')
841
1506
 
842
1507
  @functools.wraps(user_function)
843
1508
  def inner_without_validate(*args, **kwargs):
@@ -863,11 +1528,16 @@ def tensorleap_custom_loss(name: str, connects_to=None):
863
1528
  _add_mapping_connections(connects_to, arg_names, NodeMappingType.CustomLoss, name)
864
1529
 
865
1530
  def inner(*args, **kwargs):
1531
+ if not _call_from_tl_platform:
1532
+ set_current("tensorleap_custom_loss")
866
1533
  _validate_input_args(*args, **kwargs)
867
1534
 
868
1535
  result = inner_without_validate(*args, **kwargs)
869
1536
 
870
1537
  _validate_result(result)
1538
+ if not _call_from_tl_platform:
1539
+ update_env_params_func("tensorleap_custom_loss", "v")
1540
+
871
1541
  return result
872
1542
 
873
1543
  def mapping_inner(*args, **kwargs):
@@ -875,6 +1545,10 @@ def tensorleap_custom_loss(name: str, connects_to=None):
875
1545
  if 'user_unique_name' in kwargs:
876
1546
  user_unique_name = kwargs['user_unique_name']
877
1547
 
1548
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
1549
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
1550
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
1551
+
878
1552
  ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
879
1553
  ordered_connections = list(args) + ordered_connections
880
1554
  _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
@@ -920,3 +1594,180 @@ def tensorleap_custom_layer(name: str):
920
1594
  return custom_layer
921
1595
 
922
1596
  return decorating_function
1597
+
1598
+
1599
+ def tensorleap_status_table():
1600
+ import atexit
1601
+ import sys
1602
+ import traceback
1603
+ from typing import Any
1604
+
1605
+ CHECK = "✅"
1606
+ CROSS = "❌"
1607
+ UNKNOWN = "❔"
1608
+
1609
+ code_mapping_failure = [0]
1610
+
1611
+ table = [
1612
+ {"name": "tensorleap_preprocess", "Added to integration": UNKNOWN},
1613
+ {"name": "tensorleap_integration_test", "Added to integration": UNKNOWN},
1614
+ {"name": "tensorleap_input_encoder", "Added to integration": UNKNOWN},
1615
+ {"name": "tensorleap_gt_encoder", "Added to integration": UNKNOWN},
1616
+ {"name": "tensorleap_load_model", "Added to integration": UNKNOWN},
1617
+ {"name": "tensorleap_custom_loss", "Added to integration": UNKNOWN},
1618
+ {"name": "tensorleap_custom_metric (optional)", "Added to integration": UNKNOWN},
1619
+ {"name": "tensorleap_metadata (optional)", "Added to integration": UNKNOWN},
1620
+ {"name": "tensorleap_custom_visualizer (optional)", "Added to integration": UNKNOWN},
1621
+ ]
1622
+
1623
+ _finalizer_called = {"done": False}
1624
+ _crashed = {"value": False}
1625
+ _current_func = {"name": None}
1626
+
1627
+ def _link(url: str) -> str:
1628
+ return f"\033{url}\033"
1629
+
1630
+ def _remove_suffix(s: str, suffix: str) -> str:
1631
+ if suffix and s.endswith(suffix):
1632
+ return s[:-len(suffix)]
1633
+ return s
1634
+
1635
+ def _find_row(name: str):
1636
+ for row in table:
1637
+ if _remove_suffix(row["name"], " (optional)") == name:
1638
+ return row
1639
+ return None
1640
+
1641
+ def _set_status(name: str, status_symbol: str):
1642
+ row = _find_row(name)
1643
+ if not row:
1644
+ return
1645
+
1646
+ cur = row["Added to integration"]
1647
+ if status_symbol == UNKNOWN:
1648
+ return
1649
+ if cur == CHECK and status_symbol != CHECK:
1650
+ return
1651
+
1652
+ row["Added to integration"] = status_symbol
1653
+
1654
+ def _mark_unknowns_as_cross():
1655
+ for row in table:
1656
+ if row["Added to integration"] == UNKNOWN:
1657
+ row["Added to integration"] = CROSS
1658
+
1659
+ def _format_default_value(v: Any) -> str:
1660
+ if hasattr(v, "name"):
1661
+ return str(v.name)
1662
+ if isinstance(v, str):
1663
+ return v
1664
+ s = repr(v)
1665
+ return s if len(s) <= 120 else s[:120] + "..."
1666
+
1667
+ def _print_param_default_warnings():
1668
+ data = _get_param_default_warnings()
1669
+ if not data:
1670
+ return
1671
+
1672
+ print("\nWarnings (Default use. It is recommended to set values explicitly):")
1673
+ for param_name in sorted(data.keys()):
1674
+ default_value = data[param_name]["default_value"]
1675
+ funcs = ", ".join(sorted(data[param_name]["funcs"]))
1676
+ dv = _format_default_value(default_value)
1677
+
1678
+ docs_link = data[param_name].get("link_to_docs")
1679
+ docs_part = f" {_link(docs_link)}" if docs_link else ""
1680
+ print(
1681
+ f" ⚠️ Parameter '{param_name}' defaults to {dv} in the following functions: [{funcs}]. "
1682
+ f"For more information, check {docs_part}")
1683
+ print("\nIf this isn’t the intended behaviour, set them explicitly.")
1684
+
1685
+ def _print_table():
1686
+ _print_param_default_warnings()
1687
+
1688
+ if not started_from("leap_integration.py"):
1689
+ return
1690
+
1691
+ ready_mess = "\nAll parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system."
1692
+ not_ready_mess = "\nSome mandatory components have not yet been added to the Integration test. Recommended next interface to add is: "
1693
+ mandatory_ready_mess = "\nAll mandatory parts have been successfully set. If no errors accured, you can now push the project to the Tensorleap system or continue to the next optional reccomeded interface,adding: "
1694
+ code_mapping_failure_mes = "Tensorleap_integration_test code flow failed, check raised exception."
1695
+
1696
+ name_width = max(len(row["name"]) for row in table)
1697
+ status_width = max(len(row["Added to integration"]) for row in table)
1698
+
1699
+ header = f"{'Decorator Name'.ljust(name_width)} | {'Added to integration'.ljust(status_width)}"
1700
+ sep = "-" * len(header)
1701
+
1702
+ print("\n" + header)
1703
+ print(sep)
1704
+
1705
+ ready = True
1706
+ next_step = None
1707
+
1708
+ for row in table:
1709
+ print(f"{row['name'].ljust(name_width)} | {row['Added to integration'].ljust(status_width)}")
1710
+ if not _crashed["value"] and ready:
1711
+ if row["Added to integration"] != CHECK:
1712
+ ready = False
1713
+ next_step = row["name"]
1714
+
1715
+ if _crashed["value"]:
1716
+ print(f"\nScript crashed before completing all steps. crashed at function '{_current_func['name']}'.")
1717
+ return
1718
+
1719
+ if code_mapping_failure[0]:
1720
+ print(f"\n{CROSS + code_mapping_failure_mes}.")
1721
+ return
1722
+
1723
+ print(ready_mess) if ready else print(
1724
+ mandatory_ready_mess + next_step
1725
+ ) if (next_step and "optional" in next_step) else print(not_ready_mess + (next_step or ""))
1726
+
1727
+ def set_current(name: str):
1728
+ _current_func["name"] = name
1729
+
1730
+ def update_env_params(name: str, status: str = "v"):
1731
+ if name == "code_mapping":
1732
+ code_mapping_failure[0] = 1
1733
+ if status == "v":
1734
+ _set_status(name, CHECK)
1735
+ else:
1736
+ _set_status(name, CROSS)
1737
+
1738
+ def run_on_exit():
1739
+ if _finalizer_called["done"]:
1740
+ return
1741
+ _finalizer_called["done"] = True
1742
+ if not _crashed["value"]:
1743
+ _mark_unknowns_as_cross()
1744
+
1745
+ _print_table()
1746
+
1747
+ def handle_exception(exc_type, exc_value, exc_traceback):
1748
+ _crashed["value"] = True
1749
+ crashed_name = _current_func["name"]
1750
+ if crashed_name:
1751
+ row = _find_row(crashed_name)
1752
+ if row and row["Added to integration"] != CHECK:
1753
+ row["Added to integration"] = CROSS
1754
+
1755
+ traceback.print_exception(exc_type, exc_value, exc_traceback)
1756
+ run_on_exit()
1757
+
1758
+ atexit.register(run_on_exit)
1759
+ sys.excepthook = handle_exception
1760
+
1761
+ return set_current, update_env_params
1762
+
1763
+
1764
+ if not _call_from_tl_platform:
1765
+ set_current, update_env_params_func = tensorleap_status_table()
1766
+
1767
+
1768
+
1769
+
1770
+
1771
+
1772
+
1773
+