data-designer 0.3.8rc1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. data_designer/cli/commands/__init__.py +1 -1
  2. data_designer/interface/__init__.py +21 -1
  3. data_designer/{_version.py → interface/_version.py} +2 -2
  4. data_designer/interface/data_designer.py +8 -11
  5. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/METADATA +10 -42
  6. data_designer-0.4.0.dist-info/RECORD +39 -0
  7. data_designer/__init__.py +0 -17
  8. data_designer/config/__init__.py +0 -2
  9. data_designer/config/analysis/__init__.py +0 -2
  10. data_designer/config/analysis/column_profilers.py +0 -159
  11. data_designer/config/analysis/column_statistics.py +0 -421
  12. data_designer/config/analysis/dataset_profiler.py +0 -84
  13. data_designer/config/analysis/utils/errors.py +0 -10
  14. data_designer/config/analysis/utils/reporting.py +0 -192
  15. data_designer/config/base.py +0 -69
  16. data_designer/config/column_configs.py +0 -470
  17. data_designer/config/column_types.py +0 -141
  18. data_designer/config/config_builder.py +0 -595
  19. data_designer/config/data_designer_config.py +0 -40
  20. data_designer/config/dataset_builders.py +0 -13
  21. data_designer/config/dataset_metadata.py +0 -18
  22. data_designer/config/default_model_settings.py +0 -121
  23. data_designer/config/errors.py +0 -24
  24. data_designer/config/exports.py +0 -145
  25. data_designer/config/interface.py +0 -55
  26. data_designer/config/models.py +0 -455
  27. data_designer/config/preview_results.py +0 -41
  28. data_designer/config/processors.py +0 -148
  29. data_designer/config/run_config.py +0 -48
  30. data_designer/config/sampler_constraints.py +0 -52
  31. data_designer/config/sampler_params.py +0 -639
  32. data_designer/config/seed.py +0 -116
  33. data_designer/config/seed_source.py +0 -84
  34. data_designer/config/seed_source_types.py +0 -19
  35. data_designer/config/utils/code_lang.py +0 -82
  36. data_designer/config/utils/constants.py +0 -363
  37. data_designer/config/utils/errors.py +0 -21
  38. data_designer/config/utils/info.py +0 -94
  39. data_designer/config/utils/io_helpers.py +0 -258
  40. data_designer/config/utils/misc.py +0 -78
  41. data_designer/config/utils/numerical_helpers.py +0 -30
  42. data_designer/config/utils/type_helpers.py +0 -106
  43. data_designer/config/utils/visualization.py +0 -482
  44. data_designer/config/validator_params.py +0 -94
  45. data_designer/engine/__init__.py +0 -2
  46. data_designer/engine/analysis/column_profilers/base.py +0 -49
  47. data_designer/engine/analysis/column_profilers/judge_score_profiler.py +0 -153
  48. data_designer/engine/analysis/column_profilers/registry.py +0 -22
  49. data_designer/engine/analysis/column_statistics.py +0 -145
  50. data_designer/engine/analysis/dataset_profiler.py +0 -149
  51. data_designer/engine/analysis/errors.py +0 -9
  52. data_designer/engine/analysis/utils/column_statistics_calculations.py +0 -234
  53. data_designer/engine/analysis/utils/judge_score_processing.py +0 -132
  54. data_designer/engine/column_generators/__init__.py +0 -2
  55. data_designer/engine/column_generators/generators/__init__.py +0 -2
  56. data_designer/engine/column_generators/generators/base.py +0 -122
  57. data_designer/engine/column_generators/generators/embedding.py +0 -35
  58. data_designer/engine/column_generators/generators/expression.py +0 -55
  59. data_designer/engine/column_generators/generators/llm_completion.py +0 -113
  60. data_designer/engine/column_generators/generators/samplers.py +0 -69
  61. data_designer/engine/column_generators/generators/seed_dataset.py +0 -144
  62. data_designer/engine/column_generators/generators/validation.py +0 -140
  63. data_designer/engine/column_generators/registry.py +0 -60
  64. data_designer/engine/column_generators/utils/errors.py +0 -15
  65. data_designer/engine/column_generators/utils/generator_classification.py +0 -43
  66. data_designer/engine/column_generators/utils/judge_score_factory.py +0 -58
  67. data_designer/engine/column_generators/utils/prompt_renderer.py +0 -100
  68. data_designer/engine/compiler.py +0 -97
  69. data_designer/engine/configurable_task.py +0 -71
  70. data_designer/engine/dataset_builders/artifact_storage.py +0 -283
  71. data_designer/engine/dataset_builders/column_wise_builder.py +0 -338
  72. data_designer/engine/dataset_builders/errors.py +0 -15
  73. data_designer/engine/dataset_builders/multi_column_configs.py +0 -46
  74. data_designer/engine/dataset_builders/utils/__init__.py +0 -2
  75. data_designer/engine/dataset_builders/utils/concurrency.py +0 -215
  76. data_designer/engine/dataset_builders/utils/config_compiler.py +0 -62
  77. data_designer/engine/dataset_builders/utils/dag.py +0 -62
  78. data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +0 -200
  79. data_designer/engine/dataset_builders/utils/errors.py +0 -15
  80. data_designer/engine/errors.py +0 -51
  81. data_designer/engine/model_provider.py +0 -77
  82. data_designer/engine/models/__init__.py +0 -2
  83. data_designer/engine/models/errors.py +0 -300
  84. data_designer/engine/models/facade.py +0 -287
  85. data_designer/engine/models/factory.py +0 -42
  86. data_designer/engine/models/litellm_overrides.py +0 -179
  87. data_designer/engine/models/parsers/__init__.py +0 -2
  88. data_designer/engine/models/parsers/errors.py +0 -34
  89. data_designer/engine/models/parsers/parser.py +0 -235
  90. data_designer/engine/models/parsers/postprocessors.py +0 -93
  91. data_designer/engine/models/parsers/tag_parsers.py +0 -62
  92. data_designer/engine/models/parsers/types.py +0 -84
  93. data_designer/engine/models/recipes/base.py +0 -81
  94. data_designer/engine/models/recipes/response_recipes.py +0 -293
  95. data_designer/engine/models/registry.py +0 -146
  96. data_designer/engine/models/telemetry.py +0 -359
  97. data_designer/engine/models/usage.py +0 -73
  98. data_designer/engine/models/utils.py +0 -38
  99. data_designer/engine/processing/ginja/__init__.py +0 -2
  100. data_designer/engine/processing/ginja/ast.py +0 -65
  101. data_designer/engine/processing/ginja/environment.py +0 -463
  102. data_designer/engine/processing/ginja/exceptions.py +0 -56
  103. data_designer/engine/processing/ginja/record.py +0 -32
  104. data_designer/engine/processing/gsonschema/__init__.py +0 -2
  105. data_designer/engine/processing/gsonschema/exceptions.py +0 -15
  106. data_designer/engine/processing/gsonschema/schema_transformers.py +0 -83
  107. data_designer/engine/processing/gsonschema/types.py +0 -10
  108. data_designer/engine/processing/gsonschema/validators.py +0 -202
  109. data_designer/engine/processing/processors/base.py +0 -13
  110. data_designer/engine/processing/processors/drop_columns.py +0 -42
  111. data_designer/engine/processing/processors/registry.py +0 -25
  112. data_designer/engine/processing/processors/schema_transform.py +0 -49
  113. data_designer/engine/processing/utils.py +0 -169
  114. data_designer/engine/registry/base.py +0 -99
  115. data_designer/engine/registry/data_designer_registry.py +0 -39
  116. data_designer/engine/registry/errors.py +0 -12
  117. data_designer/engine/resources/managed_dataset_generator.py +0 -39
  118. data_designer/engine/resources/managed_dataset_repository.py +0 -197
  119. data_designer/engine/resources/managed_storage.py +0 -65
  120. data_designer/engine/resources/resource_provider.py +0 -77
  121. data_designer/engine/resources/seed_reader.py +0 -154
  122. data_designer/engine/sampling_gen/column.py +0 -91
  123. data_designer/engine/sampling_gen/constraints.py +0 -100
  124. data_designer/engine/sampling_gen/data_sources/base.py +0 -217
  125. data_designer/engine/sampling_gen/data_sources/errors.py +0 -12
  126. data_designer/engine/sampling_gen/data_sources/sources.py +0 -347
  127. data_designer/engine/sampling_gen/entities/__init__.py +0 -2
  128. data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
  129. data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +0 -86
  130. data_designer/engine/sampling_gen/entities/email_address_utils.py +0 -171
  131. data_designer/engine/sampling_gen/entities/errors.py +0 -10
  132. data_designer/engine/sampling_gen/entities/national_id_utils.py +0 -102
  133. data_designer/engine/sampling_gen/entities/person.py +0 -144
  134. data_designer/engine/sampling_gen/entities/phone_number.py +0 -128
  135. data_designer/engine/sampling_gen/errors.py +0 -26
  136. data_designer/engine/sampling_gen/generator.py +0 -122
  137. data_designer/engine/sampling_gen/jinja_utils.py +0 -64
  138. data_designer/engine/sampling_gen/people_gen.py +0 -199
  139. data_designer/engine/sampling_gen/person_constants.py +0 -56
  140. data_designer/engine/sampling_gen/schema.py +0 -147
  141. data_designer/engine/sampling_gen/schema_builder.py +0 -61
  142. data_designer/engine/sampling_gen/utils.py +0 -46
  143. data_designer/engine/secret_resolver.py +0 -82
  144. data_designer/engine/validation.py +0 -367
  145. data_designer/engine/validators/__init__.py +0 -19
  146. data_designer/engine/validators/base.py +0 -38
  147. data_designer/engine/validators/local_callable.py +0 -39
  148. data_designer/engine/validators/python.py +0 -254
  149. data_designer/engine/validators/remote.py +0 -89
  150. data_designer/engine/validators/sql.py +0 -65
  151. data_designer/errors.py +0 -7
  152. data_designer/essentials/__init__.py +0 -33
  153. data_designer/lazy_heavy_imports.py +0 -54
  154. data_designer/logging.py +0 -163
  155. data_designer/plugin_manager.py +0 -78
  156. data_designer/plugins/__init__.py +0 -8
  157. data_designer/plugins/errors.py +0 -15
  158. data_designer/plugins/plugin.py +0 -141
  159. data_designer/plugins/registry.py +0 -88
  160. data_designer/plugins/testing/__init__.py +0 -10
  161. data_designer/plugins/testing/stubs.py +0 -116
  162. data_designer/plugins/testing/utils.py +0 -20
  163. data_designer-0.3.8rc1.dist-info/RECORD +0 -196
  164. data_designer-0.3.8rc1.dist-info/licenses/LICENSE +0 -201
  165. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/WHEEL +0 -0
  166. {data_designer-0.3.8rc1.dist-info → data_designer-0.4.0.dist-info}/entry_points.txt +0 -0
@@ -1,100 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import json
7
- import logging
8
-
9
- from data_designer.config.column_configs import SingleColumnConfig
10
- from data_designer.config.column_types import DataDesignerColumnType
11
- from data_designer.config.models import ModelConfig
12
- from data_designer.config.utils.code_lang import CodeLang
13
- from data_designer.config.utils.misc import extract_keywords_from_jinja2_template
14
- from data_designer.config.utils.type_helpers import StrEnum
15
- from data_designer.engine.column_generators.utils.errors import PromptTemplateRenderError
16
- from data_designer.engine.column_generators.utils.judge_score_factory import (
17
- create_judge_response_model,
18
- create_judge_structured_output_model,
19
- )
20
- from data_designer.engine.models.recipes.base import ResponseRecipe
21
- from data_designer.engine.models.recipes.response_recipes import (
22
- CodeResponseRecipe,
23
- PydanticResponseRecipe,
24
- StructuredResponseRecipe,
25
- TextResponseRecipe,
26
- )
27
- from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
28
- from data_designer.engine.processing.ginja.exceptions import UserTemplateError, UserTemplateUnsupportedFiltersError
29
-
30
- logger = logging.getLogger(__name__)
31
-
32
-
33
- class PromptType(StrEnum):
34
- SYSTEM_PROMPT = "system_prompt"
35
- USER_PROMPT = "user_prompt"
36
-
37
-
38
- class RecordBasedPromptRenderer(WithJinja2UserTemplateRendering):
39
- def __init__(self, response_recipe: ResponseRecipe, *, error_message_context: dict[str, str] | None = None):
40
- self.response_recipe = response_recipe
41
- self._error_message_context = error_message_context
42
-
43
- def render(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> str | None:
44
- self._prepare_environment(prompt_template=prompt_template, record=record, prompt_type=prompt_type)
45
- rendered_prompt = self.render_multi_template(prompt_type, record) if prompt_template else ""
46
- recipe_applicator = (
47
- self.response_recipe.apply_recipe_to_user_prompt
48
- if prompt_type == PromptType.USER_PROMPT
49
- else self.response_recipe.apply_recipe_to_system_prompt
50
- )
51
- return recipe_applicator(rendered_prompt)
52
-
53
- def _prepare_environment(self, *, prompt_template: str | None, record: dict, prompt_type: PromptType) -> None:
54
- try:
55
- self.prepare_jinja2_multi_template_renderer(
56
- template_name=prompt_type.value,
57
- prompt_template=prompt_template,
58
- dataset_variables=list(record.keys()),
59
- )
60
- except (UserTemplateUnsupportedFiltersError, UserTemplateError) as exc:
61
- template_variables = extract_keywords_from_jinja2_template(prompt_template)
62
- missing_columns = list(set(template_variables) - set(record.keys()))
63
-
64
- error_msg = (
65
- f"There was an error preparing the {prompt_type.value.replace('_', ' ')} "
66
- "template. Please double check that the template is valid Jinja2 syntax, that all "
67
- "referenced variables are defined, and that any filters you are using are supported."
68
- )
69
- if len(missing_columns) > 0:
70
- error_msg += f"\nThe following {missing_columns} columns are missing!"
71
- if self._error_message_context is not None:
72
- error_msg += f"\n{json.dumps(self._error_message_context, indent=2)}"
73
- logger.error(f"🛑 {error_msg}")
74
- raise PromptTemplateRenderError(f"{exc!s} {error_msg}")
75
-
76
-
77
- def create_response_recipe(
78
- column_config: SingleColumnConfig, model_config: ModelConfig | None = None
79
- ) -> ResponseRecipe:
80
- if model_config and column_config.model_alias != model_config.alias:
81
- raise ValueError(
82
- f"Column config model alias {column_config.model_alias} does not match model config alias {model_config.alias}"
83
- )
84
- if column_config.column_type == DataDesignerColumnType.LLM_TEXT:
85
- return TextResponseRecipe()
86
- if column_config.column_type == DataDesignerColumnType.LLM_CODE:
87
- return CodeResponseRecipe(
88
- syntax=CodeLang.parse_lang(column_config.code_lang),
89
- )
90
- if column_config.column_type == DataDesignerColumnType.LLM_STRUCTURED:
91
- return StructuredResponseRecipe(
92
- json_schema=column_config.output_format,
93
- )
94
- if column_config.column_type == DataDesignerColumnType.LLM_JUDGE:
95
- return PydanticResponseRecipe(
96
- data_type=create_judge_structured_output_model(
97
- [create_judge_response_model(s) for s in column_config.scores]
98
- ),
99
- )
100
- raise ValueError(f"No response recipe found for column type: {column_config.column_type}")
@@ -1,97 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import logging
7
-
8
- from data_designer.config.column_configs import SamplerColumnConfig, SeedDatasetColumnConfig
9
- from data_designer.config.data_designer_config import DataDesignerConfig
10
- from data_designer.config.errors import InvalidConfigError
11
- from data_designer.config.sampler_params import UUIDSamplerParams
12
- from data_designer.engine.resources.resource_provider import ResourceProvider
13
- from data_designer.engine.resources.seed_reader import SeedReader
14
- from data_designer.engine.validation import ViolationLevel, rich_print_violations, validate_data_designer_config
15
-
16
- logger = logging.getLogger(__name__)
17
-
18
-
19
- def compile_data_designer_config(config: DataDesignerConfig, resource_provider: ResourceProvider) -> DataDesignerConfig:
20
- _resolve_and_add_seed_columns(config, resource_provider.seed_reader)
21
- _add_internal_row_id_column_if_needed(config)
22
- _validate(config)
23
- return config
24
-
25
-
26
- def _resolve_and_add_seed_columns(config: DataDesignerConfig, seed_reader: SeedReader | None) -> None:
27
- """Fetches the seed dataset column names, ensures there are no conflicts
28
- with other columns, and adds seed column configs to the DataDesignerConfig.
29
- """
30
-
31
- if not seed_reader:
32
- return
33
-
34
- seed_col_names = seed_reader.get_column_names()
35
- existing_columns = {column.name for column in config.columns}
36
- colliding_columns = {name for name in seed_col_names if name in existing_columns}
37
- if colliding_columns:
38
- raise InvalidConfigError(
39
- f"🛑 Seed dataset column(s) {colliding_columns} collide with existing column(s). "
40
- "Please remove the conflicting columns or use a seed dataset with different column names."
41
- )
42
-
43
- config.columns.extend([SeedDatasetColumnConfig(name=col_name) for col_name in seed_col_names])
44
-
45
-
46
- def _add_internal_row_id_column_if_needed(config: DataDesignerConfig) -> None:
47
- """Adds a UUID sampler column named '_internal_row_id' (set to drop) if needed to enable generation.
48
-
49
- Generation requires either:
50
- - At least one sampler column (which can generate data from scratch), OR
51
- - A seed dataset (which provides initial data rows)
52
-
53
- If neither exists, a UUID sampler column '_internal_row_id' is automatically added and marked for drop
54
- to enable the generation process to start.
55
-
56
- Args:
57
- config: The DataDesigner configuration to potentially modify.
58
- """
59
- has_sampler_column = any(isinstance(col, SamplerColumnConfig) for col in config.columns)
60
- has_seed_dataset_column = any(isinstance(col, SeedDatasetColumnConfig) for col in config.columns)
61
-
62
- if not has_sampler_column and not has_seed_dataset_column:
63
- logger.warning(
64
- "🔔 No sampler column or seed dataset detected. Adding UUID column '_internal_row_id' (marked for drop) to enable generation."
65
- )
66
- id_column = SamplerColumnConfig(
67
- name="_internal_row_id",
68
- sampler_type="uuid",
69
- params=UUIDSamplerParams(),
70
- drop=True,
71
- )
72
- config.columns.insert(0, id_column)
73
-
74
-
75
- def _validate(config: DataDesignerConfig) -> None:
76
- allowed_references = _get_allowed_references(config)
77
- violations = validate_data_designer_config(
78
- columns=config.columns,
79
- processor_configs=config.processors or [],
80
- allowed_references=allowed_references,
81
- )
82
- rich_print_violations(violations)
83
- if len([v for v in violations if v.level == ViolationLevel.ERROR]) > 0:
84
- raise InvalidConfigError(
85
- "🛑 Your configuration contains validation errors. Please address the indicated issues and try again."
86
- )
87
- if len(violations) == 0:
88
- logger.info("✅ Validation passed")
89
-
90
-
91
- def _get_allowed_references(config: DataDesignerConfig) -> list[str]:
92
- refs = set[str]()
93
- for column_config in config.columns:
94
- refs.add(column_config.name)
95
- for side_effect_column in column_config.side_effect_columns:
96
- refs.add(side_effect_column)
97
- return list(refs)
@@ -1,71 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- from abc import ABC
7
- from pathlib import Path
8
- from typing import TYPE_CHECKING, Generic, TypeVar, get_origin
9
-
10
- from data_designer.config.base import ConfigBase
11
- from data_designer.engine.dataset_builders.artifact_storage import ArtifactStorage
12
- from data_designer.engine.resources.resource_provider import ResourceProvider
13
- from data_designer.lazy_heavy_imports import pd
14
-
15
- if TYPE_CHECKING:
16
- import pandas as pd
17
-
18
- DataT = TypeVar("DataT", dict, pd.DataFrame)
19
- TaskConfigT = TypeVar("ConfigT", bound=ConfigBase)
20
-
21
-
22
- class ConfigurableTask(ABC, Generic[TaskConfigT]):
23
- def __init__(self, config: TaskConfigT, resource_provider: ResourceProvider):
24
- self._config = self.get_config_type().model_validate(config)
25
- self._resource_provider = resource_provider
26
- self._validate()
27
- self._initialize()
28
-
29
- @classmethod
30
- def get_config_type(cls) -> type[TaskConfigT]:
31
- for base in cls.__orig_bases__:
32
- if hasattr(base, "__args__") and len(base.__args__) == 1:
33
- arg = base.__args__[0]
34
- origin = get_origin(arg) or arg
35
- if isinstance(origin, type) and issubclass(origin, ConfigBase):
36
- return base.__args__[0]
37
- raise TypeError(
38
- f"Could not determine config type for `{cls.__name__}`. Please ensure that the "
39
- "`ConfigurableTask` is defined with a generic type argument, where the type argument "
40
- "is a subclass of `ConfigBase`."
41
- )
42
-
43
- @property
44
- def artifact_path(self) -> Path:
45
- return self.artifact_storage.artifact_path
46
-
47
- @property
48
- def artifact_storage(self) -> ArtifactStorage:
49
- return self.resource_provider.artifact_storage
50
-
51
- @property
52
- def base_dataset_path(self) -> Path:
53
- return self.artifact_storage.base_dataset_path
54
-
55
- @property
56
- def config(self) -> TaskConfigT:
57
- return self._config
58
-
59
- @property
60
- def name(self) -> str:
61
- return self.__class__.__name__
62
-
63
- @property
64
- def resource_provider(self) -> ResourceProvider:
65
- return self._resource_provider
66
-
67
- def _initialize(self) -> None:
68
- """An internal method for custom initialization logic, which will be called in the constructor."""
69
-
70
- def _validate(self) -> None:
71
- """An internal method for custom validation logic, which will be called in the constructor."""
@@ -1,283 +0,0 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from __future__ import annotations
5
-
6
- import json
7
- import logging
8
- import shutil
9
- from datetime import datetime
10
- from functools import cached_property
11
- from pathlib import Path
12
- from typing import TYPE_CHECKING
13
-
14
- from pydantic import BaseModel, field_validator, model_validator
15
-
16
- from data_designer.config.utils.io_helpers import read_parquet_dataset
17
- from data_designer.config.utils.type_helpers import StrEnum, resolve_string_enum
18
- from data_designer.engine.dataset_builders.errors import ArtifactStorageError
19
- from data_designer.lazy_heavy_imports import pd
20
-
21
- if TYPE_CHECKING:
22
- import pandas as pd
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
- BATCH_FILE_NAME_FORMAT = "batch_{batch_number:05d}.parquet"
27
- SDG_CONFIG_FILENAME = "sdg.json"
28
-
29
-
30
- class BatchStage(StrEnum):
31
- PARTIAL_RESULT = "partial_results_path"
32
- FINAL_RESULT = "final_dataset_path"
33
- DROPPED_COLUMNS = "dropped_columns_dataset_path"
34
- PROCESSORS_OUTPUTS = "processors_outputs_path"
35
-
36
-
37
- class ArtifactStorage(BaseModel):
38
- artifact_path: Path | str
39
- dataset_name: str = "dataset"
40
- final_dataset_folder_name: str = "parquet-files"
41
- partial_results_folder_name: str = "tmp-partial-parquet-files"
42
- dropped_columns_folder_name: str = "dropped-columns-parquet-files"
43
- processors_outputs_folder_name: str = "processors-files"
44
-
45
- @property
46
- def artifact_path_exists(self) -> bool:
47
- return self.artifact_path.exists()
48
-
49
- @cached_property
50
- def resolved_dataset_name(self) -> str:
51
- dataset_path = self.artifact_path / self.dataset_name
52
- if dataset_path.exists() and len(list(dataset_path.iterdir())) > 0:
53
- new_dataset_name = f"{self.dataset_name}_{datetime.now().strftime('%m-%d-%Y_%H%M%S')}"
54
- logger.info(
55
- f"📂 Dataset path {str(dataset_path)!r} already exists. Dataset from this session"
56
- f"\n\t\t will be saved to {str(self.artifact_path / new_dataset_name)!r} instead."
57
- )
58
- return new_dataset_name
59
- return self.dataset_name
60
-
61
- @property
62
- def base_dataset_path(self) -> Path:
63
- return self.artifact_path / self.resolved_dataset_name
64
-
65
- @property
66
- def dropped_columns_dataset_path(self) -> Path:
67
- return self.base_dataset_path / self.dropped_columns_folder_name
68
-
69
- @property
70
- def final_dataset_path(self) -> Path:
71
- return self.base_dataset_path / self.final_dataset_folder_name
72
-
73
- @property
74
- def metadata_file_path(self) -> Path:
75
- return self.base_dataset_path / "metadata.json"
76
-
77
- @property
78
- def partial_results_path(self) -> Path:
79
- return self.base_dataset_path / self.partial_results_folder_name
80
-
81
- @property
82
- def processors_outputs_path(self) -> Path:
83
- return self.base_dataset_path / self.processors_outputs_folder_name
84
-
85
- @field_validator("artifact_path")
86
- def validate_artifact_path(cls, v: Path | str) -> Path:
87
- v = Path(v)
88
- if not v.is_dir():
89
- raise ArtifactStorageError("Artifact path must exist and be a directory")
90
- return v
91
-
92
- @model_validator(mode="after")
93
- def validate_folder_names(self):
94
- folder_names = [
95
- self.dataset_name,
96
- self.final_dataset_folder_name,
97
- self.partial_results_folder_name,
98
- self.dropped_columns_folder_name,
99
- self.processors_outputs_folder_name,
100
- ]
101
-
102
- for name in folder_names:
103
- if len(name) == 0:
104
- raise ArtifactStorageError("🛑 Directory names must be non-empty strings.")
105
-
106
- if len(set(folder_names)) != len(folder_names):
107
- raise ArtifactStorageError("🛑 Folder names must be unique (no collisions allowed).")
108
-
109
- invalid_chars = {"<", ">", ":", '"', "/", "\\", "|", "?", "*"}
110
- for name in folder_names:
111
- if any(char in invalid_chars for char in name):
112
- raise ArtifactStorageError(f"🛑 Directory name '{name}' contains invalid characters.")
113
-
114
- return self
115
-
116
- @staticmethod
117
- def mkdir_if_needed(path: Path | str) -> Path:
118
- """Create the directory if it does not exist."""
119
- path = Path(path)
120
- if not path.exists():
121
- logger.debug(f"📁 Creating directory: {path}")
122
- path.mkdir(parents=True, exist_ok=True)
123
- return path
124
-
125
- @staticmethod
126
- def read_parquet_files(path: Path) -> pd.DataFrame:
127
- return read_parquet_dataset(path)
128
-
129
- def create_batch_file_path(
130
- self,
131
- batch_number: int,
132
- batch_stage: BatchStage,
133
- ) -> Path:
134
- if batch_number < 0:
135
- raise ArtifactStorageError("🛑 Batch number must be non-negative.")
136
- return self._get_stage_path(batch_stage) / BATCH_FILE_NAME_FORMAT.format(batch_number=batch_number)
137
-
138
- def load_dataset(self, batch_stage: BatchStage = BatchStage.FINAL_RESULT) -> pd.DataFrame:
139
- return read_parquet_dataset(self._get_stage_path(batch_stage))
140
-
141
- def load_dataset_with_dropped_columns(self) -> pd.DataFrame:
142
- # The pyarrow backend has better support for nested data types.
143
- df = self.load_dataset()
144
- if (
145
- self.dropped_columns_dataset_path.exists()
146
- and self.create_batch_file_path(0, BatchStage.DROPPED_COLUMNS).is_file()
147
- ):
148
- logger.debug("Concatenating dropped columns to the final dataset.")
149
- df_dropped = self.load_dataset(batch_stage=BatchStage.DROPPED_COLUMNS)
150
- if len(df_dropped) != len(df):
151
- raise ArtifactStorageError(
152
- "🛑 The dropped-columns dataset has a different number of rows than the main dataset. "
153
- "Something unexpected must have happened to the dataset builder's artifacts."
154
- )
155
- # To ensure indexes are aligned and avoid silent misalignment (which would introduce NaNs),
156
- # check that the indexes are identical before concatenation.
157
- if not df.index.equals(df_dropped.index):
158
- raise ArtifactStorageError(
159
- "🛑 The indexes of the main and dropped columns DataFrames are not aligned. "
160
- "Something unexpected must have happened to the dataset builder's artifacts."
161
- )
162
- df = pd.concat([df, df_dropped], axis=1)
163
- return df
164
-
165
- def move_partial_result_to_final_file_path(self, batch_number: int) -> Path:
166
- partial_result_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.PARTIAL_RESULT)
167
- if not partial_result_path.exists():
168
- raise ArtifactStorageError("🛑 Partial result file not found.")
169
- self.mkdir_if_needed(self._get_stage_path(BatchStage.FINAL_RESULT))
170
- final_file_path = self.create_batch_file_path(batch_number, batch_stage=BatchStage.FINAL_RESULT)
171
- shutil.move(partial_result_path, final_file_path)
172
- return final_file_path
173
-
174
- def write_batch_to_parquet_file(
175
- self,
176
- batch_number: int,
177
- dataframe: pd.DataFrame,
178
- batch_stage: BatchStage,
179
- subfolder: str | None = None,
180
- ) -> Path:
181
- file_path = self.create_batch_file_path(batch_number, batch_stage=batch_stage)
182
- self.write_parquet_file(file_path.name, dataframe, batch_stage, subfolder=subfolder)
183
- return file_path
184
-
185
- def write_parquet_file(
186
- self,
187
- parquet_file_name: str,
188
- dataframe: pd.DataFrame,
189
- batch_stage: BatchStage,
190
- subfolder: str | None = None,
191
- ) -> Path:
192
- subfolder = subfolder or ""
193
- self.mkdir_if_needed(self._get_stage_path(batch_stage) / subfolder)
194
- file_path = self._get_stage_path(batch_stage) / subfolder / parquet_file_name
195
- dataframe.to_parquet(file_path, index=False)
196
- return file_path
197
-
198
- def get_parquet_file_paths(self) -> list[str]:
199
- """Get list of parquet file paths relative to base_dataset_path.
200
-
201
- Returns:
202
- List of relative paths to parquet files in the final dataset folder.
203
- """
204
- return [str(f.relative_to(self.base_dataset_path)) for f in sorted(self.final_dataset_path.glob("*.parquet"))]
205
-
206
- def get_processor_file_paths(self) -> dict[str, list[str]]:
207
- """Get processor output files organized by processor name.
208
-
209
- Returns:
210
- Dictionary mapping processor names to lists of relative file paths.
211
- """
212
- processor_files: dict[str, list[str]] = {}
213
- if self.processors_outputs_path.exists():
214
- for processor_dir in sorted(self.processors_outputs_path.iterdir()):
215
- if processor_dir.is_dir():
216
- processor_name = processor_dir.name
217
- processor_files[processor_name] = [
218
- str(f.relative_to(self.base_dataset_path))
219
- for f in sorted(processor_dir.rglob("*"))
220
- if f.is_file()
221
- ]
222
- return processor_files
223
-
224
- def get_file_paths(self) -> dict[str, list[str] | dict[str, list[str]]]:
225
- """Get all file paths organized by type.
226
-
227
- Returns:
228
- Dictionary with 'parquet-files' and 'processor-files' keys.
229
- """
230
- file_paths = {
231
- "parquet-files": self.get_parquet_file_paths(),
232
- }
233
- processor_file_paths = self.get_processor_file_paths()
234
- if processor_file_paths:
235
- file_paths["processor-files"] = processor_file_paths
236
-
237
- return file_paths
238
-
239
- def read_metadata(self) -> dict:
240
- """Read metadata from the metadata.json file.
241
-
242
- Returns:
243
- Dictionary containing the metadata.
244
-
245
- Raises:
246
- FileNotFoundError: If metadata file doesn't exist.
247
- """
248
- with open(self.metadata_file_path, "r") as file:
249
- return json.load(file)
250
-
251
- def write_metadata(self, metadata: dict) -> Path:
252
- """Write metadata to the metadata.json file.
253
-
254
- Args:
255
- metadata: Dictionary containing metadata to write.
256
-
257
- Returns:
258
- Path to the written metadata file.
259
- """
260
- self.mkdir_if_needed(self.base_dataset_path)
261
- with open(self.metadata_file_path, "w") as file:
262
- json.dump(metadata, file, indent=4, sort_keys=True)
263
- return self.metadata_file_path
264
-
265
- def update_metadata(self, updates: dict) -> Path:
266
- """Update existing metadata with new fields.
267
-
268
- Args:
269
- updates: Dictionary of fields to add/update in metadata.
270
-
271
- Returns:
272
- Path to the updated metadata file.
273
- """
274
- try:
275
- existing_metadata = self.read_metadata()
276
- except FileNotFoundError:
277
- existing_metadata = {}
278
-
279
- existing_metadata.update(updates)
280
- return self.write_metadata(existing_metadata)
281
-
282
- def _get_stage_path(self, stage: BatchStage) -> Path:
283
- return getattr(self, resolve_string_enum(stage, BatchStage).value)