graphatoms 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphatoms-1.1.0/PKG-INFO +26 -0
  2. graphatoms-1.1.0/README.md +3 -0
  3. graphatoms-1.1.0/pyproject.toml +125 -0
  4. graphatoms-1.1.0/src/graphatoms/__init__.py +8 -0
  5. graphatoms-1.1.0/src/graphatoms/dataclasses/__init__.py +13 -0
  6. graphatoms-1.1.0/src/graphatoms/dataclasses/_numpydantic.py +62 -0
  7. graphatoms-1.1.0/src/graphatoms/dataclasses/_pydantic2pyarrow.py +402 -0
  8. graphatoms-1.1.0/src/graphatoms/dataclasses/_pydanticMixin.py +449 -0
  9. graphatoms-1.1.0/src/graphatoms/dataclasses/_pydanticModel.py +160 -0
  10. graphatoms-1.1.0/src/graphatoms/dataclasses/test_dataclasses.py +64 -0
  11. graphatoms-1.1.0/src/graphatoms/geometry/__init__.py +11 -0
  12. graphatoms-1.1.0/src/graphatoms/geometry/_bond_list.py +168 -0
  13. graphatoms-1.1.0/src/graphatoms/geometry/_distance_pairs.py +306 -0
  14. graphatoms-1.1.0/src/graphatoms/geometry/_inner_outer.py +91 -0
  15. graphatoms-1.1.0/src/graphatoms/geometry/_neighbor_list.py +284 -0
  16. graphatoms-1.1.0/src/graphatoms/geometry/rotation.py +151 -0
  17. graphatoms-1.1.0/src/graphatoms/geometry/sample.py +79 -0
  18. graphatoms-1.1.0/src/graphatoms/geometry/test_geometry.py +84 -0
  19. graphatoms-1.1.0/src/graphatoms/reaction/__init__.py +0 -0
  20. graphatoms-1.1.0/src/graphatoms/reaction/_amove/__init__.py +3 -0
  21. graphatoms-1.1.0/src/graphatoms/reaction/_amove/_abc.py +12 -0
  22. graphatoms-1.1.0/src/graphatoms/reaction/_amove/_base.py +135 -0
  23. graphatoms-1.1.0/src/graphatoms/reaction/_amove/_moveFB.py +179 -0
  24. graphatoms-1.1.0/src/graphatoms/reaction/_amove/_moveSP.py +119 -0
  25. graphatoms-1.1.0/src/graphatoms/reaction/_amove/_moveSwap.py +119 -0
  26. graphatoms-1.1.0/src/graphatoms/reaction/_amove/_simNVT.py +179 -0
  27. graphatoms-1.1.0/src/graphatoms/reaction/_amove/test_mmc.py +57 -0
  28. graphatoms-1.1.0/src/graphatoms/reaction/_amove/test_move.py +23 -0
  29. graphatoms-1.1.0/src/graphatoms/reaction/_amove/zzz.py +103 -0
  30. graphatoms-1.1.0/src/graphatoms/reaction/_event/__init__.py +10 -0
  31. graphatoms-1.1.0/src/graphatoms/reaction/_event/_base.py +203 -0
  32. graphatoms-1.1.0/src/graphatoms/reaction/_event/_createMixin.py +14 -0
  33. graphatoms-1.1.0/src/graphatoms/reaction/_event/_ioMixin.py +10 -0
  34. graphatoms-1.1.0/src/graphatoms/reaction/_event/event.py +13 -0
  35. graphatoms-1.1.0/src/graphatoms/reaction/_network/__init__.py +0 -0
  36. graphatoms-1.1.0/src/graphatoms/reaction/_network/_asedb.py +140 -0
  37. graphatoms-1.1.0/src/graphatoms/reaction/_network/_base.py +61 -0
  38. graphatoms-1.1.0/src/graphatoms/reaction/_network/_h5db.py +140 -0
  39. graphatoms-1.1.0/src/graphatoms/system/__init__.py +25 -0
  40. graphatoms-1.1.0/src/graphatoms/system/_atoms/__init__.py +5 -0
  41. graphatoms-1.1.0/src/graphatoms/system/_atoms/_box.py +85 -0
  42. graphatoms-1.1.0/src/graphatoms/system/_atoms/_eng.py +268 -0
  43. graphatoms-1.1.0/src/graphatoms/system/_atoms/_struct.py +167 -0
  44. graphatoms-1.1.0/src/graphatoms/system/_graph/__init__.py +4 -0
  45. graphatoms-1.1.0/src/graphatoms/system/_graph/_bonds.py +558 -0
  46. graphatoms-1.1.0/src/graphatoms/system/_graph/_gasMixin.py +45 -0
  47. graphatoms-1.1.0/src/graphatoms/system/_graph/_sysGraph.py +520 -0
  48. graphatoms-1.1.0/src/graphatoms/system/_other/_system.py +202 -0
  49. graphatoms-1.1.0/src/graphatoms/system/_sys/__init__.py +4 -0
  50. graphatoms-1.1.0/src/graphatoms/system/_sys/_sysCluster.py +124 -0
  51. graphatoms-1.1.0/src/graphatoms/system/_sys/_sysGas.py +181 -0
  52. graphatoms-1.1.0/src/graphatoms/system/_test_speed.py +17 -0
  53. graphatoms-1.1.0/src/graphatoms/system/test_system.py +310 -0
  54. graphatoms-1.1.0/src/graphatoms/utils/__init__.py +1 -0
  55. graphatoms-1.1.0/src/graphatoms/utils/bytestool.py +194 -0
  56. graphatoms-1.1.0/src/graphatoms/utils/logger.py +69 -0
  57. graphatoms-1.1.0/src/graphatoms/utils/parser.py +25 -0
  58. graphatoms-1.1.0/src/graphatoms/utils/rdutils.py +377 -0
@@ -0,0 +1,26 @@
1
+ Metadata-Version: 2.3
2
+ Name: graphatoms
3
+ Version: 1.1.0
4
+ Summary: The Chemical Core Class for Graph Theory Analysis.
5
+ Author: LiuGaoyong
6
+ Author-email: LiuGaoyong <liugaoyong_88@163.com>
7
+ Requires-Dist: ase
8
+ Requires-Dist: pymatgen>2023.6
9
+ Requires-Dist: rdkit>=2025
10
+ Requires-Dist: scikit-learn>=1.5
11
+ Requires-Dist: pyarrow
12
+ Requires-Dist: igraph>=0.11
13
+ Requires-Dist: hydra-core
14
+ Requires-Dist: numpydantic
15
+ Requires-Dist: pydantic>=2.11
16
+ Requires-Dist: python-snappy>=0.7.3
17
+ Requires-Dist: tomli ; python_full_version < '3.11'
18
+ Requires-Dist: tomli-w
19
+ Requires-Dist: typing-extensions
20
+ Requires-Dist: loguru
21
+ Requires-Python: >=3.10
22
+ Description-Content-Type: text/markdown
23
+
24
+ # graphatoms
25
+
26
+ The Chemical Core Class for Graph Theory Analysis & Graph Neural Network.
@@ -0,0 +1,3 @@
1
+ # graphatoms
2
+
3
+ The Chemical Core Class for Graph Theory Analysis & Graph Neural Network.
@@ -0,0 +1,125 @@
1
+ [build-system]
2
+ requires = ["uv_build>=0.11.15,<0.12"]
3
+ build-backend = "uv_build"
4
+
5
+ [project]
6
+ name = "graphatoms"
7
+ version = "1.1.0"
8
+ readme = "README.md"
9
+ description = "The Chemical Core Class for Graph Theory Analysis."
10
+ authors = [{ name = "LiuGaoyong", email = "liugaoyong_88@163.com" }]
11
+ requires-python = ">=3.10"
12
+ dependencies = [
13
+ # Chemoinformatics
14
+ "ase",
15
+ "pymatgen>2023.6",
16
+ "rdkit>=2025",
17
+
18
+ # Scientific Computing
19
+ # numpy pandas, scipy
20
+ "scikit-learn>=1.5",
21
+ "pyarrow",
22
+
23
+ # Graph Theory
24
+ "igraph>=0.11", # high performance graph library based on C
25
+
26
+ # Data Model
27
+ 'hydra-core',
28
+ "numpydantic",
29
+ "pydantic>=2.11",
30
+ "python-snappy>=0.7.3", # More Fase Data Compression
31
+ "tomli; python_version < '3.11'", # use tomllib in python>=3.11
32
+ "tomli-w", # TOML writer
33
+ "typing-extensions",
34
+ 'loguru',
35
+ ]
36
+ [dependency-groups]
37
+ dev = [
38
+ "ruff>=0.11",
39
+ "pytest>=8.3",
40
+ "pytest-xdist",
41
+ "jupyter",
42
+ "sqlmodel", # SQL + pydantic
43
+ 'openbabel-wheel',
44
+ ]
45
+
46
+ # uv config
47
+ [tool.uv]
48
+ package = true
49
+ [[tool.uv.index]]
50
+ url = "https://mirrors.cernet.edu.cn/pypi/web/simple" # CERNET
51
+ default = true
52
+ # [[tool.uv.index]]
53
+ # name = "pytorch-cpu"
54
+ # url = "https://mirrors.nju.edu.cn/pytorch/whl/cpu/"
55
+ # explicit = true
56
+ # format = "flat"
57
+ # [tool.uv.sources]
58
+ # torch = [{ index = "pytorch-cpu" }]
59
+
60
+ # pytest config
61
+ [tool.pytest.ini_options]
62
+ filterwarnings = [
63
+ "error", # All other warnings are transformed into errors.
64
+ # ignore the following warnings that matching a regex
65
+ # example: 'ignore:function ham\(\) is deprecated:DeprecationWarning',
66
+ "ignore::RuntimeWarning",
67
+
68
+ ]
69
+ addopts = '--maxfail=1 -rf' # exit after 1 failures, report fail info
70
+ testpaths = [
71
+ "src/graphatoms/dataclasses",
72
+ "src/graphatoms/geometry",
73
+ 'src/graphatoms/system',
74
+ 'src/tests',
75
+ ]
76
+
77
+ # ruff config
78
+ [tool.ruff]
79
+ line-length = 80
80
+ indent-width = 4
81
+ select = ["F", "E", "W", "UP", "D"]
82
+ ignore = ["F722", "D100"]
83
+ [tool.ruff.lint.pydocstyle]
84
+ convention = "google"
85
+
86
+ [tool.pixi.package]
87
+ name = 'graphatoms'
88
+ version = '1.1.0'
89
+
90
+ [tool.pixi.package.build]
91
+ backend = { name = "pixi-build-python", version = "*" }
92
+ channels = ["https://prefix.dev/conda-forge"]
93
+
94
+ [tool.pixi.workspace]
95
+ preview = ["pixi-build"]
96
+ channels = ["conda-forge"]
97
+ platforms = ["win-64", "linux-64", "osx-arm64"]
98
+
99
+ [tool.pixi.pypi-dependencies]
100
+ graphatoms = { path = ".", editable = true }
101
+
102
+ [tool.pixi.environments]
103
+ default = { solve-group = "default" }
104
+ dev = { features = ["dev"], solve-group = "default" }
105
+
106
+ [tool.pixi.tasks]
107
+ test = "pytest -s -vv"
108
+ benchmark = 'pytest -s graphatoms/system/_test_speed.py'
109
+
110
+ [tool.pixi.dependencies]
111
+ python = ">=3.12"
112
+ pytest = ">=9.0.3,<10"
113
+ ase = ">=3.28.0,<4"
114
+ pymatgen = ">=2026.5.4,<2027"
115
+ rdkit = ">=2025.9.5,<2027"
116
+ openbabel = ">=3.1.1,<4"
117
+ pyarrow = ">=21.0.0,<25"
118
+ scikit-learn = ">=1.8.0,<2"
119
+ python-igraph = ">=0.11.9,<0.12"
120
+ numpydantic = ">=1.8.1,<2"
121
+ python-snappy = ">=0.7.3,<0.8"
122
+ h5py = ">=3.16.0,<4"
123
+ hydra-core = ">=1.3.2,<2"
124
+ tomli-w = ">=1.2.0,<2"
125
+ loguru = ">=0.7.3,<0.8"
@@ -0,0 +1,8 @@
1
+ # ruff: noqa
2
+ # from .system import Cluster, Gas, System
3
+
4
+ # __all__ = [
5
+ # "Cluster",
6
+ # "Gas",
7
+ # "System",
8
+ # ]
@@ -0,0 +1,13 @@
1
+ """The extended dataclasses by pydantic & numpydantic."""
2
+
3
+ from ._numpydantic import NDArray, numpy_validator
4
+ from ._pydantic2pyarrow import get_pyarrow_schema
5
+ from ._pydanticModel import OurBaseModel, OurFrozenModel
6
+
7
+ __all__ = [
8
+ "NDArray",
9
+ "numpy_validator",
10
+ "get_pyarrow_schema",
11
+ "OurFrozenModel",
12
+ "OurBaseModel",
13
+ ]
@@ -0,0 +1,62 @@
1
+ from collections.abc import Sequence
2
+ from functools import partial
3
+ from typing import TYPE_CHECKING, Annotated, Any
4
+
5
+ import numpy as np
6
+ from numpydantic import NDArray as _NDArray
7
+ from pydantic import (
8
+ BeforeValidator,
9
+ PlainSerializer,
10
+ WithJsonSchema,
11
+ validate_call,
12
+ )
13
+
14
+ __all__ = ["NDArray", "numpy_validator"]
15
+ _ENCODING = "latin1"
16
+
17
+ NDArray = Annotated[
18
+ _NDArray,
19
+ PlainSerializer(lambda x: x.tobytes().decode(_ENCODING), return_type=str),
20
+ # for pyarrow compatibility, we need to convert bytes to str
21
+ # for numpy.ndarray before serialization.
22
+ WithJsonSchema({"type": str}, mode="serialization"),
23
+ ]
24
+ if TYPE_CHECKING:
25
+ from numpy.typing import NDArray
26
+
27
+
28
+ @validate_call
29
+ def __convert2numpy(
30
+ x: Any,
31
+ dtype: Any = "float",
32
+ shape: Sequence[int] = (-1,),
33
+ ) -> _NDArray:
34
+ """Convert a string or bytes or array-like object to a numpy array.
35
+
36
+ Raises:
37
+ KeyError: If the input value is a scalar number.
38
+ """
39
+ if isinstance(x, bytes | str):
40
+ x = x.encode(_ENCODING) if isinstance(x, str) else x
41
+ return np.frombuffer(x, np.dtype(dtype)).reshape(shape)
42
+ elif np.isscalar(x):
43
+ raise KeyError(
44
+ f"Invalid input value: {x}({type(x)})."
45
+ " The scalar value is not supported."
46
+ )
47
+ else:
48
+ return np.asarray(x, dtype=dtype).reshape(shape)
49
+
50
+
51
+ def numpy_validator(
52
+ dtype: Any = "float",
53
+ shape: Sequence[int] = (-1,),
54
+ ) -> BeforeValidator:
55
+ """Create pydantic validator in `before` mode for numpy array."""
56
+ return BeforeValidator(
57
+ partial(
58
+ __convert2numpy,
59
+ shape=shape,
60
+ dtype=dtype,
61
+ )
62
+ )
@@ -0,0 +1,402 @@
1
+ # ruff: noqa: E501
2
+ import datetime
3
+ import types
4
+ import uuid
5
+ from decimal import Decimal
6
+ from enum import EnumMeta
7
+ from typing import (
8
+ Annotated,
9
+ Any,
10
+ Literal,
11
+ NamedTuple,
12
+ TypeVar,
13
+ Union,
14
+ cast,
15
+ get_args,
16
+ get_origin,
17
+ )
18
+
19
+ import pyarrow as pa # type: ignore
20
+ from annotated_types import Ge, Gt
21
+ from pyarrow import Schema, field, schema
22
+ from pydantic import AwareDatetime, BaseModel, NaiveDatetime, WithJsonSchema
23
+ from pydantic.fields import FieldInfo
24
+
25
+ BaseModelType = TypeVar("BaseModelType", bound=BaseModel)
26
+ EnumType = TypeVar("EnumType", bound=EnumMeta)
27
+ __all__ = ["get_pyarrow_schema"]
28
+
29
+
30
+ class SchemaCreationError(Exception):
31
+ """Error when creating pyarrow schema."""
32
+
33
+
34
+ class Settings(NamedTuple):
35
+ allow_losing_tz: bool
36
+ by_alias: bool
37
+ exclude_fields: bool
38
+
39
+
40
+ FIELD_MAP = {
41
+ str: pa.string(),
42
+ bytes: pa.binary(),
43
+ bool: pa.bool_(),
44
+ float: pa.float64(),
45
+ datetime.date: pa.date32(),
46
+ NaiveDatetime: pa.timestamp("ms", tz=None),
47
+ datetime.time: pa.time64("us"),
48
+ }
49
+
50
+ # Timezone aware datetimes will lose their timezone information
51
+ # (but be correctly converted to UTC) when converted to pyarrow.
52
+ # Pyarrow does support having an entire column in a single timezone,
53
+ # but these bare types cannot guarantee that.
54
+ LOSING_TZ_TYPES = {
55
+ datetime.datetime: pa.timestamp("ms", tz=None),
56
+ AwareDatetime: pa.timestamp("ms", tz=None),
57
+ }
58
+
59
+
60
+ def _get_int_type(metadata: list[Any]) -> pa.DataType:
61
+ min_value: int | None = None
62
+ for el in metadata:
63
+ if isinstance(el, Gt):
64
+ if el.gt is not None and not isinstance(el.gt, int):
65
+ raise SchemaCreationError("Gt metadata must be int")
66
+ min_value = el.gt
67
+ elif isinstance(el, Ge):
68
+ if el.ge is not None and not isinstance(el.ge, int):
69
+ raise SchemaCreationError("Ge metadata must be int")
70
+ min_value = el.ge
71
+
72
+ if min_value is not None and min_value >= 0:
73
+ return pa.uint64()
74
+ return pa.int64()
75
+
76
+
77
+ def _get_decimal_type(metadata: list[Any]) -> pa.DataType:
78
+ general_metadata = None
79
+ for el in metadata:
80
+ if hasattr(el, "max_digits") and hasattr(el, "decimal_places"):
81
+ general_metadata = el
82
+ if general_metadata is None:
83
+ raise SchemaCreationError(
84
+ "Decimal type needs annotation setting max_digits and decimal_places"
85
+ )
86
+
87
+ return pa.decimal128(
88
+ general_metadata.max_digits, general_metadata.decimal_places
89
+ )
90
+
91
+
92
+ TYPES_WITH_METADATA = {
93
+ Decimal: _get_decimal_type,
94
+ int: _get_int_type,
95
+ }
96
+
97
+
98
+ def _get_literal_type(
99
+ field_type: type[Any],
100
+ _metadata: list[Any],
101
+ _settings: Settings,
102
+ ) -> pa.DataType:
103
+ values = get_args(field_type)
104
+ if all(isinstance(value, str) for value in values):
105
+ return pa.dictionary(pa.int32(), pa.string())
106
+ elif all(isinstance(value, int) for value in values):
107
+ # Dictionary of (int, int) is converted to just int when
108
+ # written into parquet.
109
+ return pa.int64()
110
+ else:
111
+ msg = "Literal type is only supported with all int or string values. "
112
+ raise SchemaCreationError(msg)
113
+
114
+
115
+ def _get_list_type(
116
+ field_type: type[Any],
117
+ metadata: list[Any],
118
+ settings: Settings,
119
+ ) -> pa.DataType:
120
+ sub_type = get_args(field_type)[0]
121
+ if _is_optional(sub_type):
122
+ # pyarrow lists can have null elements in them
123
+ sub_type = list(set(get_args(sub_type)) - {type(None)})[0]
124
+ return pa.list_(_get_pyarrow_type(sub_type, metadata, settings))
125
+
126
+
127
+ def _get_annotated_type(
128
+ field_type: type[Any],
129
+ metadata: list[Any],
130
+ settings: Settings,
131
+ ) -> pa.DataType:
132
+ # ???: fix / clean up / understand why / if this works in all cases
133
+ args = get_args(field_type)[1:]
134
+ metadatas = [
135
+ item.metadata if hasattr(item, "metadata") else [item] for item in args
136
+ ]
137
+ metadata = [item for sublist in metadatas for item in sublist]
138
+ field_type = cast(type[Any], get_args(field_type)[0])
139
+ return _get_pyarrow_type(field_type, metadata, settings)
140
+
141
+
142
+ def _get_dict_type(
143
+ field_type: type[Any],
144
+ metadata: list[Any],
145
+ settings: Settings,
146
+ ) -> pa.DataType:
147
+ key_type, value_type = get_args(field_type)
148
+ return pa.map_(
149
+ _get_pyarrow_type(key_type, metadata, settings),
150
+ _get_pyarrow_type(value_type, metadata, settings),
151
+ )
152
+
153
+
154
+ FIELD_TYPES = {
155
+ Literal: _get_literal_type,
156
+ list: _get_list_type,
157
+ Annotated: _get_annotated_type,
158
+ dict: _get_dict_type,
159
+ }
160
+
161
+
162
+ def _get_enum_type(field_type: type[Any]) -> pa.DataType:
163
+ is_str = [
164
+ isinstance(enum_value.value, str) #
165
+ for enum_value in field_type # type: ignore
166
+ ]
167
+ if all(is_str):
168
+ return pa.dictionary(pa.int32(), pa.string())
169
+
170
+ is_int = [
171
+ isinstance(enum_value.value, int) #
172
+ for enum_value in field_type # type: ignore
173
+ ]
174
+ if all(is_int):
175
+ return pa.int64()
176
+
177
+ msg = "Enums only allowed if all str or all int"
178
+ raise SchemaCreationError(msg)
179
+
180
+
181
+ def _get_uuid_type() -> pa.DataType:
182
+ # Different branches will execute depending on the pyarrow version
183
+ # This is tested through nox and python versions, but each one
184
+ # won't cover both branches. Hence, excluding from coverage.
185
+ if hasattr(pa, "uuid"): # pragma: no cover
186
+ return pa.uuid()
187
+ else: # pragma: no cover
188
+ msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, "
189
+ msg += "needs version 18.0 or higher"
190
+ raise SchemaCreationError(msg)
191
+
192
+
193
+ def _is_optional(field_type: type[Any]) -> bool:
194
+ origin = get_origin(field_type)
195
+ is_python_39_union = origin is Union
196
+ is_python_310_union = (
197
+ hasattr(types, "UnionType") and origin is types.UnionType
198
+ )
199
+
200
+ if not is_python_39_union and not is_python_310_union:
201
+ return False
202
+
203
+ return type(None) in get_args(field_type)
204
+
205
+
206
+ # noqa: PLR0911 - ignore until a refactoring can reduce the number of
207
+ # return statements.
208
+ def _get_pyarrow_type( # noqa: PLR0911
209
+ field_type: type[Any],
210
+ metadata: list[Any],
211
+ settings: Settings,
212
+ ) -> pa.DataType:
213
+ if field_type in FIELD_MAP:
214
+ return FIELD_MAP[field_type]
215
+
216
+ if field_type is uuid.UUID:
217
+ return _get_uuid_type()
218
+
219
+ if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
220
+ return LOSING_TZ_TYPES[field_type]
221
+
222
+ if not settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
223
+ raise SchemaCreationError(
224
+ f"{field_type} only allowed if ok losing timezone information"
225
+ )
226
+
227
+ if isinstance(field_type, EnumMeta):
228
+ return _get_enum_type(field_type)
229
+
230
+ if field_type in TYPES_WITH_METADATA:
231
+ return TYPES_WITH_METADATA[field_type](metadata)
232
+
233
+ if get_origin(field_type) in FIELD_TYPES:
234
+ return FIELD_TYPES[get_origin(field_type)]( # type: ignore
235
+ field_type,
236
+ metadata,
237
+ settings,
238
+ )
239
+
240
+ # isinstance(filed_type, type) checks whether it's a class
241
+ # otherwise eg Deque[int] would casue an exception on issubclass
242
+ if isinstance(field_type, type) and issubclass(field_type, BaseModel):
243
+ return _get_pyarrow_schema(field_type, settings, as_schema=False)
244
+
245
+ raise SchemaCreationError(f"Unknown type: {field_type}")
246
+
247
+
248
+ def _get_pyarrow_schema(
249
+ pydantic_class: type[BaseModelType],
250
+ settings: Settings,
251
+ as_schema: bool = True,
252
+ ) -> pa.Schema:
253
+ fields = []
254
+ for name, field_info in pydantic_class.model_fields.items():
255
+ if field_info.exclude and settings.exclude_fields:
256
+ continue
257
+ field_type = field_info.annotation
258
+ metadata = field_info.metadata
259
+
260
+ if field_type is None:
261
+ # Not sure how to get here through pydantic, hence nocover
262
+ raise SchemaCreationError(
263
+ f"Missing type for field {name}"
264
+ ) # pragma: no cover
265
+
266
+ try:
267
+ nullable = False
268
+ if _is_optional(field_type):
269
+ nullable = True
270
+ types_under_union = list(
271
+ set(get_args(field_type)) - {type(None)}
272
+ )
273
+ # mypy infers field_type as Type[Any] | None here, hence casting
274
+ field_type = cast(type[Any], types_under_union[0])
275
+
276
+ pa_field = _get_pyarrow_type(field_type, metadata, settings)
277
+ except Exception as err: # noqa: BLE001 - ignore blind exception
278
+ raise SchemaCreationError(
279
+ f"Error processing field {name}: {field_type}, {err}"
280
+ ) from err
281
+
282
+ serialized_name = name
283
+ if settings.by_alias and field_info.serialization_alias is not None:
284
+ serialized_name = field_info.serialization_alias
285
+ fields.append(pa.field(serialized_name, pa_field, nullable=nullable))
286
+
287
+ if as_schema:
288
+ return pa.schema(fields)
289
+ return pa.struct(fields)
290
+
291
+
292
+ # def get_pyarrow_schema(
293
+ # pydantic_class: Type[BaseModelType],
294
+ # allow_losing_tz: bool = False,
295
+ # exclude_fields: bool = False,
296
+ # by_alias: bool = False,
297
+ # ) -> pa.Schema:
298
+ # """
299
+ # Converts a Pydantic model into a PyArrow schema.
300
+ #
301
+ # Args:
302
+ # pydantic_class (Type[BaseModelType]): The Pydantic model class to convert.
303
+ # allow_losing_tz (bool, optional): Whether to allow losing timezone information
304
+ # when converting datetime fields. Defaults to False.
305
+ # exclude_fields (bool, optional): If True, will exclude fields in the pydantic
306
+ # model that have `Field(exclude=True)`. Defaults to False.
307
+ # by_alias (bool, optional): If True, will create the pyarrow schema using the
308
+ # (serialization) alias in the pydantic model. Defaults to False.
309
+ #
310
+ # Returns:
311
+ # pa.Schema: The PyArrow schema representing the Pydantic model.
312
+ # """
313
+ # settings = Settings(
314
+ # allow_losing_tz=allow_losing_tz,
315
+ # by_alias=by_alias,
316
+ # exclude_fields=exclude_fields,
317
+ # )
318
+ # return _get_pyarrow_schema(pydantic_class, settings)
319
+ ############################################################################
320
+ # The code above is copied from https://github.com/simw/pydantic-to-pyarrow.
321
+ ############################################################################
322
+
323
+
324
+ def get_pyarrow_schema(
325
+ pydantic_class: type[BaseModel],
326
+ allow_losing_tz: bool = False,
327
+ exclude_fields: bool = False,
328
+ by_alias: bool = False,
329
+ ) -> Schema:
330
+ """Convert a Pydantic model to a PyArrow schema.
331
+
332
+ Args:
333
+ pydantic_class: Pydantic model class to convert.
334
+ allow_losing_tz: Allow losing timezone info in datetime fields.
335
+ exclude_fields: Exclude fields with Field(exclude=True).
336
+ by_alias: Use serialization alias for field names.
337
+
338
+ Returns:
339
+ PyArrow schema representing the Pydantic model.
340
+ """
341
+ fields = []
342
+ settings = Settings(
343
+ allow_losing_tz=allow_losing_tz,
344
+ by_alias=by_alias,
345
+ exclude_fields=exclude_fields,
346
+ )
347
+ for name, field_info in pydantic_class.model_fields.items():
348
+ if field_info.exclude and settings.exclude_fields:
349
+ continue
350
+ field_type = field_info.annotation
351
+ metadata = field_info.metadata
352
+ if field_type is None:
353
+ # Not sure how to get here through pydantic, hence nocover
354
+ raise SchemaCreationError(
355
+ f"Missing type for field {name}"
356
+ ) # pragma: no cover
357
+ serialized_name = name
358
+ if settings.by_alias and field_info.serialization_alias is not None:
359
+ serialized_name = field_info.serialization_alias
360
+
361
+ nullable = False
362
+ if _is_optional(field_type):
363
+ nullable = True
364
+ types_under_union = list(set(get_args(field_type)) - {type(None)})
365
+ # mypy infers field_type as Type[Any] | None here, hence casting
366
+ field_type = cast(type[Any], types_under_union[0])
367
+ if get_origin(field_type) is Annotated:
368
+ f = FieldInfo.from_annotation(field_type)
369
+ field_type, metadata = f.annotation, f.metadata
370
+
371
+ for metadata_item in metadata:
372
+ if isinstance(metadata_item, WithJsonSchema):
373
+ break
374
+ else:
375
+ metadata_item = None
376
+ if (
377
+ metadata_item is not None
378
+ and metadata_item.mode == "serialization"
379
+ and metadata_item.json_schema is not None
380
+ ):
381
+ fdtp = metadata_item.json_schema.get("type", None)
382
+ if fdtp is not None and type(fdtp) is type:
383
+ try:
384
+ tp = _get_pyarrow_type(fdtp, [], settings)
385
+ f = field(serialized_name, tp, nullable=nullable)
386
+ fields.append(f)
387
+ continue
388
+ except Exception:
389
+ pass
390
+ try:
391
+ pa_field = _get_pyarrow_type(
392
+ field_type=field_type, # type: ignore
393
+ settings=settings,
394
+ metadata=metadata,
395
+ )
396
+ f = field(serialized_name, pa_field, nullable=nullable)
397
+ fields.append(f)
398
+ except Exception as err: # noqa: BLE001 - ignore blind exception
399
+ raise SchemaCreationError(
400
+ f"Error processing field {name}: {field_type}, {err}"
401
+ ) from err
402
+ return schema(fields)