data-designer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data_designer/__init__.py +15 -0
- data_designer/_version.py +34 -0
- data_designer/cli/README.md +236 -0
- data_designer/cli/__init__.py +6 -0
- data_designer/cli/commands/__init__.py +2 -0
- data_designer/cli/commands/list.py +130 -0
- data_designer/cli/commands/models.py +10 -0
- data_designer/cli/commands/providers.py +11 -0
- data_designer/cli/commands/reset.py +100 -0
- data_designer/cli/controllers/__init__.py +7 -0
- data_designer/cli/controllers/model_controller.py +246 -0
- data_designer/cli/controllers/provider_controller.py +317 -0
- data_designer/cli/forms/__init__.py +20 -0
- data_designer/cli/forms/builder.py +51 -0
- data_designer/cli/forms/field.py +180 -0
- data_designer/cli/forms/form.py +59 -0
- data_designer/cli/forms/model_builder.py +125 -0
- data_designer/cli/forms/provider_builder.py +76 -0
- data_designer/cli/main.py +44 -0
- data_designer/cli/repositories/__init__.py +8 -0
- data_designer/cli/repositories/base.py +39 -0
- data_designer/cli/repositories/model_repository.py +42 -0
- data_designer/cli/repositories/provider_repository.py +43 -0
- data_designer/cli/services/__init__.py +7 -0
- data_designer/cli/services/model_service.py +116 -0
- data_designer/cli/services/provider_service.py +111 -0
- data_designer/cli/ui.py +448 -0
- data_designer/cli/utils.py +47 -0
- data_designer/config/__init__.py +2 -0
- data_designer/config/analysis/column_profilers.py +89 -0
- data_designer/config/analysis/column_statistics.py +274 -0
- data_designer/config/analysis/dataset_profiler.py +60 -0
- data_designer/config/analysis/utils/errors.py +8 -0
- data_designer/config/analysis/utils/reporting.py +188 -0
- data_designer/config/base.py +68 -0
- data_designer/config/column_configs.py +354 -0
- data_designer/config/column_types.py +168 -0
- data_designer/config/config_builder.py +660 -0
- data_designer/config/data_designer_config.py +40 -0
- data_designer/config/dataset_builders.py +11 -0
- data_designer/config/datastore.py +151 -0
- data_designer/config/default_model_settings.py +123 -0
- data_designer/config/errors.py +19 -0
- data_designer/config/interface.py +54 -0
- data_designer/config/models.py +231 -0
- data_designer/config/preview_results.py +32 -0
- data_designer/config/processors.py +41 -0
- data_designer/config/sampler_constraints.py +51 -0
- data_designer/config/sampler_params.py +604 -0
- data_designer/config/seed.py +145 -0
- data_designer/config/utils/code_lang.py +83 -0
- data_designer/config/utils/constants.py +313 -0
- data_designer/config/utils/errors.py +19 -0
- data_designer/config/utils/info.py +88 -0
- data_designer/config/utils/io_helpers.py +273 -0
- data_designer/config/utils/misc.py +81 -0
- data_designer/config/utils/numerical_helpers.py +28 -0
- data_designer/config/utils/type_helpers.py +100 -0
- data_designer/config/utils/validation.py +336 -0
- data_designer/config/utils/visualization.py +427 -0
- data_designer/config/validator_params.py +96 -0
- data_designer/engine/__init__.py +2 -0
- data_designer/engine/analysis/column_profilers/base.py +55 -0
- data_designer/engine/analysis/column_profilers/judge_score_profiler.py +160 -0
- data_designer/engine/analysis/column_profilers/registry.py +20 -0
- data_designer/engine/analysis/column_statistics.py +142 -0
- data_designer/engine/analysis/dataset_profiler.py +125 -0
- data_designer/engine/analysis/errors.py +7 -0
- data_designer/engine/analysis/utils/column_statistics_calculations.py +209 -0
- data_designer/engine/analysis/utils/judge_score_processing.py +128 -0
- data_designer/engine/column_generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/__init__.py +2 -0
- data_designer/engine/column_generators/generators/base.py +61 -0
- data_designer/engine/column_generators/generators/expression.py +63 -0
- data_designer/engine/column_generators/generators/llm_generators.py +172 -0
- data_designer/engine/column_generators/generators/samplers.py +75 -0
- data_designer/engine/column_generators/generators/seed_dataset.py +149 -0
- data_designer/engine/column_generators/generators/validation.py +147 -0
- data_designer/engine/column_generators/registry.py +56 -0
- data_designer/engine/column_generators/utils/errors.py +13 -0
- data_designer/engine/column_generators/utils/judge_score_factory.py +57 -0
- data_designer/engine/column_generators/utils/prompt_renderer.py +98 -0
- data_designer/engine/configurable_task.py +82 -0
- data_designer/engine/dataset_builders/artifact_storage.py +181 -0
- data_designer/engine/dataset_builders/column_wise_builder.py +287 -0
- data_designer/engine/dataset_builders/errors.py +13 -0
- data_designer/engine/dataset_builders/multi_column_configs.py +44 -0
- data_designer/engine/dataset_builders/utils/__init__.py +2 -0
- data_designer/engine/dataset_builders/utils/concurrency.py +184 -0
- data_designer/engine/dataset_builders/utils/config_compiler.py +60 -0
- data_designer/engine/dataset_builders/utils/dag.py +56 -0
- data_designer/engine/dataset_builders/utils/dataset_batch_manager.py +190 -0
- data_designer/engine/dataset_builders/utils/errors.py +13 -0
- data_designer/engine/errors.py +49 -0
- data_designer/engine/model_provider.py +75 -0
- data_designer/engine/models/__init__.py +2 -0
- data_designer/engine/models/errors.py +308 -0
- data_designer/engine/models/facade.py +225 -0
- data_designer/engine/models/litellm_overrides.py +162 -0
- data_designer/engine/models/parsers/__init__.py +2 -0
- data_designer/engine/models/parsers/errors.py +34 -0
- data_designer/engine/models/parsers/parser.py +236 -0
- data_designer/engine/models/parsers/postprocessors.py +93 -0
- data_designer/engine/models/parsers/tag_parsers.py +60 -0
- data_designer/engine/models/parsers/types.py +82 -0
- data_designer/engine/models/recipes/base.py +79 -0
- data_designer/engine/models/recipes/response_recipes.py +291 -0
- data_designer/engine/models/registry.py +118 -0
- data_designer/engine/models/usage.py +75 -0
- data_designer/engine/models/utils.py +38 -0
- data_designer/engine/processing/ginja/__init__.py +2 -0
- data_designer/engine/processing/ginja/ast.py +64 -0
- data_designer/engine/processing/ginja/environment.py +461 -0
- data_designer/engine/processing/ginja/exceptions.py +54 -0
- data_designer/engine/processing/ginja/record.py +30 -0
- data_designer/engine/processing/gsonschema/__init__.py +2 -0
- data_designer/engine/processing/gsonschema/exceptions.py +8 -0
- data_designer/engine/processing/gsonschema/schema_transformers.py +81 -0
- data_designer/engine/processing/gsonschema/types.py +8 -0
- data_designer/engine/processing/gsonschema/validators.py +143 -0
- data_designer/engine/processing/processors/base.py +15 -0
- data_designer/engine/processing/processors/drop_columns.py +46 -0
- data_designer/engine/processing/processors/registry.py +20 -0
- data_designer/engine/processing/utils.py +120 -0
- data_designer/engine/registry/base.py +97 -0
- data_designer/engine/registry/data_designer_registry.py +37 -0
- data_designer/engine/registry/errors.py +10 -0
- data_designer/engine/resources/managed_dataset_generator.py +35 -0
- data_designer/engine/resources/managed_dataset_repository.py +194 -0
- data_designer/engine/resources/managed_storage.py +63 -0
- data_designer/engine/resources/resource_provider.py +46 -0
- data_designer/engine/resources/seed_dataset_data_store.py +66 -0
- data_designer/engine/sampling_gen/column.py +89 -0
- data_designer/engine/sampling_gen/constraints.py +95 -0
- data_designer/engine/sampling_gen/data_sources/base.py +214 -0
- data_designer/engine/sampling_gen/data_sources/errors.py +10 -0
- data_designer/engine/sampling_gen/data_sources/sources.py +342 -0
- data_designer/engine/sampling_gen/entities/__init__.py +2 -0
- data_designer/engine/sampling_gen/entities/assets/zip_area_code_map.parquet +0 -0
- data_designer/engine/sampling_gen/entities/dataset_based_person_fields.py +64 -0
- data_designer/engine/sampling_gen/entities/email_address_utils.py +169 -0
- data_designer/engine/sampling_gen/entities/errors.py +8 -0
- data_designer/engine/sampling_gen/entities/national_id_utils.py +100 -0
- data_designer/engine/sampling_gen/entities/person.py +142 -0
- data_designer/engine/sampling_gen/entities/phone_number.py +122 -0
- data_designer/engine/sampling_gen/errors.py +24 -0
- data_designer/engine/sampling_gen/generator.py +121 -0
- data_designer/engine/sampling_gen/jinja_utils.py +60 -0
- data_designer/engine/sampling_gen/people_gen.py +203 -0
- data_designer/engine/sampling_gen/person_constants.py +54 -0
- data_designer/engine/sampling_gen/schema.py +143 -0
- data_designer/engine/sampling_gen/schema_builder.py +59 -0
- data_designer/engine/sampling_gen/utils.py +40 -0
- data_designer/engine/secret_resolver.py +80 -0
- data_designer/engine/validators/__init__.py +17 -0
- data_designer/engine/validators/base.py +36 -0
- data_designer/engine/validators/local_callable.py +34 -0
- data_designer/engine/validators/python.py +245 -0
- data_designer/engine/validators/remote.py +83 -0
- data_designer/engine/validators/sql.py +60 -0
- data_designer/errors.py +5 -0
- data_designer/essentials/__init__.py +137 -0
- data_designer/interface/__init__.py +2 -0
- data_designer/interface/data_designer.py +351 -0
- data_designer/interface/errors.py +16 -0
- data_designer/interface/results.py +55 -0
- data_designer/logging.py +161 -0
- data_designer/plugin_manager.py +83 -0
- data_designer/plugins/__init__.py +6 -0
- data_designer/plugins/errors.py +10 -0
- data_designer/plugins/plugin.py +69 -0
- data_designer/plugins/registry.py +86 -0
- data_designer-0.1.0.dist-info/METADATA +173 -0
- data_designer-0.1.0.dist-info/RECORD +177 -0
- data_designer-0.1.0.dist-info/WHEEL +4 -0
- data_designer-0.1.0.dist-info/entry_points.txt +2 -0
- data_designer-0.1.0.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, Optional, Union
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from data_designer.config.analysis.column_profilers import JudgeScoreDistributions, JudgeScoreSample
|
|
11
|
+
from data_designer.config.analysis.column_statistics import (
|
|
12
|
+
CategoricalDistribution,
|
|
13
|
+
ColumnDistributionType,
|
|
14
|
+
MissingValue,
|
|
15
|
+
NumericalDistribution,
|
|
16
|
+
)
|
|
17
|
+
from data_designer.config.column_configs import LLMJudgeColumnConfig
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def extract_judge_score_distributions(
|
|
23
|
+
column_config: LLMJudgeColumnConfig, df: pd.DataFrame
|
|
24
|
+
) -> Union[JudgeScoreDistributions, MissingValue]:
|
|
25
|
+
scores = defaultdict(list)
|
|
26
|
+
reasoning = defaultdict(list)
|
|
27
|
+
|
|
28
|
+
# Aggregate results as dicts of form {score_name: <result>}.
|
|
29
|
+
histograms = {}
|
|
30
|
+
distributions = {}
|
|
31
|
+
distribution_types = {}
|
|
32
|
+
|
|
33
|
+
for score in column_config.scores:
|
|
34
|
+
is_numerical = True
|
|
35
|
+
name = score.name.lower()
|
|
36
|
+
for results in df[column_config.name]:
|
|
37
|
+
try:
|
|
38
|
+
score = results[name].get("score", None)
|
|
39
|
+
|
|
40
|
+
if _can_be_converted_to_int(score):
|
|
41
|
+
score = int(score)
|
|
42
|
+
else:
|
|
43
|
+
score = str(score)
|
|
44
|
+
is_numerical = False
|
|
45
|
+
|
|
46
|
+
scores[name].append(score)
|
|
47
|
+
reasoning[name].append(results[name].get("reasoning", "No reasoning provided"))
|
|
48
|
+
except Exception as e:
|
|
49
|
+
logger.warning(f"⚠️ Failed to extract judge score for '{name}': {e}")
|
|
50
|
+
return MissingValue.OUTPUT_FORMAT_ERROR
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
series = pd.Series(scores[name], name=name)
|
|
54
|
+
cat_dist = CategoricalDistribution.from_series(series)
|
|
55
|
+
|
|
56
|
+
# For judge scores, build a categorical histogram, since numerical scores are integers.
|
|
57
|
+
histograms[name] = cat_dist.histogram
|
|
58
|
+
|
|
59
|
+
if is_numerical:
|
|
60
|
+
distribution_types[name] = ColumnDistributionType.NUMERICAL
|
|
61
|
+
distributions[name] = NumericalDistribution.from_series(series)
|
|
62
|
+
else:
|
|
63
|
+
distribution_types[name] = ColumnDistributionType.CATEGORICAL
|
|
64
|
+
distributions[name] = cat_dist
|
|
65
|
+
|
|
66
|
+
except Exception as e:
|
|
67
|
+
logger.warning(f"⚠️ Failed to calculate judge score distribution for '{name}': {e}")
|
|
68
|
+
distribution_types[name] = ColumnDistributionType.UNKNOWN
|
|
69
|
+
distributions[name] = MissingValue.CALCULATION_FAILED
|
|
70
|
+
histograms[name] = MissingValue.CALCULATION_FAILED
|
|
71
|
+
|
|
72
|
+
return JudgeScoreDistributions(
|
|
73
|
+
scores=dict(scores),
|
|
74
|
+
reasoning=dict(reasoning),
|
|
75
|
+
distribution_types=distribution_types,
|
|
76
|
+
distributions=distributions,
|
|
77
|
+
histograms=histograms,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def sample_scores_and_reasoning(
|
|
82
|
+
scores: list[Union[int, str]],
|
|
83
|
+
reasoning: list[str],
|
|
84
|
+
num_samples: int,
|
|
85
|
+
random_seed: Optional[int] = None,
|
|
86
|
+
) -> list[JudgeScoreSample]:
|
|
87
|
+
if len(scores) != len(reasoning):
|
|
88
|
+
raise ValueError("scores and reasoning must have the same length")
|
|
89
|
+
|
|
90
|
+
if len(scores) == 0:
|
|
91
|
+
raise ValueError("scores and reasoning must not be empty")
|
|
92
|
+
|
|
93
|
+
if num_samples <= 0:
|
|
94
|
+
raise ValueError("num_samples must be greater than 0")
|
|
95
|
+
|
|
96
|
+
df_samples = pd.DataFrame({"score": scores, "reasoning": reasoning})
|
|
97
|
+
|
|
98
|
+
if len(scores) <= num_samples:
|
|
99
|
+
return [JudgeScoreSample(score=score, reasoning=reasoning) for score, reasoning in zip(scores, reasoning)]
|
|
100
|
+
|
|
101
|
+
# Sample maintaining original proportions from each category (int or str)
|
|
102
|
+
# Calculate the frequency of each score category
|
|
103
|
+
score_category_counts = df_samples["score"].value_counts()
|
|
104
|
+
|
|
105
|
+
# If more categories than samples, pick one sample from each of the most frequent categories
|
|
106
|
+
if len(score_category_counts) >= num_samples:
|
|
107
|
+
top_categories = score_category_counts.head(num_samples).index
|
|
108
|
+
samples = pd.concat(
|
|
109
|
+
[df_samples[df_samples["score"] == cat].sample(n=1, random_state=random_seed) for cat in top_categories],
|
|
110
|
+
ignore_index=True,
|
|
111
|
+
)
|
|
112
|
+
else:
|
|
113
|
+
# Sample proportionally to maintain original category ratios
|
|
114
|
+
# Create weights based on the original frequency of each score
|
|
115
|
+
weights = df_samples["score"].map(score_category_counts)
|
|
116
|
+
samples = df_samples.sample(n=num_samples, weights=weights, random_state=random_seed)
|
|
117
|
+
|
|
118
|
+
return [
|
|
119
|
+
JudgeScoreSample(score=row["score"], reasoning=row["reasoning"]) for row in samples.to_dict(orient="records")
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _can_be_converted_to_int(value: Any) -> bool:
|
|
124
|
+
try:
|
|
125
|
+
int(value)
|
|
126
|
+
return True
|
|
127
|
+
except (ValueError, TypeError):
|
|
128
|
+
return False
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from typing import overload
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from data_designer.config.utils.type_helpers import StrEnum
|
|
10
|
+
from data_designer.engine.configurable_task import ConfigurableTask, ConfigurableTaskMetadata, DataT, TaskConfigT
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GenerationStrategy(StrEnum):
|
|
14
|
+
CELL_BY_CELL = "cell_by_cell"
|
|
15
|
+
FULL_COLUMN = "full_column"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GeneratorMetadata(ConfigurableTaskMetadata):
|
|
19
|
+
generation_strategy: GenerationStrategy
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC):
|
|
23
|
+
@property
|
|
24
|
+
def can_generate_from_scratch(self) -> bool:
|
|
25
|
+
return False
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def generation_strategy(self) -> GenerationStrategy:
|
|
29
|
+
return self.metadata().generation_strategy
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def metadata() -> GeneratorMetadata: ...
|
|
34
|
+
|
|
35
|
+
@overload
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def generate(self, data: dict) -> dict: ...
|
|
38
|
+
|
|
39
|
+
@overload
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def generate(self, data: pd.DataFrame) -> pd.DataFrame: ...
|
|
42
|
+
|
|
43
|
+
@abstractmethod
|
|
44
|
+
def generate(self, data: DataT) -> DataT: ...
|
|
45
|
+
|
|
46
|
+
def log_pre_generation(self) -> None:
|
|
47
|
+
"""A shared method to log info before the generator's `generate` method is called.
|
|
48
|
+
|
|
49
|
+
The idea is for dataset builders to call this method for all generators before calling their
|
|
50
|
+
`generate` method. This is to avoid logging the same information multiple times when running
|
|
51
|
+
generators in parallel.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class FromScratchColumnGenerator(ColumnGenerator[TaskConfigT], ABC):
|
|
56
|
+
@property
|
|
57
|
+
def can_generate_from_scratch(self) -> bool:
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def generate_from_scratch(self, num_records: int) -> pd.DataFrame: ...
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
from data_designer.config.column_configs import ExpressionColumnConfig
|
|
9
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
10
|
+
ColumnGenerator,
|
|
11
|
+
GenerationStrategy,
|
|
12
|
+
GeneratorMetadata,
|
|
13
|
+
)
|
|
14
|
+
from data_designer.engine.column_generators.utils.errors import ExpressionTemplateRenderError
|
|
15
|
+
from data_designer.engine.processing.ginja.environment import WithJinja2UserTemplateRendering
|
|
16
|
+
from data_designer.engine.processing.utils import deserialize_json_values
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ExpressionColumnGenerator(WithJinja2UserTemplateRendering, ColumnGenerator[ExpressionColumnConfig]):
|
|
22
|
+
@staticmethod
|
|
23
|
+
def metadata() -> GeneratorMetadata:
|
|
24
|
+
return GeneratorMetadata(
|
|
25
|
+
name="expression_generator",
|
|
26
|
+
description="Generate a column from a jinja2 expression.",
|
|
27
|
+
generation_strategy=GenerationStrategy.FULL_COLUMN,
|
|
28
|
+
required_resources=None,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
32
|
+
logger.info(f"🧩 Generating column `{self.config.name}` from expression")
|
|
33
|
+
|
|
34
|
+
missing_columns = list(set(self.config.required_columns) - set(data.columns))
|
|
35
|
+
if len(missing_columns) > 0:
|
|
36
|
+
error_msg = (
|
|
37
|
+
f"There was an error preparing the Jinja2 expression template. "
|
|
38
|
+
f"The following columns {missing_columns} are missing!"
|
|
39
|
+
)
|
|
40
|
+
raise ExpressionTemplateRenderError(error_msg)
|
|
41
|
+
|
|
42
|
+
self.prepare_jinja2_template_renderer(self.config.expr, data.columns.to_list())
|
|
43
|
+
records = []
|
|
44
|
+
for record in data.to_dict(orient="records"):
|
|
45
|
+
record[self.config.name] = self._cast_type(self.render_template(deserialize_json_values(record)))
|
|
46
|
+
records.append(record)
|
|
47
|
+
|
|
48
|
+
return pd.DataFrame(records)
|
|
49
|
+
|
|
50
|
+
def _cast_type(self, value: str) -> str | float | int | bool:
|
|
51
|
+
if self.config.dtype == "str":
|
|
52
|
+
return value
|
|
53
|
+
elif self.config.dtype == "float":
|
|
54
|
+
return float(value)
|
|
55
|
+
elif self.config.dtype == "int":
|
|
56
|
+
return int(float(value))
|
|
57
|
+
elif self.config.dtype == "bool":
|
|
58
|
+
try:
|
|
59
|
+
return bool(int(float(value)))
|
|
60
|
+
except ValueError:
|
|
61
|
+
return bool(f"{value}".lower() == "true")
|
|
62
|
+
else:
|
|
63
|
+
raise ValueError(f"Invalid dtype: {self.config.dtype}")
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from data_designer.config.column_configs import (
|
|
8
|
+
LLMCodeColumnConfig,
|
|
9
|
+
LLMJudgeColumnConfig,
|
|
10
|
+
LLMStructuredColumnConfig,
|
|
11
|
+
LLMTextColumnConfig,
|
|
12
|
+
)
|
|
13
|
+
from data_designer.config.column_types import COLUMN_TYPE_EMOJI_MAP
|
|
14
|
+
from data_designer.config.models import InferenceParameters, ModelConfig
|
|
15
|
+
from data_designer.config.utils.constants import REASONING_TRACE_COLUMN_POSTFIX
|
|
16
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
17
|
+
ColumnGenerator,
|
|
18
|
+
GenerationStrategy,
|
|
19
|
+
GeneratorMetadata,
|
|
20
|
+
)
|
|
21
|
+
from data_designer.engine.column_generators.utils.prompt_renderer import (
|
|
22
|
+
PromptType,
|
|
23
|
+
RecordBasedPromptRenderer,
|
|
24
|
+
create_response_recipe,
|
|
25
|
+
)
|
|
26
|
+
from data_designer.engine.models.facade import ModelFacade
|
|
27
|
+
from data_designer.engine.models.recipes.base import ResponseRecipe
|
|
28
|
+
from data_designer.engine.processing.utils import deserialize_json_values
|
|
29
|
+
from data_designer.engine.resources.resource_provider import ResourceType
|
|
30
|
+
|
|
31
|
+
DEFAULT_MAX_CONVERSATION_RESTARTS = 5
|
|
32
|
+
DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS = 0
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class WithLLMGeneration:
|
|
39
|
+
@functools.cached_property
|
|
40
|
+
def model(self) -> ModelFacade:
|
|
41
|
+
return self.resource_provider.model_registry.get_model(model_alias=self.config.model_alias)
|
|
42
|
+
|
|
43
|
+
@functools.cached_property
|
|
44
|
+
def model_config(self) -> ModelConfig:
|
|
45
|
+
return self.resource_provider.model_registry.get_model_config(model_alias=self.config.model_alias)
|
|
46
|
+
|
|
47
|
+
@functools.cached_property
|
|
48
|
+
def inference_parameters(self) -> InferenceParameters:
|
|
49
|
+
return self.model_config.inference_parameters
|
|
50
|
+
|
|
51
|
+
@functools.cached_property
|
|
52
|
+
def prompt_renderer(self) -> RecordBasedPromptRenderer:
|
|
53
|
+
return RecordBasedPromptRenderer(
|
|
54
|
+
response_recipe=self.response_recipe,
|
|
55
|
+
error_message_context={
|
|
56
|
+
"column_name": self.config.name,
|
|
57
|
+
"column_type": self.config.column_type,
|
|
58
|
+
"model_alias": self.config.model_alias,
|
|
59
|
+
},
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
@functools.cached_property
|
|
63
|
+
def response_recipe(self) -> ResponseRecipe:
|
|
64
|
+
return create_response_recipe(self.config, self.model_config)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def max_conversation_correction_steps(self) -> int:
|
|
68
|
+
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def max_conversation_restarts(self) -> int:
|
|
72
|
+
return DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
73
|
+
|
|
74
|
+
def generate(self, data: dict) -> dict:
|
|
75
|
+
deserialized_record = deserialize_json_values(data)
|
|
76
|
+
|
|
77
|
+
multi_modal_context = None
|
|
78
|
+
if self.config.multi_modal_context is not None and len(self.config.multi_modal_context) > 0:
|
|
79
|
+
multi_modal_context = [
|
|
80
|
+
context.get_context(deserialized_record) for context in self.config.multi_modal_context
|
|
81
|
+
]
|
|
82
|
+
|
|
83
|
+
response, reasoning_trace = self.model.generate(
|
|
84
|
+
prompt=self.prompt_renderer.render(
|
|
85
|
+
record=deserialized_record,
|
|
86
|
+
prompt_template=self.config.prompt,
|
|
87
|
+
prompt_type=PromptType.USER_PROMPT,
|
|
88
|
+
),
|
|
89
|
+
system_prompt=self.prompt_renderer.render(
|
|
90
|
+
record=deserialized_record,
|
|
91
|
+
prompt_template=self.config.system_prompt,
|
|
92
|
+
prompt_type=PromptType.SYSTEM_PROMPT,
|
|
93
|
+
),
|
|
94
|
+
parser=self.response_recipe.parse,
|
|
95
|
+
multi_modal_context=multi_modal_context,
|
|
96
|
+
max_correction_steps=self.max_conversation_correction_steps,
|
|
97
|
+
max_conversation_restarts=self.max_conversation_restarts,
|
|
98
|
+
purpose=f"running generation for column '{self.config.name}'",
|
|
99
|
+
**self.inference_parameters.generate_kwargs,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
data[self.config.name] = deserialize_json_values(self.response_recipe.serialize_output(response))
|
|
103
|
+
|
|
104
|
+
if reasoning_trace:
|
|
105
|
+
data[self.config.name + REASONING_TRACE_COLUMN_POSTFIX] = reasoning_trace
|
|
106
|
+
|
|
107
|
+
return data
|
|
108
|
+
|
|
109
|
+
def log_pre_generation(self) -> None:
|
|
110
|
+
emoji = COLUMN_TYPE_EMOJI_MAP[self.config.column_type]
|
|
111
|
+
logger.info(f"{emoji} Preparing {self.config.column_type} column generation")
|
|
112
|
+
logger.info(f" |-- column name: {self.config.name!r}")
|
|
113
|
+
logger.info(f" |-- model config:\n{self.model_config.model_dump_json(indent=4)}")
|
|
114
|
+
if self.model_config.provider is None:
|
|
115
|
+
logger.info(f" |-- default model provider: {self._get_provider_name()!r}")
|
|
116
|
+
|
|
117
|
+
def _get_provider_name(self) -> str:
|
|
118
|
+
model_alias = self.model_config.alias
|
|
119
|
+
provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias)
|
|
120
|
+
return provider.name
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class LLMTextCellGenerator(WithLLMGeneration, ColumnGenerator[LLMTextColumnConfig]):
|
|
124
|
+
@staticmethod
|
|
125
|
+
def metadata() -> GeneratorMetadata:
|
|
126
|
+
return GeneratorMetadata(
|
|
127
|
+
name="llm_text_generator",
|
|
128
|
+
description="Generate a new dataset cell from a prompt template",
|
|
129
|
+
generation_strategy=GenerationStrategy.CELL_BY_CELL,
|
|
130
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class LLMCodeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMCodeColumnConfig]):
|
|
135
|
+
@staticmethod
|
|
136
|
+
def metadata() -> GeneratorMetadata:
|
|
137
|
+
return GeneratorMetadata(
|
|
138
|
+
name="llm_code_generator",
|
|
139
|
+
description="Generate a new dataset cell from a prompt template",
|
|
140
|
+
generation_strategy=GenerationStrategy.CELL_BY_CELL,
|
|
141
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
class LLMStructuredCellGenerator(WithLLMGeneration, ColumnGenerator[LLMStructuredColumnConfig]):
|
|
146
|
+
@staticmethod
|
|
147
|
+
def metadata() -> GeneratorMetadata:
|
|
148
|
+
return GeneratorMetadata(
|
|
149
|
+
name="llm_structured_generator",
|
|
150
|
+
description="Generate a new dataset cell from a prompt template",
|
|
151
|
+
generation_strategy=GenerationStrategy.CELL_BY_CELL,
|
|
152
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class LLMJudgeCellGenerator(WithLLMGeneration, ColumnGenerator[LLMJudgeColumnConfig]):
|
|
157
|
+
@staticmethod
|
|
158
|
+
def metadata() -> GeneratorMetadata:
|
|
159
|
+
return GeneratorMetadata(
|
|
160
|
+
name="llm_judge_generator",
|
|
161
|
+
description="Judge a new dataset cell based on a set of rubrics",
|
|
162
|
+
generation_strategy=GenerationStrategy.CELL_BY_CELL,
|
|
163
|
+
required_resources=[ResourceType.MODEL_REGISTRY],
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
@property
|
|
167
|
+
def max_conversation_correction_steps(self) -> int:
|
|
168
|
+
return DEFAULT_MAX_CONVERSATION_CORRECTION_STEPS
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def max_conversation_restarts(self) -> int:
|
|
172
|
+
return 2 * DEFAULT_MAX_CONVERSATION_RESTARTS
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
from functools import partial
|
|
5
|
+
import logging
|
|
6
|
+
import random
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from data_designer.config.utils.constants import LOCALES_WITH_MANAGED_DATASETS
|
|
12
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
13
|
+
FromScratchColumnGenerator,
|
|
14
|
+
GenerationStrategy,
|
|
15
|
+
GeneratorMetadata,
|
|
16
|
+
)
|
|
17
|
+
from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig
|
|
18
|
+
from data_designer.engine.processing.utils import concat_datasets
|
|
19
|
+
from data_designer.engine.resources.managed_dataset_generator import ManagedDatasetGenerator
|
|
20
|
+
from data_designer.engine.resources.resource_provider import ResourceType
|
|
21
|
+
from data_designer.engine.sampling_gen.data_sources.sources import SamplerType
|
|
22
|
+
from data_designer.engine.sampling_gen.entities.person import load_person_data_sampler
|
|
23
|
+
from data_designer.engine.sampling_gen.generator import DatasetGenerator as SamplingDatasetGenerator
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class SamplerColumnGenerator(FromScratchColumnGenerator[SamplerMultiColumnConfig]):
|
|
29
|
+
@staticmethod
|
|
30
|
+
def metadata() -> GeneratorMetadata:
|
|
31
|
+
return GeneratorMetadata(
|
|
32
|
+
name="sampler_column_generator",
|
|
33
|
+
description="Generate columns using sampling-based method.",
|
|
34
|
+
generation_strategy=GenerationStrategy.FULL_COLUMN,
|
|
35
|
+
required_resources=[ResourceType.BLOB_STORAGE],
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def generate(self, data: pd.DataFrame) -> pd.DataFrame:
|
|
39
|
+
df_samplers = self.generate_from_scratch(len(data))
|
|
40
|
+
return concat_datasets([data, df_samplers])
|
|
41
|
+
|
|
42
|
+
def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
|
|
43
|
+
sampling_generator = self._prepare_for_generation(num_records)
|
|
44
|
+
return sampling_generator.generate(num_records)
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def _needs_person_generator(self) -> bool:
|
|
48
|
+
columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
|
|
49
|
+
return any(c.params.locale in LOCALES_WITH_MANAGED_DATASETS for c in columns)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def _person_generator_loader(self) -> Callable[[bool], ManagedDatasetGenerator]:
|
|
53
|
+
return partial(load_person_data_sampler, blob_storage=self.resource_provider.blob_storage)
|
|
54
|
+
|
|
55
|
+
def _create_sampling_dataset_generator(self) -> SamplingDatasetGenerator:
|
|
56
|
+
return SamplingDatasetGenerator(
|
|
57
|
+
sampler_columns=self.config,
|
|
58
|
+
person_generator_loader=(self._person_generator_loader if self._needs_person_generator else None),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def _log_person_generation_if_needed(self) -> None:
|
|
62
|
+
if self._needs_person_generator:
|
|
63
|
+
columns = [c for c in self.config.columns if c.sampler_type == SamplerType.PERSON]
|
|
64
|
+
emoji = random.choice(["🧑🎨", "🙋♂️", "🙋♀️", "🧑🚀", "👩🎤", "👨🍳", "👩🔬", "👨💻", "👩💼"])
|
|
65
|
+
log_msg = f"🎲 {emoji} Initializing person generation"
|
|
66
|
+
if any(c.params.with_synthetic_personas for c in columns):
|
|
67
|
+
log_msg += " ⚡️ with synthetic personas ⚡️"
|
|
68
|
+
logger.info(log_msg)
|
|
69
|
+
|
|
70
|
+
def _prepare_for_generation(self, num_records: int) -> SamplingDatasetGenerator:
|
|
71
|
+
logger.info(
|
|
72
|
+
f"🎲 Preparing samplers to generate {num_records} records across {len(self.config.columns)} columns"
|
|
73
|
+
)
|
|
74
|
+
self._log_person_generation_if_needed()
|
|
75
|
+
return self._create_sampling_dataset_generator()
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import functools
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import duckdb
|
|
8
|
+
import pandas as pd
|
|
9
|
+
|
|
10
|
+
from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy
|
|
11
|
+
from data_designer.engine.column_generators.generators.base import (
|
|
12
|
+
FromScratchColumnGenerator,
|
|
13
|
+
GenerationStrategy,
|
|
14
|
+
GeneratorMetadata,
|
|
15
|
+
)
|
|
16
|
+
from data_designer.engine.column_generators.utils.errors import SeedDatasetError
|
|
17
|
+
from data_designer.engine.dataset_builders.multi_column_configs import SeedDatasetMultiColumnConfig
|
|
18
|
+
from data_designer.engine.processing.utils import concat_datasets
|
|
19
|
+
from data_designer.engine.resources.resource_provider import ResourceType
|
|
20
|
+
|
|
21
|
+
MAX_ZERO_RECORD_RESPONSE_FACTOR = 2
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class SeedDatasetColumnGenerator(FromScratchColumnGenerator[SeedDatasetMultiColumnConfig]):
|
|
27
|
+
@staticmethod
|
|
28
|
+
def metadata() -> GeneratorMetadata:
|
|
29
|
+
return GeneratorMetadata(
|
|
30
|
+
name="seed_dataset_column_generator",
|
|
31
|
+
description="Sample columns from a seed dataset.",
|
|
32
|
+
generation_strategy=GenerationStrategy.FULL_COLUMN,
|
|
33
|
+
required_resources=[ResourceType.DATASTORE],
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
@property
|
|
37
|
+
def num_records_sampled(self) -> int:
|
|
38
|
+
return self._num_records_sampled
|
|
39
|
+
|
|
40
|
+
@functools.cached_property
|
|
41
|
+
def duckdb_conn(self) -> duckdb.DuckDBPyConnection:
|
|
42
|
+
return self.resource_provider.datastore.create_duckdb_connection()
|
|
43
|
+
|
|
44
|
+
def generate(self, dataset: pd.DataFrame) -> pd.DataFrame:
|
|
45
|
+
return concat_datasets([self.generate_from_scratch(len(dataset)), dataset])
|
|
46
|
+
|
|
47
|
+
def generate_from_scratch(self, num_records: int) -> pd.DataFrame:
|
|
48
|
+
if num_records <= 0:
|
|
49
|
+
raise ValueError("🛑 `num_records` must be positive.")
|
|
50
|
+
|
|
51
|
+
if self._batch_reader is None:
|
|
52
|
+
self._reset_batch_reader(num_records)
|
|
53
|
+
|
|
54
|
+
return self._sample_records(num_records)
|
|
55
|
+
|
|
56
|
+
def _initialize(self) -> None:
|
|
57
|
+
self._num_records_sampled = 0
|
|
58
|
+
self._batch_reader = None
|
|
59
|
+
self._df_remaining = None
|
|
60
|
+
self._dataset_uri = self.resource_provider.datastore.get_dataset_uri(self.config.dataset)
|
|
61
|
+
self._seed_dataset_size = self.duckdb_conn.execute(f"SELECT COUNT(*) FROM '{self._dataset_uri}'").fetchone()[0]
|
|
62
|
+
self._index_range = self._resolve_index_range()
|
|
63
|
+
|
|
64
|
+
def _validate_selection_strategy(self) -> None:
|
|
65
|
+
err_msg = None
|
|
66
|
+
if self.config.selection_strategy is not None:
|
|
67
|
+
if (
|
|
68
|
+
isinstance(self.config.selection_strategy, IndexRange)
|
|
69
|
+
and self.config.selection_strategy.end >= self._seed_dataset_size
|
|
70
|
+
):
|
|
71
|
+
err_msg = f"Selection strategy 'end' index {self.config.selection_strategy.end} is out of bounds for dataset size {self._seed_dataset_size}"
|
|
72
|
+
elif (
|
|
73
|
+
isinstance(self.config.selection_strategy, PartitionBlock)
|
|
74
|
+
and self.config.selection_strategy.num_partitions > self._seed_dataset_size
|
|
75
|
+
):
|
|
76
|
+
err_msg = f"Selection strategy 'num_partitions' {self.config.selection_strategy.num_partitions} is out of bounds for dataset size {self._seed_dataset_size}"
|
|
77
|
+
if err_msg is not None:
|
|
78
|
+
raise SeedDatasetError(err_msg)
|
|
79
|
+
|
|
80
|
+
def _resolve_index_range(self) -> IndexRange | None:
|
|
81
|
+
self._validate_selection_strategy()
|
|
82
|
+
index_range = None
|
|
83
|
+
if self.config.selection_strategy is not None:
|
|
84
|
+
if isinstance(self.config.selection_strategy, IndexRange):
|
|
85
|
+
index_range = self.config.selection_strategy
|
|
86
|
+
elif isinstance(self.config.selection_strategy, PartitionBlock):
|
|
87
|
+
index_range = self.config.selection_strategy.to_index_range(self._seed_dataset_size)
|
|
88
|
+
return index_range
|
|
89
|
+
|
|
90
|
+
def _reset_batch_reader(self, num_records: int) -> None:
|
|
91
|
+
shuffle = self.config.sampling_strategy == SamplingStrategy.SHUFFLE
|
|
92
|
+
shuffle_query = " ORDER BY RANDOM()" if shuffle else ""
|
|
93
|
+
|
|
94
|
+
if self._index_range is not None:
|
|
95
|
+
# Use LIMIT and OFFSET for efficient index range filtering
|
|
96
|
+
# IndexRange uses 0-based indexing [start, end] inclusive
|
|
97
|
+
# OFFSET skips the first 'start' rows (0-based)
|
|
98
|
+
# LIMIT takes 'end - start + 1' rows to include both start and end (inclusive)
|
|
99
|
+
offset_value = self._index_range.start
|
|
100
|
+
limit_value = self._index_range.end - self._index_range.start + 1
|
|
101
|
+
read_query = f"""
|
|
102
|
+
SELECT * FROM '{self._dataset_uri}'
|
|
103
|
+
LIMIT {limit_value} OFFSET {offset_value}
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
read_query = f"SELECT * FROM ({read_query}){shuffle_query}"
|
|
107
|
+
else:
|
|
108
|
+
read_query = f"SELECT * FROM '{self._dataset_uri}'{shuffle_query}"
|
|
109
|
+
self._batch_reader = self.duckdb_conn.query(read_query).record_batch(batch_size=num_records)
|
|
110
|
+
|
|
111
|
+
def _sample_records(self, num_records: int) -> pd.DataFrame:
|
|
112
|
+
logger.info(f"🌱 Sampling {num_records} records from seed dataset")
|
|
113
|
+
logger.info(f" |-- seed dataset size: {self._seed_dataset_size} records")
|
|
114
|
+
logger.info(f" |-- sampling strategy: {self.config.sampling_strategy}")
|
|
115
|
+
if self._index_range is not None:
|
|
116
|
+
if isinstance(self.config.selection_strategy, IndexRange):
|
|
117
|
+
logger.info(f" |-- selection: rows [{self._index_range.start} to {self._index_range.end}] inclusive")
|
|
118
|
+
else:
|
|
119
|
+
logger.info(
|
|
120
|
+
f" |-- selection: partition {self.config.selection_strategy.index + 1} of {self.config.selection_strategy.num_partitions}"
|
|
121
|
+
)
|
|
122
|
+
logger.info(f" |-- seed dataset size after selection: {self._index_range.size} records")
|
|
123
|
+
df_batch = pd.DataFrame()
|
|
124
|
+
df_sample = pd.DataFrame() if self._df_remaining is None else self._df_remaining
|
|
125
|
+
num_zero_record_responses = 0
|
|
126
|
+
|
|
127
|
+
while len(df_sample) < num_records:
|
|
128
|
+
try:
|
|
129
|
+
df_batch = self._batch_reader.read_next_batch().to_pandas()
|
|
130
|
+
df_sample = pd.concat([df_sample, df_batch], ignore_index=True)
|
|
131
|
+
except StopIteration:
|
|
132
|
+
self._reset_batch_reader(num_records)
|
|
133
|
+
|
|
134
|
+
if len(df_batch) == 0:
|
|
135
|
+
num_zero_record_responses += 1
|
|
136
|
+
if num_zero_record_responses > MAX_ZERO_RECORD_RESPONSE_FACTOR * num_records:
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
"🛑 Something went wrong while reading from the datastore. "
|
|
139
|
+
"Please check your connection and try again. "
|
|
140
|
+
"If the issue persists, please contact support."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
self._df_remaining = None
|
|
144
|
+
if len(df_sample) > num_records:
|
|
145
|
+
self._df_remaining = df_sample.iloc[num_records:].reset_index(drop=True)
|
|
146
|
+
df_sample = df_sample.iloc[:num_records]
|
|
147
|
+
self._num_records_sampled += len(df_sample)
|
|
148
|
+
|
|
149
|
+
return df_sample
|