unstructured-ingest 0.5.8__py3-none-any.whl → 0.5.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

@@ -0,0 +1,401 @@
1
+ from unittest.mock import MagicMock
2
+
3
+ import pytest
4
+ from pydantic import ValidationError
5
+ from pytest_mock import MockerFixture
6
+
7
+ from unstructured_ingest.v2.processes.connectors.jira import (
8
+ FieldGetter,
9
+ JiraAccessConfig,
10
+ JiraConnectionConfig,
11
+ JiraIndexer,
12
+ JiraIndexerConfig,
13
+ JiraIssueMetadata,
14
+ issues_fetcher_wrapper,
15
+ nested_object_to_field_getter,
16
+ )
17
+
18
+
19
+ @pytest.fixture
20
+ def jira_connection_config():
21
+ access_config = JiraAccessConfig(password="password")
22
+ return JiraConnectionConfig(
23
+ url="http://localhost:1234",
24
+ username="test@example.com",
25
+ access_config=access_config,
26
+ )
27
+
28
+
29
+ @pytest.fixture
30
+ def jira_indexer(jira_connection_config: JiraConnectionConfig):
31
+ indexer_config = JiraIndexerConfig(projects=["TEST1"], boards=["2"], issues=["TEST2-1"])
32
+ return JiraIndexer(connection_config=jira_connection_config, index_config=indexer_config)
33
+
34
+
35
+ @pytest.fixture
36
+ def mock_jira(mocker: MockerFixture):
37
+ mock_client = mocker.patch.object(JiraConnectionConfig, "get_client", autospec=True)
38
+ mock_jira = mocker.MagicMock()
39
+ mock_client.return_value.__enter__.return_value = mock_jira
40
+ return mock_jira
41
+
42
+
43
+ def test_jira_indexer_precheck_success(
44
+ caplog: pytest.LogCaptureFixture,
45
+ mocker: MockerFixture,
46
+ jira_indexer: JiraIndexer,
47
+ mock_jira: MagicMock,
48
+ ):
49
+ get_permissions = mocker.MagicMock()
50
+ get_permissions.return_value = {"permissions": {"BROWSE_PROJECTS": {"havePermission": True}}}
51
+ mock_jira.get_permissions = get_permissions
52
+
53
+ with caplog.at_level("INFO"):
54
+ jira_indexer.precheck()
55
+ assert "Connection to Jira successful." in caplog.text
56
+
57
+ get_permissions.assert_called_once()
58
+
59
+
60
+ def test_jira_indexer_precheck_no_permission(
61
+ mocker: MockerFixture,
62
+ jira_indexer: JiraIndexer,
63
+ mock_jira: MagicMock,
64
+ ):
65
+ get_permissions = mocker.MagicMock()
66
+ get_permissions.return_value = {"permissions": {"BROWSE_PROJECTS": {"havePermission": False}}}
67
+ mock_jira.get_permissions = get_permissions
68
+
69
+ with pytest.raises(ValueError):
70
+ jira_indexer.precheck()
71
+
72
+ get_permissions.assert_called_once()
73
+
74
+
75
+ @pytest.mark.parametrize(
76
+ ("project_issues_count", "expected_issues_count"), [(2, 2), ({"total": 2}, 2), (0, 0)]
77
+ )
78
+ def test_jira_indexer_get_issues_within_single_project(
79
+ jira_indexer: JiraIndexer,
80
+ mock_jira: MagicMock,
81
+ project_issues_count,
82
+ expected_issues_count,
83
+ ):
84
+ mock_jira.get_project_issues_count.return_value = project_issues_count
85
+ mock_jira.get_all_project_issues.return_value = [
86
+ {"id": "1", "key": "TEST-1"},
87
+ {"id": "2", "key": "TEST-2"},
88
+ ]
89
+
90
+ issues = jira_indexer._get_issues_within_single_project("TEST1")
91
+ assert len(issues) == expected_issues_count
92
+
93
+ if issues:
94
+ assert issues[0].id == "1"
95
+ assert issues[0].key == "TEST-1"
96
+ assert issues[1].id == "2"
97
+ assert issues[1].key == "TEST-2"
98
+
99
+
100
+ def test_jira_indexer_get_issues_within_single_project_error(
101
+ jira_indexer: JiraIndexer,
102
+ mock_jira: MagicMock,
103
+ ):
104
+ mock_jira.get_project_issues_count.return_value = {}
105
+
106
+ with pytest.raises(KeyError):
107
+ jira_indexer._get_issues_within_single_project("TEST1")
108
+
109
+
110
+ def test_jira_indexer_get_issues_within_projects_with_projects(
111
+ jira_indexer: JiraIndexer,
112
+ mock_jira: MagicMock,
113
+ ):
114
+ mock_jira.get_project_issues_count.return_value = 2
115
+ mock_jira.get_all_project_issues.return_value = [
116
+ {"id": "1", "key": "TEST-1"},
117
+ {"id": "2", "key": "TEST-2"},
118
+ ]
119
+
120
+ issues = jira_indexer._get_issues_within_projects()
121
+ assert len(issues) == 2
122
+ assert issues[0].id == "1"
123
+ assert issues[0].key == "TEST-1"
124
+ assert issues[1].id == "2"
125
+ assert issues[1].key == "TEST-2"
126
+
127
+
128
+ def test_jira_indexer_get_issues_within_projects_no_projects_with_boards_or_issues(
129
+ mocker: MockerFixture,
130
+ jira_indexer: JiraIndexer,
131
+ ):
132
+ jira_indexer.index_config.projects = None
133
+ jira_indexer.index_config.boards = ["2"]
134
+ mocker.patch.object(JiraConnectionConfig, "get_client", autospec=True)
135
+
136
+ issues = jira_indexer._get_issues_within_projects()
137
+ assert issues == []
138
+
139
+
140
+ def test_jira_indexer_get_issues_within_projects_no_projects_no_boards_no_issues(
141
+ jira_indexer: JiraIndexer,
142
+ mock_jira: MagicMock,
143
+ ):
144
+ jira_indexer.index_config.projects = None
145
+ jira_indexer.index_config.boards = None
146
+ jira_indexer.index_config.issues = None
147
+ mock_jira.projects.return_value = [{"key": "TEST1"}, {"key": "TEST2"}]
148
+ mock_jira.get_project_issues_count.return_value = 2
149
+ mock_jira.get_all_project_issues.return_value = [
150
+ {"id": "1", "key": "TEST-1"},
151
+ {"id": "2", "key": "TEST-2"},
152
+ ]
153
+
154
+ issues = jira_indexer._get_issues_within_projects()
155
+ assert len(issues) == 4
156
+ assert issues[0].id == "1"
157
+ assert issues[0].key == "TEST-1"
158
+ assert issues[1].id == "2"
159
+ assert issues[1].key == "TEST-2"
160
+ assert issues[2].id == "1"
161
+ assert issues[2].key == "TEST-1"
162
+ assert issues[3].id == "2"
163
+ assert issues[3].key == "TEST-2"
164
+
165
+
166
+ def test_jira_indexer_get_issues_within_boards(
167
+ jira_indexer: JiraIndexer,
168
+ mock_jira: MagicMock,
169
+ ):
170
+ mock_jira.get_issues_for_board.return_value = [
171
+ {"id": "1", "key": "TEST-1"},
172
+ {"id": "2", "key": "TEST-2"},
173
+ ]
174
+
175
+ issues = jira_indexer._get_issues_within_boards()
176
+ assert len(issues) == 2
177
+ assert issues[0].id == "1"
178
+ assert issues[0].key == "TEST-1"
179
+ assert issues[1].id == "2"
180
+ assert issues[1].key == "TEST-2"
181
+
182
+
183
+ def test_jira_indexer_get_issues_within_single_board(
184
+ jira_indexer: JiraIndexer,
185
+ mock_jira: MagicMock,
186
+ ):
187
+ mock_jira.get_issues_for_board.return_value = [
188
+ {"id": "1", "key": "TEST-1"},
189
+ {"id": "2", "key": "TEST-2"},
190
+ ]
191
+
192
+ issues = jira_indexer._get_issues_within_single_board("1")
193
+ assert len(issues) == 2
194
+ assert issues[0].id == "1"
195
+ assert issues[0].key == "TEST-1"
196
+ assert issues[0].board_id == "1"
197
+ assert issues[1].id == "2"
198
+ assert issues[1].key == "TEST-2"
199
+ assert issues[1].board_id == "1"
200
+
201
+
202
+ def test_jira_indexer_get_issues_within_single_board_no_issues(
203
+ jira_indexer: JiraIndexer,
204
+ mock_jira: MagicMock,
205
+ ):
206
+ mock_jira.get_issues_for_board.return_value = []
207
+
208
+ issues = jira_indexer._get_issues_within_single_board("1")
209
+ assert len(issues) == 0
210
+
211
+
212
+ def test_jira_indexer_get_issues(
213
+ jira_indexer: JiraIndexer,
214
+ mock_jira: MagicMock,
215
+ ):
216
+ jira_indexer.index_config.issues = ["TEST2-1", "TEST2-2"]
217
+ mock_jira.get_issue.return_value = {
218
+ "id": "ISSUE_ID",
219
+ "key": "ISSUE_KEY",
220
+ }
221
+
222
+ issues = jira_indexer._get_issues()
223
+ assert len(issues) == 2
224
+ assert issues[0].id == "ISSUE_ID"
225
+ assert issues[0].key == "ISSUE_KEY"
226
+
227
+
228
+ def test_jira_indexer_get_issues_unique_issues(mocker: MockerFixture, jira_indexer: JiraIndexer):
229
+ mocker.patch.object(
230
+ JiraIndexer,
231
+ "_get_issues_within_boards",
232
+ return_value=[
233
+ JiraIssueMetadata(id="1", key="TEST-1", board_id="1"),
234
+ JiraIssueMetadata(id="2", key="TEST-2", board_id="1"),
235
+ ],
236
+ )
237
+ mocker.patch.object(
238
+ JiraIndexer,
239
+ "_get_issues_within_projects",
240
+ return_value=[
241
+ JiraIssueMetadata(id="1", key="TEST-1"),
242
+ JiraIssueMetadata(id="3", key="TEST-3"),
243
+ ],
244
+ )
245
+ mocker.patch.object(
246
+ JiraIndexer,
247
+ "_get_issues",
248
+ return_value=[
249
+ JiraIssueMetadata(id="4", key="TEST-4"),
250
+ JiraIssueMetadata(id="2", key="TEST-2"),
251
+ ],
252
+ )
253
+
254
+ issues = jira_indexer.get_issues()
255
+ assert len(issues) == 4
256
+ assert issues[0].id == "1"
257
+ assert issues[0].key == "TEST-1"
258
+ assert issues[0].board_id == "1"
259
+ assert issues[1].id == "2"
260
+ assert issues[1].key == "TEST-2"
261
+ assert issues[1].board_id == "1"
262
+ assert issues[2].id == "3"
263
+ assert issues[2].key == "TEST-3"
264
+ assert issues[3].id == "4"
265
+ assert issues[3].key == "TEST-4"
266
+
267
+
268
+ def test_jira_indexer_get_issues_no_duplicates(mocker: MockerFixture, jira_indexer: JiraIndexer):
269
+ mocker.patch.object(
270
+ JiraIndexer,
271
+ "_get_issues_within_boards",
272
+ return_value=[
273
+ JiraIssueMetadata(id="1", key="TEST-1", board_id="1"),
274
+ ],
275
+ )
276
+ mocker.patch.object(
277
+ JiraIndexer,
278
+ "_get_issues_within_projects",
279
+ return_value=[
280
+ JiraIssueMetadata(id="2", key="TEST-2"),
281
+ ],
282
+ )
283
+ mocker.patch.object(
284
+ JiraIndexer,
285
+ "_get_issues",
286
+ return_value=[
287
+ JiraIssueMetadata(id="3", key="TEST-3"),
288
+ ],
289
+ )
290
+
291
+ issues = jira_indexer.get_issues()
292
+ assert len(issues) == 3
293
+ assert issues[0].id == "1"
294
+ assert issues[0].key == "TEST-1"
295
+ assert issues[0].board_id == "1"
296
+ assert issues[1].id == "2"
297
+ assert issues[1].key == "TEST-2"
298
+ assert issues[2].id == "3"
299
+ assert issues[2].key == "TEST-3"
300
+
301
+
302
+ def test_jira_indexer_get_issues_empty(mocker: MockerFixture, jira_indexer: JiraIndexer):
303
+ mocker.patch.object(JiraIndexer, "_get_issues_within_boards", return_value=[])
304
+ mocker.patch.object(JiraIndexer, "_get_issues_within_projects", return_value=[])
305
+ mocker.patch.object(JiraIndexer, "_get_issues", return_value=[])
306
+
307
+ issues = jira_indexer.get_issues()
308
+ assert len(issues) == 0
309
+
310
+
311
+ def test_connection_config_multiple_auth():
312
+ with pytest.raises(ValidationError):
313
+ JiraConnectionConfig(
314
+ access_config=JiraAccessConfig(
315
+ password="api_token",
316
+ token="access_token",
317
+ ),
318
+ username="user_email",
319
+ url="url",
320
+ )
321
+
322
+
323
+ def test_connection_config_no_auth():
324
+ with pytest.raises(ValidationError):
325
+ JiraConnectionConfig(access_config=JiraAccessConfig(), url="url")
326
+
327
+
328
+ def test_connection_config_basic_auth():
329
+ JiraConnectionConfig(
330
+ access_config=JiraAccessConfig(password="api_token"),
331
+ url="url",
332
+ username="user_email",
333
+ )
334
+
335
+
336
+ def test_connection_config_pat_auth():
337
+ JiraConnectionConfig(
338
+ access_config=JiraAccessConfig(token="access_token"),
339
+ url="url",
340
+ )
341
+
342
+
343
+ def test_jira_issue_metadata_object():
344
+ expected = {"id": "10000", "key": "TEST-1", "board_id": "1", "project_id": "TEST"}
345
+ metadata = JiraIssueMetadata(id="10000", key="TEST-1", board_id="1")
346
+ assert expected == metadata.to_dict()
347
+
348
+
349
+ def test_nested_object_to_field_getter():
350
+ obj = {"a": 1, "b": {"c": 2}}
351
+ fg = nested_object_to_field_getter(obj)
352
+ assert isinstance(fg, FieldGetter)
353
+ assert fg["a"] == 1
354
+ assert isinstance(fg["b"], FieldGetter)
355
+ assert fg["b"]["c"] == 2
356
+ assert isinstance(fg["b"]["d"], FieldGetter)
357
+ assert fg["b"]["d"]["e"] == {}
358
+
359
+
360
+ def test_issues_fetcher_wrapper():
361
+ test_issues_to_fetch = 250
362
+ test_issues = [{"id": i} for i in range(0, test_issues_to_fetch)]
363
+
364
+ def mock_func(limit, start):
365
+ return {"results": test_issues[start : start + limit]}
366
+
367
+ wrapped_func = issues_fetcher_wrapper(mock_func, number_of_issues_to_fetch=test_issues_to_fetch)
368
+ results = wrapped_func()
369
+ assert len(results) == 250
370
+ assert results[0]["id"] == 0
371
+ assert results[-1]["id"] == 249
372
+
373
+ test_issues_to_fetch = 150
374
+ test_issues = [{"id": i} for i in range(0, test_issues_to_fetch)]
375
+
376
+ def mock_func_list(limit, start):
377
+ return test_issues[start : start + limit]
378
+
379
+ wrapped_func_list = issues_fetcher_wrapper(
380
+ mock_func_list, number_of_issues_to_fetch=test_issues_to_fetch
381
+ )
382
+ results_list = wrapped_func_list()
383
+ assert len(results_list) == 150
384
+ assert results_list[0]["id"] == 0
385
+ assert results_list[-1]["id"] == 149
386
+
387
+ def mock_func_invalid(limit, start):
388
+ return "invalid"
389
+
390
+ wrapped_func_invalid = issues_fetcher_wrapper(mock_func_invalid, number_of_issues_to_fetch=50)
391
+ with pytest.raises(TypeError):
392
+ wrapped_func_invalid()
393
+
394
+ def mock_func_key_error(limit, start):
395
+ return {"wrong_key": []}
396
+
397
+ wrapped_func_key_error = issues_fetcher_wrapper(
398
+ mock_func_key_error, number_of_issues_to_fetch=50
399
+ )
400
+ with pytest.raises(KeyError):
401
+ wrapped_func_key_error()
@@ -1 +1 @@
1
- __version__ = "0.5.8" # pragma: no cover
1
+ __version__ = "0.5.10" # pragma: no cover
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  from pydantic import Field, SecretStr
5
5
 
@@ -26,6 +26,7 @@ if TYPE_CHECKING:
26
26
  class OpenAIEmbeddingConfig(EmbeddingConfig):
27
27
  api_key: SecretStr
28
28
  embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
29
+ base_url: Optional[str] = None
29
30
 
30
31
  def wrap_error(self, e: Exception) -> Exception:
31
32
  if is_internal_error(e=e):
@@ -57,13 +58,13 @@ class OpenAIEmbeddingConfig(EmbeddingConfig):
57
58
  def get_client(self) -> "OpenAI":
58
59
  from openai import OpenAI
59
60
 
60
- return OpenAI(api_key=self.api_key.get_secret_value())
61
+ return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
61
62
 
62
63
  @requires_dependencies(["openai"], extras="openai")
63
64
  def get_async_client(self) -> "AsyncOpenAI":
64
65
  from openai import AsyncOpenAI
65
66
 
66
- return AsyncOpenAI(api_key=self.api_key.get_secret_value())
67
+ return AsyncOpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
67
68
 
68
69
 
69
70
  @dataclass
@@ -1,9 +1,12 @@
1
1
  import json
2
+ import re
2
3
  import typing as t
3
4
  from datetime import datetime
4
5
 
5
6
  from dateutil import parser
6
7
 
8
+ from unstructured_ingest.v2.logger import logger
9
+
7
10
 
8
11
  def json_to_dict(json_string: str) -> t.Union[str, t.Dict[str, t.Any]]:
9
12
  """Helper function attempts to deserialize json string to a dictionary."""
@@ -47,3 +50,25 @@ def truncate_string_bytes(string: str, max_bytes: int, encoding: str = "utf-8")
47
50
  if len(encoded_string) <= max_bytes:
48
51
  return string
49
52
  return encoded_string[:max_bytes].decode(encoding, errors="ignore")
53
+
54
+
55
+ def fix_unescaped_unicode(text: str, encoding: str = "utf-8") -> str:
56
+ """
57
+ Fix unescaped Unicode sequences in text.
58
+ """
59
+ try:
60
+ _text: str = json.dumps(text)
61
+
62
+ # Pattern to match unescaped Unicode sequences like \\uXXXX
63
+ pattern = r"\\\\u([0-9A-Fa-f]{4})"
64
+ # Replace with properly escaped Unicode sequences \uXXXX
65
+ _text = re.sub(pattern, r"\\u\1", _text)
66
+ _text = json.loads(_text)
67
+
68
+ # Encode the text to check for encoding errors
69
+ _text.encode(encoding)
70
+ return _text
71
+ except Exception as e:
72
+ # Return original text if encoding fails
73
+ logger.warning(f"Failed to fix unescaped Unicode sequences: {e}", exc_info=True)
74
+ return text
@@ -34,6 +34,8 @@ from .gitlab import CONNECTOR_TYPE as GITLAB_CONNECTOR_TYPE
34
34
  from .gitlab import gitlab_source_entry
35
35
  from .google_drive import CONNECTOR_TYPE as GOOGLE_DRIVE_CONNECTOR_TYPE
36
36
  from .google_drive import google_drive_source_entry
37
+ from .jira import CONNECTOR_TYPE as JIRA_CONNECTOR_TYPE
38
+ from .jira import jira_source_entry
37
39
  from .kdbai import CONNECTOR_TYPE as KDBAI_CONNECTOR_TYPE
38
40
  from .kdbai import kdbai_destination_entry
39
41
  from .local import CONNECTOR_TYPE as LOCAL_CONNECTOR_TYPE
@@ -115,3 +117,5 @@ add_source_entry(source_type=CONFLUENCE_CONNECTOR_TYPE, entry=confluence_source_
115
117
 
116
118
  add_source_entry(source_type=DISCORD_CONNECTOR_TYPE, entry=discord_source_entry)
117
119
  add_destination_entry(destination_type=REDIS_CONNECTOR_TYPE, entry=redis_destination_entry)
120
+
121
+ add_source_entry(source_type=JIRA_CONNECTOR_TYPE, entry=jira_source_entry)
@@ -1,5 +1,6 @@
1
1
  import csv
2
2
  import hashlib
3
+ import re
3
4
  from dataclasses import dataclass, field
4
5
  from pathlib import Path
5
6
  from time import time
@@ -48,6 +49,7 @@ if TYPE_CHECKING:
48
49
  from astrapy import AsyncCollection as AstraDBAsyncCollection
49
50
  from astrapy import Collection as AstraDBCollection
50
51
  from astrapy import DataAPIClient as AstraDBClient
52
+ from astrapy import Database as AstraDB
51
53
 
52
54
 
53
55
  CONNECTOR_TYPE = "astradb"
@@ -85,11 +87,10 @@ class AstraDBConnectionConfig(ConnectionConfig):
85
87
  )
86
88
 
87
89
 
88
- def get_astra_collection(
90
+ def get_astra_db(
89
91
  connection_config: AstraDBConnectionConfig,
90
- collection_name: str,
91
92
  keyspace: str,
92
- ) -> "AstraDBCollection":
93
+ ) -> "AstraDB":
93
94
  # Build the Astra DB object.
94
95
  access_configs = connection_config.access_config.get_secret_value()
95
96
 
@@ -103,9 +104,20 @@ def get_astra_collection(
103
104
  token=access_configs.token,
104
105
  keyspace=keyspace,
105
106
  )
107
+ return astra_db
108
+
106
109
 
107
- # Connect to the collection
110
+ def get_astra_collection(
111
+ connection_config: AstraDBConnectionConfig,
112
+ collection_name: str,
113
+ keyspace: str,
114
+ ) -> "AstraDBCollection":
115
+
116
+ astra_db = get_astra_db(connection_config=connection_config, keyspace=keyspace)
117
+
118
+ # astradb will return a collection object in all cases (even if it doesn't exist)
108
119
  astra_db_collection = astra_db.get_collection(name=collection_name)
120
+
109
121
  return astra_db_collection
110
122
 
111
123
 
@@ -151,10 +163,11 @@ class AstraDBDownloaderConfig(DownloaderConfig):
151
163
 
152
164
 
153
165
  class AstraDBUploaderConfig(UploaderConfig):
154
- collection_name: str = Field(
166
+ collection_name: Optional[str] = Field(
155
167
  description="The name of the Astra DB collection. "
156
168
  "Note that the collection name must only include letters, "
157
- "numbers, and underscores."
169
+ "numbers, and underscores.",
170
+ default=None,
158
171
  )
159
172
  keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
160
173
  requested_indexing_policy: Optional[dict[str, Any]] = Field(
@@ -337,25 +350,84 @@ class AstraDBUploader(Uploader):
337
350
  upload_config: AstraDBUploaderConfig
338
351
  connector_type: str = CONNECTOR_TYPE
339
352
 
353
+ def init(self, **kwargs: Any) -> None:
354
+ self.create_destination(**kwargs)
355
+
340
356
  def precheck(self) -> None:
341
357
  try:
342
- get_astra_collection(
343
- connection_config=self.connection_config,
344
- collection_name=self.upload_config.collection_name,
345
- keyspace=self.upload_config.keyspace,
346
- ).options()
358
+ if self.upload_config.collection_name:
359
+ self.get_collection(collection_name=self.upload_config.collection_name).options()
360
+ else:
361
+ # check for db connection only if collection name is not provided
362
+ get_astra_db(
363
+ connection_config=self.connection_config,
364
+ keyspace=self.upload_config.keyspace,
365
+ )
347
366
  except Exception as e:
348
367
  logger.error(f"Failed to validate connection {e}", exc_info=True)
349
368
  raise DestinationConnectionError(f"failed to validate connection: {e}")
350
369
 
351
370
  @requires_dependencies(["astrapy"], extras="astradb")
352
- def get_collection(self) -> "AstraDBCollection":
371
+ def get_collection(self, collection_name: Optional[str] = None) -> "AstraDBCollection":
353
372
  return get_astra_collection(
354
373
  connection_config=self.connection_config,
355
- collection_name=self.upload_config.collection_name,
374
+ collection_name=collection_name or self.upload_config.collection_name,
356
375
  keyspace=self.upload_config.keyspace,
357
376
  )
358
377
 
378
+ def _collection_exists(self, collection_name: str):
379
+ from astrapy.exceptions import CollectionNotFoundException
380
+
381
+ collection = get_astra_collection(
382
+ connection_config=self.connection_config,
383
+ collection_name=collection_name,
384
+ keyspace=self.upload_config.keyspace,
385
+ )
386
+
387
+ try:
388
+ collection.options()
389
+ return True
390
+ except CollectionNotFoundException:
391
+ return False
392
+ except Exception as e:
393
+ logger.error(f"failed to check if astra collection exists : {e}")
394
+ raise DestinationConnectionError(f"failed to check if astra collection exists : {e}")
395
+
396
+ def format_destination_name(self, destination_name: str) -> str:
397
+ # AstraDB collection naming requirements:
398
+ # must be below 50 characters
399
+ # must be lowercase alphanumeric and underscores only
400
+ formatted = re.sub(r"[^a-z0-9]", "_", destination_name.lower())
401
+ return formatted
402
+
403
+ def create_destination(
404
+ self,
405
+ vector_length: int,
406
+ destination_name: str = "unstructuredautocreated",
407
+ similarity_metric: Optional[str] = "cosine",
408
+ **kwargs: Any,
409
+ ) -> bool:
410
+ destination_name = self.format_destination_name(destination_name)
411
+ collection_name = self.upload_config.collection_name or destination_name
412
+ self.upload_config.collection_name = collection_name
413
+
414
+ if not self._collection_exists(collection_name):
415
+ astra_db = get_astra_db(
416
+ connection_config=self.connection_config, keyspace=self.upload_config.keyspace
417
+ )
418
+ logger.info(
419
+ f"creating default astra collection '{collection_name}' with dimension "
420
+ f"{vector_length} and metric {similarity_metric}"
421
+ )
422
+ astra_db.create_collection(
423
+ collection_name,
424
+ dimension=vector_length,
425
+ metric=similarity_metric,
426
+ )
427
+ return True
428
+ logger.debug(f"collection with name '{collection_name}' already exists, skipping creation")
429
+ return False
430
+
359
431
  def delete_by_record_id(self, collection: "AstraDBCollection", file_data: FileData):
360
432
  logger.debug(
361
433
  f"deleting records from collection {collection.name} "
@@ -8,6 +8,7 @@ from pydantic import Field, Secret
8
8
  from unstructured_ingest.error import SourceConnectionError
9
9
  from unstructured_ingest.utils.dep_check import requires_dependencies
10
10
  from unstructured_ingest.utils.html import HtmlMixin
11
+ from unstructured_ingest.utils.string_and_date_utils import fix_unescaped_unicode
11
12
  from unstructured_ingest.v2.interfaces import (
12
13
  AccessConfig,
13
14
  ConnectionConfig,
@@ -224,7 +225,6 @@ class ConfluenceDownloader(Downloader):
224
225
  page_id=doc_id,
225
226
  expand="history.lastUpdated,version,body.view",
226
227
  )
227
-
228
228
  except Exception as e:
229
229
  logger.error(f"Failed to retrieve page with ID {doc_id}: {e}", exc_info=True)
230
230
  raise SourceConnectionError(f"Failed to retrieve page with ID {doc_id}: {e}")
@@ -233,10 +233,10 @@ class ConfluenceDownloader(Downloader):
233
233
  raise ValueError(f"Page with ID {doc_id} does not exist.")
234
234
 
235
235
  content = page["body"]["view"]["value"]
236
- # This supports v2 html parsing in unstructured
237
236
  title = page["title"]
238
- title_html = f"<title>{title}</title>"
239
- content = f"<body class='Document' >{title_html}{content}</body>"
237
+ # Using h1 for title is supported by both v1 and v2 html parsing in unstructured
238
+ title_html = f"<h1>{title}</h1>"
239
+ content = fix_unescaped_unicode(f"<body class='Document' >{title_html}{content}</body>")
240
240
  if self.download_config.extract_images:
241
241
  with self.connection_config.get_client() as client:
242
242
  content = self.download_config.extract_html_images(
@@ -92,6 +92,7 @@ class DeltaTableUploadStager(UploadStager):
92
92
  output_path = Path(output_dir) / Path(f"{output_filename}.parquet")
93
93
 
94
94
  df = convert_to_pandas_dataframe(elements_dict=elements_contents)
95
+ df = df.dropna(axis=1, how="all")
95
96
  df.to_parquet(output_path)
96
97
 
97
98
  return output_path
@@ -153,6 +154,7 @@ class DeltaTableUploader(Uploader):
153
154
  "table_or_uri": updated_upload_path,
154
155
  "data": df,
155
156
  "mode": "overwrite",
157
+ "schema_mode": "merge",
156
158
  "storage_options": storage_options,
157
159
  }
158
160
  queue = Queue()