mp-api 0.46.2rc2__tar.gz → 0.46.2rc4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/PKG-INFO +3 -3
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/_test_utils.py +54 -7
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/client.py +306 -57
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/schemas.py +7 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/utils.py +12 -5
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/mprester.py +104 -77
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/doi.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/electrodes.py +19 -6
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/electronic_structure.py +97 -34
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/eos.py +24 -8
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/grain_boundaries.py +1 -0
- mp_api-0.46.2rc4/mp_api/client/routes/materials/phonon.py +234 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/similarity.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/substrates.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/surface_properties.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/synthesis.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/tasks.py +13 -15
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/thermo.py +73 -34
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/xas.py +39 -7
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/molecules/jcesr.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/molecules/molecules.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/molecules/summary.py +1 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api.egg-info/PKG-INFO +3 -3
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api.egg-info/requires.txt +2 -2
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/pyproject.toml +2 -2
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.11.txt +16 -16
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.11_extras.txt +96 -67
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.12.txt +16 -16
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.12_extras.txt +97 -68
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.13.txt +16 -16
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.13_extras.txt +97 -68
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.14.txt +16 -16
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/requirements/requirements-ubuntu-latest_py3.14_extras.txt +97 -68
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_chemenv.py +2 -3
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_electrodes.py +8 -12
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_electronic_structure.py +7 -5
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_eos.py +15 -3
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_phonon.py +7 -4
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_provenance.py +1 -1
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_summary.py +1 -1
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_thermo.py +3 -3
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_xas.py +13 -10
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/molecules/test_jcesr.py +12 -8
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/molecules/test_summary.py +3 -4
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/test_client.py +5 -1
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/test_mprester.py +131 -116
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/mcp/test_tools.py +1 -1
- mp_api-0.46.2rc2/mp_api/client/routes/materials/phonon.py +0 -156
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.coveragerc +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.github/workflows/lint.yml +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.github/workflows/release.yml +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.github/workflows/testing.yml +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.github/workflows/upgrade_dependencies.yml +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.gitignore +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/.pre-commit-config.yaml +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/CODE_OF_CONDUCT.md +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/LICENSE +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/README.md +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/dev/generate_mcp_tools.py +1 -1
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/dev/inspect_mcp.sh +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/Makefile +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/_templates/custom-class-template.rst +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/_templates/custom-module-template.rst +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/conf.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/index.rst +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/make.bat +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/docs/modules.rst +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/_server_utils.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/_logger.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/_types.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/_units.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/client.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/schemas.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/settings.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/contribs/utils.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/_oxygen_evolution.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/exceptions.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/core/settings.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/_server.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/absorption.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/alloys.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/bonds.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/chemenv.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/dielectric.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/elasticity.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/magnetism.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/materials.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/oxidation_states.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/piezo.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/provenance.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/robocrys.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/materials/summary.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/client/routes/molecules/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/mcp/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/mcp/_schemas.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/mcp/mp_mcp.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/mcp/server.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/mcp/tools.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/mcp/utils.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api/py.typed +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api.egg-info/SOURCES.txt +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api.egg-info/dependency_links.txt +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api.egg-info/entry_points.txt +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/mp_api.egg-info/top_level.txt +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/setup.cfg +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/test_files/Si_mp_149.cif +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/contribs/conftest.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/contribs/test_client.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/contribs/test_contribs_schemas.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/contribs/test_contribs_utils.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/contribs/test_types.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/core/test_oxygen_evolution.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/core/test_schemas.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/core/test_utils.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_absorption.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_alloys.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_bonds.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_dielectric.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_doi.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_elasticity.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_grain_boundary.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_magnetism.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_materials.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_oxidation_states.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_piezo.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_robocrys.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_similarity.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_substrates.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_surface_properties.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_synthesis.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/materials/test_tasks.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/molecules/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/molecules/test_molecules.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/test_core_client.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/client/test_heartbeat.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/mcp/__init__.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/mcp/test_server.py +0 -0
- {mp_api-0.46.2rc2 → mp_api-0.46.2rc4}/tests/mcp/test_utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mp-api
|
|
3
|
-
Version: 0.46.
|
|
3
|
+
Version: 0.46.2rc4
|
|
4
4
|
Summary: API Client for the Materials Project
|
|
5
5
|
Author-email: The Materials Project <feedback@materialsproject.org>
|
|
6
6
|
License-Expression: BSD-3-Clause-LBNL
|
|
@@ -17,7 +17,7 @@ Requires-Dist: pymatgen>2024.2.20
|
|
|
17
17
|
Requires-Dist: typing-extensions>=3.7.4.1
|
|
18
18
|
Requires-Dist: requests>=2.23.0
|
|
19
19
|
Requires-Dist: monty>=2024.12.10
|
|
20
|
-
Requires-Dist: emmet-core<0.
|
|
20
|
+
Requires-Dist: emmet-core<0.87.2,>=0.87.0rc2
|
|
21
21
|
Requires-Dist: boto3
|
|
22
22
|
Requires-Dist: orjson<4,>=3.10
|
|
23
23
|
Requires-Dist: pyarrow>=20.0.0
|
|
@@ -39,7 +39,7 @@ Requires-Dist: swagger-spec-validator; extra == "contribs"
|
|
|
39
39
|
Requires-Dist: cachetools; extra == "contribs"
|
|
40
40
|
Provides-Extra: all
|
|
41
41
|
Requires-Dist: custodian; extra == "all"
|
|
42
|
-
Requires-Dist: emmet-core[all]<0.
|
|
42
|
+
Requires-Dist: emmet-core[all]<0.87.2,>=0.87.0rc2; extra == "all"
|
|
43
43
|
Requires-Dist: fastmcp; extra == "all"
|
|
44
44
|
Requires-Dist: flask; extra == "all"
|
|
45
45
|
Provides-Extra: test
|
|
@@ -4,6 +4,8 @@
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
7
9
|
try:
|
|
8
10
|
import pytest
|
|
9
11
|
except ImportError as exc:
|
|
@@ -86,19 +88,64 @@ def client_search_testing(
|
|
|
86
88
|
assert doc[alt_name_dict.get(param, param)] is not None
|
|
87
89
|
|
|
88
90
|
|
|
89
|
-
def client_pagination(
|
|
90
|
-
|
|
91
|
-
|
|
91
|
+
def client_pagination(
|
|
92
|
+
search_method: Callable, id_name: str, additional_fields: list[str] | None = None
|
|
93
|
+
) -> None:
|
|
94
|
+
"""Test pagination on an endpoint.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
search_method (Callable) : Client search method to use
|
|
98
|
+
id_name (str) : the name of a field which uniquely indexes a series of documents
|
|
99
|
+
additional_fields (list of str) : Optional other fields to retrieve.
|
|
100
|
+
|
|
101
|
+
Raises:
|
|
102
|
+
AssertionError if pagination does not result in unique sets of documents
|
|
103
|
+
"""
|
|
104
|
+
fields = [id_name, *(additional_fields or [])]
|
|
105
|
+
page_1 = search_method(_page=1, chunk_size=NUM_DOCS, fields=fields)
|
|
106
|
+
page_2 = search_method(_page=2, chunk_size=NUM_DOCS, fields=fields)
|
|
92
107
|
assert all(len(results) == NUM_DOCS for results in (page_1, page_2))
|
|
93
108
|
assert {str(getattr(doc, id_name)) for doc in page_1}.intersection(
|
|
94
109
|
{str(getattr(doc, id_name)) for doc in page_2}
|
|
95
110
|
) == set()
|
|
96
111
|
|
|
97
112
|
|
|
98
|
-
def client_sort(
|
|
113
|
+
def client_sort(
|
|
114
|
+
search_method: Callable,
|
|
115
|
+
sort_fields: str | Sequence[str],
|
|
116
|
+
aux_query: dict[str, Any] | None = None,
|
|
117
|
+
default_fields: tuple[str, ...] = ("deprecated", "material_id"),
|
|
118
|
+
):
|
|
119
|
+
"""Test sorting on an endpoint.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
search_method (Callable) : Client search method to use
|
|
123
|
+
sort_fields (str or Sequence of str) : fields to sort on
|
|
124
|
+
aux_query (dict) : auxiliary query needed to filter documents
|
|
125
|
+
default_fields (list): default fields to return
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
AssertionError if sorting in ascending or descending order does not work.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def _normalize(doc, field: str):
|
|
132
|
+
v = getattr(doc, field)
|
|
133
|
+
# serialize enums
|
|
134
|
+
return v.value if isinstance(v, Enum) else v
|
|
135
|
+
|
|
136
|
+
user_query = {
|
|
137
|
+
k: v
|
|
138
|
+
for k, v in (aux_query or {}).items()
|
|
139
|
+
if k not in ("_page", "_sort_fields", "chunk_size", "fields")
|
|
140
|
+
}
|
|
99
141
|
for sort_field in [sort_fields] if isinstance(sort_fields, str) else sort_fields:
|
|
142
|
+
|
|
100
143
|
asc = search_method(
|
|
101
|
-
_page=1,
|
|
144
|
+
_page=1,
|
|
145
|
+
_sort_fields=sort_field,
|
|
146
|
+
chunk_size=NUM_DOCS,
|
|
147
|
+
fields=[sort_field, *default_fields],
|
|
148
|
+
**user_query,
|
|
102
149
|
)
|
|
103
150
|
desc = search_method(
|
|
104
151
|
_page=1,
|
|
@@ -108,12 +155,12 @@ def client_sort(search_method: Callable, sort_fields: str | Sequence[str]):
|
|
|
108
155
|
)
|
|
109
156
|
|
|
110
157
|
idxs = list(range(NUM_DOCS))
|
|
111
|
-
assert sorted(idxs, key=lambda idx:
|
|
158
|
+
assert sorted(idxs, key=lambda idx: _normalize(asc[idx], sort_field)) == idxs
|
|
112
159
|
|
|
113
160
|
assert (
|
|
114
161
|
sorted(
|
|
115
162
|
idxs,
|
|
116
|
-
key=lambda idx:
|
|
163
|
+
key=lambda idx: _normalize(desc[idx], sort_field),
|
|
117
164
|
reverse=True,
|
|
118
165
|
)
|
|
119
166
|
== idxs
|
|
@@ -42,6 +42,7 @@ from requests.exceptions import RequestException
|
|
|
42
42
|
from tqdm.auto import tqdm
|
|
43
43
|
from urllib3.util.retry import Retry
|
|
44
44
|
|
|
45
|
+
from mp_api.client._server_utils import get_consumer, get_user_api_key, is_dev_env
|
|
45
46
|
from mp_api.client.core.exceptions import (
|
|
46
47
|
MPRestError,
|
|
47
48
|
MPRestWarning,
|
|
@@ -52,7 +53,6 @@ from mp_api.client.core.settings import MAPI_CLIENT_SETTINGS
|
|
|
52
53
|
from mp_api.client.core.utils import (
|
|
53
54
|
MPDataset,
|
|
54
55
|
load_json,
|
|
55
|
-
validate_api_key,
|
|
56
56
|
validate_endpoint,
|
|
57
57
|
validate_ids,
|
|
58
58
|
)
|
|
@@ -68,6 +68,17 @@ try:
|
|
|
68
68
|
except PackageNotFoundError: # pragma: no cover
|
|
69
69
|
__version__ = os.getenv("SETUPTOOLS_SCM_PRETEND_VERSION", "")
|
|
70
70
|
|
|
71
|
+
STATIC_COLLECTIONS = [
|
|
72
|
+
"eos",
|
|
73
|
+
"grain_boundaries",
|
|
74
|
+
"jcesr",
|
|
75
|
+
"molecules",
|
|
76
|
+
"phonon",
|
|
77
|
+
"snls",
|
|
78
|
+
"surface-properties",
|
|
79
|
+
"synth-descriptions",
|
|
80
|
+
"xas",
|
|
81
|
+
]
|
|
71
82
|
|
|
72
83
|
hdlr = logging.StreamHandler()
|
|
73
84
|
fmt = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
|
|
@@ -86,33 +97,52 @@ def _batched(iterable: Iterable, n: int) -> Iterator:
|
|
|
86
97
|
yield batch
|
|
87
98
|
|
|
88
99
|
|
|
89
|
-
class
|
|
90
|
-
"""Base client class with core stubs."""
|
|
100
|
+
class QueryBuilderWithCache(QueryBuilder):
|
|
91
101
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
102
|
+
def __init__(self) -> None:
|
|
103
|
+
"""Extend deltalake.QueryBuilder with stored DeltaTables.
|
|
104
|
+
|
|
105
|
+
The deltalake.QueryBuilder class does not permit introspection
|
|
106
|
+
of registered DeltaTables through the python API.
|
|
107
|
+
|
|
108
|
+
Re-registering a DeltaTable
|
|
109
|
+
(1) wastes time by reading its metadata
|
|
110
|
+
(2) raises an exception because a table is already registered
|
|
111
|
+
|
|
112
|
+
This class simply allows for caching the DeltaTable instances
|
|
113
|
+
and table names on the QueryBuilder class.
|
|
114
|
+
"""
|
|
115
|
+
# Dict of table names (labels) to DeltaTable instances
|
|
116
|
+
self._delta_tables: dict[str, DeltaTable] = {}
|
|
117
|
+
super().__init__()
|
|
118
|
+
|
|
119
|
+
def register(self, table_name: str, delta_table: DeltaTable) -> QueryBuilder:
|
|
120
|
+
"""Register and cache a DeltaTable."""
|
|
121
|
+
self._delta_tables[table_name] = delta_table
|
|
122
|
+
return super().register(table_name, delta_table)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class _Rester:
|
|
126
|
+
"""Define base attributes of a REST client."""
|
|
96
127
|
|
|
97
128
|
def __init__(
|
|
98
129
|
self,
|
|
99
130
|
api_key: str | None = None,
|
|
100
131
|
endpoint: str | None = None,
|
|
101
132
|
include_user_agent: bool = True,
|
|
102
|
-
session: requests.Session | None = None,
|
|
103
|
-
s3_client: Any | None = None,
|
|
104
|
-
debug: bool = False,
|
|
105
133
|
use_document_model: bool = True,
|
|
106
|
-
|
|
134
|
+
session: requests.Session | None = None,
|
|
107
135
|
headers: dict | None = None,
|
|
108
136
|
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
|
|
137
|
+
db_version: str | None = None,
|
|
109
138
|
local_dataset_cache: (
|
|
110
139
|
str | os.PathLike
|
|
111
140
|
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
|
|
112
141
|
force_renew: bool = False,
|
|
142
|
+
query_builder: QueryBuilderWithCache | None = None,
|
|
113
143
|
**kwargs,
|
|
114
|
-
):
|
|
115
|
-
"""Initialize
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Initialize a RESTer.
|
|
116
146
|
|
|
117
147
|
Arguments:
|
|
118
148
|
api_key: A String API key for accessing the MaterialsProject
|
|
@@ -131,49 +161,56 @@ class BaseRester:
|
|
|
131
161
|
making the API request. This helps MP support pymatgen users, and
|
|
132
162
|
is similar to what most web browsers send with each page request.
|
|
133
163
|
Set to False to disable the user agent.
|
|
134
|
-
session: requests Session object with which to connect to the API, for
|
|
135
|
-
advanced usage only.
|
|
136
|
-
s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores.
|
|
137
|
-
debug: if True, print the URL for every request
|
|
138
164
|
use_document_model: If False, skip the creating the document model and return data
|
|
139
165
|
as a dictionary. This can be simpler to work with but bypasses data validation
|
|
140
166
|
and will not give auto-complete for available fields.
|
|
141
|
-
|
|
167
|
+
session: requests Session object with which to connect to the API, for
|
|
168
|
+
advanced usage only.
|
|
142
169
|
headers: Custom headers for localhost connections.
|
|
143
170
|
mute_progress_bars: Whether to disable progress bars.
|
|
171
|
+
db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database
|
|
172
|
+
than what is currently deployed. The Materials Project cannot guarantee that all
|
|
173
|
+
features will still work.
|
|
144
174
|
local_dataset_cache: Target directory for downloading full datasets. Defaults
|
|
145
175
|
to 'mp_datasets' in the user's home directory
|
|
146
176
|
force_renew: Option to overwrite existing local dataset
|
|
177
|
+
query_builder : Instance of QueryBuilderWithCache to use in querying delta tables
|
|
178
|
+
NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored.
|
|
147
179
|
**kwargs: access to legacy kwargs that may be in the process of being deprecated
|
|
148
180
|
"""
|
|
149
|
-
self.api_key =
|
|
150
|
-
self.
|
|
151
|
-
self.endpoint = validate_endpoint(endpoint, suffix=self.suffix)
|
|
181
|
+
self.api_key = get_user_api_key(api_key=api_key)
|
|
182
|
+
self.endpoint = validate_endpoint(endpoint)
|
|
152
183
|
|
|
153
|
-
self.debug = debug
|
|
154
184
|
self.include_user_agent = include_user_agent
|
|
155
185
|
self.use_document_model = use_document_model
|
|
156
|
-
self.timeout = timeout
|
|
157
|
-
self.headers = headers or {}
|
|
158
|
-
self.mute_progress_bars = mute_progress_bars
|
|
159
186
|
|
|
160
|
-
(
|
|
161
|
-
|
|
162
|
-
self.
|
|
163
|
-
|
|
187
|
+
self.headers = headers or get_consumer()
|
|
188
|
+
self._session = session or _Rester._create_session(
|
|
189
|
+
api_key=self.api_key,
|
|
190
|
+
include_user_agent=self.include_user_agent,
|
|
191
|
+
headers=self.headers,
|
|
192
|
+
)
|
|
164
193
|
|
|
165
|
-
|
|
166
|
-
|
|
194
|
+
if is_dev_env():
|
|
195
|
+
self._session.headers["x-api-key"] = self.api_key or ""
|
|
167
196
|
|
|
168
|
-
self.
|
|
169
|
-
self.
|
|
197
|
+
self.use_document_model = use_document_model
|
|
198
|
+
self.mute_progress_bars = mute_progress_bars
|
|
199
|
+
self.db_version: str = db_version or ""
|
|
200
|
+
self.local_dataset_cache = Path(local_dataset_cache)
|
|
201
|
+
self.force_renew = force_renew
|
|
202
|
+
self._query_builder = (
|
|
203
|
+
query_builder if isinstance(query_builder, QueryBuilderWithCache) else None
|
|
204
|
+
)
|
|
170
205
|
|
|
171
206
|
if "monty_decode" in kwargs:
|
|
207
|
+
# Pop to not repeatedly trigger warning to the user
|
|
208
|
+
kwargs.pop("monty_decode", None)
|
|
172
209
|
warnings.warn(
|
|
173
210
|
"Ignoring `monty_decode`, as it is no longer a supported option in `mp_api`."
|
|
174
211
|
"The client by default returns results consistent with `monty_decode=True`.",
|
|
175
|
-
category=MPRestWarning,
|
|
176
212
|
stacklevel=2,
|
|
213
|
+
category=MPRestWarning,
|
|
177
214
|
)
|
|
178
215
|
|
|
179
216
|
@property
|
|
@@ -185,13 +222,10 @@ class BaseRester:
|
|
|
185
222
|
return self._session
|
|
186
223
|
|
|
187
224
|
@property
|
|
188
|
-
def
|
|
189
|
-
if not self.
|
|
190
|
-
self.
|
|
191
|
-
|
|
192
|
-
config=Config(signature_version=UNSIGNED), # type: ignore
|
|
193
|
-
)
|
|
194
|
-
return self._s3_client
|
|
225
|
+
def query_builder(self):
|
|
226
|
+
if not self._query_builder:
|
|
227
|
+
self._query_builder = QueryBuilderWithCache()
|
|
228
|
+
return self._query_builder
|
|
195
229
|
|
|
196
230
|
@staticmethod
|
|
197
231
|
def _create_session(api_key, include_user_agent, headers):
|
|
@@ -270,6 +304,112 @@ class BaseRester:
|
|
|
270
304
|
response = get_resp.json()
|
|
271
305
|
return response["db_version"], response["access_controlled_batch_ids"]
|
|
272
306
|
|
|
307
|
+
|
|
308
|
+
class BaseRester(_Rester):
|
|
309
|
+
"""Base client class with core stubs."""
|
|
310
|
+
|
|
311
|
+
suffix: str = ""
|
|
312
|
+
document_model: type[BaseModel] = _DictLikeAccess
|
|
313
|
+
primary_key: str = "material_id"
|
|
314
|
+
delta_backed: bool = True
|
|
315
|
+
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
api_key: str | None = None,
|
|
319
|
+
endpoint: str | None = None,
|
|
320
|
+
include_user_agent: bool = True,
|
|
321
|
+
use_document_model: bool = True,
|
|
322
|
+
session: requests.Session | None = None,
|
|
323
|
+
headers: dict | None = None,
|
|
324
|
+
mute_progress_bars: bool = MAPI_CLIENT_SETTINGS.MUTE_PROGRESS_BARS,
|
|
325
|
+
db_version: str | None = None,
|
|
326
|
+
local_dataset_cache: (
|
|
327
|
+
str | os.PathLike
|
|
328
|
+
) = MAPI_CLIENT_SETTINGS.LOCAL_DATASET_CACHE,
|
|
329
|
+
force_renew: bool = False,
|
|
330
|
+
query_builder: QueryBuilderWithCache | None = None,
|
|
331
|
+
s3_client: Any | None = None,
|
|
332
|
+
timeout: int = 20,
|
|
333
|
+
**kwargs,
|
|
334
|
+
):
|
|
335
|
+
"""Initialize the REST API helper class.
|
|
336
|
+
|
|
337
|
+
s3_client: boto3 S3 client object with which to connect to the object stores.
|
|
338
|
+
timeout: Time in seconds to wait until a request timeout error is thrown
|
|
339
|
+
|
|
340
|
+
Arguments:
|
|
341
|
+
api_key: A String API key for accessing the MaterialsProject
|
|
342
|
+
REST interface. Please obtain your API key at
|
|
343
|
+
https://www.materialsproject.org/dashboard. If this is None,
|
|
344
|
+
the code will check if there is a "PMG_MAPI_KEY" setting.
|
|
345
|
+
If so, it will use that environment variable. This makes
|
|
346
|
+
easier for heavy users to simply add this environment variable to
|
|
347
|
+
their setups and MPRester can then be called without any arguments.
|
|
348
|
+
endpoint: Url of endpoint to access the MaterialsProject REST
|
|
349
|
+
interface. Defaults to the standard Materials Project REST
|
|
350
|
+
address at "https://api.materialsproject.org", but
|
|
351
|
+
can be changed to other urls implementing a similar interface.
|
|
352
|
+
include_user_agent: If True, will include a user agent with the
|
|
353
|
+
HTTP request including information on pymatgen and system version
|
|
354
|
+
making the API request. This helps MP support pymatgen users, and
|
|
355
|
+
is similar to what most web browsers send with each page request.
|
|
356
|
+
Set to False to disable the user agent.
|
|
357
|
+
session: requests Session object with which to connect to the API, for
|
|
358
|
+
advanced usage only.
|
|
359
|
+
use_document_model: If False, skip the creating the document model and return data
|
|
360
|
+
as a dictionary. This can be simpler to work with but bypasses data validation
|
|
361
|
+
and will not give auto-complete for available fields.
|
|
362
|
+
headers: Custom headers for localhost connections.
|
|
363
|
+
mute_progress_bars: Whether to disable progress bars.
|
|
364
|
+
db_version (str) : EXPERIMENTAL, allows for accessing a different version of the database
|
|
365
|
+
than what is currently deployed. The Materials Project cannot guarantee that all
|
|
366
|
+
features will still work.
|
|
367
|
+
local_dataset_cache: Target directory for downloading full datasets. Defaults
|
|
368
|
+
to 'mp_datasets' in the user's home directory
|
|
369
|
+
force_renew: Option to overwrite existing local dataset
|
|
370
|
+
query_builder : Instance of QueryBuilderWithCache to use in querying delta tables
|
|
371
|
+
NOTE: Must be a QueryBuilderWithCache, a deltalake.QueryBuilder will be ignored.
|
|
372
|
+
s3_client: boto3 S3 client object with which to connect to the object stores.ct to the object stores.ct to the object stores.
|
|
373
|
+
timeout: Time in seconds to wait until a request timeout error is thrown
|
|
374
|
+
**kwargs: access to legacy kwargs that may be in the process of being deprecated
|
|
375
|
+
"""
|
|
376
|
+
super().__init__(
|
|
377
|
+
api_key=api_key,
|
|
378
|
+
endpoint=endpoint,
|
|
379
|
+
include_user_agent=include_user_agent,
|
|
380
|
+
use_document_model=use_document_model,
|
|
381
|
+
session=session,
|
|
382
|
+
headers=headers,
|
|
383
|
+
mute_progress_bars=mute_progress_bars,
|
|
384
|
+
db_version=db_version,
|
|
385
|
+
local_dataset_cache=local_dataset_cache,
|
|
386
|
+
force_renew=force_renew,
|
|
387
|
+
query_builder=query_builder,
|
|
388
|
+
**kwargs,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
self.base_endpoint = validate_endpoint(endpoint)
|
|
392
|
+
self.endpoint = validate_endpoint(endpoint, suffix=self.suffix)
|
|
393
|
+
|
|
394
|
+
(
|
|
395
|
+
hb_db_version,
|
|
396
|
+
self.access_controlled_batch_ids,
|
|
397
|
+
) = self._get_heartbeat_info(self.base_endpoint)
|
|
398
|
+
if not self.db_version:
|
|
399
|
+
self.db_version = hb_db_version
|
|
400
|
+
|
|
401
|
+
self.timeout = timeout
|
|
402
|
+
self._s3_client = s3_client
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def s3_client(self):
|
|
406
|
+
if not self._s3_client:
|
|
407
|
+
self._s3_client = boto3.client(
|
|
408
|
+
"s3",
|
|
409
|
+
config=Config(signature_version=UNSIGNED), # type: ignore
|
|
410
|
+
)
|
|
411
|
+
return self._s3_client
|
|
412
|
+
|
|
273
413
|
def _post_resource(
|
|
274
414
|
self,
|
|
275
415
|
body: dict | None = None,
|
|
@@ -440,18 +580,120 @@ class BaseRester:
|
|
|
440
580
|
|
|
441
581
|
return decoded_data, len(decoded_data) # type: ignore
|
|
442
582
|
|
|
583
|
+
def _get_delta_table(
|
|
584
|
+
self,
|
|
585
|
+
bucket: str,
|
|
586
|
+
prefix: str,
|
|
587
|
+
connector: str = "s3a",
|
|
588
|
+
label: str | None = None,
|
|
589
|
+
) -> tuple[str, DeltaTable]:
|
|
590
|
+
"""Either create a new DeltaTable, or retrieve a cached one.
|
|
591
|
+
|
|
592
|
+
If creating a new DeltaTable, will also register in self.query_builder
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
bucket (str) : name of the bucket in S3
|
|
596
|
+
prefix (str) : name of the prefix in S3
|
|
597
|
+
connector (str) : s3, s3n, s3a (default), or other
|
|
598
|
+
valid Hadoop connector string.
|
|
599
|
+
label (str or None) : optional label for the table in the
|
|
600
|
+
cached query builder
|
|
601
|
+
If `None`, will be gleaned from the URI
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
str : the table name in the stored query builder
|
|
605
|
+
DeltaTable : If one exists at the specified bucket / prefix,
|
|
606
|
+
will retrieve the cached instance.
|
|
607
|
+
"""
|
|
608
|
+
delta_timeout = f"{self.timeout}s"
|
|
609
|
+
full_key = f"{bucket}/{prefix}"
|
|
610
|
+
qb_label = label or full_key.replace("/", "_").replace("-", "_")
|
|
611
|
+
|
|
612
|
+
uri = f"{connector}://{full_key}"
|
|
613
|
+
if not uri.endswith("/"):
|
|
614
|
+
uri += "/"
|
|
615
|
+
|
|
616
|
+
try:
|
|
617
|
+
stored_label, delta_table = next(
|
|
618
|
+
(_label, _table)
|
|
619
|
+
for _label, _table in self.query_builder._delta_tables.items()
|
|
620
|
+
if _table.table_uri == uri
|
|
621
|
+
)
|
|
622
|
+
except StopIteration:
|
|
623
|
+
stored_label = None
|
|
624
|
+
|
|
625
|
+
if stored_label is None:
|
|
626
|
+
delta_table = DeltaTable(
|
|
627
|
+
uri,
|
|
628
|
+
storage_options={
|
|
629
|
+
"AWS_SKIP_SIGNATURE": "true",
|
|
630
|
+
"AWS_REGION": "us-east-1",
|
|
631
|
+
"timeout": delta_timeout,
|
|
632
|
+
"connect_timeout": delta_timeout,
|
|
633
|
+
"retry_delay": "3",
|
|
634
|
+
"max_retries": f"{MAPI_CLIENT_SETTINGS.MAX_RETRIES}",
|
|
635
|
+
},
|
|
636
|
+
)
|
|
637
|
+
self.query_builder.register(qb_label, delta_table)
|
|
638
|
+
|
|
639
|
+
elif stored_label != qb_label:
|
|
640
|
+
warnings.warn(
|
|
641
|
+
f"DeltaTable with URI {uri} already found with different label: "
|
|
642
|
+
f"Stored label = {stored_label}; submitted label {qb_label}. "
|
|
643
|
+
"Using stored DeltaTable.",
|
|
644
|
+
category=MPRestWarning,
|
|
645
|
+
stacklevel=2,
|
|
646
|
+
)
|
|
647
|
+
return stored_label, delta_table
|
|
648
|
+
|
|
649
|
+
return qb_label, delta_table
|
|
650
|
+
|
|
651
|
+
def _query_delta_single(self, query: str) -> pa.Table:
|
|
652
|
+
"""Execute a SQL query against a registered Delta table.
|
|
653
|
+
|
|
654
|
+
Wraps the query execution in a try/except to provide a more
|
|
655
|
+
actionable error message when the underlying Delta query engine
|
|
656
|
+
fails (e.g., due to network timeouts, missing tables, or
|
|
657
|
+
malformed queries).
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
query (str): A SQL query string compatible with the
|
|
661
|
+
QueryBuilder engine.
|
|
662
|
+
|
|
663
|
+
Returns:
|
|
664
|
+
pa.Table: The query result as a PyArrow Table.
|
|
665
|
+
|
|
666
|
+
Raises:
|
|
667
|
+
MPRestError: If query execution fails for any reason,
|
|
668
|
+
including network timeouts, connectivity issues, or
|
|
669
|
+
invalid queries. Inspect the chained exception for
|
|
670
|
+
the underlying cause.
|
|
671
|
+
"""
|
|
672
|
+
try:
|
|
673
|
+
return pa.table(self.query_builder.execute(query).read_all())
|
|
674
|
+
except Exception as e:
|
|
675
|
+
raise MPRestError(
|
|
676
|
+
f"Failed to retrieve object due to: {e}. "
|
|
677
|
+
f"If this is a timeout error, try increasing the 'timeout' "
|
|
678
|
+
f"parameter on MPRester (current value: {self.timeout}s)."
|
|
679
|
+
) from e
|
|
680
|
+
|
|
443
681
|
def _query_delta_backed(
|
|
444
682
|
self,
|
|
445
683
|
bucket: str,
|
|
446
684
|
prefix: str,
|
|
685
|
+
access_controlled: bool = True,
|
|
447
686
|
timeout: int | None = None,
|
|
687
|
+
label: str | None = None,
|
|
448
688
|
) -> dict[str, Any]:
|
|
449
689
|
"""Retrieve data from S3 backed by a DeltaTable.
|
|
450
690
|
|
|
451
691
|
Args:
|
|
452
692
|
bucket (str) : S3 OpenData bucket
|
|
453
693
|
prefix (str) : S3 object prefix
|
|
694
|
+
access_controlled (bool): whether or not table has access controlled data
|
|
454
695
|
timeout (int or None) : timeout on getting access-controlled groups
|
|
696
|
+
label (str or None) : label of the table in QueryBuilder
|
|
455
697
|
|
|
456
698
|
Returns:
|
|
457
699
|
dict of str to Any
|
|
@@ -508,13 +750,7 @@ class BaseRester:
|
|
|
508
750
|
)
|
|
509
751
|
}
|
|
510
752
|
|
|
511
|
-
tbl =
|
|
512
|
-
f"s3a://{bucket}/{prefix}",
|
|
513
|
-
storage_options={
|
|
514
|
-
"AWS_SKIP_SIGNATURE": "true",
|
|
515
|
-
"AWS_REGION": "us-east-1",
|
|
516
|
-
},
|
|
517
|
-
)
|
|
753
|
+
tbl_lbl, tbl = self._get_delta_table(bucket, prefix, label=label)
|
|
518
754
|
|
|
519
755
|
controlled_batch_str = ",".join(
|
|
520
756
|
[f"'{tag}'" for tag in self.access_controlled_batch_ids]
|
|
@@ -522,19 +758,23 @@ class BaseRester:
|
|
|
522
758
|
|
|
523
759
|
predicate = (
|
|
524
760
|
f"WHERE batch_id NOT IN ({controlled_batch_str})"
|
|
525
|
-
if not has_gnome_access
|
|
761
|
+
if not has_gnome_access and controlled_batch_str and access_controlled
|
|
526
762
|
else ""
|
|
527
763
|
)
|
|
528
|
-
|
|
529
|
-
|
|
764
|
+
# TODO: do we need something like this?
|
|
765
|
+
# predicate += f"{' AND ' if predicate else 'WHERE '}version='{self.db_version}'"
|
|
530
766
|
|
|
531
767
|
# Setup progress bar
|
|
532
768
|
num_docs_needed: int = tbl.count()
|
|
533
769
|
|
|
534
770
|
if not has_gnome_access:
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
771
|
+
try:
|
|
772
|
+
num_docs_needed = self.count(
|
|
773
|
+
{"batch_id_neq_any": self.access_controlled_batch_ids}
|
|
774
|
+
)
|
|
775
|
+
except MPRestError:
|
|
776
|
+
# batch_id isn't a valid field
|
|
777
|
+
num_docs_needed = self.count()
|
|
538
778
|
|
|
539
779
|
pbar = (
|
|
540
780
|
tqdm(
|
|
@@ -549,7 +789,7 @@ class BaseRester:
|
|
|
549
789
|
else None
|
|
550
790
|
)
|
|
551
791
|
|
|
552
|
-
iterator =
|
|
792
|
+
iterator = self.query_builder.execute(f"SELECT * FROM {tbl_lbl} {predicate}")
|
|
553
793
|
|
|
554
794
|
file_options = ds.ParquetFileFormat().make_write_options(compression="zstd")
|
|
555
795
|
|
|
@@ -695,14 +935,21 @@ class BaseRester:
|
|
|
695
935
|
|
|
696
936
|
if "tasks" in suffix:
|
|
697
937
|
bucket_suffix, prefix = ("parsed", "core/tasks/")
|
|
938
|
+
elif suffix in STATIC_COLLECTIONS:
|
|
939
|
+
bucket_suffix = "build"
|
|
940
|
+
prefix = f"static-collections/{suffix}"
|
|
698
941
|
else:
|
|
942
|
+
# TODO: remove once all collections are migrated to delta-backed format
|
|
699
943
|
bucket_suffix = "build"
|
|
700
|
-
prefix = f"collections/{
|
|
944
|
+
prefix = f"collections/{suffix}"
|
|
701
945
|
|
|
702
946
|
bucket = f"materialsproject-{bucket_suffix}"
|
|
703
947
|
|
|
704
948
|
if self.delta_backed:
|
|
705
|
-
|
|
949
|
+
access_controlled = suffix not in STATIC_COLLECTIONS
|
|
950
|
+
return self._query_delta_backed(
|
|
951
|
+
bucket, prefix, access_controlled, timeout=timeout
|
|
952
|
+
)
|
|
706
953
|
|
|
707
954
|
# Paginate over all entries in the bucket.
|
|
708
955
|
# TODO: change when a subset of entries needed from DB
|
|
@@ -1448,8 +1695,10 @@ class CoreRester(BaseRester):
|
|
|
1448
1695
|
use_document_model=self.use_document_model,
|
|
1449
1696
|
headers=self.headers,
|
|
1450
1697
|
mute_progress_bars=self.mute_progress_bars,
|
|
1698
|
+
db_version=self.db_version,
|
|
1451
1699
|
local_dataset_cache=self.local_dataset_cache,
|
|
1452
1700
|
force_renew=self.force_renew,
|
|
1701
|
+
query_builder=self._query_builder,
|
|
1453
1702
|
)
|
|
1454
1703
|
return self.sub_resters[v]
|
|
1455
1704
|
raise AttributeError(f"{self.__class__} has no attribute {v}")
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
|
+
from functools import cached_property
|
|
5
6
|
from importlib import import_module
|
|
6
7
|
from itertools import chain
|
|
7
8
|
from typing import TYPE_CHECKING, ForwardRef, get_args
|
|
@@ -166,6 +167,12 @@ def _generate_returned_model(
|
|
|
166
167
|
data_model.__getattr__ = new_getattr
|
|
167
168
|
data_model.dict = new_dict
|
|
168
169
|
|
|
170
|
+
for attr in dir(document_model):
|
|
171
|
+
if isinstance(
|
|
172
|
+
prop_method := getattr(document_model, attr), property | cached_property
|
|
173
|
+
):
|
|
174
|
+
setattr(data_model, attr, prop_method)
|
|
175
|
+
|
|
169
176
|
return data_model, set_fields, fields_not_requested
|
|
170
177
|
|
|
171
178
|
|