esgpull 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- esgpull/__init__.py +12 -0
- esgpull/auth.py +181 -0
- esgpull/cli/__init__.py +73 -0
- esgpull/cli/add.py +103 -0
- esgpull/cli/autoremove.py +38 -0
- esgpull/cli/config.py +116 -0
- esgpull/cli/convert.py +285 -0
- esgpull/cli/decorators.py +342 -0
- esgpull/cli/download.py +74 -0
- esgpull/cli/facet.py +23 -0
- esgpull/cli/get.py +28 -0
- esgpull/cli/install.py +85 -0
- esgpull/cli/link.py +105 -0
- esgpull/cli/login.py +56 -0
- esgpull/cli/remove.py +73 -0
- esgpull/cli/retry.py +43 -0
- esgpull/cli/search.py +201 -0
- esgpull/cli/self.py +238 -0
- esgpull/cli/show.py +66 -0
- esgpull/cli/status.py +67 -0
- esgpull/cli/track.py +87 -0
- esgpull/cli/update.py +184 -0
- esgpull/cli/utils.py +247 -0
- esgpull/config.py +410 -0
- esgpull/constants.py +56 -0
- esgpull/context.py +724 -0
- esgpull/database.py +161 -0
- esgpull/download.py +162 -0
- esgpull/esgpull.py +447 -0
- esgpull/exceptions.py +167 -0
- esgpull/fs.py +253 -0
- esgpull/graph.py +460 -0
- esgpull/install_config.py +185 -0
- esgpull/migrations/README +1 -0
- esgpull/migrations/env.py +82 -0
- esgpull/migrations/script.py.mako +24 -0
- esgpull/migrations/versions/0.3.0_update_tables.py +170 -0
- esgpull/migrations/versions/0.3.1_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.2_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.3_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.4_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.5_update_tables.py +25 -0
- esgpull/migrations/versions/0.3.6_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.7_update_tables.py +26 -0
- esgpull/migrations/versions/0.3.8_update_tables.py +26 -0
- esgpull/migrations/versions/0.4.0_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.0_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.1_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.2_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.3_update_tables.py +26 -0
- esgpull/migrations/versions/0.5.4_update_tables.py +25 -0
- esgpull/migrations/versions/0.5.5_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.0_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.1_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.2_update_tables.py +25 -0
- esgpull/migrations/versions/0.6.3_update_tables.py +25 -0
- esgpull/models/__init__.py +31 -0
- esgpull/models/base.py +50 -0
- esgpull/models/dataset.py +34 -0
- esgpull/models/facet.py +18 -0
- esgpull/models/file.py +65 -0
- esgpull/models/options.py +164 -0
- esgpull/models/query.py +481 -0
- esgpull/models/selection.py +201 -0
- esgpull/models/sql.py +258 -0
- esgpull/models/synda_file.py +85 -0
- esgpull/models/tag.py +19 -0
- esgpull/models/utils.py +54 -0
- esgpull/presets.py +13 -0
- esgpull/processor.py +172 -0
- esgpull/py.typed +0 -0
- esgpull/result.py +53 -0
- esgpull/tui.py +346 -0
- esgpull/utils.py +54 -0
- esgpull/version.py +1 -0
- esgpull-0.6.3.dist-info/METADATA +110 -0
- esgpull-0.6.3.dist-info/RECORD +80 -0
- esgpull-0.6.3.dist-info/WHEEL +4 -0
- esgpull-0.6.3.dist-info/entry_points.txt +3 -0
- esgpull-0.6.3.dist-info/licenses/LICENSE +28 -0
esgpull/context.py
ADDED
|
@@ -0,0 +1,724 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import sys
|
|
5
|
+
from collections.abc import AsyncIterator, Callable, Coroutine, Sequence
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from datetime import datetime
|
|
8
|
+
from typing import Any, TypeAlias, TypeVar
|
|
9
|
+
|
|
10
|
+
if sys.version_info < (3, 11):
|
|
11
|
+
from exceptiongroup import BaseExceptionGroup
|
|
12
|
+
|
|
13
|
+
from httpx import AsyncClient, HTTPError, Request
|
|
14
|
+
from rich.pretty import pretty_repr
|
|
15
|
+
|
|
16
|
+
from esgpull.config import Config
|
|
17
|
+
from esgpull.exceptions import SolrUnstableQueryError
|
|
18
|
+
from esgpull.models import Dataset, File, Query
|
|
19
|
+
from esgpull.tui import logger
|
|
20
|
+
from esgpull.utils import format_date, index2url, sync
|
|
21
|
+
|
|
22
|
+
# workaround for notebooks with running event loop
|
|
23
|
+
if asyncio.get_event_loop().is_running():
|
|
24
|
+
import nest_asyncio
|
|
25
|
+
|
|
26
|
+
nest_asyncio.apply()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T")
|
|
30
|
+
RT = TypeVar("RT", bound="Result")
|
|
31
|
+
HintsDict: TypeAlias = dict[str, dict[str, int]]
|
|
32
|
+
DangerousFacets = {
|
|
33
|
+
"instance_id",
|
|
34
|
+
"dataset_id",
|
|
35
|
+
"master_id",
|
|
36
|
+
"tracking_id",
|
|
37
|
+
"url",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class Result:
|
|
43
|
+
query: Query
|
|
44
|
+
file: bool
|
|
45
|
+
request: Request = field(init=False, repr=False)
|
|
46
|
+
json: dict[str, Any] = field(init=False, repr=False)
|
|
47
|
+
exc: BaseException | None = field(init=False, default=None, repr=False)
|
|
48
|
+
processed: bool = field(init=False, default=False, repr=False)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def success(self) -> bool:
|
|
52
|
+
return self.exc is None
|
|
53
|
+
|
|
54
|
+
def prepare(
|
|
55
|
+
self,
|
|
56
|
+
index_node: str,
|
|
57
|
+
offset: int = 0,
|
|
58
|
+
page_limit: int = 50,
|
|
59
|
+
index_url: str | None = None,
|
|
60
|
+
fields_param: list[str] | None = None,
|
|
61
|
+
facets_param: list[str] | None = None,
|
|
62
|
+
date_from: datetime | None = None,
|
|
63
|
+
date_to: datetime | None = None,
|
|
64
|
+
) -> None:
|
|
65
|
+
params: dict[str, str | int | bool] = {
|
|
66
|
+
"type": "File" if self.file else "Dataset",
|
|
67
|
+
"offset": offset,
|
|
68
|
+
"limit": page_limit,
|
|
69
|
+
"format": "application/solr+json",
|
|
70
|
+
# "from": self.since,
|
|
71
|
+
}
|
|
72
|
+
if index_url is None:
|
|
73
|
+
index_url = index2url(index_node)
|
|
74
|
+
if fields_param is not None:
|
|
75
|
+
params["fields"] = ",".join(fields_param)
|
|
76
|
+
else:
|
|
77
|
+
params["fields"] = "instance_id"
|
|
78
|
+
if date_from is not None:
|
|
79
|
+
params["from"] = format_date(date_from)
|
|
80
|
+
if date_to is not None:
|
|
81
|
+
params["to"] = format_date(date_to)
|
|
82
|
+
if facets_param is not None:
|
|
83
|
+
if len(set(facets_param) & DangerousFacets) > 0:
|
|
84
|
+
raise SolrUnstableQueryError(pretty_repr(self.query))
|
|
85
|
+
facets_param_str = ",".join(facets_param)
|
|
86
|
+
facets_star = "*" in facets_param_str
|
|
87
|
+
params["facets"] = facets_param_str
|
|
88
|
+
else:
|
|
89
|
+
facets_star = False
|
|
90
|
+
# [?]TODO: add nominal temporal constraints `to`
|
|
91
|
+
# if "start" in facets:
|
|
92
|
+
# query["start"] = format_date(str(facets.pop("start")))
|
|
93
|
+
# if "end" in facets:
|
|
94
|
+
# query["end"] = format_date(str(facets.pop("end")))
|
|
95
|
+
solr_terms: list[str] = []
|
|
96
|
+
for name, values in self.query.selection.items():
|
|
97
|
+
value_term = " ".join(values)
|
|
98
|
+
if name == "query": # freetext case
|
|
99
|
+
solr_terms.append(value_term)
|
|
100
|
+
else:
|
|
101
|
+
if len(values) > 1:
|
|
102
|
+
value_term = f"({value_term})"
|
|
103
|
+
solr_terms.append(f"{name}:{value_term}")
|
|
104
|
+
if solr_terms:
|
|
105
|
+
params["query"] = " AND ".join(solr_terms)
|
|
106
|
+
for name, option in self.query.options.items(use_default=True):
|
|
107
|
+
if option.is_bool():
|
|
108
|
+
params[name] = option.name
|
|
109
|
+
if params.get("distrib") == "true" and facets_star:
|
|
110
|
+
raise SolrUnstableQueryError(pretty_repr(self.query))
|
|
111
|
+
self.request = Request("GET", index_url, params=params)
|
|
112
|
+
|
|
113
|
+
def to(self, subtype: type[RT]) -> RT:
|
|
114
|
+
result: RT = subtype(self.query, self.file)
|
|
115
|
+
result.request = self.request
|
|
116
|
+
result.exc = self.exc
|
|
117
|
+
if result.success:
|
|
118
|
+
result.json = self.json
|
|
119
|
+
return result
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@dataclass
|
|
123
|
+
class ResultHits(Result):
|
|
124
|
+
data: int = field(init=False, repr=False)
|
|
125
|
+
|
|
126
|
+
def process(self) -> None:
|
|
127
|
+
if self.success:
|
|
128
|
+
self.data = self.json["response"]["numFound"]
|
|
129
|
+
self.processed = True
|
|
130
|
+
else:
|
|
131
|
+
self.data = 0
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass
|
|
135
|
+
class ResultHints(Result):
|
|
136
|
+
data: HintsDict = field(init=False, repr=False)
|
|
137
|
+
|
|
138
|
+
def process(self) -> None:
|
|
139
|
+
self.data = {}
|
|
140
|
+
if self.success:
|
|
141
|
+
facet_fields = self.json["facet_counts"]["facet_fields"]
|
|
142
|
+
for name, value_count in facet_fields.items():
|
|
143
|
+
if len(value_count) == 0:
|
|
144
|
+
continue
|
|
145
|
+
values: list[str] = value_count[::2]
|
|
146
|
+
counts: list[int] = value_count[1::2]
|
|
147
|
+
self.data[name] = dict(zip(values, counts))
|
|
148
|
+
self.processed = True
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@dataclass
|
|
152
|
+
class ResultSearch(Result):
|
|
153
|
+
data: Sequence[File | Dataset] = field(init=False, repr=False)
|
|
154
|
+
|
|
155
|
+
def process(self) -> None:
|
|
156
|
+
raise NotImplementedError
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@dataclass
|
|
160
|
+
class ResultDatasets(Result):
|
|
161
|
+
data: Sequence[Dataset] = field(init=False, repr=False)
|
|
162
|
+
|
|
163
|
+
def process(self) -> None:
|
|
164
|
+
self.data = []
|
|
165
|
+
if self.success:
|
|
166
|
+
for doc in self.json["response"]["docs"]:
|
|
167
|
+
try:
|
|
168
|
+
dataset = Dataset.serialize(doc)
|
|
169
|
+
self.data.append(dataset)
|
|
170
|
+
except KeyError as exc:
|
|
171
|
+
logger.exception(exc)
|
|
172
|
+
self.processed = True
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@dataclass
|
|
176
|
+
class ResultFiles(Result):
|
|
177
|
+
data: Sequence[File] = field(init=False, repr=False)
|
|
178
|
+
|
|
179
|
+
def process(self) -> None:
|
|
180
|
+
self.data = []
|
|
181
|
+
if self.success:
|
|
182
|
+
for doc in self.json["response"]["docs"]:
|
|
183
|
+
try:
|
|
184
|
+
file = File.serialize(doc)
|
|
185
|
+
self.data.append(file)
|
|
186
|
+
except KeyError as exc:
|
|
187
|
+
logger.exception(exc)
|
|
188
|
+
fid = doc["instance_id"]
|
|
189
|
+
logger.warning(f"File {fid} has invalid metadata")
|
|
190
|
+
logger.debug(pretty_repr(doc))
|
|
191
|
+
self.processed = True
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@dataclass
|
|
195
|
+
class ResultSearchAsQueries(Result):
|
|
196
|
+
data: Sequence[Query] = field(init=False, repr=False)
|
|
197
|
+
|
|
198
|
+
def process(self) -> None:
|
|
199
|
+
self.data = []
|
|
200
|
+
sha = "FILE" if self.file else "DATASET"
|
|
201
|
+
if self.success:
|
|
202
|
+
for doc in self.json["response"]["docs"]:
|
|
203
|
+
query = Query._from_detailed_dict(doc)
|
|
204
|
+
query.sha = f"{sha}:{query.sha}"
|
|
205
|
+
self.data.append(query)
|
|
206
|
+
self.processed = True
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def _distribute_hits_impl(hits: list[int], max_hits: int) -> list[int]:
|
|
210
|
+
i = total = 0
|
|
211
|
+
N = len(hits)
|
|
212
|
+
accs = [0.0 for _ in range(N)]
|
|
213
|
+
result = [0 for _ in range(N)]
|
|
214
|
+
steps = [h / (sum(hits) or 1) for h in hits]
|
|
215
|
+
max_hits = min(max_hits, sum(hits))
|
|
216
|
+
while True:
|
|
217
|
+
accs[i] += steps[i]
|
|
218
|
+
step = int(accs[i])
|
|
219
|
+
if total + step >= max_hits:
|
|
220
|
+
result[i] += max_hits - total
|
|
221
|
+
break
|
|
222
|
+
total += step
|
|
223
|
+
accs[i] -= step
|
|
224
|
+
result[i] += step
|
|
225
|
+
i = (i + 1) % N
|
|
226
|
+
return result
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def _distribute_hits(
|
|
230
|
+
hits: list[int],
|
|
231
|
+
offset: int,
|
|
232
|
+
max_hits: int | None,
|
|
233
|
+
page_limit: int,
|
|
234
|
+
) -> list[list[slice]]:
|
|
235
|
+
offsets = _distribute_hits_impl(hits, offset)
|
|
236
|
+
hits_with_offset = [h - o for h, o in zip(hits, offsets)]
|
|
237
|
+
hits = hits[:]
|
|
238
|
+
if max_hits is not None:
|
|
239
|
+
hits = _distribute_hits_impl(hits_with_offset, max_hits)
|
|
240
|
+
result: list[list[slice]] = []
|
|
241
|
+
for i, hit in enumerate(hits):
|
|
242
|
+
slices = []
|
|
243
|
+
offset = offsets[i]
|
|
244
|
+
fullstop = hit + offset
|
|
245
|
+
for start in range(offset, fullstop, page_limit):
|
|
246
|
+
stop = start + min(page_limit, fullstop - start)
|
|
247
|
+
slices.append(slice(start, stop))
|
|
248
|
+
result.append(slices)
|
|
249
|
+
return result
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
FileFieldParams = ["*"]
|
|
253
|
+
DatasetFieldParams = [
|
|
254
|
+
"instance_id",
|
|
255
|
+
"data_node",
|
|
256
|
+
"size",
|
|
257
|
+
"number_of_files",
|
|
258
|
+
]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
@dataclass
|
|
262
|
+
class Context:
|
|
263
|
+
config: Config = field(default_factory=Config.default)
|
|
264
|
+
client: AsyncClient = field(
|
|
265
|
+
init=False,
|
|
266
|
+
repr=False,
|
|
267
|
+
)
|
|
268
|
+
semaphores: dict[str, asyncio.Semaphore] = field(
|
|
269
|
+
init=False,
|
|
270
|
+
repr=False,
|
|
271
|
+
default_factory=dict,
|
|
272
|
+
)
|
|
273
|
+
noraise: bool = False
|
|
274
|
+
|
|
275
|
+
# def __init__(
|
|
276
|
+
# self,
|
|
277
|
+
# config: Config | None = None,
|
|
278
|
+
# *,
|
|
279
|
+
# # since: str | datetime | None = None,
|
|
280
|
+
# ):
|
|
281
|
+
# # if since is None:
|
|
282
|
+
# # self.since = since
|
|
283
|
+
# # else:
|
|
284
|
+
# # self.since = format_date(since)
|
|
285
|
+
|
|
286
|
+
async def __aenter__(self) -> Context:
|
|
287
|
+
if hasattr(self, "client"):
|
|
288
|
+
raise Exception("Context is already initialized.")
|
|
289
|
+
self.client = AsyncClient(timeout=self.config.api.http_timeout)
|
|
290
|
+
return self
|
|
291
|
+
|
|
292
|
+
async def __aexit__(self, *exc) -> None:
|
|
293
|
+
if not hasattr(self, "client"):
|
|
294
|
+
raise Exception("Context is not initialized.")
|
|
295
|
+
await self.client.aclose()
|
|
296
|
+
del self.client
|
|
297
|
+
|
|
298
|
+
def prepare_hits(
|
|
299
|
+
self,
|
|
300
|
+
*queries: Query,
|
|
301
|
+
file: bool,
|
|
302
|
+
index_url: str | None = None,
|
|
303
|
+
index_node: str | None = None,
|
|
304
|
+
date_from: datetime | None = None,
|
|
305
|
+
date_to: datetime | None = None,
|
|
306
|
+
) -> list[ResultHits]:
|
|
307
|
+
results = []
|
|
308
|
+
for i, query in enumerate(queries):
|
|
309
|
+
result = ResultHits(query, file)
|
|
310
|
+
result.prepare(
|
|
311
|
+
index_node=index_node or self.config.api.index_node,
|
|
312
|
+
page_limit=0,
|
|
313
|
+
index_url=index_url,
|
|
314
|
+
date_from=date_from,
|
|
315
|
+
date_to=date_to,
|
|
316
|
+
)
|
|
317
|
+
results.append(result)
|
|
318
|
+
return results
|
|
319
|
+
|
|
320
|
+
def prepare_hints(
|
|
321
|
+
self,
|
|
322
|
+
*queries: Query,
|
|
323
|
+
file: bool,
|
|
324
|
+
facets: list[str],
|
|
325
|
+
index_url: str | None = None,
|
|
326
|
+
index_node: str | None = None,
|
|
327
|
+
date_from: datetime | None = None,
|
|
328
|
+
date_to: datetime | None = None,
|
|
329
|
+
) -> list[ResultHints]:
|
|
330
|
+
results = []
|
|
331
|
+
for i, query in enumerate(queries):
|
|
332
|
+
result = ResultHints(query, file)
|
|
333
|
+
result.prepare(
|
|
334
|
+
index_node=index_node or self.config.api.index_node,
|
|
335
|
+
page_limit=0,
|
|
336
|
+
facets_param=facets,
|
|
337
|
+
index_url=index_url,
|
|
338
|
+
date_from=date_from,
|
|
339
|
+
date_to=date_to,
|
|
340
|
+
)
|
|
341
|
+
results.append(result)
|
|
342
|
+
return results
|
|
343
|
+
|
|
344
|
+
def prepare_search(
|
|
345
|
+
self,
|
|
346
|
+
*queries: Query,
|
|
347
|
+
file: bool,
|
|
348
|
+
hits: list[int],
|
|
349
|
+
offset: int = 0,
|
|
350
|
+
max_hits: int | None = 200,
|
|
351
|
+
page_limit: int | None = None,
|
|
352
|
+
index_url: str | None = None,
|
|
353
|
+
index_node: str | None = None,
|
|
354
|
+
fields_param: list[str] | None = None,
|
|
355
|
+
date_from: datetime | None = None,
|
|
356
|
+
date_to: datetime | None = None,
|
|
357
|
+
) -> list[ResultSearch]:
|
|
358
|
+
if page_limit is None:
|
|
359
|
+
page_limit = self.config.api.page_limit
|
|
360
|
+
if fields_param is None:
|
|
361
|
+
if file:
|
|
362
|
+
fields_param = FileFieldParams
|
|
363
|
+
else:
|
|
364
|
+
fields_param = DatasetFieldParams
|
|
365
|
+
slices = _distribute_hits(
|
|
366
|
+
hits=hits,
|
|
367
|
+
offset=offset,
|
|
368
|
+
max_hits=max_hits,
|
|
369
|
+
page_limit=page_limit,
|
|
370
|
+
)
|
|
371
|
+
results = []
|
|
372
|
+
for query, query_slices in zip(queries, slices):
|
|
373
|
+
for sl in query_slices:
|
|
374
|
+
result = ResultSearch(query, file=file)
|
|
375
|
+
result.prepare(
|
|
376
|
+
index_node=index_node or self.config.api.index_node,
|
|
377
|
+
offset=sl.start,
|
|
378
|
+
page_limit=sl.stop - sl.start,
|
|
379
|
+
fields_param=fields_param,
|
|
380
|
+
index_url=index_url,
|
|
381
|
+
date_from=date_from,
|
|
382
|
+
date_to=date_to,
|
|
383
|
+
)
|
|
384
|
+
results.append(result)
|
|
385
|
+
return results
|
|
386
|
+
|
|
387
|
+
def prepare_search_distributed(
|
|
388
|
+
self,
|
|
389
|
+
*queries: Query,
|
|
390
|
+
file: bool,
|
|
391
|
+
hints: list[HintsDict],
|
|
392
|
+
offset: int = 0,
|
|
393
|
+
max_hits: int | None = 200,
|
|
394
|
+
page_limit: int | None = None,
|
|
395
|
+
fields_param: list[str] | None = None,
|
|
396
|
+
date_from: datetime | None = None,
|
|
397
|
+
date_to: datetime | None = None,
|
|
398
|
+
) -> list[ResultSearch]:
|
|
399
|
+
if page_limit is None:
|
|
400
|
+
page_limit = self.config.api.page_limit
|
|
401
|
+
if fields_param is None:
|
|
402
|
+
if file:
|
|
403
|
+
fields_param = FileFieldParams
|
|
404
|
+
else:
|
|
405
|
+
fields_param = DatasetFieldParams
|
|
406
|
+
hits = self.hits_from_hints(*hints)
|
|
407
|
+
if max_hits is not None:
|
|
408
|
+
hits = _distribute_hits_impl(hits, max_hits)
|
|
409
|
+
results = []
|
|
410
|
+
not_distrib = Query(options=dict(distrib=False))
|
|
411
|
+
for query, query_hints, query_max_hits in zip(queries, hints, hits):
|
|
412
|
+
nodes = query_hints["index_node"]
|
|
413
|
+
nodes_hits = [nodes[node] for node in nodes]
|
|
414
|
+
slices = _distribute_hits(
|
|
415
|
+
hits=nodes_hits,
|
|
416
|
+
offset=offset,
|
|
417
|
+
max_hits=query_max_hits,
|
|
418
|
+
page_limit=page_limit,
|
|
419
|
+
)
|
|
420
|
+
for node, node_slices in zip(nodes, slices):
|
|
421
|
+
for sl in node_slices:
|
|
422
|
+
result = ResultSearch(query << not_distrib, file=file)
|
|
423
|
+
result.prepare(
|
|
424
|
+
index_node=node,
|
|
425
|
+
offset=sl.start,
|
|
426
|
+
page_limit=sl.stop - sl.start,
|
|
427
|
+
fields_param=fields_param,
|
|
428
|
+
date_from=date_from,
|
|
429
|
+
date_to=date_to,
|
|
430
|
+
)
|
|
431
|
+
results.append(result)
|
|
432
|
+
return results
|
|
433
|
+
|
|
434
|
+
async def _fetch_one(self, result: RT) -> RT:
|
|
435
|
+
host = result.request.url.host
|
|
436
|
+
if host not in self.semaphores:
|
|
437
|
+
max_concurrent = self.config.api.max_concurrent
|
|
438
|
+
self.semaphores[host] = asyncio.Semaphore(max_concurrent)
|
|
439
|
+
async with self.semaphores[host]:
|
|
440
|
+
logger.debug(f"GET {host} params={result.request.url.params}")
|
|
441
|
+
try:
|
|
442
|
+
resp = await self.client.send(result.request)
|
|
443
|
+
resp.raise_for_status()
|
|
444
|
+
result.json = resp.json()
|
|
445
|
+
logger.info(f"✓ Fetched in {resp.elapsed}s {resp.url}")
|
|
446
|
+
except HTTPError as exc:
|
|
447
|
+
result.exc = exc
|
|
448
|
+
except (Exception, asyncio.CancelledError) as exc:
|
|
449
|
+
result.exc = exc
|
|
450
|
+
return result
|
|
451
|
+
|
|
452
|
+
async def _fetch(self, *in_results: RT) -> AsyncIterator[RT]:
|
|
453
|
+
tasks = [
|
|
454
|
+
asyncio.create_task(self._fetch_one(result))
|
|
455
|
+
for result in in_results
|
|
456
|
+
]
|
|
457
|
+
excs = []
|
|
458
|
+
for task in tasks:
|
|
459
|
+
result = await task
|
|
460
|
+
yield result
|
|
461
|
+
if result.exc is not None:
|
|
462
|
+
excs.append(result.exc)
|
|
463
|
+
if excs:
|
|
464
|
+
group = BaseExceptionGroup("fetch", excs)
|
|
465
|
+
if self.noraise:
|
|
466
|
+
logger.exception(group)
|
|
467
|
+
else:
|
|
468
|
+
raise group
|
|
469
|
+
|
|
470
|
+
async def _hits(self, *results: ResultHits) -> list[int]:
|
|
471
|
+
hits = []
|
|
472
|
+
async for result in self._fetch(*results):
|
|
473
|
+
result.process()
|
|
474
|
+
if result.processed:
|
|
475
|
+
hits.append(result.data)
|
|
476
|
+
return hits
|
|
477
|
+
|
|
478
|
+
async def _hints(self, *results: ResultHints) -> list[HintsDict]:
|
|
479
|
+
hints: list[HintsDict] = []
|
|
480
|
+
async for result in self._fetch(*results):
|
|
481
|
+
result.process()
|
|
482
|
+
if result.processed:
|
|
483
|
+
hints.append(result.data)
|
|
484
|
+
return hints
|
|
485
|
+
|
|
486
|
+
async def _datasets(
|
|
487
|
+
self,
|
|
488
|
+
*results: ResultSearch,
|
|
489
|
+
keep_duplicates: bool,
|
|
490
|
+
) -> list[Dataset]:
|
|
491
|
+
datasets: list[Dataset] = []
|
|
492
|
+
ids: set[str] = set()
|
|
493
|
+
async for result in self._fetch(*results):
|
|
494
|
+
dataset_result = result.to(ResultDatasets)
|
|
495
|
+
dataset_result.process()
|
|
496
|
+
if dataset_result.processed:
|
|
497
|
+
for d in dataset_result.data:
|
|
498
|
+
if not keep_duplicates and d.dataset_id in ids:
|
|
499
|
+
logger.warning(f"Duplicate dataset {d.dataset_id}")
|
|
500
|
+
else:
|
|
501
|
+
datasets.append(d)
|
|
502
|
+
ids.add(d.dataset_id)
|
|
503
|
+
return datasets
|
|
504
|
+
|
|
505
|
+
async def _files(
|
|
506
|
+
self,
|
|
507
|
+
*results: ResultSearch,
|
|
508
|
+
keep_duplicates: bool,
|
|
509
|
+
) -> list[File]:
|
|
510
|
+
files: list[File] = []
|
|
511
|
+
shas: set[str] = set()
|
|
512
|
+
async for result in self._fetch(*results):
|
|
513
|
+
files_result = result.to(ResultFiles)
|
|
514
|
+
files_result.process()
|
|
515
|
+
if files_result.processed:
|
|
516
|
+
for file in files_result.data:
|
|
517
|
+
if not keep_duplicates and file.sha in shas:
|
|
518
|
+
logger.warning(f"Duplicate file {file.file_id}")
|
|
519
|
+
else:
|
|
520
|
+
files.append(file)
|
|
521
|
+
shas.add(file.sha)
|
|
522
|
+
return files
|
|
523
|
+
|
|
524
|
+
async def _search_as_queries(
|
|
525
|
+
self,
|
|
526
|
+
*results: ResultSearch,
|
|
527
|
+
keep_duplicates: bool,
|
|
528
|
+
) -> list[Query]:
|
|
529
|
+
queries: list[Query] = []
|
|
530
|
+
async for result in self._fetch(*results):
|
|
531
|
+
queries_result = result.to(ResultSearchAsQueries)
|
|
532
|
+
queries_result.process()
|
|
533
|
+
if queries_result.processed:
|
|
534
|
+
for query in queries_result.data:
|
|
535
|
+
queries.append(query)
|
|
536
|
+
return queries
|
|
537
|
+
|
|
538
|
+
async def _with_client(self, coro: Coroutine[None, None, T]) -> T:
|
|
539
|
+
"""
|
|
540
|
+
Async wrapper to create client before await future.
|
|
541
|
+
This is required since asyncio does not provide a way
|
|
542
|
+
to enter an async context in a sync function.
|
|
543
|
+
"""
|
|
544
|
+
async with self:
|
|
545
|
+
return await coro
|
|
546
|
+
|
|
547
|
+
def free_semaphores(self) -> None:
|
|
548
|
+
self.semaphores = {}
|
|
549
|
+
|
|
550
|
+
def _sync(self, coro: Coroutine[None, None, T]) -> T:
|
|
551
|
+
"""
|
|
552
|
+
Reset semaphore to ensure none is bound to an expired event loop.
|
|
553
|
+
Run through `_with_client` wrapper to use `async with` synchronously.
|
|
554
|
+
"""
|
|
555
|
+
self.free_semaphores()
|
|
556
|
+
return sync(self._with_client(coro))
|
|
557
|
+
|
|
558
|
+
async def _gather(self, *coros: Coroutine[None, None, T]) -> list[T]:
|
|
559
|
+
return await asyncio.gather(*coros)
|
|
560
|
+
|
|
561
|
+
def sync_gather(self, *coros: Coroutine[None, None, T]) -> list[T]:
|
|
562
|
+
return self._sync(self._gather(*coros))
|
|
563
|
+
|
|
564
|
+
def hits(
|
|
565
|
+
self,
|
|
566
|
+
*queries: Query,
|
|
567
|
+
file: bool,
|
|
568
|
+
index_url: str | None = None,
|
|
569
|
+
index_node: str | None = None,
|
|
570
|
+
date_from: datetime | None = None,
|
|
571
|
+
date_to: datetime | None = None,
|
|
572
|
+
) -> list[int]:
|
|
573
|
+
results = self.prepare_hits(
|
|
574
|
+
*queries,
|
|
575
|
+
file=file,
|
|
576
|
+
index_url=index_url,
|
|
577
|
+
index_node=index_node,
|
|
578
|
+
date_from=date_from,
|
|
579
|
+
date_to=date_to,
|
|
580
|
+
)
|
|
581
|
+
return self._sync(self._hits(*results))
|
|
582
|
+
|
|
583
|
+
def hits_from_hints(self, *hints: HintsDict) -> list[int]:
|
|
584
|
+
result: list[int] = []
|
|
585
|
+
for hint in hints:
|
|
586
|
+
if len(hint) > 0:
|
|
587
|
+
key = next(iter(hint))
|
|
588
|
+
num = sum(hint[key].values())
|
|
589
|
+
else:
|
|
590
|
+
num = 0
|
|
591
|
+
result.append(num)
|
|
592
|
+
return result
|
|
593
|
+
|
|
594
|
+
def hints(
|
|
595
|
+
self,
|
|
596
|
+
*queries: Query,
|
|
597
|
+
file: bool,
|
|
598
|
+
facets: list[str],
|
|
599
|
+
index_url: str | None = None,
|
|
600
|
+
index_node: str | None = None,
|
|
601
|
+
date_from: datetime | None = None,
|
|
602
|
+
date_to: datetime | None = None,
|
|
603
|
+
) -> list[HintsDict]:
|
|
604
|
+
results = self.prepare_hints(
|
|
605
|
+
*queries,
|
|
606
|
+
file=file,
|
|
607
|
+
facets=facets,
|
|
608
|
+
index_url=index_url,
|
|
609
|
+
index_node=index_node,
|
|
610
|
+
date_from=date_from,
|
|
611
|
+
date_to=date_to,
|
|
612
|
+
)
|
|
613
|
+
return self._sync(self._hints(*results))
|
|
614
|
+
|
|
615
|
+
def datasets(
|
|
616
|
+
self,
|
|
617
|
+
*queries: Query,
|
|
618
|
+
hits: list[int] | None = None,
|
|
619
|
+
offset: int = 0,
|
|
620
|
+
max_hits: int | None = 200,
|
|
621
|
+
page_limit: int | None = None,
|
|
622
|
+
date_from: datetime | None = None,
|
|
623
|
+
date_to: datetime | None = None,
|
|
624
|
+
keep_duplicates: bool = True,
|
|
625
|
+
) -> list[Dataset]:
|
|
626
|
+
if hits is None:
|
|
627
|
+
hits = self.hits(*queries, file=False)
|
|
628
|
+
results = self.prepare_search(
|
|
629
|
+
*queries,
|
|
630
|
+
file=False,
|
|
631
|
+
hits=hits,
|
|
632
|
+
offset=offset,
|
|
633
|
+
page_limit=page_limit,
|
|
634
|
+
max_hits=max_hits,
|
|
635
|
+
date_from=date_from,
|
|
636
|
+
date_to=date_to,
|
|
637
|
+
)
|
|
638
|
+
coro = self._datasets(*results, keep_duplicates=keep_duplicates)
|
|
639
|
+
return self._sync(coro)
|
|
640
|
+
|
|
641
|
+
def files(
|
|
642
|
+
self,
|
|
643
|
+
*queries: Query,
|
|
644
|
+
hits: list[int] | None = None,
|
|
645
|
+
offset: int = 0,
|
|
646
|
+
max_hits: int | None = 200,
|
|
647
|
+
page_limit: int | None = None,
|
|
648
|
+
date_from: datetime | None = None,
|
|
649
|
+
date_to: datetime | None = None,
|
|
650
|
+
keep_duplicates: bool = True,
|
|
651
|
+
) -> list[File]:
|
|
652
|
+
if hits is None:
|
|
653
|
+
hits = self.hits(*queries, file=True)
|
|
654
|
+
results = self.prepare_search(
|
|
655
|
+
*queries,
|
|
656
|
+
file=True,
|
|
657
|
+
hits=hits,
|
|
658
|
+
offset=offset,
|
|
659
|
+
page_limit=page_limit,
|
|
660
|
+
max_hits=max_hits,
|
|
661
|
+
date_from=date_from,
|
|
662
|
+
date_to=date_to,
|
|
663
|
+
)
|
|
664
|
+
coro = self._files(*results, keep_duplicates=keep_duplicates)
|
|
665
|
+
return self._sync(coro)
|
|
666
|
+
|
|
667
|
+
def search_as_queries(
|
|
668
|
+
self,
|
|
669
|
+
*queries: Query,
|
|
670
|
+
file: bool,
|
|
671
|
+
hits: list[int] | None = None,
|
|
672
|
+
offset: int = 0,
|
|
673
|
+
max_hits: int | None = 1,
|
|
674
|
+
page_limit: int | None = None,
|
|
675
|
+
date_from: datetime | None = None,
|
|
676
|
+
date_to: datetime | None = None,
|
|
677
|
+
keep_duplicates: bool = True,
|
|
678
|
+
) -> Sequence[Query]:
|
|
679
|
+
if hits is None:
|
|
680
|
+
hits = self.hits(*queries, file=file)
|
|
681
|
+
results = self.prepare_search(
|
|
682
|
+
*queries,
|
|
683
|
+
file=file,
|
|
684
|
+
hits=hits,
|
|
685
|
+
offset=offset,
|
|
686
|
+
page_limit=page_limit,
|
|
687
|
+
max_hits=max_hits,
|
|
688
|
+
date_from=date_from,
|
|
689
|
+
date_to=date_to,
|
|
690
|
+
fields_param=["*"],
|
|
691
|
+
)
|
|
692
|
+
coro = self._search_as_queries(
|
|
693
|
+
*results,
|
|
694
|
+
keep_duplicates=keep_duplicates,
|
|
695
|
+
)
|
|
696
|
+
return self._sync(coro)
|
|
697
|
+
|
|
698
|
+
def search(
|
|
699
|
+
self,
|
|
700
|
+
*queries: Query,
|
|
701
|
+
file: bool,
|
|
702
|
+
hits: list[int] | None = None,
|
|
703
|
+
offset: int = 0,
|
|
704
|
+
max_hits: int | None = 200,
|
|
705
|
+
page_limit: int | None = None,
|
|
706
|
+
date_from: datetime | None = None,
|
|
707
|
+
date_to: datetime | None = None,
|
|
708
|
+
keep_duplicates: bool = True,
|
|
709
|
+
) -> Sequence[File | Dataset]:
|
|
710
|
+
fun: Callable[..., Sequence[File | Dataset]]
|
|
711
|
+
if file:
|
|
712
|
+
fun = self.files
|
|
713
|
+
else:
|
|
714
|
+
fun = self.datasets
|
|
715
|
+
return fun(
|
|
716
|
+
*queries,
|
|
717
|
+
hits=hits,
|
|
718
|
+
offset=offset,
|
|
719
|
+
max_hits=max_hits,
|
|
720
|
+
page_limit=page_limit,
|
|
721
|
+
date_from=date_from,
|
|
722
|
+
date_to=date_to,
|
|
723
|
+
keep_duplicates=keep_duplicates,
|
|
724
|
+
)
|