code-loader 1.0.173.dev2__tar.gz → 1.0.174.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/PKG-INFO +1 -1
  2. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/datasetclasses.py +7 -0
  3. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/inner_leap_binder/leapbinder_decorators.py +209 -0
  4. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/pyproject.toml +1 -1
  5. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/LICENSE +0 -0
  6. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/README.md +0 -0
  7. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/__init__.py +0 -0
  8. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/__init__.py +0 -0
  9. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/enums.py +0 -0
  10. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/exceptions.py +0 -0
  11. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/mapping.py +0 -0
  12. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/responsedataclasses.py +0 -0
  13. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/contract/visualizer_classes.py +0 -0
  14. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/default_losses.py +0 -0
  15. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/default_metrics.py +0 -0
  16. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/__init__.py +0 -0
  17. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/api.py +0 -0
  18. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/cli_config_utils.py +0 -0
  19. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/client.py +0 -0
  20. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/epoch.py +0 -0
  21. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/experiment.py +0 -0
  22. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/experiment_context.py +0 -0
  23. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/types.py +0 -0
  24. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/utils.py +0 -0
  25. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  26. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/inner_leap_binder/__init__.py +0 -0
  27. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/inner_leap_binder/leapbinder.py +0 -0
  28. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/leaploader.py +0 -0
  29. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/leaploaderbase.py +0 -0
  30. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/mixpanel_tracker.py +0 -0
  31. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/plot_functions/__init__.py +0 -0
  32. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/plot_functions/plot_functions.py +0 -0
  33. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/plot_functions/visualize.py +0 -0
  34. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/utils.py +0 -0
  35. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/visualizers/__init__.py +0 -0
  36. {code_loader-1.0.173.dev2 → code_loader-1.0.174.dev2}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.173.dev2
3
+ Version: 1.0.174.dev2
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -171,6 +171,12 @@ class MetricHandler:
171
171
  function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
172
172
 
173
173
 
174
+ @dataclass
175
+ class InstanceMetricHandler:
176
+ metric_handler_data: MetricHandlerData
177
+ function: CustomMultipleReturnCallableInterfaceMultiArgs
178
+
179
+
174
180
  @dataclass
175
181
  class RawInputsForHeatmap:
176
182
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -258,6 +264,7 @@ class DatasetIntegrationSetup:
258
264
  prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
259
265
  custom_loss_handlers: List[CustomLossHandler] = field(default_factory=list)
260
266
  metrics: List[MetricHandler] = field(default_factory=list)
267
+ instance_metrics: List[InstanceMetricHandler] = field(default_factory=list)
261
268
  custom_layers: Dict[str, CustomLayerHandler] = field(default_factory=dict)
262
269
  custom_latent_space: Optional[CustomLatentSpaceHandler] = None
263
270
 
@@ -877,6 +877,215 @@ def tensorleap_custom_metric(name: str,
877
877
  return decorating_function
878
878
 
879
879
 
880
+ def tensorleap_custom_instances_metric(name: str,
881
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
882
+ compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
883
+ connects_to=None):
884
+ name_to_unique_name = defaultdict(set)
885
+
886
+ def decorating_function(
887
+ user_function: CustomMultipleReturnCallableInterfaceMultiArgs):
888
+ nonlocal direction
889
+
890
+ direction_was_provided = direction is not _UNSET
891
+
892
+ def _validate_decorators_signature():
893
+ err_message = f"{user_function.__name__} validation failed.\n"
894
+ if not isinstance(name, str):
895
+ raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
896
+ valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
897
+ if direction is _UNSET:
898
+ pass
899
+ elif isinstance(direction, MetricDirection):
900
+ if direction not in valid_directions:
901
+ raise ValueError(
902
+ err_message +
903
+ f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
904
+ f"got type {type(direction).__name__}."
905
+ )
906
+ elif isinstance(direction, dict):
907
+ if not all(isinstance(k, str) for k in direction.keys()):
908
+ invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
909
+ raise TypeError(
910
+ err_message +
911
+ f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
912
+ )
913
+ for k, v in direction.items():
914
+ if v not in valid_directions:
915
+ raise ValueError(
916
+ err_message +
917
+ f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
918
+ f"got type {type(v).__name__}."
919
+ )
920
+ else:
921
+ raise TypeError(
922
+ err_message +
923
+ f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
924
+ f"got type {type(direction).__name__}."
925
+ )
926
+ if compute_insights is not None:
927
+ if not isinstance(compute_insights, (bool, dict)):
928
+ raise TypeError(
929
+ err_message +
930
+ f"`compute_insights` must be a bool or a Dict[str, bool], "
931
+ f"got type {type(compute_insights).__name__}."
932
+ )
933
+ if isinstance(compute_insights, dict):
934
+ if not all(isinstance(k, str) for k in compute_insights.keys()):
935
+ invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
936
+ raise TypeError(
937
+ err_message +
938
+ f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
939
+ )
940
+ for k, v in compute_insights.items():
941
+ if not isinstance(v, bool):
942
+ raise TypeError(
943
+ err_message +
944
+ f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
945
+ )
946
+ if connects_to is not None:
947
+ valid_types = (str, list, tuple, set)
948
+ if not isinstance(connects_to, valid_types):
949
+ raise TypeError(
950
+ err_message +
951
+ f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
952
+ )
953
+ if isinstance(connects_to, (list, tuple, set)):
954
+ invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
955
+ if invalid_elems:
956
+ raise TypeError(
957
+ err_message +
958
+ f"All elements in `connects_to` must be strings, "
959
+ f"but found element types: {invalid_elems}."
960
+ )
961
+
962
+ _validate_decorators_signature()
963
+
964
+ for metric_handler in leap_binder.setup_container.instance_metrics:
965
+ if metric_handler.metric_handler_data.name == name:
966
+ raise Exception(f'Metric with name {name} already exists. '
967
+ f'Please choose another')
968
+
969
+ def _validate_input_args(*args, **kwargs) -> None:
970
+ assert len(args) + len(kwargs) > 0, (
971
+ f"{user_function.__name__}() validation failed: "
972
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
973
+ f"but received none. "
974
+ f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
975
+ )
976
+ for i, arg in enumerate(args):
977
+ assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
978
+ f'{user_function.__name__}() validation failed: '
979
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
980
+ if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
981
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
982
+ (f'{user_function.__name__}() validation failed: Argument #{i} '
983
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
984
+ f'instead of {leap_binder.batch_size_to_validate}')
985
+
986
+ for _arg_name, arg in kwargs.items():
987
+ assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
988
+ f'{user_function.__name__}() validation failed: '
989
+ f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
990
+ if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
991
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
992
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
993
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
994
+ f'instead of {leap_binder.batch_size_to_validate}')
995
+
996
+ def _validate_result(result) -> None:
997
+ nonlocal direction
998
+ supported_types_message = (f'{user_function.__name__}() validation failed: '
999
+ f'{user_function.__name__}() has returned unsupported type.\nSupported type is Dict[List[NDArray[np.float32]]]')
1000
+
1001
+ def _validate_single_metric(single_metric_result, key=None):
1002
+ assert isinstance(single_metric_result,
1003
+ np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
1004
+ assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
1005
+ f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
1006
+
1007
+ if leap_binder.batch_size_to_validate:
1008
+ assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
1009
+ f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
1010
+
1011
+ assert isinstance(result, dict), f'{supported_types_message}\nGot {type(result)}.'
1012
+
1013
+ result_keys = set(result.keys())
1014
+ for key, value in result.items():
1015
+ _validate_single_metric(value, key)
1016
+
1017
+ assert isinstance(key, int), \
1018
+ (f'{user_function.__name__}() validation failed: '
1019
+ f'Keys in the return dict should be of type int (instance number). Got {type(key)}.')
1020
+
1021
+ leap_binder.setup_container.instance_metrics[-1].metric_handler_data.direction = direction
1022
+
1023
+ @functools.wraps(user_function)
1024
+ def inner_without_validate(*args, **kwargs):
1025
+ global _called_from_inside_tl_decorator
1026
+ _called_from_inside_tl_decorator += 1
1027
+
1028
+ try:
1029
+ result = user_function(*args, **kwargs)
1030
+ finally:
1031
+ _called_from_inside_tl_decorator -= 1
1032
+
1033
+ return result
1034
+
1035
+ try:
1036
+ inner_without_validate.__signature__ = inspect.signature(user_function)
1037
+ except (TypeError, ValueError):
1038
+ pass
1039
+
1040
+ leap_binder.add_custom_metric(inner_without_validate, name, direction, compute_insights)
1041
+
1042
+ if connects_to is not None:
1043
+ arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
1044
+ _add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
1045
+
1046
+ def inner(*args, **kwargs):
1047
+ if not _call_from_tl_platform:
1048
+ set_current('tensorleap_custom_instances_metric')
1049
+ _validate_input_args(*args, **kwargs)
1050
+
1051
+ result = inner_without_validate(*args, **kwargs)
1052
+
1053
+ _validate_result(result)
1054
+ if not _call_from_tl_platform:
1055
+ update_env_params_func("tensorleap_custom_instances_metric", "v")
1056
+ return result
1057
+
1058
+ def mapping_inner(*args, **kwargs):
1059
+ user_unique_name = mapping_inner.name
1060
+ if 'user_unique_name' in kwargs:
1061
+ user_unique_name = kwargs['user_unique_name']
1062
+
1063
+ ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
1064
+ ordered_connections = list(args) + ordered_connections
1065
+
1066
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
1067
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
1068
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
1069
+
1070
+ _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
1071
+ mapping_inner.name, NodeMappingType.Metric)
1072
+
1073
+ return None
1074
+
1075
+ mapping_inner.arg_names = leap_binder.setup_container.metrics[-1].metric_handler_data.arg_names
1076
+ mapping_inner.name = name
1077
+
1078
+ def final_inner(*args, **kwargs):
1079
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1080
+ return mapping_inner(*args, **kwargs)
1081
+ else:
1082
+ return inner(*args, **kwargs)
1083
+
1084
+ return final_inner
1085
+
1086
+ return decorating_function
1087
+
1088
+
880
1089
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
881
1090
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
882
1091
  connects_to=None):
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.173.dev2"
3
+ version = "1.0.174.dev2"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"