zenml-nightly 0.70.0.dev20241120__py3-none-any.whl → 0.70.0.dev20241125__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/artifact_config.py +32 -4
  3. zenml/artifacts/utils.py +12 -24
  4. zenml/cli/base.py +1 -1
  5. zenml/client.py +4 -19
  6. zenml/constants.py +1 -0
  7. zenml/integrations/kubernetes/orchestrators/kube_utils.py +8 -7
  8. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +52 -1
  9. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -1
  10. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +3 -3
  11. zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +2 -1
  12. zenml/model/utils.py +0 -24
  13. zenml/models/__init__.py +6 -1
  14. zenml/models/v2/core/artifact_version.py +25 -2
  15. zenml/models/v2/core/model_version.py +0 -4
  16. zenml/models/v2/core/model_version_artifact.py +19 -76
  17. zenml/models/v2/core/model_version_pipeline_run.py +6 -39
  18. zenml/models/v2/core/service_connector.py +4 -0
  19. zenml/models/v2/misc/server_models.py +23 -0
  20. zenml/orchestrators/step_launcher.py +0 -1
  21. zenml/orchestrators/step_run_utils.py +4 -17
  22. zenml/orchestrators/step_runner.py +3 -1
  23. zenml/zen_server/deploy/helm/templates/_environment.tpl +117 -0
  24. zenml/zen_server/deploy/helm/templates/server-db-job.yaml +3 -14
  25. zenml/zen_server/deploy/helm/templates/server-deployment.yaml +16 -4
  26. zenml/zen_server/deploy/helm/templates/server-secret.yaml +2 -17
  27. zenml/zen_server/routers/model_versions_endpoints.py +59 -0
  28. zenml/zen_server/routers/server_endpoints.py +47 -0
  29. zenml/zen_server/routers/workspaces_endpoints.py +0 -130
  30. zenml/zen_server/zen_server_api.py +45 -6
  31. zenml/zen_stores/base_zen_store.py +2 -1
  32. zenml/zen_stores/migrations/utils.py +40 -24
  33. zenml/zen_stores/migrations/versions/ec6307720f92_simplify_model_version_links.py +118 -0
  34. zenml/zen_stores/rest_zen_store.py +42 -5
  35. zenml/zen_stores/schemas/model_schemas.py +10 -94
  36. zenml/zen_stores/schemas/user_schemas.py +0 -8
  37. zenml/zen_stores/schemas/workspace_schemas.py +0 -14
  38. {zenml_nightly-0.70.0.dev20241120.dist-info → zenml_nightly-0.70.0.dev20241125.dist-info}/METADATA +1 -1
  39. {zenml_nightly-0.70.0.dev20241120.dist-info → zenml_nightly-0.70.0.dev20241125.dist-info}/RECORD +42 -41
  40. {zenml_nightly-0.70.0.dev20241120.dist-info → zenml_nightly-0.70.0.dev20241125.dist-info}/LICENSE +0 -0
  41. {zenml_nightly-0.70.0.dev20241120.dist-info → zenml_nightly-0.70.0.dev20241125.dist-info}/WHEEL +0 -0
  42. {zenml_nightly-0.70.0.dev20241120.dist-info → zenml_nightly-0.70.0.dev20241125.dist-info}/entry_points.txt +0 -0
@@ -63,14 +63,14 @@ spec:
63
63
  value: "True"
64
64
  {{- end }}
65
65
  {{- if .Values.zenml.database.url }}
66
- - name: ZENML_STORE_TYPE
67
- value: sql
68
66
  - name: DISABLE_DATABASE_MIGRATION
69
67
  value: "True"
70
- - name: ZENML_STORE_SSL_VERIFY_SERVER_CERT
71
- value: {{ .Values.zenml.database.sslVerifyServerCert | default "false" | quote }}
72
68
  {{- end }}
73
69
 
70
+ {{- range $k, $v := include "zenml.storeEnvVariables" . | fromYaml }}
71
+ - name: {{ $k }}
72
+ value: {{ $v | quote }}
73
+ {{- end }}
74
74
  {{- range $k, $v := include "zenml.serverEnvVariables" . | fromYaml }}
75
75
  - name: {{ $k }}
76
76
  value: {{ $v | quote }}
@@ -104,6 +104,18 @@ spec:
104
104
  httpGet:
105
105
  path: /health
106
106
  port: http
107
+ lifecycle:
108
+ preStop:
109
+ exec:
110
+ # Give the process 15 more seconds before the SIGTERM signal is
111
+ # sent. This allows the endpoint removal to reach the ingress
112
+ # controller in time and for traffic to be routed away from the
113
+ # pod before it is shut down. This eliminates the number of 502
114
+ # errors returned to the user.
115
+ #
116
+ # See https://learnk8s.io/graceful-shutdown for more information.
117
+ #
118
+ command: ["sleep", "15"]
107
119
  resources:
108
120
  {{- toYaml .Values.resources | nindent 12 }}
109
121
  {{- with .Values.nodeSelector }}
@@ -10,23 +10,8 @@ data:
10
10
  {{- else }}
11
11
  ZENML_SERVER_JWT_SECRET_KEY: {{ $prevServerSecret.data.ZENML_SERVER_JWT_SECRET_KEY | default (randAlphaNum 32 | b64enc | quote) }}
12
12
  {{- end }}
13
- {{- if .Values.zenml.database.url }}
14
- ZENML_STORE_URL: {{ .Values.zenml.database.url | b64enc | quote }}
15
- {{- if .Values.zenml.database.sslCa }}
16
- ZENML_STORE_SSL_CA: {{ .Files.Get .Values.zenml.database.sslCa | b64enc }}
17
- {{- end }}
18
- {{- if .Values.zenml.database.sslCert }}
19
- ZENML_STORE_SSL_CERT: {{ .Files.Get .Values.zenml.database.sslCert | b64enc }}
20
- {{- end }}
21
- {{- if .Values.zenml.database.sslKey }}
22
- ZENML_STORE_SSL_KEY: {{ .Files.Get .Values.zenml.database.sslKey | b64enc }}
23
- {{- end }}
24
- {{- if .Values.zenml.database.poolSize }}
25
- ZENML_STORE_POOL_SIZE: {{ .Values.zenml.database.poolSize | b64enc | quote }}
26
- {{- end }}
27
- {{- if .Values.zenml.database.maxOverflow }}
28
- ZENML_STORE_MAX_OVERFLOW: {{ .Values.zenml.database.maxOverflow | b64enc | quote }}
29
- {{- end }}
13
+ {{- range $k, $v := include "zenml.storeSecretEnvVariables" . | fromYaml}}
14
+ {{ $k }}: {{ $v | b64enc | quote }}
30
15
  {{- end }}
31
16
  {{- range $k, $v := include "zenml.secretsStoreSecretEnvVariables" . | fromYaml}}
32
17
  {{ $k }}: {{ $v | b64enc | quote }}
@@ -29,9 +29,11 @@ from zenml.constants import (
29
29
  )
30
30
  from zenml.models import (
31
31
  ModelVersionArtifactFilter,
32
+ ModelVersionArtifactRequest,
32
33
  ModelVersionArtifactResponse,
33
34
  ModelVersionFilter,
34
35
  ModelVersionPipelineRunFilter,
36
+ ModelVersionPipelineRunRequest,
35
37
  ModelVersionPipelineRunResponse,
36
38
  ModelVersionResponse,
37
39
  ModelVersionUpdate,
@@ -198,6 +200,34 @@ model_version_artifacts_router = APIRouter(
198
200
  )
199
201
 
200
202
 
203
+ @model_version_artifacts_router.post(
204
+ "",
205
+ responses={401: error_response, 409: error_response, 422: error_response},
206
+ )
207
+ @handle_exceptions
208
+ def create_model_version_artifact_link(
209
+ model_version_artifact_link: ModelVersionArtifactRequest,
210
+ _: AuthContext = Security(authorize),
211
+ ) -> ModelVersionArtifactResponse:
212
+ """Create a new model version to artifact link.
213
+
214
+ Args:
215
+ model_version_artifact_link: The model version to artifact link to create.
216
+
217
+ Returns:
218
+ The created model version to artifact link.
219
+ """
220
+ model_version = zen_store().get_model_version(
221
+ model_version_artifact_link.model_version
222
+ )
223
+ verify_permission_for_model(model_version, action=Action.UPDATE)
224
+
225
+ mv = zen_store().create_model_version_artifact_link(
226
+ model_version_artifact_link
227
+ )
228
+ return mv
229
+
230
+
201
231
  @model_version_artifacts_router.get(
202
232
  "",
203
233
  response_model=Page[ModelVersionArtifactResponse],
@@ -291,6 +321,35 @@ model_version_pipeline_runs_router = APIRouter(
291
321
  )
292
322
 
293
323
 
324
+ @model_version_pipeline_runs_router.post(
325
+ "",
326
+ responses={401: error_response, 409: error_response, 422: error_response},
327
+ )
328
+ @handle_exceptions
329
+ def create_model_version_pipeline_run_link(
330
+ model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
331
+ _: AuthContext = Security(authorize),
332
+ ) -> ModelVersionPipelineRunResponse:
333
+ """Create a new model version to pipeline run link.
334
+
335
+ Args:
336
+ model_version_pipeline_run_link: The model version to pipeline run link to create.
337
+
338
+ Returns:
339
+ - If Model Version to Pipeline Run Link already exists - returns the existing link.
340
+ - Otherwise, returns the newly created model version to pipeline run link.
341
+ """
342
+ model_version = zen_store().get_model_version(
343
+ model_version_pipeline_run_link.model_version, hydrate=False
344
+ )
345
+ verify_permission_for_model(model_version, action=Action.UPDATE)
346
+
347
+ mv = zen_store().create_model_version_pipeline_run_link(
348
+ model_version_pipeline_run_link
349
+ )
350
+ return mv
351
+
352
+
294
353
  @model_version_pipeline_runs_router.get(
295
354
  "",
296
355
  response_model=Page[ModelVersionPipelineRunResponse],
@@ -22,6 +22,7 @@ from zenml.constants import (
22
22
  ACTIVATE,
23
23
  API,
24
24
  INFO,
25
+ LOAD_INFO,
25
26
  ONBOARDING_STATE,
26
27
  SERVER_SETTINGS,
27
28
  VERSION_1,
@@ -30,6 +31,7 @@ from zenml.enums import AuthScheme
30
31
  from zenml.exceptions import IllegalOperationError
31
32
  from zenml.models import (
32
33
  ServerActivationRequest,
34
+ ServerLoadInfo,
33
35
  ServerModel,
34
36
  ServerSettingsResponse,
35
37
  ServerSettingsUpdate,
@@ -71,6 +73,51 @@ def server_info() -> ServerModel:
71
73
  return zen_store().get_store_info()
72
74
 
73
75
 
76
+ @router.get(
77
+ LOAD_INFO,
78
+ response_model=ServerLoadInfo,
79
+ )
80
+ @handle_exceptions
81
+ def server_load_info(_: AuthContext = Security(authorize)) -> ServerLoadInfo:
82
+ """Get information about the server load.
83
+
84
+ Returns:
85
+ Information about the server load.
86
+ """
87
+ import threading
88
+
89
+ # Get the current number of threads
90
+ num_threads = len(threading.enumerate())
91
+
92
+ store = zen_store()
93
+
94
+ if store.config.driver == "sqlite":
95
+ # SQLite doesn't have a connection pool
96
+ return ServerLoadInfo(
97
+ threads=num_threads,
98
+ db_connections_total=0,
99
+ db_connections_active=0,
100
+ db_connections_overflow=0,
101
+ )
102
+
103
+ from sqlalchemy.pool import QueuePool
104
+
105
+ # Get the number of connections
106
+ pool = store.engine.pool
107
+ assert isinstance(pool, QueuePool)
108
+ idle_conn = pool.checkedin()
109
+ active_conn = pool.checkedout()
110
+ overflow_conn = max(0, pool.overflow())
111
+ total_conn = idle_conn + active_conn
112
+
113
+ return ServerLoadInfo(
114
+ threads=num_threads,
115
+ db_connections_total=total_conn,
116
+ db_connections_active=active_conn,
117
+ db_connections_overflow=overflow_conn,
118
+ )
119
+
120
+
74
121
  @router.get(
75
122
  ONBOARDING_STATE,
76
123
  responses={
@@ -20,7 +20,6 @@ from fastapi import APIRouter, Depends, Security
20
20
 
21
21
  from zenml.constants import (
22
22
  API,
23
- ARTIFACTS,
24
23
  CODE_REPOSITORIES,
25
24
  GET_OR_CREATE,
26
25
  MODEL_VERSIONS,
@@ -54,10 +53,6 @@ from zenml.models import (
54
53
  ComponentResponse,
55
54
  ModelRequest,
56
55
  ModelResponse,
57
- ModelVersionArtifactRequest,
58
- ModelVersionArtifactResponse,
59
- ModelVersionPipelineRunRequest,
60
- ModelVersionPipelineRunResponse,
61
56
  ModelVersionRequest,
62
57
  ModelVersionResponse,
63
58
  Page,
@@ -1442,131 +1437,6 @@ def create_model_version(
1442
1437
  )
1443
1438
 
1444
1439
 
1445
- @router.post(
1446
- WORKSPACES
1447
- + "/{workspace_name_or_id}"
1448
- + MODEL_VERSIONS
1449
- + "/{model_version_id}"
1450
- + ARTIFACTS,
1451
- response_model=ModelVersionArtifactResponse,
1452
- responses={401: error_response, 409: error_response, 422: error_response},
1453
- )
1454
- @handle_exceptions
1455
- def create_model_version_artifact_link(
1456
- workspace_name_or_id: Union[str, UUID],
1457
- model_version_id: UUID,
1458
- model_version_artifact_link: ModelVersionArtifactRequest,
1459
- auth_context: AuthContext = Security(authorize),
1460
- ) -> ModelVersionArtifactResponse:
1461
- """Create a new model version to artifact link.
1462
-
1463
- Args:
1464
- workspace_name_or_id: Name or ID of the workspace.
1465
- model_version_id: ID of the model version.
1466
- model_version_artifact_link: The model version to artifact link to create.
1467
- auth_context: Authentication context.
1468
-
1469
- Returns:
1470
- The created model version to artifact link.
1471
-
1472
- Raises:
1473
- IllegalOperationError: If the workspace or user specified in the
1474
- model version does not match the current workspace or authenticated
1475
- user.
1476
- """
1477
- workspace = zen_store().get_workspace(workspace_name_or_id)
1478
- if str(model_version_id) != str(model_version_artifact_link.model_version):
1479
- raise IllegalOperationError(
1480
- f"The model version id in your path `{model_version_id}` does not "
1481
- f"match the model version specified in the request model "
1482
- f"`{model_version_artifact_link.model_version}`"
1483
- )
1484
-
1485
- if model_version_artifact_link.workspace != workspace.id:
1486
- raise IllegalOperationError(
1487
- "Creating model version to artifact links outside of the workspace scope "
1488
- f"of this endpoint `{workspace_name_or_id}` is "
1489
- f"not supported."
1490
- )
1491
- if model_version_artifact_link.user != auth_context.user.id:
1492
- raise IllegalOperationError(
1493
- "Creating model to artifact links for a user other than yourself "
1494
- "is not supported."
1495
- )
1496
-
1497
- model_version = zen_store().get_model_version(model_version_id)
1498
- verify_permission_for_model(model_version, action=Action.UPDATE)
1499
-
1500
- mv = zen_store().create_model_version_artifact_link(
1501
- model_version_artifact_link
1502
- )
1503
- return mv
1504
-
1505
-
1506
- @router.post(
1507
- WORKSPACES
1508
- + "/{workspace_name_or_id}"
1509
- + MODEL_VERSIONS
1510
- + "/{model_version_id}"
1511
- + RUNS,
1512
- response_model=ModelVersionPipelineRunResponse,
1513
- responses={401: error_response, 409: error_response, 422: error_response},
1514
- )
1515
- @handle_exceptions
1516
- def create_model_version_pipeline_run_link(
1517
- workspace_name_or_id: Union[str, UUID],
1518
- model_version_id: UUID,
1519
- model_version_pipeline_run_link: ModelVersionPipelineRunRequest,
1520
- auth_context: AuthContext = Security(authorize),
1521
- ) -> ModelVersionPipelineRunResponse:
1522
- """Create a new model version to pipeline run link.
1523
-
1524
- Args:
1525
- workspace_name_or_id: Name or ID of the workspace.
1526
- model_version_id: ID of the model version.
1527
- model_version_pipeline_run_link: The model version to pipeline run link to create.
1528
- auth_context: Authentication context.
1529
-
1530
- Returns:
1531
- - If Model Version to Pipeline Run Link already exists - returns the existing link.
1532
- - Otherwise, returns the newly created model version to pipeline run link.
1533
-
1534
- Raises:
1535
- IllegalOperationError: If the workspace or user specified in the
1536
- model version does not match the current workspace or authenticated
1537
- user.
1538
- """
1539
- workspace = zen_store().get_workspace(workspace_name_or_id)
1540
- if str(model_version_id) != str(
1541
- model_version_pipeline_run_link.model_version
1542
- ):
1543
- raise IllegalOperationError(
1544
- f"The model version id in your path `{model_version_id}` does not "
1545
- f"match the model version specified in the request model "
1546
- f"`{model_version_pipeline_run_link.model_version}`"
1547
- )
1548
-
1549
- if model_version_pipeline_run_link.workspace != workspace.id:
1550
- raise IllegalOperationError(
1551
- "Creating model versions outside of the workspace scope "
1552
- f"of this endpoint `{workspace_name_or_id}` is "
1553
- f"not supported."
1554
- )
1555
- if model_version_pipeline_run_link.user != auth_context.user.id:
1556
- raise IllegalOperationError(
1557
- "Creating models for a user other than yourself "
1558
- "is not supported."
1559
- )
1560
-
1561
- model_version = zen_store().get_model_version(model_version_id)
1562
- verify_permission_for_model(model_version, action=Action.UPDATE)
1563
-
1564
- mv = zen_store().create_model_version_pipeline_run_link(
1565
- model_version_pipeline_run_link
1566
- )
1567
- return mv
1568
-
1569
-
1570
1440
  @router.post(
1571
1441
  WORKSPACES + "/{workspace_name_or_id}" + SERVICES,
1572
1442
  response_model=ServiceResponse,
@@ -180,7 +180,15 @@ class RequestBodyLimit(BaseHTTPMiddleware):
180
180
  if content_length := request.headers.get("content-length"):
181
181
  if int(content_length) > self.max_bytes:
182
182
  return Response(status_code=413) # Request Entity Too Large
183
- return await call_next(request)
183
+
184
+ try:
185
+ return await call_next(request)
186
+ except Exception:
187
+ logger.exception("An error occurred while processing the request")
188
+ return JSONResponse(
189
+ status_code=500,
190
+ content={"detail": "An unexpected error occurred."},
191
+ )
184
192
 
185
193
 
186
194
  class RestrictFileUploadsMiddleware(BaseHTTPMiddleware):
@@ -220,7 +228,15 @@ class RestrictFileUploadsMiddleware(BaseHTTPMiddleware):
220
228
  "detail": "File uploads are not allowed on this endpoint."
221
229
  },
222
230
  )
223
- return await call_next(request)
231
+
232
+ try:
233
+ return await call_next(request)
234
+ except Exception:
235
+ logger.exception("An error occurred while processing the request")
236
+ return JSONResponse(
237
+ status_code=500,
238
+ content={"detail": "An unexpected error occurred."},
239
+ )
224
240
 
225
241
 
226
242
  ALLOWED_FOR_FILE_UPLOAD: Set[str] = set()
@@ -252,13 +268,21 @@ async def set_secure_headers(request: Request, call_next: Any) -> Any:
252
268
  Returns:
253
269
  The response with secure headers set.
254
270
  """
271
+ try:
272
+ response = await call_next(request)
273
+ except Exception:
274
+ logger.exception("An error occurred while processing the request")
275
+ response = JSONResponse(
276
+ status_code=500,
277
+ content={"detail": "An unexpected error occurred."},
278
+ )
279
+
255
280
  # If the request is for the openAPI docs, don't set secure headers
256
281
  if request.url.path.startswith("/docs") or request.url.path.startswith(
257
282
  "/redoc"
258
283
  ):
259
- return await call_next(request)
284
+ return response
260
285
 
261
- response = await call_next(request)
262
286
  secure_headers().framework.fastapi(response)
263
287
  return response
264
288
 
@@ -298,7 +322,15 @@ async def track_last_user_activity(request: Request, call_next: Any) -> Any:
298
322
  zen_store()._update_last_user_activity_timestamp(
299
323
  last_user_activity=last_user_activity
300
324
  )
301
- return await call_next(request)
325
+
326
+ try:
327
+ return await call_next(request)
328
+ except Exception:
329
+ logger.exception("An error occurred while processing the request")
330
+ return JSONResponse(
331
+ status_code=500,
332
+ content={"detail": "An unexpected error occurred."},
333
+ )
302
334
 
303
335
 
304
336
  @app.middleware("http")
@@ -330,7 +362,14 @@ async def infer_source_context(request: Request, call_next: Any) -> Any:
330
362
  )
331
363
  source_context.set(SourceContextTypes.API)
332
364
 
333
- return await call_next(request)
365
+ try:
366
+ return await call_next(request)
367
+ except Exception:
368
+ logger.exception("An error occurred while processing the request")
369
+ return JSONResponse(
370
+ status_code=500,
371
+ content={"detail": "An unexpected error occurred."},
372
+ )
334
373
 
335
374
 
336
375
  @app.on_event("startup")
@@ -42,6 +42,7 @@ from zenml.enums import (
42
42
  SecretsStoreType,
43
43
  StoreType,
44
44
  )
45
+ from zenml.exceptions import IllegalOperationError
45
46
  from zenml.logger import get_logger
46
47
  from zenml.models import (
47
48
  ServerDatabaseType,
@@ -335,7 +336,7 @@ class BaseZenStore(
335
336
  # Ensure that the active stack is still valid
336
337
  try:
337
338
  active_stack = self.get_stack(stack_id=active_stack_id)
338
- except KeyError:
339
+ except (KeyError, IllegalOperationError):
339
340
  logger.warning(
340
341
  "The current %s active stack is no longer available. "
341
342
  "Resetting the active stack to default.",
@@ -273,30 +273,25 @@ class MigrationUtils(BaseModel):
273
273
  + "\n);"
274
274
  )
275
275
 
276
+ # Detect self-referential foreign keys from the table schema
277
+ has_self_referential_foreign_keys = False
278
+ for fk in table.foreign_keys:
279
+ # Check if the foreign key points to the same table
280
+ if fk.column.table == table:
281
+ has_self_referential_foreign_keys = True
282
+ break
283
+
276
284
  # Store the table schema
277
285
  store_db_info(
278
- dict(table=table.name, create_stmt=create_table_stmt)
286
+ dict(
287
+ table=table.name,
288
+ create_stmt=create_table_stmt,
289
+ self_references=has_self_referential_foreign_keys,
290
+ )
279
291
  )
280
292
 
281
293
  # 2. extract the table data in batches
282
-
283
- # If the table has a `created` column, we use it to sort
284
- # the rows in the table starting with the oldest rows.
285
- # This is to ensure that the rows are inserted in the
286
- # correct order, since some tables have inner foreign key
287
- # constraints.
288
- if "created" in table.columns:
289
- order_by = [table.columns["created"]]
290
- else:
291
- order_by = []
292
- if "id" in table.columns:
293
- # If the table has an `id` column, we also use it to sort
294
- # the rows in the table, even if we already use "created"
295
- # to sort the rows. We need a unique field to sort the rows,
296
- # to break the tie between rows with the same "created"
297
- # date, otherwise the same entry might end up multiple times
298
- # in subsequent pages.
299
- order_by.append(table.columns["id"])
294
+ order_by = [col for col in table.primary_key]
300
295
 
301
296
  # Fetch the number of rows in the table
302
297
  row_count = conn.scalar(
@@ -305,7 +300,7 @@ class MigrationUtils(BaseModel):
305
300
 
306
301
  # Fetch the data from the table in batches
307
302
  if row_count is not None:
308
- batch_size = 50
303
+ batch_size = 100
309
304
  for i in range(0, row_count, batch_size):
310
305
  rows = conn.execute(
311
306
  table.select()
@@ -349,6 +344,7 @@ class MigrationUtils(BaseModel):
349
344
 
350
345
  with self.engine.begin() as connection:
351
346
  # read the DB information one JSON object at a time
347
+ self_references: Dict[str, bool] = {}
352
348
  for table_dump in load_db_info():
353
349
  table_name = table_dump["table"]
354
350
  if "create_stmt" in table_dump:
@@ -356,10 +352,22 @@ class MigrationUtils(BaseModel):
356
352
  connection.execute(text(table_dump["create_stmt"]))
357
353
  # Reload the database metadata after creating the table
358
354
  metadata.reflect(bind=self.engine)
355
+ self_references[table_name] = table_dump.get(
356
+ "self_references", False
357
+ )
359
358
 
360
359
  if "data" in table_dump:
361
360
  # insert the data into the database
362
361
  table = metadata.tables[table_name]
362
+ if self_references.get(table_name, False):
363
+ # If the table has self-referential foreign keys, we
364
+ # need to disable the foreign key checks before inserting
365
+ # the rows and re-enable them afterwards. This is because
366
+ # the rows need to be inserted in the correct order to
367
+ # satisfy the foreign key constraints and we don't sort
368
+ # the rows by creation time in the backup.
369
+ connection.execute(text("SET FOREIGN_KEY_CHECKS = 0"))
370
+
363
371
  for row in table_dump["data"]:
364
372
  # Convert column values to the correct type
365
373
  for column in table.columns:
@@ -372,10 +380,18 @@ class MigrationUtils(BaseModel):
372
380
  row[column.name], "utf-8"
373
381
  )
374
382
 
375
- # Insert the rows into the table
376
- connection.execute(
377
- table.insert().values(table_dump["data"])
378
- )
383
+ # Insert the rows into the table in batches
384
+ batch_size = 100
385
+ for i in range(0, len(table_dump["data"]), batch_size):
386
+ connection.execute(
387
+ table.insert().values(
388
+ table_dump["data"][i : i + batch_size]
389
+ )
390
+ )
391
+
392
+ if table_dump.get("self_references", False):
393
+ # Re-enable the foreign key checks after inserting the rows
394
+ connection.execute(text("SET FOREIGN_KEY_CHECKS = 1"))
379
395
 
380
396
  def backup_database_to_file(self, dump_file: str) -> None:
381
397
  """Backup the database to a file.