nucliadb-dataset 6.4.2.post4378__py3-none-any.whl → 6.10.0.post5677__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb_dataset/__init__.py +4 -5
- nucliadb_dataset/api.py +6 -10
- nucliadb_dataset/dataset.py +31 -29
- nucliadb_dataset/nuclia.py +6 -10
- nucliadb_dataset/settings.py +4 -5
- nucliadb_dataset/streamer.py +25 -14
- nucliadb_dataset/tasks.py +5 -4
- nucliadb_dataset/tests/fixtures.py +9 -70
- {nucliadb_dataset-6.4.2.post4378.dist-info → nucliadb_dataset-6.10.0.post5677.dist-info}/METADATA +5 -6
- nucliadb_dataset-6.10.0.post5677.dist-info/RECORD +18 -0
- nucliadb_dataset-6.4.2.post4378.dist-info/RECORD +0 -18
- {nucliadb_dataset-6.4.2.post4378.dist-info → nucliadb_dataset-6.10.0.post5677.dist-info}/WHEEL +0 -0
- {nucliadb_dataset-6.4.2.post4378.dist-info → nucliadb_dataset-6.10.0.post5677.dist-info}/entry_points.txt +0 -0
- {nucliadb_dataset-6.4.2.post4378.dist-info → nucliadb_dataset-6.10.0.post5677.dist-info}/top_level.txt +0 -0
nucliadb_dataset/__init__.py
CHANGED
|
@@ -13,12 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
from enum import Enum
|
|
16
|
-
from typing import Dict
|
|
17
16
|
|
|
18
17
|
from nucliadb_dataset.dataset import NucliaDBDataset, Task, download_all_partitions
|
|
19
18
|
from nucliadb_dataset.nuclia import NucliaDriver
|
|
20
19
|
|
|
21
|
-
NUCLIA_GLOBAL:
|
|
20
|
+
NUCLIA_GLOBAL: dict[str, NucliaDriver] = {}
|
|
22
21
|
|
|
23
22
|
CLIENT_ID = "CLIENT"
|
|
24
23
|
|
|
@@ -29,10 +28,10 @@ class ExportType(str, Enum):
|
|
|
29
28
|
|
|
30
29
|
|
|
31
30
|
__all__ = (
|
|
31
|
+
"CLIENT_ID",
|
|
32
|
+
"NUCLIA_GLOBAL",
|
|
33
|
+
"ExportType",
|
|
32
34
|
"NucliaDBDataset",
|
|
33
35
|
"Task",
|
|
34
36
|
"download_all_partitions",
|
|
35
|
-
"NUCLIA_GLOBAL",
|
|
36
|
-
"CLIENT_ID",
|
|
37
|
-
"ExportType",
|
|
38
37
|
)
|
nucliadb_dataset/api.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from collections.abc import Iterator
|
|
16
16
|
|
|
17
17
|
from nucliadb_dataset import CLIENT_ID, NUCLIA_GLOBAL
|
|
18
18
|
from nucliadb_dataset.nuclia import NucliaDriver
|
|
@@ -36,26 +36,22 @@ def get_nuclia_client() -> NucliaDriver:
|
|
|
36
36
|
|
|
37
37
|
def iterate_sentences(kbid: str, labels: bool, entities: bool, text: bool) -> Iterator[TrainSentence]:
|
|
38
38
|
client = get_nuclia_client()
|
|
39
|
-
|
|
40
|
-
yield sentence
|
|
39
|
+
yield from client.iterate_sentences(kbid, labels, entities, text)
|
|
41
40
|
|
|
42
41
|
|
|
43
42
|
def iterate_paragraphs(kbid: str, labels: bool, entities: bool, text: bool) -> Iterator[TrainParagraph]:
|
|
44
43
|
client = get_nuclia_client()
|
|
45
|
-
|
|
46
|
-
yield sentence
|
|
44
|
+
yield from client.iterate_paragraphs(kbid, labels, entities, text)
|
|
47
45
|
|
|
48
46
|
|
|
49
47
|
def iterate_fields(kbid: str, labels: bool, entities: bool, text: bool) -> Iterator[TrainField]:
|
|
50
48
|
client = get_nuclia_client()
|
|
51
|
-
|
|
52
|
-
yield sentence
|
|
49
|
+
yield from client.iterate_fields(kbid, labels, entities, text)
|
|
53
50
|
|
|
54
51
|
|
|
55
52
|
def iterate_resources(kbid: str, labels: bool, entities: bool, text: bool) -> Iterator[TrainResource]:
|
|
56
53
|
client = get_nuclia_client()
|
|
57
|
-
|
|
58
|
-
yield sentence
|
|
54
|
+
yield from client.iterate_resources(kbid, labels, entities, text)
|
|
59
55
|
|
|
60
56
|
|
|
61
57
|
def get_labels(kbid: str) -> GetLabelsResponse:
|
|
@@ -76,7 +72,7 @@ def get_info(kbid: str) -> TrainInfo:
|
|
|
76
72
|
return info
|
|
77
73
|
|
|
78
74
|
|
|
79
|
-
def get_ontology_count(kbid: str, paragraph_labelsets:
|
|
75
|
+
def get_ontology_count(kbid: str, paragraph_labelsets: list[str], resource_labelset: list[str]):
|
|
80
76
|
client = get_nuclia_client()
|
|
81
77
|
labels = client.get_ontology_count(kbid, paragraph_labelsets, resource_labelset)
|
|
82
78
|
return labels
|
nucliadb_dataset/dataset.py
CHANGED
|
@@ -14,8 +14,9 @@
|
|
|
14
14
|
|
|
15
15
|
import logging
|
|
16
16
|
import os
|
|
17
|
+
from collections.abc import Callable, Iterator
|
|
17
18
|
from dataclasses import dataclass, field
|
|
18
|
-
from typing import Any
|
|
19
|
+
from typing import Any
|
|
19
20
|
|
|
20
21
|
import pyarrow as pa # type: ignore
|
|
21
22
|
|
|
@@ -29,8 +30,9 @@ from nucliadb_dataset.tasks import (
|
|
|
29
30
|
from nucliadb_models.entities import KnowledgeBoxEntities
|
|
30
31
|
from nucliadb_models.labels import KnowledgeBoxLabels
|
|
31
32
|
from nucliadb_models.resource import KnowledgeBoxObj
|
|
33
|
+
from nucliadb_models.trainset import TrainSet as TrainSetModel
|
|
32
34
|
from nucliadb_models.trainset import TrainSetPartitions
|
|
33
|
-
from nucliadb_protos.dataset_pb2 import TrainSet
|
|
35
|
+
from nucliadb_protos.dataset_pb2 import TrainSet as TrainSetPB
|
|
34
36
|
from nucliadb_sdk.v2.sdk import NucliaDB
|
|
35
37
|
|
|
36
38
|
logger = logging.getLogger("nucliadb_dataset")
|
|
@@ -41,12 +43,12 @@ CHUNK_SIZE = 5 * 1024 * 1024
|
|
|
41
43
|
@dataclass
|
|
42
44
|
class LabelSetCount:
|
|
43
45
|
count: int
|
|
44
|
-
labels:
|
|
46
|
+
labels: dict[str, int] = field(default_factory=dict)
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
class NucliaDataset
|
|
48
|
-
labels:
|
|
49
|
-
entities:
|
|
49
|
+
class NucliaDataset:
|
|
50
|
+
labels: KnowledgeBoxLabels | None
|
|
51
|
+
entities: KnowledgeBoxEntities | None
|
|
50
52
|
|
|
51
53
|
def __new__(cls, *args, **kwargs):
|
|
52
54
|
if cls is NucliaDataset:
|
|
@@ -55,18 +57,18 @@ class NucliaDataset(object):
|
|
|
55
57
|
|
|
56
58
|
def __init__(
|
|
57
59
|
self,
|
|
58
|
-
base_path:
|
|
60
|
+
base_path: str | None = None,
|
|
59
61
|
):
|
|
60
62
|
if base_path is None:
|
|
61
63
|
base_path = os.getcwd()
|
|
62
64
|
self.base_path = base_path
|
|
63
|
-
self.mappings:
|
|
65
|
+
self.mappings: list[Callable] = []
|
|
64
66
|
|
|
65
67
|
self.labels = None
|
|
66
68
|
self.entities = None
|
|
67
69
|
self.folder = None
|
|
68
70
|
|
|
69
|
-
def iter_all_partitions(self, force=False) -> Iterator[
|
|
71
|
+
def iter_all_partitions(self, force=False) -> Iterator[tuple[str, str]]:
|
|
70
72
|
partitions = self.get_partitions()
|
|
71
73
|
for index, partition in enumerate(partitions):
|
|
72
74
|
logger.info(f"Reading partition {partition} {index}/{len(partitions)}")
|
|
@@ -74,7 +76,7 @@ class NucliaDataset(object):
|
|
|
74
76
|
logger.info(f"Done reading partition {partition}")
|
|
75
77
|
yield partition, filename
|
|
76
78
|
|
|
77
|
-
def read_all_partitions(self, force=False, path:
|
|
79
|
+
def read_all_partitions(self, force=False, path: str | None = None) -> list[str]:
|
|
78
80
|
partitions = self.get_partitions()
|
|
79
81
|
result = []
|
|
80
82
|
for index, partition in enumerate(partitions):
|
|
@@ -90,9 +92,9 @@ class NucliaDataset(object):
|
|
|
90
92
|
def read_partition(
|
|
91
93
|
self,
|
|
92
94
|
partition_id: str,
|
|
93
|
-
filename:
|
|
95
|
+
filename: str | None = None,
|
|
94
96
|
force: bool = False,
|
|
95
|
-
path:
|
|
97
|
+
path: str | None = None,
|
|
96
98
|
):
|
|
97
99
|
raise NotImplementedError()
|
|
98
100
|
|
|
@@ -102,12 +104,12 @@ class NucliaDBDataset(NucliaDataset):
|
|
|
102
104
|
self,
|
|
103
105
|
sdk: NucliaDB,
|
|
104
106
|
kbid: str,
|
|
105
|
-
task:
|
|
106
|
-
labels:
|
|
107
|
-
trainset:
|
|
108
|
-
base_path:
|
|
109
|
-
search_sdk:
|
|
110
|
-
reader_sdk:
|
|
107
|
+
task: Task | None = None,
|
|
108
|
+
labels: list[str] | None = None,
|
|
109
|
+
trainset: TrainSetPB | TrainSetModel | None = None,
|
|
110
|
+
base_path: str | None = None,
|
|
111
|
+
search_sdk: NucliaDB | None = None,
|
|
112
|
+
reader_sdk: NucliaDB | None = None,
|
|
111
113
|
):
|
|
112
114
|
super().__init__(base_path)
|
|
113
115
|
|
|
@@ -119,7 +121,7 @@ class NucliaDBDataset(NucliaDataset):
|
|
|
119
121
|
task_definition = TASK_DEFINITIONS.get(task)
|
|
120
122
|
if task_definition is None:
|
|
121
123
|
raise KeyError("Not a valid task")
|
|
122
|
-
trainset =
|
|
124
|
+
trainset = TrainSetPB(type=task_definition.proto)
|
|
123
125
|
if task_definition.labels:
|
|
124
126
|
trainset.filter.labels.extend(labels)
|
|
125
127
|
elif trainset is not None:
|
|
@@ -164,13 +166,13 @@ class NucliaDBDataset(NucliaDataset):
|
|
|
164
166
|
streamer.initialize(partition_id)
|
|
165
167
|
return streamer
|
|
166
168
|
|
|
167
|
-
def _set_mappings(self, funcs:
|
|
169
|
+
def _set_mappings(self, funcs: list[Callable[[Any, Any], tuple[Any, Any]]]):
|
|
168
170
|
self.mappings = funcs
|
|
169
171
|
|
|
170
172
|
def _set_schema(self, schema: pa.Schema):
|
|
171
173
|
self.schema = schema
|
|
172
174
|
|
|
173
|
-
def get_partitions(self) ->
|
|
175
|
+
def get_partitions(self) -> list[str]:
|
|
174
176
|
"""
|
|
175
177
|
Get expected number of partitions from a live NucliaDB
|
|
176
178
|
"""
|
|
@@ -182,9 +184,9 @@ class NucliaDBDataset(NucliaDataset):
|
|
|
182
184
|
def read_partition(
|
|
183
185
|
self,
|
|
184
186
|
partition_id: str,
|
|
185
|
-
filename:
|
|
187
|
+
filename: str | None = None,
|
|
186
188
|
force: bool = False,
|
|
187
|
-
path:
|
|
189
|
+
path: str | None = None,
|
|
188
190
|
):
|
|
189
191
|
"""
|
|
190
192
|
Export an arrow partition from a live NucliaDB and store it locally
|
|
@@ -219,12 +221,12 @@ class NucliaDBDataset(NucliaDataset):
|
|
|
219
221
|
|
|
220
222
|
def download_all_partitions(
|
|
221
223
|
task: str,
|
|
222
|
-
slug:
|
|
223
|
-
kbid:
|
|
224
|
-
nucliadb_base_url:
|
|
225
|
-
path:
|
|
226
|
-
sdk:
|
|
227
|
-
labels:
|
|
224
|
+
slug: str | None = None,
|
|
225
|
+
kbid: str | None = None,
|
|
226
|
+
nucliadb_base_url: str | None = "http://localhost:8080",
|
|
227
|
+
path: str | None = None,
|
|
228
|
+
sdk: NucliaDB | None = None,
|
|
229
|
+
labels: list[str] | None = None,
|
|
228
230
|
):
|
|
229
231
|
if sdk is None:
|
|
230
232
|
sdk = NucliaDB(region="on-prem", url=nucliadb_base_url)
|
nucliadb_dataset/nuclia.py
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from collections.abc import Iterator
|
|
16
16
|
|
|
17
17
|
import grpc
|
|
18
18
|
|
|
@@ -54,8 +54,7 @@ class NucliaDriver:
|
|
|
54
54
|
request.metadata.labels = labels
|
|
55
55
|
request.metadata.entities = entities
|
|
56
56
|
request.metadata.text = text
|
|
57
|
-
|
|
58
|
-
yield sentence
|
|
57
|
+
yield from self.stub.GetSentences(request)
|
|
59
58
|
|
|
60
59
|
def iterate_paragraphs(
|
|
61
60
|
self, kbid: str, labels: bool, entities: bool, text: bool
|
|
@@ -65,8 +64,7 @@ class NucliaDriver:
|
|
|
65
64
|
request.metadata.labels = labels
|
|
66
65
|
request.metadata.entities = entities
|
|
67
66
|
request.metadata.text = text
|
|
68
|
-
|
|
69
|
-
yield paragraph
|
|
67
|
+
yield from self.stub.GetParagraphs(request)
|
|
70
68
|
|
|
71
69
|
def iterate_resources(
|
|
72
70
|
self, kbid: str, labels: bool, entities: bool, text: bool
|
|
@@ -76,8 +74,7 @@ class NucliaDriver:
|
|
|
76
74
|
request.metadata.labels = labels
|
|
77
75
|
request.metadata.entities = entities
|
|
78
76
|
request.metadata.text = text
|
|
79
|
-
|
|
80
|
-
yield resource
|
|
77
|
+
yield from self.stub.GetResources(request)
|
|
81
78
|
|
|
82
79
|
def iterate_fields(
|
|
83
80
|
self, kbid: str, labels: bool, entities: bool, text: bool
|
|
@@ -87,8 +84,7 @@ class NucliaDriver:
|
|
|
87
84
|
request.metadata.labels = labels
|
|
88
85
|
request.metadata.entities = entities
|
|
89
86
|
request.metadata.text = text
|
|
90
|
-
|
|
91
|
-
yield field
|
|
87
|
+
yield from self.stub.GetFields(request)
|
|
92
88
|
|
|
93
89
|
def get_labels(self, kbid: str) -> GetLabelsResponse:
|
|
94
90
|
request = GetLabelsRequest()
|
|
@@ -106,7 +102,7 @@ class NucliaDriver:
|
|
|
106
102
|
return self.stub.GetInfo(request)
|
|
107
103
|
|
|
108
104
|
def get_ontology_count(
|
|
109
|
-
self, kbid: str, paragraph_labelsets:
|
|
105
|
+
self, kbid: str, paragraph_labelsets: list[str], resource_labelsets: list[str]
|
|
110
106
|
) -> LabelsetsCount:
|
|
111
107
|
request = GetLabelsetsCountRequest()
|
|
112
108
|
request.kb.uuid = kbid
|
nucliadb_dataset/settings.py
CHANGED
|
@@ -14,7 +14,6 @@
|
|
|
14
14
|
#
|
|
15
15
|
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from typing import Optional
|
|
18
17
|
|
|
19
18
|
import pydantic
|
|
20
19
|
from pydantic_settings import BaseSettings
|
|
@@ -35,20 +34,20 @@ class RunningSettings(BaseSettings):
|
|
|
35
34
|
download_path: str = pydantic.Field(f"{Path.home()}/.nuclia/download", description="Download path")
|
|
36
35
|
url: str = pydantic.Field(description="KnowledgeBox URL")
|
|
37
36
|
type: Task = pydantic.Field(description="Dataset Type")
|
|
38
|
-
labelset:
|
|
37
|
+
labelset: str | None = pydantic.Field(
|
|
39
38
|
None, description="For classification which labelset or families"
|
|
40
39
|
)
|
|
41
40
|
|
|
42
41
|
datasets_url: str = pydantic.Field(
|
|
43
42
|
"https://europe-1.nuclia.cloud",
|
|
44
|
-
description="Base url for the Nuclia datasets component (excluding /api/v1)™",
|
|
43
|
+
description="Base url for the Nuclia datasets component (excluding /api/v1)™",
|
|
45
44
|
)
|
|
46
45
|
|
|
47
|
-
apikey:
|
|
46
|
+
apikey: str | None = pydantic.Field(None, description="API key to upload to Nuclia Datasets™")
|
|
48
47
|
|
|
49
48
|
environment: str = pydantic.Field("on-prem", description="region or on-prem")
|
|
50
49
|
|
|
51
|
-
service_token:
|
|
50
|
+
service_token: str | None = pydantic.Field(
|
|
52
51
|
None, description="Service account key to access Nuclia Cloud"
|
|
53
52
|
)
|
|
54
53
|
|
nucliadb_dataset/streamer.py
CHANGED
|
@@ -13,11 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import logging
|
|
16
|
-
from typing import Dict, Optional
|
|
17
16
|
|
|
18
17
|
import requests
|
|
19
18
|
|
|
20
|
-
from
|
|
19
|
+
from nucliadb_models.trainset import TrainSet as TrainSetModel
|
|
20
|
+
from nucliadb_protos.dataset_pb2 import TrainSet as TrainSetPB
|
|
21
21
|
|
|
22
22
|
logger = logging.getLogger("nucliadb_dataset")
|
|
23
23
|
|
|
@@ -25,12 +25,12 @@ SIZE_BYTES = 4
|
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class Streamer:
|
|
28
|
-
resp:
|
|
28
|
+
resp: requests.Response | None
|
|
29
29
|
|
|
30
30
|
def __init__(
|
|
31
31
|
self,
|
|
32
|
-
trainset:
|
|
33
|
-
reader_headers:
|
|
32
|
+
trainset: TrainSetPB | TrainSetModel,
|
|
33
|
+
reader_headers: dict[str, str],
|
|
34
34
|
base_url: str,
|
|
35
35
|
kbid: str,
|
|
36
36
|
):
|
|
@@ -47,12 +47,23 @@ class Streamer:
|
|
|
47
47
|
def initialize(self, partition_id: str):
|
|
48
48
|
self.stream_session = requests.Session()
|
|
49
49
|
self.stream_session.headers.update(self.reader_headers)
|
|
50
|
-
self.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
50
|
+
if isinstance(self.trainset, TrainSetPB):
|
|
51
|
+
# Legacy version of the endpoint is passing the protobuffer as bytes in the request content
|
|
52
|
+
self.resp = self.stream_session.post(
|
|
53
|
+
f"{self.base_url}/v1/kb/{self.kbid}/trainset/{partition_id}",
|
|
54
|
+
data=self.trainset.SerializeToString(),
|
|
55
|
+
stream=True,
|
|
56
|
+
timeout=None,
|
|
57
|
+
)
|
|
58
|
+
elif isinstance(self.trainset, TrainSetModel):
|
|
59
|
+
self.resp = self.stream_session.post(
|
|
60
|
+
f"{self.base_url}/v1/kb/{self.kbid}/trainset/{partition_id}",
|
|
61
|
+
json=self.trainset.model_dump(),
|
|
62
|
+
stream=True,
|
|
63
|
+
timeout=None,
|
|
64
|
+
)
|
|
65
|
+
else: # pragma: no cover
|
|
66
|
+
raise ValueError("Invalid trainset type")
|
|
56
67
|
self.resp.raise_for_status()
|
|
57
68
|
|
|
58
69
|
def finalize(self):
|
|
@@ -63,16 +74,16 @@ class Streamer:
|
|
|
63
74
|
def __iter__(self):
|
|
64
75
|
return self
|
|
65
76
|
|
|
66
|
-
def read(self) ->
|
|
77
|
+
def read(self) -> bytes | None:
|
|
67
78
|
assert self.resp is not None, "Streamer not initialized"
|
|
68
79
|
header = self.resp.raw.read(4, decode_content=True)
|
|
69
80
|
if header == b"":
|
|
70
81
|
return None
|
|
71
|
-
payload_size = int.from_bytes(header, byteorder="big", signed=False)
|
|
82
|
+
payload_size = int.from_bytes(header, byteorder="big", signed=False)
|
|
72
83
|
data = self.resp.raw.read(payload_size)
|
|
73
84
|
return data
|
|
74
85
|
|
|
75
|
-
def __next__(self) ->
|
|
86
|
+
def __next__(self) -> bytes | None:
|
|
76
87
|
payload = self.read()
|
|
77
88
|
if payload in [None, b""]:
|
|
78
89
|
logger.info("Streamer finished reading")
|
nucliadb_dataset/tasks.py
CHANGED
|
@@ -12,9 +12,10 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from collections.abc import Callable
|
|
15
16
|
from dataclasses import dataclass
|
|
16
17
|
from enum import Enum
|
|
17
|
-
from typing import TYPE_CHECKING, Any
|
|
18
|
+
from typing import TYPE_CHECKING, Any
|
|
18
19
|
|
|
19
20
|
import pyarrow as pa # type: ignore
|
|
20
21
|
|
|
@@ -64,10 +65,10 @@ class TaskDefinition:
|
|
|
64
65
|
schema: pa.schema
|
|
65
66
|
proto: Any
|
|
66
67
|
labels: bool
|
|
67
|
-
mapping:
|
|
68
|
+
mapping: list[Callable]
|
|
68
69
|
|
|
69
70
|
|
|
70
|
-
TASK_DEFINITIONS:
|
|
71
|
+
TASK_DEFINITIONS: dict[Task, TaskDefinition] = {
|
|
71
72
|
Task.PARAGRAPH_CLASSIFICATION: TaskDefinition(
|
|
72
73
|
schema=pa.schema(
|
|
73
74
|
[
|
|
@@ -190,4 +191,4 @@ TASK_DEFINITIONS: Dict[Task, TaskDefinition] = {
|
|
|
190
191
|
),
|
|
191
192
|
}
|
|
192
193
|
|
|
193
|
-
TASK_DEFINITIONS_REVERSE = {task.proto: task for task in TASK_DEFINITIONS.values()}
|
|
194
|
+
TASK_DEFINITIONS_REVERSE = {task.proto: task for task in TASK_DEFINITIONS.values()}
|
|
@@ -14,7 +14,8 @@
|
|
|
14
14
|
|
|
15
15
|
import re
|
|
16
16
|
import tempfile
|
|
17
|
-
from
|
|
17
|
+
from collections.abc import AsyncIterator, Iterator
|
|
18
|
+
from typing import TYPE_CHECKING
|
|
18
19
|
|
|
19
20
|
import docker # type: ignore
|
|
20
21
|
import grpc
|
|
@@ -22,7 +23,6 @@ import pytest
|
|
|
22
23
|
from grpc import aio
|
|
23
24
|
|
|
24
25
|
from nucliadb_models.common import UserClassification
|
|
25
|
-
from nucliadb_models.entities import CreateEntitiesGroupPayload, Entity
|
|
26
26
|
from nucliadb_models.labels import Label, LabelSet, LabelSetKind
|
|
27
27
|
from nucliadb_models.metadata import UserMetadata
|
|
28
28
|
from nucliadb_models.resource import KnowledgeBoxObj
|
|
@@ -33,7 +33,10 @@ from nucliadb_protos.writer_pb2_grpc import WriterStub
|
|
|
33
33
|
from nucliadb_sdk.v2.sdk import NucliaDB
|
|
34
34
|
|
|
35
35
|
DOCKER_ENV_GROUPS = re.search(r"//([^:]+)", docker.from_env().api.base_url)
|
|
36
|
-
DOCKER_HOST:
|
|
36
|
+
DOCKER_HOST: str | None = DOCKER_ENV_GROUPS.group(1) if DOCKER_ENV_GROUPS else None
|
|
37
|
+
|
|
38
|
+
if TYPE_CHECKING:
|
|
39
|
+
from nucliadb_protos.writer_pb2_grpc import WriterAsyncStub
|
|
37
40
|
|
|
38
41
|
|
|
39
42
|
@pytest.fixture(scope="function")
|
|
@@ -156,70 +159,6 @@ def upload_data_paragraph_classification(sdk: NucliaDB, kb: KnowledgeBoxObj):
|
|
|
156
159
|
return kb
|
|
157
160
|
|
|
158
161
|
|
|
159
|
-
@pytest.fixture(scope="function")
|
|
160
|
-
def upload_data_token_classification(sdk: NucliaDB, kb: KnowledgeBoxObj):
|
|
161
|
-
sdk.create_entitygroup(
|
|
162
|
-
kbid=kb.uuid,
|
|
163
|
-
content=CreateEntitiesGroupPayload(
|
|
164
|
-
group="PERSON",
|
|
165
|
-
entities={
|
|
166
|
-
"ramon": Entity(value="Ramon"),
|
|
167
|
-
"carmen": Entity(value="Carmen Iniesta"),
|
|
168
|
-
"eudald": Entity(value="Eudald Camprubi"),
|
|
169
|
-
},
|
|
170
|
-
title="Animals",
|
|
171
|
-
color="black",
|
|
172
|
-
),
|
|
173
|
-
)
|
|
174
|
-
|
|
175
|
-
sdk.create_entitygroup(
|
|
176
|
-
kbid=kb.uuid,
|
|
177
|
-
content=CreateEntitiesGroupPayload(
|
|
178
|
-
group="ANIMAL",
|
|
179
|
-
entities={
|
|
180
|
-
"cheetah": Entity(value="cheetah"),
|
|
181
|
-
"tiger": Entity(value="tiger"),
|
|
182
|
-
"lion": Entity(value="lion"),
|
|
183
|
-
},
|
|
184
|
-
title="Animals",
|
|
185
|
-
color="black",
|
|
186
|
-
),
|
|
187
|
-
)
|
|
188
|
-
sdk.create_resource(
|
|
189
|
-
kbid=kb.uuid,
|
|
190
|
-
content=CreateResourcePayload(
|
|
191
|
-
slug=SlugString("doc1"),
|
|
192
|
-
texts={FieldIdString("text"): TextField(body="Ramon This is my lovely text")},
|
|
193
|
-
),
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
sdk.create_resource(
|
|
197
|
-
kbid=kb.uuid,
|
|
198
|
-
content=CreateResourcePayload(
|
|
199
|
-
slug=SlugString("doc2"),
|
|
200
|
-
texts={
|
|
201
|
-
FieldIdString("text"): TextField(
|
|
202
|
-
body="Carmen Iniesta shows an amazing classifier to Eudald Camprubi"
|
|
203
|
-
)
|
|
204
|
-
},
|
|
205
|
-
),
|
|
206
|
-
)
|
|
207
|
-
|
|
208
|
-
sdk.create_resource(
|
|
209
|
-
kbid=kb.uuid,
|
|
210
|
-
content=CreateResourcePayload(
|
|
211
|
-
slug=SlugString("doc3"),
|
|
212
|
-
texts={
|
|
213
|
-
FieldIdString("text"): TextField(
|
|
214
|
-
body="Which is the fastest animal, a lion, a tiger or a cheetah?"
|
|
215
|
-
)
|
|
216
|
-
},
|
|
217
|
-
),
|
|
218
|
-
)
|
|
219
|
-
|
|
220
|
-
return kb
|
|
221
|
-
|
|
222
|
-
|
|
223
162
|
@pytest.fixture(scope="function")
|
|
224
163
|
def text_editors_kb(sdk: NucliaDB, kb: KnowledgeBoxObj):
|
|
225
164
|
sdk.create_resource(
|
|
@@ -285,10 +224,10 @@ def temp_folder():
|
|
|
285
224
|
|
|
286
225
|
|
|
287
226
|
@pytest.fixture
|
|
288
|
-
async def ingest_stub(nucliadb) -> AsyncIterator[
|
|
227
|
+
async def ingest_stub(nucliadb) -> AsyncIterator["WriterAsyncStub"]:
|
|
289
228
|
channel = aio.insecure_channel(f"{nucliadb.host}:{nucliadb.grpc}")
|
|
290
229
|
stub = WriterStub(channel)
|
|
291
|
-
yield stub
|
|
230
|
+
yield stub # type: ignore
|
|
292
231
|
await channel.close(grace=True)
|
|
293
232
|
|
|
294
233
|
|
|
@@ -296,5 +235,5 @@ async def ingest_stub(nucliadb) -> AsyncIterator[WriterStub]:
|
|
|
296
235
|
def ingest_stub_sync(nucliadb) -> Iterator[WriterStub]:
|
|
297
236
|
channel = grpc.insecure_channel(f"{nucliadb.host}:{nucliadb.grpc}")
|
|
298
237
|
stub = WriterStub(channel)
|
|
299
|
-
yield stub
|
|
238
|
+
yield stub # type: ignore
|
|
300
239
|
channel.close()
|
{nucliadb_dataset-6.4.2.post4378.dist-info → nucliadb_dataset-6.10.0.post5677.dist-info}/METADATA
RENAMED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: nucliadb_dataset
|
|
3
|
-
Version: 6.
|
|
3
|
+
Version: 6.10.0.post5677
|
|
4
4
|
Summary: NucliaDB Train Python client
|
|
5
5
|
Author-email: Nuclia <nucliadb@nuclia.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -10,12 +10,11 @@ Classifier: Development Status :: 3 - Alpha
|
|
|
10
10
|
Classifier: Intended Audience :: Developers
|
|
11
11
|
Classifier: Intended Audience :: Information Technology
|
|
12
12
|
Classifier: Programming Language :: Python
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.12
|
|
17
16
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
18
|
-
Requires-Python: <4,>=3.
|
|
17
|
+
Requires-Python: <4,>=3.10
|
|
19
18
|
Description-Content-Type: text/markdown
|
|
20
19
|
Requires-Dist: protobuf
|
|
21
20
|
Requires-Dist: types-protobuf
|
|
@@ -25,9 +24,9 @@ Requires-Dist: aiohttp
|
|
|
25
24
|
Requires-Dist: argdantic
|
|
26
25
|
Requires-Dist: pydantic-settings>=2.2
|
|
27
26
|
Requires-Dist: pyarrow
|
|
28
|
-
Requires-Dist: nucliadb-protos>=6.
|
|
29
|
-
Requires-Dist: nucliadb-sdk>=6.
|
|
30
|
-
Requires-Dist: nucliadb-models>=6.
|
|
27
|
+
Requires-Dist: nucliadb-protos>=6.10.0.post5677
|
|
28
|
+
Requires-Dist: nucliadb-sdk>=6.10.0.post5677
|
|
29
|
+
Requires-Dist: nucliadb-models>=6.10.0.post5677
|
|
31
30
|
|
|
32
31
|
# NUCLIADB TRAIN CLIENT
|
|
33
32
|
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
nucliadb_dataset/__init__.py,sha256=1lvjYSji93zdS2zZGnanh9TeunJcaInv-CBEoEfdAL0,1030
|
|
2
|
+
nucliadb_dataset/api.py,sha256=i8OlwF1ly6lkrEMb9Ffoc0GSXB-7zGaOtXMN51wDyO8,2644
|
|
3
|
+
nucliadb_dataset/dataset.py,sha256=VeJH71WWNJPFO6UIKRP8iSz5O6FDXwOUaQdS52VuTNQ,7978
|
|
4
|
+
nucliadb_dataset/export.py,sha256=FJjmg1GA0fhzxEZVgMbrqjbcLTi2v9gw3wI0vhKaDWI,2528
|
|
5
|
+
nucliadb_dataset/mapping.py,sha256=Ayg-dDiGc4P-ctRj2ddIlbgidziuLlZGKeAZ08aHBZU,6495
|
|
6
|
+
nucliadb_dataset/nuclia.py,sha256=eBsuDs_HHZTpbyl35H7nqqu3pgEFHzDo8VygRFjTWDg,3593
|
|
7
|
+
nucliadb_dataset/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
nucliadb_dataset/run.py,sha256=Ktqv6m0f5oCs54RNZh9b3MJtYZ_JmKP1Zp2Uo3VEyF4,2806
|
|
9
|
+
nucliadb_dataset/settings.py,sha256=xRDTuqj4SVyTmliA6RwMRmbvI4aLXTx-CrfUDqMJvbc,1982
|
|
10
|
+
nucliadb_dataset/streamer.py,sha256=EKoasniFuVHSQIPVxZ6i4wQVKsXxDNLtPjpX_J5CLDk,2996
|
|
11
|
+
nucliadb_dataset/tasks.py,sha256=e1pCPzY9T2jQP2WgTNHGJnuBKg0JWg0vonuxQWAnJLs,6234
|
|
12
|
+
nucliadb_dataset/tests/__init__.py,sha256=zG33bUz1rHFPtvqQPWn4rDwBJt3FJodGuQYD45quiQg,583
|
|
13
|
+
nucliadb_dataset/tests/fixtures.py,sha256=KPMO0mYzX3HR41nzkJ8D0TlF7bUekQGBi8uSaMtAmLU,7524
|
|
14
|
+
nucliadb_dataset-6.10.0.post5677.dist-info/METADATA,sha256=FNLxuKl97x-PkZFNhrWfmKdJLLZgSQxfTjO5-02PzIA,1218
|
|
15
|
+
nucliadb_dataset-6.10.0.post5677.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
+
nucliadb_dataset-6.10.0.post5677.dist-info/entry_points.txt,sha256=ORrWnn6AUFfHGY1fWPRPTz99KV_pXXwttZAopyT8qvQ,60
|
|
17
|
+
nucliadb_dataset-6.10.0.post5677.dist-info/top_level.txt,sha256=aJtDe54tz6060E0uyk1rdTRAU4FPWo5if1fYFQGvdqU,17
|
|
18
|
+
nucliadb_dataset-6.10.0.post5677.dist-info/RECORD,,
|
|
@@ -1,18 +0,0 @@
|
|
|
1
|
-
nucliadb_dataset/__init__.py,sha256=I58PAYrrgLvxmkGGHvzKKUwnaZ2ny44hba6AXEYfKOQ,1054
|
|
2
|
-
nucliadb_dataset/api.py,sha256=RDIW23cy12E5_UlwsiOuhdFrr1OHPA4Mj7hZH0BqGgA,2757
|
|
3
|
-
nucliadb_dataset/dataset.py,sha256=E8ARHwZntdZH9g4zKEVhYbqMYxn3hnuPGkuErL0V-KM,7951
|
|
4
|
-
nucliadb_dataset/export.py,sha256=FJjmg1GA0fhzxEZVgMbrqjbcLTi2v9gw3wI0vhKaDWI,2528
|
|
5
|
-
nucliadb_dataset/mapping.py,sha256=Ayg-dDiGc4P-ctRj2ddIlbgidziuLlZGKeAZ08aHBZU,6495
|
|
6
|
-
nucliadb_dataset/nuclia.py,sha256=uXiwJS_GBcN6z9l9hFahD3jLbqhkfpGNWpMR6_8K5k8,3718
|
|
7
|
-
nucliadb_dataset/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
-
nucliadb_dataset/run.py,sha256=Ktqv6m0f5oCs54RNZh9b3MJtYZ_JmKP1Zp2Uo3VEyF4,2806
|
|
9
|
-
nucliadb_dataset/settings.py,sha256=9NYeJVgIHbLqCenUFRwj6Iz9S7klOROXQfzaUiBUEl0,2027
|
|
10
|
-
nucliadb_dataset/streamer.py,sha256=-v4FeIRhU_uLhic14vKsbBg8sIqDg9O4gtk_JwSqScg,2381
|
|
11
|
-
nucliadb_dataset/tasks.py,sha256=198o37vDlzS7OdXrHYhtwI8kz2WHWJnxpholh-rtTPQ,6227
|
|
12
|
-
nucliadb_dataset/tests/__init__.py,sha256=zG33bUz1rHFPtvqQPWn4rDwBJt3FJodGuQYD45quiQg,583
|
|
13
|
-
nucliadb_dataset/tests/fixtures.py,sha256=Fx82tfwoFx0YyaiC6tg-ThUDvS8pcmFSE0Pm1KpAS0I,9242
|
|
14
|
-
nucliadb_dataset-6.4.2.post4378.dist-info/METADATA,sha256=EQPdSB7KK2231hULhwl21g0vKSmoqwGQ8DIoVkzztJs,1263
|
|
15
|
-
nucliadb_dataset-6.4.2.post4378.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
-
nucliadb_dataset-6.4.2.post4378.dist-info/entry_points.txt,sha256=ORrWnn6AUFfHGY1fWPRPTz99KV_pXXwttZAopyT8qvQ,60
|
|
17
|
-
nucliadb_dataset-6.4.2.post4378.dist-info/top_level.txt,sha256=aJtDe54tz6060E0uyk1rdTRAU4FPWo5if1fYFQGvdqU,17
|
|
18
|
-
nucliadb_dataset-6.4.2.post4378.dist-info/RECORD,,
|
{nucliadb_dataset-6.4.2.post4378.dist-info → nucliadb_dataset-6.10.0.post5677.dist-info}/WHEEL
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|