code-loader 1.0.65__py3-none-any.whl → 1.0.65.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,34 @@
1
+
2
+ from typing import List, Dict, Union, Callable
3
+ from enum import Enum, IntEnum
4
+ from dataclasses import dataclass, field
5
+ from typing import Union
6
+
7
+ class FunctionType(IntEnum):
8
+ gt = 0
9
+ visualizer = 1
10
+ input = 2
11
+ metric = 3
12
+ loss = 4
13
+ model_input = 5
14
+ model_output = 6
15
+
16
+ @dataclass(frozen=True)
17
+ class TLFunction:
18
+ FuncType: FunctionType
19
+ FuncName: str
20
+
21
+ MappingList = List[Dict[TLFunction, Dict[str, Union[Callable, int]]]]
22
+
23
+ def leap_input(idx):
24
+ def dummy():
25
+ return None
26
+ dummy.tl_func = TLFunction(FunctionType.model_input, idx)
27
+ return dummy
28
+
29
+ def leap_output(idx):
30
+ def dummy():
31
+ return None
32
+ dummy.tl_func = TLFunction(FunctionType.model_output, idx)
33
+ return dummy
34
+
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Callable, List, Optional, Dict, Any, Type, Union
2
+ from typing import Callable, List, Optional, Dict, Any, Type, Union, Tuple
3
3
 
4
4
  import numpy as np
5
5
  import numpy.typing as npt
@@ -20,6 +20,7 @@ from code_loader.visualizers.default_visualizers import DefaultVisualizer, \
20
20
  default_graph_visualizer, \
21
21
  default_image_visualizer, default_horizontal_bar_visualizer, default_word_visualizer, \
22
22
  default_image_mask_visualizer, default_text_mask_visualizer, default_raw_data_visualizer
23
+ from code_loader.inner_leap_binder.inner_classes import FunctionType, TLFunction, MappingList, leap_output, leap_input
23
24
 
24
25
 
25
26
  class LeapBinder:
@@ -35,6 +36,7 @@ class LeapBinder:
35
36
  def __init__(self) -> None:
36
37
  self.setup_container = DatasetIntegrationSetup()
37
38
  self.cache_container: Dict[str, Any] = {"word_to_index": {}}
39
+ self._mappings: MappingList = list()
38
40
  self._visualizer_names: List[str] = list()
39
41
  self._encoder_names: List[str] = list()
40
42
  self._extend_with_default_visualizers()
@@ -65,7 +67,8 @@ class LeapBinder:
65
67
  def set_visualizer(self, function: VisualizerCallableInterface,
66
68
  name: str,
67
69
  visualizer_type: LeapDataType,
68
- heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None) -> None:
70
+ heatmap_visualizer: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
71
+ connects_to = None) -> None: #TODO add types
69
72
  """
70
73
  Set a visualizer for a specific data type.
71
74
 
@@ -132,10 +135,11 @@ class LeapBinder:
132
135
  raise Exception(
133
136
  f'The return type of function {function.__name__} is invalid. current return type: {return_type}, '
134
137
  f'should be {expected_return_type}')
135
-
136
138
  self.setup_container.visualizers.append(
137
139
  VisualizerHandler(VisualizerHandlerData(name, visualizer_type, arg_names), function, heatmap_visualizer))
138
140
  self._visualizer_names.append(name)
141
+ if connects_to is not None:
142
+ self._mappings.append({TLFunction(FunctionType.visualizer, name): connects_to})
139
143
 
140
144
  def set_preprocess(self, function: Callable[[], List[PreprocessResponse]]) -> None:
141
145
  """
@@ -240,7 +244,8 @@ class LeapBinder:
240
244
  CustomMultipleReturnCallableInterfaceMultiArgs,
241
245
  ConfusionMatrixCallableInterfaceMultiArgs],
242
246
  name: str,
243
- direction: Optional[MetricDirection] = MetricDirection.Downward) -> None:
247
+ direction: Optional[MetricDirection] = MetricDirection.Downward,
248
+ connects_to = None) -> None: #TODO TOM add types
244
249
  """
245
250
  Add a custom metric to the setup.
246
251
 
@@ -260,6 +265,8 @@ class LeapBinder:
260
265
  """
261
266
  arg_names = inspect.getfullargspec(function)[0]
262
267
  self.setup_container.metrics.append(MetricHandler(MetricHandlerData(name, arg_names, direction), function))
268
+ if connects_to is not None:
269
+ self._mappings.append({TLFunction(FunctionType.metric, name): connects_to})
263
270
 
264
271
  def add_prediction(self, name: str, labels: List[str], channel_dim: int = -1) -> None:
265
272
  """
@@ -476,6 +483,9 @@ class LeapBinder:
476
483
  continue
477
484
  self.check_handler(preprocess_response, test_result, dataset_base_handler)
478
485
 
486
+ def set_model_inputs(self, connections: Tuple[Callable[..., Any], ...]):
487
+ self._mappings.append({TLFunction(FunctionType.model_input, "model"): {str(i): (connections[i], ) for i in range(len(connections))}})
488
+
479
489
  def check(self) -> None:
480
490
  preprocess_result = self.get_preprocess_result()
481
491
  self.check_preprocess(preprocess_result)
@@ -13,9 +13,10 @@ from code_loader.contract.enums import MetricDirection, LeapDataType
13
13
  from code_loader import leap_binder
14
14
  from code_loader.contract.visualizer_classes import LeapImage, LeapImageMask, LeapTextMask, LeapText, LeapGraph, \
15
15
  LeapHorizontalBar, LeapImageWithBBox, LeapImageWithHeatmap
16
+ from code_loader.inner_leap_binder.inner_classes import TLFunction, FunctionType
16
17
 
17
-
18
- def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = MetricDirection.Downward):
18
+ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = MetricDirection.Downward,
19
+ connects_to = None): #TODO TOM add type
19
20
  def decorating_function(
20
21
  user_function: Union[CustomCallableInterfaceMultiArgs,
21
22
  CustomMultipleReturnCallableInterfaceMultiArgs,
@@ -26,7 +27,7 @@ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = M
26
27
  raise Exception(f'Metric with name {name} already exists. '
27
28
  f'Please choose another')
28
29
 
29
- leap_binder.add_custom_metric(user_function, name, direction)
30
+ leap_binder.add_custom_metric(user_function, name, direction, connects_to=connects_to)
30
31
 
31
32
  def _validate_input_args(*args, **kwargs) -> None:
32
33
  for i, arg in enumerate(args):
@@ -91,14 +92,15 @@ def tensorleap_custom_metric(name: str, direction: Optional[MetricDirection] = M
91
92
 
92
93
 
93
94
  def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
94
- heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None):
95
+ heatmap_function: Optional[Callable[..., npt.NDArray[np.float32]]] = None,
96
+ connects_to = None): #TODO TOM ADD TYPES
95
97
  def decorating_function(user_function: VisualizerCallableInterface):
96
98
  for viz_handler in leap_binder.setup_container.visualizers:
97
99
  if viz_handler.visualizer_handler_data.name == name:
98
100
  raise Exception(f'Visualizer with name {name} already exists. '
99
101
  f'Please choose another')
100
102
 
101
- leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function)
103
+ leap_binder.set_visualizer(user_function, name, visualizer_type, heatmap_function, connects_to)
102
104
 
103
105
  def _validate_input_args(*args, **kwargs):
104
106
  for i, arg in enumerate(args):
@@ -137,6 +139,7 @@ def tensorleap_custom_visualizer(name: str, visualizer_type: LeapDataType,
137
139
  result = user_function(*args, **kwargs)
138
140
  _validate_result(result)
139
141
  return result
142
+ inner.tl_func = TLFunction(FunctionType.visualizer, name)
140
143
 
141
144
  return inner
142
145
 
@@ -285,6 +288,7 @@ def tensorleap_input_encoder(name: str, channel_dim=-1):
285
288
  result = user_function(sample_id, preprocess_response)
286
289
  _validate_result(result)
287
290
  return result
291
+ inner.tl_func = TLFunction(FunctionType.input, name)
288
292
 
289
293
  return inner
290
294
 
@@ -325,7 +329,7 @@ def tensorleap_gt_encoder(name: str):
325
329
  result = user_function(sample_id, preprocess_response)
326
330
  _validate_result(result)
327
331
  return result
328
-
332
+ inner.tl_func = TLFunction(FunctionType.gt, name)
329
333
  return inner
330
334
 
331
335
  return decorating_function
@@ -379,6 +383,7 @@ def tensorleap_custom_loss(name: str):
379
383
  _validate_result(result)
380
384
  return result
381
385
 
386
+ inner.tl_func = TLFunction(FunctionType.loss, name)
382
387
  return inner
383
388
 
384
389
  return decorating_function
@@ -0,0 +1,219 @@
1
+ from typing import Optional, Callable, Any, Dict, List, Set, OrderedDict
2
+ from copy import deepcopy
3
+ import numpy as np
4
+ from collections import OrderedDict as od
5
+ from code_loader.inner_leap_binder.inner_classes import FunctionType, TLFunction, MappingList
6
+ from code_loader.contract.datasetclasses import DatasetIntegrationSetup, VisualizerHandler, MetricHandler
7
+ from code_loader.contract.datasetclasses import DatasetBaseHandler, GroundTruthHandler, InputHandler
8
+
9
+
10
+ class LeapMapping:
11
+
12
+ def __init__(self, setup_container: DatasetIntegrationSetup, mapping: MappingList):
13
+ self.container : DatasetIntegrationSetup = setup_container
14
+ self.mapping: MappingList = mapping
15
+ self.max_id = 20000
16
+ self.current_id = self.max_id
17
+ self.mapping_json = []
18
+ self.handle_dict = {
19
+ FunctionType.gt : self.container.ground_truths,
20
+ FunctionType.metric : self.container.metrics,
21
+ FunctionType.visualizer: self.container.visualizers,
22
+ FunctionType.input: self.container.inputs,
23
+ }
24
+ self.handle_to_op_type = {
25
+ GroundTruthHandler: "GroundTruth",
26
+ InputHandler: "Input",
27
+ VisualizerHandler: "Visualizer"
28
+
29
+ }
30
+ self.function_to_json : Dict[TLFunction, List[Dict[str, Any]]] = {}
31
+
32
+ def create_mapping_list(self, connections_list): #TODO add types
33
+ pass
34
+
35
+ def create_connect_nodes(self):
36
+ for element in self.mapping:
37
+ for sink in element.keys():
38
+ for node_name, connected_func_list in element[sink].items():
39
+ tl_connected_functions = [ connected_func.tl_func for connected_func in connected_func_list
40
+ if not isinstance(connected_func, int)]
41
+ for i, source in enumerate(tl_connected_functions):
42
+ print(sink, source, node_name)
43
+ sink_node = self.create_node(sink, name=node_name)
44
+ self.add_connections(sink_node, source, input_idx=i)
45
+ print(1)
46
+ import yaml
47
+ with open('temp.yaml', 'w') as yf:
48
+ yaml.safe_dump(({'decoder': self.mapping_json}), yf, default_flow_style=False, sort_keys=False)
49
+
50
+ def create_partial_mapping(self): #TODO add types
51
+ referenced = self._get_referenced_sinks()
52
+ for func in referenced:
53
+ kwargs = {}
54
+ if func.FuncType == FunctionType.model_output:
55
+ kwargs = {'name': func.FuncName}
56
+ self.create_node(func, **kwargs)
57
+ self.create_connect_nodes()
58
+
59
+ def _get_referenced_sinks(self) -> Set[TLFunction]:
60
+ referenced : Set[TLFunction] = set()
61
+ references_types = {FunctionType.input, FunctionType.gt, FunctionType.model_output}
62
+ for element in self.mapping:
63
+ for func in element.keys():
64
+ for node_name, connected_func_list in element[func].items():
65
+ tl_connected_functions = [ connected_func.tl_func for connected_func in connected_func_list
66
+ if not isinstance(connected_func, int)]
67
+ for connected_func in tl_connected_functions:
68
+ if connected_func.FuncType in references_types:
69
+ referenced.add(connected_func)
70
+ return referenced
71
+
72
+ def create_gt_input(self, handle: DatasetBaseHandler, **kwargs) -> Dict[str, Any]:
73
+ op_type = self.handle_to_op_type[type(handle)]
74
+ json_node = {'operation': op_type,
75
+ 'data': {'type': op_type, 'output_name': handle.name},
76
+ 'id': f'{str(self.current_id)}',
77
+ 'inputs': {},
78
+ 'outputs': {handle.name: []}}
79
+ return json_node
80
+
81
+ def create_visualize(self, visualize_handle: VisualizerHandler ,**kwargs) -> Dict[str, Any]:
82
+ op_type = "Visualizer"
83
+ json_node = {'operation': op_type,
84
+ 'data': {'type': op_type,
85
+ 'name': visualize_handle.name,
86
+ 'visualizer_name': visualize_handle.name,
87
+ 'visualizer_type': visualize_handle.type.name,
88
+ 'arg_names': deepcopy(visualize_handle.arg_names),
89
+ 'user_unique_name': kwargs['name']},
90
+ 'id': f'{str(self.current_id)}',
91
+ 'inputs': {arg: [] for arg in visualize_handle.arg_names},
92
+ 'outputs': {} }
93
+ return json_node
94
+
95
+ def create_metric(self, metric_handle: MetricHandler, **kwargs) -> Dict[str, Any]:
96
+ op_type = "Metric"
97
+ json_node = {'operation': op_type,
98
+ 'data': {'type': op_type,
99
+ 'name': metric_handle.name,
100
+ 'metric_name': metric_handle.name,
101
+ 'arg_names': deepcopy(metric_handle.arg_names),
102
+ 'user_unique_name': kwargs['name']},
103
+ 'id': f'{str(self.current_id)}',
104
+ 'inputs': {arg: [] for arg in metric_handle.arg_names},
105
+ 'outputs': {} }
106
+ return json_node
107
+
108
+
109
+ def create_temp_model_node(self, node_type: TLFunction, **kwargs) -> Dict[str, Any]:
110
+ default_name = self.get_default_name(node_type.FuncType, kwargs['name'])
111
+ op_type = "PlaceHolder"
112
+ json_node = {'operation': op_type,
113
+ 'data': {'type': op_type,
114
+ 'name': default_name,
115
+ 'arg_names': [default_name],
116
+ 'user_unique_name': default_name},
117
+ 'id': f'{default_name}',
118
+ 'inputs': {},
119
+ 'outputs': {} }
120
+ if node_type.FuncType == FunctionType.model_input:
121
+ json_node['inputs'] = {default_name: []}
122
+ else:
123
+ json_node['outputs'] = {default_name: []}
124
+ return json_node
125
+
126
+ def _find_handle(self, node: TLFunction) -> DatasetBaseHandler:
127
+ handle_list = self.handle_dict[node.FuncType]
128
+ chosen_handle: Optional[DatasetBaseHandler] = None
129
+ for handle in handle_list:
130
+ if handle.name == node.FuncName:
131
+ chosen_handle = handle
132
+ assert chosen_handle is not None
133
+ return chosen_handle
134
+
135
+ def _find_json_node(self, node: TLFunction, **kwargs) -> Dict[str, Any]:
136
+ candidates = self.function_to_json[node]
137
+ if len(candidates) == 1:
138
+ return candidates[0]
139
+ elif len(candidates) > 1:
140
+ raise Exception("More than 1 candidate - solve")
141
+ else:
142
+ raise Exception(f"No json node creted for node {node}")
143
+
144
+ def _get_node_id(self, node: TLFunction, **kwargs):
145
+ node = self._find_json_node(node, **kwargs)
146
+ return node['id']
147
+
148
+ def create_node(self, node_type: TLFunction, **kwargs) -> Dict[str, Any]:
149
+ if node_type.FuncType in [FunctionType.model_output, FunctionType.model_input]:
150
+ json_node = self.create_temp_model_node(node_type, **kwargs)
151
+ else:
152
+ node_creator_map: Dict[FunctionType, Callable[..., Any]] = { FunctionType.gt : self.create_gt_input,
153
+ FunctionType.input: self.create_gt_input,
154
+ FunctionType.visualizer: self.create_visualize,
155
+ FunctionType.metric: self.create_metric}
156
+ handle = self._find_handle(node_type)
157
+ json_node = node_creator_map[node_type.FuncType](handle, **kwargs)
158
+ self.current_id += 1
159
+ self.function_to_json[node_type] = self.function_to_json.get(node_type, []) + [json_node]
160
+ self.mapping_json.append(json_node)
161
+ return json_node
162
+
163
+ def get_default_name(self, f_type: FunctionType, name: str):
164
+ return f"{str(f_type)}_{name}"
165
+
166
+ def _get_node_output_name(self, node: TLFunction):
167
+ if node.FuncType not in {FunctionType.model_output, FunctionType.model_input}:
168
+ from_handle = self._find_handle(node)
169
+ output_name = from_handle.name
170
+ else:
171
+ output_name = self.get_default_name(node.FuncType, node.FuncName)
172
+ return output_name
173
+
174
+
175
+ def add_connections(self, to_json: Dict[str, Any], from_node: TLFunction, input_idx: int):
176
+ try:
177
+ input_name = to_json['data']['arg_names'][input_idx]
178
+ except Exception as e:
179
+ raise Exception(f"arg names mismatch when adding input for {to_json['operation']}")
180
+ output_key = self._get_node_output_name(from_node)
181
+ from_json = self._find_json_node(from_node)
182
+ operation = from_json['operation']
183
+ from_id = from_json['id']
184
+ # Add Input
185
+ to_json['inputs'][input_name].append({'outputKey': output_key,
186
+ 'operation': operation,
187
+ 'id': from_id
188
+ })
189
+ # Add Output
190
+ from_json['outputs'][output_key].append({'inputKey': input_name,
191
+ 'operation': to_json['operation'],
192
+ 'id': to_json['id']
193
+ })
194
+
195
+ def add_output(self, input_node, output_node):
196
+ pass
197
+
198
+ def create_mapping(model_parser_map, connections_list):
199
+
200
+ referenced_gt = []
201
+ referenced_inputs = []
202
+ for element in connections_list:
203
+ for func in element.keys():
204
+ for node_name, connected_func_list in element[func].items():
205
+ tl_connected_functions = [ connected_func.tl_func if not isinstance(connected_func, int) else connected_func
206
+ for connected_func in connected_func_list ]
207
+
208
+
209
+
210
+
211
+
212
+ # Check Max Node ID
213
+ # Get Preds Node ID
214
+
215
+ # TODO Create Inputs (if referenced)
216
+ # TODO Create Gts (if referenced)
217
+ # TODO Create Visualizers
218
+ # TODO Create Metrics
219
+ # TODO Create losses
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: code-loader
3
- Version: 1.0.65
3
+ Version: 1.0.65.dev1
4
4
  Summary:
5
5
  Home-page: https://github.com/tensorleap/code-loader
6
6
  License: MIT
@@ -1,6 +1,5 @@
1
1
  LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
2
2
  code_loader/__init__.py,sha256=6MMWr0ObOU7hkqQKgOqp4Zp3I28L7joGC9iCbQYtAJg,241
3
- code_loader/code_inegration_processes_manager.py,sha256=XslWOPeNQk4RAFJ_f3tP5Oe3EgcIR7BE7Y8r9Ty73-o,3261
4
3
  code_loader/contract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
4
  code_loader/contract/datasetclasses.py,sha256=L_fSdSvf-eKoez2uBJ8VjfrKedEP0szNOPvaUvsWeRQ,6973
6
5
  code_loader/contract/enums.py,sha256=6Lo7p5CUog68Fd31bCozIuOgIp_IhSiPqWWph2k3OGU,1602
@@ -19,14 +18,16 @@ code_loader/experiment_api/types.py,sha256=MY8xFARHwdVA7p4dxyhD60ShmttgTvb4qdp1o
19
18
  code_loader/experiment_api/utils.py,sha256=XZHtxge12TS4H4-8PjV3sKuhp8Ud6ojAiIzTZJEqBqc,3304
20
19
  code_loader/experiment_api/workingspace_config_utils.py,sha256=DLzXQCg4dgTV_YgaSbeTVzq-2ja_SQw4zi7LXwKL9cY,990
21
20
  code_loader/inner_leap_binder/__init__.py,sha256=koOlJyMNYzGbEsoIbXathSmQ-L38N_pEXH_HvL7beXU,99
22
- code_loader/inner_leap_binder/leapbinder.py,sha256=y_k7bRFYYmrZo4jCJrZ6mJykxc1slKDkODYCZ58OPs0,25691
23
- code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=I6ipji6QMN9qqFYxqQyNjtDtsnwpB-NahgKKpLPctMo,21067
21
+ code_loader/inner_leap_binder/inner_classes.py,sha256=DHaIaRUdplQPCaukKLG8RboW9KIkdO6kekeGw9PCmT4,726
22
+ code_loader/inner_leap_binder/leapbinder.py,sha256=jy3o-qiuealyovpKnpTSB89_o1jCUwC2kwD9UN2W2CE,26458
23
+ code_loader/inner_leap_binder/leapbinder_decorators.py,sha256=Q08khMTphdf75iHm_Wqata9RlshkpbUX3EdZJ7bNKdw,21570
24
+ code_loader/inner_leap_binder/mapping.py,sha256=WD3p_dNMp8SQckUq62Yl0bpIBxoaZAkJszCl4IzWSms,10198
24
25
  code_loader/leaploader.py,sha256=K__WKfqtKwEch40au177Po10EUX7gm0PJzcV6kpUMlo,22212
25
26
  code_loader/leaploaderbase.py,sha256=aHlqWDZRacIdBefeB9goYVnpApaNN2FT24uPIWKkCeQ,3090
26
27
  code_loader/utils.py,sha256=aw2i_fqW_ADjLB66FWZd9DfpCQ7mPdMyauROC5Nd51I,2197
27
28
  code_loader/visualizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
29
  code_loader/visualizers/default_visualizers.py,sha256=Ffx5VHVOe5ujBOsjBSxN_aIEVwFSQ6gbhTMG5aUS-po,2305
29
- code_loader-1.0.65.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
30
- code_loader-1.0.65.dist-info/METADATA,sha256=hWg-H6gagFGaIzqiHVFmxRXsk568WHJLKzby9YHWlYc,849
31
- code_loader-1.0.65.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
32
- code_loader-1.0.65.dist-info/RECORD,,
30
+ code_loader-1.0.65.dev1.dist-info/LICENSE,sha256=qIwWjdspQeSMTtnFZBC8MuT-95L02FPvzRUdWFxrwJY,1067
31
+ code_loader-1.0.65.dev1.dist-info/METADATA,sha256=nRVyHiEIMMtkav6kNkHO1CiiVceo7bxzRCuDOr6v-rg,854
32
+ code_loader-1.0.65.dev1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
33
+ code_loader-1.0.65.dev1.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.9.0
2
+ Generator: poetry-core 1.9.1
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,83 +0,0 @@
1
- # mypy: ignore-errors
2
- import traceback
3
- from dataclasses import dataclass
4
-
5
- from typing import List, Tuple, Optional
6
-
7
- from multiprocessing import Process, Queue
8
-
9
- from code_loader.leap_loader_parallelized_base import LeapLoaderParallelizedBase
10
- from code_loader.leaploader import LeapLoader
11
- from code_loader.contract.enums import DataStateEnum
12
- from code_loader.metric_calculator_parallelized import MetricCalculatorParallelized
13
- from code_loader.samples_generator_parallelized import SamplesGeneratorParallelized
14
-
15
-
16
- @dataclass
17
- class SampleSerializableError:
18
- state: DataStateEnum
19
- index: int
20
- leap_script_trace: str
21
- exception_as_str: str
22
-
23
-
24
- class CodeIntegrationProcessesManager:
25
- def __init__(self, code_path: str, code_entry_name: str, n_workers: Optional[int] = 2,
26
- max_samples_in_queue: int = 128) -> None:
27
- self.metric_calculator_parallelized = MetricCalculatorParallelized(code_path, code_entry_name)
28
- self.samples_generator_parallelized = SamplesGeneratorParallelized(code_path, code_entry_name)
29
-
30
- def _create_and_start_process(self) -> Process:
31
- process = self.multiprocessing_context.Process(
32
- target=CodeIntegrationProcessesManager._process_func,
33
- args=(self.code_path, self.code_entry_name, self._inputs_waiting_to_be_process,
34
- self._ready_processed_results))
35
- process.daemon = True
36
- process.start()
37
- return process
38
-
39
- def _run_and_warm_first_process(self):
40
- process = self._create_and_start_process()
41
- self.processes = [process]
42
-
43
- # needed in order to make sure the preprocess func runs once in nonparallel
44
- self._start_process_inputs([(DataStateEnum.training, 0)])
45
- self._get_next_ready_processed_result()
46
-
47
- def _operation_decider(self):
48
- if self.metric_calculator_parallelized._ready_processed_results.empty() and not \
49
- self.metric_calculator_parallelized._inputs_waiting_to_be_process.empty():
50
- return 'metric'
51
-
52
- if self.samples_generator_parallelized._ready_processed_results.empty() and not \
53
- self.samples_generator_parallelized._inputs_waiting_to_be_process.empty():
54
- return 'dataset'
55
-
56
-
57
-
58
-
59
- @staticmethod
60
- def _process_func(code_path: str, code_entry_name: str,
61
- samples_to_process: Queue, ready_samples: Queue,
62
- metrics_to_process: Queue, ready_metrics: Queue) -> None:
63
- import os
64
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
65
-
66
- leap_loader = LeapLoader(code_path, code_entry_name)
67
- while True:
68
-
69
- # decide on sample or metric to process
70
- state, idx = samples_to_process.get(block=True)
71
- leap_loader._preprocess_result()
72
- try:
73
- sample = leap_loader.get_sample(state, idx)
74
- except Exception as e:
75
- leap_script_trace = traceback.format_exc().split('File "<string>"')[-1]
76
- ready_samples.put(SampleSerializableError(state, idx, leap_script_trace, str(e)))
77
- continue
78
-
79
- ready_samples.put(sample)
80
-
81
- def generate_samples(self, sample_identities: List[Tuple[DataStateEnum, int]]):
82
- return self.start_process_inputs(sample_identities)
83
-