code-loader 1.0.78__tar.gz → 1.0.80__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. {code_loader-1.0.78 → code_loader-1.0.80}/PKG-INFO +1 -1
  2. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/contract/datasetclasses.py +1 -1
  3. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/inner_leap_binder/leapbinder.py +29 -5
  4. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/inner_leap_binder/leapbinder_decorators.py +12 -2
  5. {code_loader-1.0.78 → code_loader-1.0.80}/pyproject.toml +1 -1
  6. {code_loader-1.0.78 → code_loader-1.0.80}/LICENSE +0 -0
  7. {code_loader-1.0.78 → code_loader-1.0.80}/README.md +0 -0
  8. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/__init__.py +0 -0
  9. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/contract/__init__.py +0 -0
  10. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/contract/enums.py +0 -0
  11. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/contract/exceptions.py +0 -0
  12. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/contract/responsedataclasses.py +0 -0
  13. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/contract/visualizer_classes.py +0 -0
  14. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/default_losses.py +0 -0
  15. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/default_metrics.py +0 -0
  16. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/__init__.py +0 -0
  17. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/api.py +0 -0
  18. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/cli_config_utils.py +0 -0
  19. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/client.py +0 -0
  20. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/epoch.py +0 -0
  21. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/experiment.py +0 -0
  22. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/experiment_context.py +0 -0
  23. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/types.py +0 -0
  24. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/utils.py +0 -0
  25. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/experiment_api/workingspace_config_utils.py +0 -0
  26. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/inner_leap_binder/__init__.py +0 -0
  27. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/leaploader.py +0 -0
  28. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/leaploaderbase.py +0 -0
  29. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/utils.py +0 -0
  30. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/visualizers/__init__.py +0 -0
  31. {code_loader-1.0.78 → code_loader-1.0.80}/code_loader/visualizers/default_visualizers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.78
3
+ Version: 1.0.80
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -142,7 +142,7 @@ class MetricHandlerData:
142
142
  name: str
143
143
  arg_names: List[str]
144
144
  direction: Union[None, MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward
145
- compute_insights: Union[bool, Dict[str, bool]] = True
145
+ compute_insights: Optional[Union[bool, Dict[str, bool]]] = None
146
146
 
147
147
 
148
148
  @dataclass
@@ -103,7 +103,17 @@ class LeapBinder:
103
103
  heatmap_visualizer=image_resize_heatmap_visualizer
104
104
  )
105
105
  """
106
- arg_names = inspect.getfullargspec(function)[0]
106
+
107
+ regular_arg_names = inspect.getfullargspec(function)[0]
108
+ preprocess_response_arg_name = None
109
+ for arg_name, arg_type in inspect.getfullargspec(function).annotations.items():
110
+ if arg_type == SamplePreprocessResponse:
111
+ if preprocess_response_arg_name is not None:
112
+ raise Exception("only one argument can be of type SamplePreprocessResponse")
113
+ preprocess_response_arg_name = arg_name
114
+ regular_arg_names.remove(arg_name)
115
+
116
+ arg_names = regular_arg_names
107
117
  if heatmap_visualizer:
108
118
  visualizer_arg_names_set = set(arg_names)
109
119
  heatmap_visualizer_inspection = inspect.getfullargspec(heatmap_visualizer)
@@ -259,7 +269,7 @@ class LeapBinder:
259
269
  name: str,
260
270
  direction: Optional[
261
271
  Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward,
262
- compute_insights: Union[bool, Dict[str, bool]] = True) -> None:
272
+ compute_insights: Optional[Union[bool, Dict[str, bool]]] = None) -> None:
263
273
  """
264
274
  Add a custom metric to the setup.
265
275
 
@@ -282,8 +292,17 @@ class LeapBinder:
282
292
 
283
293
  leap_binder.add_custom_metric(custom_metric_function, name='custom_metric', direction=MetricDirection.Downward)
284
294
  """
285
- arg_names = inspect.getfullargspec(function)[0]
286
- metric_handler_data = MetricHandlerData(name, arg_names, direction, compute_insights)
295
+
296
+ regular_arg_names = inspect.getfullargspec(function)[0]
297
+ preprocess_response_arg_name = None
298
+ for arg_name, arg_type in inspect.getfullargspec(function).annotations.items():
299
+ if arg_type == SamplePreprocessResponse:
300
+ if preprocess_response_arg_name is not None:
301
+ raise Exception("only one argument can be of type SamplePreprocessResponse")
302
+ preprocess_response_arg_name = arg_name
303
+ regular_arg_names.remove(arg_name)
304
+
305
+ metric_handler_data = MetricHandlerData(name, regular_arg_names, direction, compute_insights)
287
306
  self.setup_container.metrics.append(MetricHandler(metric_handler_data, function))
288
307
 
289
308
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
@@ -448,7 +467,12 @@ class LeapBinder:
448
467
  if DataStateEnum.validation not in preprocess_result_dict:
449
468
  raise Exception("Validation data is required")
450
469
 
451
- return preprocess_result_dict
470
+ preprocess_result_dict_in_correct_order = {}
471
+ for state_enum in DataStateEnum:
472
+ if state_enum in preprocess_result_dict:
473
+ preprocess_result_dict_in_correct_order[state_enum] = preprocess_result_dict[state_enum]
474
+
475
+ return preprocess_result_dict_in_correct_order
452
476
 
453
477
  def get_preprocess_unlabeled_result(self) -> Optional[PreprocessResponse]:
454
478
  unlabeled_preprocess = self.setup_container.unlabeled_data_preprocess
@@ -17,7 +17,7 @@ from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, Le
17
17
 
18
18
  def tensorleap_custom_metric(name: str,
19
19
  direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
20
- compute_insights: Union[bool, Dict[str, bool]] = True):
20
+ compute_insights: Optional[Union[bool, Dict[str, bool]]] = None):
21
21
  def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
22
22
  CustomMultipleReturnCallableInterfaceMultiArgs,
23
23
  ConfusionMatrixCallableInterfaceMultiArgs]):
@@ -85,7 +85,11 @@ def tensorleap_custom_metric(name: str,
85
85
  (f'tensorleap_custom_metric validation failed: '
86
86
  f'Keys in the direction mapping should be part of result keys. Got key {direction_key}.')
87
87
 
88
- if isinstance(compute_insights, dict):
88
+ if compute_insights is not None:
89
+ assert isinstance(compute_insights, dict), \
90
+ (f'tensorleap_custom_metric validation failed: '
91
+ f'compute_insights should be dict if using the dict results. Got {type(compute_insights)}.')
92
+
89
93
  for ci_key in compute_insights:
90
94
  assert ci_key in result, \
91
95
  (f'tensorleap_custom_metric validation failed: '
@@ -94,6 +98,12 @@ def tensorleap_custom_metric(name: str,
94
98
  else:
95
99
  _validate_single_metric(result)
96
100
 
101
+ if compute_insights is not None:
102
+ assert isinstance(compute_insights, bool), \
103
+ (f'tensorleap_custom_metric validation failed: '
104
+ f'compute_insights should be boolean. Got {type(compute_insights)}.')
105
+
106
+
97
107
  def inner(*args, **kwargs):
98
108
  _validate_input_args(*args, **kwargs)
99
109
  result = user_function(*args, **kwargs)
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "code-loader"
3
- version = "1.0.78"
3
+ version = "1.0.80"
4
4
  description = ""
5
5
  authors = ["dorhar <doron.harnoy@tensorleap.ai>"]
6
6
  license = "MIT"
File without changes
File without changes