data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import random
|
|
8
|
+
from typing import Union
|
|
9
|
+
|
|
10
|
+
from data_designer.config.analysis.column_profilers import (
|
|
11
|
+
JudgeScoreProfilerConfig,
|
|
12
|
+
JudgeScoreProfilerResults,
|
|
13
|
+
JudgeScoreSample,
|
|
14
|
+
JudgeScoreSummary,
|
|
15
|
+
)
|
|
16
|
+
from data_designer.config.analysis.column_statistics import (
|
|
17
|
+
CategoricalDistribution,
|
|
18
|
+
CategoricalHistogramData,
|
|
19
|
+
ColumnDistributionType,
|
|
20
|
+
MissingValue,
|
|
21
|
+
NumericalDistribution,
|
|
22
|
+
)
|
|
23
|
+
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP, DataDesignerColumnType
|
|
24
|
+
from data_designer.engine.analysis.column_profilers.base import (
|
|
25
|
+
ColumnConfigWithDataFrame,
|
|
26
|
+
ColumnProfiler,
|
|
27
|
+
ColumnProfilerMetadata,
|
|
28
|
+
)
|
|
29
|
+
from data_designer.engine.analysis.utils.judge_score_processing import (
|
|
30
|
+
extract_judge_score_distributions,
|
|
31
|
+
sample_scores_and_reasoning,
|
|
32
|
+
)
|
|
33
|
+
from data_designer.engine.models.facade import ModelFacade
|
|
34
|
+
from data_designer.engine.models.recipes.response_recipes import TextResponseRecipe
|
|
35
|
+
from data_designer.engine.resources.resource_provider import ResourceType
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class JudgeScoreProfiler(ColumnProfiler[JudgeScoreProfilerConfig]):
|
|
41
|
+
@staticmethod
|
|
42
|
+
def metadata() -> ColumnProfilerMetadata:
|
|
43
|
+
return ColumnProfilerMetadata(
|
|
44
|
+
name="judge_score_profiler",
|
|
45
|
+
description="Analyzes LLM-as-judge score distributions in a Data Designer dataset.",
|
|
46
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
47
|
+
applicable_column_types=[DataDesignerColumnType.LLM_JUDGE],
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def get_model(self, model_alias: str) -> ModelFacade:
|
|
51
|
+
return self.resource_provider.model_registry.get_model(model_alias=model_alias)
|
|
52
|
+
|
|
53
|
+
def profile(self, column_config_with_df: ColumnConfigWithDataFrame) -> JudgeScoreProfilerResults:
|
|
54
|
+
column_config, df = column_config_with_df.as_tuple()
|
|
55
|
+
|
|
56
|
+
logger.info(
|
|
57
|
+
f"{COLUMN_TYPE_EMOJI_MAP[column_config.column_type]} Analyzing LLM-as-judge "
|
|
58
|
+
f"scores for column: '{column_config.name}'"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
score_summaries = {}
|
|
62
|
+
score_distributions = extract_judge_score_distributions(column_config, df)
|
|
63
|
+
|
|
64
|
+
if self.config.summary_score_sample_size is None or isinstance(score_distributions, MissingValue):
|
|
65
|
+
return JudgeScoreProfilerResults(
|
|
66
|
+
summaries={},
|
|
67
|
+
column_name=column_config.name,
|
|
68
|
+
score_distributions=score_distributions,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
for score in column_config.scores:
|
|
72
|
+
score_name = score.name.lower()
|
|
73
|
+
logger.info(f"{random.choice(['👩⚖️', '👨⚖️'])} Summarizing LLM-as-judge score: '{score_name}'")
|
|
74
|
+
score_sample = sample_scores_and_reasoning(
|
|
75
|
+
scores=score_distributions.scores[score_name],
|
|
76
|
+
reasoning=score_distributions.reasoning[score_name],
|
|
77
|
+
num_samples=self.config.summary_score_sample_size,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
score_summaries[score_name] = self._summarize_score_sample(
|
|
81
|
+
name=score_name,
|
|
82
|
+
sample=score_sample,
|
|
83
|
+
histogram=score_distributions.histograms[score_name],
|
|
84
|
+
distribution=score_distributions.distributions[score_name],
|
|
85
|
+
distribution_type=score_distributions.distribution_types[score_name],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return JudgeScoreProfilerResults(
|
|
89
|
+
column_name=column_config.name,
|
|
90
|
+
summaries=score_summaries,
|
|
91
|
+
score_distributions=score_distributions,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def _summarize_score_sample(
|
|
95
|
+
self,
|
|
96
|
+
name: str,
|
|
97
|
+
sample: list[JudgeScoreSample],
|
|
98
|
+
histogram: CategoricalHistogramData,
|
|
99
|
+
distribution: Union[CategoricalDistribution, NumericalDistribution, MissingValue],
|
|
100
|
+
distribution_type: ColumnDistributionType,
|
|
101
|
+
) -> JudgeScoreSummary:
|
|
102
|
+
if isinstance(distribution, MissingValue) or not sample:
|
|
103
|
+
return JudgeScoreSummary(
|
|
104
|
+
score_name=name,
|
|
105
|
+
summary="No judge score information available to summarize.",
|
|
106
|
+
score_samples=sample,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
category_info = []
|
|
110
|
+
total_count = sum(histogram.counts)
|
|
111
|
+
for cat, count in zip(histogram.categories, histogram.counts):
|
|
112
|
+
percentage = (count / total_count) * 100
|
|
113
|
+
category_info.append(f"{cat}: {count} records ({percentage:.1f}%)")
|
|
114
|
+
|
|
115
|
+
distribution_context = f"Score distribution - {', '.join(category_info)}, "
|
|
116
|
+
if distribution_type == ColumnDistributionType.CATEGORICAL:
|
|
117
|
+
distribution_context += f"Most common value: {distribution.most_common_value}. "
|
|
118
|
+
if distribution_type == ColumnDistributionType.NUMERICAL:
|
|
119
|
+
distribution_context += f"Mean score: {distribution.mean:.2f}. "
|
|
120
|
+
|
|
121
|
+
logger.info(f" |-- number of score samples: {len(sample)}")
|
|
122
|
+
logger.info(f" |-- {distribution_context.lower()}")
|
|
123
|
+
|
|
124
|
+
combined_reasoning = "\n".join([r.reasoning for r in sample])
|
|
125
|
+
prompt = (
|
|
126
|
+
f"Based on the following evaluator reasoning for the '{name}' criterion, "
|
|
127
|
+
"provide a concise summary that captures both the strengths and areas for improvement mentioned. "
|
|
128
|
+
"Be specific about what worked well and what needs improvement.\n\n"
|
|
129
|
+
f"Overall distribution of scores: {distribution_context}"
|
|
130
|
+
f"\nA sample of reasoning:\n{combined_reasoning}\n\n"
|
|
131
|
+
"Do not include any titles like `Summary` or `Summary:`. "
|
|
132
|
+
"Do not wrap the summary in quotation marks. "
|
|
133
|
+
"YOU WILL PRODUCE LESS THAN 75 WORDS in a readable sentence format. "
|
|
134
|
+
"No need to use bullets or headers. Write naturally."
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
system_prompt = (
|
|
138
|
+
"You are an expert at distilling complex feedback into concise summaries. "
|
|
139
|
+
"Focus on specificity and balance, incorporating both the distribution context and individual reasoning examples."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
model = self.get_model(self.config.model_alias)
|
|
144
|
+
recipe = TextResponseRecipe()
|
|
145
|
+
summary, _ = model.generate(
|
|
146
|
+
prompt=recipe.apply_recipe_to_user_prompt(prompt),
|
|
147
|
+
system_prompt=recipe.apply_recipe_to_system_prompt(system_prompt),
|
|
148
|
+
parser=recipe.parse,
|
|
149
|
+
)
|
|
150
|
+
return JudgeScoreSummary(
|
|
151
|
+
score_name=name,
|
|
152
|
+
summary=summary.strip(),
|
|
153
|
+
score_samples=sample,
|
|
154
|
+
)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
return JudgeScoreSummary(
|
|
157
|
+
score_name=name,
|
|
158
|
+
summary=f"Score summarization failed: {e}",
|
|
159
|
+
score_samples=sample,
|
|
160
|
+
)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from data_designer.config.analysis.column_profilers import ColumnProfilerType
|
|
5
|
+
from data_designer.config.base import ConfigBase
|
|
6
|
+
from data_designer.engine.analysis.column_profilers.base import ColumnProfiler
|
|
7
|
+
from data_designer.engine.analysis.column_profilers.judge_score_profiler import (
|
|
8
|
+
JudgeScoreProfiler,
|
|
9
|
+
JudgeScoreProfilerConfig,
|
|
10
|
+
)
|
|
11
|
+
from data_designer.engine.registry.base import TaskRegistry
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ColumnProfilerRegistry(TaskRegistry[ColumnProfilerType, ColumnProfiler, ConfigBase]): ...
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def create_default_column_profiler_registry() -> ColumnProfilerRegistry:
|
|
18
|
+
registry = ColumnProfilerRegistry()
|
|
19
|
+
registry.register(ColumnProfilerType.JUDGE_SCORE, JudgeScoreProfiler, JudgeScoreProfilerConfig, False)
|
|
20
|
+
return registry
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from typing import Any, Type, TypeAlias, Union
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from typing_extensions import Self
|
|
12
|
+
|
|
13
|
+
from data_designer.config.analysis.column_statistics import (
|
|
14
|
+
DEFAULT_COLUMN_STATISTICS_MAP,
|
|
15
|
+
ColumnStatisticsT,
|
|
16
|
+
GeneralColumnStatistics,
|
|
17
|
+
)
|
|
18
|
+
from data_designer.config.column_types import ColumnConfigT, DataDesignerColumnType
|
|
19
|
+
from data_designer.config.sampler_params import SamplerType, is_numerical_sampler_type
|
|
20
|
+
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame
|
|
21
|
+
from data_designer.engine.analysis.utils.column_statistics_calculations import (
|
|
22
|
+
ColumnDistributionType,
|
|
23
|
+
calculate_column_distribution,
|
|
24
|
+
calculate_general_column_info,
|
|
25
|
+
calculate_token_stats,
|
|
26
|
+
calculate_validation_column_info,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class GeneralColumnStatisticsCalculator(BaseModel):
|
|
33
|
+
column_config_with_df: ColumnConfigWithDataFrame
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def column_config(self) -> ColumnConfigT:
|
|
37
|
+
return self.column_config_with_df.column_config
|
|
38
|
+
|
|
39
|
+
@property
|
|
40
|
+
def df(self) -> pd.DataFrame:
|
|
41
|
+
return self.column_config_with_df.df
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def column_statistics_type(self) -> Type[ColumnStatisticsT]:
|
|
45
|
+
return DEFAULT_COLUMN_STATISTICS_MAP.get(self.column_config.column_type, GeneralColumnStatistics)
|
|
46
|
+
|
|
47
|
+
def calculate(self) -> Self:
|
|
48
|
+
"""Calculate all the column statistics fields for the given column configuration and dataset profiler.
|
|
49
|
+
|
|
50
|
+
This method dynamically collects all class methods prefixed with 'calculate_' and invokes them to
|
|
51
|
+
compute various column statistics, aggregating their results into a single statistics object.
|
|
52
|
+
"""
|
|
53
|
+
calculate_methods = [
|
|
54
|
+
name for name in dir(self) if name.startswith("calculate_") and callable(getattr(self, name))
|
|
55
|
+
]
|
|
56
|
+
return self.column_statistics_type(
|
|
57
|
+
column_name=self.column_config.name,
|
|
58
|
+
**{k: v for name in calculate_methods for k, v in getattr(self, name)().items()},
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def calculate_general_column_info(self) -> dict[str, Any]:
|
|
62
|
+
return calculate_general_column_info(self.column_config, self.df)
|
|
63
|
+
|
|
64
|
+
def __repr__(self) -> str:
|
|
65
|
+
params = []
|
|
66
|
+
for field, value in self.model_dump(mode="json").items():
|
|
67
|
+
params.append(f" {field}: {value}")
|
|
68
|
+
params_str = "\n".join(params)
|
|
69
|
+
return f"{self.__class__.__name__}(\n{params_str}\n)"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LLMTextColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
|
|
73
|
+
def calculate_token_stats(self) -> dict[str, Any]:
|
|
74
|
+
return calculate_token_stats(self.column_config, self.df)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class LLMCodeColumnStatisticsCalculator(LLMTextColumnStatisticsCalculator): ...
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class LLMStructuredColumnStatisticsCalculator(LLMTextColumnStatisticsCalculator): ...
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class LLMJudgedColumnStatisticsCalculator(LLMTextColumnStatisticsCalculator): ...
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class SamplerColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
|
|
87
|
+
def calculate_sampler_distribution(self) -> dict[str, Any]:
|
|
88
|
+
make_dist, dist_type = False, ColumnDistributionType.OTHER
|
|
89
|
+
if self.column_config.sampler_type in [SamplerType.CATEGORY, SamplerType.SUBCATEGORY]:
|
|
90
|
+
make_dist, dist_type = True, ColumnDistributionType.CATEGORICAL
|
|
91
|
+
elif is_numerical_sampler_type(self.column_config.sampler_type):
|
|
92
|
+
make_dist, dist_type = True, ColumnDistributionType.NUMERICAL
|
|
93
|
+
return (
|
|
94
|
+
{
|
|
95
|
+
"sampler_type": SamplerType(self.column_config.sampler_type),
|
|
96
|
+
**calculate_column_distribution(self.column_config, self.df, dist_type),
|
|
97
|
+
}
|
|
98
|
+
if make_dist
|
|
99
|
+
else {
|
|
100
|
+
"sampler_type": SamplerType(self.column_config.sampler_type),
|
|
101
|
+
"distribution_type": dist_type,
|
|
102
|
+
"distribution": None,
|
|
103
|
+
}
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SeedDatasetColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class ValidationColumnStatisticsCalculator(GeneralColumnStatisticsCalculator):
|
|
111
|
+
def calculate_validation_column_info(self) -> dict[str, Any]:
|
|
112
|
+
return calculate_validation_column_info(self.column_config, self.df)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
class ExpressionColumnStatisticsCalculator(GeneralColumnStatisticsCalculator): ...
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
ColumnStatisticsCalculatorT: TypeAlias = Union[
|
|
119
|
+
ExpressionColumnStatisticsCalculator,
|
|
120
|
+
ValidationColumnStatisticsCalculator,
|
|
121
|
+
GeneralColumnStatisticsCalculator,
|
|
122
|
+
LLMCodeColumnStatisticsCalculator,
|
|
123
|
+
LLMJudgedColumnStatisticsCalculator,
|
|
124
|
+
LLMStructuredColumnStatisticsCalculator,
|
|
125
|
+
LLMTextColumnStatisticsCalculator,
|
|
126
|
+
SamplerColumnStatisticsCalculator,
|
|
127
|
+
SeedDatasetColumnStatisticsCalculator,
|
|
128
|
+
]
|
|
129
|
+
DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP = {
|
|
130
|
+
DataDesignerColumnType.EXPRESSION: ExpressionColumnStatisticsCalculator,
|
|
131
|
+
DataDesignerColumnType.VALIDATION: ValidationColumnStatisticsCalculator,
|
|
132
|
+
DataDesignerColumnType.LLM_CODE: LLMCodeColumnStatisticsCalculator,
|
|
133
|
+
DataDesignerColumnType.LLM_JUDGE: LLMJudgedColumnStatisticsCalculator,
|
|
134
|
+
DataDesignerColumnType.LLM_STRUCTURED: LLMStructuredColumnStatisticsCalculator,
|
|
135
|
+
DataDesignerColumnType.LLM_TEXT: LLMTextColumnStatisticsCalculator,
|
|
136
|
+
DataDesignerColumnType.SAMPLER: SamplerColumnStatisticsCalculator,
|
|
137
|
+
DataDesignerColumnType.SEED_DATASET: SeedDatasetColumnStatisticsCalculator,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def get_column_statistics_calculator(column_type: DataDesignerColumnType) -> ColumnStatisticsCalculatorT:
|
|
142
|
+
return DEFAULT_COLUMN_STATISTICS_CALCULATOR_MAP.get(column_type, GeneralColumnStatisticsCalculator)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from functools import cached_property
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
from pydantic import Field, field_validator
|
|
10
|
+
|
|
11
|
+
from data_designer.config.analysis.column_profilers import ColumnProfilerConfigT
|
|
12
|
+
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
13
|
+
from data_designer.config.base import ConfigBase
|
|
14
|
+
from data_designer.config.column_configs import SingleColumnConfig
|
|
15
|
+
from data_designer.config.column_types import (
|
|
16
|
+
COLUMN_TYPE_EMOJI_MAP,
|
|
17
|
+
ColumnConfigT,
|
|
18
|
+
)
|
|
19
|
+
from data_designer.engine.analysis.column_profilers.base import ColumnConfigWithDataFrame, ColumnProfiler
|
|
20
|
+
from data_designer.engine.analysis.column_statistics import get_column_statistics_calculator
|
|
21
|
+
from data_designer.engine.analysis.errors import DatasetProfilerConfigurationError
|
|
22
|
+
from data_designer.engine.dataset_builders.multi_column_configs import (
|
|
23
|
+
DatasetBuilderColumnConfigT,
|
|
24
|
+
MultiColumnConfig,
|
|
25
|
+
)
|
|
26
|
+
from data_designer.engine.registry.data_designer_registry import DataDesignerRegistry
|
|
27
|
+
from data_designer.engine.resources.resource_provider import ResourceProvider
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DatasetProfilerConfig(ConfigBase):
|
|
33
|
+
column_configs: Sequence[DatasetBuilderColumnConfigT] = Field(..., min_length=1)
|
|
34
|
+
column_profiler_configs: Sequence[ColumnProfilerConfigT] | None = None
|
|
35
|
+
|
|
36
|
+
@field_validator("column_configs")
|
|
37
|
+
def flatten_and_validate_column_configs(cls, v: list[DatasetBuilderColumnConfigT]) -> list[ColumnConfigT]:
|
|
38
|
+
column_configs = []
|
|
39
|
+
for config in v:
|
|
40
|
+
if isinstance(config, SingleColumnConfig) and not config.drop:
|
|
41
|
+
column_configs.append(config)
|
|
42
|
+
elif isinstance(config, MultiColumnConfig):
|
|
43
|
+
column_configs.extend([c for c in config.columns if not c.drop])
|
|
44
|
+
if len(column_configs) == 0:
|
|
45
|
+
raise DatasetProfilerConfigurationError("All columns were dropped!")
|
|
46
|
+
return column_configs
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class DataDesignerDatasetProfiler:
|
|
50
|
+
def __init__(self, config: DatasetProfilerConfig, resource_provider: ResourceProvider):
|
|
51
|
+
self.config = config
|
|
52
|
+
self.resource_provider = resource_provider
|
|
53
|
+
self._validate_column_profiler_configs()
|
|
54
|
+
|
|
55
|
+
@cached_property
|
|
56
|
+
def column_names_from_configs(self) -> list[str]:
|
|
57
|
+
return [c.name for c in self.config.column_configs]
|
|
58
|
+
|
|
59
|
+
@cached_property
|
|
60
|
+
def registry(self) -> DataDesignerRegistry:
|
|
61
|
+
return DataDesignerRegistry()
|
|
62
|
+
|
|
63
|
+
def profile_dataset(
|
|
64
|
+
self,
|
|
65
|
+
target_num_records: int,
|
|
66
|
+
dataset: pd.DataFrame,
|
|
67
|
+
) -> DatasetProfilerResults:
|
|
68
|
+
logger.info("📐 Measuring dataset column statistics:")
|
|
69
|
+
|
|
70
|
+
self._validate_schema_consistency(list(dataset.columns))
|
|
71
|
+
|
|
72
|
+
column_statistics = []
|
|
73
|
+
for c in self.config.column_configs:
|
|
74
|
+
logger.info(f" |-- {COLUMN_TYPE_EMOJI_MAP[c.column_type]} column: '{c.name}'")
|
|
75
|
+
column_statistics.append(
|
|
76
|
+
get_column_statistics_calculator(c.column_type)(
|
|
77
|
+
column_config_with_df=ColumnConfigWithDataFrame(column_config=c, df=dataset)
|
|
78
|
+
).calculate()
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
column_profiles = []
|
|
82
|
+
for profiler_config in self.config.column_profiler_configs or []:
|
|
83
|
+
profiler = self._create_column_profiler(profiler_config)
|
|
84
|
+
applicable_column_types = profiler.metadata().applicable_column_types
|
|
85
|
+
for c in self.config.column_configs:
|
|
86
|
+
if c.column_type in applicable_column_types:
|
|
87
|
+
params = ColumnConfigWithDataFrame(column_config=c, df=dataset)
|
|
88
|
+
column_profiles.append(profiler.profile(params))
|
|
89
|
+
if len(column_profiles) == 0:
|
|
90
|
+
logger.warning(
|
|
91
|
+
f"⚠️ No applicable column types found for the '{profiler.metadata().name}' profiler. "
|
|
92
|
+
f"This profiler is applicable to the following column types: {applicable_column_types}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return DatasetProfilerResults(
|
|
96
|
+
num_records=len(dataset),
|
|
97
|
+
target_num_records=target_num_records,
|
|
98
|
+
side_effect_column_names=list(set(dataset.columns) - set(self.column_names_from_configs)),
|
|
99
|
+
column_statistics=column_statistics,
|
|
100
|
+
column_profiles=column_profiles if column_profiles else None,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def _create_column_profiler(self, profiler_config: ColumnProfilerConfigT) -> ColumnProfiler:
|
|
104
|
+
return self.registry.column_profilers.get_for_config_type(type(profiler_config))(
|
|
105
|
+
config=profiler_config, resource_provider=self.resource_provider
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
def _validate_column_profiler_configs(self) -> None:
|
|
109
|
+
if self.config.column_profiler_configs:
|
|
110
|
+
if self.resource_provider.model_registry is None:
|
|
111
|
+
raise DatasetProfilerConfigurationError("Model registry is required for column profiler configs")
|
|
112
|
+
self._validate_model_configs()
|
|
113
|
+
|
|
114
|
+
def _validate_model_configs(self) -> None:
|
|
115
|
+
aliases = [alias for alias in self.resource_provider.model_registry.model_configs.keys()]
|
|
116
|
+
for column_config in self.config.column_configs:
|
|
117
|
+
if hasattr(column_config, "model_alias") and column_config.model_alias not in aliases:
|
|
118
|
+
raise DatasetProfilerConfigurationError(
|
|
119
|
+
f"Model config '{column_config.model_alias}' not found in model configs"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def _validate_schema_consistency(self, dataset_column_names: list[str]) -> None:
|
|
123
|
+
for column_name in self.column_names_from_configs:
|
|
124
|
+
if column_name not in dataset_column_names:
|
|
125
|
+
raise DatasetProfilerConfigurationError(f"Column '{column_name}' not found in dataset")
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
from numbers import Number
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import pyarrow as pa
|
|
13
|
+
import tiktoken
|
|
14
|
+
|
|
15
|
+
from data_designer.config.analysis.column_statistics import (
|
|
16
|
+
CategoricalDistribution,
|
|
17
|
+
ColumnDistributionType,
|
|
18
|
+
MissingValue,
|
|
19
|
+
NumericalDistribution,
|
|
20
|
+
)
|
|
21
|
+
from data_designer.config.column_configs import (
|
|
22
|
+
LLMTextColumnConfig,
|
|
23
|
+
SingleColumnConfig,
|
|
24
|
+
ValidationColumnConfig,
|
|
25
|
+
)
|
|
26
|
+
from data_designer.engine.column_generators.generators.llm_generators import (
|
|
27
|
+
PromptType,
|
|
28
|
+
RecordBasedPromptRenderer,
|
|
29
|
+
create_response_recipe,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
RANDOM_SEED = 42
|
|
33
|
+
MAX_PROMPT_SAMPLE_SIZE = 1000
|
|
34
|
+
TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
|
35
|
+
WARNING_PREFIX = "⚠️ Error during column profile calculation: "
|
|
36
|
+
TEXT_FIELD_AVG_SPACE_COUNT_THRESHOLD = 0.1
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def calculate_column_distribution(
|
|
42
|
+
column_config: SingleColumnConfig, df: pd.DataFrame, distribution_type: ColumnDistributionType
|
|
43
|
+
) -> dict[str, CategoricalDistribution | NumericalDistribution | MissingValue | None]:
|
|
44
|
+
distribution_type = ColumnDistributionType(distribution_type)
|
|
45
|
+
try:
|
|
46
|
+
if distribution_type == ColumnDistributionType.CATEGORICAL:
|
|
47
|
+
return {
|
|
48
|
+
"distribution_type": ColumnDistributionType.CATEGORICAL,
|
|
49
|
+
"distribution": CategoricalDistribution.from_series(df[column_config.name]),
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
if distribution_type == ColumnDistributionType.NUMERICAL:
|
|
53
|
+
return {
|
|
54
|
+
"distribution_type": ColumnDistributionType.NUMERICAL,
|
|
55
|
+
"distribution": NumericalDistribution.from_series(df[column_config.name]),
|
|
56
|
+
}
|
|
57
|
+
except Exception as e:
|
|
58
|
+
logger.warning(f"{WARNING_PREFIX} failed to calculate column distribution for '{column_config.name}' {e}")
|
|
59
|
+
return {
|
|
60
|
+
"distribution_type": ColumnDistributionType.UNKNOWN,
|
|
61
|
+
"distribution": MissingValue.CALCULATION_FAILED,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def calculate_general_column_info(column_config: SingleColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
|
|
66
|
+
try:
|
|
67
|
+
_df = pd.DataFrame(df[column_config.name].apply(ensure_hashable))
|
|
68
|
+
return {
|
|
69
|
+
"pyarrow_dtype": str(df[column_config.name].dtype.pyarrow_dtype),
|
|
70
|
+
"simple_dtype": convert_pyarrow_dtype_to_simple_dtype(df[column_config.name].dtype.pyarrow_dtype),
|
|
71
|
+
"num_records": len(_df[column_config.name]),
|
|
72
|
+
"num_null": _df[column_config.name].isnull().sum(),
|
|
73
|
+
"num_unique": _df[column_config.name].nunique(),
|
|
74
|
+
}
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.warning(f"{WARNING_PREFIX} failed to calculate general column info for '{column_config.name}': {e}")
|
|
77
|
+
return {
|
|
78
|
+
"pyarrow_dtype": MissingValue.CALCULATION_FAILED,
|
|
79
|
+
"simple_dtype": MissingValue.CALCULATION_FAILED,
|
|
80
|
+
"num_records": MissingValue.CALCULATION_FAILED,
|
|
81
|
+
"num_null": MissingValue.CALCULATION_FAILED,
|
|
82
|
+
"num_unique": MissingValue.CALCULATION_FAILED,
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def calculate_prompt_token_stats(
|
|
87
|
+
column_config: LLMTextColumnConfig, df: pd.DataFrame
|
|
88
|
+
) -> dict[str, float | MissingValue]:
|
|
89
|
+
try:
|
|
90
|
+
num_tokens = []
|
|
91
|
+
num_samples = min(MAX_PROMPT_SAMPLE_SIZE, len(df))
|
|
92
|
+
renderer = RecordBasedPromptRenderer(response_recipe=create_response_recipe(column_config))
|
|
93
|
+
for record in df.sample(num_samples, random_state=RANDOM_SEED).to_dict(orient="records"):
|
|
94
|
+
system_prompt = renderer.render(
|
|
95
|
+
prompt_template=column_config.system_prompt, record=record, prompt_type=PromptType.SYSTEM_PROMPT
|
|
96
|
+
)
|
|
97
|
+
prompt = renderer.render(
|
|
98
|
+
prompt_template=column_config.prompt, record=record, prompt_type=PromptType.USER_PROMPT
|
|
99
|
+
)
|
|
100
|
+
concatenated_prompt = str(system_prompt + "\n\n" + prompt)
|
|
101
|
+
num_tokens.append(len(TOKENIZER.encode(concatenated_prompt, disallowed_special=())))
|
|
102
|
+
except Exception as e:
|
|
103
|
+
logger.warning(
|
|
104
|
+
f"{WARNING_PREFIX} failed to calculate prompt token stats for column {column_config.name!r}: {e}"
|
|
105
|
+
)
|
|
106
|
+
return {
|
|
107
|
+
"prompt_tokens_mean": MissingValue.CALCULATION_FAILED,
|
|
108
|
+
"prompt_tokens_median": MissingValue.CALCULATION_FAILED,
|
|
109
|
+
"prompt_tokens_stddev": MissingValue.CALCULATION_FAILED,
|
|
110
|
+
}
|
|
111
|
+
return {
|
|
112
|
+
"prompt_tokens_mean": np.mean(num_tokens),
|
|
113
|
+
"prompt_tokens_median": np.median(num_tokens),
|
|
114
|
+
"prompt_tokens_stddev": np.std(num_tokens),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def calculate_completion_token_stats(
|
|
119
|
+
column_config: LLMTextColumnConfig, df: pd.DataFrame
|
|
120
|
+
) -> dict[str, float | MissingValue]:
|
|
121
|
+
try:
|
|
122
|
+
tokens_per_record = df[column_config.name].apply(
|
|
123
|
+
lambda value: len(TOKENIZER.encode(str(value), disallowed_special=()))
|
|
124
|
+
)
|
|
125
|
+
return {
|
|
126
|
+
"completion_tokens_mean": tokens_per_record.mean(),
|
|
127
|
+
"completion_tokens_median": tokens_per_record.median(),
|
|
128
|
+
"completion_tokens_stddev": tokens_per_record.std(),
|
|
129
|
+
}
|
|
130
|
+
except Exception as e:
|
|
131
|
+
logger.warning(
|
|
132
|
+
f"{WARNING_PREFIX} failed to calculate completion token stats for column {column_config.name}: {e}"
|
|
133
|
+
)
|
|
134
|
+
return {
|
|
135
|
+
"completion_tokens_mean": MissingValue.CALCULATION_FAILED,
|
|
136
|
+
"completion_tokens_median": MissingValue.CALCULATION_FAILED,
|
|
137
|
+
"completion_tokens_stddev": MissingValue.CALCULATION_FAILED,
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def calculate_token_stats(column_config: LLMTextColumnConfig, df: pd.DataFrame) -> dict[str, float | MissingValue]:
|
|
142
|
+
return {
|
|
143
|
+
**calculate_prompt_token_stats(column_config, df),
|
|
144
|
+
**calculate_completion_token_stats(column_config, df),
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def calculate_validation_column_info(column_config: ValidationColumnConfig, df: pd.DataFrame) -> dict[str, Any]:
|
|
149
|
+
try:
|
|
150
|
+
return {"num_valid_records": df[column_config.name].apply(lambda x: ensure_boolean(x["is_valid"])).sum()}
|
|
151
|
+
except Exception as e:
|
|
152
|
+
logger.warning(
|
|
153
|
+
f"{WARNING_PREFIX} failed to calculate code validation column info for column {column_config.name}: {e}"
|
|
154
|
+
)
|
|
155
|
+
return {"num_valid_records": MissingValue.CALCULATION_FAILED}
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype: pa.DataType) -> str:
|
|
159
|
+
if isinstance(pyarrow_dtype, pa.ListType):
|
|
160
|
+
return f"list[{convert_pyarrow_dtype_to_simple_dtype(pyarrow_dtype.value_type)}]"
|
|
161
|
+
if isinstance(pyarrow_dtype, pa.StructType):
|
|
162
|
+
return "dict"
|
|
163
|
+
pyarrow_dtype_str = str(pyarrow_dtype)
|
|
164
|
+
if "int" in pyarrow_dtype_str:
|
|
165
|
+
return "int"
|
|
166
|
+
if "double" in pyarrow_dtype_str:
|
|
167
|
+
return "float"
|
|
168
|
+
if "float" in pyarrow_dtype_str:
|
|
169
|
+
return "float"
|
|
170
|
+
if "string" in pyarrow_dtype_str:
|
|
171
|
+
return "string"
|
|
172
|
+
if "timestamp" in pyarrow_dtype_str:
|
|
173
|
+
return "timestamp"
|
|
174
|
+
if "time" in pyarrow_dtype_str:
|
|
175
|
+
return "time"
|
|
176
|
+
if "date" in pyarrow_dtype_str:
|
|
177
|
+
return "date"
|
|
178
|
+
return pyarrow_dtype_str
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def ensure_hashable(x: Any) -> str:
|
|
182
|
+
"""
|
|
183
|
+
Makes a best effort turn known unhashable types to a hashable
|
|
184
|
+
string representation that preserves both structure and values.
|
|
185
|
+
"""
|
|
186
|
+
if isinstance(x, (Number, bool)) or x is None:
|
|
187
|
+
return x
|
|
188
|
+
|
|
189
|
+
if isinstance(x, dict):
|
|
190
|
+
# Sort by keys and convert key-value pairs to tuples
|
|
191
|
+
return str(sorted([(str(k), ensure_hashable(v)) for k, v in x.items()]))
|
|
192
|
+
|
|
193
|
+
if isinstance(x, (list, tuple, set, np.ndarray)):
|
|
194
|
+
# Recursively make all elements hashable
|
|
195
|
+
return str(sorted([ensure_hashable(e) for e in x]))
|
|
196
|
+
|
|
197
|
+
return str(x)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def ensure_boolean(v: bool | str | int | None) -> bool:
|
|
201
|
+
if isinstance(v, (bool, np.bool_)):
|
|
202
|
+
return bool(v)
|
|
203
|
+
if isinstance(v, (int, float, np.integer, np.floating)) and v in [0, 1, 0.0, 1.0]:
|
|
204
|
+
return bool(v)
|
|
205
|
+
if isinstance(v, (str, np.str_)) and v.lower() in ["true", "false"]:
|
|
206
|
+
return v.lower() == "true"
|
|
207
|
+
if v is None:
|
|
208
|
+
return False
|
|
209
|
+
raise ValueError(f"Invalid boolean value: {v}")
|