nucliadb-dataset 6.4.2.post4389__py3-none-any.whl → 6.5.0.post4404__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,7 @@
15
15
  import logging
16
16
  import os
17
17
  from dataclasses import dataclass, field
18
- from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
18
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
19
19
 
20
20
  import pyarrow as pa # type: ignore
21
21
 
@@ -29,8 +29,9 @@ from nucliadb_dataset.tasks import (
29
29
  from nucliadb_models.entities import KnowledgeBoxEntities
30
30
  from nucliadb_models.labels import KnowledgeBoxLabels
31
31
  from nucliadb_models.resource import KnowledgeBoxObj
32
+ from nucliadb_models.trainset import TrainSet as TrainSetModel
32
33
  from nucliadb_models.trainset import TrainSetPartitions
33
- from nucliadb_protos.dataset_pb2 import TrainSet
34
+ from nucliadb_protos.dataset_pb2 import TrainSet as TrainSetPB
34
35
  from nucliadb_sdk.v2.sdk import NucliaDB
35
36
 
36
37
  logger = logging.getLogger("nucliadb_dataset")
@@ -104,7 +105,7 @@ class NucliaDBDataset(NucliaDataset):
104
105
  kbid: str,
105
106
  task: Optional[Task] = None,
106
107
  labels: Optional[List[str]] = None,
107
- trainset: Optional[TrainSet] = None,
108
+ trainset: Optional[Union[TrainSetPB, TrainSetModel]] = None,
108
109
  base_path: Optional[str] = None,
109
110
  search_sdk: Optional[NucliaDB] = None,
110
111
  reader_sdk: Optional[NucliaDB] = None,
@@ -119,7 +120,7 @@ class NucliaDBDataset(NucliaDataset):
119
120
  task_definition = TASK_DEFINITIONS.get(task)
120
121
  if task_definition is None:
121
122
  raise KeyError("Not a valid task")
122
- trainset = TrainSet(type=task_definition.proto)
123
+ trainset = TrainSetPB(type=task_definition.proto)
123
124
  if task_definition.labels:
124
125
  trainset.filter.labels.extend(labels)
125
126
  elif trainset is not None:
@@ -13,11 +13,12 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Optional
16
+ from typing import Dict, Optional, Union
17
17
 
18
18
  import requests
19
19
 
20
- from nucliadb_protos.dataset_pb2 import TrainSet
20
+ from nucliadb_models.trainset import TrainSet as TrainSetModel
21
+ from nucliadb_protos.dataset_pb2 import TrainSet as TrainSetPB
21
22
 
22
23
  logger = logging.getLogger("nucliadb_dataset")
23
24
 
@@ -29,7 +30,7 @@ class Streamer:
29
30
 
30
31
  def __init__(
31
32
  self,
32
- trainset: TrainSet,
33
+ trainset: Union[TrainSetPB, TrainSetModel],
33
34
  reader_headers: Dict[str, str],
34
35
  base_url: str,
35
36
  kbid: str,
@@ -47,12 +48,23 @@ class Streamer:
47
48
  def initialize(self, partition_id: str):
48
49
  self.stream_session = requests.Session()
49
50
  self.stream_session.headers.update(self.reader_headers)
50
- self.resp = self.stream_session.post(
51
- f"{self.base_url}/v1/kb/{self.kbid}/trainset/{partition_id}",
52
- data=self.trainset.SerializeToString(),
53
- stream=True,
54
- timeout=None,
55
- )
51
+ if isinstance(self.trainset, TrainSetPB):
52
+ # Legacy version of the endpoint is passing the protobuffer as bytes in the request content
53
+ self.resp = self.stream_session.post(
54
+ f"{self.base_url}/v1/kb/{self.kbid}/trainset/{partition_id}",
55
+ data=self.trainset.SerializeToString(),
56
+ stream=True,
57
+ timeout=None,
58
+ )
59
+ elif isinstance(self.trainset, TrainSetModel):
60
+ self.resp = self.stream_session.post(
61
+ f"{self.base_url}/v1/kb/{self.kbid}/trainset/{partition_id}",
62
+ json=self.trainset.model_dump(),
63
+ stream=True,
64
+ timeout=None,
65
+ )
66
+ else: # pragma: no cover
67
+ raise ValueError("Invalid trainset type")
56
68
  self.resp.raise_for_status()
57
69
 
58
70
  def finalize(self):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nucliadb_dataset
3
- Version: 6.4.2.post4389
3
+ Version: 6.5.0.post4404
4
4
  Summary: NucliaDB Train Python client
5
5
  Author-email: Nuclia <nucliadb@nuclia.com>
6
6
  License-Expression: Apache-2.0
@@ -25,9 +25,9 @@ Requires-Dist: aiohttp
25
25
  Requires-Dist: argdantic
26
26
  Requires-Dist: pydantic-settings>=2.2
27
27
  Requires-Dist: pyarrow
28
- Requires-Dist: nucliadb-protos>=6.4.2.post4389
29
- Requires-Dist: nucliadb-sdk>=6.4.2.post4389
30
- Requires-Dist: nucliadb-models>=6.4.2.post4389
28
+ Requires-Dist: nucliadb-protos>=6.5.0.post4404
29
+ Requires-Dist: nucliadb-sdk>=6.5.0.post4404
30
+ Requires-Dist: nucliadb-models>=6.5.0.post4404
31
31
 
32
32
  # NUCLIADB TRAIN CLIENT
33
33
 
@@ -1,18 +1,18 @@
1
1
  nucliadb_dataset/__init__.py,sha256=I58PAYrrgLvxmkGGHvzKKUwnaZ2ny44hba6AXEYfKOQ,1054
2
2
  nucliadb_dataset/api.py,sha256=RDIW23cy12E5_UlwsiOuhdFrr1OHPA4Mj7hZH0BqGgA,2757
3
- nucliadb_dataset/dataset.py,sha256=E8ARHwZntdZH9g4zKEVhYbqMYxn3hnuPGkuErL0V-KM,7951
3
+ nucliadb_dataset/dataset.py,sha256=11XoRslzMQQHwJA5MrdAQ30eLP8DoNTv5EmuzL2mln0,8061
4
4
  nucliadb_dataset/export.py,sha256=FJjmg1GA0fhzxEZVgMbrqjbcLTi2v9gw3wI0vhKaDWI,2528
5
5
  nucliadb_dataset/mapping.py,sha256=Ayg-dDiGc4P-ctRj2ddIlbgidziuLlZGKeAZ08aHBZU,6495
6
6
  nucliadb_dataset/nuclia.py,sha256=uXiwJS_GBcN6z9l9hFahD3jLbqhkfpGNWpMR6_8K5k8,3718
7
7
  nucliadb_dataset/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  nucliadb_dataset/run.py,sha256=Ktqv6m0f5oCs54RNZh9b3MJtYZ_JmKP1Zp2Uo3VEyF4,2806
9
9
  nucliadb_dataset/settings.py,sha256=9NYeJVgIHbLqCenUFRwj6Iz9S7klOROXQfzaUiBUEl0,2027
10
- nucliadb_dataset/streamer.py,sha256=-v4FeIRhU_uLhic14vKsbBg8sIqDg9O4gtk_JwSqScg,2381
10
+ nucliadb_dataset/streamer.py,sha256=aBzYWNQQVWM5qA14f8p0YQdshzC4KpOQFt3J0TYe1uk,3060
11
11
  nucliadb_dataset/tasks.py,sha256=198o37vDlzS7OdXrHYhtwI8kz2WHWJnxpholh-rtTPQ,6227
12
12
  nucliadb_dataset/tests/__init__.py,sha256=zG33bUz1rHFPtvqQPWn4rDwBJt3FJodGuQYD45quiQg,583
13
13
  nucliadb_dataset/tests/fixtures.py,sha256=Fx82tfwoFx0YyaiC6tg-ThUDvS8pcmFSE0Pm1KpAS0I,9242
14
- nucliadb_dataset-6.4.2.post4389.dist-info/METADATA,sha256=wB4H4rkEt3JZQ4Ec3yxLVAHgTGUOQOHGqLG9LyKDEh4,1263
15
- nucliadb_dataset-6.4.2.post4389.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
- nucliadb_dataset-6.4.2.post4389.dist-info/entry_points.txt,sha256=ORrWnn6AUFfHGY1fWPRPTz99KV_pXXwttZAopyT8qvQ,60
17
- nucliadb_dataset-6.4.2.post4389.dist-info/top_level.txt,sha256=aJtDe54tz6060E0uyk1rdTRAU4FPWo5if1fYFQGvdqU,17
18
- nucliadb_dataset-6.4.2.post4389.dist-info/RECORD,,
14
+ nucliadb_dataset-6.5.0.post4404.dist-info/METADATA,sha256=wr3_UeG9eunFrhgNc3GRmBFZ9x_EoFhYwm48T5TmaSg,1263
15
+ nucliadb_dataset-6.5.0.post4404.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
16
+ nucliadb_dataset-6.5.0.post4404.dist-info/entry_points.txt,sha256=ORrWnn6AUFfHGY1fWPRPTz99KV_pXXwttZAopyT8qvQ,60
17
+ nucliadb_dataset-6.5.0.post4404.dist-info/top_level.txt,sha256=aJtDe54tz6060E0uyk1rdTRAU4FPWo5if1fYFQGvdqU,17
18
+ nucliadb_dataset-6.5.0.post4404.dist-info/RECORD,,