datachain 0.30.2__py3-none-any.whl → 0.30.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of datachain might be problematic. Click here for more details.

@@ -77,12 +77,15 @@ def to_database(
77
77
  on_conflict: Optional[str] = None,
78
78
  conflict_columns: Optional[list[str]] = None,
79
79
  column_mapping: Optional[dict[str, Optional[str]]] = None,
80
- ) -> None:
80
+ ) -> int:
81
81
  """
82
82
  Implementation function for exporting DataChain to database tables.
83
83
 
84
84
  This is the core implementation that handles the actual database operations.
85
85
  For user-facing documentation, see DataChain.to_database() method.
86
+
87
+ Returns:
88
+ int: Number of rows affected (inserted/updated).
86
89
  """
87
90
  if on_conflict and on_conflict not in ("ignore", "update"):
88
91
  raise ValueError(
@@ -101,11 +104,16 @@ def to_database(
101
104
  all_columns, normalized_column_mapping
102
105
  )
103
106
 
107
+ normalized_conflict_columns = _normalize_conflict_columns(
108
+ conflict_columns, normalized_column_mapping
109
+ )
110
+
104
111
  with _connect(connection) as conn:
105
112
  metadata = sqlalchemy.MetaData()
106
113
  table = sqlalchemy.Table(table_name, metadata, *columns)
107
114
 
108
115
  table_existed_before = False
116
+ total_rows_affected = 0
109
117
  try:
110
118
  with conn.begin():
111
119
  # Check if table exists to determine if we should clean up on error.
@@ -117,14 +125,18 @@ def to_database(
117
125
 
118
126
  rows_iter = chain._leaf_values()
119
127
  for batch in batched(rows_iter, batch_rows):
120
- _process_batch(
128
+ rows_affected = _process_batch(
121
129
  conn,
122
130
  table,
123
131
  batch,
124
132
  on_conflict,
125
- conflict_columns,
133
+ normalized_conflict_columns,
126
134
  column_indices_and_names,
127
135
  )
136
+ if rows_affected < 0 or total_rows_affected < 0:
137
+ total_rows_affected = -1
138
+ else:
139
+ total_rows_affected += rows_affected
128
140
  except Exception:
129
141
  if not table_existed_before:
130
142
  try:
@@ -134,6 +146,8 @@ def to_database(
134
146
  pass
135
147
  raise
136
148
 
149
+ return total_rows_affected
150
+
137
151
 
138
152
  def _normalize_column_mapping(
139
153
  column_mapping: dict[str, Optional[str]],
@@ -174,6 +188,30 @@ def _normalize_column_mapping(
174
188
  return normalized_mapping
175
189
 
176
190
 
191
+ def _normalize_conflict_columns(
192
+ conflict_columns: Optional[list[str]], column_mapping: dict[str, Optional[str]]
193
+ ) -> Optional[list[str]]:
194
+ """
195
+ Normalize conflict_columns by converting DataChain format to database format
196
+ and applying column mapping.
197
+ """
198
+ if not conflict_columns:
199
+ return None
200
+
201
+ normalized_columns = []
202
+ for col in conflict_columns:
203
+ db_col = ColumnMeta.to_db_name(col)
204
+
205
+ if db_col in column_mapping or hasattr(column_mapping, "default_factory"):
206
+ mapped_name = column_mapping[db_col]
207
+ if mapped_name:
208
+ normalized_columns.append(mapped_name)
209
+ else:
210
+ normalized_columns.append(db_col)
211
+
212
+ return normalized_columns
213
+
214
+
177
215
  def _prepare_columns(all_columns, column_mapping):
178
216
  """Prepare column mapping and column definitions."""
179
217
  column_indices_and_names = [] # List of (index, target_name) tuples
@@ -192,8 +230,12 @@ def _prepare_columns(all_columns, column_mapping):
192
230
 
193
231
  def _process_batch(
194
232
  conn, table, batch, on_conflict, conflict_columns, column_indices_and_names
195
- ):
196
- """Process a batch of rows with conflict resolution."""
233
+ ) -> int:
234
+ """Process a batch of rows with conflict resolution.
235
+
236
+ Returns:
237
+ int: Number of rows affected by the insert operation.
238
+ """
197
239
 
198
240
  def prepare_row(row_values):
199
241
  """Convert a row tuple to a dictionary with proper DB column names."""
@@ -206,6 +248,7 @@ def _process_batch(
206
248
 
207
249
  supports_conflict = on_conflict and conn.engine.name in ("postgresql", "sqlite")
208
250
 
251
+ insert_stmt: Any # Can be PostgreSQL, SQLite, or regular insert statement
209
252
  if supports_conflict:
210
253
  # Use dialect-specific insert for conflict resolution
211
254
  if conn.engine.name == "postgresql":
@@ -249,7 +292,8 @@ def _process_batch(
249
292
  stacklevel=2,
250
293
  )
251
294
 
252
- conn.execute(insert_stmt, rows_to_insert)
295
+ result = conn.execute(insert_stmt, rows_to_insert)
296
+ return result.rowcount
253
297
 
254
298
 
255
299
  def read_database(
@@ -67,6 +67,7 @@ from .utils import (
67
67
  Sys,
68
68
  _get_merge_error_str,
69
69
  _validate_merge_on,
70
+ is_studio,
70
71
  resolve_columns,
71
72
  )
72
73
 
@@ -284,7 +285,11 @@ class DataChain:
284
285
  """Underlying dataset, if there is one."""
285
286
  if not self.name:
286
287
  return None
287
- return self.session.catalog.get_dataset(self.name, self._query.project)
288
+ return self.session.catalog.get_dataset(
289
+ self.name,
290
+ namespace_name=self._query.project.namespace.name,
291
+ project_name=self._query.project.name,
292
+ )
288
293
 
289
294
  def __or__(self, other: "Self") -> "Self":
290
295
  """Return `self.union(other)`."""
@@ -605,7 +610,7 @@ class DataChain:
605
610
  project = self.session.catalog.metastore.get_project(
606
611
  project_name,
607
612
  namespace_name,
608
- create=self.session.catalog.metastore.project_allowed_to_create,
613
+ create=is_studio(),
609
614
  )
610
615
  except ProjectNotFoundError as e:
611
616
  # not being able to create it as creation is not allowed
@@ -1180,17 +1185,13 @@ class DataChain:
1180
1185
  )
1181
1186
 
1182
1187
  def mutate(self, **kwargs) -> "Self":
1183
- """Create new signals based on existing signals.
1184
-
1185
- This method cannot modify existing columns. If you need to modify an
1186
- existing column, use a different name for the new column and then use
1187
- `select()` to choose which columns to keep.
1188
+ """Create or modify signals based on existing signals.
1188
1189
 
1189
1190
  This method is vectorized and more efficient compared to map(), and it does not
1190
1191
  extract or download any data from the internal database. However, it can only
1191
1192
  utilize predefined built-in functions and their combinations.
1192
1193
 
1193
- The supported functions:
1194
+ Supported functions:
1194
1195
  Numerical: +, -, *, /, rand(), avg(), count(), func(),
1195
1196
  greatest(), least(), max(), min(), sum()
1196
1197
  String: length(), split(), replace(), regexp_replace()
@@ -1217,13 +1218,20 @@ class DataChain:
1217
1218
  ```
1218
1219
 
1219
1220
  This method can be also used to rename signals. If the Column("name") provided
1220
- as value for the new signal - the old column will be dropped. Otherwise a new
1221
- column is created.
1221
+ as value for the new signal - the old signal will be dropped. Otherwise a new
1222
+ signal is created. Exception, if the old signal is nested one (e.g.
1223
+ `C("file.path")`), it will be kept to keep the object intact.
1222
1224
 
1223
1225
  Example:
1224
1226
  ```py
1225
1227
  dc.mutate(
1226
- newkey=Column("oldkey")
1228
+ newkey=Column("oldkey") # drops oldkey
1229
+ )
1230
+ ```
1231
+
1232
+ ```py
1233
+ dc.mutate(
1234
+ size=Column("file.size") # keeps `file.size`
1227
1235
  )
1228
1236
  ```
1229
1237
  """
@@ -1258,8 +1266,10 @@ class DataChain:
1258
1266
  # adding new signal
1259
1267
  mutated[name] = value
1260
1268
 
1269
+ new_schema = schema.mutate(kwargs)
1261
1270
  return self._evolve(
1262
- query=self._query.mutate(**mutated), signal_schema=schema.mutate(kwargs)
1271
+ query=self._query.mutate(new_schema=new_schema, **mutated),
1272
+ signal_schema=new_schema,
1263
1273
  )
1264
1274
 
1265
1275
  @property
@@ -2298,13 +2308,17 @@ class DataChain:
2298
2308
  on_conflict: Optional[str] = None,
2299
2309
  conflict_columns: Optional[list[str]] = None,
2300
2310
  column_mapping: Optional[dict[str, Optional[str]]] = None,
2301
- ) -> None:
2311
+ ) -> int:
2302
2312
  """Save chain to a database table using a given database connection.
2303
2313
 
2304
2314
  This method exports all DataChain records to a database table, creating the
2305
2315
  table if it doesn't exist and appending data if it does. The table schema
2306
2316
  is automatically inferred from the DataChain's signal schema.
2307
2317
 
2318
+ For PostgreSQL, tables are created in the schema specified by the connection's
2319
+ search_path (defaults to 'public'). Use URL parameters to target specific
2320
+ schemas.
2321
+
2308
2322
  Parameters:
2309
2323
  table_name: Name of the database table to create/write to.
2310
2324
  connection: SQLAlchemy connectable, str, or a sqlite3 connection
@@ -2328,20 +2342,26 @@ class DataChain:
2328
2342
  - Set values to None to skip columns entirely, or use `defaultdict` to
2329
2343
  skip all columns except those specified.
2330
2344
 
2345
+ Returns:
2346
+ int: Number of rows affected (inserted/updated). -1 if DB driver doesn't
2347
+ support telemetry.
2348
+
2331
2349
  Examples:
2332
2350
  Basic usage with PostgreSQL:
2333
2351
  ```py
2334
- import sqlalchemy as sa
2335
2352
  import datachain as dc
2336
2353
 
2337
- chain = dc.read_storage("s3://my-bucket/")
2338
- engine = sa.create_engine("postgresql://user:pass@localhost/mydb")
2339
- chain.to_database("files_table", engine)
2354
+ rows_affected = (dc
2355
+ .read_storage("s3://my-bucket/")
2356
+ .to_database("files_table", "postgresql://user:pass@localhost/mydb")
2357
+ )
2358
+ print(f"Inserted/updated {rows_affected} rows")
2340
2359
  ```
2341
2360
 
2342
2361
  Using SQLite with connection string:
2343
2362
  ```py
2344
- chain.to_database("my_table", "sqlite:///data.db")
2363
+ rows_affected = chain.to_database("my_table", "sqlite:///data.db")
2364
+ print(f"Affected {rows_affected} rows")
2345
2365
  ```
2346
2366
 
2347
2367
  Column mapping and renaming:
@@ -2360,7 +2380,9 @@ class DataChain:
2360
2380
  chain.to_database("my_table", engine, on_conflict="ignore")
2361
2381
 
2362
2382
  # Update existing records
2363
- chain.to_database("my_table", engine, on_conflict="update")
2383
+ chain.to_database(
2384
+ "my_table", engine, on_conflict="update", conflict_columns=["id"]
2385
+ )
2364
2386
  ```
2365
2387
 
2366
2388
  Working with different databases:
@@ -2372,10 +2394,16 @@ class DataChain:
2372
2394
  # SQLite in-memory
2373
2395
  chain.to_database("temp_table", "sqlite:///:memory:")
2374
2396
  ```
2397
+
2398
+ PostgreSQL with schema support:
2399
+ ```py
2400
+ pg_url = "postgresql://user:pass@host/db?options=-c search_path=analytics"
2401
+ chain.to_database("processed_data", pg_url)
2402
+ ```
2375
2403
  """
2376
2404
  from .database import to_database
2377
2405
 
2378
- to_database(
2406
+ return to_database(
2379
2407
  self,
2380
2408
  table_name,
2381
2409
  connection,
@@ -13,7 +13,7 @@ from datachain.lib.signal_schema import SignalSchema
13
13
  from datachain.query import Session
14
14
  from datachain.query.dataset import DatasetQuery
15
15
 
16
- from .utils import Sys
16
+ from .utils import Sys, is_studio
17
17
  from .values import read_values
18
18
 
19
19
  if TYPE_CHECKING:
@@ -343,7 +343,7 @@ def delete_dataset(
343
343
  namespace_name=namespace,
344
344
  )
345
345
 
346
- if not catalog.metastore.is_local_dataset(namespace_name) and studio:
346
+ if not is_studio() and studio:
347
347
  return remove_studio_dataset(
348
348
  None, name, namespace_name, project_name, version=version, force=force
349
349
  )
@@ -357,7 +357,14 @@ def delete_dataset(
357
357
  ) from None
358
358
 
359
359
  if not force:
360
- version = version or catalog.get_dataset(name, ds_project).latest_version
360
+ version = (
361
+ version
362
+ or catalog.get_dataset(
363
+ name,
364
+ namespace_name=ds_project.namespace.name,
365
+ project_name=ds_project.name,
366
+ ).latest_version
367
+ )
361
368
  else:
362
369
  version = None
363
370
  catalog.remove_dataset(name, ds_project, version=version, force=force)
@@ -403,9 +410,7 @@ def move_dataset(
403
410
  namespace, project, name = catalog.get_full_dataset_name(src)
404
411
  dest_namespace, dest_project, dest_name = catalog.get_full_dataset_name(dest)
405
412
 
406
- dataset = catalog.get_dataset(
407
- name, catalog.metastore.get_project(project, namespace)
408
- )
413
+ dataset = catalog.get_dataset(name, namespace_name=namespace, project_name=project)
409
414
 
410
415
  catalog.update_dataset(
411
416
  dataset,
@@ -413,6 +418,6 @@ def move_dataset(
413
418
  project_id=catalog.metastore.get_project(
414
419
  dest_project,
415
420
  dest_namespace,
416
- create=catalog.metastore.project_allowed_to_create,
421
+ create=is_studio(),
417
422
  ).id,
418
423
  )
datachain/lib/dc/utils.py CHANGED
@@ -15,6 +15,7 @@ from datachain.func.base import Function
15
15
  from datachain.lib.data_model import DataModel, DataType
16
16
  from datachain.lib.utils import DataChainParamsError
17
17
  from datachain.query.schema import DEFAULT_DELIMITER
18
+ from datachain.utils import getenv_bool
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from typing_extensions import Concatenate, ParamSpec
@@ -26,6 +27,10 @@ if TYPE_CHECKING:
26
27
  D = TypeVar("D", bound="DataChain")
27
28
 
28
29
 
30
+ def is_studio() -> bool:
31
+ return getenv_bool("DATACHAIN_IS_STUDIO", default=False)
32
+
33
+
29
34
  def resolve_columns(
30
35
  method: "Callable[Concatenate[D, P], D]",
31
36
  ) -> "Callable[Concatenate[D, P], D]":
@@ -28,7 +28,9 @@ def create(
28
28
  """
29
29
  session = Session.get(session)
30
30
 
31
- if not session.catalog.metastore.namespace_allowed_to_create:
31
+ from datachain.lib.dc.utils import is_studio
32
+
33
+ if not is_studio():
32
34
  raise NamespaceCreateNotAllowedError("Creating namespace is not allowed")
33
35
 
34
36
  Namespace.validate_name(name)
datachain/lib/projects.py CHANGED
@@ -32,7 +32,9 @@ def create(
32
32
  """
33
33
  session = Session.get(session)
34
34
 
35
- if not session.catalog.metastore.project_allowed_to_create:
35
+ from datachain.lib.dc.utils import is_studio
36
+
37
+ if not is_studio():
36
38
  raise ProjectCreateNotAllowedError("Creating project is not allowed")
37
39
 
38
40
  Project.validate_name(name)
@@ -34,7 +34,7 @@ from datachain.lib.data_model import DataModel, DataType, DataValue
34
34
  from datachain.lib.file import File
35
35
  from datachain.lib.model_store import ModelStore
36
36
  from datachain.lib.utils import DataChainParamsError
37
- from datachain.query.schema import DEFAULT_DELIMITER, Column, ColumnMeta
37
+ from datachain.query.schema import DEFAULT_DELIMITER, C, Column, ColumnMeta
38
38
  from datachain.sql.types import SQLType
39
39
 
40
40
  if TYPE_CHECKING:
@@ -680,35 +680,46 @@ class SignalSchema:
680
680
  primitives = (bool, str, int, float)
681
681
 
682
682
  for name, value in args_map.items():
683
+ current_type = None
684
+
685
+ if C.is_nested(name):
686
+ try:
687
+ current_type = self.get_column_type(name)
688
+ except SignalResolvingError as err:
689
+ msg = f"Creating new nested columns directly is not allowed: {name}"
690
+ raise ValueError(msg) from err
691
+
683
692
  if isinstance(value, Column) and value.name in self.values:
684
693
  # renaming existing signal
694
+ # Note: it won't touch nested signals here (e.g. file__path)
695
+ # we don't allow removing nested columns to keep objects consistent
685
696
  del new_values[value.name]
686
697
  new_values[name] = self.values[value.name]
687
- continue
688
- if isinstance(value, Column):
698
+ elif isinstance(value, Column):
689
699
  # adding new signal from existing signal field
690
- try:
691
- new_values[name] = self.get_column_type(
692
- value.name, with_subtree=True
693
- )
694
- continue
695
- except SignalResolvingError:
696
- pass
697
- if isinstance(value, Func):
700
+ new_values[name] = self.get_column_type(value.name, with_subtree=True)
701
+ elif isinstance(value, Func):
698
702
  # adding new signal with function
699
703
  new_values[name] = value.get_result_type(self)
700
- continue
701
- if isinstance(value, primitives):
704
+ elif isinstance(value, primitives):
702
705
  # For primitives, store the type, not the value
703
706
  val = literal(value)
704
707
  val.type = python_to_sql(type(value))()
705
708
  new_values[name] = sql_to_python(val)
706
- continue
707
- if isinstance(value, ColumnElement):
709
+ elif isinstance(value, ColumnElement):
708
710
  # adding new signal
709
711
  new_values[name] = sql_to_python(value)
710
- continue
711
- new_values[name] = value
712
+ else:
713
+ new_values[name] = value
714
+
715
+ if C.is_nested(name):
716
+ if current_type != new_values[name]:
717
+ msg = (
718
+ f"Altering nested column type is not allowed: {name}, "
719
+ f"current type: {current_type}, new type: {new_values[name]}"
720
+ )
721
+ raise ValueError(msg)
722
+ del new_values[name]
712
723
 
713
724
  return SignalSchema(new_values)
714
725
 
datachain/listing.py CHANGED
@@ -65,17 +65,13 @@ class Listing:
65
65
 
66
66
  @cached_property
67
67
  def dataset(self) -> "DatasetRecord":
68
- from datachain.error import DatasetNotFoundError
69
-
70
68
  assert self.dataset_name
71
69
  project = self.metastore.listing_project
72
- try:
73
- return self.metastore.get_dataset(self.dataset_name, project.id)
74
- except DatasetNotFoundError:
75
- raise DatasetNotFoundError(
76
- f"Dataset {self.dataset_name} not found in namespace"
77
- f" {project.namespace.name} and project {project.name}"
78
- ) from None
70
+ return self.metastore.get_dataset(
71
+ self.dataset_name,
72
+ namespace_name=project.namespace.name,
73
+ project_name=project.name,
74
+ )
79
75
 
80
76
  @cached_property
81
77
  def dataset_rows(self):
@@ -31,11 +31,11 @@ class YoloBBox(DataModel):
31
31
  if not summary:
32
32
  return YoloBBox(box=BBox())
33
33
  name = summary[0].get("name", "")
34
- box = (
35
- BBox.from_dict(summary[0]["box"], title=name)
36
- if summary[0].get("box")
37
- else BBox()
38
- )
34
+ if summary[0].get("box"):
35
+ assert isinstance(summary[0]["box"], dict)
36
+ box = BBox.from_dict(summary[0]["box"], title=name)
37
+ else:
38
+ box = BBox()
39
39
  return YoloBBox(
40
40
  cls=summary[0]["class"],
41
41
  name=name,
@@ -70,7 +70,8 @@ class YoloBBoxes(DataModel):
70
70
  names.append(name)
71
71
  confidence.append(s["confidence"])
72
72
  if s.get("box"):
73
- box.append(BBox.from_dict(s.get("box"), title=name))
73
+ assert isinstance(s["box"], dict)
74
+ box.append(BBox.from_dict(s["box"], title=name))
74
75
  return YoloBBoxes(
75
76
  cls=cls,
76
77
  name=names,
@@ -101,11 +102,11 @@ class YoloOBBox(DataModel):
101
102
  if not summary:
102
103
  return YoloOBBox(box=OBBox())
103
104
  name = summary[0].get("name", "")
104
- box = (
105
- OBBox.from_dict(summary[0]["box"], title=name)
106
- if summary[0].get("box")
107
- else OBBox()
108
- )
105
+ if summary[0].get("box"):
106
+ assert isinstance(summary[0]["box"], dict)
107
+ box = OBBox.from_dict(summary[0]["box"], title=name)
108
+ else:
109
+ box = OBBox()
109
110
  return YoloOBBox(
110
111
  cls=summary[0]["class"],
111
112
  name=name,
@@ -140,7 +141,8 @@ class YoloOBBoxes(DataModel):
140
141
  names.append(name)
141
142
  confidence.append(s["confidence"])
142
143
  if s.get("box"):
143
- box.append(OBBox.from_dict(s.get("box"), title=name))
144
+ assert isinstance(s["box"], dict)
145
+ box.append(OBBox.from_dict(s["box"], title=name))
144
146
  return YoloOBBoxes(
145
147
  cls=cls,
146
148
  name=names,
@@ -56,16 +56,16 @@ class YoloPose(DataModel):
56
56
  if not summary:
57
57
  return YoloPose(box=BBox(), pose=Pose3D())
58
58
  name = summary[0].get("name", "")
59
- box = (
60
- BBox.from_dict(summary[0]["box"], title=name)
61
- if summary[0].get("box")
62
- else BBox()
63
- )
64
- pose = (
65
- Pose3D.from_dict(summary[0]["keypoints"])
66
- if summary[0].get("keypoints")
67
- else Pose3D()
68
- )
59
+ if summary[0].get("box"):
60
+ assert isinstance(summary[0]["box"], dict)
61
+ box = BBox.from_dict(summary[0]["box"], title=name)
62
+ else:
63
+ box = BBox()
64
+ if summary[0].get("keypoints"):
65
+ assert isinstance(summary[0]["keypoints"], dict)
66
+ pose = Pose3D.from_dict(summary[0]["keypoints"])
67
+ else:
68
+ pose = Pose3D()
69
69
  return YoloPose(
70
70
  cls=summary[0]["class"],
71
71
  name=name,
@@ -103,9 +103,11 @@ class YoloPoses(DataModel):
103
103
  names.append(name)
104
104
  confidence.append(s["confidence"])
105
105
  if s.get("box"):
106
- box.append(BBox.from_dict(s.get("box"), title=name))
106
+ assert isinstance(s["box"], dict)
107
+ box.append(BBox.from_dict(s["box"], title=name))
107
108
  if s.get("keypoints"):
108
- pose.append(Pose3D.from_dict(s.get("keypoints")))
109
+ assert isinstance(s["keypoints"], dict)
110
+ pose.append(Pose3D.from_dict(s["keypoints"]))
109
111
  return YoloPoses(
110
112
  cls=cls,
111
113
  name=names,
@@ -34,16 +34,16 @@ class YoloSegment(DataModel):
34
34
  if not summary:
35
35
  return YoloSegment(box=BBox(), segment=Segment())
36
36
  name = summary[0].get("name", "")
37
- box = (
38
- BBox.from_dict(summary[0]["box"], title=name)
39
- if summary[0].get("box")
40
- else BBox()
41
- )
42
- segment = (
43
- Segment.from_dict(summary[0]["segments"], title=name)
44
- if summary[0].get("segments")
45
- else Segment()
46
- )
37
+ if summary[0].get("box"):
38
+ assert isinstance(summary[0]["box"], dict)
39
+ box = BBox.from_dict(summary[0]["box"], title=name)
40
+ else:
41
+ box = BBox()
42
+ if summary[0].get("segments"):
43
+ assert isinstance(summary[0]["segments"], dict)
44
+ segment = Segment.from_dict(summary[0]["segments"], title=name)
45
+ else:
46
+ segment = Segment()
47
47
  return YoloSegment(
48
48
  cls=summary[0]["class"],
49
49
  name=summary[0]["name"],
@@ -81,9 +81,11 @@ class YoloSegments(DataModel):
81
81
  names.append(name)
82
82
  confidence.append(s["confidence"])
83
83
  if s.get("box"):
84
- box.append(BBox.from_dict(s.get("box"), title=name))
84
+ assert isinstance(s["box"], dict)
85
+ box.append(BBox.from_dict(s["box"], title=name))
85
86
  if s.get("segments"):
86
- segment.append(Segment.from_dict(s.get("segments"), title=name))
87
+ assert isinstance(s["segments"], dict)
88
+ segment.append(Segment.from_dict(s["segments"], title=name))
87
89
  return YoloSegments(
88
90
  cls=cls,
89
91
  name=names,