cognite-toolkit 0.5.60__py3-none-any.whl → 0.5.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,7 +1,12 @@
1
- from collections.abc import Callable
1
+ from abc import ABC, abstractmethod
2
+ from collections.abc import Callable, Mapping
2
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from functools import cached_property
5
+ from typing import ClassVar, Literal, TypeAlias, overload
3
6
 
4
7
  from cognite.client.exceptions import CogniteException
8
+ from rich import box
9
+ from rich.console import Console
5
10
  from rich.live import Live
6
11
  from rich.spinner import Spinner
7
12
  from rich.table import Table
@@ -23,108 +28,212 @@ from cognite_toolkit._cdf_tk.utils.aggregators import (
23
28
  from ._base import ToolkitCommand
24
29
 
25
30
 
26
- class ProfileCommand(ToolkitCommand):
27
- class Columns:
28
- Resource = "Resource"
29
- Count = "Count"
30
- MetadataKeyCount = "Metadata Key Count"
31
- LabelCount = "Label Count"
32
- Transformation = "Transformations"
31
+ class WaitingAPICallClass:
32
+ def __bool__(self) -> bool:
33
+ return False
33
34
 
34
- columns = (
35
- Columns.Resource,
36
- Columns.Count,
37
- Columns.MetadataKeyCount,
38
- Columns.LabelCount,
39
- Columns.Transformation,
40
- )
41
- spinner_speed = 1.0
42
35
 
43
- @classmethod
44
- def asset_centric(
45
- cls,
46
- client: ToolkitClient,
47
- verbose: bool = False,
48
- ) -> list[dict[str, str]]:
49
- aggregators: list[AssetCentricAggregator] = [
50
- AssetAggregator(client),
51
- EventAggregator(client),
52
- FileAggregator(client),
53
- TimeSeriesAggregator(client),
54
- SequenceAggregator(client),
55
- RelationshipAggregator(client),
56
- LabelCountAggregator(client),
57
- ]
58
- results, api_calls = cls._create_initial_table(aggregators)
59
- with Live(cls.create_profile_table(results), refresh_per_second=4) as live:
60
- with ThreadPoolExecutor(max_workers=8) as executor:
61
- future_to_cell = {
62
- executor.submit(api_calls[(index, col)]): (index, col)
63
- for index in range(len(aggregators))
64
- for col in cls.columns
65
- if (index, col) in api_calls
36
+ WaitingAPICall = WaitingAPICallClass()
37
+
38
+ PendingCellValue: TypeAlias = int | float | str | bool | None | WaitingAPICallClass
39
+ CellValue: TypeAlias = int | float | str | bool | None
40
+ PendingTable: TypeAlias = dict[tuple[str, str], PendingCellValue]
41
+
42
+
43
+ class ProfileCommand(ToolkitCommand, ABC):
44
+ def __init__(self, print_warning: bool = True, skip_tracking: bool = False, silent: bool = False) -> None:
45
+ super().__init__(print_warning, skip_tracking, silent)
46
+ self.table_title = self.__class__.__name__.removesuffix("Command")
47
+
48
+ class Columns: # Placeholder for columns, subclasses should define their own Columns class
49
+ ...
50
+
51
+ spinner_args: ClassVar[Mapping] = dict(name="arc", text="loading...", style="bold green", speed=1.0)
52
+
53
+ max_workers = 8
54
+ is_dynamic_table = False
55
+
56
+ @cached_property
57
+ def columns(self) -> tuple[str, ...]:
58
+ return (
59
+ tuple([attr for attr in self.Columns.__dict__.keys() if not attr.startswith("_")])
60
+ if hasattr(self, "Columns")
61
+ else tuple()
62
+ )
63
+
64
+ def create_profile_table(self, client: ToolkitClient) -> list[dict[str, CellValue]]:
65
+ console = Console()
66
+ with console.status("Setting up", spinner="aesthetic", speed=0.4) as _:
67
+ table = self.create_initial_table(client)
68
+ with (
69
+ Live(self.draw_table(table), refresh_per_second=4, console=console) as live,
70
+ ThreadPoolExecutor(max_workers=self.max_workers) as executor,
71
+ ):
72
+ while True:
73
+ current_calls = {
74
+ executor.submit(self.call_api(row, col, client)): (row, col)
75
+ for (row, col), cell in table.items()
76
+ if cell is WaitingAPICall
66
77
  }
67
- for future in as_completed(future_to_cell):
68
- index, col = future_to_cell[future]
69
- results[index][col] = future.result()
70
- live.update(cls.create_profile_table(results))
71
- return [{col: str(value) for col, value in row.items()} for row in results]
78
+ if not current_calls:
79
+ break
80
+ for future in as_completed(current_calls):
81
+ row, col = current_calls[future]
82
+ try:
83
+ result = future.result()
84
+ except CogniteException as e:
85
+ result = type(e).__name__
86
+ table[(row, col)] = self.format_result(result, row, col)
87
+ if self.is_dynamic_table:
88
+ table = self.update_table(table, result, row, col)
89
+ live.update(self.draw_table(table))
90
+ return self.as_record_format(table, allow_waiting_api_call=False)
72
91
 
73
- @classmethod
74
- def _create_initial_table(
75
- cls, aggregators: list[AssetCentricAggregator]
76
- ) -> tuple[list[dict[str, str | Spinner]], dict[tuple[int, str], Callable[[], str]]]:
77
- rows: list[dict[str, str | Spinner]] = []
78
- api_calls: dict[tuple[int, str], Callable[[], str]] = {}
79
- for index, aggregator in enumerate(aggregators):
80
- row: dict[str, str | Spinner] = {
81
- cls.Columns.Resource: aggregator.display_name,
82
- cls.Columns.Count: Spinner("arc", text="loading...", style="bold green", speed=cls.spinner_speed),
83
- }
84
- api_calls[(index, cls.Columns.Count)] = cls._call_api(aggregator.count)
85
- count: str | Spinner = "-"
86
- if isinstance(aggregator, MetadataAggregator):
87
- count = Spinner("arc", text="loading...", style="bold green", speed=cls.spinner_speed)
88
- api_calls[(index, cls.Columns.MetadataKeyCount)] = cls._call_api(aggregator.metadata_key_count)
89
- row[cls.Columns.MetadataKeyCount] = count
92
+ @abstractmethod
93
+ def create_initial_table(self, client: ToolkitClient) -> PendingTable:
94
+ """
95
+ Create the initial table with placeholders for API calls.
96
+ Each cell that requires an API call should be initialized with WaitingAPICall.
97
+ """
98
+ raise NotImplementedError("Subclasses must implement create_initial_table.")
90
99
 
91
- count = "-"
92
- if isinstance(aggregator, LabelAggregator):
93
- count = Spinner("arc", text="loading...", style="bold green", speed=cls.spinner_speed)
94
- api_calls[(index, cls.Columns.LabelCount)] = cls._call_api(aggregator.label_count)
95
- row[cls.Columns.LabelCount] = count
100
+ @abstractmethod
101
+ def call_api(self, row: str, col: str, client: ToolkitClient) -> Callable:
102
+ raise NotImplementedError("Subclasses must implement call_api.")
96
103
 
97
- row[cls.Columns.Transformation] = Spinner(
98
- "arc", text="loading...", style="bold green", speed=cls.spinner_speed
99
- )
100
- api_calls[(index, cls.Columns.Transformation)] = cls._call_api(aggregator.transformation_count)
104
+ def format_result(self, result: object, row: str, col: str) -> CellValue:
105
+ """
106
+ Format the result of an API call for display in the table.
107
+ This can be overridden by subclasses to customize formatting.
108
+ """
109
+ if isinstance(result, int | float | bool | str):
110
+ return result
111
+ raise NotImplementedError("Subclasses must implement format_result.")
101
112
 
102
- rows.append(row)
103
- return rows, api_calls
113
+ def update_table(
114
+ self,
115
+ current_table: PendingTable,
116
+ result: object,
117
+ row: str,
118
+ col: str,
119
+ ) -> PendingTable:
120
+ raise NotImplementedError("Subclasses must implement update_table.")
104
121
 
105
- @classmethod
106
- def create_profile_table(cls, rows: list[dict[str, str | Spinner]]) -> Table:
107
- table = Table(
108
- title="Asset Centric Profile",
122
+ def draw_table(self, table: PendingTable) -> Table:
123
+ rich_table = Table(
124
+ title=self.table_title,
109
125
  title_justify="left",
110
126
  show_header=True,
111
127
  header_style="bold magenta",
128
+ box=box.MINIMAL,
112
129
  )
113
- for col in cls.columns:
114
- table.add_column(col)
130
+ for col in self.columns:
131
+ rich_table.add_column(col)
132
+
133
+ rows = self.as_record_format(table)
115
134
 
116
135
  for row in rows:
117
- table.add_row(*row.values())
118
- return table
136
+ rich_table.add_row(*[self._as_cell(value) for value in row.values()])
137
+ return rich_table
138
+
139
+ @classmethod
140
+ @overload
141
+ def as_record_format(
142
+ cls, table: PendingTable, allow_waiting_api_call: Literal[True] = True
143
+ ) -> list[dict[str, PendingCellValue]]: ...
144
+
145
+ @classmethod
146
+ @overload
147
+ def as_record_format(
148
+ cls,
149
+ table: PendingTable,
150
+ allow_waiting_api_call: Literal[False],
151
+ ) -> list[dict[str, CellValue]]: ...
119
152
 
120
- @staticmethod
121
- def _call_api(call_fun: Callable[[], int]) -> Callable[[], str]:
122
- def styled_callable() -> str:
123
- try:
124
- value = call_fun()
125
- except CogniteException as e:
126
- return type(e).__name__
153
+ @classmethod
154
+ def as_record_format(
155
+ cls,
156
+ table: PendingTable,
157
+ allow_waiting_api_call: bool = True,
158
+ ) -> list[dict[str, PendingCellValue]] | list[dict[str, CellValue]]:
159
+ rows: list[dict[str, PendingCellValue]] = []
160
+ row_indices: dict[str, int] = {}
161
+ for (row, col), value in table.items():
162
+ if value is WaitingAPICall and not allow_waiting_api_call:
163
+ value = None
164
+ if row not in row_indices:
165
+ row_indices[row] = len(rows)
166
+ rows.append({col: value})
127
167
  else:
128
- return f"{value:,}"
168
+ rows[row_indices[row]][col] = value
169
+ return rows
170
+
171
+ def _as_cell(self, value: PendingCellValue) -> str | Spinner:
172
+ if isinstance(value, WaitingAPICallClass):
173
+ return Spinner(**self.spinner_args)
174
+ elif isinstance(value, int):
175
+ return f"{value:,}"
176
+ elif isinstance(value, float):
177
+ return f"{value:.2f}"
178
+ elif value is None:
179
+ return "-"
180
+ return str(value)
181
+
182
+
183
+ class ProfileAssetCentricCommand(ProfileCommand):
184
+ def __init__(self, print_warning: bool = True, skip_tracking: bool = False, silent: bool = False) -> None:
185
+ super().__init__(print_warning, skip_tracking, silent)
186
+ self.table_title = "Asset Centric Profile"
187
+ self.aggregators: dict[str, AssetCentricAggregator] = {}
188
+
189
+ class Columns:
190
+ Resource = "Resource"
191
+ Count = "Count"
192
+ MetadataKeyCount = "Metadata Key Count"
193
+ LabelCount = "Label Count"
194
+ Transformation = "Transformations"
195
+
196
+ def asset_centric(self, client: ToolkitClient, verbose: bool = False) -> list[dict[str, CellValue]]:
197
+ self.aggregators.update(
198
+ {
199
+ agg.display_name: agg
200
+ for agg in [
201
+ AssetAggregator(client),
202
+ EventAggregator(client),
203
+ FileAggregator(client),
204
+ TimeSeriesAggregator(client),
205
+ SequenceAggregator(client),
206
+ RelationshipAggregator(client),
207
+ LabelCountAggregator(client),
208
+ ]
209
+ }
210
+ )
211
+ return self.create_profile_table(client)
212
+
213
+ def create_initial_table(self, client: ToolkitClient) -> PendingTable:
214
+ table: dict[tuple[str, str], str | int | float | bool | None | WaitingAPICallClass] = {}
215
+ for index, aggregator in self.aggregators.items():
216
+ table[(index, self.Columns.Resource)] = aggregator.display_name
217
+ table[(index, self.Columns.Count)] = WaitingAPICall
218
+ if isinstance(aggregator, MetadataAggregator):
219
+ table[(index, self.Columns.MetadataKeyCount)] = WaitingAPICall
220
+ else:
221
+ table[(index, self.Columns.MetadataKeyCount)] = None
222
+ if isinstance(aggregator, LabelAggregator):
223
+ table[(index, self.Columns.LabelCount)] = WaitingAPICall
224
+ else:
225
+ table[(index, self.Columns.LabelCount)] = None
226
+ table[(index, self.Columns.Transformation)] = WaitingAPICall
227
+ return table
129
228
 
130
- return styled_callable
229
+ def call_api(self, row: str, col: str, client: ToolkitClient) -> Callable:
230
+ aggregator = self.aggregators[row]
231
+ if col == self.Columns.Count:
232
+ return aggregator.count
233
+ elif col == self.Columns.MetadataKeyCount and isinstance(aggregator, MetadataAggregator):
234
+ return aggregator.metadata_key_count
235
+ elif col == self.Columns.LabelCount and isinstance(aggregator, LabelAggregator):
236
+ return aggregator.label_count
237
+ elif col == self.Columns.Transformation:
238
+ return aggregator.transformation_count
239
+ raise ValueError(f"Unknown column: {col} for row: {row}")
@@ -1,12 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import shutil
4
+ import tempfile
5
+ import zipfile
4
6
  from collections import Counter
7
+ from hashlib import sha256
5
8
  from importlib import resources
6
9
  from pathlib import Path
10
+ from types import TracebackType
7
11
  from typing import Any, Literal, Optional
8
12
 
9
13
  import questionary
14
+ import requests
10
15
  import typer
11
16
  from packaging.version import Version
12
17
  from packaging.version import parse as parse_version
@@ -14,7 +19,7 @@ from rich import print
14
19
  from rich.markdown import Markdown
15
20
  from rich.padding import Padding
16
21
  from rich.panel import Panel
17
- from rich.progress import track
22
+ from rich.progress import Progress, track
18
23
  from rich.rule import Rule
19
24
  from rich.table import Table
20
25
  from rich.tree import Tree
@@ -47,7 +52,8 @@ from cognite_toolkit._cdf_tk.data_classes import (
47
52
  Package,
48
53
  Packages,
49
54
  )
50
- from cognite_toolkit._cdf_tk.exceptions import ToolkitRequiredValueError, ToolkitValueError
55
+ from cognite_toolkit._cdf_tk.exceptions import ToolkitError, ToolkitRequiredValueError, ToolkitValueError
56
+ from cognite_toolkit._cdf_tk.feature_flags import Flags
51
57
  from cognite_toolkit._cdf_tk.hints import verify_module_directory
52
58
  from cognite_toolkit._cdf_tk.tk_warnings import MediumSeverityWarning
53
59
  from cognite_toolkit._cdf_tk.utils import humanize_collection, read_yaml_file
@@ -85,6 +91,28 @@ class ModulesCommand(ToolkitCommand):
85
91
  def __init__(self, print_warning: bool = True, skip_tracking: bool = False, silent: bool = False):
86
92
  super().__init__(print_warning, skip_tracking, silent)
87
93
  self._builtin_modules_path = Path(resources.files(cognite_toolkit.__name__)) / BUILTIN_MODULES # type: ignore [arg-type]
94
+ self._temp_download_dir = Path(tempfile.gettempdir()) / "library_downloads"
95
+ if not self._temp_download_dir.exists():
96
+ self._temp_download_dir.mkdir(parents=True, exist_ok=True)
97
+
98
+ def __enter__(self) -> ModulesCommand:
99
+ """
100
+ Context manager to ensure the temporary download directory is cleaned up after use. It requires the command to be used in a `with` block.
101
+ """
102
+ return self
103
+
104
+ def __exit__(
105
+ self,
106
+ exc_type: type[BaseException] | None, # Type of the exception
107
+ exc_value: BaseException | None, # Exception instance
108
+ traceback: TracebackType | None, # Traceback object
109
+ ) -> None:
110
+ """
111
+ Clean up the temporary download directory.
112
+ """
113
+
114
+ if self._temp_download_dir.exists():
115
+ safe_rmtree(self._temp_download_dir)
88
116
 
89
117
  @classmethod
90
118
  def _create_tree(cls, item: Packages) -> Tree:
@@ -128,6 +156,7 @@ class ModulesCommand(ToolkitCommand):
128
156
  downloader_by_repo: dict[str, FileDownloader] = {}
129
157
 
130
158
  extra_resources: set[Path] = set()
159
+
131
160
  for package_name, package in selected_packages.items():
132
161
  print(f"{INDENT}[{'yellow' if mode == 'clean' else 'green'}]Creating {package_name}[/]")
133
162
 
@@ -280,7 +309,8 @@ default_organization_dir = "{organization_dir.name}"''',
280
309
  organization_dir = Path(organization_dir_raw.strip())
281
310
 
282
311
  modules_root_dir = organization_dir / MODULES
283
- packages = Packages().load(self._builtin_modules_path)
312
+
313
+ packages = self._get_available_packages()
284
314
 
285
315
  if select_all:
286
316
  print(Panel("Instantiating all available modules"))
@@ -680,9 +710,102 @@ default_organization_dir = "{organization_dir.name}"''',
680
710
  build_env = default.environment.validation_type
681
711
 
682
712
  existing_module_names = [module.name for module in ModuleResources(organization_dir, build_env).list()]
683
- available_packages = Packages().load(self._builtin_modules_path)
684
-
713
+ available_packages = self._get_available_packages()
685
714
  added_packages = self._select_packages(available_packages, existing_module_names)
686
715
 
687
716
  download_data = self._get_download_data(added_packages)
688
717
  self._create(organization_dir, added_packages, environments, "update", download_data)
718
+
719
+ def _get_available_packages(self) -> Packages:
720
+ """
721
+ Returns a list of available packages, either from the CDF TOML file or from external libraries if the feature flag is enabled.
722
+ If the feature flag is not enabled and no libraries are specified, it returns the built-in modules.
723
+ """
724
+
725
+ cdf_toml = CDFToml.load()
726
+ if Flags.EXTERNAL_LIBRARIES.is_enabled() and cdf_toml.libraries:
727
+ for library_name, library in cdf_toml.libraries.items():
728
+ try:
729
+ print(f"[green]Adding library {library_name}[/]")
730
+ file_path = self._temp_download_dir / f"{library_name}.zip"
731
+ self._download(library.url, file_path)
732
+ self._validate_checksum(library.checksum, file_path)
733
+ self._unpack(file_path)
734
+ return Packages().load(file_path.parent)
735
+ except Exception as e:
736
+ if isinstance(e, ToolkitError):
737
+ raise e
738
+ else:
739
+ raise ToolkitError(
740
+ f"An unexpected error occurred during downloading {library.url} to {file_path}: {e}"
741
+ ) from e
742
+
743
+ raise ToolkitError(f"Failed to add library {library_name}, {e}")
744
+ # If no libraries are specified or the flag is not enabled, load the built-in modules
745
+ raise ValueError("No valid libraries found.")
746
+ else:
747
+ return Packages.load(self._builtin_modules_path)
748
+
749
+ def _download(self, url: str, file_path: Path) -> None:
750
+ """
751
+ Downloads a file from a URL to the specified output path.
752
+ If the file already exists, it skips the download.
753
+ """
754
+ try:
755
+ response = requests.get(url, stream=True)
756
+ response.raise_for_status() # Raise an exception for HTTP errors
757
+
758
+ total_size = int(response.headers.get("content-length", 0))
759
+
760
+ with Progress() as progress:
761
+ task = progress.add_task("Download", total=total_size)
762
+ with open(file_path, "wb") as f:
763
+ for chunk in response.iter_content(chunk_size=8192):
764
+ f.write(chunk)
765
+ progress.update(task, advance=len(chunk))
766
+
767
+ except requests.exceptions.RequestException as e:
768
+ raise ToolkitError(f"Error downloading file from {url}: {e}") from e
769
+
770
+ def _validate_checksum(self, checksum: str, file_path: Path) -> None:
771
+ """
772
+ Compares the checksum of the downloaded file with the expected checksum.
773
+ """
774
+
775
+ if checksum.lower().startswith("sha256:"):
776
+ checksum = checksum[7:]
777
+ else:
778
+ raise ToolkitValueError(f"Unsupported checksum format: {checksum}. Expected 'sha256:' prefix")
779
+
780
+ chunk_size: int = 8192
781
+ sha256_hash = sha256()
782
+ try:
783
+ with open(file_path, "rb") as f:
784
+ # Read the file in chunks to handle large files efficiently
785
+ for chunk in iter(lambda: f.read(chunk_size), b""):
786
+ sha256_hash.update(chunk)
787
+ calculated = sha256_hash.hexdigest()
788
+ if calculated != checksum:
789
+ raise ToolkitError(f"Checksum mismatch. Expected {checksum}, got {calculated}.")
790
+ else:
791
+ print("Checksum verified")
792
+ except Exception as e:
793
+ raise ToolkitError(f"Failed to calculate checksum for {file_path}: {e}") from e
794
+
795
+ def _unpack(self, file_path: Path) -> None:
796
+ """
797
+ Unzips the downloaded file to the specified output path.
798
+ If the file is not a zip file, it raises an error.
799
+ """
800
+ total_size = file_path.stat().st_size if file_path.exists() else 0
801
+
802
+ try:
803
+ with Progress() as progress:
804
+ unzip = progress.add_task("Unzipping", total=total_size)
805
+ with zipfile.ZipFile(file_path, "r") as zip_ref:
806
+ zip_ref.extractall(file_path.parent)
807
+ progress.update(unzip, advance=total_size)
808
+ except zipfile.BadZipFile as e:
809
+ raise ToolkitError(f"Error unpacking zip file {file_path}: {e}") from e
810
+ except Exception as e:
811
+ raise ToolkitError(f"An unexpected error occurred while unpacking {file_path}: {e}") from e
@@ -66,8 +66,8 @@ class Packages(dict, MutableMapping[str, Package]):
66
66
  root_module_dir: The module directories to load the packages from.
67
67
  """
68
68
 
69
- package_definition_path = root_module_dir / "package.toml"
70
- if not package_definition_path.exists():
69
+ package_definition_path = next(root_module_dir.rglob("packages.toml"), None)
70
+ if not package_definition_path or not package_definition_path.exists():
71
71
  raise ToolkitFileNotFoundError(f"Package manifest toml not found at {package_definition_path}")
72
72
  package_definitions = toml.loads(package_definition_path.read_text(encoding="utf-8"))["packages"]
73
73
 
@@ -52,6 +52,10 @@ class Flags(Enum):
52
52
  "visible": True,
53
53
  "description": "Enables the migrate command",
54
54
  }
55
+ EXTERNAL_LIBRARIES: ClassVar[dict[str, Any]] = { # type: ignore[misc]
56
+ "visible": True,
57
+ "description": "Enables the support for external libraries in the config file",
58
+ }
55
59
 
56
60
  def is_enabled(self) -> bool:
57
61
  return FeatureFlag.is_enabled(self)
@@ -48,6 +48,11 @@ class AllScope(Scope):
48
48
  _scope_name = "all"
49
49
 
50
50
 
51
+ class AppConfigScope(Scope):
52
+ _scope_name = "appScope"
53
+ apps: list[Literal["SEARCH"]]
54
+
55
+
51
56
  class CurrentUserScope(Scope):
52
57
  _scope_name = "currentuserscope"
53
58
 
@@ -175,6 +180,12 @@ class AnnotationsAcl(Capability):
175
180
  scope: AllScope
176
181
 
177
182
 
183
+ class AppConfigAcl(Capability):
184
+ _capability_name = "appConfigAcl"
185
+ actions: list[Literal["READ", "WRITE"]]
186
+ scope: AllScope | AppConfigScope
187
+
188
+
178
189
  class AssetsAcl(Capability):
179
190
  _capability_name = "assetsAcl"
180
191
  actions: list[Literal["READ", "WRITE"]]