cognite-toolkit 0.6.88__py3-none-any.whl → 0.6.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cognite-toolkit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cognite_toolkit/_cdf_tk/commands/_migrate/canvas.py +60 -5
  2. cognite_toolkit/_cdf_tk/commands/_migrate/command.py +4 -2
  3. cognite_toolkit/_cdf_tk/commands/_migrate/conversion.py +161 -44
  4. cognite_toolkit/_cdf_tk/commands/_migrate/data_classes.py +10 -10
  5. cognite_toolkit/_cdf_tk/commands/_migrate/data_mapper.py +7 -3
  6. cognite_toolkit/_cdf_tk/commands/_migrate/migration_io.py +8 -10
  7. cognite_toolkit/_cdf_tk/commands/build_cmd.py +1 -1
  8. cognite_toolkit/_cdf_tk/commands/pull.py +6 -5
  9. cognite_toolkit/_cdf_tk/data_classes/_build_variables.py +120 -14
  10. cognite_toolkit/_cdf_tk/data_classes/_built_resources.py +1 -1
  11. cognite_toolkit/_cdf_tk/resource_classes/agent.py +1 -0
  12. cognite_toolkit/_cdf_tk/resource_classes/infield_cdmv1.py +92 -0
  13. cognite_toolkit/_cdf_tk/storageio/__init__.py +2 -0
  14. cognite_toolkit/_cdf_tk/storageio/_annotations.py +102 -0
  15. cognite_toolkit/_cdf_tk/tracker.py +6 -6
  16. cognite_toolkit/_cdf_tk/utils/fileio/_readers.py +90 -44
  17. cognite_toolkit/_cdf_tk/utils/http_client/_client.py +6 -4
  18. cognite_toolkit/_cdf_tk/utils/http_client/_data_classes.py +2 -0
  19. cognite_toolkit/_cdf_tk/utils/useful_types.py +7 -4
  20. cognite_toolkit/_repo_files/GitHub/.github/workflows/deploy.yaml +1 -1
  21. cognite_toolkit/_repo_files/GitHub/.github/workflows/dry-run.yaml +1 -1
  22. cognite_toolkit/_resources/cdf.toml +1 -1
  23. cognite_toolkit/_version.py +1 -1
  24. {cognite_toolkit-0.6.88.dist-info → cognite_toolkit-0.6.90.dist-info}/METADATA +1 -1
  25. {cognite_toolkit-0.6.88.dist-info → cognite_toolkit-0.6.90.dist-info}/RECORD +28 -27
  26. cognite_toolkit/_cdf_tk/commands/_migrate/base.py +0 -106
  27. {cognite_toolkit-0.6.88.dist-info → cognite_toolkit-0.6.90.dist-info}/WHEEL +0 -0
  28. {cognite_toolkit-0.6.88.dist-info → cognite_toolkit-0.6.90.dist-info}/entry_points.txt +0 -0
  29. {cognite_toolkit-0.6.88.dist-info → cognite_toolkit-0.6.90.dist-info}/licenses/LICENSE +0 -0
@@ -8,11 +8,11 @@ from functools import cached_property
8
8
  from pathlib import Path
9
9
  from typing import Any, Literal, SupportsIndex, overload
10
10
 
11
+ from cognite_toolkit._cdf_tk.cruds._resource_cruds.transformation import TransformationCRUD
12
+ from cognite_toolkit._cdf_tk.data_classes._module_directories import ModuleLocation
11
13
  from cognite_toolkit._cdf_tk.exceptions import ToolkitValueError
12
14
  from cognite_toolkit._cdf_tk.feature_flags import Flags
13
15
 
14
- from ._module_directories import ModuleLocation
15
-
16
16
  if sys.version_info >= (3, 11):
17
17
  from typing import Self
18
18
  else:
@@ -161,16 +161,19 @@ class BuildVariables(tuple, Sequence[BuildVariable]):
161
161
  ]
162
162
 
163
163
  @overload
164
- def replace(self, content: str, file_suffix: str = ".yaml", use_placeholder: Literal[False] = False) -> str: ...
164
+ def replace(self, content: str, file_path: Path | None = None, use_placeholder: Literal[False] = False) -> str: ...
165
165
 
166
166
  @overload
167
167
  def replace(
168
- self, content: str, file_suffix: str = ".yaml", use_placeholder: Literal[True] = True
168
+ self, content: str, file_path: Path | None = None, use_placeholder: Literal[True] = True
169
169
  ) -> tuple[str, dict[str, BuildVariable]]: ...
170
170
 
171
171
  def replace(
172
- self, content: str, file_suffix: str = ".yaml", use_placeholder: bool = False
172
+ self, content: str, file_path: Path | None = None, use_placeholder: bool = False
173
173
  ) -> str | tuple[str, dict[str, BuildVariable]]:
174
+ # Extract file suffix from path, default to .yaml if not provided
175
+ file_suffix = file_path.suffix if file_path and file_path.suffix else ".yaml"
176
+
174
177
  variable_by_placeholder: dict[str, BuildVariable] = {}
175
178
  for variable in self:
176
179
  if not use_placeholder:
@@ -180,22 +183,125 @@ class BuildVariables(tuple, Sequence[BuildVariable]):
180
183
  variable_by_placeholder[replace] = variable
181
184
 
182
185
  _core_pattern = rf"{{{{\s*{variable.key}\s*}}}}"
183
- if file_suffix in {".yaml", ".yml", ".json"}:
184
- # Preserve data types
185
- pattern = _core_pattern
186
- if isinstance(replace, str) and (replace.isdigit() or replace.endswith(":")):
187
- replace = f'"{replace}"'
188
- pattern = rf"'{_core_pattern}'|{_core_pattern}|" + rf'"{_core_pattern}"'
189
- elif replace is None:
190
- replace = "null"
191
- content = re.sub(pattern, str(replace), content)
186
+ if file_suffix == ".sql":
187
+ # For SQL files, convert lists to SQL-style tuples
188
+ if isinstance(replace, list):
189
+ replace = self._format_list_as_sql_tuple(replace)
190
+ content = re.sub(_core_pattern, str(replace), content)
191
+ elif file_suffix in {".yaml", ".yml", ".json"}:
192
+ # Check if this is a transformation file (ends with Transformation.yaml/yml)
193
+ is_transformation_file = file_path is not None and f".{TransformationCRUD.kind}." in file_path.name
194
+ # Check if variable is within a query field (SQL context)
195
+ is_in_query_field = self._is_in_query_field(content, variable.key)
196
+
197
+ # For lists in query fields, use SQL-style tuples
198
+ # For transformation files, ensure SQL conversion is applied to query property variables
199
+ if is_transformation_file and is_in_query_field and isinstance(replace, list):
200
+ replace = self._format_list_as_sql_tuple(replace)
201
+ # Use simple pattern for SQL context (no YAML quoting needed)
202
+ content = re.sub(_core_pattern, str(replace), content)
203
+ else:
204
+ # Preserve data types for YAML
205
+ pattern = _core_pattern
206
+ if isinstance(replace, str) and (replace.isdigit() or replace.endswith(":")):
207
+ replace = f'"{replace}"'
208
+ pattern = rf"'{_core_pattern}'|{_core_pattern}|" + rf'"{_core_pattern}"'
209
+ elif replace is None:
210
+ replace = "null"
211
+ content = re.sub(pattern, str(replace), content)
192
212
  else:
213
+ # For other file types, use simple string replacement
193
214
  content = re.sub(_core_pattern, str(replace), content)
194
215
  if use_placeholder:
195
216
  return content, variable_by_placeholder
196
217
  else:
197
218
  return content
198
219
 
220
+ @staticmethod
221
+ def _is_transformation_file(file_path: Path) -> bool:
222
+ """Check if the file path indicates a transformation YAML file.
223
+
224
+ Transformation files are YAML files in the "transformations" folder.
225
+
226
+ Args:
227
+ file_path: The file path to check
228
+
229
+ Returns:
230
+ True if the file is a transformation YAML file
231
+ """
232
+ # Check if path contains "transformations" folder and ends with .yaml/.yml
233
+ path_str = file_path.as_posix().lower()
234
+ return "transformations" in path_str and file_path.suffix.lower() in {".yaml", ".yml"}
235
+
236
+ @staticmethod
237
+ def _format_list_as_sql_tuple(replace: list[Any]) -> str:
238
+ """Format a list as a SQL-style tuple string.
239
+
240
+ Args:
241
+ replace: The list to format
242
+
243
+ Returns:
244
+ SQL tuple string, e.g., "('A', 'B', 'C')" or "()" for empty lists
245
+ """
246
+ if not replace:
247
+ # Empty list becomes empty SQL tuple
248
+ return "()"
249
+ else:
250
+ # Format list as SQL tuple: ('A', 'B', 'C')
251
+ formatted_items = []
252
+ for item in replace:
253
+ if item is None:
254
+ formatted_items.append("NULL")
255
+ elif isinstance(item, str):
256
+ formatted_items.append(f"'{item}'")
257
+ else:
258
+ formatted_items.append(str(item))
259
+ return f"({', '.join(formatted_items)})"
260
+
261
+ @staticmethod
262
+ def _is_in_query_field(content: str, variable_key: str) -> bool:
263
+ """Check if a variable is within a query field in YAML.
264
+
265
+ Assumes query is a top-level property. This detects various YAML formats:
266
+ - query: >-
267
+ - query: |
268
+ - query: "..."
269
+ - query: ...
270
+ """
271
+ lines = content.split("\n")
272
+ variable_pattern = rf"{{{{\s*{re.escape(variable_key)}\s*}}}}"
273
+ in_query_field = False
274
+
275
+ for line in lines:
276
+ # Check if this line starts a top-level query field
277
+ query_match = re.match(r"^query\s*:\s*(.*)$", line)
278
+ if query_match:
279
+ in_query_field = True
280
+ query_content_start = query_match.group(1).strip()
281
+
282
+ # Check if variable is on the same line as query: declaration
283
+ if re.search(variable_pattern, line):
284
+ return True
285
+
286
+ # If query content starts on same line (not a block scalar), check it
287
+ if query_content_start and not query_content_start.startswith(("|", ">", "|-", ">-", "|+", ">+")):
288
+ if re.search(variable_pattern, query_content_start):
289
+ return True
290
+ continue
291
+
292
+ # Check if we're still in the query field
293
+ if in_query_field:
294
+ # If we hit another top-level property, we've exited the query field
295
+ if re.match(r"^\w+\s*:", line):
296
+ in_query_field = False
297
+ continue
298
+
299
+ # We're still in the query field, check for variable
300
+ if re.search(variable_pattern, line):
301
+ return True
302
+
303
+ return False
304
+
199
305
  # Implemented to get correct type hints
200
306
  def __iter__(self) -> Iterator[BuildVariable]:
201
307
  return super().__iter__()
@@ -158,7 +158,7 @@ class BuiltResourceFull(BuiltResource[T_ID]):
158
158
  def load_resource_dict(
159
159
  self, environment_variables: dict[str, str | None], validate: bool = False
160
160
  ) -> dict[str, Any]:
161
- content = self.build_variables.replace(safe_read(self.source.path))
161
+ content = self.build_variables.replace(safe_read(self.source.path), self.source.path)
162
162
  loader = cast(ResourceCRUD, get_crud(self.resource_dir, self.kind))
163
163
  raw = load_yaml_inject_variables(
164
164
  content,
@@ -55,3 +55,4 @@ class AgentYAML(ToolkitResource):
55
55
  "azure/gpt-4o-mini", description="The name of the model to use. Defaults to your CDF project's default model."
56
56
  )
57
57
  tools: list[AgentTool] | None = Field(None, description="A list of tools available to the agent.", max_length=20)
58
+ runtime_version: str | None = Field(None, description="The runtime version")
@@ -0,0 +1,92 @@
1
+ from typing import Any
2
+
3
+ from .base import BaseModelResource, ToolkitResource
4
+
5
+
6
+ class ObservationFeatureToggles(BaseModelResource):
7
+ """Feature toggles for observations."""
8
+
9
+ is_enabled: bool | None = None
10
+ is_write_back_enabled: bool | None = None
11
+ notifications_endpoint_external_id: str | None = None
12
+ attachments_endpoint_external_id: str | None = None
13
+
14
+
15
+ class FeatureToggles(BaseModelResource):
16
+ """Feature toggles for InField location configuration."""
17
+
18
+ three_d: bool | None = None
19
+ trends: bool | None = None
20
+ documents: bool | None = None
21
+ workorders: bool | None = None
22
+ notifications: bool | None = None
23
+ media: bool | None = None
24
+ template_checklist_flow: bool | None = None
25
+ workorder_checklist_flow: bool | None = None
26
+ observations: ObservationFeatureToggles | None = None
27
+
28
+
29
+ class AccessManagement(BaseModelResource):
30
+ """Access management configuration."""
31
+
32
+ template_admins: list[str] | None = None # list of CDF group external IDs
33
+ checklist_admins: list[str] | None = None # list of CDF group external IDs
34
+
35
+
36
+ class ResourceFilters(BaseModelResource):
37
+ """Resource filters."""
38
+
39
+ spaces: list[str] | None = None
40
+
41
+
42
+ class RootLocationDataFilters(BaseModelResource):
43
+ """Data filters for root location."""
44
+
45
+ general: ResourceFilters | None = None
46
+ assets: ResourceFilters | None = None
47
+ files: ResourceFilters | None = None
48
+ timeseries: ResourceFilters | None = None
49
+
50
+
51
+ class DataExplorationConfig(BaseModelResource):
52
+ """Properties for DataExplorationConfig node.
53
+
54
+ Contains configuration for data exploration features:
55
+ - observations: Observations feature configuration
56
+ - activities: Activities configuration
57
+ - documents: Document configuration
58
+ - notifications: Notifications configuration
59
+ - assets: Asset page configuration
60
+ """
61
+
62
+ external_id: str
63
+
64
+ observations: dict[str, Any] | None = None # ObservationsConfigFeature
65
+ activities: dict[str, Any] | None = None # ActivitiesConfiguration
66
+ documents: dict[str, Any] | None = None # DocumentConfiguration
67
+ notifications: dict[str, Any] | None = None # NotificationsConfiguration
68
+ assets: dict[str, Any] | None = None # AssetPageConfiguration
69
+
70
+
71
+ class InfieldLocationConfigYAML(ToolkitResource):
72
+ """Properties for InFieldLocationConfig node.
73
+
74
+ Currently migrated fields:
75
+ - root_location_external_id: Reference to the LocationFilterDTO external ID
76
+ - feature_toggles: Feature toggles migrated from old configuration
77
+ - rootAsset: Direct relation to the root asset (space and externalId)
78
+ - app_instance_space: Application instance space from appDataInstanceSpace
79
+ - access_management: Template and checklist admin groups (from templateAdmins and checklistAdmins)
80
+ - disciplines: List of disciplines (from disciplines in FeatureConfiguration)
81
+ - data_filters: Data filters for general, assets, files, and timeseries (from dataFilters in old configuration)
82
+ - data_exploration_config: Direct relation to the DataExplorationConfig node (shared across all locations)
83
+ """
84
+
85
+ external_id: str
86
+
87
+ root_location_external_id: str | None = None
88
+ feature_toggles: FeatureToggles | None = None
89
+ app_instance_space: str | None = None
90
+ access_management: AccessManagement | None = None
91
+ data_filters: RootLocationDataFilters | None = None
92
+ data_exploration_config: DataExplorationConfig | None = None
@@ -3,6 +3,7 @@ from pathlib import Path
3
3
  from cognite_toolkit._cdf_tk.utils._auxiliary import get_concrete_subclasses
4
4
  from cognite_toolkit._cdf_tk.utils.fileio import COMPRESSION_BY_SUFFIX
5
5
 
6
+ from ._annotations import FileAnnotationIO
6
7
  from ._applications import CanvasIO, ChartIO
7
8
  from ._asset_centric import AssetIO, BaseAssetCentricIO, EventIO, FileMetadataIO, HierarchyIO, TimeSeriesIO
8
9
  from ._base import (
@@ -50,6 +51,7 @@ __all__ = [
50
51
  "ChartIO",
51
52
  "ConfigurableStorageIO",
52
53
  "EventIO",
54
+ "FileAnnotationIO",
53
55
  "FileMetadataIO",
54
56
  "HierarchyIO",
55
57
  "InstanceIO",
@@ -0,0 +1,102 @@
1
+ from collections.abc import Iterable, Sequence
2
+ from typing import Any
3
+
4
+ from cognite.client.data_classes import Annotation, AnnotationFilter
5
+
6
+ from cognite_toolkit._cdf_tk.utils.collection import chunker_sequence
7
+ from cognite_toolkit._cdf_tk.utils.useful_types import JsonVal
8
+
9
+ from ._asset_centric import FileMetadataIO
10
+ from ._base import Page, StorageIO
11
+ from .selectors import AssetCentricSelector
12
+
13
+
14
+ class FileAnnotationIO(StorageIO[AssetCentricSelector, Annotation]):
15
+ SUPPORTED_DOWNLOAD_FORMATS = frozenset({".ndjson"})
16
+ SUPPORTED_COMPRESSIONS = frozenset({".gz"})
17
+ CHUNK_SIZE = 1000
18
+ BASE_SELECTOR = AssetCentricSelector
19
+
20
+ MISSING_ID = "<MISSING_RESOURCE_ID>"
21
+
22
+ def as_id(self, item: Annotation) -> str:
23
+ project = item._cognite_client.config.project
24
+ return f"INTERNAL_ID_project_{project}_{item.id!s}"
25
+
26
+ def stream_data(self, selector: AssetCentricSelector, limit: int | None = None) -> Iterable[Page]:
27
+ total = 0
28
+ for file_chunk in FileMetadataIO(self.client).stream_data(selector, None):
29
+ # Todo Support pagination. This is missing in the SDK.
30
+ results = self.client.annotations.list(
31
+ filter=AnnotationFilter(
32
+ annotated_resource_type="file",
33
+ annotated_resource_ids=[{"id": file_metadata.id} for file_metadata in file_chunk.items],
34
+ )
35
+ )
36
+ if limit is not None and total + len(results) > limit:
37
+ results = results[: limit - total]
38
+
39
+ for chunk in chunker_sequence(results, self.CHUNK_SIZE):
40
+ yield Page(worker_id="main", items=chunk)
41
+ total += len(chunk)
42
+ if limit is not None and total >= limit:
43
+ break
44
+
45
+ def count(self, selector: AssetCentricSelector) -> int | None:
46
+ """There is no efficient way to count annotations in CDF."""
47
+ return None
48
+
49
+ def data_to_json_chunk(
50
+ self, data_chunk: Sequence[Annotation], selector: AssetCentricSelector | None = None
51
+ ) -> list[dict[str, JsonVal]]:
52
+ files_ids: set[int] = set()
53
+ for item in data_chunk:
54
+ if item.annotated_resource_type == "file" and item.annotated_resource_id is not None:
55
+ files_ids.add(item.annotated_resource_id)
56
+ if file_id := self._get_file_id(item.data):
57
+ files_ids.add(file_id)
58
+ self.client.lookup.files.external_id(list(files_ids)) # Preload file external IDs
59
+ asset_ids = {asset_id for item in data_chunk if (asset_id := self._get_asset_id(item.data))}
60
+ self.client.lookup.assets.external_id(list(asset_ids)) # Preload asset external IDs
61
+ return [self.dump_annotation_to_json(item) for item in data_chunk]
62
+
63
+ def dump_annotation_to_json(self, annotation: Annotation) -> dict[str, JsonVal]:
64
+ """Dump annotations to a list of JSON serializable dictionaries.
65
+
66
+ Args:
67
+ annotation: The annotations to dump.
68
+
69
+ Returns:
70
+ A list of JSON serializable dictionaries representing the annotations.
71
+ """
72
+ dumped = annotation.as_write().dump()
73
+ if isinstance(annotated_resource_id := dumped.pop("annotatedResourceId", None), int):
74
+ external_id = self.client.lookup.files.external_id(annotated_resource_id)
75
+ dumped["annotatedResourceExternalId"] = self.MISSING_ID if external_id is None else external_id
76
+
77
+ if isinstance(data := dumped.get("data"), dict):
78
+ if isinstance(file_ref := data.get("fileRef"), dict) and isinstance(file_ref.get("id"), int):
79
+ external_id = self.client.lookup.files.external_id(file_ref.pop("id"))
80
+ file_ref["externalId"] = self.MISSING_ID if external_id is None else external_id
81
+ if isinstance(asset_ref := data.get("assetRef"), dict) and isinstance(asset_ref.get("id"), int):
82
+ external_id = self.client.lookup.assets.external_id(asset_ref.pop("id"))
83
+ asset_ref["externalId"] = self.MISSING_ID if external_id is None else external_id
84
+ return dumped
85
+
86
+ @classmethod
87
+ def _get_file_id(cls, data: dict[str, Any]) -> int | None:
88
+ file_ref = data.get("fileRef")
89
+ if isinstance(file_ref, dict):
90
+ id_ = file_ref.get("id")
91
+ if isinstance(id_, int):
92
+ return id_
93
+ return None
94
+
95
+ @classmethod
96
+ def _get_asset_id(cls, data: dict[str, Any]) -> int | None:
97
+ asset_ref = data.get("assetRef")
98
+ if isinstance(asset_ref, dict):
99
+ id_ = asset_ref.get("id")
100
+ if isinstance(id_, int):
101
+ return id_
102
+ return None
@@ -58,7 +58,7 @@ class Tracker:
58
58
  warning_details[f"warningMostCommon{no}Count"] = count
59
59
  warning_details[f"warningMostCommon{no}Name"] = warning
60
60
 
61
- positional_args, optional_args = self._parse_sys_args()
61
+ subcommands, optional_args = self._parse_sys_args()
62
62
  event_information = {
63
63
  "userInput": self.user_command,
64
64
  "toolkitVersion": __version__,
@@ -69,7 +69,7 @@ class Tracker:
69
69
  **warning_details,
70
70
  "result": type(result).__name__ if isinstance(result, Exception) else result,
71
71
  "error": str(result) if isinstance(result, Exception) else "",
72
- **positional_args,
72
+ "subcommands": subcommands,
73
73
  **optional_args,
74
74
  "alphaFlags": [name for name, value in self._cdf_toml.alpha_flags.items() if value],
75
75
  "plugins": [name for name, value in self._cdf_toml.plugins.items() if value],
@@ -128,9 +128,9 @@ class Tracker:
128
128
  return distinct_id
129
129
 
130
130
  @staticmethod
131
- def _parse_sys_args() -> tuple[dict[str, str], dict[str, str | bool]]:
131
+ def _parse_sys_args() -> tuple[list[str], dict[str, str | bool]]:
132
132
  optional_args: dict[str, str | bool] = {}
133
- positional_args: dict[str, str] = {}
133
+ subcommands: list[str] = []
134
134
  last_key: str | None = None
135
135
  if sys.argv and len(sys.argv) > 1:
136
136
  for arg in sys.argv[1:]:
@@ -147,11 +147,11 @@ class Tracker:
147
147
  optional_args[last_key] = arg
148
148
  last_key = None
149
149
  else:
150
- positional_args[f"positionalArg{len(positional_args)}"] = arg
150
+ subcommands.append(arg)
151
151
 
152
152
  if last_key:
153
153
  optional_args[last_key] = True
154
- return positional_args, optional_args
154
+ return subcommands, optional_args
155
155
 
156
156
  @property
157
157
  def _cicd(self) -> str:
@@ -7,6 +7,7 @@ from dataclasses import dataclass
7
7
  from functools import partial
8
8
  from io import TextIOWrapper
9
9
  from pathlib import Path
10
+ from typing import Any
10
11
 
11
12
  import yaml
12
13
 
@@ -87,26 +88,20 @@ class FailedParsing:
87
88
  error: str
88
89
 
89
90
 
90
- class TableReader(FileReader, ABC): ...
91
-
92
-
93
- class CSVReader(TableReader):
94
- """Reads CSV files and yields each row as a dictionary.
91
+ class TableReader(FileReader, ABC):
92
+ """Reads table-like files and yields each row as a dictionary.
95
93
 
96
94
  Args:
97
- input_file (Path): The path to the CSV file to read.
95
+ input_file (Path): The path to the table file to read.
98
96
  sniff_rows (int | None): Optional number of rows to sniff for
99
97
  schema detection. If None, no schema is detected. If a schema is sniffed
100
- from the first `sniff_rows` rows, it will be used to parse the CSV.
98
+ from the first `sniff_rows` rows, it will be used to parse the table.
101
99
  schema (Sequence[SchemaColumn] | None): Optional schema to use for parsing.
102
100
  You can either provide a schema or use `sniff_rows` to detect it.
103
101
  keep_failed_cells (bool): If True, failed cells will be kept in the
104
102
  `failed_cell` attribute. If False, they will be ignored.
105
-
106
103
  """
107
104
 
108
- format = ".csv"
109
-
110
105
  def __init__(
111
106
  self,
112
107
  input_file: Path,
@@ -152,18 +147,19 @@ class CSVReader(TableReader):
152
147
  @classmethod
153
148
  def sniff_schema(cls, input_file: Path, sniff_rows: int = 100) -> list[SchemaColumn]:
154
149
  """
155
- Sniff the schema from the first `sniff_rows` rows of the CSV file.
150
+ Sniff the schema from the first `sniff_rows` rows of the file.
156
151
 
157
152
  Args:
158
- input_file (Path): The path to the CSV file.
153
+ input_file (Path): The path to the tabular file.
159
154
  sniff_rows (int): The number of rows to read for sniffing the schema.
160
155
 
161
156
  Returns:
162
157
  list[SchemaColumn]: The inferred schema as a list of SchemaColumn objects.
158
+
163
159
  Raises:
164
160
  ValueError: If `sniff_rows` is not a positive integer.
165
161
  ToolkitFileNotFoundError: If the file does not exist.
166
- ToolkitValueError: If the file is not a CSV file or if there are issues with the content.
162
+ ToolkitValueError: If the file is not the correct format or if there are issues with the content.
167
163
 
168
164
  """
169
165
  if sniff_rows <= 0:
@@ -171,43 +167,50 @@ class CSVReader(TableReader):
171
167
 
172
168
  if not input_file.exists():
173
169
  raise ToolkitFileNotFoundError(f"File not found: {input_file.as_posix()!r}.")
174
- if input_file.suffix != ".csv":
175
- raise ToolkitValueError(f"Expected a .csv file got a {input_file.suffix!r} file instead.")
170
+ if input_file.suffix != cls.format:
171
+ raise ToolkitValueError(f"Expected a {cls.format} file got a {input_file.suffix!r} file instead.")
176
172
 
177
- with input_file.open("r", encoding="utf-8-sig") as file:
178
- reader = csv.DictReader(file)
179
- column_names = Counter(reader.fieldnames)
180
- if duplicated := [name for name, count in column_names.items() if count > 1]:
181
- raise ToolkitValueError(f"CSV file contains duplicate headers: {humanize_collection(duplicated)}")
182
- sample_rows: list[dict[str, str]] = []
183
- for no, row in enumerate(reader):
184
- if no >= sniff_rows:
185
- break
186
- sample_rows.append(row)
173
+ column_names, sample_rows = cls._read_sample_rows(input_file, sniff_rows)
174
+ cls._check_column_names(column_names)
175
+ return cls._infer_schema(sample_rows, column_names)
187
176
 
188
- if not sample_rows:
189
- raise ToolkitValueError(f"No data found in the file: {input_file.as_posix()!r}.")
177
+ @classmethod
178
+ @abstractmethod
179
+ def _read_sample_rows(cls, input_file: Path, sniff_rows: int) -> tuple[Sequence[str], list[dict[str, str]]]: ...
190
180
 
191
- schema = []
192
- for column_name in reader.fieldnames or []:
193
- sample_values = [row[column_name] for row in sample_rows if column_name in row]
194
- if not sample_values:
195
- column = SchemaColumn(name=column_name, type="string")
181
+ @classmethod
182
+ def _infer_schema(cls, sample_rows: list[dict[str, Any]], column_names: Sequence[str]) -> list[SchemaColumn]:
183
+ schema: list[SchemaColumn] = []
184
+ for column_name in column_names:
185
+ sample_values = [row[column_name] for row in sample_rows if column_name in row]
186
+ if not sample_values:
187
+ column = SchemaColumn(name=column_name, type="string")
188
+ else:
189
+ data_types = Counter(
190
+ infer_data_type_from_value(value, dtype="Json")[0] for value in sample_values if value is not None
191
+ )
192
+ if not data_types:
193
+ inferred_type = "string"
196
194
  else:
197
- data_types = Counter(
198
- infer_data_type_from_value(value, dtype="Json")[0]
199
- for value in sample_values
200
- if value is not None
201
- )
202
- if not data_types:
203
- inferred_type = "string"
204
- else:
205
- inferred_type = data_types.most_common()[0][0]
206
- # Json dtype is a subset of Datatype that SchemaColumn accepts
207
- column = SchemaColumn(name=column_name, type=inferred_type) # type: ignore[arg-type]
208
- schema.append(column)
195
+ inferred_type = data_types.most_common()[0][0]
196
+ # Json dtype is a subset of Datatype that SchemaColumn accepts
197
+ column = SchemaColumn(name=column_name, type=inferred_type) # type: ignore[arg-type]
198
+ schema.append(column)
209
199
  return schema
210
200
 
201
+ @classmethod
202
+ def _check_column_names(cls, column_names: Sequence[str]) -> None:
203
+ """Check for duplicate column names."""
204
+ duplicates = [name for name, count in Counter(column_names).items() if count > 1]
205
+ if duplicates:
206
+ raise ToolkitValueError(f"Duplicate column names found: {humanize_collection(duplicates)}.")
207
+
208
+
209
+ class CSVReader(TableReader):
210
+ """Reads CSV files and yields each row as a dictionary."""
211
+
212
+ format = ".csv"
213
+
211
214
  def _read_chunks_from_file(self, file: TextIOWrapper) -> Iterator[dict[str, JsonVal]]:
212
215
  if self.keep_failed_cells and self.failed_cell:
213
216
  self.failed_cell.clear()
@@ -231,10 +234,31 @@ class CSVReader(TableReader):
231
234
  with compression.open("r") as file:
232
235
  yield from csv.DictReader(file)
233
236
 
237
+ @classmethod
238
+ def _read_sample_rows(cls, input_file: Path, sniff_rows: int) -> tuple[Sequence[str], list[dict[str, str]]]:
239
+ column_names: Sequence[str] = []
240
+ compression = Compression.from_filepath(input_file)
241
+ with compression.open("r") as file:
242
+ reader = csv.DictReader(file)
243
+ column_names = reader.fieldnames or []
244
+ sample_rows: list[dict[str, str]] = []
245
+ for no, row in enumerate(reader):
246
+ if no >= sniff_rows:
247
+ break
248
+ sample_rows.append(row)
249
+
250
+ if not sample_rows:
251
+ raise ToolkitValueError(f"No data found in the file: {input_file.as_posix()!r}.")
252
+ return column_names, sample_rows
253
+
234
254
 
235
255
  class ParquetReader(TableReader):
236
256
  format = ".parquet"
237
257
 
258
+ def __init__(self, input_file: Path) -> None:
259
+ # Parquet files have their own schema, so we don't need to sniff or provide one.
260
+ super().__init__(input_file, sniff_rows=None, schema=None, keep_failed_cells=False)
261
+
238
262
  def read_chunks(self) -> Iterator[dict[str, JsonVal]]:
239
263
  import pyarrow.parquet as pq
240
264
 
@@ -258,6 +282,28 @@ class ParquetReader(TableReader):
258
282
  return value
259
283
  return value
260
284
 
285
+ @classmethod
286
+ def _read_sample_rows(cls, input_file: Path, sniff_rows: int) -> tuple[Sequence[str], list[dict[str, str]]]:
287
+ import pyarrow.parquet as pq
288
+
289
+ column_names: Sequence[str] = []
290
+ sample_rows: list[dict[str, str]] = []
291
+ with pq.ParquetFile(input_file) as parquet_file:
292
+ column_names = parquet_file.schema.names
293
+ row_count = min(sniff_rows, parquet_file.metadata.num_rows)
294
+ row_iter = parquet_file.iter_batches(batch_size=row_count)
295
+ try:
296
+ batch = next(row_iter)
297
+ for row in batch.to_pylist():
298
+ str_row = {key: (str(value) if value is not None else "") for key, value in row.items()}
299
+ sample_rows.append(str_row)
300
+ except StopIteration:
301
+ pass
302
+
303
+ if not sample_rows:
304
+ raise ToolkitValueError(f"No data found in the file: {input_file.as_posix()!r}.")
305
+ return column_names, sample_rows
306
+
261
307
 
262
308
  FILE_READ_CLS_BY_FORMAT: Mapping[str, type[FileReader]] = {}
263
309
  TABLE_READ_CLS_BY_FORMAT: Mapping[str, type[TableReader]] = {}