synapse-sdk 1.0.0a31__py3-none-any.whl → 1.0.0a33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of synapse-sdk might be problematic. Click here for more details.

Files changed (28) hide show
  1. synapse_sdk/clients/backend/__init__.py +2 -0
  2. synapse_sdk/clients/backend/annotation.py +4 -4
  3. synapse_sdk/clients/backend/dataset.py +57 -5
  4. synapse_sdk/clients/backend/hitl.py +17 -0
  5. synapse_sdk/clients/backend/integration.py +3 -1
  6. synapse_sdk/clients/backend/models.py +44 -0
  7. synapse_sdk/clients/base.py +61 -16
  8. synapse_sdk/plugins/categories/base.py +40 -0
  9. synapse_sdk/plugins/categories/export/actions/export.py +168 -28
  10. synapse_sdk/plugins/categories/export/templates/plugin/export.py +43 -33
  11. synapse_sdk/plugins/categories/smart_tool/templates/config.yaml +2 -0
  12. synapse_sdk/plugins/categories/upload/actions/upload.py +292 -0
  13. synapse_sdk/plugins/categories/upload/templates/config.yaml +6 -0
  14. synapse_sdk/plugins/categories/upload/templates/plugin/__init__.py +0 -0
  15. synapse_sdk/plugins/categories/upload/templates/plugin/upload.py +44 -0
  16. synapse_sdk/plugins/enums.py +3 -1
  17. synapse_sdk/plugins/models.py +16 -0
  18. synapse_sdk/utils/storage/__init__.py +20 -2
  19. {synapse_sdk-1.0.0a31.dist-info → synapse_sdk-1.0.0a33.dist-info}/METADATA +3 -2
  20. {synapse_sdk-1.0.0a31.dist-info → synapse_sdk-1.0.0a33.dist-info}/RECORD +26 -22
  21. {synapse_sdk-1.0.0a31.dist-info → synapse_sdk-1.0.0a33.dist-info}/WHEEL +1 -1
  22. synapse_sdk/plugins/categories/export/actions/utils.py +0 -5
  23. synapse_sdk/plugins/categories/import/actions/import.py +0 -10
  24. /synapse_sdk/plugins/categories/{import → upload}/__init__.py +0 -0
  25. /synapse_sdk/plugins/categories/{import → upload}/actions/__init__.py +0 -0
  26. {synapse_sdk-1.0.0a31.dist-info → synapse_sdk-1.0.0a33.dist-info}/entry_points.txt +0 -0
  27. {synapse_sdk-1.0.0a31.dist-info → synapse_sdk-1.0.0a33.dist-info/licenses}/LICENSE +0 -0
  28. {synapse_sdk-1.0.0a31.dist-info → synapse_sdk-1.0.0a33.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  from synapse_sdk.clients.backend.annotation import AnnotationClientMixin
2
2
  from synapse_sdk.clients.backend.core import CoreClientMixin
3
3
  from synapse_sdk.clients.backend.dataset import DatasetClientMixin
4
+ from synapse_sdk.clients.backend.hitl import HITLClientMixin
4
5
  from synapse_sdk.clients.backend.integration import IntegrationClientMixin
5
6
  from synapse_sdk.clients.backend.ml import MLClientMixin
6
7
 
@@ -11,6 +12,7 @@ class BackendClient(
11
12
  DatasetClientMixin,
12
13
  IntegrationClientMixin,
13
14
  MLClientMixin,
15
+ HITLClientMixin,
14
16
  ):
15
17
  name = 'Backend'
16
18
  token = None
@@ -11,14 +11,14 @@ class AnnotationClientMixin(BaseClient):
11
11
  path = f'task_tags/{pk}/'
12
12
  return self._get(path)
13
13
 
14
- def list_task_tags(self, data):
14
+ def list_task_tags(self, params):
15
15
  path = 'task_tags/'
16
- return self._list(path, data=data)
16
+ return self._list(path, params=params)
17
17
 
18
- def list_tasks(self, data, url_conversion=None, list_all=False):
18
+ def list_tasks(self, params=None, url_conversion=None, list_all=False):
19
19
  path = 'tasks/'
20
20
  url_conversion = get_default_url_conversion(url_conversion, files_fields=['files'])
21
- return self._list(path, data=data, url_conversion=url_conversion, list_all=list_all)
21
+ return self._list(path, params=params, url_conversion=url_conversion, list_all=list_all)
22
22
 
23
23
  def create_tasks(self, data):
24
24
  path = 'tasks/'
@@ -1,4 +1,6 @@
1
1
  from multiprocessing import Pool
2
+ from pathlib import Path
3
+ from typing import Dict, Optional
2
4
 
3
5
  from tqdm import tqdm
4
6
 
@@ -11,21 +13,59 @@ class DatasetClientMixin(BaseClient):
11
13
  path = 'datasets/'
12
14
  return self._list(path)
13
15
 
14
- def create_data_file(self, file_path):
16
+ def get_dataset(self, dataset_id):
17
+ """Get dataset from synapse-backend.
18
+
19
+ Args:
20
+ dataset_id: The dataset id to get.
21
+ """
22
+ path = f'datasets/{dataset_id}/?expand=file_specifications'
23
+ return self._get(path)
24
+
25
+ def create_data_file(self, file_path: Path):
26
+ """Create data file to synapse-backend.
27
+
28
+ Args:
29
+ file_path: The file pathlib object to upload.
30
+ """
15
31
  path = 'data_files/'
16
32
  return self._post(path, files={'file': file_path})
17
33
 
18
34
  def create_data_units(self, data):
35
+ """Create data units to synapse-backend.
36
+
37
+ Args:
38
+ data: The data bindings to upload from create_data_file interface.
39
+ """
19
40
  path = 'data_units/'
20
41
  return self._post(path, data=data)
21
42
 
22
- def import_dataset(self, dataset_id, dataset, project_id=None, batch_size=1000, process_pool=10):
43
+ def upload_dataset(
44
+ self,
45
+ dataset_id: int,
46
+ dataset: Dict,
47
+ project_id: Optional[int] = None,
48
+ batch_size: int = 1000,
49
+ process_pool: int = 10,
50
+ ):
51
+ """Upload dataset to synapse-backend.
52
+
53
+ Args:
54
+ dataset_id: The dataset id to upload the data to.
55
+ dataset: The dataset to upload.
56
+ * structure:
57
+ - files: The files to upload. (key: file name, value: file pathlib object)
58
+ - meta: The meta data to upload.
59
+ project_id: The project id to upload the data to.
60
+ batch_size: The batch size to upload the data.
61
+ process_pool: The process pool to upload the data.
62
+ """
23
63
  # TODO validate dataset with schema
24
64
 
25
65
  params = [(data, dataset_id) for data in dataset]
26
66
 
27
67
  with Pool(processes=process_pool) as pool:
28
- dataset = pool.starmap(self.import_data_file, tqdm(params))
68
+ dataset = pool.starmap(self.upload_data_file, tqdm(params))
29
69
 
30
70
  batches = get_batched_list(dataset, batch_size)
31
71
 
@@ -36,13 +76,25 @@ class DatasetClientMixin(BaseClient):
36
76
  tasks_data = []
37
77
  for data, data_unit in zip(batch, data_units):
38
78
  task_data = {'project': project_id, 'data_unit': data_unit['id']}
39
- # TODO: 추후 import Task data 저장 필요 해당 로직 추가 필요.
79
+ # TODO: Additional logic needed here if task data storage is required during import.
40
80
 
41
81
  tasks_data.append(task_data)
42
82
 
43
83
  self.create_tasks(tasks_data)
44
84
 
45
- def import_data_file(self, data, dataset_id):
85
+ def upload_data_file(self, data: Dict, dataset_id: int) -> Dict:
86
+ """Upload files to synapse-backend.
87
+
88
+ Args:
89
+ data: The data to upload.
90
+ * structure:
91
+ - files: The files to upload. (key: file name, value: file pathlib object)
92
+ - meta: The meta data to upload.
93
+ dataset_id: The dataset id to upload the data to.
94
+
95
+ Returns:
96
+ Dict: The result of the upload.
97
+ """
46
98
  for name, path in data['files'].items():
47
99
  data_file = self.create_data_file(path)
48
100
  data['dataset'] = dataset_id
@@ -0,0 +1,17 @@
1
+ from synapse_sdk.clients.base import BaseClient
2
+ from synapse_sdk.clients.utils import get_default_url_conversion
3
+
4
+
5
+ class HITLClientMixin(BaseClient):
6
+ def get_assignment(self, pk):
7
+ path = f'assignments/{pk}/'
8
+ return self._get(path)
9
+
10
+ def list_assignments(self, params=None, url_conversion=None, list_all=False):
11
+ path = 'assignments/'
12
+ url_conversion = get_default_url_conversion(url_conversion, files_fields=['files'])
13
+ return self._list(path, params=params, url_conversion=url_conversion, list_all=list_all)
14
+
15
+ def set_tags_assignments(self, data, params=None):
16
+ path = 'assignments/set_tags/'
17
+ return self._post(path, payload=data, params=params)
@@ -1,3 +1,4 @@
1
+ from synapse_sdk.clients.backend.models import Storage
1
2
  from synapse_sdk.clients.base import BaseClient
2
3
  from synapse_sdk.utils.file import convert_file_to_base64
3
4
 
@@ -79,5 +80,6 @@ class IntegrationClientMixin(BaseClient):
79
80
  return self._list(path, params=params, list_all=list_all)
80
81
 
81
82
  def get_storage(self, pk):
83
+ """Get specific storage data from synapse backend."""
82
84
  path = f'storages/{pk}/'
83
- return self._get(path)
85
+ return self._get(path, pydantic_model=Storage)
@@ -0,0 +1,44 @@
1
+ from enum import Enum
2
+ from typing import Dict
3
+
4
+ from pydantic import BaseModel
5
+
6
+
7
+ class StorageCategory(str, Enum):
8
+ """Synapse Backend Storage Category Enum."""
9
+
10
+ INTERNAL = 'internal'
11
+ EXTERNAL = 'external'
12
+
13
+
14
+ class StorageProvider(str, Enum):
15
+ """Synapse Backend Storage Provider Enum."""
16
+
17
+ AMAZON_S3 = 'amazon_s3'
18
+ AZURE = 'azure'
19
+ DIGITAL_OCEAN = 'digital_ocean'
20
+ FILE_SYSTEM = 'file_system'
21
+ FTP = 'ftp'
22
+ SFTP = 'sftp'
23
+ MINIO = 'minio'
24
+ GCP = 'gcp'
25
+
26
+
27
+ class Storage(BaseModel):
28
+ """Synapse Backend Storage Model.
29
+
30
+ Attrs:
31
+ id (int): The storage pk.
32
+ name (str): The storage name.
33
+ category (str): The storage category. (ex: internal, external)
34
+ provider (str): The storage provider. (ex: s3, gcp)
35
+ configuration (Dict): The storage configuration.
36
+ is_default (bool): The storage is default for Synapse backend workspace.
37
+ """
38
+
39
+ id: int
40
+ name: str
41
+ category: StorageCategory
42
+ provider: StorageProvider
43
+ configuration: Dict
44
+ is_default: bool
@@ -47,15 +47,12 @@ class BaseClient:
47
47
  # If files are included in the request, open them as binary files
48
48
  if kwargs.get('files') is not None:
49
49
  for name, file in kwargs['files'].items():
50
- # If file is a path string, bind it as a Path object and open
50
+ # Handle both string and Path object cases
51
51
  if isinstance(file, str):
52
- opened_file = Path(file).open(mode='rb')
53
- kwargs['files'][name] = opened_file
54
- opened_files.append(opened_file)
55
- # If file is a Path object, open it directly
56
- elif isinstance(file, Path):
52
+ file = Path(file)
53
+ if isinstance(file, Path):
57
54
  opened_file = file.open(mode='rb')
58
- kwargs['files'][name] = opened_file
55
+ kwargs['files'][name] = (file.name, opened_file)
59
56
  opened_files.append(opened_file)
60
57
  if 'data' in kwargs:
61
58
  for name, value in kwargs['data'].items():
@@ -67,6 +64,7 @@ class BaseClient:
67
64
  kwargs['data'] = json.dumps(kwargs['data'])
68
65
 
69
66
  try:
67
+ # Send request
70
68
  response = getattr(self.requests_session, method)(url, headers=headers, **kwargs)
71
69
  if not response.ok:
72
70
  raise ClientError(
@@ -87,26 +85,59 @@ class BaseClient:
87
85
  except ValueError:
88
86
  return response.text
89
87
 
90
- def _get(self, path, url_conversion=None, **kwargs):
88
+ def _get(self, path, url_conversion=None, pydantic_model=None, **kwargs):
89
+ """
90
+ Perform a GET request and optionally convert response to a pydantic model.
91
+
92
+ Args:
93
+ path (str): URL path to request.
94
+ url_conversion (dict, optional): Configuration for URL to path conversion.
95
+ pydantic_model (Type, optional): Pydantic model to convert the response to.
96
+ **kwargs: Additional keyword arguments to pass to the request.
97
+
98
+ Returns:
99
+ The response data, optionally converted to a pydantic model.
100
+ """
91
101
  response = self._request('get', path, **kwargs)
102
+
92
103
  if url_conversion:
93
104
  if url_conversion['is_list']:
94
105
  files_url_to_path_from_objs(response['results'], **url_conversion, is_async=True)
95
106
  else:
96
107
  files_url_to_path_from_objs(response, **url_conversion)
108
+
109
+ if pydantic_model:
110
+ return self._validate_response_with_pydantic_model(response, pydantic_model)
111
+
97
112
  return response
98
113
 
99
- def _post(self, path, **kwargs):
100
- return self._request('post', path, **kwargs)
114
+ def _post(self, path, pydantic_model=None, **kwargs):
115
+ response = self._request('post', path, **kwargs)
116
+ if pydantic_model:
117
+ return self._validate_response_with_pydantic_model(response, pydantic_model)
118
+ else:
119
+ return response
101
120
 
102
- def _put(self, path, **kwargs):
103
- return self._request('put', path, **kwargs)
121
+ def _put(self, path, pydantic_model=None, **kwargs):
122
+ response = self._request('put', path, **kwargs)
123
+ if pydantic_model:
124
+ return self._validate_response_with_pydantic_model(response, pydantic_model)
125
+ else:
126
+ return response
104
127
 
105
- def _patch(self, path, **kwargs):
106
- return self._request('patch', path, **kwargs)
128
+ def _patch(self, path, pydantic_model=None, **kwargs):
129
+ response = self._request('patch', path, **kwargs)
130
+ if pydantic_model:
131
+ return self._validate_response_with_pydantic_model(response, pydantic_model)
132
+ else:
133
+ return response
107
134
 
108
- def _delete(self, path, **kwargs):
109
- return self._request('delete', path, **kwargs)
135
+ def _delete(self, path, pydantic_model=None, **kwargs):
136
+ response = self._request('delete', path, **kwargs)
137
+ if pydantic_model:
138
+ return self._validate_response_with_pydantic_model(response, pydantic_model)
139
+ else:
140
+ return response
110
141
 
111
142
  def _list(self, path, url_conversion=None, list_all=False, **kwargs):
112
143
  response = self._get(path, **kwargs)
@@ -123,3 +154,17 @@ class BaseClient:
123
154
 
124
155
  def exists(self, api, *args, **kwargs):
125
156
  return getattr(self, api)(*args, **kwargs)['count'] > 0
157
+
158
+ def _validate_response_with_pydantic_model(self, response, pydantic_model):
159
+ """Validate a response with a pydantic model."""
160
+ # Check if model is a pydantic model (has the __pydantic_model__ attribute)
161
+ if (
162
+ hasattr(pydantic_model, '__pydantic_model__')
163
+ or hasattr(pydantic_model, 'model_validate')
164
+ or hasattr(pydantic_model, 'parse_obj')
165
+ ):
166
+ pydantic_model.model_validate(response)
167
+ return response
168
+ else:
169
+ # Not a pydantic model
170
+ raise TypeError('The provided model is not a pydantic model')
@@ -17,6 +17,30 @@ from synapse_sdk.utils.pydantic.errors import pydantic_to_drf_error
17
17
 
18
18
 
19
19
  class Action:
20
+ """Base class for all plugin actions.
21
+
22
+ Attrs:
23
+ name (str): The name of the action.
24
+ category (PluginCategory): The category of the action.
25
+ method (RunMethod): The method to run of the action.
26
+ run_class (Run): The class to run the action.
27
+ params_model (BaseModel): The model to validate the params.
28
+ progress_categories (List[str]): The categories to update the progress.
29
+ params (Dict): The params to run the action.
30
+ plugin_config (Dict): The plugin config.
31
+ plugin_release (PluginRelease): The plugin release.
32
+ config (Dict): The action config.
33
+ requirements (List[str]): The requirements to install.
34
+ job_id (str): The job id.
35
+ direct (bool): The flag to run the action directly.
36
+ debug (bool): The flag to run the action in debug mode.
37
+ envs (Dict): The runtime envs.
38
+ run (Run): The run instance.
39
+
40
+ Raises:
41
+ ActionError: If the action fails.
42
+ """
43
+
20
44
  # class 변수
21
45
  name = None
22
46
  category = None
@@ -159,11 +183,19 @@ class Action:
159
183
  return getattr(self, f'start_by_{self.method.value}')()
160
184
 
161
185
  def start(self):
186
+ """Start the action.
187
+
188
+ TODO: Specify the return type of start method for overrided methods.
189
+ """
162
190
  if self.method == RunMethod.JOB:
163
191
  return self.entrypoint(self.run, **self.params)
164
192
  return self.entrypoint(**self.params)
165
193
 
166
194
  def start_by_task(self):
195
+ """Ray Task based execution.
196
+
197
+ * A task method that simply executes the entrypoint without job management functionality.
198
+ """
167
199
  import ray
168
200
  from ray.exceptions import RayTaskError
169
201
 
@@ -195,6 +227,10 @@ class Action:
195
227
  raise ActionError(e.cause)
196
228
 
197
229
  def start_by_job(self):
230
+ """Ray Job based execution.
231
+
232
+ * Executes the entrypoint with Ray job. Ray job manages the entrypoint execution and stores the results.
233
+ """
198
234
  main_options = []
199
235
  options = ['run', '--direct']
200
236
  arguments = [self.name, f'{json.dumps(json.dumps(self.params))}']
@@ -215,6 +251,10 @@ class Action:
215
251
  )
216
252
 
217
253
  def start_by_restapi(self):
254
+ """Ray Serve based execution.
255
+
256
+ * This method executes a Fastapi endpoint defined within the Plugin.
257
+ """
218
258
  path = self.params.pop('path', '')
219
259
  method = self.params.pop('method')
220
260
 
@@ -1,3 +1,6 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Literal
3
+
1
4
  from pydantic import BaseModel, field_validator
2
5
  from pydantic_core import PydanticCustomError
3
6
 
@@ -9,11 +12,158 @@ from synapse_sdk.plugins.enums import PluginCategory, RunMethod
9
12
  from synapse_sdk.utils.storage import get_pathlib
10
13
 
11
14
 
15
+ class ExportTargetHandler(ABC):
16
+ """
17
+ Abstract base class for handling export targets.
18
+
19
+ This class defines the blueprint for export target handlers, requiring the implementation
20
+ of methods to validate filters, retrieve results, and process collections of results.
21
+ """
22
+
23
+ @abstractmethod
24
+ def validate_filter(self, value: dict, client: Any):
25
+ """
26
+ Validate filter query params to request original data from api.
27
+
28
+ Args:
29
+ value (dict): The filter criteria to validate.
30
+ client (Any): The client used to validate the filter.
31
+
32
+ Raises:
33
+ PydanticCustomError: If the filter criteria are invalid.
34
+
35
+ Returns:
36
+ dict: The validated filter criteria.
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def get_results(self, client: Any, filters: dict):
42
+ """
43
+ Retrieve original data from target sources.
44
+
45
+ Args:
46
+ client (Any): The client used to retrieve the results.
47
+ filters (dict): The filter criteria to apply.
48
+
49
+ Returns:
50
+ tuple: A tuple containing the results and the total count of results.
51
+ """
52
+ pass
53
+
54
+ @abstractmethod
55
+ def get_export_item(self, results):
56
+ """
57
+ Providing elements to build export data.
58
+
59
+ Args:
60
+ results (list): The results to process.
61
+
62
+ Yields:
63
+ generator: A generator that yields processed data items.
64
+ """
65
+ pass
66
+
67
+
68
+ class AssignmentExportTargetHandler(ExportTargetHandler):
69
+ def validate_filter(self, value: dict, client: Any):
70
+ if 'project' not in value:
71
+ raise PydanticCustomError('missing_field', _('Project is required for Assignment.'))
72
+ try:
73
+ client.list_assignments(params=value)
74
+ except ClientError:
75
+ raise PydanticCustomError('client_error', _('Unable to get Assignment.'))
76
+ return value
77
+
78
+ def get_results(self, client: Any, filters: dict):
79
+ return client.list_assignments(params=filters, list_all=True)
80
+
81
+ def get_export_item(self, results):
82
+ for result in results:
83
+ yield {
84
+ 'data': result['data'],
85
+ 'files': result['file'],
86
+ 'id': result['id'],
87
+ }
88
+
89
+
90
+ class GroundTruthExportTargetHandler(ExportTargetHandler):
91
+ def validate_filter(self, value: dict, client: Any):
92
+ if 'ground_truth_dataset_version' not in value:
93
+ raise PydanticCustomError('missing_field', _('Ground Truth dataset version is required.'))
94
+ try:
95
+ client.get_ground_truth_version(value['ground_truth_dataset_version'])
96
+ except ClientError:
97
+ raise PydanticCustomError('client_error', _('Unable to get Ground Truth dataset version.'))
98
+ return value
99
+
100
+ def get_results(self, client: Any, filters: dict):
101
+ filters['ground_truth_dataset_versions'] = filters.pop('ground_truth_dataset_version')
102
+ return client.list_ground_truth_events(params=filters, list_all=True)
103
+
104
+ def get_export_item(self, results):
105
+ for result in results:
106
+ files_key = next(iter(result['data_unit']['files']))
107
+ yield {
108
+ 'data': result['data'],
109
+ 'files': result['data_unit']['files'][files_key],
110
+ 'id': result['ground_truth'],
111
+ }
112
+
113
+
114
+ class TaskExportTargetHandler(ExportTargetHandler):
115
+ def validate_filter(self, value: dict, client: Any):
116
+ if 'project' not in value:
117
+ raise PydanticCustomError('missing_field', _('Project is required for Task.'))
118
+ try:
119
+ client.list_tasks(params=value)
120
+ except ClientError:
121
+ raise PydanticCustomError('client_error', _('Unable to get Task.'))
122
+ return value
123
+
124
+ def get_results(self, client: Any, filters: dict):
125
+ filters['expand'] = 'data_unit'
126
+ return client.list_tasks(params=filters, list_all=True)
127
+
128
+ def get_export_item(self, results):
129
+ for result in results:
130
+ files_key = next(iter(result['data_unit']['files']))
131
+ yield {
132
+ 'data': result['data'],
133
+ 'files': result['data_unit']['files'][files_key],
134
+ 'id': result['id'],
135
+ }
136
+
137
+
138
+ class TargetHandlerFactory:
139
+ @staticmethod
140
+ def get_handler(target: str) -> ExportTargetHandler:
141
+ if target == 'assignment':
142
+ return AssignmentExportTargetHandler()
143
+ elif target == 'ground_truth':
144
+ return GroundTruthExportTargetHandler()
145
+ elif target == 'task':
146
+ return TaskExportTargetHandler()
147
+ else:
148
+ raise ValueError(f'Unknown target: {target}')
149
+
150
+
12
151
  class ExportParams(BaseModel):
152
+ """
153
+ Parameters for the export action.
154
+
155
+ Attributes:
156
+ storage (int): The storage ID to save the exported data.
157
+ save_original_file (bool): Whether to save the original file.
158
+ path (str): The path to save the exported data.
159
+ target (str): The target source to export data from. (ex. ground_truth, assignment, task)
160
+ filter (dict): The filter criteria to apply.
161
+ """
162
+
13
163
  storage: int
14
164
  save_original_file: bool = True
15
165
  path: str
16
- ground_truth_dataset_version: int
166
+ target: Literal['assignment', 'ground_truth', 'task']
17
167
  filter: dict
18
168
 
19
169
  @field_validator('storage')
@@ -27,16 +177,14 @@ class ExportParams(BaseModel):
27
177
  raise PydanticCustomError('client_error', _('Unable to get storage from Synapse backend.'))
28
178
  return value
29
179
 
30
- @field_validator('ground_truth_dataset_version')
180
+ @field_validator('filter')
31
181
  @staticmethod
32
- def check_ground_truth_dataset_version_exists(value, info):
182
+ def check_filter_by_target(value, info):
33
183
  action = info.context['action']
34
184
  client = action.client
35
- try:
36
- client.get_ground_truth_version(value)
37
- except ClientError:
38
- raise PydanticCustomError('client_error', _('Unable to get Ground Truth dataset version.'))
39
- return value
185
+ target = action.params['target']
186
+ handler = TargetHandlerFactory.get_handler(target)
187
+ return handler.validate_filter(value, client)
40
188
 
41
189
 
42
190
  @register_action
@@ -51,32 +199,24 @@ class ExportAction(Action):
51
199
  }
52
200
  }
53
201
 
54
- def get_dataset(self, results):
55
- """Get dataset for export."""
56
- for result in results:
57
- yield {
58
- 'data': result['data'],
59
- 'files': result['data_unit']['files'],
60
- 'id': result['ground_truth'],
61
- }
62
-
63
- def get_filtered_results(self):
64
- """Get filtered ground truth events."""
65
- self.params['filter']['ground_truth_dataset_versions'] = self.params['ground_truth_dataset_version']
66
- filters = {'expand': 'data', **self.params['filter']}
67
-
202
+ def get_filtered_results(self, filters, handler):
203
+ """Get filtered target results."""
68
204
  try:
69
- gt_dataset_events_list = self.client.list_ground_truth_events(params=filters, list_all=True)
70
- results = gt_dataset_events_list[0]
71
- count = gt_dataset_events_list[1]
205
+ result_list = handler.get_results(self.client, filters)
206
+ results = result_list[0]
207
+ count = result_list[1]
72
208
  except ClientError:
73
209
  raise PydanticCustomError('client_error', _('Unable to get Ground Truth dataset.'))
74
210
  return results, count
75
211
 
76
212
  def start(self):
77
- self.params['results'], self.params['count'] = self.get_filtered_results()
78
- dataset = self.get_dataset(self.params['results'])
213
+ filters = {'expand': 'data', **self.params['filter']}
214
+ target = self.params['target']
215
+ handler = TargetHandlerFactory.get_handler(target)
216
+
217
+ self.params['results'], self.params['count'] = self.get_filtered_results(filters, handler)
218
+ export_items = handler.get_export_item(self.params['results'])
79
219
 
80
220
  storage = self.client.get_storage(self.params['storage'])
81
221
  pathlib_cwd = get_pathlib(storage, self.params['path'])
82
- return self.entrypoint(self.run, dataset, pathlib_cwd, **self.params)
222
+ return self.entrypoint(self.run, export_items, pathlib_cwd, **self.params)