unstructured-ingest 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (30) hide show
  1. test/integration/connectors/conftest.py +13 -0
  2. test/integration/connectors/databricks_tests/test_volumes_native.py +8 -4
  3. test/integration/connectors/sql/test_postgres.py +6 -10
  4. test/integration/connectors/sql/test_snowflake.py +205 -0
  5. test/integration/connectors/sql/test_sqlite.py +6 -10
  6. test/integration/connectors/test_delta_table.py +138 -0
  7. test/integration/connectors/utils/docker.py +78 -0
  8. test/integration/connectors/utils/validation.py +93 -2
  9. unstructured_ingest/__version__.py +1 -1
  10. unstructured_ingest/v2/cli/utils/click.py +32 -1
  11. unstructured_ingest/v2/cli/utils/model_conversion.py +10 -3
  12. unstructured_ingest/v2/interfaces/indexer.py +4 -1
  13. unstructured_ingest/v2/pipeline/pipeline.py +10 -2
  14. unstructured_ingest/v2/pipeline/steps/index.py +18 -1
  15. unstructured_ingest/v2/processes/connectors/__init__.py +10 -0
  16. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +1 -1
  17. unstructured_ingest/v2/processes/connectors/delta_table.py +185 -0
  18. unstructured_ingest/v2/processes/connectors/slack.py +248 -0
  19. unstructured_ingest/v2/processes/connectors/sql/__init__.py +10 -2
  20. unstructured_ingest/v2/processes/connectors/sql/postgres.py +33 -37
  21. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +164 -0
  22. unstructured_ingest/v2/processes/connectors/sql/sql.py +38 -10
  23. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +31 -32
  24. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/METADATA +14 -12
  25. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/RECORD +29 -24
  26. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +0 -250
  27. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/LICENSE.md +0 -0
  28. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/WHEEL +0 -0
  29. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/entry_points.txt +0 -0
  30. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,11 @@
1
+ import tempfile
1
2
  from pathlib import Path
3
+ from typing import Generator
2
4
 
3
5
  import pytest
4
6
 
7
+ from unstructured_ingest.v2.logger import logger
8
+
5
9
  FILENAME = "DA-1p-with-duplicate-pages.pdf.json"
6
10
 
7
11
 
@@ -13,3 +17,12 @@ def upload_file() -> Path:
13
17
  assert upload_file.exists()
14
18
  assert upload_file.is_file()
15
19
  return upload_file
20
+
21
+
22
+ @pytest.fixture
23
+ def temp_dir() -> Generator[Path, None, None]:
24
+ with tempfile.TemporaryDirectory() as temp_dir:
25
+ temp_path = Path(temp_dir)
26
+ logger.info(f"Created temp dir '{temp_path}'")
27
+ yield temp_path
28
+ logger.info(f"Removing temp dir '{temp_path}'")
@@ -16,7 +16,7 @@ from test.integration.connectors.utils.validation import (
16
16
  source_connector_validation,
17
17
  )
18
18
  from test.integration.utils import requires_env
19
- from unstructured_ingest.v2.interfaces import FileData
19
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
20
20
  from unstructured_ingest.v2.processes.connectors.databricks.volumes_native import (
21
21
  CONNECTOR_TYPE,
22
22
  DatabricksNativeVolumesAccessConfig,
@@ -139,7 +139,11 @@ def validate_upload(client: WorkspaceClient, catalog: str, volume: str, volume_p
139
139
  async def test_volumes_native_destination(upload_file: Path):
140
140
  env_data = get_env_data()
141
141
  volume_path = f"databricks-volumes-test-output-{uuid.uuid4()}"
142
- mock_file_data = FileData(identifier="mock file data", connector_type=CONNECTOR_TYPE)
142
+ file_data = FileData(
143
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
144
+ connector_type=CONNECTOR_TYPE,
145
+ identifier="mock file data",
146
+ )
143
147
  with databricks_destination_context(
144
148
  volume="test-platform", volume_path=volume_path, env_data=env_data
145
149
  ) as workspace_client:
@@ -153,9 +157,9 @@ async def test_volumes_native_destination(upload_file: Path):
153
157
  ),
154
158
  )
155
159
  if uploader.is_async():
156
- await uploader.run_async(path=upload_file, file_data=mock_file_data)
160
+ await uploader.run_async(path=upload_file, file_data=file_data)
157
161
  else:
158
- uploader.run(path=upload_file, file_data=mock_file_data)
162
+ uploader.run(path=upload_file, file_data=file_data)
159
163
 
160
164
  validate_upload(
161
165
  client=workspace_client,
@@ -2,7 +2,6 @@ import tempfile
2
2
  from contextlib import contextmanager
3
3
  from pathlib import Path
4
4
 
5
- import faker
6
5
  import pandas as pd
7
6
  import pytest
8
7
  from psycopg2 import connect
@@ -26,9 +25,7 @@ from unstructured_ingest.v2.processes.connectors.sql.postgres import (
26
25
  PostgresUploadStager,
27
26
  )
28
27
 
29
- faker = faker.Faker()
30
-
31
- SEED_DATA_ROWS = 40
28
+ SEED_DATA_ROWS = 20
32
29
 
33
30
 
34
31
  @contextmanager
@@ -42,11 +39,8 @@ def postgres_download_setup() -> None:
42
39
  port=5433,
43
40
  )
44
41
  with connection.cursor() as cursor:
45
- for _ in range(SEED_DATA_ROWS):
46
- sql_statment = (
47
- f"INSERT INTO cars (brand, price) VALUES "
48
- f"('{faker.word()}', {faker.random_int()})"
49
- )
42
+ for i in range(SEED_DATA_ROWS):
43
+ sql_statment = f"INSERT INTO cars (brand, price) VALUES " f"('brand_{i}', {i})"
50
44
  cursor.execute(sql_statment)
51
45
  connection.commit()
52
46
  yield
@@ -88,7 +82,9 @@ async def test_postgres_source():
88
82
  downloader=downloader,
89
83
  configs=ValidationConfigs(
90
84
  test_id="postgres",
91
- expected_num_files=40,
85
+ expected_num_files=SEED_DATA_ROWS,
86
+ expected_number_indexed_file_data=4,
87
+ validate_downloaded_files=True,
92
88
  ),
93
89
  )
94
90
 
@@ -0,0 +1,205 @@
1
+ import os
2
+ import tempfile
3
+ from pathlib import Path
4
+
5
+ import docker
6
+ import pandas as pd
7
+ import pytest
8
+ import snowflake.connector as sf
9
+
10
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG, env_setup_path
11
+ from test.integration.connectors.utils.docker import container_context
12
+ from test.integration.connectors.utils.validation import (
13
+ ValidationConfigs,
14
+ source_connector_validation,
15
+ )
16
+ from test.integration.utils import requires_env
17
+ from unstructured_ingest.v2.interfaces import FileData
18
+ from unstructured_ingest.v2.processes.connectors.sql.snowflake import (
19
+ CONNECTOR_TYPE,
20
+ SnowflakeAccessConfig,
21
+ SnowflakeConnectionConfig,
22
+ SnowflakeDownloader,
23
+ SnowflakeDownloaderConfig,
24
+ SnowflakeIndexer,
25
+ SnowflakeIndexerConfig,
26
+ SnowflakeUploader,
27
+ SnowflakeUploadStager,
28
+ )
29
+
30
+ SEED_DATA_ROWS = 20
31
+
32
+
33
+ def seed_data():
34
+ conn = sf.connect(
35
+ user="test",
36
+ password="test",
37
+ account="test",
38
+ database="test",
39
+ host="snowflake.localhost.localstack.cloud",
40
+ )
41
+
42
+ file = Path(env_setup_path / "sql" / "snowflake" / "source" / "snowflake-schema.sql")
43
+
44
+ with file.open() as f:
45
+ sql = f.read()
46
+
47
+ cur = conn.cursor()
48
+ cur.execute(sql)
49
+ for i in range(SEED_DATA_ROWS):
50
+ sql_statment = f"INSERT INTO cars (brand, price) VALUES " f"('brand_{i}', {i})"
51
+ cur.execute(sql_statment)
52
+
53
+ cur.close()
54
+ conn.close()
55
+
56
+
57
+ def init_db_destination():
58
+ conn = sf.connect(
59
+ user="test",
60
+ password="test",
61
+ account="test",
62
+ database="test",
63
+ host="snowflake.localhost.localstack.cloud",
64
+ )
65
+
66
+ file = Path(env_setup_path / "sql" / "snowflake" / "destination" / "snowflake-schema.sql")
67
+
68
+ with file.open() as f:
69
+ sql = f.read()
70
+
71
+ cur = conn.cursor()
72
+ cur.execute(sql)
73
+
74
+ cur.close()
75
+ conn.close()
76
+
77
+
78
+ @pytest.mark.asyncio
79
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, "sql")
80
+ @requires_env("LOCALSTACK_AUTH_TOKEN")
81
+ async def test_snowflake_source():
82
+ docker_client = docker.from_env()
83
+ token = os.getenv("LOCALSTACK_AUTH_TOKEN")
84
+ with container_context(
85
+ docker_client=docker_client,
86
+ image="localstack/snowflake",
87
+ environment={"LOCALSTACK_AUTH_TOKEN": token, "EXTRA_CORS_ALLOWED_ORIGINS": "*"},
88
+ ports={4566: 4566, 443: 443},
89
+ healthcheck_timeout=30,
90
+ ):
91
+ seed_data()
92
+ with tempfile.TemporaryDirectory() as tmpdir:
93
+ connection_config = SnowflakeConnectionConfig(
94
+ access_config=SnowflakeAccessConfig(password="test"),
95
+ account="test",
96
+ user="test",
97
+ database="test",
98
+ host="snowflake.localhost.localstack.cloud",
99
+ )
100
+ indexer = SnowflakeIndexer(
101
+ connection_config=connection_config,
102
+ index_config=SnowflakeIndexerConfig(
103
+ table_name="cars", id_column="CAR_ID", batch_size=5
104
+ ),
105
+ )
106
+ downloader = SnowflakeDownloader(
107
+ connection_config=connection_config,
108
+ download_config=SnowflakeDownloaderConfig(
109
+ fields=["CAR_ID", "BRAND"], download_dir=Path(tmpdir)
110
+ ),
111
+ )
112
+ await source_connector_validation(
113
+ indexer=indexer,
114
+ downloader=downloader,
115
+ configs=ValidationConfigs(
116
+ test_id="snowflake",
117
+ expected_num_files=SEED_DATA_ROWS,
118
+ expected_number_indexed_file_data=4,
119
+ validate_downloaded_files=True,
120
+ ),
121
+ )
122
+
123
+
124
+ def validate_destination(
125
+ connect_params: dict,
126
+ expected_num_elements: int,
127
+ ):
128
+ # Run the following validations:
129
+ # * Check that the number of records in the table match the expected value
130
+ # * Given the embedding, make sure it matches the associated text it belongs to
131
+ conn = sf.connect(**connect_params)
132
+ cursor = conn.cursor()
133
+ try:
134
+ query = "select count(*) from elements;"
135
+ cursor.execute(query)
136
+ count = cursor.fetchone()[0]
137
+ assert (
138
+ count == expected_num_elements
139
+ ), f"dest check failed: got {count}, expected {expected_num_elements}"
140
+ finally:
141
+ cursor.close()
142
+ conn.close()
143
+
144
+
145
+ @pytest.mark.asyncio
146
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, "sql")
147
+ @requires_env("LOCALSTACK_AUTH_TOKEN")
148
+ async def test_snowflake_destination(upload_file: Path):
149
+ # the postgres destination connector doesn't leverage the file data but is required as an input,
150
+ # mocking it with arbitrary values to meet the base requirements:
151
+ mock_file_data = FileData(identifier="mock file data", connector_type=CONNECTOR_TYPE)
152
+ docker_client = docker.from_env()
153
+ token = os.getenv("LOCALSTACK_AUTH_TOKEN")
154
+ with container_context(
155
+ docker_client=docker_client,
156
+ image="localstack/snowflake",
157
+ environment={"LOCALSTACK_AUTH_TOKEN": token, "EXTRA_CORS_ALLOWED_ORIGINS": "*"},
158
+ ports={4566: 4566, 443: 443},
159
+ healthcheck_timeout=30,
160
+ ):
161
+ init_db_destination()
162
+ with tempfile.TemporaryDirectory() as tmpdir:
163
+ stager = SnowflakeUploadStager()
164
+ stager_params = {
165
+ "elements_filepath": upload_file,
166
+ "file_data": mock_file_data,
167
+ "output_dir": Path(tmpdir),
168
+ "output_filename": "test_db",
169
+ }
170
+ if stager.is_async():
171
+ staged_path = await stager.run_async(**stager_params)
172
+ else:
173
+ staged_path = stager.run(**stager_params)
174
+
175
+ # The stager should append the `.json` suffix to the output filename passed in.
176
+ assert staged_path.name == "test_db.json"
177
+
178
+ connect_params = {
179
+ "user": "test",
180
+ "password": "test",
181
+ "account": "test",
182
+ "database": "test",
183
+ "host": "snowflake.localhost.localstack.cloud",
184
+ }
185
+
186
+ uploader = SnowflakeUploader(
187
+ connection_config=SnowflakeConnectionConfig(
188
+ access_config=SnowflakeAccessConfig(password=connect_params["password"]),
189
+ account=connect_params["account"],
190
+ user=connect_params["user"],
191
+ database=connect_params["database"],
192
+ host=connect_params["host"],
193
+ )
194
+ )
195
+ if uploader.is_async():
196
+ await uploader.run_async(path=staged_path, file_data=mock_file_data)
197
+ else:
198
+ uploader.run(path=staged_path, file_data=mock_file_data)
199
+
200
+ staged_df = pd.read_json(staged_path, orient="records", lines=True)
201
+ expected_num_elements = len(staged_df)
202
+ validate_destination(
203
+ connect_params=connect_params,
204
+ expected_num_elements=expected_num_elements,
205
+ )
@@ -3,7 +3,6 @@ import tempfile
3
3
  from contextlib import contextmanager
4
4
  from pathlib import Path
5
5
 
6
- import faker
7
6
  import pandas as pd
8
7
  import pytest
9
8
 
@@ -24,9 +23,7 @@ from unstructured_ingest.v2.processes.connectors.sql.sqlite import (
24
23
  SQLiteUploadStager,
25
24
  )
26
25
 
27
- faker = faker.Faker()
28
-
29
- SEED_DATA_ROWS = 40
26
+ SEED_DATA_ROWS = 20
30
27
 
31
28
 
32
29
  @contextmanager
@@ -41,11 +38,8 @@ def sqlite_download_setup() -> Path:
41
38
  with db_init_path.open("r") as f:
42
39
  query = f.read()
43
40
  cursor.executescript(query)
44
- for _ in range(SEED_DATA_ROWS):
45
- sql_statment = (
46
- f"INSERT INTO cars (brand, price) "
47
- f"VALUES ('{faker.word()}', {faker.random_int()})"
48
- )
41
+ for i in range(SEED_DATA_ROWS):
42
+ sql_statment = f"INSERT INTO cars (brand, price) " f"VALUES ('brand{i}', {i})"
49
43
  cursor.execute(sql_statment)
50
44
 
51
45
  sqlite_connection.commit()
@@ -76,7 +70,9 @@ async def test_sqlite_source():
76
70
  downloader=downloader,
77
71
  configs=ValidationConfigs(
78
72
  test_id="sqlite",
79
- expected_num_files=40,
73
+ expected_num_files=SEED_DATA_ROWS,
74
+ expected_number_indexed_file_data=4,
75
+ validate_downloaded_files=True,
80
76
  ),
81
77
  )
82
78
 
@@ -0,0 +1,138 @@
1
+ import multiprocessing
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import pytest
6
+ from deltalake import DeltaTable
7
+ from fsspec import get_filesystem_class
8
+
9
+ from test.integration.connectors.utils.constants import (
10
+ DESTINATION_TAG,
11
+ )
12
+ from test.integration.utils import requires_env
13
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
14
+ from unstructured_ingest.v2.processes.connectors.delta_table import (
15
+ CONNECTOR_TYPE,
16
+ DeltaTableAccessConfig,
17
+ DeltaTableConnectionConfig,
18
+ DeltaTableUploader,
19
+ DeltaTableUploaderConfig,
20
+ DeltaTableUploadStager,
21
+ DeltaTableUploadStagerConfig,
22
+ )
23
+
24
+ multiprocessing.set_start_method("spawn")
25
+
26
+
27
+ @pytest.mark.asyncio
28
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
29
+ async def test_delta_table_destination_local(upload_file: Path, temp_dir: Path):
30
+ destination_path = str(temp_dir)
31
+ connection_config = DeltaTableConnectionConfig(
32
+ access_config=DeltaTableAccessConfig(),
33
+ table_uri=destination_path,
34
+ )
35
+ stager_config = DeltaTableUploadStagerConfig()
36
+ stager = DeltaTableUploadStager(upload_stager_config=stager_config)
37
+ new_upload_file = stager.run(
38
+ elements_filepath=upload_file,
39
+ output_dir=temp_dir,
40
+ output_filename=upload_file.name,
41
+ )
42
+
43
+ upload_config = DeltaTableUploaderConfig()
44
+ uploader = DeltaTableUploader(connection_config=connection_config, upload_config=upload_config)
45
+ file_data = FileData(
46
+ source_identifiers=SourceIdentifiers(
47
+ fullpath=upload_file.name, filename=new_upload_file.name
48
+ ),
49
+ connector_type=CONNECTOR_TYPE,
50
+ identifier="mock file data",
51
+ )
52
+
53
+ if uploader.is_async():
54
+ await uploader.run_async(path=new_upload_file, file_data=file_data)
55
+ else:
56
+ uploader.run(path=new_upload_file, file_data=file_data)
57
+ delta_table_path = os.path.join(destination_path, upload_file.name)
58
+ delta_table = DeltaTable(table_uri=delta_table_path)
59
+ df = delta_table.to_pandas()
60
+
61
+ EXPECTED_COLUMNS = 10
62
+ EXPECTED_ROWS = 22
63
+ assert (
64
+ len(df) == EXPECTED_ROWS
65
+ ), f"Number of rows in table vs expected: {len(df)}/{EXPECTED_ROWS}"
66
+ assert (
67
+ len(df.columns) == EXPECTED_COLUMNS
68
+ ), f"Number of columns in table vs expected: {len(df.columns)}/{EXPECTED_COLUMNS}"
69
+
70
+
71
+ def get_aws_credentials() -> dict:
72
+ access_key = os.getenv("S3_INGEST_TEST_ACCESS_KEY", None)
73
+ assert access_key
74
+ secret_key = os.getenv("S3_INGEST_TEST_SECRET_KEY", None)
75
+ assert secret_key
76
+ return {
77
+ "AWS_ACCESS_KEY_ID": access_key,
78
+ "AWS_SECRET_ACCESS_KEY": secret_key,
79
+ "AWS_REGION": "us-east-2",
80
+ }
81
+
82
+
83
+ @pytest.mark.asyncio
84
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
85
+ @requires_env("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY")
86
+ async def test_delta_table_destination_s3(upload_file: Path, temp_dir: Path):
87
+ aws_credentials = get_aws_credentials()
88
+ s3_bucket = "s3://utic-platform-test-destination"
89
+ destination_path = f"{s3_bucket}/destination/test"
90
+ connection_config = DeltaTableConnectionConfig(
91
+ access_config=DeltaTableAccessConfig(
92
+ aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"],
93
+ aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"],
94
+ ),
95
+ aws_region=aws_credentials["AWS_REGION"],
96
+ table_uri=destination_path,
97
+ )
98
+ stager_config = DeltaTableUploadStagerConfig()
99
+ stager = DeltaTableUploadStager(upload_stager_config=stager_config)
100
+ new_upload_file = stager.run(
101
+ elements_filepath=upload_file,
102
+ output_dir=temp_dir,
103
+ output_filename=upload_file.name,
104
+ )
105
+
106
+ upload_config = DeltaTableUploaderConfig()
107
+ uploader = DeltaTableUploader(connection_config=connection_config, upload_config=upload_config)
108
+ file_data = FileData(
109
+ source_identifiers=SourceIdentifiers(
110
+ fullpath=upload_file.name, filename=new_upload_file.name
111
+ ),
112
+ connector_type=CONNECTOR_TYPE,
113
+ identifier="mock file data",
114
+ )
115
+
116
+ try:
117
+ if uploader.is_async():
118
+ await uploader.run_async(path=new_upload_file, file_data=file_data)
119
+ else:
120
+ uploader.run(path=new_upload_file, file_data=file_data)
121
+ delta_table_path = os.path.join(destination_path, upload_file.name)
122
+ delta_table = DeltaTable(table_uri=delta_table_path, storage_options=aws_credentials)
123
+ df = delta_table.to_pandas()
124
+
125
+ EXPECTED_COLUMNS = 10
126
+ EXPECTED_ROWS = 22
127
+ assert (
128
+ len(df) == EXPECTED_ROWS
129
+ ), f"Number of rows in table vs expected: {len(df)}/{EXPECTED_ROWS}"
130
+ assert (
131
+ len(df.columns) == EXPECTED_COLUMNS
132
+ ), f"Number of columns in table vs expected: {len(df.columns)}/{EXPECTED_COLUMNS}"
133
+ finally:
134
+ s3fs = get_filesystem_class("s3")(
135
+ key=aws_credentials["AWS_ACCESS_KEY_ID"],
136
+ secret=aws_credentials["AWS_SECRET_ACCESS_KEY"],
137
+ )
138
+ s3fs.rm(path=destination_path, recursive=True)
@@ -0,0 +1,78 @@
1
+ import time
2
+ from contextlib import contextmanager
3
+ from typing import Optional
4
+
5
+ import docker
6
+ from docker.models.containers import Container
7
+
8
+
9
+ def get_container(
10
+ docker_client: docker.DockerClient,
11
+ image: str,
12
+ ports: dict,
13
+ environment: Optional[dict] = None,
14
+ volumes: Optional[dict] = None,
15
+ healthcheck: Optional[dict] = None,
16
+ ) -> Container:
17
+ run_kwargs = {
18
+ "image": image,
19
+ "detach": True,
20
+ "ports": ports,
21
+ }
22
+ if environment:
23
+ run_kwargs["environment"] = environment
24
+ if volumes:
25
+ run_kwargs["volumes"] = volumes
26
+ if healthcheck:
27
+ run_kwargs["healthcheck"] = healthcheck
28
+ container: Container = docker_client.containers.run(**run_kwargs)
29
+ return container
30
+
31
+
32
+ def has_healthcheck(container: Container) -> bool:
33
+ return container.attrs.get("Config", {}).get("Healthcheck", None) is not None
34
+
35
+
36
+ def healthcheck_wait(container: Container, timeout: int = 10) -> None:
37
+ health = container.health
38
+ start = time.time()
39
+ while health != "healthy" and time.time() - start < timeout:
40
+ time.sleep(1)
41
+ container.reload()
42
+ health = container.health
43
+ if health != "healthy":
44
+ health_dict = container.attrs.get("State", {}).get("Health", {})
45
+ raise TimeoutError(f"Docker container never came up healthy: {health_dict}")
46
+
47
+
48
+ @contextmanager
49
+ def container_context(
50
+ docker_client: docker.DockerClient,
51
+ image: str,
52
+ ports: dict,
53
+ environment: Optional[dict] = None,
54
+ volumes: Optional[dict] = None,
55
+ healthcheck: Optional[dict] = None,
56
+ healthcheck_timeout: int = 10,
57
+ ):
58
+ container: Optional[Container] = None
59
+ try:
60
+ container = get_container(
61
+ docker_client=docker_client,
62
+ image=image,
63
+ ports=ports,
64
+ environment=environment,
65
+ volumes=volumes,
66
+ healthcheck=healthcheck,
67
+ )
68
+ if has_healthcheck(container):
69
+ healthcheck_wait(container=container, timeout=healthcheck_timeout)
70
+ yield container
71
+ except AssertionError as e:
72
+ if container:
73
+ logs = container.logs()
74
+ print(logs.decode("utf-8"))
75
+ raise e
76
+ finally:
77
+ if container:
78
+ container.kill()