code-loader 1.0.175__tar.gz → 1.0.176__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {code_loader-1.0.175 → code_loader-1.0.176}/PKG-INFO +3 -3
  2. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/datasetclasses.py +7 -0
  3. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/inner_leap_binder/leapbinder.py +42 -1
  4. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/inner_leap_binder/leapbinder_decorators.py +209 -0
  5. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/leaploader.py +33 -1
  6. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/leaploaderbase.py +10 -0
  7. {code_loader-1.0.175 → code_loader-1.0.176}/pyproject.toml +1 -1
  8. {code_loader-1.0.175 → code_loader-1.0.176}/LICENSE +0 -0
  9. {code_loader-1.0.175 → code_loader-1.0.176}/README.md +0 -0
  10. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/__init__.py +0 -0
  11. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/__init__.py +0 -0
  12. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/enums.py +0 -0
  13. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/exceptions.py +0 -0
  14. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/mapping.py +0 -0
  15. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/responsedataclasses.py +0 -0
  16. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/contract/visualizer_classes.py +0 -0
  17. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/default_losses.py +0 -0
  18. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/default_metrics.py +0 -0
  19. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/__init__.py +0 -0
  20. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/api.py +0 -0
  21. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/cli_config_utils.py +0 -0
  22. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/client.py +0 -0
  23. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/epoch.py +0 -0
  24. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/experiment.py +0 -0
  25. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/experiment_context.py +0 -0
  26. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/types.py +0 -0
  27. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/utils.py +0 -0
  28. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  29. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/inner_leap_binder/__init__.py +0 -0
  30. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/mixpanel_tracker.py +0 -0
  31. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/plot_functions/__init__.py +0 -0
  32. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/plot_functions/plot_functions.py +0 -0
  33. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/plot_functions/visualize.py +0 -0
  34. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/utils.py +0 -0
  35. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/visualizers/__init__.py +0 -0
  36. {code_loader-1.0.175 → code_loader-1.0.176}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,7 +1,8 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.175
3
+ Version: 1.0.176
4
4
  Summary:
5
+ Home-page: https://github.com/tensorleap/code-loader
5
6
  License: MIT
6
7
  Author: dorhar
7
8
  Author-email: doron.harnoy@tensorleap.ai
@@ -19,7 +20,6 @@ Requires-Dist: numpy (>=2.3.2,<3.0.0) ; python_version >= "3.11" and python_vers
19
20
  Requires-Dist: psutil (>=5.9.5,<6.0.0)
20
21
  Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
21
22
  Requires-Dist: requests (>=2.32.3,<3.0.0)
22
- Project-URL: Homepage, https://github.com/tensorleap/code-loader
23
23
  Project-URL: Repository, https://github.com/tensorleap/code-loader
24
24
  Description-Content-Type: text/markdown
25
25
 
@@ -171,6 +171,12 @@ class MetricHandler:
171
171
  function: Union[CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs]
172
172
 
173
173
 
174
+ @dataclass
175
+ class InstanceMetricHandler:
176
+ metric_handler_data: MetricHandlerData
177
+ function: CustomMultipleReturnCallableInterfaceMultiArgs
178
+
179
+
174
180
  @dataclass
175
181
  class RawInputsForHeatmap:
176
182
  raw_input_by_vizualizer_arg_name: Dict[str, npt.NDArray[np.float32]]
@@ -258,6 +264,7 @@ class DatasetIntegrationSetup:
258
264
  prediction_types: List[PredictionTypeHandler] = field(default_factory=list)
259
265
  custom_loss_handlers: List[CustomLossHandler] = field(default_factory=list)
260
266
  metrics: List[MetricHandler] = field(default_factory=list)
267
+ instance_metrics: List[InstanceMetricHandler] = field(default_factory=list)
261
268
  custom_layers: Dict[str, CustomLayerHandler] = field(default_factory=dict)
262
269
  custom_latent_space: Optional[CustomLatentSpaceHandler] = None
263
270
 
@@ -12,7 +12,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
12
12
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
13
13
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
14
14
  RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse, \
15
- ElementInstanceMasksHandler, InstanceCallableInterface, CustomLatentSpaceHandler
15
+ ElementInstanceMasksHandler, InstanceCallableInterface, CustomLatentSpaceHandler, InstanceMetricHandler
16
16
  from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, DatasetMetadataType
17
17
  from code_loader.contract.mapping import NodeConnection, NodeMapping, NodeMappingType
18
18
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload, LeapAnalysisConfiguration
@@ -335,6 +335,47 @@ class LeapBinder:
335
335
  metric_handler_data = MetricHandlerData(name, regular_arg_names, direction, compute_insights)
336
336
  self.setup_container.metrics.append(MetricHandler(metric_handler_data, function))
337
337
 
338
+ def add_custom_instance_metric(self,
339
+ function: CustomMultipleReturnCallableInterfaceMultiArgs,
340
+ name: str,
341
+ direction: Optional[
342
+ Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward,
343
+ compute_insights: Optional[Union[bool, Dict[str, bool]]] = None) -> None:
344
+ """
345
+ Add a custom metric to the setup.
346
+
347
+ Args:
348
+ function (CustomMultipleReturnCallableInterfaceMultiArgs): The custom metric function returning Dict[int, np.ndarray].
349
+ name (str): The name of the custom metric.
350
+ direction (Optional[Union[MetricDirection, Dict[str, MetricDirection]]]): The direction of the metric, either
351
+ MetricDirection.Upward or MetricDirection.Downward, in case custom metric return a dictionary of metrics we can
352
+ supply a dictionary of directions correspondingly.
353
+ - MetricDirection.Upward: Indicates that higher values of the metric are better and should be maximized.
354
+ - MetricDirection.Downward: Indicates that lower values of the metric are better and should be minimized.
355
+ compute_insights (Union[bool, Dict[str, bool]]): Whether to compute insights or not. in case custom metric
356
+ return a dictionary of metrics we can supply a dictionary of values correspondingly
357
+
358
+
359
+
360
+ Example:
361
+ def custom_metric_function(y_true, y_pred):
362
+ return np.mean(np.abs(y_true - y_pred))
363
+
364
+ leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
365
+ """
366
+
367
+ regular_arg_names = inspect.getfullargspec(function)[0]
368
+ preprocess_response_arg_name = None
369
+ for arg_name, arg_type in inspect.getfullargspec(function).annotations.items():
370
+ if arg_type == SamplePreprocessResponse:
371
+ if preprocess_response_arg_name is not None:
372
+ raise Exception("only one argument can be of type SamplePreprocessResponse")
373
+ preprocess_response_arg_name = arg_name
374
+ regular_arg_names.remove(arg_name)
375
+
376
+ metric_handler_data = MetricHandlerData(name, regular_arg_names, direction, compute_insights)
377
+ self.setup_container.instance_metrics.append(InstanceMetricHandler(metric_handler_data, function))
378
+
338
379
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1, prediction_index: Optional[int]=None) -> None:
339
380
  """
340
381
  Add prediction labels to the setup.
@@ -877,6 +877,215 @@ def tensorleap_custom_metric(name: str,
877
877
  return decorating_function
878
878
 
879
879
 
880
+ def tensorleap_custom_instances_metric(name: str,
881
+ direction: Union[MetricDirection, Dict[str, MetricDirection]] = _UNSET,
882
+ compute_insights: Optional[Union[bool, Dict[str, bool]]] = None,
883
+ connects_to=None):
884
+ name_to_unique_name = defaultdict(set)
885
+
886
+ def decorating_function(
887
+ user_function: CustomMultipleReturnCallableInterfaceMultiArgs):
888
+ nonlocal direction
889
+
890
+ direction_was_provided = direction is not _UNSET
891
+
892
+ def _validate_decorators_signature():
893
+ err_message = f"{user_function.__name__} validation failed.\n"
894
+ if not isinstance(name, str):
895
+ raise TypeError(err_message + f"`name` must be a string, got type {type(name).__name__}.")
896
+ valid_directions = {MetricDirection.Upward, MetricDirection.Downward}
897
+ if direction is _UNSET:
898
+ pass
899
+ elif isinstance(direction, MetricDirection):
900
+ if direction not in valid_directions:
901
+ raise ValueError(
902
+ err_message +
903
+ f"Invalid MetricDirection: {direction}. Must be one of {valid_directions}, "
904
+ f"got type {type(direction).__name__}."
905
+ )
906
+ elif isinstance(direction, dict):
907
+ if not all(isinstance(k, str) for k in direction.keys()):
908
+ invalid_keys = {k: type(k).__name__ for k in direction.keys() if not isinstance(k, str)}
909
+ raise TypeError(
910
+ err_message +
911
+ f"All keys in `direction` must be strings, got invalid key types: {invalid_keys}."
912
+ )
913
+ for k, v in direction.items():
914
+ if v not in valid_directions:
915
+ raise ValueError(
916
+ err_message +
917
+ f"Invalid direction for key '{k}': {v}. Must be one of {valid_directions}, "
918
+ f"got type {type(v).__name__}."
919
+ )
920
+ else:
921
+ raise TypeError(
922
+ err_message +
923
+ f"`direction` must be a MetricDirection or a Dict[str, MetricDirection], "
924
+ f"got type {type(direction).__name__}."
925
+ )
926
+ if compute_insights is not None:
927
+ if not isinstance(compute_insights, (bool, dict)):
928
+ raise TypeError(
929
+ err_message +
930
+ f"`compute_insights` must be a bool or a Dict[str, bool], "
931
+ f"got type {type(compute_insights).__name__}."
932
+ )
933
+ if isinstance(compute_insights, dict):
934
+ if not all(isinstance(k, str) for k in compute_insights.keys()):
935
+ invalid_keys = {k: type(k).__name__ for k in compute_insights.keys() if not isinstance(k, str)}
936
+ raise TypeError(
937
+ err_message +
938
+ f"All keys in `compute_insights` must be strings, got invalid key types: {invalid_keys}."
939
+ )
940
+ for k, v in compute_insights.items():
941
+ if not isinstance(v, bool):
942
+ raise TypeError(
943
+ err_message +
944
+ f"Invalid type for compute_insights['{k}']: expected bool, got type {type(v).__name__}."
945
+ )
946
+ if connects_to is not None:
947
+ valid_types = (str, list, tuple, set)
948
+ if not isinstance(connects_to, valid_types):
949
+ raise TypeError(
950
+ err_message +
951
+ f"`connects_to` must be one of {valid_types}, got type {type(connects_to).__name__}."
952
+ )
953
+ if isinstance(connects_to, (list, tuple, set)):
954
+ invalid_elems = [f"{type(e).__name__}" for e in connects_to if not isinstance(e, str)]
955
+ if invalid_elems:
956
+ raise TypeError(
957
+ err_message +
958
+ f"All elements in `connects_to` must be strings, "
959
+ f"but found element types: {invalid_elems}."
960
+ )
961
+
962
+ _validate_decorators_signature()
963
+
964
+ for metric_handler in leap_binder.setup_container.instance_metrics:
965
+ if metric_handler.metric_handler_data.name == name:
966
+ raise Exception(f'Metric with name {name} already exists. '
967
+ f'Please choose another')
968
+
969
+ def _validate_input_args(*args, **kwargs) -> None:
970
+ assert len(args) + len(kwargs) > 0, (
971
+ f"{user_function.__name__}() validation failed: "
972
+ f"Expected at least one positional|key-word argument of type np.ndarray, "
973
+ f"but received none. "
974
+ f"Correct usage example: tensorleap_custom_metric(input_array: np.ndarray, ...)"
975
+ )
976
+ for i, arg in enumerate(args):
977
+ assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
978
+ f'{user_function.__name__}() validation failed: '
979
+ f'Argument #{i} should be a numpy array. Got {type(arg)}.')
980
+ if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
981
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
982
+ (f'{user_function.__name__}() validation failed: Argument #{i} '
983
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
984
+ f'instead of {leap_binder.batch_size_to_validate}')
985
+
986
+ for _arg_name, arg in kwargs.items():
987
+ assert isinstance(arg, (np.ndarray, SamplePreprocessResponse)), (
988
+ f'{user_function.__name__}() validation failed: '
989
+ f'Argument {_arg_name} should be a numpy array. Got {type(arg)}.')
990
+ if leap_binder.batch_size_to_validate and isinstance(arg, np.ndarray):
991
+ assert arg.shape[0] == leap_binder.batch_size_to_validate, \
992
+ (f'{user_function.__name__}() validation failed: Argument {_arg_name} '
993
+ f'first dim should be as the batch size. Got {arg.shape[0]} '
994
+ f'instead of {leap_binder.batch_size_to_validate}')
995
+
996
+ def _validate_result(result) -> None:
997
+ nonlocal direction
998
+ supported_types_message = (f'{user_function.__name__}() validation failed: '
999
+ f'{user_function.__name__}() has returned unsupported type.\nSupported type is Dict[List[NDArray[np.float32]]]')
1000
+
1001
+ def _validate_single_metric(single_metric_result, key=None):
1002
+ assert isinstance(single_metric_result,
1003
+ np.ndarray), f'{supported_types_message}\nGot {type(single_metric_result)}.'
1004
+ assert len(single_metric_result.shape) == 1, (f'{user_function.__name__}() validation failed: '
1005
+ f'The return shape should be 1D. Got {len(single_metric_result.shape)}D.')
1006
+
1007
+ if leap_binder.batch_size_to_validate:
1008
+ assert len(single_metric_result) == leap_binder.batch_size_to_validate, \
1009
+ f'{user_function.__name__}() validation failed: The return len {f"of srt{key} value" if key is not None else ""} should be as the batch size.'
1010
+
1011
+ assert isinstance(result, dict), f'{supported_types_message}\nGot {type(result)}.'
1012
+
1013
+ result_keys = set(result.keys())
1014
+ for key, value in result.items():
1015
+ _validate_single_metric(value, key)
1016
+
1017
+ assert isinstance(key, int), \
1018
+ (f'{user_function.__name__}() validation failed: '
1019
+ f'Keys in the return dict should be of type int (instance number). Got {type(key)}.')
1020
+
1021
+ leap_binder.setup_container.instance_metrics[-1].metric_handler_data.direction = direction
1022
+
1023
+ @functools.wraps(user_function)
1024
+ def inner_without_validate(*args, **kwargs):
1025
+ global _called_from_inside_tl_decorator
1026
+ _called_from_inside_tl_decorator += 1
1027
+
1028
+ try:
1029
+ result = user_function(*args, **kwargs)
1030
+ finally:
1031
+ _called_from_inside_tl_decorator -= 1
1032
+
1033
+ return result
1034
+
1035
+ try:
1036
+ inner_without_validate.__signature__ = inspect.signature(user_function)
1037
+ except (TypeError, ValueError):
1038
+ pass
1039
+
1040
+ leap_binder.add_custom_instance_metric(inner_without_validate, name, direction, compute_insights)
1041
+
1042
+ if connects_to is not None:
1043
+ arg_names = leap_binder.setup_container.instance_metrics[-1].metric_handler_data.arg_names
1044
+ _add_mapping_connections(connects_to, arg_names, NodeMappingType.Metric, name)
1045
+
1046
+ def inner(*args, **kwargs):
1047
+ if not _call_from_tl_platform:
1048
+ set_current('tensorleap_custom_instances_metric')
1049
+ _validate_input_args(*args, **kwargs)
1050
+
1051
+ result = inner_without_validate(*args, **kwargs)
1052
+
1053
+ _validate_result(result)
1054
+ if not _call_from_tl_platform:
1055
+ update_env_params_func("tensorleap_custom_instances_metric", "v")
1056
+ return result
1057
+
1058
+ def mapping_inner(*args, **kwargs):
1059
+ user_unique_name = mapping_inner.name
1060
+ if 'user_unique_name' in kwargs:
1061
+ user_unique_name = kwargs['user_unique_name']
1062
+
1063
+ ordered_connections = [kwargs[n] for n in mapping_inner.arg_names if n in kwargs]
1064
+ ordered_connections = list(args) + ordered_connections
1065
+
1066
+ if user_unique_name in name_to_unique_name[mapping_inner.name]:
1067
+ user_unique_name = f'{user_unique_name}_{len(name_to_unique_name[mapping_inner.name])}'
1068
+ name_to_unique_name[mapping_inner.name].add(user_unique_name)
1069
+
1070
+ _add_mapping_connection(user_unique_name, ordered_connections, mapping_inner.arg_names,
1071
+ mapping_inner.name, NodeMappingType.Metric)
1072
+
1073
+ return None
1074
+
1075
+ mapping_inner.arg_names = leap_binder.setup_container.instance_metrics[-1].metric_handler_data.arg_names
1076
+ mapping_inner.name = name
1077
+
1078
+ def final_inner(*args, **kwargs):
1079
+ if os.environ.get(mapping_runtime_mode_env_var_mame):
1080
+ return mapping_inner(*args, **kwargs)
1081
+ else:
1082
+ return inner(*args, **kwargs)
1083
+
1084
+ return final_inner
1085
+
1086
+ return decorating_function
1087
+
1088
+
880
1089
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
881
1090
  heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
882
1091
  connects_to=None):
@@ -16,7 +16,7 @@ from code_loader.contract.datasetclasses import DatasetSample, DatasetBaseHandle
16
16
  PreprocessResponse, VisualizerHandler, LeapData, \
17
17
  PredictionTypeHandler, MetadataHandler, CustomLayerHandler, MetricHandler, VisualizerHandlerData, MetricHandlerData, \
18
18
  MetricCallableReturnType, CustomLossHandlerData, CustomLossHandler, RawInputsForHeatmap, SamplePreprocessResponse, \
19
- ElementInstance, custom_latent_space_attribute, DatasetIntegrationSetup
19
+ ElementInstance, custom_latent_space_attribute, DatasetIntegrationSetup, InstanceMetricHandler
20
20
  from code_loader.contract.enums import DataStateEnum, TestingSectionEnum, DataStateType, DatasetMetadataType
21
21
  from code_loader.contract.exceptions import DatasetScriptException
22
22
  from code_loader.contract.responsedataclasses import DatasetIntegParseResult, DatasetTestResultPayload, \
@@ -109,6 +109,24 @@ class LeapLoader(LeapLoaderBase):
109
109
  for metric_handler in setup.metrics
110
110
  }
111
111
 
112
+ @lru_cache()
113
+ def instance_metric_by_name(self) -> Dict[str, MetricHandlerData]:
114
+ self.exec_script()
115
+ setup = global_leap_binder.setup_container
116
+ return {
117
+ handler.metric_handler_data.name: handler.metric_handler_data
118
+ for handler in setup.instance_metrics
119
+ }
120
+
121
+ @lru_cache()
122
+ def _instance_metric_handler_by_name(self) -> Dict[str, InstanceMetricHandler]:
123
+ self.exec_script()
124
+ setup = global_leap_binder.setup_container
125
+ return {
126
+ handler.metric_handler_data.name: handler
127
+ for handler in setup.instance_metrics
128
+ }
129
+
112
130
  @lru_cache()
113
131
  def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
114
132
  self.exec_script()
@@ -293,6 +311,20 @@ class LeapLoader(LeapLoaderBase):
293
311
 
294
312
  return metric_handler.function(**input_tensors_by_arg_name)
295
313
 
314
+ def run_instance_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
315
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
316
+ ) -> Dict[int, npt.NDArray[np.float32]]:
317
+ self._preprocess_result()
318
+
319
+ handler = self._instance_metric_handler_by_name()[metric_name]
320
+ preprocess_response_arg_name = self._get_preprocess_response_arg_name(handler.function)
321
+
322
+ if preprocess_response_arg_name is not None:
323
+ input_tensors_by_arg_name[preprocess_response_arg_name] = SamplePreprocessResponse(
324
+ sample_ids, self._preprocess_result()[state])
325
+
326
+ return handler.function(**input_tensors_by_arg_name)
327
+
296
328
  @staticmethod
297
329
  def _get_preprocess_response_arg_name(
298
330
  func: Callable) -> Optional[str]:
@@ -44,6 +44,10 @@ class LeapLoaderBase:
44
44
  def metric_by_name(self) -> Dict[str, MetricHandlerData]:
45
45
  pass
46
46
 
47
+ @abstractmethod
48
+ def instance_metric_by_name(self) -> Dict[str, MetricHandlerData]:
49
+ pass
50
+
47
51
  @abstractmethod
48
52
  def visualizer_by_name(self) -> Dict[str, VisualizerHandlerData]:
49
53
  pass
@@ -100,6 +104,12 @@ class LeapLoaderBase:
100
104
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]) -> MetricCallableReturnType:
101
105
  pass
102
106
 
107
+ @abstractmethod
108
+ def run_instance_metric(self, metric_name: str, sample_ids: np.array, state: DataStateEnum,
109
+ input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]
110
+ ) -> Dict[int, npt.NDArray[np.float32]]:
111
+ pass
112
+
103
113
  @abstractmethod
104
114
  def run_custom_loss(self, custom_loss_name: str, sample_ids: np.array, state: DataStateEnum,
105
115
  input_tensors_by_arg_name: Dict[str, npt.NDArray[np.float32]]):
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.175"
3
+ version = "1.0.176"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"
File without changes
File without changes