esd-services-api-client 2.0.0__py3-none-any.whl → 2.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. esd_services_api_client/_version.py +1 -1
  2. esd_services_api_client/crystal/_connector.py +7 -7
  3. esd_services_api_client/nexus/README.md +280 -0
  4. esd_services_api_client/nexus/__init__.py +18 -0
  5. esd_services_api_client/nexus/abstractions/__init__.py +18 -0
  6. esd_services_api_client/nexus/abstractions/logger_factory.py +64 -0
  7. esd_services_api_client/nexus/abstractions/nexus_object.py +64 -0
  8. esd_services_api_client/nexus/abstractions/socket_provider.py +51 -0
  9. esd_services_api_client/nexus/algorithms/__init__.py +22 -0
  10. esd_services_api_client/nexus/algorithms/_baseline_algorithm.py +56 -0
  11. esd_services_api_client/nexus/algorithms/distributed.py +53 -0
  12. esd_services_api_client/nexus/algorithms/minimalistic.py +44 -0
  13. esd_services_api_client/nexus/algorithms/recursive.py +58 -0
  14. esd_services_api_client/nexus/core/__init__.py +18 -0
  15. esd_services_api_client/nexus/core/app_core.py +259 -0
  16. esd_services_api_client/nexus/core/app_dependencies.py +205 -0
  17. esd_services_api_client/nexus/exceptions/__init__.py +20 -0
  18. esd_services_api_client/nexus/exceptions/_nexus_error.py +30 -0
  19. esd_services_api_client/nexus/exceptions/input_reader_error.py +51 -0
  20. esd_services_api_client/nexus/exceptions/startup_error.py +48 -0
  21. esd_services_api_client/nexus/input/__init__.py +22 -0
  22. esd_services_api_client/nexus/input/input_processor.py +94 -0
  23. esd_services_api_client/nexus/input/input_reader.py +109 -0
  24. esd_services_api_client/nexus/input/payload_reader.py +83 -0
  25. {esd_services_api_client-2.0.0.dist-info → esd_services_api_client-2.0.2.dist-info}/METADATA +5 -2
  26. esd_services_api_client-2.0.2.dist-info/RECORD +43 -0
  27. esd_services_api_client-2.0.0.dist-info/RECORD +0 -21
  28. {esd_services_api_client-2.0.0.dist-info → esd_services_api_client-2.0.2.dist-info}/LICENSE +0 -0
  29. {esd_services_api_client-2.0.0.dist-info → esd_services_api_client-2.0.2.dist-info}/WHEEL +0 -0
@@ -1 +1 @@
1
- __version__ = '2.0.0'
1
+ __version__ = '2.0.2'
@@ -282,19 +282,19 @@ class CrystalConnector:
282
282
  "sasUri": result.sas_uri,
283
283
  }
284
284
 
285
- if debug and self._logger is not None:
285
+ if not debug:
286
+ run_response = self._http.post(url=get_api_path(), json=payload)
287
+ # raise if not successful
288
+ run_response.raise_for_status()
289
+ return
290
+
291
+ if self._logger is not None:
286
292
  self._logger.debug(
287
293
  "Submitting result to {submission_url}, payload {payload}",
288
294
  submission_url=get_api_path(),
289
295
  payload=json.dumps(payload),
290
296
  )
291
297
 
292
- else:
293
- run_response = self._http.post(url=get_api_path(), json=payload)
294
-
295
- # raise if not successful
296
- run_response.raise_for_status()
297
-
298
298
  @staticmethod
299
299
  def read_input(
300
300
  *,
@@ -0,0 +1,280 @@
1
+ ## Nexus
2
+ Set the following environment variables for Azure:
3
+ ```
4
+ IS_LOCAL_RUN=1
5
+ NEXUS__ALGORITHM_OUTPUT_PATH=abfss://container@account.dfs.core.windows.net/path/to/result
6
+ NEXUS__METRIC_PROVIDER_CONFIGURATION={"metric_namespace": "test"}
7
+ NEXUS__QES_CONNECTION_STRING=qes://engine\=DELTA\;plaintext_credentials\={"auth_client_class":"adapta.security.clients.AzureClient"}\;settings\={}
8
+ NEXUS__STORAGE_CLIENT_CLASS=adapta.storage.blob.azure_storage_client.AzureStorageClient
9
+ NEXUS__ALGORITHM_INPUT_EXTERNAL_DATA_SOCKETS=[{"alias": "x", "data_path": "test/x", "data_format": "test"}, {"alias": "y", "data_path": "test/y", "data_format": "test"}]
10
+ PROTEUS__USE_AZURE_CREDENTIAL=1
11
+ ```
12
+
13
+ Example usage:
14
+
15
+ ```python
16
+ import asyncio
17
+ import json
18
+ import socketserver
19
+ import threading
20
+ from dataclasses import dataclass
21
+ from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
22
+ from typing import Dict, Optional
23
+
24
+ import pandas
25
+ from adapta.metrics import MetricsProvider
26
+ from adapta.storage.query_enabled_store import QueryEnabledStore
27
+ from dataclasses_json import DataClassJsonMixin
28
+ from injector import inject
29
+
30
+ from esd_services_api_client.nexus.abstractions.logger_factory import LoggerFactory
31
+ from esd_services_api_client.nexus.abstractions.socket_provider import (
32
+ ExternalSocketProvider,
33
+ )
34
+ from esd_services_api_client.nexus.core.app_core import Nexus
35
+ from esd_services_api_client.nexus.algorithms import MinimalisticAlgorithm
36
+ from esd_services_api_client.nexus.input import InputReader, InputProcessor
37
+ from pandas import DataFrame as PandasDataFrame
38
+
39
+ from esd_services_api_client.nexus.input.payload_reader import AlgorithmPayload
40
+
41
+
42
+ async def my_on_complete_func_1(**kwargs):
43
+ pass
44
+
45
+
46
+ async def my_on_complete_func_2(**kwargs):
47
+ pass
48
+
49
+
50
+ @dataclass
51
+ class MyAlgorithmPayload(AlgorithmPayload, DataClassJsonMixin):
52
+ x: Optional[list[int]] = None
53
+ y: Optional[list[int]] = None
54
+
55
+
56
+ @dataclass
57
+ class MyAlgorithmPayload2(AlgorithmPayload, DataClassJsonMixin):
58
+ z: list[int]
59
+ x: Optional[list[int]] = None
60
+ y: Optional[list[int]] = None
61
+
62
+
63
+ class MockRequestHandler(BaseHTTPRequestHandler):
64
+ """
65
+ HTTPServer Mock Request handler
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ request: bytes,
71
+ client_address: tuple[str, int],
72
+ server: socketserver.BaseServer,
73
+ ):
74
+ """
75
+ Initialize request handler
76
+ :param request:
77
+ :param client_address:
78
+ :param server:
79
+ """
80
+ self._responses = {
81
+ "some/payload": (
82
+ {
83
+ # "x": [-1, 0, 2],
84
+ # "y": [10, 11, 12],
85
+ "z": [1, 2, 3]
86
+ },
87
+ 200,
88
+ )
89
+ }
90
+ super().__init__(request, client_address, server)
91
+
92
+ def do_GET(self): # pylint: disable=invalid-name
93
+ """Handle POST requests"""
94
+ current_url = self.path.removeprefix("/")
95
+
96
+ if current_url not in self._responses:
97
+ self.send_response(500, "Unknown URL")
98
+ return
99
+
100
+ self.send_response(self._responses[current_url][1])
101
+ self.send_header("Content-Type", "application/json")
102
+ self.end_headers()
103
+ self.wfile.write(json.dumps(self._responses[current_url][0]).encode("utf-8"))
104
+
105
+ def log_request(self, code=None, size=None):
106
+ """
107
+ Don't log anything
108
+ :param code:
109
+ :param size:
110
+ :return:
111
+ """
112
+ pass
113
+
114
+
115
+ class XReader(InputReader[MyAlgorithmPayload]):
116
+ async def _context_open(self):
117
+ pass
118
+
119
+ async def _context_close(self):
120
+ pass
121
+
122
+ @inject
123
+ def __init__(
124
+ self,
125
+ store: QueryEnabledStore,
126
+ metrics_provider: MetricsProvider,
127
+ logger_factory: LoggerFactory,
128
+ payload: MyAlgorithmPayload,
129
+ socket_provider: ExternalSocketProvider,
130
+ *readers: "InputReader"
131
+ ):
132
+ super().__init__(
133
+ socket_provider.socket("x"),
134
+ store,
135
+ metrics_provider,
136
+ logger_factory,
137
+ payload,
138
+ *readers
139
+ )
140
+
141
+ async def _read_input(self) -> PandasDataFrame:
142
+ self._logger.info(
143
+ "Payload: {payload}; Socket path: {socket_path}",
144
+ payload=self._payload.to_json(),
145
+ socket_path=self.socket.data_path,
146
+ )
147
+ return pandas.DataFrame([{"a": 1, "b": 2}, {"a": 2, "b": 3}])
148
+
149
+
150
+ class YReader(InputReader[MyAlgorithmPayload2]):
151
+ async def _context_open(self):
152
+ pass
153
+
154
+ async def _context_close(self):
155
+ pass
156
+
157
+ @inject
158
+ def __init__(
159
+ self,
160
+ store: QueryEnabledStore,
161
+ metrics_provider: MetricsProvider,
162
+ logger_factory: LoggerFactory,
163
+ payload: MyAlgorithmPayload2,
164
+ socket_provider: ExternalSocketProvider,
165
+ *readers: "InputReader"
166
+ ):
167
+ super().__init__(
168
+ socket_provider.socket("y"),
169
+ store,
170
+ metrics_provider,
171
+ logger_factory,
172
+ payload,
173
+ *readers
174
+ )
175
+
176
+ async def _read_input(self) -> PandasDataFrame:
177
+ self._logger.info(
178
+ "Payload: {payload}; Socket path: {socket_path}",
179
+ payload=self._payload.to_json(),
180
+ socket_path=self.socket.data_path,
181
+ )
182
+ return pandas.DataFrame([{"a": 10, "b": 12}, {"a": 11, "b": 13}])
183
+
184
+
185
+ class MyInputProcessor(InputProcessor):
186
+ async def _context_open(self):
187
+ pass
188
+
189
+ async def _context_close(self):
190
+ pass
191
+
192
+ @inject
193
+ def __init__(
194
+ self,
195
+ x: XReader,
196
+ y: YReader,
197
+ metrics_provider: MetricsProvider,
198
+ logger_factory: LoggerFactory,
199
+ ):
200
+ super().__init__(
201
+ x,
202
+ y,
203
+ metrics_provider=metrics_provider,
204
+ logger_factory=logger_factory,
205
+ payload=None,
206
+ )
207
+
208
+ async def process_input(self, **_) -> Dict[str, PandasDataFrame]:
209
+ inputs = await self._read_input()
210
+ return {
211
+ "x_ready": inputs["x"].assign(c=[-1, 1]),
212
+ "y_ready": inputs["y"].assign(c=[-1, 1]),
213
+ }
214
+
215
+
216
+ class MyAlgorithm(MinimalisticAlgorithm):
217
+ async def _context_open(self):
218
+ pass
219
+
220
+ async def _context_close(self):
221
+ pass
222
+
223
+ @inject
224
+ def __init__(
225
+ self,
226
+ input_processor: MyInputProcessor,
227
+ metrics_provider: MetricsProvider,
228
+ logger_factory: LoggerFactory,
229
+ ):
230
+ super().__init__(input_processor, metrics_provider, logger_factory)
231
+
232
+ async def _run(
233
+ self, x_ready: PandasDataFrame, y_ready: PandasDataFrame, **kwargs
234
+ ) -> PandasDataFrame:
235
+ return pandas.concat([x_ready, y_ready])
236
+
237
+
238
+ async def main():
239
+ """
240
+ Mock HTTP Server
241
+ :return:
242
+ """
243
+ with ThreadingHTTPServer(("localhost", 9876), MockRequestHandler) as server:
244
+ server_thread = threading.Thread(target=server.serve_forever)
245
+ server_thread.daemon = True
246
+ server_thread.start()
247
+ nexus = (
248
+ await Nexus.create()
249
+ .add_reader(XReader)
250
+ .add_reader(YReader)
251
+ .use_processor(MyInputProcessor)
252
+ .use_algorithm(MyAlgorithm)
253
+ .inject_payload(MyAlgorithmPayload, MyAlgorithmPayload2)
254
+ )
255
+
256
+ await nexus.activate()
257
+ server.shutdown()
258
+
259
+
260
+ if __name__ == "__main__":
261
+ asyncio.run(main())
262
+
263
+ ```
264
+
265
+ Run this code as `sample.py`:
266
+
267
+ ```shell
268
+ python3 sample.py --sas-uri http://localhost:9876/some/payload --request-id test
269
+ ```
270
+
271
+ Produces the following:
272
+
273
+ ```
274
+ Running _read
275
+ Payload: {"x": null, "y": null}; Socket path: test/x
276
+ Finished reading X from path test/x in 0.00s seconds
277
+ Running _read
278
+ Payload: {"z": [1, 2, 3], "x": null, "y": null}; Socket path: test/y
279
+ Finished reading Y from path test/y in 0.00s seconds
280
+ ```
@@ -0,0 +1,18 @@
1
+ """
2
+ Import index.
3
+ """
4
+
5
+ # Copyright (c) 2023. ECCO Sneaks & Data
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
@@ -0,0 +1,18 @@
1
+ """
2
+ Import index.
3
+ """
4
+
5
+ # Copyright (c) 2023. ECCO Sneaks & Data
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
@@ -0,0 +1,64 @@
1
+ """
2
+ Logger factory for async loggers.
3
+ """
4
+
5
+ # Copyright (c) 2023. ECCO Sneaks & Data
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+
20
+ import json
21
+ import os
22
+ from logging import StreamHandler
23
+ from typing import final, Type, TypeVar, Optional, Dict
24
+
25
+ from adapta.logs._async_logger import _AsyncLogger, create_async_logger
26
+ from adapta.logs.handlers.datadog_api_handler import DataDogApiHandler
27
+ from adapta.logs.models import LogLevel
28
+
29
+ TLogger = TypeVar("TLogger") # pylint: disable=C0103:
30
+
31
+
32
+ @final
33
+ class LoggerFactory:
34
+ """
35
+ Async logger provisioner.
36
+ """
37
+
38
+ def __init__(self):
39
+ self._log_handlers = [
40
+ StreamHandler(),
41
+ ]
42
+ if "NEXUS__DATADOG_LOGGER_CONFIGURATION" in os.environ:
43
+ self._log_handlers.append(
44
+ DataDogApiHandler(
45
+ **json.loads(os.getenv("NEXUS__DATADOG_LOGGER_CONFIGURATION"))
46
+ )
47
+ )
48
+
49
+ def create_logger(
50
+ self,
51
+ logger_type: Type[TLogger],
52
+ fixed_template: Optional[Dict[str, Dict[str, str]]] = None,
53
+ fixed_template_delimiter=", ",
54
+ ) -> _AsyncLogger[TLogger]:
55
+ """
56
+ Creates an async-safe logger for the provided class name.
57
+ """
58
+ return create_async_logger(
59
+ logger_type=logger_type,
60
+ log_handlers=self._log_handlers,
61
+ min_log_level=LogLevel(os.getenv("NEXUS__LOG_LEVEL", "INFO")),
62
+ fixed_template=fixed_template,
63
+ fixed_template_delimiter=fixed_template_delimiter,
64
+ )
@@ -0,0 +1,64 @@
1
+ """
2
+ Base classes for all objects used by Nexus.
3
+ """
4
+
5
+ # Copyright (c) 2023. ECCO Sneaks & Data
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+
20
+
21
+ from abc import ABC, abstractmethod
22
+ from typing import Generic, TypeVar
23
+
24
+ from adapta.metrics import MetricsProvider
25
+
26
+ from esd_services_api_client.nexus.abstractions.logger_factory import LoggerFactory
27
+
28
+
29
+ TPayload = TypeVar("TPayload") # pylint: disable=C0103
30
+
31
+
32
+ class NexusObject(Generic[TPayload], ABC):
33
+ """
34
+ Base class for all Nexus objects.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ metrics_provider: MetricsProvider,
40
+ logger_factory: LoggerFactory,
41
+ ):
42
+ self._metrics_provider = metrics_provider
43
+ self._logger = logger_factory.create_logger(logger_type=self.__class__)
44
+
45
+ async def __aenter__(self):
46
+ self._logger.start()
47
+ await self._context_open()
48
+ return self
49
+
50
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
51
+ self._logger.stop()
52
+ await self._context_close()
53
+
54
+ @abstractmethod
55
+ async def _context_open(self):
56
+ """
57
+ Optional actions to perform on context activation.
58
+ """
59
+
60
+ @abstractmethod
61
+ async def _context_close(self):
62
+ """
63
+ Optional actions to perform on context closure.
64
+ """
@@ -0,0 +1,51 @@
1
+ """
2
+ Socket provider for all data sockets used by algorithms.
3
+ """
4
+ import json
5
+
6
+ # Copyright (c) 2023. ECCO Sneaks & Data
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ #
20
+
21
+ from typing import final, Optional
22
+
23
+ from adapta.process_communication import DataSocket
24
+
25
+
26
+ @final
27
+ class ExternalSocketProvider:
28
+ """
29
+ Wraps a socket collection
30
+ """
31
+
32
+ def __init__(self, *sockets: DataSocket):
33
+ self._sockets = {socket.alias: socket for socket in sockets}
34
+
35
+ def socket(self, name: str) -> Optional[DataSocket]:
36
+ """
37
+ Retrieve a socket if it exists.
38
+ """
39
+ return self._sockets.get(name, None)
40
+
41
+ @classmethod
42
+ def from_serialized(cls, socket_list_ser: str) -> "ExternalSocketProvider":
43
+ """
44
+ Creates a SocketProvider from a list of serialized sockets
45
+ """
46
+ return cls(
47
+ *[
48
+ DataSocket.from_dict(socket_dict)
49
+ for socket_dict in json.loads(socket_list_ser)
50
+ ]
51
+ )
@@ -0,0 +1,22 @@
1
+ """
2
+ Import index.
3
+ """
4
+
5
+ # Copyright (c) 2023. ECCO Sneaks & Data
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+
20
+ from esd_services_api_client.nexus.algorithms.minimalistic import *
21
+ from esd_services_api_client.nexus.algorithms.recursive import *
22
+ from esd_services_api_client.nexus.algorithms.distributed import *
@@ -0,0 +1,56 @@
1
+ """
2
+ Base algorithm
3
+ """
4
+
5
+ # Copyright (c) 2023. ECCO Sneaks & Data
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ #
19
+
20
+
21
+ from abc import abstractmethod
22
+
23
+ from adapta.metrics import MetricsProvider
24
+ from pandas import DataFrame as PandasDataFrame
25
+
26
+ from esd_services_api_client.nexus.abstractions.nexus_object import NexusObject
27
+ from esd_services_api_client.nexus.abstractions.logger_factory import LoggerFactory
28
+ from esd_services_api_client.nexus.input.input_processor import InputProcessor
29
+
30
+
31
+ class BaselineAlgorithm(NexusObject):
32
+ """
33
+ Base class for all algorithm implementations.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ input_processor: InputProcessor,
39
+ metrics_provider: MetricsProvider,
40
+ logger_factory: LoggerFactory,
41
+ ):
42
+ super().__init__(metrics_provider, logger_factory)
43
+ self._input_processor = input_processor
44
+
45
+ @abstractmethod
46
+ async def _run(self, **kwargs) -> PandasDataFrame:
47
+ """
48
+ Core logic for this algorithm. Implementing this method is mandatory.
49
+ """
50
+
51
+ async def run(self, **kwargs) -> PandasDataFrame:
52
+ """
53
+ Coroutine that executes the algorithm logic.
54
+ """
55
+ async with self._input_processor as input_processor:
56
+ return await self._run(**(await input_processor.process_input(**kwargs)))
@@ -0,0 +1,53 @@
1
+ """
2
+ Algorithm that supports splitting the problem into sub-problems and combining the results.
3
+ Sub-problems can be Distributed, Minimalistic or Recursive as well.
4
+ """
5
+
6
+ # Copyright (c) 2023. ECCO Sneaks & Data
7
+ #
8
+ # Licensed under the Apache License, Version 2.0 (the "License");
9
+ # you may not use this file except in compliance with the License.
10
+ # You may obtain a copy of the License at
11
+ #
12
+ # http://www.apache.org/licenses/LICENSE-2.0
13
+ #
14
+ # Unless required by applicable law or agreed to in writing, software
15
+ # distributed under the License is distributed on an "AS IS" BASIS,
16
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
+ # See the License for the specific language governing permissions and
18
+ # limitations under the License.
19
+ #
20
+
21
+ import asyncio
22
+ from abc import ABC, abstractmethod
23
+
24
+ from pandas import DataFrame as PandasDataFrame
25
+ from esd_services_api_client.nexus.algorithms._baseline_algorithm import (
26
+ BaselineAlgorithm,
27
+ )
28
+
29
+
30
+ class DistributedAlgorithm(BaselineAlgorithm, ABC):
31
+ """
32
+ Distributed algorithm base class.
33
+ """
34
+
35
+ @abstractmethod
36
+ async def _split(self, **_) -> list[BaselineAlgorithm]:
37
+ """
38
+ Sub-problem generator.
39
+ """
40
+
41
+ @abstractmethod
42
+ async def _fold(self, *split_tasks: asyncio.Task) -> PandasDataFrame:
43
+ """
44
+ Sub-problem result aggregator.
45
+ """
46
+
47
+ async def _run(self, **kwargs) -> PandasDataFrame:
48
+ splits = await self._split(**kwargs)
49
+ tasks = [asyncio.create_task(split.run(**kwargs)) for split in splits]
50
+
51
+ await asyncio.wait(*tasks)
52
+
53
+ return await self._fold(*tasks)