graphatoms 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphatoms-1.1.0/PKG-INFO +26 -0
- graphatoms-1.1.0/README.md +3 -0
- graphatoms-1.1.0/pyproject.toml +125 -0
- graphatoms-1.1.0/src/graphatoms/__init__.py +8 -0
- graphatoms-1.1.0/src/graphatoms/dataclasses/__init__.py +13 -0
- graphatoms-1.1.0/src/graphatoms/dataclasses/_numpydantic.py +62 -0
- graphatoms-1.1.0/src/graphatoms/dataclasses/_pydantic2pyarrow.py +402 -0
- graphatoms-1.1.0/src/graphatoms/dataclasses/_pydanticMixin.py +449 -0
- graphatoms-1.1.0/src/graphatoms/dataclasses/_pydanticModel.py +160 -0
- graphatoms-1.1.0/src/graphatoms/dataclasses/test_dataclasses.py +64 -0
- graphatoms-1.1.0/src/graphatoms/geometry/__init__.py +11 -0
- graphatoms-1.1.0/src/graphatoms/geometry/_bond_list.py +168 -0
- graphatoms-1.1.0/src/graphatoms/geometry/_distance_pairs.py +306 -0
- graphatoms-1.1.0/src/graphatoms/geometry/_inner_outer.py +91 -0
- graphatoms-1.1.0/src/graphatoms/geometry/_neighbor_list.py +284 -0
- graphatoms-1.1.0/src/graphatoms/geometry/rotation.py +151 -0
- graphatoms-1.1.0/src/graphatoms/geometry/sample.py +79 -0
- graphatoms-1.1.0/src/graphatoms/geometry/test_geometry.py +84 -0
- graphatoms-1.1.0/src/graphatoms/reaction/__init__.py +0 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/__init__.py +3 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/_abc.py +12 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/_base.py +135 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/_moveFB.py +179 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/_moveSP.py +119 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/_moveSwap.py +119 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/_simNVT.py +179 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/test_mmc.py +57 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/test_move.py +23 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_amove/zzz.py +103 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_event/__init__.py +10 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_event/_base.py +203 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_event/_createMixin.py +14 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_event/_ioMixin.py +10 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_event/event.py +13 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_network/__init__.py +0 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_network/_asedb.py +140 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_network/_base.py +61 -0
- graphatoms-1.1.0/src/graphatoms/reaction/_network/_h5db.py +140 -0
- graphatoms-1.1.0/src/graphatoms/system/__init__.py +25 -0
- graphatoms-1.1.0/src/graphatoms/system/_atoms/__init__.py +5 -0
- graphatoms-1.1.0/src/graphatoms/system/_atoms/_box.py +85 -0
- graphatoms-1.1.0/src/graphatoms/system/_atoms/_eng.py +268 -0
- graphatoms-1.1.0/src/graphatoms/system/_atoms/_struct.py +167 -0
- graphatoms-1.1.0/src/graphatoms/system/_graph/__init__.py +4 -0
- graphatoms-1.1.0/src/graphatoms/system/_graph/_bonds.py +558 -0
- graphatoms-1.1.0/src/graphatoms/system/_graph/_gasMixin.py +45 -0
- graphatoms-1.1.0/src/graphatoms/system/_graph/_sysGraph.py +520 -0
- graphatoms-1.1.0/src/graphatoms/system/_other/_system.py +202 -0
- graphatoms-1.1.0/src/graphatoms/system/_sys/__init__.py +4 -0
- graphatoms-1.1.0/src/graphatoms/system/_sys/_sysCluster.py +124 -0
- graphatoms-1.1.0/src/graphatoms/system/_sys/_sysGas.py +181 -0
- graphatoms-1.1.0/src/graphatoms/system/_test_speed.py +17 -0
- graphatoms-1.1.0/src/graphatoms/system/test_system.py +310 -0
- graphatoms-1.1.0/src/graphatoms/utils/__init__.py +1 -0
- graphatoms-1.1.0/src/graphatoms/utils/bytestool.py +194 -0
- graphatoms-1.1.0/src/graphatoms/utils/logger.py +69 -0
- graphatoms-1.1.0/src/graphatoms/utils/parser.py +25 -0
- graphatoms-1.1.0/src/graphatoms/utils/rdutils.py +377 -0
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: graphatoms
|
|
3
|
+
Version: 1.1.0
|
|
4
|
+
Summary: The Chemical Core Class for Graph Theory Analysis.
|
|
5
|
+
Author: LiuGaoyong
|
|
6
|
+
Author-email: LiuGaoyong <liugaoyong_88@163.com>
|
|
7
|
+
Requires-Dist: ase
|
|
8
|
+
Requires-Dist: pymatgen>2023.6
|
|
9
|
+
Requires-Dist: rdkit>=2025
|
|
10
|
+
Requires-Dist: scikit-learn>=1.5
|
|
11
|
+
Requires-Dist: pyarrow
|
|
12
|
+
Requires-Dist: igraph>=0.11
|
|
13
|
+
Requires-Dist: hydra-core
|
|
14
|
+
Requires-Dist: numpydantic
|
|
15
|
+
Requires-Dist: pydantic>=2.11
|
|
16
|
+
Requires-Dist: python-snappy>=0.7.3
|
|
17
|
+
Requires-Dist: tomli ; python_full_version < '3.11'
|
|
18
|
+
Requires-Dist: tomli-w
|
|
19
|
+
Requires-Dist: typing-extensions
|
|
20
|
+
Requires-Dist: loguru
|
|
21
|
+
Requires-Python: >=3.10
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
|
|
24
|
+
# graphatoms
|
|
25
|
+
|
|
26
|
+
The Chemical Core Class for Graph Theory Analysis & Graph Neural Network.
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["uv_build>=0.11.15,<0.12"]
|
|
3
|
+
build-backend = "uv_build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "graphatoms"
|
|
7
|
+
version = "1.1.0"
|
|
8
|
+
readme = "README.md"
|
|
9
|
+
description = "The Chemical Core Class for Graph Theory Analysis."
|
|
10
|
+
authors = [{ name = "LiuGaoyong", email = "liugaoyong_88@163.com" }]
|
|
11
|
+
requires-python = ">=3.10"
|
|
12
|
+
dependencies = [
|
|
13
|
+
# Chemoinformatics
|
|
14
|
+
"ase",
|
|
15
|
+
"pymatgen>2023.6",
|
|
16
|
+
"rdkit>=2025",
|
|
17
|
+
|
|
18
|
+
# Scientific Computing
|
|
19
|
+
# numpy pandas, scipy
|
|
20
|
+
"scikit-learn>=1.5",
|
|
21
|
+
"pyarrow",
|
|
22
|
+
|
|
23
|
+
# Graph Theory
|
|
24
|
+
"igraph>=0.11", # high performance graph library based on C
|
|
25
|
+
|
|
26
|
+
# Data Model
|
|
27
|
+
'hydra-core',
|
|
28
|
+
"numpydantic",
|
|
29
|
+
"pydantic>=2.11",
|
|
30
|
+
"python-snappy>=0.7.3", # More Fase Data Compression
|
|
31
|
+
"tomli; python_version < '3.11'", # use tomllib in python>=3.11
|
|
32
|
+
"tomli-w", # TOML writer
|
|
33
|
+
"typing-extensions",
|
|
34
|
+
'loguru',
|
|
35
|
+
]
|
|
36
|
+
[dependency-groups]
|
|
37
|
+
dev = [
|
|
38
|
+
"ruff>=0.11",
|
|
39
|
+
"pytest>=8.3",
|
|
40
|
+
"pytest-xdist",
|
|
41
|
+
"jupyter",
|
|
42
|
+
"sqlmodel", # SQL + pydantic
|
|
43
|
+
'openbabel-wheel',
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
# uv config
|
|
47
|
+
[tool.uv]
|
|
48
|
+
package = true
|
|
49
|
+
[[tool.uv.index]]
|
|
50
|
+
url = "https://mirrors.cernet.edu.cn/pypi/web/simple" # CERNET
|
|
51
|
+
default = true
|
|
52
|
+
# [[tool.uv.index]]
|
|
53
|
+
# name = "pytorch-cpu"
|
|
54
|
+
# url = "https://mirrors.nju.edu.cn/pytorch/whl/cpu/"
|
|
55
|
+
# explicit = true
|
|
56
|
+
# format = "flat"
|
|
57
|
+
# [tool.uv.sources]
|
|
58
|
+
# torch = [{ index = "pytorch-cpu" }]
|
|
59
|
+
|
|
60
|
+
# pytest config
|
|
61
|
+
[tool.pytest.ini_options]
|
|
62
|
+
filterwarnings = [
|
|
63
|
+
"error", # All other warnings are transformed into errors.
|
|
64
|
+
# ignore the following warnings that matching a regex
|
|
65
|
+
# example: 'ignore:function ham\(\) is deprecated:DeprecationWarning',
|
|
66
|
+
"ignore::RuntimeWarning",
|
|
67
|
+
|
|
68
|
+
]
|
|
69
|
+
addopts = '--maxfail=1 -rf' # exit after 1 failures, report fail info
|
|
70
|
+
testpaths = [
|
|
71
|
+
"src/graphatoms/dataclasses",
|
|
72
|
+
"src/graphatoms/geometry",
|
|
73
|
+
'src/graphatoms/system',
|
|
74
|
+
'src/tests',
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
# ruff config
|
|
78
|
+
[tool.ruff]
|
|
79
|
+
line-length = 80
|
|
80
|
+
indent-width = 4
|
|
81
|
+
select = ["F", "E", "W", "UP", "D"]
|
|
82
|
+
ignore = ["F722", "D100"]
|
|
83
|
+
[tool.ruff.lint.pydocstyle]
|
|
84
|
+
convention = "google"
|
|
85
|
+
|
|
86
|
+
[tool.pixi.package]
|
|
87
|
+
name = 'graphatoms'
|
|
88
|
+
version = '1.1.0'
|
|
89
|
+
|
|
90
|
+
[tool.pixi.package.build]
|
|
91
|
+
backend = { name = "pixi-build-python", version = "*" }
|
|
92
|
+
channels = ["https://prefix.dev/conda-forge"]
|
|
93
|
+
|
|
94
|
+
[tool.pixi.workspace]
|
|
95
|
+
preview = ["pixi-build"]
|
|
96
|
+
channels = ["conda-forge"]
|
|
97
|
+
platforms = ["win-64", "linux-64", "osx-arm64"]
|
|
98
|
+
|
|
99
|
+
[tool.pixi.pypi-dependencies]
|
|
100
|
+
graphatoms = { path = ".", editable = true }
|
|
101
|
+
|
|
102
|
+
[tool.pixi.environments]
|
|
103
|
+
default = { solve-group = "default" }
|
|
104
|
+
dev = { features = ["dev"], solve-group = "default" }
|
|
105
|
+
|
|
106
|
+
[tool.pixi.tasks]
|
|
107
|
+
test = "pytest -s -vv"
|
|
108
|
+
benchmark = 'pytest -s graphatoms/system/_test_speed.py'
|
|
109
|
+
|
|
110
|
+
[tool.pixi.dependencies]
|
|
111
|
+
python = ">=3.12"
|
|
112
|
+
pytest = ">=9.0.3,<10"
|
|
113
|
+
ase = ">=3.28.0,<4"
|
|
114
|
+
pymatgen = ">=2026.5.4,<2027"
|
|
115
|
+
rdkit = ">=2025.9.5,<2027"
|
|
116
|
+
openbabel = ">=3.1.1,<4"
|
|
117
|
+
pyarrow = ">=21.0.0,<25"
|
|
118
|
+
scikit-learn = ">=1.8.0,<2"
|
|
119
|
+
python-igraph = ">=0.11.9,<0.12"
|
|
120
|
+
numpydantic = ">=1.8.1,<2"
|
|
121
|
+
python-snappy = ">=0.7.3,<0.8"
|
|
122
|
+
h5py = ">=3.16.0,<4"
|
|
123
|
+
hydra-core = ">=1.3.2,<2"
|
|
124
|
+
tomli-w = ">=1.2.0,<2"
|
|
125
|
+
loguru = ">=0.7.3,<0.8"
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""The extended dataclasses by pydantic & numpydantic."""
|
|
2
|
+
|
|
3
|
+
from ._numpydantic import NDArray, numpy_validator
|
|
4
|
+
from ._pydantic2pyarrow import get_pyarrow_schema
|
|
5
|
+
from ._pydanticModel import OurBaseModel, OurFrozenModel
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"NDArray",
|
|
9
|
+
"numpy_validator",
|
|
10
|
+
"get_pyarrow_schema",
|
|
11
|
+
"OurFrozenModel",
|
|
12
|
+
"OurBaseModel",
|
|
13
|
+
]
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import TYPE_CHECKING, Annotated, Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpydantic import NDArray as _NDArray
|
|
7
|
+
from pydantic import (
|
|
8
|
+
BeforeValidator,
|
|
9
|
+
PlainSerializer,
|
|
10
|
+
WithJsonSchema,
|
|
11
|
+
validate_call,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = ["NDArray", "numpy_validator"]
|
|
15
|
+
_ENCODING = "latin1"
|
|
16
|
+
|
|
17
|
+
NDArray = Annotated[
|
|
18
|
+
_NDArray,
|
|
19
|
+
PlainSerializer(lambda x: x.tobytes().decode(_ENCODING), return_type=str),
|
|
20
|
+
# for pyarrow compatibility, we need to convert bytes to str
|
|
21
|
+
# for numpy.ndarray before serialization.
|
|
22
|
+
WithJsonSchema({"type": str}, mode="serialization"),
|
|
23
|
+
]
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from numpy.typing import NDArray
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@validate_call
|
|
29
|
+
def __convert2numpy(
|
|
30
|
+
x: Any,
|
|
31
|
+
dtype: Any = "float",
|
|
32
|
+
shape: Sequence[int] = (-1,),
|
|
33
|
+
) -> _NDArray:
|
|
34
|
+
"""Convert a string or bytes or array-like object to a numpy array.
|
|
35
|
+
|
|
36
|
+
Raises:
|
|
37
|
+
KeyError: If the input value is a scalar number.
|
|
38
|
+
"""
|
|
39
|
+
if isinstance(x, bytes | str):
|
|
40
|
+
x = x.encode(_ENCODING) if isinstance(x, str) else x
|
|
41
|
+
return np.frombuffer(x, np.dtype(dtype)).reshape(shape)
|
|
42
|
+
elif np.isscalar(x):
|
|
43
|
+
raise KeyError(
|
|
44
|
+
f"Invalid input value: {x}({type(x)})."
|
|
45
|
+
" The scalar value is not supported."
|
|
46
|
+
)
|
|
47
|
+
else:
|
|
48
|
+
return np.asarray(x, dtype=dtype).reshape(shape)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def numpy_validator(
|
|
52
|
+
dtype: Any = "float",
|
|
53
|
+
shape: Sequence[int] = (-1,),
|
|
54
|
+
) -> BeforeValidator:
|
|
55
|
+
"""Create pydantic validator in `before` mode for numpy array."""
|
|
56
|
+
return BeforeValidator(
|
|
57
|
+
partial(
|
|
58
|
+
__convert2numpy,
|
|
59
|
+
shape=shape,
|
|
60
|
+
dtype=dtype,
|
|
61
|
+
)
|
|
62
|
+
)
|
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
# ruff: noqa: E501
|
|
2
|
+
import datetime
|
|
3
|
+
import types
|
|
4
|
+
import uuid
|
|
5
|
+
from decimal import Decimal
|
|
6
|
+
from enum import EnumMeta
|
|
7
|
+
from typing import (
|
|
8
|
+
Annotated,
|
|
9
|
+
Any,
|
|
10
|
+
Literal,
|
|
11
|
+
NamedTuple,
|
|
12
|
+
TypeVar,
|
|
13
|
+
Union,
|
|
14
|
+
cast,
|
|
15
|
+
get_args,
|
|
16
|
+
get_origin,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
import pyarrow as pa # type: ignore
|
|
20
|
+
from annotated_types import Ge, Gt
|
|
21
|
+
from pyarrow import Schema, field, schema
|
|
22
|
+
from pydantic import AwareDatetime, BaseModel, NaiveDatetime, WithJsonSchema
|
|
23
|
+
from pydantic.fields import FieldInfo
|
|
24
|
+
|
|
25
|
+
BaseModelType = TypeVar("BaseModelType", bound=BaseModel)
|
|
26
|
+
EnumType = TypeVar("EnumType", bound=EnumMeta)
|
|
27
|
+
__all__ = ["get_pyarrow_schema"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class SchemaCreationError(Exception):
|
|
31
|
+
"""Error when creating pyarrow schema."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Settings(NamedTuple):
|
|
35
|
+
allow_losing_tz: bool
|
|
36
|
+
by_alias: bool
|
|
37
|
+
exclude_fields: bool
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
FIELD_MAP = {
|
|
41
|
+
str: pa.string(),
|
|
42
|
+
bytes: pa.binary(),
|
|
43
|
+
bool: pa.bool_(),
|
|
44
|
+
float: pa.float64(),
|
|
45
|
+
datetime.date: pa.date32(),
|
|
46
|
+
NaiveDatetime: pa.timestamp("ms", tz=None),
|
|
47
|
+
datetime.time: pa.time64("us"),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Timezone aware datetimes will lose their timezone information
|
|
51
|
+
# (but be correctly converted to UTC) when converted to pyarrow.
|
|
52
|
+
# Pyarrow does support having an entire column in a single timezone,
|
|
53
|
+
# but these bare types cannot guarantee that.
|
|
54
|
+
LOSING_TZ_TYPES = {
|
|
55
|
+
datetime.datetime: pa.timestamp("ms", tz=None),
|
|
56
|
+
AwareDatetime: pa.timestamp("ms", tz=None),
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _get_int_type(metadata: list[Any]) -> pa.DataType:
|
|
61
|
+
min_value: int | None = None
|
|
62
|
+
for el in metadata:
|
|
63
|
+
if isinstance(el, Gt):
|
|
64
|
+
if el.gt is not None and not isinstance(el.gt, int):
|
|
65
|
+
raise SchemaCreationError("Gt metadata must be int")
|
|
66
|
+
min_value = el.gt
|
|
67
|
+
elif isinstance(el, Ge):
|
|
68
|
+
if el.ge is not None and not isinstance(el.ge, int):
|
|
69
|
+
raise SchemaCreationError("Ge metadata must be int")
|
|
70
|
+
min_value = el.ge
|
|
71
|
+
|
|
72
|
+
if min_value is not None and min_value >= 0:
|
|
73
|
+
return pa.uint64()
|
|
74
|
+
return pa.int64()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _get_decimal_type(metadata: list[Any]) -> pa.DataType:
|
|
78
|
+
general_metadata = None
|
|
79
|
+
for el in metadata:
|
|
80
|
+
if hasattr(el, "max_digits") and hasattr(el, "decimal_places"):
|
|
81
|
+
general_metadata = el
|
|
82
|
+
if general_metadata is None:
|
|
83
|
+
raise SchemaCreationError(
|
|
84
|
+
"Decimal type needs annotation setting max_digits and decimal_places"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return pa.decimal128(
|
|
88
|
+
general_metadata.max_digits, general_metadata.decimal_places
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
TYPES_WITH_METADATA = {
|
|
93
|
+
Decimal: _get_decimal_type,
|
|
94
|
+
int: _get_int_type,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _get_literal_type(
|
|
99
|
+
field_type: type[Any],
|
|
100
|
+
_metadata: list[Any],
|
|
101
|
+
_settings: Settings,
|
|
102
|
+
) -> pa.DataType:
|
|
103
|
+
values = get_args(field_type)
|
|
104
|
+
if all(isinstance(value, str) for value in values):
|
|
105
|
+
return pa.dictionary(pa.int32(), pa.string())
|
|
106
|
+
elif all(isinstance(value, int) for value in values):
|
|
107
|
+
# Dictionary of (int, int) is converted to just int when
|
|
108
|
+
# written into parquet.
|
|
109
|
+
return pa.int64()
|
|
110
|
+
else:
|
|
111
|
+
msg = "Literal type is only supported with all int or string values. "
|
|
112
|
+
raise SchemaCreationError(msg)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _get_list_type(
|
|
116
|
+
field_type: type[Any],
|
|
117
|
+
metadata: list[Any],
|
|
118
|
+
settings: Settings,
|
|
119
|
+
) -> pa.DataType:
|
|
120
|
+
sub_type = get_args(field_type)[0]
|
|
121
|
+
if _is_optional(sub_type):
|
|
122
|
+
# pyarrow lists can have null elements in them
|
|
123
|
+
sub_type = list(set(get_args(sub_type)) - {type(None)})[0]
|
|
124
|
+
return pa.list_(_get_pyarrow_type(sub_type, metadata, settings))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _get_annotated_type(
|
|
128
|
+
field_type: type[Any],
|
|
129
|
+
metadata: list[Any],
|
|
130
|
+
settings: Settings,
|
|
131
|
+
) -> pa.DataType:
|
|
132
|
+
# ???: fix / clean up / understand why / if this works in all cases
|
|
133
|
+
args = get_args(field_type)[1:]
|
|
134
|
+
metadatas = [
|
|
135
|
+
item.metadata if hasattr(item, "metadata") else [item] for item in args
|
|
136
|
+
]
|
|
137
|
+
metadata = [item for sublist in metadatas for item in sublist]
|
|
138
|
+
field_type = cast(type[Any], get_args(field_type)[0])
|
|
139
|
+
return _get_pyarrow_type(field_type, metadata, settings)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def _get_dict_type(
|
|
143
|
+
field_type: type[Any],
|
|
144
|
+
metadata: list[Any],
|
|
145
|
+
settings: Settings,
|
|
146
|
+
) -> pa.DataType:
|
|
147
|
+
key_type, value_type = get_args(field_type)
|
|
148
|
+
return pa.map_(
|
|
149
|
+
_get_pyarrow_type(key_type, metadata, settings),
|
|
150
|
+
_get_pyarrow_type(value_type, metadata, settings),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
FIELD_TYPES = {
|
|
155
|
+
Literal: _get_literal_type,
|
|
156
|
+
list: _get_list_type,
|
|
157
|
+
Annotated: _get_annotated_type,
|
|
158
|
+
dict: _get_dict_type,
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _get_enum_type(field_type: type[Any]) -> pa.DataType:
|
|
163
|
+
is_str = [
|
|
164
|
+
isinstance(enum_value.value, str) #
|
|
165
|
+
for enum_value in field_type # type: ignore
|
|
166
|
+
]
|
|
167
|
+
if all(is_str):
|
|
168
|
+
return pa.dictionary(pa.int32(), pa.string())
|
|
169
|
+
|
|
170
|
+
is_int = [
|
|
171
|
+
isinstance(enum_value.value, int) #
|
|
172
|
+
for enum_value in field_type # type: ignore
|
|
173
|
+
]
|
|
174
|
+
if all(is_int):
|
|
175
|
+
return pa.int64()
|
|
176
|
+
|
|
177
|
+
msg = "Enums only allowed if all str or all int"
|
|
178
|
+
raise SchemaCreationError(msg)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def _get_uuid_type() -> pa.DataType:
|
|
182
|
+
# Different branches will execute depending on the pyarrow version
|
|
183
|
+
# This is tested through nox and python versions, but each one
|
|
184
|
+
# won't cover both branches. Hence, excluding from coverage.
|
|
185
|
+
if hasattr(pa, "uuid"): # pragma: no cover
|
|
186
|
+
return pa.uuid()
|
|
187
|
+
else: # pragma: no cover
|
|
188
|
+
msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, "
|
|
189
|
+
msg += "needs version 18.0 or higher"
|
|
190
|
+
raise SchemaCreationError(msg)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _is_optional(field_type: type[Any]) -> bool:
|
|
194
|
+
origin = get_origin(field_type)
|
|
195
|
+
is_python_39_union = origin is Union
|
|
196
|
+
is_python_310_union = (
|
|
197
|
+
hasattr(types, "UnionType") and origin is types.UnionType
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if not is_python_39_union and not is_python_310_union:
|
|
201
|
+
return False
|
|
202
|
+
|
|
203
|
+
return type(None) in get_args(field_type)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
# noqa: PLR0911 - ignore until a refactoring can reduce the number of
|
|
207
|
+
# return statements.
|
|
208
|
+
def _get_pyarrow_type( # noqa: PLR0911
|
|
209
|
+
field_type: type[Any],
|
|
210
|
+
metadata: list[Any],
|
|
211
|
+
settings: Settings,
|
|
212
|
+
) -> pa.DataType:
|
|
213
|
+
if field_type in FIELD_MAP:
|
|
214
|
+
return FIELD_MAP[field_type]
|
|
215
|
+
|
|
216
|
+
if field_type is uuid.UUID:
|
|
217
|
+
return _get_uuid_type()
|
|
218
|
+
|
|
219
|
+
if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
|
|
220
|
+
return LOSING_TZ_TYPES[field_type]
|
|
221
|
+
|
|
222
|
+
if not settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
|
|
223
|
+
raise SchemaCreationError(
|
|
224
|
+
f"{field_type} only allowed if ok losing timezone information"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
if isinstance(field_type, EnumMeta):
|
|
228
|
+
return _get_enum_type(field_type)
|
|
229
|
+
|
|
230
|
+
if field_type in TYPES_WITH_METADATA:
|
|
231
|
+
return TYPES_WITH_METADATA[field_type](metadata)
|
|
232
|
+
|
|
233
|
+
if get_origin(field_type) in FIELD_TYPES:
|
|
234
|
+
return FIELD_TYPES[get_origin(field_type)]( # type: ignore
|
|
235
|
+
field_type,
|
|
236
|
+
metadata,
|
|
237
|
+
settings,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# isinstance(filed_type, type) checks whether it's a class
|
|
241
|
+
# otherwise eg Deque[int] would casue an exception on issubclass
|
|
242
|
+
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
|
|
243
|
+
return _get_pyarrow_schema(field_type, settings, as_schema=False)
|
|
244
|
+
|
|
245
|
+
raise SchemaCreationError(f"Unknown type: {field_type}")
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _get_pyarrow_schema(
|
|
249
|
+
pydantic_class: type[BaseModelType],
|
|
250
|
+
settings: Settings,
|
|
251
|
+
as_schema: bool = True,
|
|
252
|
+
) -> pa.Schema:
|
|
253
|
+
fields = []
|
|
254
|
+
for name, field_info in pydantic_class.model_fields.items():
|
|
255
|
+
if field_info.exclude and settings.exclude_fields:
|
|
256
|
+
continue
|
|
257
|
+
field_type = field_info.annotation
|
|
258
|
+
metadata = field_info.metadata
|
|
259
|
+
|
|
260
|
+
if field_type is None:
|
|
261
|
+
# Not sure how to get here through pydantic, hence nocover
|
|
262
|
+
raise SchemaCreationError(
|
|
263
|
+
f"Missing type for field {name}"
|
|
264
|
+
) # pragma: no cover
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
nullable = False
|
|
268
|
+
if _is_optional(field_type):
|
|
269
|
+
nullable = True
|
|
270
|
+
types_under_union = list(
|
|
271
|
+
set(get_args(field_type)) - {type(None)}
|
|
272
|
+
)
|
|
273
|
+
# mypy infers field_type as Type[Any] | None here, hence casting
|
|
274
|
+
field_type = cast(type[Any], types_under_union[0])
|
|
275
|
+
|
|
276
|
+
pa_field = _get_pyarrow_type(field_type, metadata, settings)
|
|
277
|
+
except Exception as err: # noqa: BLE001 - ignore blind exception
|
|
278
|
+
raise SchemaCreationError(
|
|
279
|
+
f"Error processing field {name}: {field_type}, {err}"
|
|
280
|
+
) from err
|
|
281
|
+
|
|
282
|
+
serialized_name = name
|
|
283
|
+
if settings.by_alias and field_info.serialization_alias is not None:
|
|
284
|
+
serialized_name = field_info.serialization_alias
|
|
285
|
+
fields.append(pa.field(serialized_name, pa_field, nullable=nullable))
|
|
286
|
+
|
|
287
|
+
if as_schema:
|
|
288
|
+
return pa.schema(fields)
|
|
289
|
+
return pa.struct(fields)
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
# def get_pyarrow_schema(
|
|
293
|
+
# pydantic_class: Type[BaseModelType],
|
|
294
|
+
# allow_losing_tz: bool = False,
|
|
295
|
+
# exclude_fields: bool = False,
|
|
296
|
+
# by_alias: bool = False,
|
|
297
|
+
# ) -> pa.Schema:
|
|
298
|
+
# """
|
|
299
|
+
# Converts a Pydantic model into a PyArrow schema.
|
|
300
|
+
#
|
|
301
|
+
# Args:
|
|
302
|
+
# pydantic_class (Type[BaseModelType]): The Pydantic model class to convert.
|
|
303
|
+
# allow_losing_tz (bool, optional): Whether to allow losing timezone information
|
|
304
|
+
# when converting datetime fields. Defaults to False.
|
|
305
|
+
# exclude_fields (bool, optional): If True, will exclude fields in the pydantic
|
|
306
|
+
# model that have `Field(exclude=True)`. Defaults to False.
|
|
307
|
+
# by_alias (bool, optional): If True, will create the pyarrow schema using the
|
|
308
|
+
# (serialization) alias in the pydantic model. Defaults to False.
|
|
309
|
+
#
|
|
310
|
+
# Returns:
|
|
311
|
+
# pa.Schema: The PyArrow schema representing the Pydantic model.
|
|
312
|
+
# """
|
|
313
|
+
# settings = Settings(
|
|
314
|
+
# allow_losing_tz=allow_losing_tz,
|
|
315
|
+
# by_alias=by_alias,
|
|
316
|
+
# exclude_fields=exclude_fields,
|
|
317
|
+
# )
|
|
318
|
+
# return _get_pyarrow_schema(pydantic_class, settings)
|
|
319
|
+
############################################################################
|
|
320
|
+
# The code above is copied from https://github.com/simw/pydantic-to-pyarrow.
|
|
321
|
+
############################################################################
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def get_pyarrow_schema(
|
|
325
|
+
pydantic_class: type[BaseModel],
|
|
326
|
+
allow_losing_tz: bool = False,
|
|
327
|
+
exclude_fields: bool = False,
|
|
328
|
+
by_alias: bool = False,
|
|
329
|
+
) -> Schema:
|
|
330
|
+
"""Convert a Pydantic model to a PyArrow schema.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
pydantic_class: Pydantic model class to convert.
|
|
334
|
+
allow_losing_tz: Allow losing timezone info in datetime fields.
|
|
335
|
+
exclude_fields: Exclude fields with Field(exclude=True).
|
|
336
|
+
by_alias: Use serialization alias for field names.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
PyArrow schema representing the Pydantic model.
|
|
340
|
+
"""
|
|
341
|
+
fields = []
|
|
342
|
+
settings = Settings(
|
|
343
|
+
allow_losing_tz=allow_losing_tz,
|
|
344
|
+
by_alias=by_alias,
|
|
345
|
+
exclude_fields=exclude_fields,
|
|
346
|
+
)
|
|
347
|
+
for name, field_info in pydantic_class.model_fields.items():
|
|
348
|
+
if field_info.exclude and settings.exclude_fields:
|
|
349
|
+
continue
|
|
350
|
+
field_type = field_info.annotation
|
|
351
|
+
metadata = field_info.metadata
|
|
352
|
+
if field_type is None:
|
|
353
|
+
# Not sure how to get here through pydantic, hence nocover
|
|
354
|
+
raise SchemaCreationError(
|
|
355
|
+
f"Missing type for field {name}"
|
|
356
|
+
) # pragma: no cover
|
|
357
|
+
serialized_name = name
|
|
358
|
+
if settings.by_alias and field_info.serialization_alias is not None:
|
|
359
|
+
serialized_name = field_info.serialization_alias
|
|
360
|
+
|
|
361
|
+
nullable = False
|
|
362
|
+
if _is_optional(field_type):
|
|
363
|
+
nullable = True
|
|
364
|
+
types_under_union = list(set(get_args(field_type)) - {type(None)})
|
|
365
|
+
# mypy infers field_type as Type[Any] | None here, hence casting
|
|
366
|
+
field_type = cast(type[Any], types_under_union[0])
|
|
367
|
+
if get_origin(field_type) is Annotated:
|
|
368
|
+
f = FieldInfo.from_annotation(field_type)
|
|
369
|
+
field_type, metadata = f.annotation, f.metadata
|
|
370
|
+
|
|
371
|
+
for metadata_item in metadata:
|
|
372
|
+
if isinstance(metadata_item, WithJsonSchema):
|
|
373
|
+
break
|
|
374
|
+
else:
|
|
375
|
+
metadata_item = None
|
|
376
|
+
if (
|
|
377
|
+
metadata_item is not None
|
|
378
|
+
and metadata_item.mode == "serialization"
|
|
379
|
+
and metadata_item.json_schema is not None
|
|
380
|
+
):
|
|
381
|
+
fdtp = metadata_item.json_schema.get("type", None)
|
|
382
|
+
if fdtp is not None and type(fdtp) is type:
|
|
383
|
+
try:
|
|
384
|
+
tp = _get_pyarrow_type(fdtp, [], settings)
|
|
385
|
+
f = field(serialized_name, tp, nullable=nullable)
|
|
386
|
+
fields.append(f)
|
|
387
|
+
continue
|
|
388
|
+
except Exception:
|
|
389
|
+
pass
|
|
390
|
+
try:
|
|
391
|
+
pa_field = _get_pyarrow_type(
|
|
392
|
+
field_type=field_type, # type: ignore
|
|
393
|
+
settings=settings,
|
|
394
|
+
metadata=metadata,
|
|
395
|
+
)
|
|
396
|
+
f = field(serialized_name, pa_field, nullable=nullable)
|
|
397
|
+
fields.append(f)
|
|
398
|
+
except Exception as err: # noqa: BLE001 - ignore blind exception
|
|
399
|
+
raise SchemaCreationError(
|
|
400
|
+
f"Error processing field {name}: {field_type}, {err}"
|
|
401
|
+
) from err
|
|
402
|
+
return schema(fields)
|