unstructured-ingest 0.3.7__py3-none-any.whl → 0.3.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (64) hide show
  1. test/integration/chunkers/test_chunkers.py +0 -11
  2. test/integration/connectors/conftest.py +11 -1
  3. test/integration/connectors/databricks_tests/test_volumes_native.py +4 -3
  4. test/integration/connectors/duckdb/conftest.py +14 -0
  5. test/integration/connectors/duckdb/test_duckdb.py +51 -44
  6. test/integration/connectors/duckdb/test_motherduck.py +37 -48
  7. test/integration/connectors/elasticsearch/test_elasticsearch.py +26 -4
  8. test/integration/connectors/elasticsearch/test_opensearch.py +26 -3
  9. test/integration/connectors/sql/test_postgres.py +102 -91
  10. test/integration/connectors/sql/test_singlestore.py +111 -99
  11. test/integration/connectors/sql/test_snowflake.py +142 -117
  12. test/integration/connectors/sql/test_sqlite.py +86 -75
  13. test/integration/connectors/test_astradb.py +22 -1
  14. test/integration/connectors/test_azure_ai_search.py +25 -3
  15. test/integration/connectors/test_chroma.py +120 -0
  16. test/integration/connectors/test_confluence.py +4 -4
  17. test/integration/connectors/test_delta_table.py +1 -0
  18. test/integration/connectors/test_kafka.py +4 -4
  19. test/integration/connectors/test_milvus.py +21 -0
  20. test/integration/connectors/test_mongodb.py +3 -3
  21. test/integration/connectors/test_neo4j.py +236 -0
  22. test/integration/connectors/test_pinecone.py +25 -1
  23. test/integration/connectors/test_qdrant.py +25 -2
  24. test/integration/connectors/test_s3.py +9 -6
  25. test/integration/connectors/utils/docker.py +6 -0
  26. test/integration/connectors/utils/validation/__init__.py +0 -0
  27. test/integration/connectors/utils/validation/destination.py +88 -0
  28. test/integration/connectors/utils/validation/equality.py +75 -0
  29. test/integration/connectors/utils/{validation.py → validation/source.py} +15 -91
  30. test/integration/connectors/utils/validation/utils.py +36 -0
  31. unstructured_ingest/__version__.py +1 -1
  32. unstructured_ingest/utils/chunking.py +11 -0
  33. unstructured_ingest/utils/data_prep.py +36 -0
  34. unstructured_ingest/v2/interfaces/upload_stager.py +70 -6
  35. unstructured_ingest/v2/interfaces/uploader.py +11 -2
  36. unstructured_ingest/v2/pipeline/steps/stage.py +3 -1
  37. unstructured_ingest/v2/processes/connectors/astradb.py +8 -30
  38. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +16 -40
  39. unstructured_ingest/v2/processes/connectors/chroma.py +36 -59
  40. unstructured_ingest/v2/processes/connectors/couchbase.py +42 -52
  41. unstructured_ingest/v2/processes/connectors/delta_table.py +11 -33
  42. unstructured_ingest/v2/processes/connectors/duckdb/base.py +26 -26
  43. unstructured_ingest/v2/processes/connectors/duckdb/duckdb.py +29 -20
  44. unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +37 -44
  45. unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +5 -30
  46. unstructured_ingest/v2/processes/connectors/gitlab.py +32 -31
  47. unstructured_ingest/v2/processes/connectors/google_drive.py +32 -29
  48. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +2 -4
  49. unstructured_ingest/v2/processes/connectors/kdbai.py +44 -70
  50. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +8 -10
  51. unstructured_ingest/v2/processes/connectors/local.py +13 -2
  52. unstructured_ingest/v2/processes/connectors/milvus.py +16 -57
  53. unstructured_ingest/v2/processes/connectors/mongodb.py +4 -8
  54. unstructured_ingest/v2/processes/connectors/neo4j.py +381 -0
  55. unstructured_ingest/v2/processes/connectors/pinecone.py +23 -65
  56. unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +32 -41
  57. unstructured_ingest/v2/processes/connectors/sql/sql.py +41 -40
  58. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +9 -31
  59. {unstructured_ingest-0.3.7.dist-info → unstructured_ingest-0.3.9.dist-info}/METADATA +21 -17
  60. {unstructured_ingest-0.3.7.dist-info → unstructured_ingest-0.3.9.dist-info}/RECORD +64 -56
  61. {unstructured_ingest-0.3.7.dist-info → unstructured_ingest-0.3.9.dist-info}/LICENSE.md +0 -0
  62. {unstructured_ingest-0.3.7.dist-info → unstructured_ingest-0.3.9.dist-info}/WHEEL +0 -0
  63. {unstructured_ingest-0.3.7.dist-info → unstructured_ingest-0.3.9.dist-info}/entry_points.txt +0 -0
  64. {unstructured_ingest-0.3.7.dist-info → unstructured_ingest-0.3.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,236 @@
1
+ import json
2
+ import time
3
+ import uuid
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+ from neo4j import AsyncGraphDatabase, Driver, GraphDatabase
9
+ from neo4j.exceptions import ServiceUnavailable
10
+ from pytest_check import check
11
+
12
+ from test.integration.connectors.utils.constants import DESTINATION_TAG
13
+ from test.integration.connectors.utils.docker import container_context
14
+ from unstructured_ingest.error import DestinationConnectionError
15
+ from unstructured_ingest.utils.chunking import elements_from_base64_gzipped_json
16
+ from unstructured_ingest.v2.interfaces.file_data import (
17
+ FileData,
18
+ FileDataSourceMetadata,
19
+ SourceIdentifiers,
20
+ )
21
+ from unstructured_ingest.v2.processes.connectors.neo4j import (
22
+ CONNECTOR_TYPE,
23
+ Label,
24
+ Neo4jAccessConfig,
25
+ Neo4jConnectionConfig,
26
+ Neo4jUploader,
27
+ Neo4jUploaderConfig,
28
+ Neo4jUploadStager,
29
+ Relationship,
30
+ )
31
+
32
+ USERNAME = "neo4j"
33
+ PASSWORD = "password"
34
+ URI = "neo4j://localhost:7687"
35
+ DATABASE = "neo4j"
36
+
37
+ EXPECTED_DOCUMENT_COUNT = 1
38
+
39
+
40
+ # NOTE: Precheck tests are read-only so we utilize the same container for all tests.
41
+ # If new tests require clean neo4j container, this fixture's scope should be adjusted.
42
+ @pytest.fixture(autouse=True, scope="module")
43
+ def _neo4j_server():
44
+ with container_context(
45
+ image="neo4j:latest", environment={"NEO4J_AUTH": "neo4j/password"}, ports={"7687": "7687"}
46
+ ):
47
+ driver = GraphDatabase.driver(uri=URI, auth=(USERNAME, PASSWORD))
48
+ wait_for_connection(driver)
49
+ driver.close()
50
+ yield
51
+
52
+
53
+ @pytest.mark.asyncio
54
+ @pytest.mark.tags(DESTINATION_TAG, CONNECTOR_TYPE)
55
+ async def test_neo4j_destination(upload_file: Path, tmp_path: Path):
56
+ stager = Neo4jUploadStager()
57
+ uploader = Neo4jUploader(
58
+ connection_config=Neo4jConnectionConfig(
59
+ access_config=Neo4jAccessConfig(password=PASSWORD), # type: ignore
60
+ username=USERNAME,
61
+ uri=URI,
62
+ database=DATABASE,
63
+ ),
64
+ upload_config=Neo4jUploaderConfig(),
65
+ )
66
+ file_data = FileData(
67
+ identifier="mock-file-data",
68
+ connector_type="neo4j",
69
+ source_identifiers=SourceIdentifiers(
70
+ filename=upload_file.name,
71
+ fullpath=upload_file.name,
72
+ ),
73
+ metadata=FileDataSourceMetadata(
74
+ date_created=str(datetime(2022, 1, 1).timestamp()),
75
+ date_modified=str(datetime(2022, 1, 2).timestamp()),
76
+ ),
77
+ )
78
+ staged_filepath = stager.run(
79
+ upload_file,
80
+ file_data=file_data,
81
+ output_dir=tmp_path,
82
+ output_filename=upload_file.name,
83
+ )
84
+
85
+ await uploader.run_async(staged_filepath, file_data)
86
+ await validate_uploaded_graph(upload_file)
87
+
88
+ modified_upload_file = tmp_path / f"modified-{upload_file.name}"
89
+ with open(upload_file) as file:
90
+ elements = json.load(file)
91
+ for element in elements:
92
+ element["element_id"] = str(uuid.uuid4())
93
+
94
+ with open(modified_upload_file, "w") as file:
95
+ json.dump(elements, file, indent=4)
96
+
97
+ staged_filepath = stager.run(
98
+ modified_upload_file,
99
+ file_data=file_data,
100
+ output_dir=tmp_path,
101
+ output_filename=modified_upload_file.name,
102
+ )
103
+ await uploader.run_async(staged_filepath, file_data)
104
+ await validate_uploaded_graph(modified_upload_file)
105
+
106
+
107
+ @pytest.mark.tags(DESTINATION_TAG, CONNECTOR_TYPE)
108
+ class TestPrecheck:
109
+ @pytest.fixture
110
+ def configured_uploader(self) -> Neo4jUploader:
111
+ return Neo4jUploader(
112
+ connection_config=Neo4jConnectionConfig(
113
+ access_config=Neo4jAccessConfig(password=PASSWORD), # type: ignore
114
+ username=USERNAME,
115
+ uri=URI,
116
+ database=DATABASE,
117
+ ),
118
+ upload_config=Neo4jUploaderConfig(),
119
+ )
120
+
121
+ def test_succeeds(self, configured_uploader: Neo4jUploader):
122
+ configured_uploader.precheck()
123
+
124
+ def test_fails_on_invalid_password(self, configured_uploader: Neo4jUploader):
125
+ configured_uploader.connection_config.access_config.get_secret_value().password = (
126
+ "invalid-password"
127
+ )
128
+ with pytest.raises(
129
+ DestinationConnectionError,
130
+ match="{code: Neo.ClientError.Security.Unauthorized}",
131
+ ):
132
+ configured_uploader.precheck()
133
+
134
+ def test_fails_on_invalid_username(self, configured_uploader: Neo4jUploader):
135
+ configured_uploader.connection_config.username = "invalid-username"
136
+ with pytest.raises(
137
+ DestinationConnectionError, match="{code: Neo.ClientError.Security.Unauthorized}"
138
+ ):
139
+ configured_uploader.precheck()
140
+
141
+ @pytest.mark.parametrize(
142
+ ("uri", "expected_error_msg"),
143
+ [
144
+ ("neo4j://localhst:7687", "Cannot resolve address"),
145
+ ("neo4j://localhost:7777", "Unable to retrieve routing information"),
146
+ ],
147
+ )
148
+ def test_fails_on_invalid_uri(
149
+ self, configured_uploader: Neo4jUploader, uri: str, expected_error_msg: str
150
+ ):
151
+ configured_uploader.connection_config.uri = uri
152
+ with pytest.raises(DestinationConnectionError, match=expected_error_msg):
153
+ configured_uploader.precheck()
154
+
155
+ def test_fails_on_invalid_database(self, configured_uploader: Neo4jUploader):
156
+ configured_uploader.connection_config.database = "invalid-database"
157
+ with pytest.raises(
158
+ DestinationConnectionError, match="{code: Neo.ClientError.Database.DatabaseNotFound}"
159
+ ):
160
+ configured_uploader.precheck()
161
+
162
+
163
+ def wait_for_connection(driver: Driver, retries: int = 10, delay_seconds: int = 2):
164
+ attempts = 0
165
+ while attempts < retries:
166
+ try:
167
+ driver.verify_connectivity()
168
+ return
169
+ except ServiceUnavailable:
170
+ time.sleep(delay_seconds)
171
+ attempts += 1
172
+
173
+ pytest.fail("Failed to connect with Neo4j server.")
174
+
175
+
176
+ async def validate_uploaded_graph(upload_file: Path):
177
+ with open(upload_file) as file:
178
+ elements = json.load(file)
179
+
180
+ for element in elements:
181
+ if "orig_elements" in element["metadata"]:
182
+ element["metadata"]["orig_elements"] = elements_from_base64_gzipped_json(
183
+ element["metadata"]["orig_elements"]
184
+ )
185
+ else:
186
+ element["metadata"]["orig_elements"] = []
187
+
188
+ expected_chunks_count = len(elements)
189
+ expected_element_count = len(
190
+ {
191
+ origin_element["element_id"]
192
+ for chunk in elements
193
+ for origin_element in chunk["metadata"]["orig_elements"]
194
+ }
195
+ )
196
+ expected_nodes_count = expected_chunks_count + expected_element_count + EXPECTED_DOCUMENT_COUNT
197
+
198
+ driver = AsyncGraphDatabase.driver(uri=URI, auth=(USERNAME, PASSWORD))
199
+ try:
200
+ nodes_count = len((await driver.execute_query("MATCH (n) RETURN n"))[0])
201
+ chunk_nodes_count = len(
202
+ (await driver.execute_query(f"MATCH (n: {Label.CHUNK}) RETURN n"))[0]
203
+ )
204
+ document_nodes_count = len(
205
+ (await driver.execute_query(f"MATCH (n: {Label.DOCUMENT}) RETURN n"))[0]
206
+ )
207
+ element_nodes_count = len(
208
+ (await driver.execute_query(f"MATCH (n: {Label.UNSTRUCTURED_ELEMENT}) RETURN n"))[0]
209
+ )
210
+ with check:
211
+ assert nodes_count == expected_nodes_count
212
+ with check:
213
+ assert document_nodes_count == EXPECTED_DOCUMENT_COUNT
214
+ with check:
215
+ assert chunk_nodes_count == expected_chunks_count
216
+ with check:
217
+ assert element_nodes_count == expected_element_count
218
+
219
+ records, _, _ = await driver.execute_query(
220
+ f"MATCH ()-[r:{Relationship.PART_OF_DOCUMENT}]->(:{Label.DOCUMENT}) RETURN r"
221
+ )
222
+ part_of_document_count = len(records)
223
+
224
+ records, _, _ = await driver.execute_query(
225
+ f"MATCH (:{Label.CHUNK})-[r:{Relationship.NEXT_CHUNK}]->(:{Label.CHUNK}) RETURN r"
226
+ )
227
+ next_chunk_count = len(records)
228
+
229
+ if not check.any_failures():
230
+ with check:
231
+ assert part_of_document_count == expected_chunks_count + expected_element_count
232
+ with check:
233
+ assert next_chunk_count == expected_chunks_count - 1
234
+
235
+ finally:
236
+ await driver.close()
@@ -8,12 +8,17 @@ from typing import Generator
8
8
  from uuid import uuid4
9
9
 
10
10
  import pytest
11
+ from _pytest.fixtures import TopRequest
11
12
  from pinecone import Pinecone, ServerlessSpec
12
13
  from pinecone.core.openapi.shared.exceptions import NotFoundException
13
14
 
14
15
  from test.integration.connectors.utils.constants import (
15
16
  DESTINATION_TAG,
16
17
  )
18
+ from test.integration.connectors.utils.validation.destination import (
19
+ StagerValidationConfigs,
20
+ stager_validation,
21
+ )
17
22
  from test.integration.utils import requires_env
18
23
  from unstructured_ingest.error import DestinationConnectionError
19
24
  from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
@@ -251,7 +256,10 @@ def test_large_metadata(pinecone_index: str, tmp_path: Path, upload_file: Path):
251
256
  identifier="mock-file-data",
252
257
  )
253
258
  staged_file = stager.run(
254
- file_data, large_metadata_upload_file, tmp_path, large_metadata_upload_file.name
259
+ elements_filepath=large_metadata_upload_file,
260
+ file_data=file_data,
261
+ output_dir=tmp_path,
262
+ output_filename=large_metadata_upload_file.name,
255
263
  )
256
264
  try:
257
265
  uploader.run(staged_file, file_data)
@@ -262,3 +270,19 @@ def test_large_metadata(pinecone_index: str, tmp_path: Path, upload_file: Path):
262
270
  raise pytest.fail("Upload request failed due to metadata exceeding limits.")
263
271
 
264
272
  validate_pinecone_index(pinecone_index, 1, interval=5)
273
+
274
+
275
+ @pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
276
+ def test_pinecone_stager(
277
+ request: TopRequest,
278
+ upload_file_str: str,
279
+ tmp_path: Path,
280
+ ):
281
+ upload_file: Path = request.getfixturevalue(upload_file_str)
282
+ stager = PineconeUploadStager()
283
+ stager_validation(
284
+ configs=StagerValidationConfigs(test_id=CONNECTOR_TYPE, expected_count=22),
285
+ input_file=upload_file,
286
+ stager=stager,
287
+ tmp_dir=tmp_path,
288
+ )
@@ -6,10 +6,15 @@ from pathlib import Path
6
6
  from typing import AsyncGenerator
7
7
 
8
8
  import pytest
9
+ from _pytest.fixtures import TopRequest
9
10
  from qdrant_client import AsyncQdrantClient
10
11
 
11
12
  from test.integration.connectors.utils.constants import DESTINATION_TAG
12
13
  from test.integration.connectors.utils.docker import container_context
14
+ from test.integration.connectors.utils.validation.destination import (
15
+ StagerValidationConfigs,
16
+ stager_validation,
17
+ )
13
18
  from test.integration.utils import requires_env
14
19
  from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
15
20
  from unstructured_ingest.v2.processes.connectors.qdrant.cloud import (
@@ -138,7 +143,7 @@ async def test_qdrant_destination_server(upload_file: Path, tmp_path: Path, dock
138
143
  output_dir=tmp_path,
139
144
  output_filename=upload_file.name,
140
145
  )
141
-
146
+ uploader.precheck()
142
147
  if uploader.is_async():
143
148
  await uploader.run_async(path=staged_upload_file, file_data=file_data)
144
149
  else:
@@ -183,10 +188,28 @@ async def test_qdrant_destination_cloud(upload_file: Path, tmp_path: Path):
183
188
  output_dir=tmp_path,
184
189
  output_filename=upload_file.name,
185
190
  )
186
-
191
+ uploader.precheck()
187
192
  if uploader.is_async():
188
193
  await uploader.run_async(path=staged_upload_file, file_data=file_data)
189
194
  else:
190
195
  uploader.run(path=staged_upload_file, file_data=file_data)
191
196
  async with qdrant_client(connection_kwargs) as client:
192
197
  await validate_upload(client=client, upload_file=upload_file)
198
+
199
+
200
+ @pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
201
+ def test_qdrant_stager(
202
+ request: TopRequest,
203
+ upload_file_str: str,
204
+ tmp_path: Path,
205
+ ):
206
+ upload_file: Path = request.getfixturevalue(upload_file_str)
207
+ stager = LocalQdrantUploadStager(
208
+ upload_stager_config=LocalQdrantUploadStagerConfig(),
209
+ )
210
+ stager_validation(
211
+ configs=StagerValidationConfigs(test_id=LOCAL_CONNECTOR_TYPE, expected_count=22),
212
+ input_file=upload_file,
213
+ stager=stager,
214
+ tmp_dir=tmp_path,
215
+ )
@@ -11,8 +11,8 @@ from test.integration.connectors.utils.constants import (
11
11
  env_setup_path,
12
12
  )
13
13
  from test.integration.connectors.utils.docker_compose import docker_compose_context
14
- from test.integration.connectors.utils.validation import (
15
- ValidationConfigs,
14
+ from test.integration.connectors.utils.validation.source import (
15
+ SourceValidationConfigs,
16
16
  source_connector_validation,
17
17
  )
18
18
  from test.integration.utils import requires_env
@@ -62,7 +62,7 @@ async def test_s3_source(anon_connection_config: S3ConnectionConfig):
62
62
  await source_connector_validation(
63
63
  indexer=indexer,
64
64
  downloader=downloader,
65
- configs=ValidationConfigs(
65
+ configs=SourceValidationConfigs(
66
66
  test_id="s3",
67
67
  predownload_file_data_check=validate_predownload_file_data,
68
68
  postdownload_file_data_check=validate_postdownload_file_data,
@@ -85,7 +85,7 @@ async def test_s3_source_special_char(anon_connection_config: S3ConnectionConfig
85
85
  await source_connector_validation(
86
86
  indexer=indexer,
87
87
  downloader=downloader,
88
- configs=ValidationConfigs(
88
+ configs=SourceValidationConfigs(
89
89
  test_id="s3-specialchar",
90
90
  predownload_file_data_check=validate_predownload_file_data,
91
91
  postdownload_file_data_check=validate_postdownload_file_data,
@@ -121,7 +121,7 @@ async def test_s3_minio_source(anon_connection_config: S3ConnectionConfig):
121
121
  await source_connector_validation(
122
122
  indexer=indexer,
123
123
  downloader=downloader,
124
- configs=ValidationConfigs(
124
+ configs=SourceValidationConfigs(
125
125
  test_id="s3-minio",
126
126
  predownload_file_data_check=validate_predownload_file_data,
127
127
  postdownload_file_data_check=validate_postdownload_file_data,
@@ -165,11 +165,14 @@ async def test_s3_destination(upload_file: Path):
165
165
  identifier="mock file data",
166
166
  )
167
167
  try:
168
+ uploader.precheck()
168
169
  if uploader.is_async():
169
170
  await uploader.run_async(path=upload_file, file_data=file_data)
170
171
  else:
171
172
  uploader.run(path=upload_file, file_data=file_data)
172
- uploaded_files = s3fs.ls(path=destination_path)
173
+ uploaded_files = [
174
+ Path(file) for file in s3fs.ls(path=destination_path) if Path(file).name != "_empty"
175
+ ]
173
176
  assert len(uploaded_files) == 1
174
177
  finally:
175
178
  s3fs.rm(path=destination_path, recursive=True)
@@ -44,6 +44,7 @@ def get_container(
44
44
  docker_client: docker.DockerClient,
45
45
  image: str,
46
46
  ports: dict,
47
+ name: Optional[str] = "connector_test",
47
48
  environment: Optional[dict] = None,
48
49
  volumes: Optional[dict] = None,
49
50
  healthcheck: Optional[HealthCheck] = None,
@@ -59,6 +60,8 @@ def get_container(
59
60
  run_kwargs["volumes"] = volumes
60
61
  if healthcheck:
61
62
  run_kwargs["healthcheck"] = healthcheck.model_dump()
63
+ if name:
64
+ run_kwargs["name"] = name
62
65
  container: Container = docker_client.containers.run(**run_kwargs)
63
66
  return container
64
67
 
@@ -112,6 +115,7 @@ def container_context(
112
115
  healthcheck: Optional[HealthCheck] = None,
113
116
  healthcheck_retries: int = 30,
114
117
  docker_client: Optional[docker.DockerClient] = None,
118
+ name: Optional[str] = "connector_test",
115
119
  ):
116
120
  docker_client = docker_client or docker.from_env()
117
121
  print(f"pulling image {image}")
@@ -125,6 +129,7 @@ def container_context(
125
129
  environment=environment,
126
130
  volumes=volumes,
127
131
  healthcheck=healthcheck,
132
+ name=name,
128
133
  )
129
134
  if healthcheck_data := get_healthcheck(container):
130
135
  # Mirror whatever healthcheck config set on container
@@ -143,3 +148,4 @@ def container_context(
143
148
  finally:
144
149
  if container:
145
150
  container.kill()
151
+ container.remove()
@@ -0,0 +1,88 @@
1
+ import json
2
+ import os
3
+ import shutil
4
+ from pathlib import Path
5
+
6
+ import ndjson
7
+
8
+ from test.integration.connectors.utils.validation.utils import ValidationConfig
9
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers, UploadStager
10
+
11
+
12
+ class StagerValidationConfigs(ValidationConfig):
13
+ expected_count: int
14
+
15
+ def stager_output_dir(self) -> Path:
16
+ dir = self.test_output_dir() / "stager"
17
+ dir.mkdir(exist_ok=True, parents=True)
18
+ return dir
19
+
20
+ def stager_output_path(self, input_path: Path) -> Path:
21
+ return self.stager_output_dir() / input_path.name
22
+
23
+
24
+ def run_all_stager_validations(
25
+ configs: StagerValidationConfigs, input_file: Path, staged_filepath: Path
26
+ ):
27
+ # Validate matching extensions
28
+ assert input_file.suffix == staged_filepath.suffix
29
+
30
+ # Validate length
31
+ staged_data = get_data(staged_filepath=staged_filepath)
32
+ assert len(staged_data) == configs.expected_count
33
+
34
+ # Validate file
35
+ expected_filepath = configs.stager_output_path(input_path=input_file)
36
+ assert expected_filepath.exists(), f"{expected_filepath} does not exist"
37
+ assert expected_filepath.is_file(), f"{expected_filepath} is not a file"
38
+ if configs.detect_diff(expected_filepath=expected_filepath, current_filepath=staged_filepath):
39
+ raise AssertionError(
40
+ f"Current file ({staged_filepath}) does not match expected file: {expected_filepath}"
41
+ )
42
+
43
+
44
+ def update_stager_fixtures(stager_output_path: Path, staged_filepath: Path):
45
+ copied_filepath = stager_output_path / staged_filepath.name
46
+ shutil.copy(staged_filepath, copied_filepath)
47
+
48
+
49
+ def get_data(staged_filepath: Path) -> list[dict]:
50
+ if staged_filepath.suffix == ".json":
51
+ with staged_filepath.open() as f:
52
+ return json.load(f)
53
+ elif staged_filepath.suffix == ".ndjson":
54
+ with staged_filepath.open() as f:
55
+ return ndjson.load(f)
56
+ else:
57
+ raise ValueError(f"Unsupported file type: {staged_filepath.suffix}")
58
+
59
+
60
+ def stager_validation(
61
+ stager: UploadStager,
62
+ tmp_dir: Path,
63
+ input_file: Path,
64
+ configs: StagerValidationConfigs,
65
+ overwrite_fixtures: bool = os.getenv("OVERWRITE_FIXTURES", "False").lower() == "true",
66
+ ) -> None:
67
+ # Run stager
68
+ file_data = FileData(
69
+ source_identifiers=SourceIdentifiers(fullpath=input_file.name, filename=input_file.name),
70
+ connector_type=configs.test_id,
71
+ identifier="mock file data",
72
+ )
73
+ staged_filepath = stager.run(
74
+ elements_filepath=input_file,
75
+ file_data=file_data,
76
+ output_dir=tmp_dir,
77
+ output_filename=input_file.name,
78
+ )
79
+ if not overwrite_fixtures:
80
+ print("Running validation")
81
+ run_all_stager_validations(
82
+ configs=configs, input_file=input_file, staged_filepath=staged_filepath
83
+ )
84
+ else:
85
+ print("Running fixtures update")
86
+ update_stager_fixtures(
87
+ stager_output_path=configs.stager_output_dir(), staged_filepath=staged_filepath
88
+ )
@@ -0,0 +1,75 @@
1
+ import json
2
+ from pathlib import Path
3
+
4
+ import ndjson
5
+ from bs4 import BeautifulSoup
6
+ from deepdiff import DeepDiff
7
+
8
+
9
+ def json_equality_check(expected_filepath: Path, current_filepath: Path) -> bool:
10
+ with expected_filepath.open() as f:
11
+ expected_data = json.load(f)
12
+ with current_filepath.open() as f:
13
+ current_data = json.load(f)
14
+ diff = DeepDiff(expected_data, current_data)
15
+ if diff:
16
+ print("diff between expected and current json")
17
+ print(diff.to_json(indent=2))
18
+ return False
19
+ return True
20
+
21
+
22
+ def ndjson_equality_check(expected_filepath: Path, current_filepath: Path) -> bool:
23
+ with expected_filepath.open() as f:
24
+ expected_data = ndjson.load(f)
25
+ with current_filepath.open() as f:
26
+ current_data = ndjson.load(f)
27
+ if len(current_data) != len(expected_data):
28
+ print(
29
+ f"expected data length {len(expected_data)} "
30
+ f"didn't match current results: {len(current_data)}"
31
+ )
32
+ for i in range(len(expected_data)):
33
+ e = expected_data[i]
34
+ r = current_data[i]
35
+ if e != r:
36
+ print(f"{i}th element doesn't match:\nexpected {e}\ncurrent {r}")
37
+ return False
38
+ return True
39
+
40
+
41
+ def html_equality_check(expected_filepath: Path, current_filepath: Path) -> bool:
42
+ with expected_filepath.open() as expected_f:
43
+ expected_soup = BeautifulSoup(expected_f, "html.parser")
44
+ with current_filepath.open() as current_f:
45
+ current_soup = BeautifulSoup(current_f, "html.parser")
46
+ return expected_soup.text == current_soup.text
47
+
48
+
49
+ def txt_equality_check(expected_filepath: Path, current_filepath: Path) -> bool:
50
+ with expected_filepath.open() as expected_f:
51
+ expected_text_lines = expected_f.readlines()
52
+ with current_filepath.open() as current_f:
53
+ current_text_lines = current_f.readlines()
54
+ if len(expected_text_lines) != len(current_text_lines):
55
+ print(
56
+ f"Lines in expected text file ({len(expected_text_lines)}) "
57
+ f"don't match current text file ({len(current_text_lines)})"
58
+ )
59
+ return False
60
+ expected_text = "\n".join(expected_text_lines)
61
+ current_text = "\n".join(current_text_lines)
62
+ if expected_text == current_text:
63
+ return True
64
+ print("txt content don't match:")
65
+ print(f"expected: {expected_text}")
66
+ print(f"current: {current_text}")
67
+ return False
68
+
69
+
70
+ file_type_equality_check = {
71
+ ".json": json_equality_check,
72
+ ".ndjson": ndjson_equality_check,
73
+ ".html": html_equality_check,
74
+ ".txt": txt_equality_check,
75
+ }