code-loader 1.0.77.dev4__py3-none-any.whl → 1.0.77.dev30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,8 +4,7 @@ import re
4
4
  import numpy as np
5
5
  import numpy.typing as npt
6
6
 
7
- from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, \
8
- MetricDirection, DatasetMetadataType
7
+ from code_loader.contract.enums import DataStateType, DataStateEnum, LeapDataType, ConfusionMatrixValue, MetricDirection
9
8
  from code_loader.contract.visualizer_classes import LeapImage, LeapText, LeapGraph, LeapHorizontalBar, \
10
9
  LeapTextMask, LeapImageMask, LeapImageWithBBox, LeapImageWithHeatmap
11
10
 
@@ -197,7 +196,6 @@ class GroundTruthHandler(DatasetBaseHandler):
197
196
  class MetadataHandler:
198
197
  name: str
199
198
  function: MetadataSectionCallableInterface
200
- metadata_type: Optional[Union[DatasetMetadataType, Dict[str, DatasetMetadataType]]] = None
201
199
 
202
200
 
203
201
  @dataclass
@@ -234,7 +232,6 @@ class DatasetIntegrationSetup:
234
232
  class DatasetSample:
235
233
  inputs: Dict[str, npt.NDArray[np.float32]]
236
234
  gt: Optional[Dict[str, npt.NDArray[np.float32]]]
237
- metadata: Dict[str, Union[Optional[str], int, bool, Optional[float]]]
238
- metadata_is_none: Dict[str, bool]
235
+ metadata: Dict[str, Union[str, int, bool, float]]
239
236
  index: Union[int, str]
240
237
  state: DataStateEnum
@@ -0,0 +1,34 @@
1
+
2
+ from typing import List, Dict, Union, Callable
3
+ from enum import Enum, IntEnum
4
+ from dataclasses import dataclass, field
5
+ from typing import Union
6
+
7
+ class FunctionType(IntEnum):
8
+ gt = 0
9
+ visualizer = 1
10
+ input = 2
11
+ metric = 3
12
+ loss = 4
13
+ model_input = 5
14
+ model_output = 6
15
+
16
+ @dataclass(frozen=True)
17
+ class TLFunction:
18
+ FuncType: FunctionType
19
+ FuncName: str
20
+
21
+ MappingList = List[Dict[TLFunction, Dict[str, Union[Callable, int]]]]
22
+
23
+ def leap_input(idx):
24
+ def dummy():
25
+ return None
26
+ dummy.tl_func = TLFunction(FunctionType.model_input, idx)
27
+ return dummy
28
+
29
+ def leap_output(idx):
30
+ def dummy():
31
+ return None
32
+ dummy.tl_func = TLFunction(FunctionType.model_output, idx)
33
+ return dummy
34
+
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Callable, List, Optional, Dict, Any, Type, Union
2
+ from typing import Callable, List, Optional, Dict, Any, Type, Union, Tuple
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
@@ -11,7 +11,7 @@ from code_loader.contract.datasetclasses import SectionCallableInterface, InputH
11
11
  CustomCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, LeapData, \
12
12
  CustomMultipleReturnCallableInterfaceMultiArgs, DatasetBaseHandler, custom_latent_space_attribute, \
13
13
  RawInputsForHeatmap, VisualizerHandlerData, MetricHandlerData, CustomLossHandlerData, SamplePreprocessResponse
14
- from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection, DatasetMetadataType
14
+ from code_loader.contract.enums import LeapDataType, DataStateEnum, DataStateType, MetricDirection
15
15
  from code_loader.contract.responsedataclasses import DatasetTestResultPayload
16
16
  from code_loader.contract.visualizer_classes import map_leap_data_type_to_visualizer_class
17
17
  from code_loader.default_losses import loss_name_to_function
@@ -21,6 +21,7 @@ from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
21
21
  default_graph_visualizer, \
22
22
  default_image_visualizer, default_horizontal_bar_visualizer, default_word_visualizer, \
23
23
  default_image_mask_visualizer, default_text_mask_visualizer, default_raw_data_visualizer
24
+ from code_loader.inner_leap_binder.inner_classes import FunctionType, TLFunction, MappingList, leap_output, leap_input
24
25
 
25
26
 
26
27
  class LeapBinder:
@@ -37,6 +38,7 @@ class LeapBinder:
37
38
  def __init__(self) -> None:
38
39
  self.setup_container = DatasetIntegrationSetup()
39
40
  self.cache_container: Dict[str, Any] = {"word_to_index": {}}
41
+ self._mappings: MappingList = list()
40
42
  self._visualizer_names: List[str] = list()
41
43
  self._encoder_names: List[str] = list()
42
44
  self._extend_with_default_visualizers()
@@ -72,7 +74,8 @@ class LeapBinder:
72
74
  def set_visualizer(self, function: VisualizerCallableInterface,
73
75
  name: str,
74
76
  visualizer_type: LeapDataType,
75
- heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None) -> None:
77
+ heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
78
+ connects_to = None) -> None: #TODO add types
76
79
  """
77
80
  Set a visualizer for a specific data type.
78
81
 
@@ -139,10 +142,11 @@ class LeapBinder:
139
142
  raise Exception(
140
143
  f'The return type of function {function.__name__} is invalid. current return type: {return_type}, '
141
144
  f'should be {expected_return_type}')
142
-
143
145
  self.setup_container.visualizers.append(
144
146
  VisualizerHandler(VisualizerHandlerData(name, visualizer_type, arg_names), function, heatmap_visualizer))
145
147
  self._visualizer_names.append(name)
148
+ if connects_to is not None:
149
+ self._mappings.append({TLFunction(FunctionType.visualizer, name): connects_to})
146
150
 
147
151
  def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
148
152
  """
@@ -259,7 +263,8 @@ class LeapBinder:
259
263
  name: str,
260
264
  direction: Optional[
261
265
  Union[MetricDirection, Dict[str, MetricDirection]]] = MetricDirection.Downward,
262
- compute_insights: Union[bool, Dict[str, bool]] = True) -> None:
266
+ compute_insights: Union[bool, Dict[str, bool]] = True,
267
+ connects_to = None) -> None: #TODO TOM add types
263
268
  """
264
269
  Add a custom metric to the setup.
265
270
 
@@ -285,6 +290,8 @@ class LeapBinder:
285
290
  arg_names = inspect.getfullargspec(function)[0]
286
291
  metric_handler_data = MetricHandlerData(name, arg_names, direction, compute_insights)
287
292
  self.setup_container.metrics.append(MetricHandler(metric_handler_data, function))
293
+ if connects_to is not None:
294
+ self._mappings.append({TLFunction(FunctionType.metric, name): connects_to})
288
295
 
289
296
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
290
297
  """
@@ -333,8 +340,7 @@ class LeapBinder:
333
340
 
334
341
  self._encoder_names.append(name)
335
342
 
336
- def set_metadata(self, function: MetadataSectionCallableInterface, name: str,
337
- metadata_type: Optional[Union[DatasetMetadataType, Dict[str, DatasetMetadataType]]] = None) -> None:
343
+ def set_metadata(self, function: MetadataSectionCallableInterface, name: str) -> None:
338
344
  """
339
345
  Set the metadata handler function. This function is used for measuring and analyzing external variable values per sample, which is recommended for analysis within the Tensorleap platform.
340
346
 
@@ -369,7 +375,7 @@ class LeapBinder:
369
375
  leap_binder.set_metadata(metadata_handler_index, name='metadata_index')
370
376
  leap_binder.set_metadata(metadata_handler_image_mean, name='metadata_image_mean')
371
377
  """
372
- self.setup_container.metadata.append(MetadataHandler(name, function, metadata_type))
378
+ self.setup_container.metadata.append(MetadataHandler(name, function))
373
379
 
374
380
  def set_custom_layer(self, custom_layer: Type[Any], name: str, inspect_layer: bool = False,
375
381
  kernel_index: Optional[int] = None, use_custom_latent_space: bool = False) -> None:
@@ -466,55 +472,23 @@ class LeapBinder:
466
472
  @staticmethod
467
473
  def check_handler(
468
474
  preprocess_response: PreprocessResponse, test_result: List[DatasetTestResultPayload],
469
- dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler], state: DataStateEnum) -> List[DatasetTestResultPayload]:
475
+ dataset_base_handler: Union[DatasetBaseHandler, MetadataHandler]) -> List[DatasetTestResultPayload]:
470
476
  assert preprocess_response.sample_ids is not None
471
477
  raw_result = dataset_base_handler.function(preprocess_response.sample_ids[0], preprocess_response)
472
478
  handler_type = 'metadata' if isinstance(dataset_base_handler, MetadataHandler) else None
473
- if isinstance(dataset_base_handler, MetadataHandler):
474
- if isinstance(raw_result, dict):
475
- metadata_test_result_payloads = [
476
- DatasetTestResultPayload(f'{dataset_base_handler.name}_{single_metadata_name}')
477
- for single_metadata_name, single_metadata_result in raw_result.items()
478
- ]
479
- for i, (single_metadata_name, single_metadata_result) in enumerate(raw_result.items()):
480
- metadata_test_result = metadata_test_result_payloads[i]
481
-
482
- metadata_type = None
483
- if single_metadata_result is None:
484
- if state != DataStateEnum.training and test_result[i].name == f'{dataset_base_handler.name}_{single_metadata_name}':
485
- metadata_test_result_payloads[i] = test_result[i]
486
- continue
487
-
488
- if dataset_base_handler.metadata_type is None:
489
- raise Exception(f"Metadata {single_metadata_name} is None and no metadata type is provided")
490
- elif isinstance(dataset_base_handler.metadata_type, dict):
491
- if single_metadata_name not in dataset_base_handler.metadata_type:
492
- raise Exception(f"Metadata {single_metadata_name} is None and no metadata type is provided")
493
- metadata_type = dataset_base_handler.metadata_type[single_metadata_name]
494
- else:
495
- raise Exception(f"Metadata {single_metadata_name} is None and no metadata type is provided")
496
-
497
- result_shape = get_shape(single_metadata_result)
498
- metadata_test_result.shape = result_shape
499
- metadata_test_result.raw_result = (
500
- single_metadata_result) if single_metadata_result is not None else metadata_type
501
- metadata_test_result.handler_type = handler_type
502
- test_result = metadata_test_result_payloads
503
- else:
504
- if raw_result is None:
505
- if state != DataStateEnum.training:
506
- return test_result
507
-
508
- if dataset_base_handler.metadata_type is None:
509
- raise Exception(f"Metadata {dataset_base_handler.name} is None and no metadata type is provided")
510
- elif isinstance(dataset_base_handler.metadata_type, dict):
511
- raise Exception(f"Metadata {dataset_base_handler.name} is None and no metadata type is provided")
512
- metadata_type = dataset_base_handler.metadata_type
513
-
514
- result_shape = get_shape(raw_result)
515
- test_result[0].shape = result_shape
516
- test_result[0].raw_result = raw_result if raw_result is not None else metadata_type
517
- test_result[0].handler_type = handler_type
479
+ if isinstance(dataset_base_handler, MetadataHandler) and isinstance(raw_result, dict):
480
+ metadata_test_result_payloads = [
481
+ DatasetTestResultPayload(f'{dataset_base_handler.name}_{single_metadata_name}')
482
+ for single_metadata_name, single_metadata_result in raw_result.items()
483
+ ]
484
+ for i, (single_metadata_name, single_metadata_result) in enumerate(raw_result.items()):
485
+ metadata_test_result = metadata_test_result_payloads[i]
486
+
487
+ result_shape = get_shape(single_metadata_result)
488
+ metadata_test_result.shape = result_shape
489
+ metadata_test_result.raw_result = single_metadata_result
490
+ metadata_test_result.handler_type = handler_type
491
+ test_result = metadata_test_result_payloads
518
492
  else:
519
493
  result_shape = get_shape(raw_result)
520
494
  test_result[0].shape = result_shape
@@ -535,6 +509,9 @@ class LeapBinder:
535
509
  continue
536
510
  self.check_handler(preprocess_response, test_result, dataset_base_handler)
537
511
 
512
+ def set_model_inputs(self, connections: Tuple[Callable[..., Any], ...]):
513
+ self._mappings.append({TLFunction(FunctionType.model_input, "model"): {str(i): (connections[i], ) for i in range(len(connections))}})
514
+
538
515
  def check(self) -> None:
539
516
  preprocess_result = self.get_preprocess_result()
540
517
  self.check_preprocess(preprocess_result)
@@ -9,15 +9,16 @@ from code_loader.contract.datasetclasses import CustomCallableInterfaceMultiArgs
9
9
  CustomMultipleReturnCallableInterfaceMultiArgs, ConfusionMatrixCallableInterfaceMultiArgs, CustomCallableInterface, \
10
10
  VisualizerCallableInterface, MetadataSectionCallableInterface, PreprocessResponse, SectionCallableInterface, \
11
11
  ConfusionMatrixElement, SamplePreprocessResponse
12
- from code_loader.contract.enums import MetricDirection, LeapDataType, DatasetMetadataType
12
+ from code_loader.contract.enums import MetricDirection, LeapDataType
13
13
  from code_loader import leap_binder
14
14
  from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
15
15
  LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
16
-
16
+ from code_loader.inner_leap_binder.inner_classes import TLFunction, FunctionType
17
17
 
18
18
  def tensorleap_custom_metric(name: str,
19
19
  direction: Union[MetricDirection, Dict[str, MetricDirection]] = MetricDirection.Downward,
20
- compute_insights: Union[bool, Dict[str, bool]] = True):
20
+ compute_insights: Union[bool, Dict[str, bool]] = True,
21
+ connects_to = None):
21
22
  def decorating_function(user_function: Union[CustomCallableInterfaceMultiArgs,
22
23
  CustomMultipleReturnCallableInterfaceMultiArgs,
23
24
  ConfusionMatrixCallableInterfaceMultiArgs]):
@@ -26,7 +27,7 @@ def tensorleap_custom_metric(name: str,
26
27
  raise Exception(f'Metric with name {name} already exists. '
27
28
  f'Please choose another')
28
29
 
29
- leap_binder.add_custom_metric(user_function, name, direction, compute_insights)
30
+ leap_binder.add_custom_metric(user_function, name, direction, compute_insights, connects_to=connects_to)
30
31
 
31
32
  def _validate_input_args(*args, **kwargs) -> None:
32
33
  for i, arg in enumerate(args):
@@ -106,14 +107,15 @@ def tensorleap_custom_metric(name: str,
106
107
 
107
108
 
108
109
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
109
- heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None):
110
+ heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
111
+ connects_to = None): #TODO TOM ADD TYPES
110
112
  def decorating_function(user_function: VisualizerCallableInterface):
111
113
  for viz_handler in leap_binder.setup_container.visualizers:
112
114
  if viz_handler.visualizer_handler_data.name == name:
113
115
  raise Exception(f'Visualizer with name {name} already exists. '
114
116
  f'Please choose another')
115
117
 
116
- leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function)
118
+ leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function, connects_to)
117
119
 
118
120
  def _validate_input_args(*args, **kwargs):
119
121
  for i, arg in enumerate(args):
@@ -154,21 +156,21 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
154
156
  result = user_function(*args, **kwargs)
155
157
  _validate_result(result)
156
158
  return result
159
+ inner.tl_func = TLFunction(FunctionType.visualizer, name)
157
160
 
158
161
  return inner
159
162
 
160
163
  return decorating_function
161
164
 
162
165
 
163
- def tensorleap_metadata(
164
- name: str, metadata_type: Optional[Union[DatasetMetadataType, Dict[str, DatasetMetadataType]]] = None):
166
+ def tensorleap_metadata(name: str):
165
167
  def decorating_function(user_function: MetadataSectionCallableInterface):
166
168
  for metadata_handler in leap_binder.setup_container.metadata:
167
169
  if metadata_handler.name == name:
168
170
  raise Exception(f'Metadata with name {name} already exists. '
169
171
  f'Please choose another')
170
172
 
171
- leap_binder.set_metadata(user_function, name, metadata_type)
173
+ leap_binder.set_metadata(user_function, name)
172
174
 
173
175
  def _validate_input_args(sample_id: Union[int, str], preprocess_response: PreprocessResponse):
174
176
  assert isinstance(sample_id, (int, str)), \
@@ -183,7 +185,7 @@ def tensorleap_metadata(
183
185
  f'{preprocess_response.sample_id_type}. Got {type(sample_id)}.')
184
186
 
185
187
  def _validate_result(result):
186
- supported_result_types = (type(None), int, str, bool, float, dict, np.floating,
188
+ supported_result_types = (int, str, bool, float, dict, np.floating,
187
189
  np.bool_, np.unsignedinteger, np.signedinteger, np.integer)
188
190
  assert isinstance(result, supported_result_types), \
189
191
  (f'tensorleap_metadata validation failed: '
@@ -303,6 +305,7 @@ def tensorleap_input_encoder(name: str, channel_dim=-1):
303
305
  result = user_function(sample_id, preprocess_response)
304
306
  _validate_result(result)
305
307
  return result
308
+ inner.tl_func = TLFunction(FunctionType.input, name)
306
309
 
307
310
  return inner
308
311
 
@@ -343,7 +346,7 @@ def tensorleap_gt_encoder(name: str):
343
346
  result = user_function(sample_id, preprocess_response)
344
347
  _validate_result(result)
345
348
  return result
346
-
349
+ inner.tl_func = TLFunction(FunctionType.gt, name)
347
350
  return inner
348
351
 
349
352
  return decorating_function
@@ -395,6 +398,7 @@ def tensorleap_custom_loss(name: str):
395
398
  _validate_result(result)
396
399
  return result
397
400
 
401
+ inner.tl_func = TLFunction(FunctionType.loss, name)
398
402
  return inner
399
403
 
400
404
  return decorating_function
@@ -0,0 +1,219 @@
1
+ from typing import Optional, Callable, Any, Dict, List, Set, OrderedDict
2
+ from copy import deepcopy
3
+ import numpy as np
4
+ from collections import OrderedDict as od
5
+ from code_loader.inner_leap_binder.inner_classes import FunctionType, TLFunction, MappingList
6
+ from code_loader.contract.datasetclasses import DatasetIntegrationSetup, VisualizerHandler, MetricHandler
7
+ from code_loader.contract.datasetclasses import DatasetBaseHandler, GroundTruthHandler, InputHandler
8
+
9
+
10
+ class LeapMapping:
11
+
12
+ def __init__(self, setup_container: DatasetIntegrationSetup, mapping: MappingList):
13
+ self.container : DatasetIntegrationSetup = setup_container
14
+ self.mapping: MappingList = mapping
15
+ self.max_id = 20000
16
+ self.current_id = self.max_id
17
+ self.mapping_json = []
18
+ self.handle_dict = {
19
+ FunctionType.gt : self.container.ground_truths,
20
+ FunctionType.metric : self.container.metrics,
21
+ FunctionType.visualizer: self.container.visualizers,
22
+ FunctionType.input: self.container.inputs,
23
+ }
24
+ self.handle_to_op_type = {
25
+ GroundTruthHandler: "GroundTruth",
26
+ InputHandler: "Input",
27
+ VisualizerHandler: "Visualizer"
28
+
29
+ }
30
+ self.function_to_json : Dict[TLFunction, List[Dict[str, Any]]] = {}
31
+
32
+ def create_mapping_list(self, connections_list): #TODO add types
33
+ pass
34
+
35
+ def create_connect_nodes(self):
36
+ for element in self.mapping:
37
+ for sink in element.keys():
38
+ for node_name, connected_func_list in element[sink].items():
39
+ tl_connected_functions = [ connected_func.tl_func for connected_func in connected_func_list
40
+ if not isinstance(connected_func, int)]
41
+ for i, source in enumerate(tl_connected_functions):
42
+ print(sink, source, node_name)
43
+ sink_node = self.create_node(sink, name=node_name)
44
+ self.add_connections(sink_node, source, input_idx=i)
45
+ print(1)
46
+ import yaml
47
+ with open('temp.yaml', 'w') as yf:
48
+ yaml.safe_dump(({'decoder': self.mapping_json}), yf, default_flow_style=False, sort_keys=False)
49
+
50
+ def create_partial_mapping(self): #TODO add types
51
+ referenced = self._get_referenced_sinks()
52
+ for func in referenced:
53
+ kwargs = {}
54
+ if func.FuncType == FunctionType.model_output:
55
+ kwargs = {'name': func.FuncName}
56
+ self.create_node(func, **kwargs)
57
+ self.create_connect_nodes()
58
+
59
+ def _get_referenced_sinks(self) -> Set[TLFunction]:
60
+ referenced : Set[TLFunction] = set()
61
+ references_types = {FunctionType.input, FunctionType.gt, FunctionType.model_output}
62
+ for element in self.mapping:
63
+ for func in element.keys():
64
+ for node_name, connected_func_list in element[func].items():
65
+ tl_connected_functions = [ connected_func.tl_func for connected_func in connected_func_list
66
+ if not isinstance(connected_func, int)]
67
+ for connected_func in tl_connected_functions:
68
+ if connected_func.FuncType in references_types:
69
+ referenced.add(connected_func)
70
+ return referenced
71
+
72
+ def create_gt_input(self, handle: DatasetBaseHandler, **kwargs) -> Dict[str, Any]:
73
+ op_type = self.handle_to_op_type[type(handle)]
74
+ json_node = {'operation': op_type,
75
+ 'data': {'type': op_type, 'output_name': handle.name},
76
+ 'id': f'{str(self.current_id)}',
77
+ 'inputs': {},
78
+ 'outputs': {handle.name: []}}
79
+ return json_node
80
+
81
+ def create_visualize(self, visualize_handle: VisualizerHandler ,**kwargs) -> Dict[str, Any]:
82
+ op_type = "Visualizer"
83
+ json_node = {'operation': op_type,
84
+ 'data': {'type': op_type,
85
+ 'name': visualize_handle.name,
86
+ 'visualizer_name': visualize_handle.name,
87
+ 'visualizer_type': visualize_handle.type.name,
88
+ 'arg_names': deepcopy(visualize_handle.arg_names),
89
+ 'user_unique_name': kwargs['name']},
90
+ 'id': f'{str(self.current_id)}',
91
+ 'inputs': {arg: [] for arg in visualize_handle.arg_names},
92
+ 'outputs': {} }
93
+ return json_node
94
+
95
+ def create_metric(self, metric_handle: MetricHandler, **kwargs) -> Dict[str, Any]:
96
+ op_type = "Metric"
97
+ json_node = {'operation': op_type,
98
+ 'data': {'type': op_type,
99
+ 'name': metric_handle.name,
100
+ 'metric_name': metric_handle.name,
101
+ 'arg_names': deepcopy(metric_handle.arg_names),
102
+ 'user_unique_name': kwargs['name']},
103
+ 'id': f'{str(self.current_id)}',
104
+ 'inputs': {arg: [] for arg in metric_handle.arg_names},
105
+ 'outputs': {} }
106
+ return json_node
107
+
108
+
109
+ def create_temp_model_node(self, node_type: TLFunction, **kwargs) -> Dict[str, Any]:
110
+ default_name = self.get_default_name(node_type.FuncType, kwargs['name'])
111
+ op_type = "PlaceHolder"
112
+ json_node = {'operation': op_type,
113
+ 'data': {'type': op_type,
114
+ 'name': default_name,
115
+ 'arg_names': [default_name],
116
+ 'user_unique_name': default_name},
117
+ 'id': f'{default_name}',
118
+ 'inputs': {},
119
+ 'outputs': {} }
120
+ if node_type.FuncType == FunctionType.model_input:
121
+ json_node['inputs'] = {default_name: []}
122
+ else:
123
+ json_node['outputs'] = {default_name: []}
124
+ return json_node
125
+
126
+ def _find_handle(self, node: TLFunction) -> DatasetBaseHandler:
127
+ handle_list = self.handle_dict[node.FuncType]
128
+ chosen_handle: Optional[DatasetBaseHandler] = None
129
+ for handle in handle_list:
130
+ if handle.name == node.FuncName:
131
+ chosen_handle = handle
132
+ assert chosen_handle is not None
133
+ return chosen_handle
134
+
135
+ def _find_json_node(self, node: TLFunction, **kwargs) -> Dict[str, Any]:
136
+ candidates = self.function_to_json[node]
137
+ if len(candidates) == 1:
138
+ return candidates[0]
139
+ elif len(candidates) > 1:
140
+ raise Exception("More than 1 candidate - solve")
141
+ else:
142
+ raise Exception(f"No json node creted for node {node}")
143
+
144
+ def _get_node_id(self, node: TLFunction, **kwargs):
145
+ node = self._find_json_node(node, **kwargs)
146
+ return node['id']
147
+
148
+ def create_node(self, node_type: TLFunction, **kwargs) -> Dict[str, Any]:
149
+ if node_type.FuncType in [FunctionType.model_output, FunctionType.model_input]:
150
+ json_node = self.create_temp_model_node(node_type, **kwargs)
151
+ else:
152
+ node_creator_map: Dict[FunctionType, Callable[..., Any]] = { FunctionType.gt : self.create_gt_input,
153
+ FunctionType.input: self.create_gt_input,
154
+ FunctionType.visualizer: self.create_visualize,
155
+ FunctionType.metric: self.create_metric}
156
+ handle = self._find_handle(node_type)
157
+ json_node = node_creator_map[node_type.FuncType](handle, **kwargs)
158
+ self.current_id += 1
159
+ self.function_to_json[node_type] = self.function_to_json.get(node_type, []) + [json_node]
160
+ self.mapping_json.append(json_node)
161
+ return json_node
162
+
163
+ def get_default_name(self, f_type: FunctionType, name: str):
164
+ return f"{str(f_type)}_{name}"
165
+
166
+ def _get_node_output_name(self, node: TLFunction):
167
+ if node.FuncType not in {FunctionType.model_output, FunctionType.model_input}:
168
+ from_handle = self._find_handle(node)
169
+ output_name = from_handle.name
170
+ else:
171
+ output_name = self.get_default_name(node.FuncType, node.FuncName)
172
+ return output_name
173
+
174
+
175
+ def add_connections(self, to_json: Dict[str, Any], from_node: TLFunction, input_idx: int):
176
+ try:
177
+ input_name = to_json['data']['arg_names'][input_idx]
178
+ except Exception as e:
179
+ raise Exception(f"arg names mismatch when adding input for {to_json['operation']}")
180
+ output_key = self._get_node_output_name(from_node)
181
+ from_json = self._find_json_node(from_node)
182
+ operation = from_json['operation']
183
+ from_id = from_json['id']
184
+ # Add Input
185
+ to_json['inputs'][input_name].append({'outputKey': output_key,
186
+ 'operation': operation,
187
+ 'id': from_id
188
+ })
189
+ # Add Output
190
+ from_json['outputs'][output_key].append({'inputKey': input_name,
191
+ 'operation': to_json['operation'],
192
+ 'id': to_json['id']
193
+ })
194
+
195
+ def add_output(self, input_node, output_node):
196
+ pass
197
+
198
+ def create_mapping(model_parser_map, connections_list):
199
+
200
+ referenced_gt = []
201
+ referenced_inputs = []
202
+ for element in connections_list:
203
+ for func in element.keys():
204
+ for node_name, connected_func_list in element[func].items():
205
+ tl_connected_functions = [ connected_func.tl_func if not isinstance(connected_func, int) else connected_func
206
+ for connected_func in connected_func_list ]
207
+
208
+
209
+
210
+
211
+
212
+ # Check Max Node ID
213
+ # Get Preds Node ID
214
+
215
+ # TODO Create Inputs (if referenced)
216
+ # TODO Create Gts (if referenced)
217
+ # TODO Create Visualizers
218
+ # TODO Create Metrics
219
+ # TODO Create losses
code_loader/leaploader.py CHANGED
@@ -6,7 +6,7 @@ import sys
6
6
  from contextlib import redirect_stdout
7
7
  from functools import lru_cache
8
8
  from pathlib import Path
9
- from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable, Tuple
9
+ from typing import Dict, List, Iterable, Union, Any, Type, Optional, Callable
10
10
 
11
11
  import numpy as np
12
12
  import numpy.typing as npt
@@ -46,7 +46,7 @@ class LeapLoader(LeapLoaderBase):
46
46
 
47
47
  def evaluate_module(self) -> None:
48
48
  def append_path_recursively(full_path: str) -> None:
49
- if '/' not in full_path or full_path == '/':
49
+ if self.code_path not in full_path or full_path == '/':
50
50
  return
51
51
 
52
52
  parent_path = str(Path(full_path).parent)
@@ -140,11 +140,9 @@ class LeapLoader(LeapLoaderBase):
140
140
  if state == DataStateEnum.unlabeled and sample_id not in preprocess_result[state].sample_ids:
141
141
  self._preprocess_result(update_unlabeled_preprocess=True)
142
142
 
143
- metadata, metadata_is_none = self._get_metadata(state, sample_id)
144
143
  sample = DatasetSample(inputs=self._get_inputs(state, sample_id),
145
144
  gt=None if state == DataStateEnum.unlabeled else self._get_gt(state, sample_id),
146
- metadata=metadata,
147
- metadata_is_none=metadata_is_none,
145
+ metadata=self._get_metadata(state, sample_id),
148
146
  index=sample_id,
149
147
  state=state)
150
148
  return sample
@@ -212,7 +210,7 @@ class LeapLoader(LeapLoaderBase):
212
210
  state_name = state.name
213
211
  try:
214
212
  test_result = global_leap_binder.check_handler(
215
- preprocess_response, test_result, dataset_base_handler, state)
213
+ preprocess_response, test_result, dataset_base_handler)
216
214
  except Exception as e:
217
215
  line_number, file_name, stacktrace = get_root_exception_file_and_line_number()
218
216
  test_result[0].display[
@@ -341,11 +339,6 @@ class LeapLoader(LeapLoaderBase):
341
339
  continue
342
340
  if hasattr(handler_test_payload.raw_result, 'tolist'):
343
341
  handler_test_payload.raw_result = handler_test_payload.raw_result.tolist()
344
- if isinstance(handler_test_payload.raw_result, DatasetMetadataType):
345
- dataset_metadata_type = handler_test_payload.raw_result
346
- metadata_instances.append(DatasetMetadataInstance(name=handler_test_payload.name,
347
- type=dataset_metadata_type))
348
- continue
349
342
  metadata_type = type(handler_test_payload.raw_result)
350
343
  if metadata_type == int or isinstance(handler_test_payload.raw_result,
351
344
  (np.unsignedinteger, np.signedinteger)):
@@ -445,7 +438,7 @@ class LeapLoader(LeapLoaderBase):
445
438
  }
446
439
  return metadata_name_to_type
447
440
 
448
- def _convert_metadata_to_correct_type(self, metadata_name: str, value: Any) -> Tuple[Any, bool]:
441
+ def _convert_metadata_to_correct_type(self, metadata_name: str, value: Any) -> Any:
449
442
  metadata_name_to_type = self._metadata_name_to_type()
450
443
  metadata_type_to_python_type = {
451
444
  DatasetMetadataType.float: float,
@@ -454,26 +447,22 @@ class LeapLoader(LeapLoaderBase):
454
447
  DatasetMetadataType.int: int
455
448
  }
456
449
  metadata_type_to_default_value = {
457
- DatasetMetadataType.float: -1.0,
458
- DatasetMetadataType.string: 'None',
450
+ DatasetMetadataType.float: -1,
451
+ DatasetMetadataType.string: "",
459
452
  DatasetMetadataType.boolean: False,
460
453
  DatasetMetadataType.int: -1
461
454
  }
462
455
 
463
456
  try:
464
- is_none = False
465
- if value is None:
466
- raise ValueError()
467
457
  converted_value = metadata_type_to_python_type[metadata_name_to_type[metadata_name]](value)
468
458
  except ValueError:
469
- is_none = True
470
459
  converted_value = metadata_type_to_default_value[metadata_name_to_type[metadata_name]]
471
460
 
472
- return converted_value, is_none
461
+ return converted_value
473
462
 
474
- def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Tuple[Dict[str, Union[str, int, bool, float]], Dict[str, bool]]:
463
+ def _get_metadata(self, state: DataStateEnum, sample_id: Union[int, str]) -> Dict[
464
+ str, Union[str, int, bool, float]]:
475
465
  result_agg = {}
476
- is_none = {}
477
466
  preprocess_result = self._preprocess_result()
478
467
  preprocess_state = preprocess_result[state]
479
468
  for handler in global_leap_binder.setup_container.metadata:
@@ -481,14 +470,13 @@ class LeapLoader(LeapLoaderBase):
481
470
  if isinstance(handler_result, dict):
482
471
  for single_metadata_name, single_metadata_result in handler_result.items():
483
472
  handler_name = f'{handler.name}_{single_metadata_name}'
484
- result_agg[handler_name], is_none[handler_name] = self._convert_metadata_to_correct_type(
485
- handler_name, single_metadata_result)
473
+ result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name,
474
+ single_metadata_result)
486
475
  else:
487
476
  handler_name = handler.name
488
- result_agg[handler_name], is_none[handler_name] = self._convert_metadata_to_correct_type(
489
- handler_name, handler_result)
477
+ result_agg[handler_name] = self._convert_metadata_to_correct_type(handler_name, handler_result)
490
478
 
491
- return result_agg, is_none
479
+ return result_agg
492
480
 
493
481
  @lru_cache()
494
482
  def get_sample_id_type(self) -> Type:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.77.dev4
3
+ Version: 1.0.77.dev30
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,7 +1,7 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- code_loader/contract/datasetclasses.py,sha256=BjT7NoOBgWtBXtMRedKwwnC-vbHg-KHKqFL11GdgUx8,7902
4
+ code_loader/contract/datasetclasses.py,sha256=t1pvW68dVIbYB0bllqxHXO6jOgHCJI5NGn9rSrYC6Vw,7722
5
5
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
6
6
  code_loader/contract/exceptions.py,sha256=jWqu5i7t-0IG0jGRsKF4DjJdrsdpJjIYpUkN1F4RiyQ,51
7
7
  code_loader/contract/responsedataclasses.py,sha256=RSx9m_R3LawhK5o1nAcO3hfp2F9oJYtxZr_bpP3bTmw,4005
@@ -19,14 +19,16 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
19
19
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
20
20
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
21
21
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
22
- code_loader/inner_leap_binder/leapbinder.py,sha256=bn7H99bjttzAv3UaXR4cRludctUyTjeOk1-0-qDV8mw,29360
23
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=HUIVrI-I0aB-nQl7lZErE5vuhml7A8MMGaw_R4U4QCk,22053
24
- code_loader/leaploader.py,sha256=JfQsKjXWiy6K6ANj_HOqcb8_dcuagetmeoCVZ4MajE0,25668
22
+ code_loader/inner_leap_binder/inner_classes.py,sha256=DHaIaRUdplQPCaukKLG8RboW9KIkdO6kekeGw9PCmT4,726
23
+ code_loader/inner_leap_binder/leapbinder.py,sha256=jkSsoJwnNEmn8vlibjT4GczNuj4NAIM3_v-KZCLjPuw,27832
24
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=1oUrEbcY3GZkkaxVAdzlV4Vl3Y5XQ4aGwPWvSUb2hWM,22388
25
+ code_loader/inner_leap_binder/mapping.py,sha256=WD3p_dNMp8SQckUq62Yl0bpIBxoaZAkJszCl4IzWSms,10198
26
+ code_loader/leaploader.py,sha256=KkAPKyzZxZYI135_67q63ICVVERu7WH60liB0pboA-c,24983
25
27
  code_loader/leaploaderbase.py,sha256=VH0vddRmkqLtcDlYPCO7hfz1_VbKo43lUdHDAbd4iJc,4198
26
28
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
27
29
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
30
  code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
29
- code_loader-1.0.77.dev4.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
- code_loader-1.0.77.dev4.dist-info/METADATA,sha256=2RGRvw7LlvX2szuGL2WqQYctTeINmq98iaf1NzUCkiA,854
31
- code_loader-1.0.77.dev4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
- code_loader-1.0.77.dev4.dist-info/RECORD,,
31
+ code_loader-1.0.77.dev30.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
32
+ code_loader-1.0.77.dev30.dist-info/METADATA,sha256=ntsmieF-9tHybrMeZ_YqnkSvI5k9Pp3eUKi0uO-GZ0w,855
33
+ code_loader-1.0.77.dev30.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
34
+ code_loader-1.0.77.dev30.dist-info/RECORD,,