cognite-toolkit 0.6.87__py3-none-any.whl → 0.6.89__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. cognite_toolkit/_cdf_tk/cdf_toml.py +13 -9
  2. cognite_toolkit/_cdf_tk/commands/_base.py +2 -1
  3. cognite_toolkit/_cdf_tk/commands/_migrate/canvas.py +60 -5
  4. cognite_toolkit/_cdf_tk/commands/_migrate/command.py +4 -2
  5. cognite_toolkit/_cdf_tk/commands/_migrate/conversion.py +161 -44
  6. cognite_toolkit/_cdf_tk/commands/_migrate/data_classes.py +10 -10
  7. cognite_toolkit/_cdf_tk/commands/_migrate/data_mapper.py +7 -3
  8. cognite_toolkit/_cdf_tk/commands/_migrate/migration_io.py +8 -10
  9. cognite_toolkit/_cdf_tk/commands/build_cmd.py +5 -7
  10. cognite_toolkit/_cdf_tk/commands/deploy.py +11 -0
  11. cognite_toolkit/_cdf_tk/commands/modules.py +16 -12
  12. cognite_toolkit/_cdf_tk/data_classes/__init__.py +2 -0
  13. cognite_toolkit/_cdf_tk/data_classes/_config_yaml.py +7 -1
  14. cognite_toolkit/_cdf_tk/data_classes/_module_directories.py +8 -0
  15. cognite_toolkit/_cdf_tk/data_classes/_tracking_info.py +43 -0
  16. cognite_toolkit/_cdf_tk/storageio/__init__.py +2 -0
  17. cognite_toolkit/_cdf_tk/storageio/_annotations.py +102 -0
  18. cognite_toolkit/_cdf_tk/tracker.py +9 -19
  19. cognite_toolkit/_cdf_tk/utils/fileio/_readers.py +90 -44
  20. cognite_toolkit/_cdf_tk/utils/http_client/_client.py +6 -4
  21. cognite_toolkit/_cdf_tk/utils/http_client/_data_classes.py +2 -0
  22. cognite_toolkit/_cdf_tk/utils/useful_types.py +7 -4
  23. cognite_toolkit/_repo_files/GitHub/.github/workflows/deploy.yaml +1 -1
  24. cognite_toolkit/_repo_files/GitHub/.github/workflows/dry-run.yaml +1 -1
  25. cognite_toolkit/_resources/cdf.toml +1 -1
  26. cognite_toolkit/_version.py +1 -1
  27. {cognite_toolkit-0.6.87.dist-info → cognite_toolkit-0.6.89.dist-info}/METADATA +1 -1
  28. {cognite_toolkit-0.6.87.dist-info → cognite_toolkit-0.6.89.dist-info}/RECORD +31 -30
  29. cognite_toolkit/_cdf_tk/commands/_migrate/base.py +0 -106
  30. {cognite_toolkit-0.6.87.dist-info → cognite_toolkit-0.6.89.dist-info}/WHEEL +0 -0
  31. {cognite_toolkit-0.6.87.dist-info → cognite_toolkit-0.6.89.dist-info}/entry_points.txt +0 -0
  32. {cognite_toolkit-0.6.87.dist-info → cognite_toolkit-0.6.89.dist-info}/licenses/LICENSE +0 -0
@@ -95,11 +95,14 @@ class ModulesCommand(ToolkitCommand):
95
95
  print_warning: bool = True,
96
96
  skip_tracking: bool = False,
97
97
  silent: bool = False,
98
+ temp_dir_suffix: str | None = None,
98
99
  module_source_dir: Path | None = None,
99
100
  ):
100
101
  super().__init__(print_warning, skip_tracking, silent)
101
102
  self._module_source_dir: Path | None = module_source_dir
102
- self._temp_download_dir = Path(tempfile.gettempdir()) / MODULES
103
+ # Use suffix to make temp directory unique (useful for parallel test execution)
104
+ modules_dir_name = f"{MODULES}.{temp_dir_suffix}" if temp_dir_suffix else MODULES
105
+ self._temp_download_dir = Path(tempfile.gettempdir()) / modules_dir_name
103
106
  if not self._temp_download_dir.exists():
104
107
  self._temp_download_dir.mkdir(parents=True, exist_ok=True)
105
108
 
@@ -170,6 +173,11 @@ class ModulesCommand(ToolkitCommand):
170
173
  print(f"{INDENT}[{'yellow' if mode == 'clean' else 'green'}]Creating {package_name}[/]")
171
174
 
172
175
  for module in package.modules:
176
+ if module.module_id:
177
+ self._additional_tracking_info.installed_module_ids.add(module.module_id)
178
+ if module.package_id:
179
+ self._additional_tracking_info.installed_package_ids.add(module.package_id)
180
+
173
181
  if module.dir in seen_modules:
174
182
  # A module can be part of multiple packages
175
183
  continue
@@ -769,7 +777,6 @@ default_organization_dir = "{organization_dir.name}"''',
769
777
  def _get_available_packages(self, user_library: Library | None = None) -> tuple[Packages, Path]:
770
778
  """
771
779
  Returns a list of available packages, either from the CDF TOML file or from external libraries if the feature flag is enabled.
772
- If the feature flag is not enabled and no libraries are specified, it returns the built-in modules.
773
780
  """
774
781
 
775
782
  cdf_toml = CDFToml.load()
@@ -778,9 +785,8 @@ default_organization_dir = "{organization_dir.name}"''',
778
785
 
779
786
  for library_name, library in libraries.items():
780
787
  try:
781
- additional_tracking_info = self._additional_tracking_info.setdefault("downloadedLibraryIds", [])
782
- if library_name not in additional_tracking_info:
783
- additional_tracking_info.append(library_name)
788
+ if library_name:
789
+ self._additional_tracking_info.downloaded_library_ids.add(library_name)
784
790
 
785
791
  print(f"[green]Adding library {library_name} from {library.url}[/]")
786
792
  # Extract filename from URL, fallback to library_name.zip if no filename found
@@ -802,14 +808,12 @@ default_organization_dir = "{organization_dir.name}"''',
802
808
 
803
809
  # Track deployment pack download for each package and module
804
810
  for package in packages.values():
805
- downloaded_package_ids = self._additional_tracking_info.setdefault("downloadedPackageIds", [])
806
- if package.id and package.id not in downloaded_package_ids:
807
- downloaded_package_ids.append(package.id)
811
+ if package.id:
812
+ self._additional_tracking_info.downloaded_package_ids.add(package.id)
808
813
 
809
- downloaded_module_ids = self._additional_tracking_info.setdefault("downloadedModuleIds", [])
810
814
  for module in package.modules:
811
- if module.module_id and module.module_id not in downloaded_module_ids:
812
- downloaded_module_ids.append(module.module_id)
815
+ if module.module_id:
816
+ self._additional_tracking_info.downloaded_module_ids.add(module.module_id)
813
817
 
814
818
  return packages, file_path.parent
815
819
  except Exception as e:
@@ -821,7 +825,7 @@ default_organization_dir = "{organization_dir.name}"''',
821
825
  ) from e
822
826
 
823
827
  raise ToolkitError(f"Failed to add library {library_name}, {e}")
824
- # If no libraries are specified or the flag is not enabled, load the built-in modules
828
+ # If no libraries are specified or the flag is not enabled, raise an error
825
829
  raise ValueError("No valid libraries found.")
826
830
  else:
827
831
  if user_library:
@@ -32,6 +32,7 @@ from ._deploy_results import (
32
32
  from ._module_directories import ModuleDirectories, ModuleLocation
33
33
  from ._module_resources import ModuleResources
34
34
  from ._packages import Package, Packages
35
+ from ._tracking_info import CommandTrackingInfo
35
36
  from ._yaml_comments import YAMLComments
36
37
 
37
38
  __all__ = [
@@ -47,6 +48,7 @@ __all__ = [
47
48
  "BuiltResource",
48
49
  "BuiltResourceFull",
49
50
  "BuiltResourceList",
51
+ "CommandTrackingInfo",
50
52
  "ConfigEntry",
51
53
  "ConfigYAMLs",
52
54
  "DatapointDeployResult",
@@ -496,7 +496,9 @@ class InitConfigYAML(YAMLWithComments[tuple[str, ...], ConfigEntry], ConfigYAMLC
496
496
  adds them to the config.yaml file.
497
497
 
498
498
  Args:
499
- cognite_root_module: The root module for all cognite modules.
499
+ cognite_root_module: Path to the root directory containing all Cognite modules.
500
+ defaults_files: List of paths to default.config.yaml files to load.
501
+ ignore_patterns: Optional list of tuples containing patterns to ignore when loading defaults.
500
502
 
501
503
  Returns:
502
504
  self
@@ -509,6 +511,10 @@ class InitConfigYAML(YAMLWithComments[tuple[str, ...], ConfigEntry], ConfigYAMLC
509
511
  raw_file = safe_read(default_config)
510
512
  file_comments = self._extract_comments(raw_file, key_prefix=tuple(parts))
511
513
  file_data = cast(dict, read_yaml_content(raw_file))
514
+
515
+ # a file may exist, but contain just comments, thus the file_data is None
516
+ if file_data is None:
517
+ continue
512
518
  for key, value in file_data.items():
513
519
  if len(parts) >= 1 and parts[0] in ROOT_MODULES:
514
520
  key_path = (self._variables, *parts, key)
@@ -164,6 +164,8 @@ class ModuleLocation:
164
164
  return ReadModule(
165
165
  dir=self.dir,
166
166
  resource_directories=tuple(self.resource_directories),
167
+ module_id=self.module_id,
168
+ package_id=self.package_id,
167
169
  )
168
170
 
169
171
 
@@ -178,6 +180,8 @@ class ReadModule:
178
180
 
179
181
  dir: Path
180
182
  resource_directories: tuple[str, ...]
183
+ module_id: str | None
184
+ package_id: str | None
181
185
 
182
186
  def resource_dir_path(self, resource_folder: str) -> Path | None:
183
187
  """Returns the path to a resource in the module.
@@ -198,12 +202,16 @@ class ReadModule:
198
202
  return cls(
199
203
  dir=Path(data["dir"]),
200
204
  resource_directories=tuple(data["resource_directories"]),
205
+ module_id=data.get("module_id"),
206
+ package_id=data.get("package_id"),
201
207
  )
202
208
 
203
209
  def dump(self) -> dict[str, Any]:
204
210
  return {
205
211
  "dir": self.dir.as_posix(),
206
212
  "resource_directories": list(self.resource_directories),
213
+ "module_id": self.module_id,
214
+ "package_id": self.package_id,
207
215
  }
208
216
 
209
217
 
@@ -0,0 +1,43 @@
1
+ """Data class for command tracking information."""
2
+
3
+ from typing import Any
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class CommandTrackingInfo(BaseModel):
9
+ """Structured tracking information for CLI commands.
10
+
11
+ This model provides type-safe tracking information that can be collected
12
+ during command execution and sent to Mixpanel for analytics.
13
+
14
+ Attributes:
15
+ project: The CDF project name.
16
+ cluster: The CDF cluster name.
17
+ module_ids: List of module IDs that were deployed or built.
18
+ package_ids: List of package IDs that were deployed or built.
19
+ installed_module_ids: List of module IDs that were installed.
20
+ installed_package_ids: List of package IDs that were installed.
21
+ downloaded_library_ids: List of library IDs that were downloaded.
22
+ downloaded_package_ids: List of package IDs that were downloaded.
23
+ downloaded_module_ids: List of module IDs that were downloaded.
24
+ """
25
+
26
+ project: str | None = Field(default=None)
27
+ cluster: str | None = Field(default=None)
28
+ module_ids: set[str] = Field(default_factory=set, alias="moduleIds")
29
+ package_ids: set[str] = Field(default_factory=set, alias="packageIds")
30
+ installed_module_ids: set[str] = Field(default_factory=set, alias="installedModuleIds")
31
+ installed_package_ids: set[str] = Field(default_factory=set, alias="installedPackageIds")
32
+ downloaded_library_ids: set[str] = Field(default_factory=set, alias="downloadedLibraryIds")
33
+ downloaded_package_ids: set[str] = Field(default_factory=set, alias="downloadedPackageIds")
34
+ downloaded_module_ids: set[str] = Field(default_factory=set, alias="downloadedModuleIds")
35
+
36
+ def to_dict(self) -> dict[str, Any]:
37
+ """Convert the tracking info to a dictionary for Mixpanel.
38
+
39
+ Returns:
40
+ A dictionary with camelCase keys matching Mixpanel's expected format.
41
+ Default values are excluded.
42
+ """
43
+ return self.model_dump(by_alias=True, exclude_defaults=True)
@@ -3,6 +3,7 @@ from pathlib import Path
3
3
  from cognite_toolkit._cdf_tk.utils._auxiliary import get_concrete_subclasses
4
4
  from cognite_toolkit._cdf_tk.utils.fileio import COMPRESSION_BY_SUFFIX
5
5
 
6
+ from ._annotations import FileAnnotationIO
6
7
  from ._applications import CanvasIO, ChartIO
7
8
  from ._asset_centric import AssetIO, BaseAssetCentricIO, EventIO, FileMetadataIO, HierarchyIO, TimeSeriesIO
8
9
  from ._base import (
@@ -50,6 +51,7 @@ __all__ = [
50
51
  "ChartIO",
51
52
  "ConfigurableStorageIO",
52
53
  "EventIO",
54
+ "FileAnnotationIO",
53
55
  "FileMetadataIO",
54
56
  "HierarchyIO",
55
57
  "InstanceIO",
@@ -0,0 +1,102 @@
1
+ from collections.abc import Iterable, Sequence
2
+ from typing import Any
3
+
4
+ from cognite.client.data_classes import Annotation, AnnotationFilter
5
+
6
+ from cognite_toolkit._cdf_tk.utils.collection import chunker_sequence
7
+ from cognite_toolkit._cdf_tk.utils.useful_types import JsonVal
8
+
9
+ from ._asset_centric import FileMetadataIO
10
+ from ._base import Page, StorageIO
11
+ from .selectors import AssetCentricSelector
12
+
13
+
14
+ class FileAnnotationIO(StorageIO[AssetCentricSelector, Annotation]):
15
+ SUPPORTED_DOWNLOAD_FORMATS = frozenset({".ndjson"})
16
+ SUPPORTED_COMPRESSIONS = frozenset({".gz"})
17
+ CHUNK_SIZE = 1000
18
+ BASE_SELECTOR = AssetCentricSelector
19
+
20
+ MISSING_ID = "<MISSING_RESOURCE_ID>"
21
+
22
+ def as_id(self, item: Annotation) -> str:
23
+ project = item._cognite_client.config.project
24
+ return f"INTERNAL_ID_project_{project}_{item.id!s}"
25
+
26
+ def stream_data(self, selector: AssetCentricSelector, limit: int | None = None) -> Iterable[Page]:
27
+ total = 0
28
+ for file_chunk in FileMetadataIO(self.client).stream_data(selector, None):
29
+ # Todo Support pagination. This is missing in the SDK.
30
+ results = self.client.annotations.list(
31
+ filter=AnnotationFilter(
32
+ annotated_resource_type="file",
33
+ annotated_resource_ids=[{"id": file_metadata.id} for file_metadata in file_chunk.items],
34
+ )
35
+ )
36
+ if limit is not None and total + len(results) > limit:
37
+ results = results[: limit - total]
38
+
39
+ for chunk in chunker_sequence(results, self.CHUNK_SIZE):
40
+ yield Page(worker_id="main", items=chunk)
41
+ total += len(chunk)
42
+ if limit is not None and total >= limit:
43
+ break
44
+
45
+ def count(self, selector: AssetCentricSelector) -> int | None:
46
+ """There is no efficient way to count annotations in CDF."""
47
+ return None
48
+
49
+ def data_to_json_chunk(
50
+ self, data_chunk: Sequence[Annotation], selector: AssetCentricSelector | None = None
51
+ ) -> list[dict[str, JsonVal]]:
52
+ files_ids: set[int] = set()
53
+ for item in data_chunk:
54
+ if item.annotated_resource_type == "file" and item.annotated_resource_id is not None:
55
+ files_ids.add(item.annotated_resource_id)
56
+ if file_id := self._get_file_id(item.data):
57
+ files_ids.add(file_id)
58
+ self.client.lookup.files.external_id(list(files_ids)) # Preload file external IDs
59
+ asset_ids = {asset_id for item in data_chunk if (asset_id := self._get_asset_id(item.data))}
60
+ self.client.lookup.assets.external_id(list(asset_ids)) # Preload asset external IDs
61
+ return [self.dump_annotation_to_json(item) for item in data_chunk]
62
+
63
+ def dump_annotation_to_json(self, annotation: Annotation) -> dict[str, JsonVal]:
64
+ """Dump annotations to a list of JSON serializable dictionaries.
65
+
66
+ Args:
67
+ annotation: The annotations to dump.
68
+
69
+ Returns:
70
+ A list of JSON serializable dictionaries representing the annotations.
71
+ """
72
+ dumped = annotation.as_write().dump()
73
+ if isinstance(annotated_resource_id := dumped.pop("annotatedResourceId", None), int):
74
+ external_id = self.client.lookup.files.external_id(annotated_resource_id)
75
+ dumped["annotatedResourceExternalId"] = self.MISSING_ID if external_id is None else external_id
76
+
77
+ if isinstance(data := dumped.get("data"), dict):
78
+ if isinstance(file_ref := data.get("fileRef"), dict) and isinstance(file_ref.get("id"), int):
79
+ external_id = self.client.lookup.files.external_id(file_ref.pop("id"))
80
+ file_ref["externalId"] = self.MISSING_ID if external_id is None else external_id
81
+ if isinstance(asset_ref := data.get("assetRef"), dict) and isinstance(asset_ref.get("id"), int):
82
+ external_id = self.client.lookup.assets.external_id(asset_ref.pop("id"))
83
+ asset_ref["externalId"] = self.MISSING_ID if external_id is None else external_id
84
+ return dumped
85
+
86
+ @classmethod
87
+ def _get_file_id(cls, data: dict[str, Any]) -> int | None:
88
+ file_ref = data.get("fileRef")
89
+ if isinstance(file_ref, dict):
90
+ id_ = file_ref.get("id")
91
+ if isinstance(id_, int):
92
+ return id_
93
+ return None
94
+
95
+ @classmethod
96
+ def _get_asset_id(cls, data: dict[str, Any]) -> int | None:
97
+ asset_ref = data.get("assetRef")
98
+ if isinstance(asset_ref, dict):
99
+ id_ = asset_ref.get("id")
100
+ if isinstance(id_, int):
101
+ return id_
102
+ return None
@@ -14,7 +14,7 @@ from mixpanel import Consumer, Mixpanel, MixpanelException
14
14
 
15
15
  from cognite_toolkit._cdf_tk.cdf_toml import CDFToml
16
16
  from cognite_toolkit._cdf_tk.constants import IN_BROWSER
17
- from cognite_toolkit._cdf_tk.data_classes._built_modules import BuiltModule
17
+ from cognite_toolkit._cdf_tk.data_classes import CommandTrackingInfo
18
18
  from cognite_toolkit._cdf_tk.tk_warnings import ToolkitWarning, WarningList
19
19
  from cognite_toolkit._cdf_tk.utils import get_cicd_environment
20
20
  from cognite_toolkit._version import __version__
@@ -49,7 +49,7 @@ class Tracker:
49
49
  warning_list: WarningList[ToolkitWarning],
50
50
  result: str | Exception,
51
51
  cmd: str,
52
- additional_tracking_info: dict[str, Any] | None = None,
52
+ additional_tracking_info: CommandTrackingInfo | None = None,
53
53
  ) -> bool:
54
54
  warning_count = Counter([type(w).__name__ for w in warning_list])
55
55
 
@@ -58,7 +58,7 @@ class Tracker:
58
58
  warning_details[f"warningMostCommon{no}Count"] = count
59
59
  warning_details[f"warningMostCommon{no}Name"] = warning
60
60
 
61
- positional_args, optional_args = self._parse_sys_args()
61
+ subcommands, optional_args = self._parse_sys_args()
62
62
  event_information = {
63
63
  "userInput": self.user_command,
64
64
  "toolkitVersion": __version__,
@@ -69,27 +69,17 @@ class Tracker:
69
69
  **warning_details,
70
70
  "result": type(result).__name__ if isinstance(result, Exception) else result,
71
71
  "error": str(result) if isinstance(result, Exception) else "",
72
- **positional_args,
72
+ "subcommands": subcommands,
73
73
  **optional_args,
74
74
  "alphaFlags": [name for name, value in self._cdf_toml.alpha_flags.items() if value],
75
75
  "plugins": [name for name, value in self._cdf_toml.plugins.items() if value],
76
76
  }
77
77
 
78
78
  if additional_tracking_info:
79
- event_information.update(additional_tracking_info)
79
+ event_information.update(additional_tracking_info.to_dict())
80
80
 
81
81
  return self._track(f"command{cmd.capitalize()}", event_information)
82
82
 
83
- def track_module_build(self, module: BuiltModule) -> bool:
84
- event_information = {
85
- "module": module.name,
86
- "location_path": module.location.path.as_posix(),
87
- "warning_count": module.warning_count,
88
- "status": module.status,
89
- **{resource_type: len(resource_build) for resource_type, resource_build in module.resources.items()},
90
- }
91
- return self._track("moduleBuild", event_information)
92
-
93
83
  def _track(self, event_name: str, event_information: dict[str, Any]) -> bool:
94
84
  if self.skip_tracking or not self.opted_in or "PYTEST_CURRENT_TEST" in os.environ:
95
85
  return False
@@ -138,9 +128,9 @@ class Tracker:
138
128
  return distinct_id
139
129
 
140
130
  @staticmethod
141
- def _parse_sys_args() -> tuple[dict[str, str], dict[str, str | bool]]:
131
+ def _parse_sys_args() -> tuple[list[str], dict[str, str | bool]]:
142
132
  optional_args: dict[str, str | bool] = {}
143
- positional_args: dict[str, str] = {}
133
+ subcommands: list[str] = []
144
134
  last_key: str | None = None
145
135
  if sys.argv and len(sys.argv) > 1:
146
136
  for arg in sys.argv[1:]:
@@ -157,11 +147,11 @@ class Tracker:
157
147
  optional_args[last_key] = arg
158
148
  last_key = None
159
149
  else:
160
- positional_args[f"positionalArg{len(positional_args)}"] = arg
150
+ subcommands.append(arg)
161
151
 
162
152
  if last_key:
163
153
  optional_args[last_key] = True
164
- return positional_args, optional_args
154
+ return subcommands, optional_args
165
155
 
166
156
  @property
167
157
  def _cicd(self) -> str:
@@ -7,6 +7,7 @@ from dataclasses import dataclass
7
7
  from functools import partial
8
8
  from io import TextIOWrapper
9
9
  from pathlib import Path
10
+ from typing import Any
10
11
 
11
12
  import yaml
12
13
 
@@ -87,26 +88,20 @@ class FailedParsing:
87
88
  error: str
88
89
 
89
90
 
90
- class TableReader(FileReader, ABC): ...
91
-
92
-
93
- class CSVReader(TableReader):
94
- """Reads CSV files and yields each row as a dictionary.
91
+ class TableReader(FileReader, ABC):
92
+ """Reads table-like files and yields each row as a dictionary.
95
93
 
96
94
  Args:
97
- input_file (Path): The path to the CSV file to read.
95
+ input_file (Path): The path to the table file to read.
98
96
  sniff_rows (int | None): Optional number of rows to sniff for
99
97
  schema detection. If None, no schema is detected. If a schema is sniffed
100
- from the first `sniff_rows` rows, it will be used to parse the CSV.
98
+ from the first `sniff_rows` rows, it will be used to parse the table.
101
99
  schema (Sequence[SchemaColumn] | None): Optional schema to use for parsing.
102
100
  You can either provide a schema or use `sniff_rows` to detect it.
103
101
  keep_failed_cells (bool): If True, failed cells will be kept in the
104
102
  `failed_cell` attribute. If False, they will be ignored.
105
-
106
103
  """
107
104
 
108
- format = ".csv"
109
-
110
105
  def __init__(
111
106
  self,
112
107
  input_file: Path,
@@ -152,18 +147,19 @@ class CSVReader(TableReader):
152
147
  @classmethod
153
148
  def sniff_schema(cls, input_file: Path, sniff_rows: int = 100) -> list[SchemaColumn]:
154
149
  """
155
- Sniff the schema from the first `sniff_rows` rows of the CSV file.
150
+ Sniff the schema from the first `sniff_rows` rows of the file.
156
151
 
157
152
  Args:
158
- input_file (Path): The path to the CSV file.
153
+ input_file (Path): The path to the tabular file.
159
154
  sniff_rows (int): The number of rows to read for sniffing the schema.
160
155
 
161
156
  Returns:
162
157
  list[SchemaColumn]: The inferred schema as a list of SchemaColumn objects.
158
+
163
159
  Raises:
164
160
  ValueError: If `sniff_rows` is not a positive integer.
165
161
  ToolkitFileNotFoundError: If the file does not exist.
166
- ToolkitValueError: If the file is not a CSV file or if there are issues with the content.
162
+ ToolkitValueError: If the file is not the correct format or if there are issues with the content.
167
163
 
168
164
  """
169
165
  if sniff_rows <= 0:
@@ -171,43 +167,50 @@ class CSVReader(TableReader):
171
167
 
172
168
  if not input_file.exists():
173
169
  raise ToolkitFileNotFoundError(f"File not found: {input_file.as_posix()!r}.")
174
- if input_file.suffix != ".csv":
175
- raise ToolkitValueError(f"Expected a .csv file got a {input_file.suffix!r} file instead.")
170
+ if input_file.suffix != cls.format:
171
+ raise ToolkitValueError(f"Expected a {cls.format} file got a {input_file.suffix!r} file instead.")
176
172
 
177
- with input_file.open("r", encoding="utf-8-sig") as file:
178
- reader = csv.DictReader(file)
179
- column_names = Counter(reader.fieldnames)
180
- if duplicated := [name for name, count in column_names.items() if count > 1]:
181
- raise ToolkitValueError(f"CSV file contains duplicate headers: {humanize_collection(duplicated)}")
182
- sample_rows: list[dict[str, str]] = []
183
- for no, row in enumerate(reader):
184
- if no >= sniff_rows:
185
- break
186
- sample_rows.append(row)
173
+ column_names, sample_rows = cls._read_sample_rows(input_file, sniff_rows)
174
+ cls._check_column_names(column_names)
175
+ return cls._infer_schema(sample_rows, column_names)
187
176
 
188
- if not sample_rows:
189
- raise ToolkitValueError(f"No data found in the file: {input_file.as_posix()!r}.")
177
+ @classmethod
178
+ @abstractmethod
179
+ def _read_sample_rows(cls, input_file: Path, sniff_rows: int) -> tuple[Sequence[str], list[dict[str, str]]]: ...
190
180
 
191
- schema = []
192
- for column_name in reader.fieldnames or []:
193
- sample_values = [row[column_name] for row in sample_rows if column_name in row]
194
- if not sample_values:
195
- column = SchemaColumn(name=column_name, type="string")
181
+ @classmethod
182
+ def _infer_schema(cls, sample_rows: list[dict[str, Any]], column_names: Sequence[str]) -> list[SchemaColumn]:
183
+ schema: list[SchemaColumn] = []
184
+ for column_name in column_names:
185
+ sample_values = [row[column_name] for row in sample_rows if column_name in row]
186
+ if not sample_values:
187
+ column = SchemaColumn(name=column_name, type="string")
188
+ else:
189
+ data_types = Counter(
190
+ infer_data_type_from_value(value, dtype="Json")[0] for value in sample_values if value is not None
191
+ )
192
+ if not data_types:
193
+ inferred_type = "string"
196
194
  else:
197
- data_types = Counter(
198
- infer_data_type_from_value(value, dtype="Json")[0]
199
- for value in sample_values
200
- if value is not None
201
- )
202
- if not data_types:
203
- inferred_type = "string"
204
- else:
205
- inferred_type = data_types.most_common()[0][0]
206
- # Json dtype is a subset of Datatype that SchemaColumn accepts
207
- column = SchemaColumn(name=column_name, type=inferred_type) # type: ignore[arg-type]
208
- schema.append(column)
195
+ inferred_type = data_types.most_common()[0][0]
196
+ # Json dtype is a subset of Datatype that SchemaColumn accepts
197
+ column = SchemaColumn(name=column_name, type=inferred_type) # type: ignore[arg-type]
198
+ schema.append(column)
209
199
  return schema
210
200
 
201
+ @classmethod
202
+ def _check_column_names(cls, column_names: Sequence[str]) -> None:
203
+ """Check for duplicate column names."""
204
+ duplicates = [name for name, count in Counter(column_names).items() if count > 1]
205
+ if duplicates:
206
+ raise ToolkitValueError(f"Duplicate column names found: {humanize_collection(duplicates)}.")
207
+
208
+
209
+ class CSVReader(TableReader):
210
+ """Reads CSV files and yields each row as a dictionary."""
211
+
212
+ format = ".csv"
213
+
211
214
  def _read_chunks_from_file(self, file: TextIOWrapper) -> Iterator[dict[str, JsonVal]]:
212
215
  if self.keep_failed_cells and self.failed_cell:
213
216
  self.failed_cell.clear()
@@ -231,10 +234,31 @@ class CSVReader(TableReader):
231
234
  with compression.open("r") as file:
232
235
  yield from csv.DictReader(file)
233
236
 
237
+ @classmethod
238
+ def _read_sample_rows(cls, input_file: Path, sniff_rows: int) -> tuple[Sequence[str], list[dict[str, str]]]:
239
+ column_names: Sequence[str] = []
240
+ compression = Compression.from_filepath(input_file)
241
+ with compression.open("r") as file:
242
+ reader = csv.DictReader(file)
243
+ column_names = reader.fieldnames or []
244
+ sample_rows: list[dict[str, str]] = []
245
+ for no, row in enumerate(reader):
246
+ if no >= sniff_rows:
247
+ break
248
+ sample_rows.append(row)
249
+
250
+ if not sample_rows:
251
+ raise ToolkitValueError(f"No data found in the file: {input_file.as_posix()!r}.")
252
+ return column_names, sample_rows
253
+
234
254
 
235
255
  class ParquetReader(TableReader):
236
256
  format = ".parquet"
237
257
 
258
+ def __init__(self, input_file: Path) -> None:
259
+ # Parquet files have their own schema, so we don't need to sniff or provide one.
260
+ super().__init__(input_file, sniff_rows=None, schema=None, keep_failed_cells=False)
261
+
238
262
  def read_chunks(self) -> Iterator[dict[str, JsonVal]]:
239
263
  import pyarrow.parquet as pq
240
264
 
@@ -258,6 +282,28 @@ class ParquetReader(TableReader):
258
282
  return value
259
283
  return value
260
284
 
285
+ @classmethod
286
+ def _read_sample_rows(cls, input_file: Path, sniff_rows: int) -> tuple[Sequence[str], list[dict[str, str]]]:
287
+ import pyarrow.parquet as pq
288
+
289
+ column_names: Sequence[str] = []
290
+ sample_rows: list[dict[str, str]] = []
291
+ with pq.ParquetFile(input_file) as parquet_file:
292
+ column_names = parquet_file.schema.names
293
+ row_count = min(sniff_rows, parquet_file.metadata.num_rows)
294
+ row_iter = parquet_file.iter_batches(batch_size=row_count)
295
+ try:
296
+ batch = next(row_iter)
297
+ for row in batch.to_pylist():
298
+ str_row = {key: (str(value) if value is not None else "") for key, value in row.items()}
299
+ sample_rows.append(str_row)
300
+ except StopIteration:
301
+ pass
302
+
303
+ if not sample_rows:
304
+ raise ToolkitValueError(f"No data found in the file: {input_file.as_posix()!r}.")
305
+ return column_names, sample_rows
306
+
261
307
 
262
308
  FILE_READ_CLS_BY_FORMAT: Mapping[str, type[FileReader]] = {}
263
309
  TABLE_READ_CLS_BY_FORMAT: Mapping[str, type[TableReader]] = {}
@@ -147,13 +147,15 @@ class HTTPClient:
147
147
  timeout=self.config.timeout,
148
148
  )
149
149
 
150
- def _create_headers(self, api_version: str | None = None) -> MutableMapping[str, str]:
150
+ def _create_headers(
151
+ self, api_version: str | None = None, content_type: str = "application/json", accept: str = "application/json"
152
+ ) -> MutableMapping[str, str]:
151
153
  headers: MutableMapping[str, str] = {}
152
154
  headers["User-Agent"] = f"httpx/{httpx.__version__} {get_user_agent()}"
153
155
  auth_name, auth_value = self.config.credentials.authorization_header()
154
156
  headers[auth_name] = auth_value
155
- headers["content-type"] = "application/json"
156
- headers["accept"] = "application/json"
157
+ headers["content-type"] = content_type
158
+ headers["accept"] = accept
157
159
  headers["x-cdp-sdk"] = f"CogniteToolkit:{get_current_toolkit_version()}"
158
160
  headers["x-cdp-app"] = self.config.client_name
159
161
  headers["cdf-version"] = api_version or self.config.api_subversion
@@ -162,7 +164,7 @@ class HTTPClient:
162
164
  return headers
163
165
 
164
166
  def _make_request(self, item: RequestMessage) -> httpx.Response:
165
- headers = self._create_headers(item.api_version)
167
+ headers = self._create_headers(item.api_version, item.content_type, item.accept)
166
168
  params: dict[str, PrimitiveType] | None = None
167
169
  if isinstance(item, ParamRequest):
168
170
  params = item.parameters
@@ -92,6 +92,8 @@ class RequestMessage(HTTPMessage):
92
92
  read_attempt: int = 0
93
93
  status_attempt: int = 0
94
94
  api_version: str | None = None
95
+ content_type: str = "application/json"
96
+ accept: str = "application/json"
95
97
 
96
98
  @property
97
99
  def total_attempts(self) -> int: