data-designer-config 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/config/__init__.py +149 -0
- data_designer/config/_version.py +34 -0
- data_designer/config/analysis/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +159 -0
- data_designer/config/analysis/column_statistics.py +421 -0
- data_designer/config/analysis/dataset_profiler.py +84 -0
- data_designer/config/analysis/utils/errors.py +10 -0
- data_designer/config/analysis/utils/reporting.py +192 -0
- data_designer/config/base.py +69 -0
- data_designer/config/column_configs.py +476 -0
- data_designer/config/column_types.py +141 -0
- data_designer/config/config_builder.py +595 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +13 -0
- data_designer/config/dataset_metadata.py +18 -0
- data_designer/config/default_model_settings.py +129 -0
- data_designer/config/errors.py +24 -0
- data_designer/config/interface.py +55 -0
- data_designer/config/models.py +486 -0
- data_designer/config/preview_results.py +41 -0
- data_designer/config/processors.py +148 -0
- data_designer/config/run_config.py +56 -0
- data_designer/config/sampler_constraints.py +52 -0
- data_designer/config/sampler_params.py +639 -0
- data_designer/config/seed.py +116 -0
- data_designer/config/seed_source.py +84 -0
- data_designer/config/seed_source_types.py +19 -0
- data_designer/config/testing/__init__.py +6 -0
- data_designer/config/testing/fixtures.py +308 -0
- data_designer/config/utils/code_lang.py +93 -0
- data_designer/config/utils/constants.py +365 -0
- data_designer/config/utils/errors.py +21 -0
- data_designer/config/utils/info.py +94 -0
- data_designer/config/utils/io_helpers.py +258 -0
- data_designer/config/utils/misc.py +78 -0
- data_designer/config/utils/numerical_helpers.py +30 -0
- data_designer/config/utils/type_helpers.py +106 -0
- data_designer/config/utils/visualization.py +482 -0
- data_designer/config/validator_params.py +94 -0
- data_designer/errors.py +7 -0
- data_designer/lazy_heavy_imports.py +56 -0
- data_designer/logging.py +180 -0
- data_designer/plugin_manager.py +78 -0
- data_designer/plugins/__init__.py +8 -0
- data_designer/plugins/errors.py +15 -0
- data_designer/plugins/plugin.py +141 -0
- data_designer/plugins/registry.py +88 -0
- data_designer_config-0.4.0.dist-info/METADATA +75 -0
- data_designer_config-0.4.0.dist-info/RECORD +50 -0
- data_designer_config-0.4.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
from pydantic import Field, model_validator
|
|
9
|
+
from typing_extensions import Self
|
|
10
|
+
|
|
11
|
+
from data_designer.config.base import ConfigBase
|
|
12
|
+
from data_designer.config.seed_source_types import SeedSourceT
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SamplingStrategy(str, Enum):
|
|
16
|
+
ORDERED = "ordered"
|
|
17
|
+
SHUFFLE = "shuffle"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class IndexRange(ConfigBase):
|
|
21
|
+
start: int = Field(ge=0, description="The start index of the index range (inclusive)")
|
|
22
|
+
end: int = Field(ge=0, description="The end index of the index range (inclusive)")
|
|
23
|
+
|
|
24
|
+
@model_validator(mode="after")
|
|
25
|
+
def _validate_index_range(self) -> Self:
|
|
26
|
+
if self.start > self.end:
|
|
27
|
+
raise ValueError("'start' index must be less than or equal to 'end' index")
|
|
28
|
+
return self
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def size(self) -> int:
|
|
32
|
+
return self.end - self.start + 1
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class PartitionBlock(ConfigBase):
|
|
36
|
+
index: int = Field(default=0, ge=0, description="The index of the partition to sample from")
|
|
37
|
+
num_partitions: int = Field(default=1, ge=1, description="The total number of partitions in the dataset")
|
|
38
|
+
|
|
39
|
+
@model_validator(mode="after")
|
|
40
|
+
def _validate_partition_block(self) -> Self:
|
|
41
|
+
if self.index >= self.num_partitions:
|
|
42
|
+
raise ValueError("'index' must be less than 'num_partitions'")
|
|
43
|
+
return self
|
|
44
|
+
|
|
45
|
+
def to_index_range(self, dataset_size: int) -> IndexRange:
|
|
46
|
+
partition_size = dataset_size // self.num_partitions
|
|
47
|
+
start = self.index * partition_size
|
|
48
|
+
|
|
49
|
+
# For the last partition, extend to the end of the dataset to include remainder rows
|
|
50
|
+
if self.index == self.num_partitions - 1:
|
|
51
|
+
end = dataset_size - 1
|
|
52
|
+
else:
|
|
53
|
+
end = ((self.index + 1) * partition_size) - 1
|
|
54
|
+
return IndexRange(start=start, end=end)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class SeedConfig(ConfigBase):
|
|
58
|
+
"""Configuration for sampling data from a seed dataset.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
source: A SeedSource defining where the seed data exists
|
|
62
|
+
sampling_strategy: Strategy for how to sample rows from the dataset.
|
|
63
|
+
- ORDERED: Read rows sequentially in their original order.
|
|
64
|
+
- SHUFFLE: Randomly shuffle rows before sampling. When used with
|
|
65
|
+
selection_strategy, shuffling occurs within the selected range/partition.
|
|
66
|
+
selection_strategy: Optional strategy to select a subset of the dataset.
|
|
67
|
+
- IndexRange: Select a specific range of indices (e.g., rows 100-200).
|
|
68
|
+
- PartitionBlock: Select a partition by splitting the dataset into N equal parts.
|
|
69
|
+
Partition indices are zero-based (index=0 is the first partition, index=1 is
|
|
70
|
+
the second, etc.).
|
|
71
|
+
|
|
72
|
+
Examples:
|
|
73
|
+
Read rows sequentially from start to end:
|
|
74
|
+
SeedConfig(
|
|
75
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
76
|
+
sampling_strategy=SamplingStrategy.ORDERED
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
Read rows in random order:
|
|
80
|
+
SeedConfig(
|
|
81
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
82
|
+
sampling_strategy=SamplingStrategy.SHUFFLE
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
Read specific index range (rows 100-199):
|
|
86
|
+
SeedConfig(
|
|
87
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
88
|
+
sampling_strategy=SamplingStrategy.ORDERED,
|
|
89
|
+
selection_strategy=IndexRange(start=100, end=199)
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
Read random rows from a specific index range (shuffles within rows 100-199):
|
|
93
|
+
SeedConfig(
|
|
94
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
95
|
+
sampling_strategy=SamplingStrategy.SHUFFLE,
|
|
96
|
+
selection_strategy=IndexRange(start=100, end=199)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
Read from partition 2 (3rd partition, zero-based) of 5 partitions (20% of dataset):
|
|
100
|
+
SeedConfig(
|
|
101
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
102
|
+
sampling_strategy=SamplingStrategy.ORDERED,
|
|
103
|
+
selection_strategy=PartitionBlock(index=2, num_partitions=5)
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
Read shuffled rows from partition 0 of 10 partitions (shuffles within the partition):
|
|
107
|
+
SeedConfig(
|
|
108
|
+
source=LocalFileSeedSource(path="my_data.parquet"),
|
|
109
|
+
sampling_strategy=SamplingStrategy.SHUFFLE,
|
|
110
|
+
selection_strategy=PartitionBlock(index=0, num_partitions=10)
|
|
111
|
+
)
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
source: SeedSourceT
|
|
115
|
+
sampling_strategy: SamplingStrategy = SamplingStrategy.ORDERED
|
|
116
|
+
selection_strategy: IndexRange | PartitionBlock | None = None
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from abc import ABC
|
|
7
|
+
from typing import TYPE_CHECKING, Literal
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
10
|
+
from pydantic.json_schema import SkipJsonSchema
|
|
11
|
+
from typing_extensions import Self
|
|
12
|
+
|
|
13
|
+
from data_designer.config.utils.io_helpers import (
|
|
14
|
+
VALID_DATASET_FILE_EXTENSIONS,
|
|
15
|
+
validate_dataset_file_path,
|
|
16
|
+
validate_path_contains_files_of_type,
|
|
17
|
+
)
|
|
18
|
+
from data_designer.lazy_heavy_imports import pd
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import pandas as pd
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class SeedSource(BaseModel, ABC):
|
|
25
|
+
"""Base class for seed dataset configurations.
|
|
26
|
+
|
|
27
|
+
All subclasses must define a `seed_type` field with a Literal value.
|
|
28
|
+
This serves as a discriminated union discriminator.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
seed_type: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class LocalFileSeedSource(SeedSource):
|
|
35
|
+
seed_type: Literal["local"] = "local"
|
|
36
|
+
|
|
37
|
+
path: str
|
|
38
|
+
|
|
39
|
+
@field_validator("path", mode="after")
|
|
40
|
+
def validate_path(cls, v: str) -> str:
|
|
41
|
+
valid_wild_card_versions = {f"*{ext}" for ext in VALID_DATASET_FILE_EXTENSIONS}
|
|
42
|
+
if any(v.endswith(wildcard) for wildcard in valid_wild_card_versions):
|
|
43
|
+
parts = v.split("*.")
|
|
44
|
+
file_path = parts[0]
|
|
45
|
+
file_extension = parts[-1]
|
|
46
|
+
validate_path_contains_files_of_type(file_path, file_extension)
|
|
47
|
+
else:
|
|
48
|
+
validate_dataset_file_path(v)
|
|
49
|
+
return v
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_dataframe(cls, df: pd.DataFrame, path: str) -> Self:
|
|
53
|
+
df.to_parquet(path, index=False)
|
|
54
|
+
return cls(path=path)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class HuggingFaceSeedSource(SeedSource):
|
|
58
|
+
seed_type: Literal["hf"] = "hf"
|
|
59
|
+
|
|
60
|
+
path: str = Field(
|
|
61
|
+
...,
|
|
62
|
+
description=(
|
|
63
|
+
"Path to the seed data in HuggingFace. Wildcards are allowed. Examples include "
|
|
64
|
+
"'datasets/my-username/my-dataset/data/000_00000.parquet', 'datasets/my-username/my-dataset/data/*.parquet', "
|
|
65
|
+
"and 'datasets/my-username/my-dataset/**/*.parquet'"
|
|
66
|
+
),
|
|
67
|
+
)
|
|
68
|
+
token: str | None = None
|
|
69
|
+
endpoint: str = "https://huggingface.co"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DataFrameSeedSource(SeedSource):
|
|
73
|
+
seed_type: Literal["df"] = "df"
|
|
74
|
+
|
|
75
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
76
|
+
|
|
77
|
+
df: SkipJsonSchema[pd.DataFrame] = Field(
|
|
78
|
+
...,
|
|
79
|
+
exclude=True,
|
|
80
|
+
description=(
|
|
81
|
+
"DataFrame to use directly as the seed dataset. NOTE: if you need to write a Data Designer config, "
|
|
82
|
+
"you must use `LocalFileSeedSource` instead, since DataFrame objects are not serializable."
|
|
83
|
+
),
|
|
84
|
+
)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from typing import Annotated
|
|
7
|
+
|
|
8
|
+
from pydantic import Field
|
|
9
|
+
from typing_extensions import TypeAlias
|
|
10
|
+
|
|
11
|
+
from data_designer.config.seed_source import DataFrameSeedSource, HuggingFaceSeedSource, LocalFileSeedSource
|
|
12
|
+
from data_designer.plugin_manager import PluginManager
|
|
13
|
+
|
|
14
|
+
plugin_manager = PluginManager()
|
|
15
|
+
|
|
16
|
+
_SeedSourceT: TypeAlias = LocalFileSeedSource | HuggingFaceSeedSource | DataFrameSeedSource
|
|
17
|
+
_SeedSourceT = plugin_manager.inject_into_seed_source_type_union(_SeedSourceT)
|
|
18
|
+
|
|
19
|
+
SeedSourceT = Annotated[_SeedSourceT, Field(discriminator="seed_type")]
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Pytest fixtures for config testing."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import os
|
|
9
|
+
import tarfile
|
|
10
|
+
import tempfile
|
|
11
|
+
import textwrap
|
|
12
|
+
from typing import TYPE_CHECKING
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
import yaml
|
|
16
|
+
|
|
17
|
+
from data_designer.config.analysis.column_statistics import GeneralColumnStatistics
|
|
18
|
+
from data_designer.config.analysis.dataset_profiler import DatasetProfilerResults
|
|
19
|
+
from data_designer.config.column_configs import SamplerColumnConfig
|
|
20
|
+
from data_designer.config.config_builder import DataDesignerConfigBuilder
|
|
21
|
+
from data_designer.config.data_designer_config import DataDesignerConfig
|
|
22
|
+
from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig, ModelProvider
|
|
23
|
+
from data_designer.lazy_heavy_imports import pd
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
import pandas as pd
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def stub_data_designer_config_str() -> str:
|
|
31
|
+
return """
|
|
32
|
+
model_configs:
|
|
33
|
+
- alias: my_own_code_model
|
|
34
|
+
model: openai/meta/llama-3.3-70b-instruct
|
|
35
|
+
inference_parameters:
|
|
36
|
+
temperature:
|
|
37
|
+
distribution_type: uniform
|
|
38
|
+
params:
|
|
39
|
+
low: 0.5
|
|
40
|
+
high: 0.9
|
|
41
|
+
top_p:
|
|
42
|
+
distribution_type: manual
|
|
43
|
+
params:
|
|
44
|
+
values: [0.1, 0.2, 0.33]
|
|
45
|
+
weights: [0.3, 0.2, 0.50]
|
|
46
|
+
|
|
47
|
+
seed_config:
|
|
48
|
+
source:
|
|
49
|
+
seed_type: hf
|
|
50
|
+
path: datasets/test-repo/testing/data.csv
|
|
51
|
+
sampling_strategy: shuffle
|
|
52
|
+
|
|
53
|
+
columns:
|
|
54
|
+
- name: code_id
|
|
55
|
+
sampler_type: uuid
|
|
56
|
+
column_type: sampler
|
|
57
|
+
params:
|
|
58
|
+
prefix: code_
|
|
59
|
+
short_form: true
|
|
60
|
+
uppercase: true
|
|
61
|
+
- name: age
|
|
62
|
+
sampler_type: uniform
|
|
63
|
+
column_type: sampler
|
|
64
|
+
params:
|
|
65
|
+
low: 35
|
|
66
|
+
high: 88
|
|
67
|
+
- name: domain
|
|
68
|
+
sampler_type: category
|
|
69
|
+
column_type: sampler
|
|
70
|
+
params:
|
|
71
|
+
values: [Healthcare, Finance, Education, Government]
|
|
72
|
+
- name: topic
|
|
73
|
+
sampler_type: category
|
|
74
|
+
column_type: sampler
|
|
75
|
+
params:
|
|
76
|
+
values: [Web Development, Data Science, Machine Learning, Cloud Computing]
|
|
77
|
+
- name: text
|
|
78
|
+
column_type: llm-text
|
|
79
|
+
prompt: Write a description of python code in topic {topic} and domain {domain}
|
|
80
|
+
model_alias: my_own_code_model
|
|
81
|
+
- name: code
|
|
82
|
+
column_type: llm-code
|
|
83
|
+
prompt: Write Python code that will be paired with the following prompt {text}
|
|
84
|
+
model_alias: my_own_code_model
|
|
85
|
+
code_lang: python
|
|
86
|
+
- name: code_validation_result
|
|
87
|
+
column_type: validation
|
|
88
|
+
target_columns:
|
|
89
|
+
- code
|
|
90
|
+
validator_type: code
|
|
91
|
+
validator_params:
|
|
92
|
+
code_lang: python
|
|
93
|
+
- name: code_judge_result
|
|
94
|
+
model_alias: my_own_code_model
|
|
95
|
+
column_type: llm-judge
|
|
96
|
+
prompt: You are an expert in Python programming and make appropriate judgement on the quality of the code.
|
|
97
|
+
scores:
|
|
98
|
+
- name: Pythonic
|
|
99
|
+
description: Pythonic Code and Best Practices (Does the code follow Python conventions and best practices?)
|
|
100
|
+
options:
|
|
101
|
+
"4": The code exemplifies Pythonic principles, making excellent use of Python-specific constructs, standard library modules and programming idioms; follows all relevant PEPs.
|
|
102
|
+
"3": The code closely follows Python conventions and adheres to many best practices; good use of Python-specific constructs, standard library modules and programming idioms.
|
|
103
|
+
"2": The code generally follows Python conventions but has room for better alignment with Pythonic practices.
|
|
104
|
+
"1": The code loosely follows Python conventions, with several deviations from best practices.
|
|
105
|
+
"0": The code does not follow Python conventions or best practices, using non-Pythonic approaches.
|
|
106
|
+
- name: Readability
|
|
107
|
+
description: Readability and Maintainability (Is the Python code easy to understand and maintain?)
|
|
108
|
+
options:
|
|
109
|
+
"4": The code is excellently formatted, follows PEP 8 guidelines, is elegantly concise and clear, uses meaningful variable names, ensuring high readability and ease of maintenance; organizes complex logic well. Docstrings are given in a Google Docstring format.
|
|
110
|
+
"3": The code is well-formatted in the sense of code-as-documentation, making it relatively easy to understand and maintain; uses descriptive names and organizes logic clearly.
|
|
111
|
+
"2": The code is somewhat readable with basic formatting and some comments, but improvements are needed; needs better use of descriptive names and organization.
|
|
112
|
+
"1": The code has minimal formatting, making it hard to understand; lacks meaningful names and organization.
|
|
113
|
+
"0": The code is unreadable, with no attempt at formatting or description.
|
|
114
|
+
|
|
115
|
+
constraints:
|
|
116
|
+
- target_column: age
|
|
117
|
+
operator: "lt"
|
|
118
|
+
rhs: 65
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@pytest.fixture
|
|
123
|
+
def stub_data_designer_builder_config_str(stub_data_designer_config_str: str) -> str:
|
|
124
|
+
return f"""
|
|
125
|
+
data_designer:
|
|
126
|
+
{textwrap.indent(stub_data_designer_config_str, prefix=" ")}
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
@pytest.fixture
|
|
131
|
+
def stub_data_designer_config(stub_data_designer_config_str: str) -> DataDesignerConfig:
|
|
132
|
+
json_config = yaml.safe_load(stub_data_designer_config_str)
|
|
133
|
+
return DataDesignerConfig.model_validate(json_config)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@pytest.fixture
|
|
137
|
+
def stub_model_configs() -> list[ModelConfig]:
|
|
138
|
+
return [
|
|
139
|
+
ModelConfig(
|
|
140
|
+
alias="stub-model",
|
|
141
|
+
model="stub-model",
|
|
142
|
+
inference_parameters=ChatCompletionInferenceParams(
|
|
143
|
+
temperature=0.9,
|
|
144
|
+
top_p=0.9,
|
|
145
|
+
max_tokens=2048,
|
|
146
|
+
),
|
|
147
|
+
)
|
|
148
|
+
]
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@pytest.fixture
|
|
152
|
+
def stub_model_providers() -> list[ModelProvider]:
|
|
153
|
+
return [
|
|
154
|
+
ModelProvider(
|
|
155
|
+
name="provider-1",
|
|
156
|
+
endpoint="https://api.provider-1.com/v1",
|
|
157
|
+
api_key="PROVIDER_1_API_KEY",
|
|
158
|
+
)
|
|
159
|
+
]
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@pytest.fixture
|
|
163
|
+
def stub_empty_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerConfigBuilder:
|
|
164
|
+
"""Test builder with model configs."""
|
|
165
|
+
return DataDesignerConfigBuilder(model_configs=stub_model_configs)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@pytest.fixture
|
|
169
|
+
def stub_complete_builder(stub_data_designer_builder_config_str: str) -> DataDesignerConfigBuilder:
|
|
170
|
+
return DataDesignerConfigBuilder.from_config(config=stub_data_designer_builder_config_str)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@pytest.fixture
|
|
174
|
+
def stub_dataframe() -> pd.DataFrame:
|
|
175
|
+
return pd.DataFrame(
|
|
176
|
+
{
|
|
177
|
+
"name": ["John", "Jane", "Jim", "Jill", "Mike", "Mary", "Mark", "Martha", "Alex", "Alice", "Bob", "Bella"],
|
|
178
|
+
"age": [25, 30, 35, 40, 45, 50, 55, 60, 22, 28, 65, 38],
|
|
179
|
+
"city": [
|
|
180
|
+
"New York",
|
|
181
|
+
"Los Angeles",
|
|
182
|
+
"Chicago",
|
|
183
|
+
"Houston",
|
|
184
|
+
"Miami",
|
|
185
|
+
"Seattle",
|
|
186
|
+
"San Francisco",
|
|
187
|
+
"Boston",
|
|
188
|
+
"Denver",
|
|
189
|
+
"Austin",
|
|
190
|
+
"Portland",
|
|
191
|
+
"Atlanta",
|
|
192
|
+
],
|
|
193
|
+
"state": ["NY", "CA", "IL", "TX", "FL", "WA", "CA", "MA", "CO", "TX", "OR", "GA"],
|
|
194
|
+
"zip": [
|
|
195
|
+
"10001",
|
|
196
|
+
"90001",
|
|
197
|
+
"60601",
|
|
198
|
+
"77001",
|
|
199
|
+
"33101",
|
|
200
|
+
"98101",
|
|
201
|
+
"94101",
|
|
202
|
+
"02101",
|
|
203
|
+
"80201",
|
|
204
|
+
"73301",
|
|
205
|
+
"97201",
|
|
206
|
+
"30301",
|
|
207
|
+
],
|
|
208
|
+
"email": [
|
|
209
|
+
"john@example.com",
|
|
210
|
+
"jane@example.com",
|
|
211
|
+
"jim@example.com",
|
|
212
|
+
"jill@example.com",
|
|
213
|
+
"mike@example.com",
|
|
214
|
+
"mary@example.com",
|
|
215
|
+
"mark@example.com",
|
|
216
|
+
"martha@example.com",
|
|
217
|
+
"alex.smith@example.co.uk",
|
|
218
|
+
"alice.wu@example.ca",
|
|
219
|
+
"bob.martin@example.org",
|
|
220
|
+
"bella.rossi@example.it",
|
|
221
|
+
],
|
|
222
|
+
"phone": [
|
|
223
|
+
"123-456-7890",
|
|
224
|
+
"213-555-1234",
|
|
225
|
+
"312-222-3333",
|
|
226
|
+
"713-444-5555",
|
|
227
|
+
"305-888-9999",
|
|
228
|
+
"206-777-8888",
|
|
229
|
+
"415-999-0000",
|
|
230
|
+
"617-111-2222",
|
|
231
|
+
"+44 20 7946 0958",
|
|
232
|
+
"+1-416-555-0199",
|
|
233
|
+
"+39 06 6982 1234",
|
|
234
|
+
"+49 30 123456",
|
|
235
|
+
],
|
|
236
|
+
"address": [
|
|
237
|
+
"123 Main St",
|
|
238
|
+
"456 Oak Ave",
|
|
239
|
+
"789 Pine Rd",
|
|
240
|
+
"101 Maple Blvd",
|
|
241
|
+
"202 Elm St",
|
|
242
|
+
"303 Cedar Ave",
|
|
243
|
+
"404 Spruce Dr",
|
|
244
|
+
"505 Birch Ln",
|
|
245
|
+
"12 Baker St",
|
|
246
|
+
"88 King St W",
|
|
247
|
+
"Via Roma 1",
|
|
248
|
+
"Unter den Linden 5",
|
|
249
|
+
],
|
|
250
|
+
}
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@pytest.fixture
|
|
255
|
+
def stub_dataset_tar_file():
|
|
256
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
257
|
+
# Create valid parquet files with actual data
|
|
258
|
+
df1 = pd.DataFrame({"id": ["1", "2"], "name": ["test", "sample"]})
|
|
259
|
+
df2 = pd.DataFrame({"id": ["3", "4"], "name": ["data", "example"]})
|
|
260
|
+
|
|
261
|
+
# Write parquet files
|
|
262
|
+
os.makedirs(temp_dir + "/dataset", exist_ok=True)
|
|
263
|
+
df1.to_parquet(temp_dir + "/dataset/dataset-001.parquet", index=False)
|
|
264
|
+
df2.to_parquet(temp_dir + "/dataset/dataset-002.parquet", index=False)
|
|
265
|
+
|
|
266
|
+
# Create tar file
|
|
267
|
+
tar_path = temp_dir + "/dataset.tar"
|
|
268
|
+
with tarfile.open(tar_path, "w:gz") as tar:
|
|
269
|
+
tar.add(temp_dir + "/dataset/dataset-001.parquet", arcname="dataset/dataset-001.parquet")
|
|
270
|
+
tar.add(temp_dir + "/dataset/dataset-002.parquet", arcname="dataset/dataset-002.parquet")
|
|
271
|
+
with open(tar_path, "rb") as tar_file:
|
|
272
|
+
yield tar_file
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
@pytest.fixture
|
|
276
|
+
def stub_dataset_profiler_results() -> DatasetProfilerResults:
|
|
277
|
+
stub_column_statistics = GeneralColumnStatistics(
|
|
278
|
+
column_name="some",
|
|
279
|
+
num_records=1,
|
|
280
|
+
num_unique=1,
|
|
281
|
+
num_null=0,
|
|
282
|
+
pyarrow_dtype="string",
|
|
283
|
+
simple_dtype="string",
|
|
284
|
+
)
|
|
285
|
+
return DatasetProfilerResults(
|
|
286
|
+
num_records=1,
|
|
287
|
+
target_num_records=100,
|
|
288
|
+
column_statistics=[stub_column_statistics],
|
|
289
|
+
side_effect_column_names=None,
|
|
290
|
+
column_profiles=None,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
@pytest.fixture
|
|
295
|
+
def stub_sampler_only_config_builder(stub_model_configs: list[ModelConfig]) -> DataDesignerConfigBuilder:
|
|
296
|
+
config_builder = DataDesignerConfigBuilder(model_configs=stub_model_configs)
|
|
297
|
+
config_builder.add_column(
|
|
298
|
+
SamplerColumnConfig(
|
|
299
|
+
name="uuid", sampler_type="uuid", params={"prefix": "code_", "short_form": True, "uppercase": True}
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
config_builder.add_column(
|
|
303
|
+
SamplerColumnConfig(name="category", sampler_type="category", params={"values": ["a", "b", "c"]})
|
|
304
|
+
)
|
|
305
|
+
config_builder.add_column(
|
|
306
|
+
SamplerColumnConfig(name="uniform", sampler_type="uniform", params={"low": 1, "high": 100})
|
|
307
|
+
)
|
|
308
|
+
return config_builder
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from __future__ import annotations
|
|
5
|
+
|
|
6
|
+
from enum import Enum
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CodeLang(str, Enum):
|
|
10
|
+
BASH = "bash"
|
|
11
|
+
C = "c"
|
|
12
|
+
COBOL = "cobol"
|
|
13
|
+
CPP = "cpp"
|
|
14
|
+
CSHARP = "csharp"
|
|
15
|
+
GO = "go"
|
|
16
|
+
JAVA = "java"
|
|
17
|
+
JAVASCRIPT = "javascript"
|
|
18
|
+
KOTLIN = "kotlin"
|
|
19
|
+
PYTHON = "python"
|
|
20
|
+
RUBY = "ruby"
|
|
21
|
+
RUST = "rust"
|
|
22
|
+
SCALA = "scala"
|
|
23
|
+
SWIFT = "swift"
|
|
24
|
+
TYPESCRIPT = "typescript"
|
|
25
|
+
SQL_SQLITE = "sql:sqlite"
|
|
26
|
+
SQL_TSQL = "sql:tsql"
|
|
27
|
+
SQL_BIGQUERY = "sql:bigquery"
|
|
28
|
+
SQL_MYSQL = "sql:mysql"
|
|
29
|
+
SQL_POSTGRES = "sql:postgres"
|
|
30
|
+
SQL_ANSI = "sql:ansi"
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def parse(value: str | CodeLang) -> tuple[str, str | None]:
|
|
34
|
+
value = value.value if isinstance(value, CodeLang) else value
|
|
35
|
+
split_vals = value.split(":")
|
|
36
|
+
return (split_vals[0], split_vals[1] if len(split_vals) > 1 else None)
|
|
37
|
+
|
|
38
|
+
@staticmethod
|
|
39
|
+
def parse_lang(value: str | CodeLang) -> str:
|
|
40
|
+
return CodeLang.parse(value)[0]
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def parse_dialect(value: str | CodeLang) -> str | None:
|
|
44
|
+
return CodeLang.parse(value)[1]
|
|
45
|
+
|
|
46
|
+
@staticmethod
|
|
47
|
+
def supported_values() -> set[str]:
|
|
48
|
+
return {lang.value for lang in CodeLang}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
SQL_DIALECTS: set[CodeLang] = {
|
|
52
|
+
CodeLang.SQL_SQLITE,
|
|
53
|
+
CodeLang.SQL_TSQL,
|
|
54
|
+
CodeLang.SQL_BIGQUERY,
|
|
55
|
+
CodeLang.SQL_MYSQL,
|
|
56
|
+
CodeLang.SQL_POSTGRES,
|
|
57
|
+
CodeLang.SQL_ANSI,
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
##########################################################
|
|
61
|
+
# Helper functions
|
|
62
|
+
##########################################################
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def code_lang_to_syntax_lexer(code_lang: CodeLang | str) -> str:
|
|
66
|
+
"""Convert the code language to a syntax lexer for Pygments.
|
|
67
|
+
|
|
68
|
+
Reference: https://pygments.org/docs/lexers/
|
|
69
|
+
"""
|
|
70
|
+
code_lang_to_lexer = {
|
|
71
|
+
CodeLang.BASH: "bash",
|
|
72
|
+
CodeLang.C: "c",
|
|
73
|
+
CodeLang.COBOL: "cobol",
|
|
74
|
+
CodeLang.CPP: "cpp",
|
|
75
|
+
CodeLang.CSHARP: "csharp",
|
|
76
|
+
CodeLang.GO: "golang",
|
|
77
|
+
CodeLang.JAVA: "java",
|
|
78
|
+
CodeLang.JAVASCRIPT: "javascript",
|
|
79
|
+
CodeLang.KOTLIN: "kotlin",
|
|
80
|
+
CodeLang.PYTHON: "python",
|
|
81
|
+
CodeLang.RUBY: "ruby",
|
|
82
|
+
CodeLang.RUST: "rust",
|
|
83
|
+
CodeLang.SCALA: "scala",
|
|
84
|
+
CodeLang.SWIFT: "swift",
|
|
85
|
+
CodeLang.TYPESCRIPT: "typescript",
|
|
86
|
+
CodeLang.SQL_SQLITE: "sql",
|
|
87
|
+
CodeLang.SQL_ANSI: "sql",
|
|
88
|
+
CodeLang.SQL_TSQL: "tsql",
|
|
89
|
+
CodeLang.SQL_BIGQUERY: "sql",
|
|
90
|
+
CodeLang.SQL_MYSQL: "mysql",
|
|
91
|
+
CodeLang.SQL_POSTGRES: "postgres",
|
|
92
|
+
}
|
|
93
|
+
return code_lang_to_lexer.get(code_lang, code_lang)
|