dtlpy 1.117.6__py3-none-any.whl → 1.118.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dtlpy/__version__.py CHANGED
@@ -1 +1 @@
1
- version = '1.117.6'
1
+ version = '1.118.13'
@@ -6,7 +6,7 @@ import os
6
6
  import sys
7
7
  import jwt
8
8
 
9
- from .. import exceptions, entities, repositories, utilities, assets
9
+ from .. import exceptions, entities, repositories, utilities, assets, miscellaneous
10
10
 
11
11
  logger = logging.getLogger(name='dtlpy')
12
12
 
@@ -76,8 +76,16 @@ class CommandExecutor:
76
76
  url = 'dtlpy'
77
77
  if args.url is None:
78
78
  try:
79
- payload = jwt.decode(self.dl.client_api.token, algorithms=['HS256'],
80
- verify=False, options={'verify_signature': False})
79
+ # oxsec-disable jwt-signature-disabled - Client-side SDK: signature verification disabled intentionally to check admin role; server validates on API calls
80
+ payload = jwt.decode(
81
+ self.dl.client_api.token,
82
+ options={
83
+ "verify_signature": False,
84
+ "verify_exp": False,
85
+ "verify_aud": False,
86
+ "verify_iss": False,
87
+ }
88
+ )
81
89
  if 'admin' in payload['https://dataloop.ai/authorization']['roles']:
82
90
  url = "https://storage.googleapis.com/dtlpy/dev/dtlpy-latest-py3-none-any.whl"
83
91
  except Exception:
@@ -235,6 +243,13 @@ class CommandExecutor:
235
243
  project = self.dl.projects.get(project_name=args.project_name)
236
244
  dataset = project.datasets.get(dataset_name=args.dataset_name)
237
245
 
246
+ # Validate local_path and local_annotations_path to prevent path traversal
247
+ miscellaneous.PathUtils.validate_paths(
248
+ [args.local_path, args.local_annotations_path],
249
+ base_path=os.getcwd(),
250
+ must_exist=True
251
+ )
252
+
238
253
  dataset.items.upload(local_path=args.local_path,
239
254
  remote_path=args.remote_path,
240
255
  file_types=args.file_types,
@@ -277,6 +292,13 @@ class CommandExecutor:
277
292
  remote_path.pop(remote_path.index(item))
278
293
  filters.add(field="dir", values=remote_path, operator=entities.FiltersOperations.IN, method='or')
279
294
 
295
+ # Validate local_path to prevent path traversal
296
+ miscellaneous.PathUtils.validate_directory_path(
297
+ args.local_path,
298
+ base_path=os.getcwd(),
299
+ must_exist=False
300
+ )
301
+
280
302
  if not args.without_binaries:
281
303
  dataset.items.download(filters=filters,
282
304
  local_path=args.local_path,
@@ -325,6 +347,9 @@ class CommandExecutor:
325
347
  args.split_seconds = int(args.split_seconds)
326
348
  if isinstance(args.split_times, str):
327
349
  args.split_times = [int(sec) for sec in args.split_times.split(",")]
350
+ # Validate filepath to prevent path traversal
351
+ miscellaneous.PathUtils.validate_file_path(args.filename)
352
+
328
353
  self.dl.utilities.videos.Videos.split_and_upload(
329
354
  project_name=args.project_name,
330
355
  dataset_name=args.dataset_name,
@@ -407,6 +432,8 @@ class CommandExecutor:
407
432
  def deploy(self, args):
408
433
  project = self.dl.projects.get(project_name=args.project_name)
409
434
  json_filepath = args.json_file
435
+ # Validate file path to prevent path traversal
436
+ miscellaneous.PathUtils.validate_file_path(json_filepath)
410
437
  deployed_services, package = self.dl.packages.deploy_from_file(project=project, json_filepath=json_filepath)
411
438
  logger.info("Successfully deployed {} from file: {}\nServices: {}".format(len(deployed_services),
412
439
  json_filepath,
@@ -464,6 +491,13 @@ class CommandExecutor:
464
491
  elif args.packages == "push":
465
492
  packages = self.utils.get_packages_repo(args=args)
466
493
 
494
+ # Validate src_path to prevent path traversal
495
+ miscellaneous.PathUtils.validate_directory_path(
496
+ args.src_path,
497
+ base_path=os.getcwd(),
498
+ must_exist=True
499
+ )
500
+
467
501
  package = packages.push(src_path=args.src_path,
468
502
  package_name=args.package_name,
469
503
  checkout=args.checkout)
@@ -568,7 +602,13 @@ class CommandExecutor:
568
602
  answers = inquirer.prompt(questions)
569
603
  #####
570
604
  # create a dir for that panel
571
- os.makedirs(answers.get('name'), exist_ok=True)
605
+ # Validate panel name to prevent path traversal
606
+ panel_name = answers.get('name')
607
+ # Validate panel name to prevent path traversal
608
+ miscellaneous.PathUtils.validate_directory_name(panel_name)
609
+ # Create directory in current working directory
610
+ panel_dir = os.path.join(os.getcwd(), panel_name)
611
+ os.makedirs(panel_dir, exist_ok=True)
572
612
  # dump to dataloop.json
573
613
  app_filename = assets.paths.APP_JSON_FILENAME
574
614
  if not os.path.isfile(app_filename):
@@ -630,12 +670,22 @@ class CommandExecutor:
630
670
  directory = args.dir
631
671
  if directory == '..':
632
672
  directory = os.path.split(os.getcwd())[0]
673
+ # Validate path to prevent path traversal
674
+ miscellaneous.PathUtils.validate_directory_path(
675
+ directory,
676
+ base_path=os.getcwd(),
677
+ must_exist=True
678
+ )
633
679
  os.chdir(directory)
634
680
  print(os.getcwd())
635
681
 
636
682
  @staticmethod
637
683
  def mkdir(args):
638
- os.mkdir(args.name)
684
+ # Validate directory name to prevent path traversal
685
+ miscellaneous.PathUtils.validate_directory_name(args.name)
686
+ # Create directory in current working directory
687
+ dir_path = os.path.join(os.getcwd(), args.name)
688
+ os.mkdir(dir_path)
639
689
 
640
690
  # noinspection PyUnusedLocal
641
691
  @staticmethod
@@ -1827,7 +1827,7 @@ class FrameAnnotation(entities.BaseEntity):
1827
1827
  return frame
1828
1828
 
1829
1829
  @classmethod
1830
- def from_snapshot(cls, annotation, _json, fps):
1830
+ def from_snapshot(cls, annotation, _json, fps=None):
1831
1831
  """
1832
1832
  new frame state to annotation
1833
1833
 
dtlpy/entities/app.py CHANGED
@@ -93,7 +93,7 @@ class App(entities.BaseEntity):
93
93
  .. code-block:: python
94
94
  succeed = app.uninstall()
95
95
  """
96
- return self.apps.uninstall(self.id)
96
+ return self.apps.uninstall(app=self)
97
97
 
98
98
  def update(self):
99
99
  """
dtlpy/entities/compute.py CHANGED
@@ -13,6 +13,7 @@ class ClusterProvider(str, Enum):
13
13
  LOCAL = 'local'
14
14
  RANCHER_K3S = 'rancher-k3s'
15
15
  RANCHER_RKE = 'rancher-rke'
16
+ OPENSHIFT = 'openshift'
16
17
 
17
18
 
18
19
  class ComputeType(str, Enum):
dtlpy/entities/dataset.py CHANGED
@@ -86,6 +86,9 @@ class Dataset(entities.BaseEntity):
86
86
  # api
87
87
  _client_api = attr.ib(type=ApiClient, repr=False)
88
88
 
89
+ # syncing status
90
+ is_syncing = attr.ib(default=False, repr=False)
91
+
89
92
  # entities
90
93
  _project = attr.ib(default=None, repr=False)
91
94
 
@@ -183,6 +186,7 @@ class Dataset(entities.BaseEntity):
183
186
  expiration_options=expiration_options,
184
187
  index_driver=_json.get('indexDriver', None),
185
188
  enable_sync_with_cloned=_json.get('enableSyncWithCloned', None),
189
+ is_syncing=_json.get('isSyncing', False),
186
190
  src_dataset=_json.get('srcDataset', None))
187
191
  inst.is_fetched = is_fetched
188
192
  return inst
@@ -215,6 +219,7 @@ class Dataset(entities.BaseEntity):
215
219
  attr.fields(Dataset).items_count,
216
220
  attr.fields(Dataset).index_driver,
217
221
  attr.fields(Dataset).enable_sync_with_cloned,
222
+ attr.fields(Dataset).is_syncing,
218
223
  attr.fields(Dataset).src_dataset,
219
224
  ))
220
225
  _json.update({'items': self.items_url})
@@ -231,6 +236,7 @@ class Dataset(entities.BaseEntity):
231
236
  _json['expirationOptions'] = self.expiration_options.to_json()
232
237
  if self.enable_sync_with_cloned is not None:
233
238
  _json['enableSyncWithCloned'] = self.enable_sync_with_cloned
239
+ _json['isSyncing'] = self.is_syncing
234
240
  if self.src_dataset is not None:
235
241
  _json['srcDataset'] = self.src_dataset
236
242
  return _json
@@ -288,12 +294,15 @@ class Dataset(entities.BaseEntity):
288
294
  def set_repositories(self):
289
295
  reps = namedtuple('repositories',
290
296
  field_names=['items', 'recipes', 'datasets', 'assignments', 'tasks', 'annotations',
291
- 'ontologies', 'features', 'settings', 'schema', 'collections'])
297
+ 'ontologies', 'features', 'feature_sets', 'settings', 'schema', 'collections'])
298
+ _project_id = None
292
299
  if self._project is None:
293
300
  datasets = repositories.Datasets(client_api=self._client_api, project=self._project)
301
+ if self.projects is not None and len(self.projects) > 0:
302
+ _project_id = self.projects[0]
294
303
  else:
295
304
  datasets = self._project.datasets
296
-
305
+ _project_id = self._project.id
297
306
  return reps(
298
307
  items=repositories.Items(client_api=self._client_api, dataset=self, datasets=datasets),
299
308
  recipes=repositories.Recipes(client_api=self._client_api, dataset=self),
@@ -303,6 +312,7 @@ class Dataset(entities.BaseEntity):
303
312
  datasets=datasets,
304
313
  ontologies=repositories.Ontologies(client_api=self._client_api, dataset=self),
305
314
  features=repositories.Features(client_api=self._client_api, project=self._project, dataset=self),
315
+ feature_sets=repositories.FeatureSets(client_api=self._client_api, project=self._project, project_id=_project_id, dataset=self),
306
316
  settings=repositories.Settings(client_api=self._client_api, dataset=self),
307
317
  schema=repositories.Schema(client_api=self._client_api, dataset=self),
308
318
  collections=repositories.Collections(client_api=self._client_api, dataset=self)
@@ -353,6 +363,11 @@ class Dataset(entities.BaseEntity):
353
363
  assert isinstance(self._repositories.features, repositories.Features)
354
364
  return self._repositories.features
355
365
 
366
+ @property
367
+ def feature_sets(self):
368
+ assert isinstance(self._repositories.feature_sets, repositories.FeatureSets)
369
+ return self._repositories.feature_sets
370
+
356
371
  @property
357
372
  def collections(self):
358
373
  assert isinstance(self._repositories.collections, repositories.Collections)
dtlpy/entities/model.py CHANGED
@@ -423,7 +423,7 @@ class Model(entities.BaseEntity):
423
423
  # default
424
424
  if 'id_to_label_map' not in self.configuration:
425
425
  if not (self.dataset_id == 'null' or self.dataset_id is None):
426
- self.labels = [label.tag for label in self.dataset.labels]
426
+ self.labels = [flat_key for flat_key, _ in self.dataset.labels_flat_dict.items()]
427
427
  self.configuration['id_to_label_map'] = {int(idx): lbl for idx, lbl in enumerate(self.labels)}
428
428
  # use existing
429
429
  else:
@@ -9,7 +9,7 @@ from typing import Optional, List, Any
9
9
  import attr
10
10
 
11
11
  from .filters import FiltersOperations, FiltersOrderByDirection, FiltersResource
12
- from .. import miscellaneous
12
+ from .. import miscellaneous, exceptions
13
13
  from ..services.api_client import ApiClient
14
14
 
15
15
  logger = logging.getLogger(name='dtlpy')
@@ -243,8 +243,12 @@ class PagedEntities:
243
243
  :param page_offset: page offset (for offset-based)
244
244
  :param page_size: page size
245
245
  """
246
- items = self.return_page(page_offset=page_offset, page_size=page_size)
247
- self.items = items
246
+ try:
247
+ items = self.return_page(page_offset=page_offset, page_size=page_size)
248
+ self.items = items
249
+ except exceptions.BadRequest as e:
250
+ logger.warning(f"BadRequest error received: {str(e)}")
251
+ self.items = miscellaneous.List(list())
248
252
 
249
253
  def next_page(self) -> None:
250
254
  """
dtlpy/entities/service.py CHANGED
@@ -142,7 +142,7 @@ class KubernetesRuntime(ServiceRuntime):
142
142
  num_replicas=DEFAULT_NUM_REPLICAS,
143
143
  concurrency=DEFAULT_CONCURRENCY,
144
144
  dynamic_concurrency=None,
145
- concurrency_update_method=None,
145
+ dynamic_concurrency_config=None,
146
146
  runner_image=None,
147
147
  autoscaler=None,
148
148
  **kwargs):
@@ -156,7 +156,7 @@ class KubernetesRuntime(ServiceRuntime):
156
156
  self.single_agent = kwargs.get('singleAgent', None)
157
157
  self.preemptible = kwargs.get('preemptible', None)
158
158
  self.dynamic_concurrency = kwargs.get('dynamicConcurrency', dynamic_concurrency)
159
- self.concurrency_update_method = kwargs.get('concurrencyUpdateMethod', concurrency_update_method)
159
+ self.dynamic_concurrency_config = kwargs.get('dynamicConcurrencyConfig', dynamic_concurrency_config)
160
160
 
161
161
  self.autoscaler = kwargs.get('autoscaler', autoscaler)
162
162
  if self.autoscaler is not None and isinstance(self.autoscaler, dict):
@@ -191,8 +191,8 @@ class KubernetesRuntime(ServiceRuntime):
191
191
  if self.dynamic_concurrency is not None:
192
192
  _json['dynamicConcurrency'] = self.dynamic_concurrency
193
193
 
194
- if self.concurrency_update_method is not None:
195
- _json['concurrencyUpdateMethod'] = self.concurrency_update_method
194
+ if self.dynamic_concurrency_config is not None:
195
+ _json['dynamicConcurrencyConfig'] = self.dynamic_concurrency_config
196
196
 
197
197
  return _json
198
198
 
@@ -18,3 +18,4 @@ from .git_utils import GitUtils
18
18
  from .zipping import Zipping
19
19
  from .list_print import List
20
20
  from .json_utils import JsonUtils
21
+ from .path_utils import PathUtils
@@ -0,0 +1,264 @@
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+ from .. import exceptions
5
+ from ..services import service_defaults
6
+
7
+
8
+ class PathUtils:
9
+ """
10
+ Utility class for path validation and sanitization to prevent path traversal attacks.
11
+ """
12
+ allowed_roots = [tempfile.gettempdir(), service_defaults.DATALOOP_PATH]
13
+
14
+ @staticmethod
15
+ def _contains_traversal(path: str) -> bool:
16
+ """
17
+ Check if path contains path traversal sequences.
18
+
19
+ :param str path: Path to check
20
+ :return: True if path contains traversal sequences
21
+ :rtype: bool
22
+ """
23
+ if not path:
24
+ return False
25
+
26
+ # Normalize the path to handle different separators
27
+ normalized = os.path.normpath(path)
28
+
29
+ # Check for parent directory references
30
+ parts = Path(normalized).parts
31
+ if '..' in parts:
32
+ return True
33
+
34
+ # Check for encoded traversal sequences (evasion attempts)
35
+ if '%2e%2e' in path.lower() or '..%2f' in path.lower() or '..%5c' in path.lower():
36
+ return True
37
+
38
+ return False
39
+
40
+ @staticmethod
41
+ def _is_within_base(resolved_path: str, base_path: str) -> bool:
42
+ """
43
+ Check if resolved_path is within base_path.
44
+
45
+ :param str resolved_path: Absolute resolved path
46
+ :param str base_path: Base directory path
47
+ :return: True if resolved_path is within base_path
48
+ :rtype: bool
49
+ """
50
+ try:
51
+ resolved = os.path.abspath(os.path.normpath(resolved_path))
52
+ base = os.path.abspath(os.path.normpath(base_path))
53
+
54
+ # Get common path
55
+ common = os.path.commonpath([resolved, base])
56
+ return common == base
57
+ except (ValueError, OSError):
58
+ # On Windows, if paths are on different drives, commonpath raises ValueError
59
+ return False
60
+
61
+ @staticmethod
62
+ def _is_allowed_path(resolved_path: str, base_path: str) -> bool:
63
+ """
64
+ Check if resolved_path is within base_path or any allowed_root.
65
+
66
+ :param str resolved_path: Absolute resolved path
67
+ :param str base_path: Base directory path
68
+ :return: True if resolved_path is within base_path or any allowed_root
69
+ :rtype: bool
70
+ """
71
+ for allowed_root in [base_path] + PathUtils.allowed_roots:
72
+ if PathUtils._is_within_base(resolved_path, allowed_root):
73
+ return True
74
+ return False
75
+
76
+ @staticmethod
77
+ def validate_directory_name(name: str) -> str:
78
+ """
79
+ Validate a directory name to ensure it doesn't contain path traversal sequences.
80
+
81
+ :param str name: Directory name to validate
82
+ :return: Validated directory name
83
+ :rtype: str
84
+ :raises PlatformException: If name contains invalid characters or traversal sequences
85
+ """
86
+ if not name:
87
+ raise exceptions.PlatformException(
88
+ error='400',
89
+ message='Directory name cannot be empty'
90
+ )
91
+
92
+ # Check for path separators
93
+ if os.sep in name or (os.altsep and os.altsep in name):
94
+ raise exceptions.PlatformException(
95
+ error='400',
96
+ message='Directory name cannot contain path separators'
97
+ )
98
+
99
+ # Check for traversal sequences
100
+ if PathUtils._contains_traversal(name):
101
+ raise exceptions.PlatformException(
102
+ error='400',
103
+ message='Directory name cannot contain path traversal sequences'
104
+ )
105
+
106
+ return name
107
+
108
+ @staticmethod
109
+ def _validate_single_path(path, base_path: str, must_exist: bool):
110
+ """
111
+ Internal method to validate a single path string.
112
+
113
+ :param path: Path to validate (str or Path object)
114
+ :param str base_path: Base directory to restrict path to
115
+ :param bool must_exist: If True, path must exist
116
+ :raises PlatformException: If path is invalid or contains traversal sequences
117
+ """
118
+ # Convert Path objects to strings
119
+ if isinstance(path, Path):
120
+ path = str(path)
121
+ if isinstance(base_path, Path):
122
+ base_path = str(base_path)
123
+
124
+ # Skip validation if not a string
125
+ if not isinstance(path, str):
126
+ return
127
+
128
+ # Skip validation for URLs and external paths
129
+ if path.startswith(('http://', 'https://', 'external://')):
130
+ return
131
+
132
+ # Empty string check
133
+ if not path:
134
+ raise exceptions.PlatformException(
135
+ error='400',
136
+ message='Path cannot be empty'
137
+ )
138
+
139
+ # Check for traversal sequences in the original path
140
+ if PathUtils._contains_traversal(path):
141
+ raise exceptions.PlatformException(
142
+ error='400',
143
+ message='Path contains invalid traversal sequences'
144
+ )
145
+
146
+ # Resolve path (absolute paths allowed if within base_path)
147
+ if os.path.isabs(path):
148
+ resolved = os.path.abspath(os.path.normpath(path))
149
+ else:
150
+ resolved = os.path.abspath(os.path.normpath(os.path.join(base_path, path)))
151
+
152
+ # Reject if path is outside base_path or allowed_roots
153
+ if not PathUtils._is_allowed_path(resolved, base_path):
154
+ raise exceptions.PlatformException(
155
+ error='400',
156
+ message='Path resolves outside allowed directory'
157
+ )
158
+
159
+ # Check if path must exist
160
+ if must_exist and not os.path.exists(resolved):
161
+ raise exceptions.PlatformException(
162
+ error='404',
163
+ message='Path does not exist: {}'.format(path)
164
+ )
165
+
166
+ @staticmethod
167
+ def validate_paths(paths, base_path = None, must_exist: bool = False):
168
+ """
169
+ Validate file or directory paths against path traversal attacks.
170
+ Accepts a list of paths and validates each one.
171
+ Skips validation if path is None or not a string.
172
+ Skips validation for URLs (http://, https://) and external paths (external://).
173
+
174
+ :param paths: Path(s) to validate - can be str, Path, list of str/Path, or None
175
+ :param base_path: Optional base directory to restrict path to (str or Path). If None, uses current working directory
176
+ :param bool must_exist: If True, path must exist
177
+ :raises PlatformException: If any path is invalid or contains traversal sequences
178
+ """
179
+ # Handle None - skip validation
180
+ if paths is None:
181
+ return
182
+
183
+ # Convert base_path Path object to string
184
+ if isinstance(base_path, Path):
185
+ base_path = str(base_path)
186
+
187
+ # Resolve base_path
188
+ if base_path is None:
189
+ base_path = os.getcwd()
190
+
191
+ # Handle list of paths
192
+ if isinstance(paths, list):
193
+ for path in paths:
194
+ PathUtils._validate_single_path(path, base_path, must_exist)
195
+ else:
196
+ # Single path
197
+ PathUtils._validate_single_path(paths, base_path, must_exist)
198
+
199
+ @staticmethod
200
+ def validate_file_path(file_path, base_path = None, must_exist: bool = True):
201
+ """
202
+ Validate a file path against path traversal attacks.
203
+
204
+ :param file_path: File path to validate (str or Path object)
205
+ :param base_path: Optional base directory to restrict path to (str or Path). If None, uses current working directory
206
+ :param bool must_exist: If True, file must exist (default: True)
207
+ :raises PlatformException: If path is invalid, contains traversal sequences, or is not a file
208
+ """
209
+ # Convert Path objects to strings
210
+ if isinstance(file_path, Path):
211
+ file_path = str(file_path)
212
+ if isinstance(base_path, Path):
213
+ base_path = str(base_path)
214
+
215
+ PathUtils.validate_paths(file_path, base_path=base_path, must_exist=must_exist)
216
+
217
+ if must_exist and isinstance(file_path, str) and not file_path.startswith(('http://', 'https://', 'external://')):
218
+ # Resolve path to check if it's a file
219
+ if base_path is None:
220
+ base_path = os.getcwd()
221
+ if os.path.isabs(file_path):
222
+ resolved = os.path.abspath(os.path.normpath(file_path))
223
+ else:
224
+ resolved = os.path.abspath(os.path.normpath(os.path.join(base_path, file_path)))
225
+
226
+ if not os.path.isfile(resolved):
227
+ raise exceptions.PlatformException(
228
+ error='400',
229
+ message='Path is not a file: {}'.format(file_path)
230
+ )
231
+
232
+ @staticmethod
233
+ def validate_directory_path(dir_path, base_path = None, must_exist: bool = True):
234
+ """
235
+ Validate a directory path against path traversal attacks.
236
+
237
+ :param dir_path: Directory path to validate (str or Path object)
238
+ :param base_path: Optional base directory to restrict path to (str or Path). If None, uses current working directory
239
+ :param bool must_exist: If True, directory must exist (default: True)
240
+ :raises PlatformException: If path is invalid, contains traversal sequences, or is not a directory
241
+ """
242
+ # Convert Path objects to strings
243
+ if isinstance(dir_path, Path):
244
+ dir_path = str(dir_path)
245
+ if isinstance(base_path, Path):
246
+ base_path = str(base_path)
247
+
248
+ PathUtils.validate_paths(dir_path, base_path=base_path, must_exist=must_exist)
249
+
250
+ if must_exist and isinstance(dir_path, str) and not dir_path.startswith(('http://', 'https://', 'external://')):
251
+ # Resolve path to check if it's a directory
252
+ if base_path is None:
253
+ base_path = os.getcwd()
254
+ if os.path.isabs(dir_path):
255
+ resolved = os.path.abspath(os.path.normpath(dir_path))
256
+ else:
257
+ resolved = os.path.abspath(os.path.normpath(os.path.join(base_path, dir_path)))
258
+
259
+ if not os.path.isdir(resolved):
260
+ raise exceptions.PlatformException(
261
+ error='400',
262
+ message='Path is not a directory: {}'.format(dir_path)
263
+ )
264
+
@@ -472,31 +472,32 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
472
472
  self.logger.debug("Downloading subset {!r} of {}".format(subset, self.model_entity.dataset.name))
473
473
 
474
474
  annotation_filters = None
475
- if subset in annotations_subsets:
476
- annotation_filters = entities.Filters(
477
- use_defaults=False,
478
- resource=entities.FiltersResource.ANNOTATION,
479
- custom_filter=annotations_subsets[subset],
480
- )
481
- # if user provided annotation_filters, skip the default filters
482
- elif self.model_entity.output_type is not None and self.model_entity.output_type != "embedding":
483
- annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
484
- if self.model_entity.output_type in [
485
- entities.AnnotationType.SEGMENTATION,
486
- entities.AnnotationType.POLYGON,
487
- ]:
488
- model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
489
- else:
490
- model_output_types = [self.model_entity.output_type]
491
-
492
- annotation_filters.add(
493
- field=entities.FiltersKnownFields.TYPE,
494
- values=model_output_types,
495
- operator=entities.FiltersOperations.IN,
496
- )
475
+ if self.model_entity.output_type != "embedding":
476
+ if subset in annotations_subsets:
477
+ annotation_filters = entities.Filters(
478
+ use_defaults=False,
479
+ resource=entities.FiltersResource.ANNOTATION,
480
+ custom_filter=annotations_subsets[subset],
481
+ )
482
+ # if user provided annotation_filters, skip the default filters
483
+ elif self.model_entity.output_type is not None:
484
+ annotation_filters = entities.Filters(resource=entities.FiltersResource.ANNOTATION, use_defaults=False)
485
+ if self.model_entity.output_type in [
486
+ entities.AnnotationType.SEGMENTATION,
487
+ entities.AnnotationType.POLYGON,
488
+ ]:
489
+ model_output_types = [entities.AnnotationType.SEGMENTATION, entities.AnnotationType.POLYGON]
490
+ else:
491
+ model_output_types = [self.model_entity.output_type]
492
+
493
+ annotation_filters.add(
494
+ field=entities.FiltersKnownFields.TYPE,
495
+ values=model_output_types,
496
+ operator=entities.FiltersOperations.IN,
497
+ )
497
498
 
498
- annotation_filters = self.__include_model_annotations(annotation_filters)
499
- annotations_subsets[subset] = annotation_filters.prepare()
499
+ annotation_filters = self.__include_model_annotations(annotation_filters)
500
+ annotations_subsets[subset] = annotation_filters.prepare()
500
501
 
501
502
  ret_list = self.__download_items(
502
503
  dataset=dataset,
@@ -709,7 +710,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
709
710
  valid_vectors = []
710
711
  items_to_upload = []
711
712
  vectors_to_upload = []
712
-
713
+
713
714
  for item, vector in zip(_items, vectors):
714
715
  # Check if vector is valid
715
716
  if vector is None or len(vector) != embeddings_size:
@@ -719,25 +720,25 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
719
720
  # Item and vector are valid
720
721
  valid_items.append(item)
721
722
  valid_vectors.append(vector)
722
-
723
+
723
724
  # Check if item should be skipped (prompt items)
724
725
  _system_metadata = getattr(item, 'system', dict())
725
726
  is_prompt = _system_metadata.get('shebang', dict()).get('dltype', '') == 'prompt'
726
727
  if skip_default_items and is_prompt:
727
728
  self.logger.debug(f"Skipping feature upload for prompt item {item.id}")
728
729
  continue
729
-
730
+
730
731
  # Items were not skipped - should be uploaded
731
732
  items_to_upload.append(item)
732
733
  vectors_to_upload.append(vector)
733
-
734
+
734
735
  # Update the original lists with valid items only
735
736
  _items[:] = valid_items
736
737
  vectors[:] = valid_vectors
737
-
738
+
738
739
  if len(_items) != len(vectors):
739
740
  raise ValueError(f"The number of items ({len(_items)}) is not equal to the number of vectors ({len(vectors)}).")
740
-
741
+
741
742
  self.logger.debug(f"Uploading {len(items_to_upload)} items' feature vectors for model {self.model_entity.name}.")
742
743
  try:
743
744
  start_time = time.time()
@@ -830,7 +831,7 @@ class BaseModelAdapter(utilities.BaseServiceRunner):
830
831
  logger.info("Received {s} for training".format(s=model.id))
831
832
  model = model.wait_for_model_ready()
832
833
  if model.status == 'failed':
833
- raise ValueError("Model is in failed state, cannot train.")
834
+ logger.warning("Model failed. New training will attempt to resume from previous checkpoints.")
834
835
 
835
836
  ##############
836
837
  # Set status #