industrial-model 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,121 @@
1
+ import inspect
2
+ from collections import defaultdict
3
+ from types import UnionType
4
+ from typing import (
5
+ Any,
6
+ TypeVar,
7
+ Union,
8
+ get_args,
9
+ get_origin,
10
+ get_type_hints,
11
+ )
12
+
13
+ from pydantic import BaseModel
14
+
15
+ from industrial_model.constants import NESTED_SEP
16
+
17
+ TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
18
+
19
+
20
+ def get_schema_properties(
21
+ cls: type[TBaseModel],
22
+ nested_separator: str = NESTED_SEP,
23
+ prefix: str | None = None,
24
+ ) -> list[str]:
25
+ data = _get_type_properties(cls, defaultdict(lambda: 0)) or {}
26
+ keys = _flatten_dict_keys(data, None, nested_separator)
27
+ if not prefix:
28
+ return keys
29
+
30
+ return [f"{prefix}{nested_separator}{key}" for key in keys]
31
+
32
+
33
+ def _get_type_properties(
34
+ cls: type[BaseModel], visited_count: defaultdict[type, int]
35
+ ) -> dict[str, Any] | None:
36
+ if visited_count[cls] > 2:
37
+ return None
38
+ visited_count[cls] += visited_count[cls] + 1
39
+
40
+ hints = get_type_hints(cls)
41
+ origins = {
42
+ key: _get_field_type(type_hint, visited_count)
43
+ for key, type_hint in hints.items()
44
+ }
45
+
46
+ return {
47
+ field_info.alias or key: origins[key][1]
48
+ for key, field_info in cls.model_fields.items()
49
+ if key in origins
50
+ }
51
+
52
+
53
+ def _get_field_type(
54
+ type_hint: type, visited_count: defaultdict[type, int]
55
+ ) -> tuple[bool, dict[str, Any] | None]:
56
+ should_iter = _type_is_list_or_union(type_hint)
57
+
58
+ if not should_iter:
59
+ return _get_field_relations(
60
+ [_cast_base_model(type_hint)], visited_count
61
+ )
62
+
63
+ entries: list[type[BaseModel] | None] = []
64
+ for arg in get_args(type_hint):
65
+ if _type_is_list_or_union(arg):
66
+ return _get_field_type(arg, visited_count)
67
+ entries.append(_cast_base_model(arg))
68
+
69
+ return _get_field_relations(entries, visited_count)
70
+
71
+
72
+ def _get_field_relations(
73
+ entries: list[type[TBaseModel] | None],
74
+ visited_count: defaultdict[type, int],
75
+ ) -> tuple[bool, dict[str, Any] | None]:
76
+ entry_type = next((type_ for type_ in entries if type_ is not None), None)
77
+
78
+ if not entry_type:
79
+ return False, None
80
+
81
+ properties = _get_type_properties(entry_type, visited_count)
82
+
83
+ return True, properties
84
+
85
+
86
+ def _type_is_list_or_union(entry: type) -> bool:
87
+ origin = get_origin(entry)
88
+ is_union = origin in (Union, UnionType)
89
+ is_list = origin in (list, list)
90
+
91
+ return is_union or is_list
92
+
93
+
94
+ def _cast_base_model(entry: type) -> type[TBaseModel] | None:
95
+ is_base_model = (
96
+ entry is not type(None)
97
+ and inspect.isclass(entry) is True
98
+ and issubclass(entry, BaseModel)
99
+ )
100
+ return entry if is_base_model else None
101
+
102
+
103
+ def _flatten_dict_keys(
104
+ data: dict[str, Any], parent_key: str | None, nested_separator: str
105
+ ) -> list[str]:
106
+ paths: set[str] = set()
107
+ for key, value in data.items():
108
+ full_key = (
109
+ f"{parent_key}{nested_separator}{key}" if parent_key else key
110
+ )
111
+ paths.add(full_key)
112
+ if isinstance(value, dict) and value:
113
+ paths.update(_flatten_dict_keys(value, full_key, nested_separator))
114
+ elif isinstance(value, str):
115
+ paths.add(f"{full_key}{nested_separator}{value}")
116
+ elif isinstance(value, list | set):
117
+ paths.update(
118
+ [f"{full_key}{nested_separator}{item}" for item in value]
119
+ )
120
+
121
+ return list(paths)
@@ -0,0 +1,21 @@
1
+ from cognite.client.data_classes.data_modeling import InstanceSort, View
2
+
3
+ from industrial_model.cognite_adapters.utils import get_property_ref
4
+ from industrial_model.constants import SORT_DIRECTION
5
+ from industrial_model.statements.expressions import Column
6
+
7
+
8
+ class SortMapper:
9
+ def map(
10
+ self,
11
+ sort_clauses: list[tuple[Column, SORT_DIRECTION]],
12
+ root_view: View,
13
+ ) -> list[InstanceSort]:
14
+ return [
15
+ InstanceSort(
16
+ property=get_property_ref(column.property, root_view),
17
+ direction=direction,
18
+ nulls_first=direction == "descending",
19
+ )
20
+ for column, direction in sort_clauses
21
+ ]
@@ -0,0 +1,33 @@
1
+ from typing import Literal
2
+
3
+ from cognite.client.data_classes.data_modeling import (
4
+ View,
5
+ ViewId,
6
+ )
7
+
8
+ from industrial_model.models import TViewInstance
9
+
10
+ NODE_PROPERTIES = {"externalId", "space", "createdTime", "deletedTime"}
11
+ INTANCE_TYPE = Literal["node", "edge"]
12
+
13
+
14
+ def get_property_ref(
15
+ property: str, view: View | ViewId, instance_type: INTANCE_TYPE = "node"
16
+ ) -> tuple[str, str, str] | tuple[str, str]:
17
+ return (
18
+ view.as_property_ref(property)
19
+ if property not in NODE_PROPERTIES
20
+ else (instance_type, property)
21
+ )
22
+
23
+
24
+ def get_cognite_instance_ids(
25
+ instance_ids: list[TViewInstance],
26
+ ) -> list[dict[str, str]]:
27
+ return [
28
+ get_cognite_instance_id(instance_id) for instance_id in instance_ids
29
+ ]
30
+
31
+
32
+ def get_cognite_instance_id(instance_id: TViewInstance) -> dict[str, str]:
33
+ return {"space": instance_id.space, "externalId": instance_id.external_id}
@@ -0,0 +1,16 @@
1
+ from cognite.client.data_classes.data_modeling import (
2
+ View,
3
+ )
4
+
5
+
6
+ class ViewMapper:
7
+ def __init__(self, views: list[View]):
8
+ self._views_as_dict = {view.external_id: view for view in views}
9
+
10
+ def get_view(self, view_external_id: str) -> View:
11
+ if view_external_id not in self._views_as_dict:
12
+ raise ValueError(
13
+ f"View {view_external_id} is not available in data model"
14
+ )
15
+
16
+ return self._views_as_dict[view_external_id]
@@ -0,0 +1,10 @@
1
+ from industrial_model.models import RootModel
2
+
3
+
4
+ class DataModelId(RootModel):
5
+ external_id: str
6
+ space: str
7
+ version: str
8
+
9
+ def as_tuple(self) -> tuple[str, str, str]:
10
+ return self.space, self.external_id, self.version
@@ -0,0 +1,24 @@
1
+ from typing import Literal
2
+
3
+ SORT_DIRECTION = Literal["ascending", "descending"]
4
+ LEAF_EXPRESSION_OPERATORS = Literal[
5
+ "==",
6
+ "in",
7
+ ">=",
8
+ ">",
9
+ "<=",
10
+ "<",
11
+ "nested",
12
+ "exists",
13
+ "prefix",
14
+ "containsAll",
15
+ "containsAny",
16
+ ]
17
+ BOOL_EXPRESSION_OPERATORS = Literal["not", "and", "or"]
18
+
19
+
20
+ NESTED_SEP = "|"
21
+ EDGE_MARKER = "<EdgeMarker>"
22
+ EDGE_DIRECTION = Literal["outwards", "inwards"]
23
+ MAX_LIMIT = 10_000
24
+ DEFAULT_LIMIT = 5_000
@@ -0,0 +1,4 @@
1
+ from .async_engine import AsyncEngine
2
+ from .engine import Engine
3
+
4
+ __all__ = ["Engine", "AsyncEngine"]
@@ -0,0 +1,37 @@
1
+ from cognite.client import CogniteClient
2
+
3
+ from industrial_model.config import DataModelId
4
+ from industrial_model.models import (
5
+ PaginatedResult,
6
+ TViewInstance,
7
+ ValidationMode,
8
+ )
9
+ from industrial_model.statements import Statement
10
+ from industrial_model.utils import run_async
11
+
12
+ from .engine import Engine
13
+
14
+
15
+ class AsyncEngine:
16
+ def __init__(
17
+ self,
18
+ cognite_client: CogniteClient,
19
+ data_model_id: DataModelId,
20
+ ):
21
+ self._engine = Engine(cognite_client, data_model_id)
22
+
23
+ async def query_async(
24
+ self,
25
+ statement: Statement[TViewInstance],
26
+ validation_mode: ValidationMode = "raiseOnError",
27
+ ) -> PaginatedResult[TViewInstance]:
28
+ return await run_async(self._engine.query, statement, validation_mode)
29
+
30
+ async def query_all_pages_async(
31
+ self,
32
+ statement: Statement[TViewInstance],
33
+ validation_mode: ValidationMode = "raiseOnError",
34
+ ) -> list[TViewInstance]:
35
+ return await run_async(
36
+ self._engine.query_all_pages, statement, validation_mode
37
+ )
@@ -0,0 +1,62 @@
1
+ from typing import Any
2
+
3
+ from cognite.client import CogniteClient
4
+
5
+ from industrial_model.cognite_adapters import CogniteAdapter
6
+ from industrial_model.config import DataModelId
7
+ from industrial_model.models import (
8
+ PaginatedResult,
9
+ TViewInstance,
10
+ ValidationMode,
11
+ )
12
+ from industrial_model.statements import Statement
13
+
14
+
15
+ class Engine:
16
+ def __init__(
17
+ self,
18
+ cognite_client: CogniteClient,
19
+ data_model_id: DataModelId,
20
+ ):
21
+ self._cognite_adapter = CogniteAdapter(cognite_client, data_model_id)
22
+
23
+ def query(
24
+ self,
25
+ statement: Statement[TViewInstance],
26
+ validation_mode: ValidationMode = "raiseOnError",
27
+ ) -> PaginatedResult[TViewInstance]:
28
+ data, next_cursor = self._cognite_adapter.query(statement, False)
29
+
30
+ return PaginatedResult(
31
+ data=self._validate_data(statement.entity, data, validation_mode),
32
+ next_cursor=next_cursor,
33
+ has_next_page=next_cursor is not None,
34
+ )
35
+
36
+ def query_all_pages(
37
+ self,
38
+ statement: Statement[TViewInstance],
39
+ validation_mode: ValidationMode = "raiseOnError",
40
+ ) -> list[TViewInstance]:
41
+ if statement.cursor_:
42
+ raise ValueError("Cursor should be none when querying all pages")
43
+
44
+ data, _ = self._cognite_adapter.query(statement, True)
45
+
46
+ return self._validate_data(statement.entity, data, validation_mode)
47
+
48
+ def _validate_data(
49
+ self,
50
+ entity: type[TViewInstance],
51
+ data: list[dict[str, Any]],
52
+ validation_mode: ValidationMode,
53
+ ) -> list[TViewInstance]:
54
+ result: list[TViewInstance] = []
55
+ for item in data:
56
+ try:
57
+ result.append(entity.model_validate(item))
58
+ except Exception:
59
+ if validation_mode == "ignoreOnError":
60
+ continue
61
+ raise
62
+ return result
@@ -0,0 +1,19 @@
1
+ from .base import RootModel
2
+ from .entities import (
3
+ InstanceId,
4
+ PaginatedResult,
5
+ TViewInstance,
6
+ ValidationMode,
7
+ ViewInstance,
8
+ ViewInstanceConfig,
9
+ )
10
+
11
+ __all__ = [
12
+ "RootModel",
13
+ "InstanceId",
14
+ "TViewInstance",
15
+ "ViewInstance",
16
+ "ValidationMode",
17
+ "PaginatedResult",
18
+ "ViewInstanceConfig",
19
+ ]
@@ -0,0 +1,46 @@
1
+ from typing import (
2
+ Any,
3
+ dataclass_transform,
4
+ )
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+ from pydantic._internal import _model_construction
8
+ from pydantic.alias_generators import to_camel
9
+
10
+ from industrial_model.statements import Column
11
+
12
+
13
+ @dataclass_transform(kw_only_default=True, field_specifiers=(Field,))
14
+ class DBModelMetaclass(_model_construction.ModelMetaclass):
15
+ is_constructing: bool = False
16
+
17
+ def __new__(
18
+ mcs,
19
+ cls_name: str,
20
+ bases: tuple[type[Any], ...],
21
+ namespace: dict[str, Any],
22
+ **kwargs: Any,
23
+ ) -> type:
24
+ mcs.is_constructing = True
25
+ cls = super().__new__(mcs, cls_name, bases, namespace, **kwargs)
26
+ mcs.is_constructing = False
27
+ return cls
28
+
29
+ def __getattr__(self, key: str) -> Any:
30
+ if self.is_constructing:
31
+ return super().__getattr__(key) # type: ignore
32
+
33
+ try:
34
+ return super().__getattr__(key) # type: ignore
35
+ except AttributeError:
36
+ if key in self.model_fields:
37
+ return Column(self.model_fields[key].alias or key)
38
+ raise
39
+
40
+
41
+ class RootModel(BaseModel):
42
+ model_config = ConfigDict(
43
+ alias_generator=to_camel,
44
+ populate_by_name=True,
45
+ from_attributes=True,
46
+ )
@@ -0,0 +1,55 @@
1
+ from typing import (
2
+ Any,
3
+ ClassVar,
4
+ Generic,
5
+ Literal,
6
+ TypedDict,
7
+ TypeVar,
8
+ )
9
+
10
+ from .base import DBModelMetaclass, RootModel
11
+
12
+
13
+ class InstanceId(RootModel):
14
+ external_id: str
15
+ space: str
16
+
17
+ def __hash__(self) -> int:
18
+ return hash((self.external_id, self.space))
19
+
20
+ def __eq__(self, other: Any) -> bool:
21
+ return (
22
+ other is not None
23
+ and isinstance(other, InstanceId)
24
+ and self.external_id == other.external_id
25
+ and self.space == other.space
26
+ )
27
+
28
+ def as_tuple(self) -> tuple[str, str]:
29
+ return (self.space, self.external_id)
30
+
31
+
32
+ class ViewInstanceConfig(TypedDict, total=False):
33
+ view_external_id: str | None
34
+ instance_spaces: list[str] | None
35
+ instance_spaces_prefix: str | None
36
+
37
+
38
+ class ViewInstance(InstanceId, metaclass=DBModelMetaclass):
39
+ view_config: ClassVar[ViewInstanceConfig] = ViewInstanceConfig()
40
+
41
+ @classmethod
42
+ def get_view_external_id(cls) -> str:
43
+ return cls.view_config.get("view_external_id") or cls.__name__
44
+
45
+
46
+ TViewInstance = TypeVar("TViewInstance", bound=ViewInstance)
47
+
48
+
49
+ class PaginatedResult(RootModel, Generic[TViewInstance]):
50
+ data: list[TViewInstance]
51
+ has_next_page: bool
52
+ next_cursor: str | None
53
+
54
+
55
+ ValidationMode = Literal["raiseOnError", "ignoreOnError"]
File without changes
@@ -0,0 +1,10 @@
1
+ from .models import BasePaginatedQuery, BaseQuery
2
+ from .params import NestedQueryParam, QueryParam, SortParam
3
+
4
+ __all__ = [
5
+ "BaseQuery",
6
+ "BasePaginatedQuery",
7
+ "SortParam",
8
+ "QueryParam",
9
+ "NestedQueryParam",
10
+ ]
@@ -0,0 +1,37 @@
1
+ from industrial_model.models import RootModel, TViewInstance
2
+ from industrial_model.statements import Statement
3
+
4
+ from .params import NestedQueryParam, QueryParam, SortParam
5
+
6
+
7
+ class BaseQuery(RootModel):
8
+ def to_statement(
9
+ self, entity: type[TViewInstance]
10
+ ) -> Statement[TViewInstance]:
11
+ statement = Statement(entity)
12
+
13
+ for key, item in self.__class__.model_fields.items():
14
+ values = getattr(self, key)
15
+ if not values:
16
+ continue
17
+ for metadata_item in item.metadata:
18
+ if isinstance(metadata_item, SortParam):
19
+ statement.sort(values, metadata_item.direction)
20
+ elif isinstance(metadata_item, QueryParam | NestedQueryParam):
21
+ statement.where(metadata_item.to_expression(values))
22
+
23
+ return statement
24
+
25
+
26
+ class BasePaginatedQuery(BaseQuery):
27
+ limit: int = 1000
28
+ cursor: str | None = None
29
+
30
+ def to_statement(
31
+ self, entity: type[TViewInstance]
32
+ ) -> Statement[TViewInstance]:
33
+ statement = super().to_statement(entity)
34
+ statement.limit(self.limit)
35
+ statement.cursor(self.cursor)
36
+
37
+ return statement
@@ -0,0 +1,42 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any
3
+
4
+ from industrial_model.constants import (
5
+ LEAF_EXPRESSION_OPERATORS,
6
+ SORT_DIRECTION,
7
+ )
8
+ from industrial_model.statements import LeafExpression
9
+
10
+
11
+ @dataclass
12
+ class QueryParam:
13
+ property: str
14
+ operator: LEAF_EXPRESSION_OPERATORS
15
+
16
+ def to_expression(self, value: Any) -> LeafExpression:
17
+ if self.operator == "nested":
18
+ raise ValueError("Can not have nested operator on QuertParam")
19
+
20
+ return LeafExpression(
21
+ property=self.property,
22
+ operator=self.operator,
23
+ value=value,
24
+ )
25
+
26
+
27
+ @dataclass
28
+ class NestedQueryParam:
29
+ property: str
30
+ value: QueryParam
31
+
32
+ def to_expression(self, value: Any) -> LeafExpression:
33
+ return LeafExpression(
34
+ property=self.property,
35
+ operator="nested",
36
+ value=self.value.to_expression(value),
37
+ )
38
+
39
+
40
+ @dataclass
41
+ class SortParam:
42
+ direction: SORT_DIRECTION
@@ -0,0 +1,70 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Generic, Self, TypeVar
3
+
4
+ from industrial_model.constants import DEFAULT_LIMIT, SORT_DIRECTION
5
+
6
+ from .expressions import (
7
+ BoolExpression,
8
+ Column,
9
+ Expression,
10
+ LeafExpression,
11
+ and_,
12
+ col,
13
+ or_,
14
+ )
15
+
16
+ T = TypeVar("T")
17
+
18
+
19
+ @dataclass
20
+ class Statement(Generic[T]):
21
+ entity: type[T] = field(init=True)
22
+ where_clauses: list[Expression] = field(init=False, default_factory=list)
23
+ sort_clauses: list[tuple[Column, SORT_DIRECTION]] = field(
24
+ init=False, default_factory=list
25
+ )
26
+ limit_: int = field(init=False, default=DEFAULT_LIMIT)
27
+ cursor_: str | None = field(init=False, default=None)
28
+
29
+ def where(self, *expressions: bool | Expression) -> Self:
30
+ for expression in expressions:
31
+ assert isinstance(expression, Expression)
32
+ self.where_clauses.append(expression)
33
+ return self
34
+
35
+ def asc(self, property: Any) -> Self:
36
+ self.sort_clauses.append((Column(property), "ascending"))
37
+ return self
38
+
39
+ def desc(self, property: Any) -> Self:
40
+ self.sort_clauses.append((Column(property), "descending"))
41
+ return self
42
+
43
+ def sort(self, property: Any, direction: SORT_DIRECTION) -> Self:
44
+ self.sort_clauses.append((Column(property), direction))
45
+ return self
46
+
47
+ def limit(self, limit: int) -> Self:
48
+ self.limit_ = limit
49
+ return self
50
+
51
+ def cursor(self, cursor: str | None) -> Self:
52
+ self.cursor_ = cursor
53
+ return self
54
+
55
+
56
+ def select(entity: type[T]) -> Statement[T]:
57
+ return Statement(entity)
58
+
59
+
60
+ __all__ = [
61
+ "Statement",
62
+ "select",
63
+ "Column",
64
+ "col",
65
+ "Expression",
66
+ "LeafExpression",
67
+ "BoolExpression",
68
+ "and_",
69
+ "or_",
70
+ ]